pub const Q14_ONE: i16 = 16384;
pub const Q14_HALF: i16 = 8192;
pub const Q14_QUARTER: i16 = 4096;
#[inline]
pub fn lif_step(membrane: i16, alpha: i16, input_current: i32, v_thr: i16) -> (i16, bool) {
let decay = (membrane as i32 * alpha as i32) >> 14;
let v_new = decay + input_current;
let spike = v_new >= v_thr as i32;
let v_reset = if spike { v_new - v_thr as i32 } else { v_new };
(
v_reset.clamp(i16::MIN as i32, i16::MAX as i32) as i16,
spike,
)
}
#[inline]
pub fn surrogate_gradient_pwl(membrane: i16, v_thr: i16, gamma: i16) -> i16 {
debug_assert!(v_thr > 0, "v_thr must be positive for surrogate gradient");
let diff = membrane as i32 - v_thr as i32;
let abs_diff = if diff < 0 { -diff } else { diff };
let v_thr_i32 = v_thr as i32;
if abs_diff < v_thr_i32 {
let numerator = gamma as i32 * (v_thr_i32 - abs_diff);
(numerator / v_thr_i32) as i16
} else {
0
}
}
#[inline]
pub fn f64_to_q14(value: f64) -> i16 {
let scaled = value * Q14_ONE as f64;
scaled.clamp(i16::MIN as f64, i16::MAX as f64) as i16
}
#[inline]
pub fn q14_to_f64(value: i16) -> f64 {
value as f64 / Q14_ONE as f64
}
#[inline]
pub fn q14_mul(a: i16, b: i16) -> i16 {
((a as i32 * b as i32) >> 14) as i16
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn subthreshold_input_no_spike() {
let v_thr = Q14_HALF; let alpha = f64_to_q14(0.95);
let input = 1000_i32; let (v_new, spike) = lif_step(0, alpha, input, v_thr);
assert!(!spike, "should not spike with small input");
assert_eq!(v_new, 1000, "membrane should be 0*alpha + 1000 = 1000");
}
#[test]
fn suprathreshold_input_causes_spike_and_reset() {
let v_thr = Q14_HALF; let alpha = f64_to_q14(0.95);
let input = 10000_i32; let (v_new, spike) = lif_step(0, alpha, input, v_thr);
assert!(spike, "should spike when input exceeds threshold");
assert_eq!(v_new, 1808, "membrane should be reset by subtracting v_thr");
}
#[test]
fn decay_reduces_membrane_without_input() {
let alpha = f64_to_q14(0.5); let v_thr = Q14_ONE;
let membrane = Q14_HALF;
let (v_new, spike) = lif_step(membrane, alpha, 0, v_thr);
assert!(!spike, "should not spike without input");
assert_eq!(v_new, Q14_QUARTER, "membrane should decay to 0.25");
}
#[test]
fn accumulated_input_crosses_threshold() {
let v_thr = Q14_HALF;
let alpha = f64_to_q14(0.9);
let mut membrane: i16 = 0;
let input = 2000_i32; let mut total_spikes = 0;
for _ in 0..20 {
let (v, spike) = lif_step(membrane, alpha, input, v_thr);
membrane = v;
if spike {
total_spikes += 1;
}
}
assert!(
total_spikes > 0,
"should eventually spike after accumulating input over 20 steps"
);
}
#[test]
fn membrane_clamps_to_i16_range() {
let alpha = Q14_ONE;
let v_thr = Q14_ONE;
let input = 60000_i32;
let (v_new, _spike) = lif_step(i16::MAX, alpha, input, v_thr);
let _ = v_new; }
#[test]
fn surrogate_gradient_peak_at_threshold() {
let v_thr = Q14_HALF;
let gamma = Q14_ONE;
let psi = surrogate_gradient_pwl(v_thr, v_thr, gamma);
assert_eq!(
psi, gamma,
"surrogate gradient should peak at threshold: got {}",
psi
);
}
#[test]
fn surrogate_gradient_zero_far_from_threshold() {
let v_thr = Q14_HALF;
let gamma = Q14_ONE;
let psi_high = surrogate_gradient_pwl(Q14_ONE + Q14_HALF, v_thr, gamma);
assert_eq!(psi_high, 0, "should be zero far above threshold");
let psi_low = surrogate_gradient_pwl(-Q14_HALF, v_thr, gamma);
assert_eq!(psi_low, 0, "should be zero far below threshold");
}
#[test]
fn surrogate_gradient_symmetric_around_threshold() {
let v_thr = Q14_HALF;
let gamma = Q14_ONE;
let offset = 1000_i16;
let psi_above = surrogate_gradient_pwl(v_thr + offset, v_thr, gamma);
let psi_below = surrogate_gradient_pwl(v_thr - offset, v_thr, gamma);
assert_eq!(
psi_above, psi_below,
"surrogate gradient should be symmetric around v_thr"
);
}
#[test]
fn q14_conversion_roundtrip() {
let values = [0.0, 0.5, 1.0, -1.0, 0.95, -0.3];
for &v in &values {
let q = f64_to_q14(v);
let back = q14_to_f64(q);
assert!(
(back - v).abs() < 0.001,
"roundtrip failed for {}: got {}",
v,
back
);
}
}
#[test]
fn q14_mul_correctness() {
let result = q14_mul(Q14_HALF, Q14_HALF);
assert_eq!(result, Q14_QUARTER, "0.5 * 0.5 should be 0.25");
let result2 = q14_mul(Q14_ONE, Q14_HALF);
assert_eq!(result2, Q14_HALF, "1.0 * 0.5 should be 0.5");
}
#[test]
fn negative_membrane_decay() {
let alpha = f64_to_q14(0.9);
let v_thr = Q14_HALF;
let membrane: i16 = -4000; let (v_new, spike) = lif_step(membrane, alpha, 0, v_thr);
assert!(!spike, "negative membrane should not spike");
let expected = ((-4000_i32 * alpha as i32) >> 14) as i16;
assert_eq!(v_new, expected, "negative membrane should decay toward 0");
}
}