irithyll_core/snn/readout.rs
1//! Non-spiking leaky integrator readout neuron.
2//!
3//! The readout neuron accumulates weighted spike inputs from the hidden layer
4//! without producing spikes itself. Its membrane potential serves as the
5//! continuous output prediction of the network.
6//!
7//! ```text
8//! y[t] = kappa_out * y[t-1] + sum(W_kj * z_j[t])
9//! ```
10//!
11//! The membrane uses i32 precision to avoid accumulation overflow from many
12//! input spikes over long sequences.
13
14/// Non-spiking leaky integrator readout neuron.
15///
16/// The readout accumulates weighted spike inputs with exponential decay.
17/// Its membrane potential (i32 for extra precision) is the network's
18/// continuous output.
19///
20/// # Precision
21///
22/// While hidden layer neurons use i16 membranes, the readout uses i32 to
23/// prevent overflow during long sequences. The decay factor `kappa` is
24/// still Q1.14 (i16), but the membrane has ~18 bits of headroom above
25/// the Q1.14 range.
26pub struct ReadoutNeuron {
27 /// Membrane potential (higher precision accumulator).
28 /// Stored in Q1.14-compatible scale but with i32 range.
29 pub membrane: i32,
30 /// Decay factor in Q1.14 (controls how fast past inputs are forgotten).
31 pub kappa: i16,
32}
33
34impl ReadoutNeuron {
35 /// Create a new readout neuron with the given decay factor.
36 ///
37 /// # Arguments
38 ///
39 /// * `kappa` -- decay factor in Q1.14 (0 = no memory, Q14_ONE = perfect memory)
40 pub fn new(kappa: i16) -> Self {
41 Self { membrane: 0, kappa }
42 }
43
44 /// Advance the readout by one timestep.
45 ///
46 /// Decays the membrane potential and integrates new weighted input:
47 ///
48 /// ```text
49 /// membrane = (membrane * kappa) >> 14 + weighted_input
50 /// ```
51 ///
52 /// # Arguments
53 ///
54 /// * `weighted_input` -- sum of `W_kj * z_j[t]` for active spikes, as i32
55 #[inline]
56 pub fn step(&mut self, weighted_input: i32) {
57 let decayed = (self.membrane as i64 * self.kappa as i64) >> 14;
58 self.membrane = (decayed as i32).saturating_add(weighted_input);
59 }
60
61 /// Get the raw membrane potential as i32.
62 #[inline]
63 pub fn output_i32(&self) -> i32 {
64 self.membrane
65 }
66
67 /// Dequantize the membrane potential to f64.
68 ///
69 /// # Arguments
70 ///
71 /// * `scale` -- output scaling factor (typically `1.0 / Q14_ONE as f64`)
72 #[inline]
73 pub fn output_f64(&self, scale: f64) -> f64 {
74 self.membrane as f64 * scale
75 }
76
77 /// Reset the membrane potential to zero.
78 pub fn reset(&mut self) {
79 self.membrane = 0;
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86 use crate::snn::lif::{f64_to_q14, Q14_HALF, Q14_ONE};
87
88 #[test]
89 fn new_readout_has_zero_membrane() {
90 let r = ReadoutNeuron::new(Q14_HALF);
91 assert_eq!(r.membrane, 0);
92 assert_eq!(r.output_i32(), 0);
93 }
94
95 #[test]
96 fn step_accumulates_input() {
97 let mut r = ReadoutNeuron::new(Q14_ONE); // no decay
98 r.step(1000);
99 assert_eq!(r.output_i32(), 1000);
100 r.step(500);
101 assert_eq!(r.output_i32(), 1500, "should accumulate with kappa=1.0");
102 }
103
104 #[test]
105 fn step_decays_membrane() {
106 let mut r = ReadoutNeuron::new(Q14_HALF); // 0.5 decay
107 r.step(1000);
108 assert_eq!(r.output_i32(), 1000);
109
110 // Next step with zero input: 1000 * 0.5 = 500
111 r.step(0);
112 assert_eq!(r.output_i32(), 500, "membrane should decay by kappa");
113
114 // Another step: 500 * 0.5 = 250
115 r.step(0);
116 assert_eq!(r.output_i32(), 250, "membrane should continue decaying");
117 }
118
119 #[test]
120 fn output_f64_dequantizes_correctly() {
121 let mut r = ReadoutNeuron::new(Q14_ONE);
122 r.membrane = Q14_ONE as i32; // set to 1.0 in Q1.14
123 let scale = 1.0 / Q14_ONE as f64;
124 let out = r.output_f64(scale);
125 assert!(
126 (out - 1.0).abs() < 0.001,
127 "output_f64 should dequantize 16384 to ~1.0, got {}",
128 out
129 );
130 }
131
132 #[test]
133 fn reset_clears_membrane() {
134 let mut r = ReadoutNeuron::new(Q14_HALF);
135 r.step(5000);
136 assert!(r.output_i32() != 0);
137 r.reset();
138 assert_eq!(r.output_i32(), 0, "reset should zero the membrane");
139 }
140
141 #[test]
142 fn no_overflow_with_large_accumulation() {
143 let mut r = ReadoutNeuron::new(f64_to_q14(0.99));
144 // Feed many large inputs to test i32 accumulation
145 for _ in 0..1000 {
146 r.step(10000);
147 }
148 // Should not panic or wrap around
149 assert!(
150 r.output_i32() > 0,
151 "membrane should be positive after positive inputs"
152 );
153 }
154
155 #[test]
156 fn decay_with_negative_membrane() {
157 let mut r = ReadoutNeuron::new(Q14_HALF);
158 r.step(-2000);
159 assert_eq!(r.output_i32(), -2000);
160
161 r.step(0);
162 assert_eq!(
163 r.output_i32(),
164 -1000,
165 "negative membrane should decay toward 0"
166 );
167 }
168
169 #[test]
170 fn zero_kappa_forgets_immediately() {
171 let mut r = ReadoutNeuron::new(0); // zero decay
172 r.step(5000);
173 assert_eq!(r.output_i32(), 5000);
174
175 r.step(0);
176 assert_eq!(
177 r.output_i32(),
178 0,
179 "zero kappa should forget membrane immediately"
180 );
181 }
182}