Skip to main content

irithyll_core/snn/
readout.rs

1//! Non-spiking leaky integrator readout neuron.
2//!
3//! The readout neuron accumulates weighted spike inputs from the hidden layer
4//! without producing spikes itself. Its membrane potential serves as the
5//! continuous output prediction of the network.
6//!
7//! ```text
8//! y[t] = kappa_out * y[t-1] + sum(W_kj * z_j[t])
9//! ```
10//!
11//! The membrane uses i32 precision to avoid accumulation overflow from many
12//! input spikes over long sequences.
13
14/// Non-spiking leaky integrator readout neuron.
15///
16/// The readout accumulates weighted spike inputs with exponential decay.
17/// Its membrane potential (i32 for extra precision) is the network's
18/// continuous output.
19///
20/// # Precision
21///
22/// While hidden layer neurons use i16 membranes, the readout uses i32 to
23/// prevent overflow during long sequences. The decay factor `kappa` is
24/// still Q1.14 (i16), but the membrane has ~18 bits of headroom above
25/// the Q1.14 range.
26pub struct ReadoutNeuron {
27    /// Membrane potential (higher precision accumulator).
28    /// Stored in Q1.14-compatible scale but with i32 range.
29    pub membrane: i32,
30    /// Decay factor in Q1.14 (controls how fast past inputs are forgotten).
31    pub kappa: i16,
32}
33
34impl ReadoutNeuron {
35    /// Create a new readout neuron with the given decay factor.
36    ///
37    /// # Arguments
38    ///
39    /// * `kappa` -- decay factor in Q1.14 (0 = no memory, Q14_ONE = perfect memory)
40    pub fn new(kappa: i16) -> Self {
41        Self { membrane: 0, kappa }
42    }
43
44    /// Advance the readout by one timestep.
45    ///
46    /// Decays the membrane potential and integrates new weighted input:
47    ///
48    /// ```text
49    /// membrane = (membrane * kappa) >> 14 + weighted_input
50    /// ```
51    ///
52    /// # Arguments
53    ///
54    /// * `weighted_input` -- sum of `W_kj * z_j[t]` for active spikes, as i32
55    #[inline]
56    pub fn step(&mut self, weighted_input: i32) {
57        let decayed = (self.membrane as i64 * self.kappa as i64) >> 14;
58        self.membrane = (decayed as i32).saturating_add(weighted_input);
59    }
60
61    /// Get the raw membrane potential as i32.
62    #[inline]
63    pub fn output_i32(&self) -> i32 {
64        self.membrane
65    }
66
67    /// Dequantize the membrane potential to f64.
68    ///
69    /// # Arguments
70    ///
71    /// * `scale` -- output scaling factor (typically `1.0 / Q14_ONE as f64`)
72    #[inline]
73    pub fn output_f64(&self, scale: f64) -> f64 {
74        self.membrane as f64 * scale
75    }
76
77    /// Reset the membrane potential to zero.
78    pub fn reset(&mut self) {
79        self.membrane = 0;
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86    use crate::snn::lif::{f64_to_q14, Q14_HALF, Q14_ONE};
87
88    #[test]
89    fn new_readout_has_zero_membrane() {
90        let r = ReadoutNeuron::new(Q14_HALF);
91        assert_eq!(r.membrane, 0);
92        assert_eq!(r.output_i32(), 0);
93    }
94
95    #[test]
96    fn step_accumulates_input() {
97        let mut r = ReadoutNeuron::new(Q14_ONE); // no decay
98        r.step(1000);
99        assert_eq!(r.output_i32(), 1000);
100        r.step(500);
101        assert_eq!(r.output_i32(), 1500, "should accumulate with kappa=1.0");
102    }
103
104    #[test]
105    fn step_decays_membrane() {
106        let mut r = ReadoutNeuron::new(Q14_HALF); // 0.5 decay
107        r.step(1000);
108        assert_eq!(r.output_i32(), 1000);
109
110        // Next step with zero input: 1000 * 0.5 = 500
111        r.step(0);
112        assert_eq!(r.output_i32(), 500, "membrane should decay by kappa");
113
114        // Another step: 500 * 0.5 = 250
115        r.step(0);
116        assert_eq!(r.output_i32(), 250, "membrane should continue decaying");
117    }
118
119    #[test]
120    fn output_f64_dequantizes_correctly() {
121        let mut r = ReadoutNeuron::new(Q14_ONE);
122        r.membrane = Q14_ONE as i32; // set to 1.0 in Q1.14
123        let scale = 1.0 / Q14_ONE as f64;
124        let out = r.output_f64(scale);
125        assert!(
126            (out - 1.0).abs() < 0.001,
127            "output_f64 should dequantize 16384 to ~1.0, got {}",
128            out
129        );
130    }
131
132    #[test]
133    fn reset_clears_membrane() {
134        let mut r = ReadoutNeuron::new(Q14_HALF);
135        r.step(5000);
136        assert!(r.output_i32() != 0);
137        r.reset();
138        assert_eq!(r.output_i32(), 0, "reset should zero the membrane");
139    }
140
141    #[test]
142    fn no_overflow_with_large_accumulation() {
143        let mut r = ReadoutNeuron::new(f64_to_q14(0.99));
144        // Feed many large inputs to test i32 accumulation
145        for _ in 0..1000 {
146            r.step(10000);
147        }
148        // Should not panic or wrap around
149        assert!(
150            r.output_i32() > 0,
151            "membrane should be positive after positive inputs"
152        );
153    }
154
155    #[test]
156    fn decay_with_negative_membrane() {
157        let mut r = ReadoutNeuron::new(Q14_HALF);
158        r.step(-2000);
159        assert_eq!(r.output_i32(), -2000);
160
161        r.step(0);
162        assert_eq!(
163            r.output_i32(),
164            -1000,
165            "negative membrane should decay toward 0"
166        );
167    }
168
169    #[test]
170    fn zero_kappa_forgets_immediately() {
171        let mut r = ReadoutNeuron::new(0); // zero decay
172        r.step(5000);
173        assert_eq!(r.output_i32(), 5000);
174
175        r.step(0);
176        assert_eq!(
177            r.output_i32(),
178            0,
179            "zero kappa should forget membrane immediately"
180        );
181    }
182}