Skip to main content

irithyll_core/snn/
network_fixed.rs

1//! Complete SNN combining LIF neurons, delta encoding, e-prop learning, and readout.
2//!
3//! [`SpikeNetFixed`] is the core spiking neural network implementation. It
4//! manages all neuron state, weights, eligibility traces, and spike encoding
5//! in a single struct. The "Fixed" name refers to fixed-point arithmetic
6//! (Q1.14), not fixed array sizes -- the network uses `Vec` for runtime-sized
7//! buffers but maintains a constant memory footprint after construction.
8//!
9//! # Architecture
10//!
11//! ```text
12//! Raw input (i16)
13//!   |
14//!   v
15//! Delta Encoder -> spike_buf (2*N_IN binary spikes)
16//!   |
17//!   v
18//! Hidden Layer (N_HID LIF neurons, recurrently connected)
19//!   |  - w_input: spike_buf -> hidden
20//!   |  - w_recurrent: hidden -> hidden (from previous timestep)
21//!   |
22//!   v
23//! Readout Layer (N_OUT leaky integrators)
24//!   |  - w_output: hidden spikes -> readout
25//!   |
26//!   v
27//! Output (i32 membrane potentials, dequantized to f64)
28//! ```
29//!
30//! # Memory Layout
31//!
32//! All weight matrices are stored as flat `Vec<i16>` in row-major order:
33//! - `w_input[j * n_enc + i]` = weight from encoded input `i` to hidden neuron `j`
34//! - `w_recurrent[j * n_hid + i]` = weight from hidden `i` to hidden `j`
35//! - `w_output[k * n_hid + j]` = weight from hidden `j` to readout `k`
36//! - `feedback[j * n_out + k]` = fixed random feedback from output `k` to hidden `j`
37
38use alloc::vec;
39use alloc::vec::Vec;
40
41use super::astrocyte::{AstrocyteGate, AstrocyteMode};
42use super::eprop::{
43    compute_learning_signal_fixed, update_eligibility_fixed, update_output_weights_fixed,
44    update_pre_trace_fixed, update_weights_fixed,
45};
46use super::lif::{lif_step, surrogate_gradient_pwl};
47use super::readout::ReadoutNeuron;
48use super::spike_encoding::DeltaEncoderFixed;
49
50/// Compute precision for SNN weight and activation storage.
51///
52/// Controls whether the network operates in floating-point or fixed-point mode.
53/// The default is platform-dependent:
54///
55/// - When `std` is available (desktop / server), `Float` (f32) is the default.
56///   f32 offers fast IEEE-754 arithmetic and is sufficient for online learning
57///   workloads where weights are updated continuously.
58///
59/// - When `std` is absent (cortex_m / bare-metal `no_std`), `Fixed` (Q1.14
60///   i16) is the default. Fixed-point is mandatory on Cortex-M0+ and RISC-V
61///   parts without FPU hardware; it also reduces SRAM footprint by 2x versus
62///   f32 (2 bytes vs 4 per weight).
63///
64/// This is a principled default, not an arbitrary one: the selection criterion
65/// is memory budget and FPU availability, both of which are unambiguous at
66/// compile time. A desktop build never needs the 2x SRAM saving; a bare-metal
67/// build never has a hardware FPU to benefit from.
68///
69/// # Note on Current State
70///
71/// `SpikeNetFixed` unconditionally uses Q1.14 arithmetic at the compute level.
72/// This enum documents the intended precision and is used as a config field for
73/// future path selection. Selecting [`Precision::Float`] on a `std` build
74/// currently behaves identically to [`Precision::Fixed`] (the fixed-point
75/// kernel is always used). When an f32 kernel lands, `Float` will activate it.
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77#[non_exhaustive]
78pub enum Precision {
79    /// 32-bit IEEE-754 floating-point. Default on `std` targets.
80    ///
81    /// Best for: desktop / server workloads with hardware FPU.
82    Float,
83    /// Q1.14 fixed-point (i16). Default on `no_std` / cortex_m targets.
84    ///
85    /// Best for: microcontrollers without FPU, minimum SRAM footprint.
86    Fixed,
87}
88
89// The Default impl uses conditional compilation to select Float (std/FPU targets) vs
90// Fixed (no_std/bare-metal). `#[derive(Default)]` cannot express this conditional —
91// it always picks a single default variant regardless of feature flags.
92#[allow(clippy::derivable_impls)]
93impl Default for Precision {
94    fn default() -> Self {
95        // std = hardware FPU available: Float is the principled choice.
96        // no_std = cortex_m / bare-metal: Fixed is mandatory.
97        #[cfg(feature = "std")]
98        {
99            Precision::Float
100        }
101        #[cfg(not(feature = "std"))]
102        {
103            Precision::Fixed
104        }
105    }
106}
107
108/// Configuration for a `SpikeNetFixed` network.
109///
110/// All Q1.14 parameters should be in the range `[-2.0, +2.0)` when converted
111/// from f64 via `f64_to_q14()`. Typical defaults:
112///
113/// | Parameter | f64 | Q1.14 |
114/// |-----------|-----|-------|
115/// | alpha | 0.95 | 15565 |
116/// | kappa | 0.99 | 16220 |
117/// | kappa_out | 0.90 | 14746 |
118/// | eta | 0.001 | 16 |
119/// | v_thr | 0.50 | 8192 |
120/// | gamma | 0.30 | 4915 |
121/// | spike_threshold | 0.05 | 819 |
122#[derive(Debug, Clone)]
123pub struct SpikeNetFixedConfig {
124    /// Number of raw input features (encoded to 2x for spike channels).
125    pub n_input: usize,
126    /// Number of hidden LIF neurons.
127    pub n_hidden: usize,
128    /// Number of output readout neurons.
129    pub n_output: usize,
130    /// Membrane decay factor in Q1.14.
131    pub alpha: i16,
132    /// Eligibility trace decay factor in Q1.14.
133    pub kappa: i16,
134    /// Readout membrane decay factor in Q1.14.
135    pub kappa_out: i16,
136    /// Learning rate in Q1.14.
137    pub eta: i16,
138    /// Firing threshold in Q1.14.
139    pub v_thr: i16,
140    /// Surrogate gradient dampening factor in Q1.14.
141    pub gamma: i16,
142    /// Delta encoding threshold in Q1.14.
143    pub spike_threshold: i16,
144    /// PRNG seed for reproducible weight initialization.
145    pub seed: u64,
146    /// Weight initialization range: weights sampled from `[-range, +range]` in Q1.14.
147    pub weight_init_range: i16,
148    /// Enable astrocyte-gated modulation of input weights.
149    pub use_astrocyte: bool,
150    /// Astrocyte EWMA time constant (higher = slower adaptation). Default: 1000.
151    pub astrocyte_tau: f64,
152    /// Astrocyte gating mode (only used when `use_astrocyte` is true).
153    ///
154    /// `WeightMod` (default): scales forward-pass input weights.
155    /// `LearningRateGate`: scales the per-neuron learning rate during updates
156    /// (Dong & He 2025, Frontiers Neurosci, Eq. 4 -- AGMP proper).
157    pub astrocyte_mode: AstrocyteMode,
158}
159
160impl Default for SpikeNetFixedConfig {
161    fn default() -> Self {
162        Self {
163            n_input: 1,
164            n_hidden: 64,
165            n_output: 1,
166            alpha: 15565,         // 0.95
167            kappa: 16220,         // 0.99
168            kappa_out: 14746,     // 0.90
169            eta: 16,              // 0.001
170            v_thr: 8192,          // 0.50
171            gamma: 4915,          // 0.30
172            spike_threshold: 819, // 0.05
173            seed: 42,
174            weight_init_range: 1638, // 0.10
175            use_astrocyte: false,
176            astrocyte_tau: 1000.0,
177            astrocyte_mode: AstrocyteMode::WeightMod,
178        }
179    }
180}
181
182use crate::rng::xorshift64;
183
184/// Generate a random i16 in `[-range, +range]` from the PRNG.
185///
186/// `range` should be positive. If zero, returns 0.
187#[inline]
188fn xorshift64_i16(state: &mut u64, range: i16) -> i16 {
189    let raw = xorshift64(state);
190    let abs_range = if range < 0 { -range } else { range };
191    if abs_range == 0 {
192        return 0;
193    }
194    let abs_u64 = abs_range as u64;
195    let modulus = 2 * abs_u64 + 1;
196    ((raw % modulus) as i16) - abs_range
197}
198
199/// Complete spiking neural network with e-prop online learning.
200///
201/// Manages all neuron state, synaptic weights, eligibility traces, and spike
202/// encoding. After construction, memory is fixed -- no further allocations
203/// occur during `forward` or `train_step`.
204///
205/// # Thread Safety
206///
207/// `SpikeNetFixed` is `Send + Sync` because it contains only `Vec<T>` and
208/// primitive fields with no interior mutability.
209pub struct SpikeNetFixed {
210    config: SpikeNetFixedConfig,
211    n_input_encoded: usize, // 2 * n_input
212
213    // --- Neuron state ---
214    membrane: Vec<i16>,   // [n_hidden]
215    spikes: Vec<u8>,      // [n_hidden]
216    prev_spikes: Vec<u8>, // [n_hidden] previous timestep for recurrent
217
218    // --- Presynaptic traces ---
219    pre_trace_in: Vec<i16>,  // [n_input_encoded]
220    pre_trace_hid: Vec<i16>, // [n_hidden]
221
222    // --- Weights (row-major) ---
223    w_input: Vec<i16>,     // [n_hidden * n_input_encoded]
224    w_recurrent: Vec<i16>, // [n_hidden * n_hidden]
225    w_output: Vec<i16>,    // [n_output * n_hidden]
226    feedback: Vec<i16>,    // [n_hidden * n_output] (fixed random, never updated)
227
228    // --- Eligibility traces ---
229    elig_in: Vec<i16>,  // [n_hidden * n_input_encoded]
230    elig_rec: Vec<i16>, // [n_hidden * n_hidden]
231
232    // --- Readout ---
233    readout: Vec<ReadoutNeuron>, // [n_output]
234
235    // --- Encoder ---
236    encoder: DeltaEncoderFixed,
237
238    // --- Spike buffer ---
239    spike_buf: Vec<u8>, // [n_input_encoded]
240
241    // --- Error buffer (reusable) ---
242    error_buf: Vec<i16>, // [n_output]
243
244    // --- Astrocyte gating ---
245    astrocyte: Option<AstrocyteGate>,
246
247    // --- Counters ---
248    n_samples: u64,
249}
250
251// Safety: SpikeNetFixed contains only Vec, primitives, and other Send+Sync types
252unsafe impl Send for SpikeNetFixed {}
253unsafe impl Sync for SpikeNetFixed {}
254
255impl SpikeNetFixed {
256    /// Create a new SpikeNetFixed with the given configuration.
257    ///
258    /// Allocates all internal buffers and initializes weights from the PRNG.
259    /// No further allocations occur during operation.
260    pub fn new(config: SpikeNetFixedConfig) -> Self {
261        let n_in = config.n_input;
262        let n_hid = config.n_hidden;
263        let n_out = config.n_output;
264        let n_enc = 2 * n_in;
265
266        let mut rng_state = if config.seed == 0 { 1 } else { config.seed };
267        let range = config.weight_init_range;
268
269        // Initialize input weights
270        let w_input: Vec<i16> = (0..n_hid * n_enc)
271            .map(|_| xorshift64_i16(&mut rng_state, range))
272            .collect();
273
274        // Initialize recurrent weights
275        let w_recurrent: Vec<i16> = (0..n_hid * n_hid)
276            .map(|_| xorshift64_i16(&mut rng_state, range))
277            .collect();
278
279        // Initialize output weights
280        let w_output: Vec<i16> = (0..n_out * n_hid)
281            .map(|_| xorshift64_i16(&mut rng_state, range))
282            .collect();
283
284        // Initialize fixed random feedback weights
285        let feedback: Vec<i16> = (0..n_hid * n_out)
286            .map(|_| xorshift64_i16(&mut rng_state, range))
287            .collect();
288
289        let readout: Vec<ReadoutNeuron> = (0..n_out)
290            .map(|_| ReadoutNeuron::new(config.kappa_out))
291            .collect();
292
293        let encoder = DeltaEncoderFixed::new(n_in, config.spike_threshold);
294
295        let astrocyte = if config.use_astrocyte {
296            Some(AstrocyteGate::with_mode(
297                n_hid,
298                config.astrocyte_tau,
299                config.astrocyte_mode,
300            ))
301        } else {
302            None
303        };
304
305        Self {
306            n_input_encoded: n_enc,
307            membrane: vec![0; n_hid],
308            spikes: vec![0; n_hid],
309            prev_spikes: vec![0; n_hid],
310            pre_trace_in: vec![0; n_enc],
311            pre_trace_hid: vec![0; n_hid],
312            w_input,
313            w_recurrent,
314            w_output,
315            feedback,
316            elig_in: vec![0; n_hid * n_enc],
317            elig_rec: vec![0; n_hid * n_hid],
318            readout,
319            encoder,
320            spike_buf: vec![0; n_enc],
321            error_buf: vec![0; n_out],
322            astrocyte,
323            n_samples: 0,
324            config,
325        }
326    }
327
328    /// Run one forward timestep without learning.
329    ///
330    /// Encodes input into spikes, updates hidden neuron states, and advances
331    /// the readout. Does NOT update weights or eligibility traces.
332    ///
333    /// # Arguments
334    ///
335    /// * `input_i16` -- raw input features in Q1.14, length must equal `config.n_input`
336    pub fn forward(&mut self, input_i16: &[i16]) {
337        let n_hid = self.config.n_hidden;
338        let n_enc = self.n_input_encoded;
339
340        // 1. Delta-encode input into spike buffer
341        self.encoder.encode(input_i16, &mut self.spike_buf);
342
343        // 2. Store previous spikes for recurrent computation
344        self.prev_spikes.copy_from_slice(&self.spikes);
345
346        // 3. Update hidden layer LIF neurons
347        for j in 0..n_hid {
348            // Compute input current from encoded spikes
349            let mut current: i32 = 0;
350            let w_in_offset = j * n_enc;
351            for i in 0..n_enc {
352                if self.spike_buf[i] != 0 {
353                    // Apply astrocyte weight modulation ONLY in WeightMod mode.
354                    // In LearningRateGate mode (AGMP proper, Dong & He 2025 Eq. 4),
355                    // the forward pass uses stored weights unmodified -- only the
356                    // weight UPDATE is scaled by the gate (see train_step).
357                    let w = match &self.astrocyte {
358                        Some(astro) if astro.mode() == AstrocyteMode::WeightMod => {
359                            astro.modulate_weight(j, self.w_input[w_in_offset + i])
360                        }
361                        _ => self.w_input[w_in_offset + i],
362                    };
363                    current += w as i32;
364                }
365            }
366
367            // Add recurrent current from previous hidden spikes
368            let w_rec_offset = j * n_hid;
369            for i in 0..n_hid {
370                if self.prev_spikes[i] != 0 {
371                    current += self.w_recurrent[w_rec_offset + i] as i32;
372                }
373            }
374
375            // LIF step
376            let (v_new, spike) = lif_step(
377                self.membrane[j],
378                self.config.alpha,
379                current,
380                self.config.v_thr,
381            );
382            self.membrane[j] = v_new;
383            self.spikes[j] = spike as u8;
384        }
385
386        // 3b. Update astrocyte spike tracking after spikes are computed
387        if let Some(ref mut astro) = self.astrocyte {
388            astro.update(&self.spikes);
389        }
390
391        // 4. Update readout neurons
392        let n_out = self.config.n_output;
393        for k in 0..n_out {
394            let w_out_offset = k * n_hid;
395            let mut weighted_input: i32 = 0;
396            for j in 0..n_hid {
397                if self.spikes[j] != 0 {
398                    weighted_input += self.w_output[w_out_offset + j] as i32;
399                }
400            }
401            self.readout[k].step(weighted_input);
402        }
403    }
404
405    /// Run one forward + learning timestep (e-prop three-factor rule).
406    ///
407    /// Performs a forward pass, then computes error signals and updates
408    /// all weights using the e-prop learning rule.
409    ///
410    /// # Arguments
411    ///
412    /// * `input_i16` -- raw input features in Q1.14
413    /// * `target_i16` -- target values in Q1.14, length must equal `config.n_output`
414    pub fn train_step(&mut self, input_i16: &[i16], target_i16: &[i16]) {
415        let n_hid = self.config.n_hidden;
416        let n_enc = self.n_input_encoded;
417        let n_out = self.config.n_output;
418
419        // 1. Forward pass
420        self.forward(input_i16);
421
422        // 2. Compute error signals: error = target - readout
423        for (k, &target_k) in target_i16.iter().enumerate().take(n_out) {
424            let readout_clamped = self.readout[k]
425                .output_i32()
426                .clamp(i16::MIN as i32, i16::MAX as i32) as i16;
427            self.error_buf[k] = target_k.saturating_sub(readout_clamped);
428        }
429
430        // 3. Update presynaptic traces
431        update_pre_trace_fixed(&mut self.pre_trace_in, &self.spike_buf, self.config.alpha);
432        update_pre_trace_fixed(&mut self.pre_trace_hid, &self.spikes, self.config.alpha);
433
434        // 4. For each hidden neuron: update eligibility, compute learning signal, update weights
435        for j in 0..n_hid {
436            // Surrogate gradient
437            let psi =
438                surrogate_gradient_pwl(self.membrane[j], self.config.v_thr, self.config.gamma);
439
440            // Update input eligibility traces for neuron j
441            let elig_in_start = j * n_enc;
442            let elig_in_end = elig_in_start + n_enc;
443            update_eligibility_fixed(
444                &mut self.elig_in[elig_in_start..elig_in_end],
445                psi,
446                &self.pre_trace_in,
447                self.config.kappa,
448            );
449
450            // Update recurrent eligibility traces for neuron j
451            let elig_rec_start = j * n_hid;
452            let elig_rec_end = elig_rec_start + n_hid;
453            update_eligibility_fixed(
454                &mut self.elig_rec[elig_rec_start..elig_rec_end],
455                psi,
456                &self.pre_trace_hid,
457                self.config.kappa,
458            );
459
460            // Compute learning signal via feedback alignment
461            let fb_start = j * n_out;
462            let fb_end = fb_start + n_out;
463            let learning_signal = compute_learning_signal_fixed(
464                &self.feedback[fb_start..fb_end],
465                &self.error_buf[..n_out],
466            );
467
468            // Compute effective learning rate for neuron j.
469            // In LearningRateGate mode (AGMP, Dong & He 2025 Eq. 4):
470            //   eta_eff(j) = eta * g_j   where g_j = σ(rate_j - target) ∈ (0,1)
471            // In WeightMod mode or when no astrocyte, eta_eff = base eta.
472            let eta_j = match &self.astrocyte {
473                Some(astro) if astro.mode() == AstrocyteMode::LearningRateGate => {
474                    astro.effective_eta_q14(j, self.config.eta)
475                }
476                _ => self.config.eta,
477            };
478
479            // Update input weights for neuron j
480            let w_in_start = j * n_enc;
481            let w_in_end = w_in_start + n_enc;
482            update_weights_fixed(
483                &mut self.w_input[w_in_start..w_in_end],
484                &self.elig_in[elig_in_start..elig_in_end],
485                learning_signal,
486                eta_j,
487            );
488
489            // Update recurrent weights for neuron j
490            let w_rec_start = j * n_hid;
491            let w_rec_end = w_rec_start + n_hid;
492            update_weights_fixed(
493                &mut self.w_recurrent[w_rec_start..w_rec_end],
494                &self.elig_rec[elig_rec_start..elig_rec_end],
495                learning_signal,
496                eta_j,
497            );
498        }
499
500        // 5. Update output weights via delta rule
501        for k in 0..n_out {
502            let w_out_start = k * n_hid;
503            let w_out_end = w_out_start + n_hid;
504            update_output_weights_fixed(
505                &mut self.w_output[w_out_start..w_out_end],
506                self.error_buf[k],
507                &self.spikes,
508                self.config.eta,
509            );
510        }
511
512        self.n_samples += 1;
513    }
514
515    /// Get raw readout membrane potentials as i32.
516    ///
517    /// Returns a reference to an internal buffer that is updated on each
518    /// `forward` or `train_step` call.
519    pub fn predict_raw(&self) -> Vec<i32> {
520        self.readout.iter().map(|r| r.output_i32()).collect()
521    }
522
523    /// Get the first readout's membrane potential, dequantized to f64.
524    ///
525    /// # Arguments
526    ///
527    /// * `output_scale` -- scaling factor (typically `1.0 / Q14_ONE as f64`)
528    pub fn predict_f64(&self, output_scale: f64) -> f64 {
529        if self.readout.is_empty() {
530            return 0.0;
531        }
532        self.readout[0].output_f64(output_scale)
533    }
534
535    /// Get all readout membrane potentials, dequantized to f64.
536    ///
537    /// # Arguments
538    ///
539    /// * `output_scale` -- scaling factor per output
540    pub fn predict_all_f64(&self, output_scale: f64) -> Vec<f64> {
541        self.readout
542            .iter()
543            .map(|r| r.output_f64(output_scale))
544            .collect()
545    }
546
547    /// Number of training samples seen.
548    pub fn n_samples_seen(&self) -> u64 {
549        self.n_samples
550    }
551
552    /// Reference to the network configuration.
553    pub fn config(&self) -> &SpikeNetFixedConfig {
554        &self.config
555    }
556
557    /// Number of hidden neurons.
558    pub fn n_hidden(&self) -> usize {
559        self.config.n_hidden
560    }
561
562    /// Number of encoded input channels (2 * n_input).
563    pub fn n_input_encoded(&self) -> usize {
564        self.n_input_encoded
565    }
566
567    /// Get current hidden spike vector.
568    pub fn hidden_spikes(&self) -> &[u8] {
569        &self.spikes
570    }
571
572    /// Get current hidden membrane potentials.
573    pub fn hidden_membrane(&self) -> &[i16] {
574        &self.membrane
575    }
576
577    /// Compute total memory usage in bytes.
578    ///
579    /// Counts all Vec contents plus struct overhead.
580    pub fn memory_bytes(&self) -> usize {
581        let n_hid = self.config.n_hidden;
582        let n_enc = self.n_input_encoded;
583        let n_out = self.config.n_output;
584        let n_in = self.config.n_input;
585
586        let size_of_i16 = core::mem::size_of::<i16>();
587        let size_of_u8 = core::mem::size_of::<u8>();
588
589        // Neuron state
590        let membrane = n_hid * size_of_i16;
591        let spikes = n_hid * size_of_u8;
592        let prev_spikes = n_hid * size_of_u8;
593
594        // Presynaptic traces
595        let pre_trace_in = n_enc * size_of_i16;
596        let pre_trace_hid = n_hid * size_of_i16;
597
598        // Weights
599        let w_input = n_hid * n_enc * size_of_i16;
600        let w_recurrent = n_hid * n_hid * size_of_i16;
601        let w_output = n_out * n_hid * size_of_i16;
602        let feedback = n_hid * n_out * size_of_i16;
603
604        // Eligibility traces
605        let elig_in = n_hid * n_enc * size_of_i16;
606        let elig_rec = n_hid * n_hid * size_of_i16;
607
608        // Readout (membrane i32 + kappa i16 + padding)
609        let readout_size = n_out * core::mem::size_of::<ReadoutNeuron>();
610
611        // Encoder state
612        let encoder_prev = n_in * size_of_i16;
613        let encoder_thr = n_in * size_of_i16;
614
615        // Spike buffer
616        let spike_buf = n_enc * size_of_u8;
617
618        // Error buffer
619        let error_buf = n_out * size_of_i16;
620
621        // Struct overhead (approximate)
622        let struct_overhead = core::mem::size_of::<Self>();
623
624        // Total Vec contents
625        let vec_contents = membrane
626            + spikes
627            + prev_spikes
628            + pre_trace_in
629            + pre_trace_hid
630            + w_input
631            + w_recurrent
632            + w_output
633            + feedback
634            + elig_in
635            + elig_rec
636            + readout_size
637            + encoder_prev
638            + encoder_thr
639            + spike_buf
640            + error_buf;
641
642        struct_overhead + vec_contents
643    }
644
645    /// Reset all network state (neuron potentials, traces, readout) to zero.
646    ///
647    /// Weights are re-initialized from the original seed. The network behaves
648    /// as if freshly constructed after calling reset.
649    pub fn reset(&mut self) {
650        // Zero neuron state
651        for v in self.membrane.iter_mut() {
652            *v = 0;
653        }
654        for s in self.spikes.iter_mut() {
655            *s = 0;
656        }
657        for s in self.prev_spikes.iter_mut() {
658            *s = 0;
659        }
660
661        // Zero traces
662        for t in self.pre_trace_in.iter_mut() {
663            *t = 0;
664        }
665        for t in self.pre_trace_hid.iter_mut() {
666            *t = 0;
667        }
668        for e in self.elig_in.iter_mut() {
669            *e = 0;
670        }
671        for e in self.elig_rec.iter_mut() {
672            *e = 0;
673        }
674
675        // Reset readout
676        for r in self.readout.iter_mut() {
677            r.reset();
678        }
679
680        // Reset encoder
681        self.encoder.reset();
682
683        // Zero spike buffer
684        for s in self.spike_buf.iter_mut() {
685            *s = 0;
686        }
687
688        // Zero error buffer
689        for e in self.error_buf.iter_mut() {
690            *e = 0;
691        }
692
693        // Reset astrocyte if present
694        if let Some(ref mut astro) = self.astrocyte {
695            astro.reset();
696        }
697
698        // Re-initialize weights from seed
699        let mut rng_state = if self.config.seed == 0 {
700            1
701        } else {
702            self.config.seed
703        };
704        let range = self.config.weight_init_range;
705
706        for w in self.w_input.iter_mut() {
707            *w = xorshift64_i16(&mut rng_state, range);
708        }
709        for w in self.w_recurrent.iter_mut() {
710            *w = xorshift64_i16(&mut rng_state, range);
711        }
712        for w in self.w_output.iter_mut() {
713            *w = xorshift64_i16(&mut rng_state, range);
714        }
715        for w in self.feedback.iter_mut() {
716            *w = xorshift64_i16(&mut rng_state, range);
717        }
718
719        self.n_samples = 0;
720    }
721}
722
723#[cfg(test)]
724mod tests {
725    use super::*;
726    use crate::snn::lif::{f64_to_q14, Q14_ONE};
727
728    fn default_small_config() -> SpikeNetFixedConfig {
729        SpikeNetFixedConfig {
730            n_input: 2,
731            n_hidden: 8,
732            n_output: 1,
733            alpha: f64_to_q14(0.95),
734            kappa: f64_to_q14(0.99),
735            kappa_out: f64_to_q14(0.9),
736            eta: f64_to_q14(0.01),
737            v_thr: f64_to_q14(0.5),
738            gamma: f64_to_q14(0.3),
739            spike_threshold: f64_to_q14(0.05),
740            seed: 42,
741            weight_init_range: f64_to_q14(0.1),
742            use_astrocyte: false,
743            astrocyte_tau: 1000.0,
744            astrocyte_mode: AstrocyteMode::WeightMod,
745        }
746    }
747
748    #[test]
749    fn construction_initializes_all_buffers() {
750        let config = default_small_config();
751        let net = SpikeNetFixed::new(config);
752
753        assert_eq!(net.membrane.len(), 8);
754        assert_eq!(net.spikes.len(), 8);
755        assert_eq!(net.n_input_encoded(), 4);
756        assert_eq!(net.w_input.len(), 8 * 4);
757        assert_eq!(net.w_recurrent.len(), 8 * 8);
758        assert_eq!(net.w_output.len(), 8);
759        assert_eq!(net.feedback.len(), 8);
760        assert_eq!(net.elig_in.len(), 8 * 4);
761        assert_eq!(net.elig_rec.len(), 8 * 8);
762        assert_eq!(net.readout.len(), 1);
763        assert_eq!(net.n_samples_seen(), 0);
764    }
765
766    #[test]
767    fn forward_does_not_crash() {
768        let config = default_small_config();
769        let mut net = SpikeNetFixed::new(config);
770
771        // First call (encoder warmup)
772        net.forward(&[f64_to_q14(0.5), f64_to_q14(-0.3)]);
773        // Second call (actual spikes possible)
774        net.forward(&[f64_to_q14(0.8), f64_to_q14(0.2)]);
775
776        // Should produce some output
777        let raw = net.predict_raw();
778        assert_eq!(raw.len(), 1, "should have one readout output");
779    }
780
781    #[test]
782    fn train_step_increments_counter() {
783        let config = default_small_config();
784        let mut net = SpikeNetFixed::new(config);
785
786        let input = [f64_to_q14(0.5), f64_to_q14(-0.3)];
787        let target = [f64_to_q14(0.7)];
788
789        net.train_step(&input, &target);
790        assert_eq!(net.n_samples_seen(), 1);
791
792        net.train_step(&input, &target);
793        assert_eq!(net.n_samples_seen(), 2);
794    }
795
796    #[test]
797    fn predictions_change_after_training() {
798        let config = SpikeNetFixedConfig {
799            n_input: 2,
800            n_hidden: 16,
801            n_output: 1,
802            alpha: f64_to_q14(0.9),
803            kappa: f64_to_q14(0.95),
804            kappa_out: f64_to_q14(0.85),
805            eta: f64_to_q14(0.05),  // larger learning rate for visible change
806            v_thr: f64_to_q14(0.3), // lower threshold for more spiking
807            gamma: f64_to_q14(0.5),
808            spike_threshold: f64_to_q14(0.01), // very sensitive encoding
809            seed: 12345,
810            weight_init_range: f64_to_q14(0.2),
811            use_astrocyte: false,
812            astrocyte_tau: 1000.0,
813            astrocyte_mode: AstrocyteMode::WeightMod,
814        };
815
816        let mut net = SpikeNetFixed::new(config);
817        let scale = 1.0 / Q14_ONE as f64;
818
819        // Warm up encoder
820        net.forward(&[0, 0]);
821        let pred_before = net.predict_f64(scale);
822
823        // Train on a pattern for many steps
824        for step in 0..200 {
825            let x = if step % 2 == 0 {
826                [f64_to_q14(0.8), f64_to_q14(-0.5)]
827            } else {
828                [f64_to_q14(-0.3), f64_to_q14(0.6)]
829            };
830            let target = if step % 2 == 0 {
831                [f64_to_q14(1.0)]
832            } else {
833                [f64_to_q14(-1.0)]
834            };
835            net.train_step(&x, &target);
836        }
837
838        let pred_after = net.predict_f64(scale);
839
840        assert!(
841            (pred_after - pred_before).abs() > 1e-10,
842            "prediction should change after training: before={}, after={}",
843            pred_before,
844            pred_after
845        );
846    }
847
848    #[test]
849    fn reset_restores_initial_state() {
850        let config = default_small_config();
851        let mut net = SpikeNetFixed::new(config.clone());
852        let fresh = SpikeNetFixed::new(config);
853
854        // Train a few steps
855        net.train_step(&[1000, -500], &[2000]);
856        net.train_step(&[-1000, 500], &[-2000]);
857        assert!(net.n_samples_seen() > 0);
858
859        // Reset
860        net.reset();
861
862        // Compare with fresh network
863        assert_eq!(net.n_samples_seen(), 0);
864        assert_eq!(net.membrane, fresh.membrane);
865        assert_eq!(net.spikes, fresh.spikes);
866        assert_eq!(
867            net.w_input, fresh.w_input,
868            "weights should be re-initialized from seed"
869        );
870        assert_eq!(net.w_recurrent, fresh.w_recurrent);
871        assert_eq!(net.w_output, fresh.w_output);
872        assert_eq!(net.feedback, fresh.feedback);
873    }
874
875    #[test]
876    fn memory_bytes_is_reasonable() {
877        let config = SpikeNetFixedConfig {
878            n_input: 10,
879            n_hidden: 64,
880            n_output: 1,
881            ..SpikeNetFixedConfig::default()
882        };
883        let net = SpikeNetFixed::new(config);
884        let mem = net.memory_bytes();
885
886        // Dominant terms: w_recurrent = 64*64*2 = 8192 bytes
887        // w_input = 64*20*2 = 2560 bytes
888        // elig_rec = 64*64*2 = 8192 bytes
889        // elig_in = 64*20*2 = 2560 bytes
890        // Total Vec contents should be ~22KB + struct overhead
891        assert!(
892            mem > 20_000,
893            "memory should be at least 20KB for 10-in/64-hid/1-out, got {}",
894            mem
895        );
896        assert!(
897            mem < 100_000,
898            "memory should be under 100KB for small network, got {}",
899            mem
900        );
901    }
902
903    #[test]
904    fn deterministic_with_same_seed() {
905        let config = default_small_config();
906        let mut net1 = SpikeNetFixed::new(config.clone());
907        let mut net2 = SpikeNetFixed::new(config);
908
909        let input = [f64_to_q14(0.3), f64_to_q14(-0.7)];
910        let target = [f64_to_q14(0.5)];
911
912        for _ in 0..10 {
913            net1.train_step(&input, &target);
914            net2.train_step(&input, &target);
915        }
916
917        let scale = 1.0 / Q14_ONE as f64;
918        let p1 = net1.predict_f64(scale);
919        let p2 = net2.predict_f64(scale);
920        assert_eq!(p1, p2, "same seed should produce identical predictions");
921    }
922
923    #[test]
924    fn multi_output_network() {
925        let config = SpikeNetFixedConfig {
926            n_input: 3,
927            n_hidden: 8,
928            n_output: 3,
929            ..SpikeNetFixedConfig::default()
930        };
931        let mut net = SpikeNetFixed::new(config);
932
933        net.forward(&[1000, -500, 200]);
934        net.forward(&[1500, 0, -300]);
935
936        let raw = net.predict_raw();
937        assert_eq!(raw.len(), 3, "should have 3 readout outputs");
938
939        let scale = 1.0 / Q14_ONE as f64;
940        let all = net.predict_all_f64(scale);
941        assert_eq!(all.len(), 3);
942    }
943
944    #[test]
945    fn train_step_with_multi_output() {
946        let config = SpikeNetFixedConfig {
947            n_input: 2,
948            n_hidden: 8,
949            n_output: 2,
950            ..SpikeNetFixedConfig::default()
951        };
952        let mut net = SpikeNetFixed::new(config);
953
954        // Should not panic
955        net.train_step(&[1000, -500], &[2000, -1000]);
956        assert_eq!(net.n_samples_seen(), 1);
957    }
958
959    #[test]
960    fn network_with_astrocyte_runs() {
961        let config = SpikeNetFixedConfig {
962            use_astrocyte: true,
963            astrocyte_tau: 100.0,
964            ..default_small_config()
965        };
966        let mut net = SpikeNetFixed::new(config);
967        for _ in 0..50 {
968            net.train_step(&[1000, -500], &[2000]);
969        }
970        assert_eq!(net.n_samples_seen(), 50);
971        let raw = net.predict_raw();
972        assert_eq!(raw.len(), 1);
973    }
974
975    /// AGMP proper (Dong & He 2025 Eq. 4): `LearningRateGate` must modulate the
976    /// learning rate in the weight update, NOT the forward-pass weights.
977    ///
978    /// The observable invariant: a network running `LearningRateGate` completes
979    /// training, produces a finite prediction, and advances its sample counter.
980    /// This confirms (a) the forward pass ran unmodified (gate not applied to
981    /// weights in forward pass), (b) weight updates occurred (learning rate
982    /// is non-zero; gate only scales magnitude), and (c) no panic or degenerate
983    /// output.
984    ///
985    /// Note: we do NOT assert that LR-gate and WeightMod produce distinct
986    /// predictions from the same constant input. Both start from the same seed.
987    /// Weight trajectories diverge only when the WeightMod gate changes effective
988    /// forward-pass inputs enough to change the spike pattern -- unreliable in
989    /// a constant-input unit test. The `agmp_gates_learning_rate` test in
990    /// `astrocyte.rs` covers the gate mechanics directly via `effective_eta_q14`.
991    #[test]
992    fn agmp_modulates_learning_rate_not_weights() {
993        use crate::snn::lif::f64_to_q14;
994
995        let config = SpikeNetFixedConfig {
996            use_astrocyte: true,
997            astrocyte_tau: 10.0, // fast tau so gate activates quickly
998            astrocyte_mode: AstrocyteMode::LearningRateGate,
999            n_input: 2,
1000            n_hidden: 16,
1001            n_output: 1,
1002            ..SpikeNetFixedConfig::default()
1003        };
1004
1005        let mut net = SpikeNetFixed::new(config);
1006
1007        let input = [f64_to_q14(0.5), f64_to_q14(-0.3)];
1008        let target = [f64_to_q14(1.0)];
1009
1010        for _ in 0..200 {
1011            net.train_step(&input, &target);
1012        }
1013
1014        let scale = 1.0 / Q14_ONE as f64;
1015        let pred = net.predict_f64(scale);
1016
1017        // Finite prediction confirms the gate was not applied to weights in the
1018        // forward pass (which would cause weight-scale drift) and that the weight
1019        // update path (modulated by effective_eta_q14) ran without panicking.
1020        assert!(
1021            pred.is_finite(),
1022            "LearningRateGate network should produce finite prediction after training, got {pred}"
1023        );
1024
1025        // Confirm training advanced sample count.
1026        assert_eq!(net.n_samples_seen(), 200);
1027    }
1028
1029    #[test]
1030    fn hidden_spikes_accessible() {
1031        let config = default_small_config();
1032        let mut net = SpikeNetFixed::new(config);
1033
1034        net.forward(&[0, 0]);
1035        net.forward(&[Q14_ONE, -Q14_ONE]); // big change to trigger spikes
1036
1037        let spikes = net.hidden_spikes();
1038        assert_eq!(spikes.len(), 8);
1039        // Spikes are binary
1040        for &s in spikes {
1041            assert!(s == 0 || s == 1, "spike should be 0 or 1, got {}", s);
1042        }
1043    }
1044
1045    #[test]
1046    fn config_default_is_sensible() {
1047        let config = SpikeNetFixedConfig::default();
1048        assert!(config.alpha > 0, "alpha should be positive");
1049        assert!(config.v_thr > 0, "v_thr should be positive");
1050        assert!(config.eta > 0, "eta should be positive");
1051        assert!(config.n_hidden > 0, "n_hidden should be positive");
1052    }
1053
1054    /// `Precision` default is `Float` when `std` is available (hardware FPU,
1055    /// no SRAM pressure), and `Fixed` when `no_std` (cortex_m, no FPU).
1056    /// The selection criterion is principled: FPU availability and memory
1057    /// budget, both unambiguous at compile time.
1058    #[test]
1059    fn precision_default_is_float_in_std() {
1060        let p = Precision::default();
1061        // On std targets (where this test runs), the default must be Float.
1062        assert_eq!(
1063            p,
1064            Precision::Float,
1065            "Precision::default() must be Float on std targets, got {p:?}"
1066        );
1067
1068        // Both variants must be constructible and distinct.
1069        assert_ne!(
1070            Precision::Float,
1071            Precision::Fixed,
1072            "Float and Fixed must be distinct"
1073        );
1074    }
1075}