pub fn update_pre_trace_fixed(trace: &mut [i16], spikes: &[u8], alpha: i16) {
debug_assert_eq!(trace.len(), spikes.len());
let q14_one = super::lif::Q14_ONE;
for i in 0..trace.len() {
let decayed = (trace[i] as i32 * alpha as i32) >> 14;
let spike_contrib = if spikes[i] != 0 { q14_one as i32 } else { 0 };
let new_val = decayed + spike_contrib;
trace[i] = new_val.clamp(i16::MIN as i32, i16::MAX as i32) as i16;
}
}
pub fn update_eligibility_fixed(elig: &mut [i16], psi: i16, pre_traces: &[i16], kappa: i16) {
debug_assert_eq!(elig.len(), pre_traces.len());
let q14_one = super::lif::Q14_ONE as i32;
let one_minus_kappa = q14_one - kappa as i32;
for i in 0..elig.len() {
let decayed = (elig[i] as i32 * kappa as i32) >> 14;
let psi_pre = (psi as i32 * pre_traces[i] as i32) >> 14;
let contribution = (one_minus_kappa * psi_pre) >> 14;
let new_val = decayed + contribution;
elig[i] = new_val.clamp(i16::MIN as i32, i16::MAX as i32) as i16;
}
}
pub fn compute_learning_signal_fixed(feedback: &[i16], error: &[i16]) -> i16 {
debug_assert_eq!(feedback.len(), error.len());
let mut acc: i32 = 0;
for i in 0..feedback.len() {
acc += (feedback[i] as i32 * error[i] as i32) >> 14;
}
acc.clamp(i16::MIN as i32, i16::MAX as i32) as i16
}
pub fn update_weights_fixed(
weights: &mut [i16],
eligibilities: &[i16],
learning_signal: i16,
eta: i16,
) {
debug_assert_eq!(weights.len(), eligibilities.len());
for i in 0..weights.len() {
let product = (eligibilities[i] as i32 * learning_signal as i32) >> 14;
let delta = (eta as i32 * product) >> 14;
let new_w = weights[i] as i32 + delta;
weights[i] = new_w.clamp(i16::MIN as i32, i16::MAX as i32) as i16;
}
}
pub fn update_output_weights_fixed(weights: &mut [i16], error: i16, spikes: &[u8], eta: i16) {
debug_assert_eq!(weights.len(), spikes.len());
for j in 0..weights.len() {
if spikes[j] != 0 {
let delta = (eta as i32 * error as i32) >> 14;
let new_w = weights[j] as i32 + delta;
weights[j] = new_w.clamp(i16::MIN as i32, i16::MAX as i32) as i16;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::snn::lif::{f64_to_q14, q14_to_f64, Q14_HALF, Q14_ONE};
#[test]
fn pre_trace_accumulates_spikes() {
let alpha = f64_to_q14(0.9);
let mut trace = [0_i16; 3];
let spikes = [1_u8, 0, 1];
update_pre_trace_fixed(&mut trace, &spikes, alpha);
assert_eq!(
trace[0], Q14_ONE,
"spiking channel should accumulate to Q14_ONE"
);
assert_eq!(trace[1], 0, "non-spiking channel should stay at 0");
assert_eq!(
trace[2], Q14_ONE,
"spiking channel should accumulate to Q14_ONE"
);
}
#[test]
fn pre_trace_decays_without_spikes() {
let alpha = f64_to_q14(0.5);
let mut trace = [Q14_ONE; 2];
let no_spikes = [0_u8; 2];
update_pre_trace_fixed(&mut trace, &no_spikes, alpha);
assert_eq!(trace[0], Q14_HALF, "trace should decay by alpha");
assert_eq!(trace[1], Q14_HALF, "trace should decay by alpha");
}
#[test]
fn eligibility_accumulates_with_psi_and_trace() {
let kappa = f64_to_q14(0.9);
let psi = f64_to_q14(0.5); let pre_traces = [Q14_ONE; 2]; let mut elig = [0_i16; 2];
update_eligibility_fixed(&mut elig, psi, &pre_traces, kappa);
let expected = f64_to_q14(0.05);
let tolerance = 50; assert!(
(elig[0] - expected).abs() < tolerance,
"eligibility should be ~0.05, got {}",
q14_to_f64(elig[0])
);
}
#[test]
fn eligibility_decays_with_kappa() {
let kappa = f64_to_q14(0.5);
let psi = 0_i16; let pre_traces = [0_i16; 1];
let mut elig = [Q14_ONE; 1];
update_eligibility_fixed(&mut elig, psi, &pre_traces, kappa);
assert_eq!(
elig[0], Q14_HALF,
"eligibility should decay by kappa to 0.5"
);
}
#[test]
fn learning_signal_combines_feedback_and_error() {
let feedback = [Q14_ONE, f64_to_q14(-0.5)];
let error = [f64_to_q14(0.1), f64_to_q14(0.2)];
let l = compute_learning_signal_fixed(&feedback, &error);
let tolerance = 50;
assert!(
l.abs() < tolerance,
"learning signal should be ~0.0, got {}",
q14_to_f64(l)
);
}
#[test]
fn learning_signal_responds_to_error_direction() {
let feedback = [Q14_ONE]; let positive_error = [f64_to_q14(0.5)];
let negative_error = [f64_to_q14(-0.5)];
let l_pos = compute_learning_signal_fixed(&feedback, &positive_error);
let l_neg = compute_learning_signal_fixed(&feedback, &negative_error);
assert!(l_pos > 0, "positive error should produce positive signal");
assert!(l_neg < 0, "negative error should produce negative signal");
}
#[test]
fn weight_update_direction() {
let eta = f64_to_q14(0.01);
let learning_signal = f64_to_q14(0.5);
let elig = [f64_to_q14(0.3), f64_to_q14(-0.2)];
let mut weights = [0_i16; 2];
update_weights_fixed(&mut weights, &elig, learning_signal, eta);
assert!(
weights[0] > 0,
"positive elig + positive signal should increase weight"
);
assert!(
weights[1] < 0,
"negative elig + positive signal should decrease weight"
);
}
#[test]
fn weight_update_magnitude_scales_with_eta() {
let elig = [Q14_ONE; 1];
let learning_signal = Q14_ONE;
let eta_small = f64_to_q14(0.001);
let eta_large = f64_to_q14(0.1);
let mut w_small = [0_i16; 1];
let mut w_large = [0_i16; 1];
update_weights_fixed(&mut w_small, &elig, learning_signal, eta_small);
update_weights_fixed(&mut w_large, &elig, learning_signal, eta_large);
assert!(
w_large[0].abs() > w_small[0].abs(),
"larger eta should produce larger update: small={}, large={}",
w_small[0],
w_large[0]
);
}
#[test]
fn output_weight_update_only_for_spiking_neurons() {
let eta = f64_to_q14(0.01);
let error = f64_to_q14(1.0);
let spikes = [1_u8, 0, 1, 0];
let mut weights = [0_i16; 4];
update_output_weights_fixed(&mut weights, error, &spikes, eta);
assert!(
weights[0] != 0,
"spiking neuron 0 should have weight update"
);
assert_eq!(weights[1], 0, "non-spiking neuron 1 should have no update");
assert!(
weights[2] != 0,
"spiking neuron 2 should have weight update"
);
assert_eq!(weights[3], 0, "non-spiking neuron 3 should have no update");
}
#[test]
fn output_weight_update_sign_matches_error() {
let eta = f64_to_q14(0.1);
let spikes = [1_u8; 1];
let mut w_pos = [0_i16; 1];
let mut w_neg = [0_i16; 1];
update_output_weights_fixed(&mut w_pos, f64_to_q14(0.5), &spikes, eta);
update_output_weights_fixed(&mut w_neg, f64_to_q14(-0.5), &spikes, eta);
assert!(w_pos[0] > 0, "positive error should increase weight");
assert!(w_neg[0] < 0, "negative error should decrease weight");
}
#[test]
fn weights_clamp_to_i16_range() {
let eta = Q14_ONE; let learning_signal = Q14_ONE;
let elig = [i16::MAX; 1];
let mut weights = [i16::MAX; 1];
update_weights_fixed(&mut weights, &elig, learning_signal, eta);
let _ = weights[0];
}
}