1use alloc::vec;
39use alloc::vec::Vec;
40
41use super::eprop::{
42 compute_learning_signal_fixed, update_eligibility_fixed, update_output_weights_fixed,
43 update_pre_trace_fixed, update_weights_fixed,
44};
45use super::lif::{lif_step, surrogate_gradient_pwl};
46use super::readout::ReadoutNeuron;
47use super::spike_encoding::DeltaEncoderFixed;
48
49#[derive(Debug, Clone)]
64pub struct SpikeNetFixedConfig {
65 pub n_input: usize,
67 pub n_hidden: usize,
69 pub n_output: usize,
71 pub alpha: i16,
73 pub kappa: i16,
75 pub kappa_out: i16,
77 pub eta: i16,
79 pub v_thr: i16,
81 pub gamma: i16,
83 pub spike_threshold: i16,
85 pub seed: u64,
87 pub weight_init_range: i16,
89}
90
91impl Default for SpikeNetFixedConfig {
92 fn default() -> Self {
93 Self {
94 n_input: 1,
95 n_hidden: 64,
96 n_output: 1,
97 alpha: 15565, kappa: 16220, kappa_out: 14746, eta: 16, v_thr: 8192, gamma: 4915, spike_threshold: 819, seed: 42,
105 weight_init_range: 1638, }
107 }
108}
109
110#[inline]
114fn xorshift64(state: &mut u64) -> u64 {
115 let mut x = *state;
116 x ^= x << 13;
117 x ^= x >> 7;
118 x ^= x << 17;
119 *state = x;
120 x
121}
122
123#[inline]
127fn xorshift64_i16(state: &mut u64, range: i16) -> i16 {
128 let raw = xorshift64(state);
129 let abs_range = if range < 0 { -range } else { range };
130 if abs_range == 0 {
131 return 0;
132 }
133 let abs_u64 = abs_range as u64;
134 let modulus = 2 * abs_u64 + 1;
135 ((raw % modulus) as i16) - abs_range
136}
137
138pub struct SpikeNetFixed {
149 config: SpikeNetFixedConfig,
150 n_input_encoded: usize, membrane: Vec<i16>, spikes: Vec<u8>, prev_spikes: Vec<u8>, pre_trace_in: Vec<i16>, pre_trace_hid: Vec<i16>, w_input: Vec<i16>, w_recurrent: Vec<i16>, w_output: Vec<i16>, feedback: Vec<i16>, elig_in: Vec<i16>, elig_rec: Vec<i16>, readout: Vec<ReadoutNeuron>, encoder: DeltaEncoderFixed,
176
177 spike_buf: Vec<u8>, error_buf: Vec<i16>, n_samples: u64,
185}
186
187unsafe impl Send for SpikeNetFixed {}
189unsafe impl Sync for SpikeNetFixed {}
190
191impl SpikeNetFixed {
192 pub fn new(config: SpikeNetFixedConfig) -> Self {
197 let n_in = config.n_input;
198 let n_hid = config.n_hidden;
199 let n_out = config.n_output;
200 let n_enc = 2 * n_in;
201
202 let mut rng_state = if config.seed == 0 { 1 } else { config.seed };
203 let range = config.weight_init_range;
204
205 let w_input: Vec<i16> = (0..n_hid * n_enc)
207 .map(|_| xorshift64_i16(&mut rng_state, range))
208 .collect();
209
210 let w_recurrent: Vec<i16> = (0..n_hid * n_hid)
212 .map(|_| xorshift64_i16(&mut rng_state, range))
213 .collect();
214
215 let w_output: Vec<i16> = (0..n_out * n_hid)
217 .map(|_| xorshift64_i16(&mut rng_state, range))
218 .collect();
219
220 let feedback: Vec<i16> = (0..n_hid * n_out)
222 .map(|_| xorshift64_i16(&mut rng_state, range))
223 .collect();
224
225 let readout: Vec<ReadoutNeuron> = (0..n_out)
226 .map(|_| ReadoutNeuron::new(config.kappa_out))
227 .collect();
228
229 let encoder = DeltaEncoderFixed::new(n_in, config.spike_threshold);
230
231 Self {
232 n_input_encoded: n_enc,
233 membrane: vec![0; n_hid],
234 spikes: vec![0; n_hid],
235 prev_spikes: vec![0; n_hid],
236 pre_trace_in: vec![0; n_enc],
237 pre_trace_hid: vec![0; n_hid],
238 w_input,
239 w_recurrent,
240 w_output,
241 feedback,
242 elig_in: vec![0; n_hid * n_enc],
243 elig_rec: vec![0; n_hid * n_hid],
244 readout,
245 encoder,
246 spike_buf: vec![0; n_enc],
247 error_buf: vec![0; n_out],
248 n_samples: 0,
249 config,
250 }
251 }
252
253 pub fn forward(&mut self, input_i16: &[i16]) {
262 let n_hid = self.config.n_hidden;
263 let n_enc = self.n_input_encoded;
264
265 self.encoder.encode(input_i16, &mut self.spike_buf);
267
268 self.prev_spikes.copy_from_slice(&self.spikes);
270
271 for j in 0..n_hid {
273 let mut current: i32 = 0;
275 let w_in_offset = j * n_enc;
276 for i in 0..n_enc {
277 if self.spike_buf[i] != 0 {
278 current += self.w_input[w_in_offset + i] as i32;
279 }
280 }
281
282 let w_rec_offset = j * n_hid;
284 for i in 0..n_hid {
285 if self.prev_spikes[i] != 0 {
286 current += self.w_recurrent[w_rec_offset + i] as i32;
287 }
288 }
289
290 let (v_new, spike) = lif_step(
292 self.membrane[j],
293 self.config.alpha,
294 current,
295 self.config.v_thr,
296 );
297 self.membrane[j] = v_new;
298 self.spikes[j] = spike as u8;
299 }
300
301 let n_out = self.config.n_output;
303 for k in 0..n_out {
304 let w_out_offset = k * n_hid;
305 let mut weighted_input: i32 = 0;
306 for j in 0..n_hid {
307 if self.spikes[j] != 0 {
308 weighted_input += self.w_output[w_out_offset + j] as i32;
309 }
310 }
311 self.readout[k].step(weighted_input);
312 }
313 }
314
315 pub fn train_step(&mut self, input_i16: &[i16], target_i16: &[i16]) {
325 let n_hid = self.config.n_hidden;
326 let n_enc = self.n_input_encoded;
327 let n_out = self.config.n_output;
328
329 self.forward(input_i16);
331
332 for (k, &target_k) in target_i16.iter().enumerate().take(n_out) {
334 let readout_clamped = self.readout[k]
335 .output_i32()
336 .clamp(i16::MIN as i32, i16::MAX as i32) as i16;
337 self.error_buf[k] = target_k.saturating_sub(readout_clamped);
338 }
339
340 update_pre_trace_fixed(&mut self.pre_trace_in, &self.spike_buf, self.config.alpha);
342 update_pre_trace_fixed(&mut self.pre_trace_hid, &self.spikes, self.config.alpha);
343
344 for j in 0..n_hid {
346 let psi =
348 surrogate_gradient_pwl(self.membrane[j], self.config.v_thr, self.config.gamma);
349
350 let elig_in_start = j * n_enc;
352 let elig_in_end = elig_in_start + n_enc;
353 update_eligibility_fixed(
354 &mut self.elig_in[elig_in_start..elig_in_end],
355 psi,
356 &self.pre_trace_in,
357 self.config.kappa,
358 );
359
360 let elig_rec_start = j * n_hid;
362 let elig_rec_end = elig_rec_start + n_hid;
363 update_eligibility_fixed(
364 &mut self.elig_rec[elig_rec_start..elig_rec_end],
365 psi,
366 &self.pre_trace_hid,
367 self.config.kappa,
368 );
369
370 let fb_start = j * n_out;
372 let fb_end = fb_start + n_out;
373 let learning_signal = compute_learning_signal_fixed(
374 &self.feedback[fb_start..fb_end],
375 &self.error_buf[..n_out],
376 );
377
378 let w_in_start = j * n_enc;
380 let w_in_end = w_in_start + n_enc;
381 update_weights_fixed(
382 &mut self.w_input[w_in_start..w_in_end],
383 &self.elig_in[elig_in_start..elig_in_end],
384 learning_signal,
385 self.config.eta,
386 );
387
388 let w_rec_start = j * n_hid;
390 let w_rec_end = w_rec_start + n_hid;
391 update_weights_fixed(
392 &mut self.w_recurrent[w_rec_start..w_rec_end],
393 &self.elig_rec[elig_rec_start..elig_rec_end],
394 learning_signal,
395 self.config.eta,
396 );
397 }
398
399 for k in 0..n_out {
401 let w_out_start = k * n_hid;
402 let w_out_end = w_out_start + n_hid;
403 update_output_weights_fixed(
404 &mut self.w_output[w_out_start..w_out_end],
405 self.error_buf[k],
406 &self.spikes,
407 self.config.eta,
408 );
409 }
410
411 self.n_samples += 1;
412 }
413
414 pub fn predict_raw(&self) -> Vec<i32> {
419 self.readout.iter().map(|r| r.output_i32()).collect()
420 }
421
422 pub fn predict_f64(&self, output_scale: f64) -> f64 {
428 if self.readout.is_empty() {
429 return 0.0;
430 }
431 self.readout[0].output_f64(output_scale)
432 }
433
434 pub fn predict_all_f64(&self, output_scale: f64) -> Vec<f64> {
440 self.readout
441 .iter()
442 .map(|r| r.output_f64(output_scale))
443 .collect()
444 }
445
446 pub fn n_samples_seen(&self) -> u64 {
448 self.n_samples
449 }
450
451 pub fn config(&self) -> &SpikeNetFixedConfig {
453 &self.config
454 }
455
456 pub fn n_hidden(&self) -> usize {
458 self.config.n_hidden
459 }
460
461 pub fn n_input_encoded(&self) -> usize {
463 self.n_input_encoded
464 }
465
466 pub fn hidden_spikes(&self) -> &[u8] {
468 &self.spikes
469 }
470
471 pub fn hidden_membrane(&self) -> &[i16] {
473 &self.membrane
474 }
475
476 pub fn memory_bytes(&self) -> usize {
480 let n_hid = self.config.n_hidden;
481 let n_enc = self.n_input_encoded;
482 let n_out = self.config.n_output;
483 let n_in = self.config.n_input;
484
485 let size_of_i16 = core::mem::size_of::<i16>();
486 let size_of_u8 = core::mem::size_of::<u8>();
487
488 let membrane = n_hid * size_of_i16;
490 let spikes = n_hid * size_of_u8;
491 let prev_spikes = n_hid * size_of_u8;
492
493 let pre_trace_in = n_enc * size_of_i16;
495 let pre_trace_hid = n_hid * size_of_i16;
496
497 let w_input = n_hid * n_enc * size_of_i16;
499 let w_recurrent = n_hid * n_hid * size_of_i16;
500 let w_output = n_out * n_hid * size_of_i16;
501 let feedback = n_hid * n_out * size_of_i16;
502
503 let elig_in = n_hid * n_enc * size_of_i16;
505 let elig_rec = n_hid * n_hid * size_of_i16;
506
507 let readout_size = n_out * core::mem::size_of::<ReadoutNeuron>();
509
510 let encoder_prev = n_in * size_of_i16;
512 let encoder_thr = n_in * size_of_i16;
513
514 let spike_buf = n_enc * size_of_u8;
516
517 let error_buf = n_out * size_of_i16;
519
520 let struct_overhead = core::mem::size_of::<Self>();
522
523 let vec_contents = membrane
525 + spikes
526 + prev_spikes
527 + pre_trace_in
528 + pre_trace_hid
529 + w_input
530 + w_recurrent
531 + w_output
532 + feedback
533 + elig_in
534 + elig_rec
535 + readout_size
536 + encoder_prev
537 + encoder_thr
538 + spike_buf
539 + error_buf;
540
541 struct_overhead + vec_contents
542 }
543
544 pub fn reset(&mut self) {
549 for v in self.membrane.iter_mut() {
551 *v = 0;
552 }
553 for s in self.spikes.iter_mut() {
554 *s = 0;
555 }
556 for s in self.prev_spikes.iter_mut() {
557 *s = 0;
558 }
559
560 for t in self.pre_trace_in.iter_mut() {
562 *t = 0;
563 }
564 for t in self.pre_trace_hid.iter_mut() {
565 *t = 0;
566 }
567 for e in self.elig_in.iter_mut() {
568 *e = 0;
569 }
570 for e in self.elig_rec.iter_mut() {
571 *e = 0;
572 }
573
574 for r in self.readout.iter_mut() {
576 r.reset();
577 }
578
579 self.encoder.reset();
581
582 for s in self.spike_buf.iter_mut() {
584 *s = 0;
585 }
586
587 for e in self.error_buf.iter_mut() {
589 *e = 0;
590 }
591
592 let mut rng_state = if self.config.seed == 0 {
594 1
595 } else {
596 self.config.seed
597 };
598 let range = self.config.weight_init_range;
599
600 for w in self.w_input.iter_mut() {
601 *w = xorshift64_i16(&mut rng_state, range);
602 }
603 for w in self.w_recurrent.iter_mut() {
604 *w = xorshift64_i16(&mut rng_state, range);
605 }
606 for w in self.w_output.iter_mut() {
607 *w = xorshift64_i16(&mut rng_state, range);
608 }
609 for w in self.feedback.iter_mut() {
610 *w = xorshift64_i16(&mut rng_state, range);
611 }
612
613 self.n_samples = 0;
614 }
615}
616
617#[cfg(test)]
618mod tests {
619 use super::*;
620 use crate::snn::lif::{f64_to_q14, Q14_ONE};
621
622 fn default_small_config() -> SpikeNetFixedConfig {
623 SpikeNetFixedConfig {
624 n_input: 2,
625 n_hidden: 8,
626 n_output: 1,
627 alpha: f64_to_q14(0.95),
628 kappa: f64_to_q14(0.99),
629 kappa_out: f64_to_q14(0.9),
630 eta: f64_to_q14(0.01),
631 v_thr: f64_to_q14(0.5),
632 gamma: f64_to_q14(0.3),
633 spike_threshold: f64_to_q14(0.05),
634 seed: 42,
635 weight_init_range: f64_to_q14(0.1),
636 }
637 }
638
639 #[test]
640 fn construction_initializes_all_buffers() {
641 let config = default_small_config();
642 let net = SpikeNetFixed::new(config);
643
644 assert_eq!(net.membrane.len(), 8);
645 assert_eq!(net.spikes.len(), 8);
646 assert_eq!(net.n_input_encoded(), 4);
647 assert_eq!(net.w_input.len(), 8 * 4);
648 assert_eq!(net.w_recurrent.len(), 8 * 8);
649 assert_eq!(net.w_output.len(), 1 * 8);
650 assert_eq!(net.feedback.len(), 8 * 1);
651 assert_eq!(net.elig_in.len(), 8 * 4);
652 assert_eq!(net.elig_rec.len(), 8 * 8);
653 assert_eq!(net.readout.len(), 1);
654 assert_eq!(net.n_samples_seen(), 0);
655 }
656
657 #[test]
658 fn forward_does_not_crash() {
659 let config = default_small_config();
660 let mut net = SpikeNetFixed::new(config);
661
662 net.forward(&[f64_to_q14(0.5), f64_to_q14(-0.3)]);
664 net.forward(&[f64_to_q14(0.8), f64_to_q14(0.2)]);
666
667 let raw = net.predict_raw();
669 assert_eq!(raw.len(), 1, "should have one readout output");
670 }
671
672 #[test]
673 fn train_step_increments_counter() {
674 let config = default_small_config();
675 let mut net = SpikeNetFixed::new(config);
676
677 let input = [f64_to_q14(0.5), f64_to_q14(-0.3)];
678 let target = [f64_to_q14(0.7)];
679
680 net.train_step(&input, &target);
681 assert_eq!(net.n_samples_seen(), 1);
682
683 net.train_step(&input, &target);
684 assert_eq!(net.n_samples_seen(), 2);
685 }
686
687 #[test]
688 fn predictions_change_after_training() {
689 let config = SpikeNetFixedConfig {
690 n_input: 2,
691 n_hidden: 16,
692 n_output: 1,
693 alpha: f64_to_q14(0.9),
694 kappa: f64_to_q14(0.95),
695 kappa_out: f64_to_q14(0.85),
696 eta: f64_to_q14(0.05), v_thr: f64_to_q14(0.3), gamma: f64_to_q14(0.5),
699 spike_threshold: f64_to_q14(0.01), seed: 12345,
701 weight_init_range: f64_to_q14(0.2),
702 };
703
704 let mut net = SpikeNetFixed::new(config);
705 let scale = 1.0 / Q14_ONE as f64;
706
707 net.forward(&[0, 0]);
709 let pred_before = net.predict_f64(scale);
710
711 for step in 0..200 {
713 let x = if step % 2 == 0 {
714 [f64_to_q14(0.8), f64_to_q14(-0.5)]
715 } else {
716 [f64_to_q14(-0.3), f64_to_q14(0.6)]
717 };
718 let target = if step % 2 == 0 {
719 [f64_to_q14(1.0)]
720 } else {
721 [f64_to_q14(-1.0)]
722 };
723 net.train_step(&x, &target);
724 }
725
726 let pred_after = net.predict_f64(scale);
727
728 assert!(
729 (pred_after - pred_before).abs() > 1e-10,
730 "prediction should change after training: before={}, after={}",
731 pred_before,
732 pred_after
733 );
734 }
735
736 #[test]
737 fn reset_restores_initial_state() {
738 let config = default_small_config();
739 let mut net = SpikeNetFixed::new(config.clone());
740 let fresh = SpikeNetFixed::new(config);
741
742 net.train_step(&[1000, -500], &[2000]);
744 net.train_step(&[-1000, 500], &[-2000]);
745 assert!(net.n_samples_seen() > 0);
746
747 net.reset();
749
750 assert_eq!(net.n_samples_seen(), 0);
752 assert_eq!(net.membrane, fresh.membrane);
753 assert_eq!(net.spikes, fresh.spikes);
754 assert_eq!(
755 net.w_input, fresh.w_input,
756 "weights should be re-initialized from seed"
757 );
758 assert_eq!(net.w_recurrent, fresh.w_recurrent);
759 assert_eq!(net.w_output, fresh.w_output);
760 assert_eq!(net.feedback, fresh.feedback);
761 }
762
763 #[test]
764 fn memory_bytes_is_reasonable() {
765 let config = SpikeNetFixedConfig {
766 n_input: 10,
767 n_hidden: 64,
768 n_output: 1,
769 ..SpikeNetFixedConfig::default()
770 };
771 let net = SpikeNetFixed::new(config);
772 let mem = net.memory_bytes();
773
774 assert!(
780 mem > 20_000,
781 "memory should be at least 20KB for 10-in/64-hid/1-out, got {}",
782 mem
783 );
784 assert!(
785 mem < 100_000,
786 "memory should be under 100KB for small network, got {}",
787 mem
788 );
789 }
790
791 #[test]
792 fn deterministic_with_same_seed() {
793 let config = default_small_config();
794 let mut net1 = SpikeNetFixed::new(config.clone());
795 let mut net2 = SpikeNetFixed::new(config);
796
797 let input = [f64_to_q14(0.3), f64_to_q14(-0.7)];
798 let target = [f64_to_q14(0.5)];
799
800 for _ in 0..10 {
801 net1.train_step(&input, &target);
802 net2.train_step(&input, &target);
803 }
804
805 let scale = 1.0 / Q14_ONE as f64;
806 let p1 = net1.predict_f64(scale);
807 let p2 = net2.predict_f64(scale);
808 assert_eq!(p1, p2, "same seed should produce identical predictions");
809 }
810
811 #[test]
812 fn multi_output_network() {
813 let config = SpikeNetFixedConfig {
814 n_input: 3,
815 n_hidden: 8,
816 n_output: 3,
817 ..SpikeNetFixedConfig::default()
818 };
819 let mut net = SpikeNetFixed::new(config);
820
821 net.forward(&[1000, -500, 200]);
822 net.forward(&[1500, 0, -300]);
823
824 let raw = net.predict_raw();
825 assert_eq!(raw.len(), 3, "should have 3 readout outputs");
826
827 let scale = 1.0 / Q14_ONE as f64;
828 let all = net.predict_all_f64(scale);
829 assert_eq!(all.len(), 3);
830 }
831
832 #[test]
833 fn train_step_with_multi_output() {
834 let config = SpikeNetFixedConfig {
835 n_input: 2,
836 n_hidden: 8,
837 n_output: 2,
838 ..SpikeNetFixedConfig::default()
839 };
840 let mut net = SpikeNetFixed::new(config);
841
842 net.train_step(&[1000, -500], &[2000, -1000]);
844 assert_eq!(net.n_samples_seen(), 1);
845 }
846
847 #[test]
848 fn hidden_spikes_accessible() {
849 let config = default_small_config();
850 let mut net = SpikeNetFixed::new(config);
851
852 net.forward(&[0, 0]);
853 net.forward(&[Q14_ONE, -Q14_ONE]); let spikes = net.hidden_spikes();
856 assert_eq!(spikes.len(), 8);
857 for &s in spikes {
859 assert!(s == 0 || s == 1, "spike should be 0 or 1, got {}", s);
860 }
861 }
862
863 #[test]
864 fn config_default_is_sensible() {
865 let config = SpikeNetFixedConfig::default();
866 assert!(config.alpha > 0, "alpha should be positive");
867 assert!(config.v_thr > 0, "v_thr should be positive");
868 assert!(config.eta > 0, "eta should be positive");
869 assert!(config.n_hidden > 0, "n_hidden should be positive");
870 }
871}