Skip to main content

irithyll_core/snn/
network_fixed.rs

1//! Complete SNN combining LIF neurons, delta encoding, e-prop learning, and readout.
2//!
3//! [`SpikeNetFixed`] is the core spiking neural network implementation. It
4//! manages all neuron state, weights, eligibility traces, and spike encoding
5//! in a single struct. The "Fixed" name refers to fixed-point arithmetic
6//! (Q1.14), not fixed array sizes -- the network uses `Vec` for runtime-sized
7//! buffers but maintains a constant memory footprint after construction.
8//!
9//! # Architecture
10//!
11//! ```text
12//! Raw input (i16)
13//!   |
14//!   v
15//! Delta Encoder -> spike_buf (2*N_IN binary spikes)
16//!   |
17//!   v
18//! Hidden Layer (N_HID LIF neurons, recurrently connected)
19//!   |  - w_input: spike_buf -> hidden
20//!   |  - w_recurrent: hidden -> hidden (from previous timestep)
21//!   |
22//!   v
23//! Readout Layer (N_OUT leaky integrators)
24//!   |  - w_output: hidden spikes -> readout
25//!   |
26//!   v
27//! Output (i32 membrane potentials, dequantized to f64)
28//! ```
29//!
30//! # Memory Layout
31//!
32//! All weight matrices are stored as flat `Vec<i16>` in row-major order:
33//! - `w_input[j * n_enc + i]` = weight from encoded input `i` to hidden neuron `j`
34//! - `w_recurrent[j * n_hid + i]` = weight from hidden `i` to hidden `j`
35//! - `w_output[k * n_hid + j]` = weight from hidden `j` to readout `k`
36//! - `feedback[j * n_out + k]` = fixed random feedback from output `k` to hidden `j`
37
38use alloc::vec;
39use alloc::vec::Vec;
40
41use super::eprop::{
42    compute_learning_signal_fixed, update_eligibility_fixed, update_output_weights_fixed,
43    update_pre_trace_fixed, update_weights_fixed,
44};
45use super::lif::{lif_step, surrogate_gradient_pwl};
46use super::readout::ReadoutNeuron;
47use super::spike_encoding::DeltaEncoderFixed;
48
49/// Configuration for a `SpikeNetFixed` network.
50///
51/// All Q1.14 parameters should be in the range `[-2.0, +2.0)` when converted
52/// from f64 via `f64_to_q14()`. Typical defaults:
53///
54/// | Parameter | f64 | Q1.14 |
55/// |-----------|-----|-------|
56/// | alpha | 0.95 | 15565 |
57/// | kappa | 0.99 | 16220 |
58/// | kappa_out | 0.90 | 14746 |
59/// | eta | 0.001 | 16 |
60/// | v_thr | 0.50 | 8192 |
61/// | gamma | 0.30 | 4915 |
62/// | spike_threshold | 0.05 | 819 |
63#[derive(Debug, Clone)]
64pub struct SpikeNetFixedConfig {
65    /// Number of raw input features (encoded to 2x for spike channels).
66    pub n_input: usize,
67    /// Number of hidden LIF neurons.
68    pub n_hidden: usize,
69    /// Number of output readout neurons.
70    pub n_output: usize,
71    /// Membrane decay factor in Q1.14.
72    pub alpha: i16,
73    /// Eligibility trace decay factor in Q1.14.
74    pub kappa: i16,
75    /// Readout membrane decay factor in Q1.14.
76    pub kappa_out: i16,
77    /// Learning rate in Q1.14.
78    pub eta: i16,
79    /// Firing threshold in Q1.14.
80    pub v_thr: i16,
81    /// Surrogate gradient dampening factor in Q1.14.
82    pub gamma: i16,
83    /// Delta encoding threshold in Q1.14.
84    pub spike_threshold: i16,
85    /// PRNG seed for reproducible weight initialization.
86    pub seed: u64,
87    /// Weight initialization range: weights sampled from `[-range, +range]` in Q1.14.
88    pub weight_init_range: i16,
89}
90
91impl Default for SpikeNetFixedConfig {
92    fn default() -> Self {
93        Self {
94            n_input: 1,
95            n_hidden: 64,
96            n_output: 1,
97            alpha: 15565,         // 0.95
98            kappa: 16220,         // 0.99
99            kappa_out: 14746,     // 0.90
100            eta: 16,              // 0.001
101            v_thr: 8192,          // 0.50
102            gamma: 4915,          // 0.30
103            spike_threshold: 819, // 0.05
104            seed: 42,
105            weight_init_range: 1638, // 0.10
106        }
107    }
108}
109
110/// Inline xorshift64 PRNG for weight initialization.
111///
112/// Returns the next pseudo-random u64.
113#[inline]
114fn xorshift64(state: &mut u64) -> u64 {
115    let mut x = *state;
116    x ^= x << 13;
117    x ^= x >> 7;
118    x ^= x << 17;
119    *state = x;
120    x
121}
122
123/// Generate a random i16 in `[-range, +range]` from the PRNG.
124///
125/// `range` should be positive. If zero, returns 0.
126#[inline]
127fn xorshift64_i16(state: &mut u64, range: i16) -> i16 {
128    let raw = xorshift64(state);
129    let abs_range = if range < 0 { -range } else { range };
130    if abs_range == 0 {
131        return 0;
132    }
133    let abs_u64 = abs_range as u64;
134    let modulus = 2 * abs_u64 + 1;
135    ((raw % modulus) as i16) - abs_range
136}
137
138/// Complete spiking neural network with e-prop online learning.
139///
140/// Manages all neuron state, synaptic weights, eligibility traces, and spike
141/// encoding. After construction, memory is fixed -- no further allocations
142/// occur during `forward` or `train_step`.
143///
144/// # Thread Safety
145///
146/// `SpikeNetFixed` is `Send + Sync` because it contains only `Vec<T>` and
147/// primitive fields with no interior mutability.
148pub struct SpikeNetFixed {
149    config: SpikeNetFixedConfig,
150    n_input_encoded: usize, // 2 * n_input
151
152    // --- Neuron state ---
153    membrane: Vec<i16>,   // [n_hidden]
154    spikes: Vec<u8>,      // [n_hidden]
155    prev_spikes: Vec<u8>, // [n_hidden] previous timestep for recurrent
156
157    // --- Presynaptic traces ---
158    pre_trace_in: Vec<i16>,  // [n_input_encoded]
159    pre_trace_hid: Vec<i16>, // [n_hidden]
160
161    // --- Weights (row-major) ---
162    w_input: Vec<i16>,     // [n_hidden * n_input_encoded]
163    w_recurrent: Vec<i16>, // [n_hidden * n_hidden]
164    w_output: Vec<i16>,    // [n_output * n_hidden]
165    feedback: Vec<i16>,    // [n_hidden * n_output] (fixed random, never updated)
166
167    // --- Eligibility traces ---
168    elig_in: Vec<i16>,  // [n_hidden * n_input_encoded]
169    elig_rec: Vec<i16>, // [n_hidden * n_hidden]
170
171    // --- Readout ---
172    readout: Vec<ReadoutNeuron>, // [n_output]
173
174    // --- Encoder ---
175    encoder: DeltaEncoderFixed,
176
177    // --- Spike buffer ---
178    spike_buf: Vec<u8>, // [n_input_encoded]
179
180    // --- Error buffer (reusable) ---
181    error_buf: Vec<i16>, // [n_output]
182
183    // --- Counters ---
184    n_samples: u64,
185}
186
187// Safety: SpikeNetFixed contains only Vec, primitives, and other Send+Sync types
188unsafe impl Send for SpikeNetFixed {}
189unsafe impl Sync for SpikeNetFixed {}
190
191impl SpikeNetFixed {
192    /// Create a new SpikeNetFixed with the given configuration.
193    ///
194    /// Allocates all internal buffers and initializes weights from the PRNG.
195    /// No further allocations occur during operation.
196    pub fn new(config: SpikeNetFixedConfig) -> Self {
197        let n_in = config.n_input;
198        let n_hid = config.n_hidden;
199        let n_out = config.n_output;
200        let n_enc = 2 * n_in;
201
202        let mut rng_state = if config.seed == 0 { 1 } else { config.seed };
203        let range = config.weight_init_range;
204
205        // Initialize input weights
206        let w_input: Vec<i16> = (0..n_hid * n_enc)
207            .map(|_| xorshift64_i16(&mut rng_state, range))
208            .collect();
209
210        // Initialize recurrent weights
211        let w_recurrent: Vec<i16> = (0..n_hid * n_hid)
212            .map(|_| xorshift64_i16(&mut rng_state, range))
213            .collect();
214
215        // Initialize output weights
216        let w_output: Vec<i16> = (0..n_out * n_hid)
217            .map(|_| xorshift64_i16(&mut rng_state, range))
218            .collect();
219
220        // Initialize fixed random feedback weights
221        let feedback: Vec<i16> = (0..n_hid * n_out)
222            .map(|_| xorshift64_i16(&mut rng_state, range))
223            .collect();
224
225        let readout: Vec<ReadoutNeuron> = (0..n_out)
226            .map(|_| ReadoutNeuron::new(config.kappa_out))
227            .collect();
228
229        let encoder = DeltaEncoderFixed::new(n_in, config.spike_threshold);
230
231        Self {
232            n_input_encoded: n_enc,
233            membrane: vec![0; n_hid],
234            spikes: vec![0; n_hid],
235            prev_spikes: vec![0; n_hid],
236            pre_trace_in: vec![0; n_enc],
237            pre_trace_hid: vec![0; n_hid],
238            w_input,
239            w_recurrent,
240            w_output,
241            feedback,
242            elig_in: vec![0; n_hid * n_enc],
243            elig_rec: vec![0; n_hid * n_hid],
244            readout,
245            encoder,
246            spike_buf: vec![0; n_enc],
247            error_buf: vec![0; n_out],
248            n_samples: 0,
249            config,
250        }
251    }
252
253    /// Run one forward timestep without learning.
254    ///
255    /// Encodes input into spikes, updates hidden neuron states, and advances
256    /// the readout. Does NOT update weights or eligibility traces.
257    ///
258    /// # Arguments
259    ///
260    /// * `input_i16` -- raw input features in Q1.14, length must equal `config.n_input`
261    pub fn forward(&mut self, input_i16: &[i16]) {
262        let n_hid = self.config.n_hidden;
263        let n_enc = self.n_input_encoded;
264
265        // 1. Delta-encode input into spike buffer
266        self.encoder.encode(input_i16, &mut self.spike_buf);
267
268        // 2. Store previous spikes for recurrent computation
269        self.prev_spikes.copy_from_slice(&self.spikes);
270
271        // 3. Update hidden layer LIF neurons
272        for j in 0..n_hid {
273            // Compute input current from encoded spikes
274            let mut current: i32 = 0;
275            let w_in_offset = j * n_enc;
276            for i in 0..n_enc {
277                if self.spike_buf[i] != 0 {
278                    current += self.w_input[w_in_offset + i] as i32;
279                }
280            }
281
282            // Add recurrent current from previous hidden spikes
283            let w_rec_offset = j * n_hid;
284            for i in 0..n_hid {
285                if self.prev_spikes[i] != 0 {
286                    current += self.w_recurrent[w_rec_offset + i] as i32;
287                }
288            }
289
290            // LIF step
291            let (v_new, spike) = lif_step(
292                self.membrane[j],
293                self.config.alpha,
294                current,
295                self.config.v_thr,
296            );
297            self.membrane[j] = v_new;
298            self.spikes[j] = spike as u8;
299        }
300
301        // 4. Update readout neurons
302        let n_out = self.config.n_output;
303        for k in 0..n_out {
304            let w_out_offset = k * n_hid;
305            let mut weighted_input: i32 = 0;
306            for j in 0..n_hid {
307                if self.spikes[j] != 0 {
308                    weighted_input += self.w_output[w_out_offset + j] as i32;
309                }
310            }
311            self.readout[k].step(weighted_input);
312        }
313    }
314
315    /// Run one forward + learning timestep (e-prop three-factor rule).
316    ///
317    /// Performs a forward pass, then computes error signals and updates
318    /// all weights using the e-prop learning rule.
319    ///
320    /// # Arguments
321    ///
322    /// * `input_i16` -- raw input features in Q1.14
323    /// * `target_i16` -- target values in Q1.14, length must equal `config.n_output`
324    pub fn train_step(&mut self, input_i16: &[i16], target_i16: &[i16]) {
325        let n_hid = self.config.n_hidden;
326        let n_enc = self.n_input_encoded;
327        let n_out = self.config.n_output;
328
329        // 1. Forward pass
330        self.forward(input_i16);
331
332        // 2. Compute error signals: error = target - readout
333        for (k, &target_k) in target_i16.iter().enumerate().take(n_out) {
334            let readout_clamped = self.readout[k]
335                .output_i32()
336                .clamp(i16::MIN as i32, i16::MAX as i32) as i16;
337            self.error_buf[k] = target_k.saturating_sub(readout_clamped);
338        }
339
340        // 3. Update presynaptic traces
341        update_pre_trace_fixed(&mut self.pre_trace_in, &self.spike_buf, self.config.alpha);
342        update_pre_trace_fixed(&mut self.pre_trace_hid, &self.spikes, self.config.alpha);
343
344        // 4. For each hidden neuron: update eligibility, compute learning signal, update weights
345        for j in 0..n_hid {
346            // Surrogate gradient
347            let psi =
348                surrogate_gradient_pwl(self.membrane[j], self.config.v_thr, self.config.gamma);
349
350            // Update input eligibility traces for neuron j
351            let elig_in_start = j * n_enc;
352            let elig_in_end = elig_in_start + n_enc;
353            update_eligibility_fixed(
354                &mut self.elig_in[elig_in_start..elig_in_end],
355                psi,
356                &self.pre_trace_in,
357                self.config.kappa,
358            );
359
360            // Update recurrent eligibility traces for neuron j
361            let elig_rec_start = j * n_hid;
362            let elig_rec_end = elig_rec_start + n_hid;
363            update_eligibility_fixed(
364                &mut self.elig_rec[elig_rec_start..elig_rec_end],
365                psi,
366                &self.pre_trace_hid,
367                self.config.kappa,
368            );
369
370            // Compute learning signal via feedback alignment
371            let fb_start = j * n_out;
372            let fb_end = fb_start + n_out;
373            let learning_signal = compute_learning_signal_fixed(
374                &self.feedback[fb_start..fb_end],
375                &self.error_buf[..n_out],
376            );
377
378            // Update input weights for neuron j
379            let w_in_start = j * n_enc;
380            let w_in_end = w_in_start + n_enc;
381            update_weights_fixed(
382                &mut self.w_input[w_in_start..w_in_end],
383                &self.elig_in[elig_in_start..elig_in_end],
384                learning_signal,
385                self.config.eta,
386            );
387
388            // Update recurrent weights for neuron j
389            let w_rec_start = j * n_hid;
390            let w_rec_end = w_rec_start + n_hid;
391            update_weights_fixed(
392                &mut self.w_recurrent[w_rec_start..w_rec_end],
393                &self.elig_rec[elig_rec_start..elig_rec_end],
394                learning_signal,
395                self.config.eta,
396            );
397        }
398
399        // 5. Update output weights via delta rule
400        for k in 0..n_out {
401            let w_out_start = k * n_hid;
402            let w_out_end = w_out_start + n_hid;
403            update_output_weights_fixed(
404                &mut self.w_output[w_out_start..w_out_end],
405                self.error_buf[k],
406                &self.spikes,
407                self.config.eta,
408            );
409        }
410
411        self.n_samples += 1;
412    }
413
414    /// Get raw readout membrane potentials as i32.
415    ///
416    /// Returns a reference to an internal buffer that is updated on each
417    /// `forward` or `train_step` call.
418    pub fn predict_raw(&self) -> Vec<i32> {
419        self.readout.iter().map(|r| r.output_i32()).collect()
420    }
421
422    /// Get the first readout's membrane potential, dequantized to f64.
423    ///
424    /// # Arguments
425    ///
426    /// * `output_scale` -- scaling factor (typically `1.0 / Q14_ONE as f64`)
427    pub fn predict_f64(&self, output_scale: f64) -> f64 {
428        if self.readout.is_empty() {
429            return 0.0;
430        }
431        self.readout[0].output_f64(output_scale)
432    }
433
434    /// Get all readout membrane potentials, dequantized to f64.
435    ///
436    /// # Arguments
437    ///
438    /// * `output_scale` -- scaling factor per output
439    pub fn predict_all_f64(&self, output_scale: f64) -> Vec<f64> {
440        self.readout
441            .iter()
442            .map(|r| r.output_f64(output_scale))
443            .collect()
444    }
445
446    /// Number of training samples seen.
447    pub fn n_samples_seen(&self) -> u64 {
448        self.n_samples
449    }
450
451    /// Reference to the network configuration.
452    pub fn config(&self) -> &SpikeNetFixedConfig {
453        &self.config
454    }
455
456    /// Number of hidden neurons.
457    pub fn n_hidden(&self) -> usize {
458        self.config.n_hidden
459    }
460
461    /// Number of encoded input channels (2 * n_input).
462    pub fn n_input_encoded(&self) -> usize {
463        self.n_input_encoded
464    }
465
466    /// Get current hidden spike vector.
467    pub fn hidden_spikes(&self) -> &[u8] {
468        &self.spikes
469    }
470
471    /// Get current hidden membrane potentials.
472    pub fn hidden_membrane(&self) -> &[i16] {
473        &self.membrane
474    }
475
476    /// Compute total memory usage in bytes.
477    ///
478    /// Counts all Vec contents plus struct overhead.
479    pub fn memory_bytes(&self) -> usize {
480        let n_hid = self.config.n_hidden;
481        let n_enc = self.n_input_encoded;
482        let n_out = self.config.n_output;
483        let n_in = self.config.n_input;
484
485        let size_of_i16 = core::mem::size_of::<i16>();
486        let size_of_u8 = core::mem::size_of::<u8>();
487
488        // Neuron state
489        let membrane = n_hid * size_of_i16;
490        let spikes = n_hid * size_of_u8;
491        let prev_spikes = n_hid * size_of_u8;
492
493        // Presynaptic traces
494        let pre_trace_in = n_enc * size_of_i16;
495        let pre_trace_hid = n_hid * size_of_i16;
496
497        // Weights
498        let w_input = n_hid * n_enc * size_of_i16;
499        let w_recurrent = n_hid * n_hid * size_of_i16;
500        let w_output = n_out * n_hid * size_of_i16;
501        let feedback = n_hid * n_out * size_of_i16;
502
503        // Eligibility traces
504        let elig_in = n_hid * n_enc * size_of_i16;
505        let elig_rec = n_hid * n_hid * size_of_i16;
506
507        // Readout (membrane i32 + kappa i16 + padding)
508        let readout_size = n_out * core::mem::size_of::<ReadoutNeuron>();
509
510        // Encoder state
511        let encoder_prev = n_in * size_of_i16;
512        let encoder_thr = n_in * size_of_i16;
513
514        // Spike buffer
515        let spike_buf = n_enc * size_of_u8;
516
517        // Error buffer
518        let error_buf = n_out * size_of_i16;
519
520        // Struct overhead (approximate)
521        let struct_overhead = core::mem::size_of::<Self>();
522
523        // Total Vec contents
524        let vec_contents = membrane
525            + spikes
526            + prev_spikes
527            + pre_trace_in
528            + pre_trace_hid
529            + w_input
530            + w_recurrent
531            + w_output
532            + feedback
533            + elig_in
534            + elig_rec
535            + readout_size
536            + encoder_prev
537            + encoder_thr
538            + spike_buf
539            + error_buf;
540
541        struct_overhead + vec_contents
542    }
543
544    /// Reset all network state (neuron potentials, traces, readout) to zero.
545    ///
546    /// Weights are re-initialized from the original seed. The network behaves
547    /// as if freshly constructed after calling reset.
548    pub fn reset(&mut self) {
549        // Zero neuron state
550        for v in self.membrane.iter_mut() {
551            *v = 0;
552        }
553        for s in self.spikes.iter_mut() {
554            *s = 0;
555        }
556        for s in self.prev_spikes.iter_mut() {
557            *s = 0;
558        }
559
560        // Zero traces
561        for t in self.pre_trace_in.iter_mut() {
562            *t = 0;
563        }
564        for t in self.pre_trace_hid.iter_mut() {
565            *t = 0;
566        }
567        for e in self.elig_in.iter_mut() {
568            *e = 0;
569        }
570        for e in self.elig_rec.iter_mut() {
571            *e = 0;
572        }
573
574        // Reset readout
575        for r in self.readout.iter_mut() {
576            r.reset();
577        }
578
579        // Reset encoder
580        self.encoder.reset();
581
582        // Zero spike buffer
583        for s in self.spike_buf.iter_mut() {
584            *s = 0;
585        }
586
587        // Zero error buffer
588        for e in self.error_buf.iter_mut() {
589            *e = 0;
590        }
591
592        // Re-initialize weights from seed
593        let mut rng_state = if self.config.seed == 0 {
594            1
595        } else {
596            self.config.seed
597        };
598        let range = self.config.weight_init_range;
599
600        for w in self.w_input.iter_mut() {
601            *w = xorshift64_i16(&mut rng_state, range);
602        }
603        for w in self.w_recurrent.iter_mut() {
604            *w = xorshift64_i16(&mut rng_state, range);
605        }
606        for w in self.w_output.iter_mut() {
607            *w = xorshift64_i16(&mut rng_state, range);
608        }
609        for w in self.feedback.iter_mut() {
610            *w = xorshift64_i16(&mut rng_state, range);
611        }
612
613        self.n_samples = 0;
614    }
615}
616
617#[cfg(test)]
618mod tests {
619    use super::*;
620    use crate::snn::lif::{f64_to_q14, Q14_ONE};
621
622    fn default_small_config() -> SpikeNetFixedConfig {
623        SpikeNetFixedConfig {
624            n_input: 2,
625            n_hidden: 8,
626            n_output: 1,
627            alpha: f64_to_q14(0.95),
628            kappa: f64_to_q14(0.99),
629            kappa_out: f64_to_q14(0.9),
630            eta: f64_to_q14(0.01),
631            v_thr: f64_to_q14(0.5),
632            gamma: f64_to_q14(0.3),
633            spike_threshold: f64_to_q14(0.05),
634            seed: 42,
635            weight_init_range: f64_to_q14(0.1),
636        }
637    }
638
639    #[test]
640    fn construction_initializes_all_buffers() {
641        let config = default_small_config();
642        let net = SpikeNetFixed::new(config);
643
644        assert_eq!(net.membrane.len(), 8);
645        assert_eq!(net.spikes.len(), 8);
646        assert_eq!(net.n_input_encoded(), 4);
647        assert_eq!(net.w_input.len(), 8 * 4);
648        assert_eq!(net.w_recurrent.len(), 8 * 8);
649        assert_eq!(net.w_output.len(), 1 * 8);
650        assert_eq!(net.feedback.len(), 8 * 1);
651        assert_eq!(net.elig_in.len(), 8 * 4);
652        assert_eq!(net.elig_rec.len(), 8 * 8);
653        assert_eq!(net.readout.len(), 1);
654        assert_eq!(net.n_samples_seen(), 0);
655    }
656
657    #[test]
658    fn forward_does_not_crash() {
659        let config = default_small_config();
660        let mut net = SpikeNetFixed::new(config);
661
662        // First call (encoder warmup)
663        net.forward(&[f64_to_q14(0.5), f64_to_q14(-0.3)]);
664        // Second call (actual spikes possible)
665        net.forward(&[f64_to_q14(0.8), f64_to_q14(0.2)]);
666
667        // Should produce some output
668        let raw = net.predict_raw();
669        assert_eq!(raw.len(), 1, "should have one readout output");
670    }
671
672    #[test]
673    fn train_step_increments_counter() {
674        let config = default_small_config();
675        let mut net = SpikeNetFixed::new(config);
676
677        let input = [f64_to_q14(0.5), f64_to_q14(-0.3)];
678        let target = [f64_to_q14(0.7)];
679
680        net.train_step(&input, &target);
681        assert_eq!(net.n_samples_seen(), 1);
682
683        net.train_step(&input, &target);
684        assert_eq!(net.n_samples_seen(), 2);
685    }
686
687    #[test]
688    fn predictions_change_after_training() {
689        let config = SpikeNetFixedConfig {
690            n_input: 2,
691            n_hidden: 16,
692            n_output: 1,
693            alpha: f64_to_q14(0.9),
694            kappa: f64_to_q14(0.95),
695            kappa_out: f64_to_q14(0.85),
696            eta: f64_to_q14(0.05),  // larger learning rate for visible change
697            v_thr: f64_to_q14(0.3), // lower threshold for more spiking
698            gamma: f64_to_q14(0.5),
699            spike_threshold: f64_to_q14(0.01), // very sensitive encoding
700            seed: 12345,
701            weight_init_range: f64_to_q14(0.2),
702        };
703
704        let mut net = SpikeNetFixed::new(config);
705        let scale = 1.0 / Q14_ONE as f64;
706
707        // Warm up encoder
708        net.forward(&[0, 0]);
709        let pred_before = net.predict_f64(scale);
710
711        // Train on a pattern for many steps
712        for step in 0..200 {
713            let x = if step % 2 == 0 {
714                [f64_to_q14(0.8), f64_to_q14(-0.5)]
715            } else {
716                [f64_to_q14(-0.3), f64_to_q14(0.6)]
717            };
718            let target = if step % 2 == 0 {
719                [f64_to_q14(1.0)]
720            } else {
721                [f64_to_q14(-1.0)]
722            };
723            net.train_step(&x, &target);
724        }
725
726        let pred_after = net.predict_f64(scale);
727
728        assert!(
729            (pred_after - pred_before).abs() > 1e-10,
730            "prediction should change after training: before={}, after={}",
731            pred_before,
732            pred_after
733        );
734    }
735
736    #[test]
737    fn reset_restores_initial_state() {
738        let config = default_small_config();
739        let mut net = SpikeNetFixed::new(config.clone());
740        let fresh = SpikeNetFixed::new(config);
741
742        // Train a few steps
743        net.train_step(&[1000, -500], &[2000]);
744        net.train_step(&[-1000, 500], &[-2000]);
745        assert!(net.n_samples_seen() > 0);
746
747        // Reset
748        net.reset();
749
750        // Compare with fresh network
751        assert_eq!(net.n_samples_seen(), 0);
752        assert_eq!(net.membrane, fresh.membrane);
753        assert_eq!(net.spikes, fresh.spikes);
754        assert_eq!(
755            net.w_input, fresh.w_input,
756            "weights should be re-initialized from seed"
757        );
758        assert_eq!(net.w_recurrent, fresh.w_recurrent);
759        assert_eq!(net.w_output, fresh.w_output);
760        assert_eq!(net.feedback, fresh.feedback);
761    }
762
763    #[test]
764    fn memory_bytes_is_reasonable() {
765        let config = SpikeNetFixedConfig {
766            n_input: 10,
767            n_hidden: 64,
768            n_output: 1,
769            ..SpikeNetFixedConfig::default()
770        };
771        let net = SpikeNetFixed::new(config);
772        let mem = net.memory_bytes();
773
774        // Dominant terms: w_recurrent = 64*64*2 = 8192 bytes
775        // w_input = 64*20*2 = 2560 bytes
776        // elig_rec = 64*64*2 = 8192 bytes
777        // elig_in = 64*20*2 = 2560 bytes
778        // Total Vec contents should be ~22KB + struct overhead
779        assert!(
780            mem > 20_000,
781            "memory should be at least 20KB for 10-in/64-hid/1-out, got {}",
782            mem
783        );
784        assert!(
785            mem < 100_000,
786            "memory should be under 100KB for small network, got {}",
787            mem
788        );
789    }
790
791    #[test]
792    fn deterministic_with_same_seed() {
793        let config = default_small_config();
794        let mut net1 = SpikeNetFixed::new(config.clone());
795        let mut net2 = SpikeNetFixed::new(config);
796
797        let input = [f64_to_q14(0.3), f64_to_q14(-0.7)];
798        let target = [f64_to_q14(0.5)];
799
800        for _ in 0..10 {
801            net1.train_step(&input, &target);
802            net2.train_step(&input, &target);
803        }
804
805        let scale = 1.0 / Q14_ONE as f64;
806        let p1 = net1.predict_f64(scale);
807        let p2 = net2.predict_f64(scale);
808        assert_eq!(p1, p2, "same seed should produce identical predictions");
809    }
810
811    #[test]
812    fn multi_output_network() {
813        let config = SpikeNetFixedConfig {
814            n_input: 3,
815            n_hidden: 8,
816            n_output: 3,
817            ..SpikeNetFixedConfig::default()
818        };
819        let mut net = SpikeNetFixed::new(config);
820
821        net.forward(&[1000, -500, 200]);
822        net.forward(&[1500, 0, -300]);
823
824        let raw = net.predict_raw();
825        assert_eq!(raw.len(), 3, "should have 3 readout outputs");
826
827        let scale = 1.0 / Q14_ONE as f64;
828        let all = net.predict_all_f64(scale);
829        assert_eq!(all.len(), 3);
830    }
831
832    #[test]
833    fn train_step_with_multi_output() {
834        let config = SpikeNetFixedConfig {
835            n_input: 2,
836            n_hidden: 8,
837            n_output: 2,
838            ..SpikeNetFixedConfig::default()
839        };
840        let mut net = SpikeNetFixed::new(config);
841
842        // Should not panic
843        net.train_step(&[1000, -500], &[2000, -1000]);
844        assert_eq!(net.n_samples_seen(), 1);
845    }
846
847    #[test]
848    fn hidden_spikes_accessible() {
849        let config = default_small_config();
850        let mut net = SpikeNetFixed::new(config);
851
852        net.forward(&[0, 0]);
853        net.forward(&[Q14_ONE, -Q14_ONE]); // big change to trigger spikes
854
855        let spikes = net.hidden_spikes();
856        assert_eq!(spikes.len(), 8);
857        // Spikes are binary
858        for &s in spikes {
859            assert!(s == 0 || s == 1, "spike should be 0 or 1, got {}", s);
860        }
861    }
862
863    #[test]
864    fn config_default_is_sensible() {
865        let config = SpikeNetFixedConfig::default();
866        assert!(config.alpha > 0, "alpha should be positive");
867        assert!(config.v_thr > 0, "v_thr should be positive");
868        assert!(config.eta > 0, "eta should be positive");
869        assert!(config.n_hidden > 0, "n_hidden should be positive");
870    }
871}