use alloc::vec;
use alloc::vec::Vec;
use super::astrocyte::{AstrocyteGate, AstrocyteMode};
use super::eprop::{
compute_learning_signal_fixed, update_eligibility_fixed, update_output_weights_fixed,
update_pre_trace_fixed, update_weights_fixed,
};
use super::lif::{lif_step, surrogate_gradient_pwl};
use super::readout::ReadoutNeuron;
use super::spike_encoding::DeltaEncoderFixed;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum Precision {
Float,
Fixed,
}
#[allow(clippy::derivable_impls)]
impl Default for Precision {
fn default() -> Self {
#[cfg(feature = "std")]
{
Precision::Float
}
#[cfg(not(feature = "std"))]
{
Precision::Fixed
}
}
}
#[derive(Debug, Clone)]
pub struct SpikeNetFixedConfig {
pub n_input: usize,
pub n_hidden: usize,
pub n_output: usize,
pub alpha: i16,
pub kappa: i16,
pub kappa_out: i16,
pub eta: i16,
pub v_thr: i16,
pub gamma: i16,
pub spike_threshold: i16,
pub seed: u64,
pub weight_init_range: i16,
pub use_astrocyte: bool,
pub astrocyte_tau: f64,
pub astrocyte_mode: AstrocyteMode,
}
impl Default for SpikeNetFixedConfig {
fn default() -> Self {
Self {
n_input: 1,
n_hidden: 64,
n_output: 1,
alpha: 15565, kappa: 16220, kappa_out: 14746, eta: 16, v_thr: 8192, gamma: 4915, spike_threshold: 819, seed: 42,
weight_init_range: 1638, use_astrocyte: false,
astrocyte_tau: 1000.0,
astrocyte_mode: AstrocyteMode::WeightMod,
}
}
}
use crate::rng::xorshift64;
#[inline]
fn xorshift64_i16(state: &mut u64, range: i16) -> i16 {
let raw = xorshift64(state);
let abs_range = if range < 0 { -range } else { range };
if abs_range == 0 {
return 0;
}
let abs_u64 = abs_range as u64;
let modulus = 2 * abs_u64 + 1;
((raw % modulus) as i16) - abs_range
}
pub struct SpikeNetFixed {
config: SpikeNetFixedConfig,
n_input_encoded: usize,
membrane: Vec<i16>, spikes: Vec<u8>, prev_spikes: Vec<u8>,
pre_trace_in: Vec<i16>, pre_trace_hid: Vec<i16>,
w_input: Vec<i16>, w_recurrent: Vec<i16>, w_output: Vec<i16>, feedback: Vec<i16>,
elig_in: Vec<i16>, elig_rec: Vec<i16>,
readout: Vec<ReadoutNeuron>,
encoder: DeltaEncoderFixed,
spike_buf: Vec<u8>,
error_buf: Vec<i16>,
astrocyte: Option<AstrocyteGate>,
n_samples: u64,
}
unsafe impl Send for SpikeNetFixed {}
unsafe impl Sync for SpikeNetFixed {}
impl SpikeNetFixed {
pub fn new(config: SpikeNetFixedConfig) -> Self {
let n_in = config.n_input;
let n_hid = config.n_hidden;
let n_out = config.n_output;
let n_enc = 2 * n_in;
let mut rng_state = if config.seed == 0 { 1 } else { config.seed };
let range = config.weight_init_range;
let w_input: Vec<i16> = (0..n_hid * n_enc)
.map(|_| xorshift64_i16(&mut rng_state, range))
.collect();
let w_recurrent: Vec<i16> = (0..n_hid * n_hid)
.map(|_| xorshift64_i16(&mut rng_state, range))
.collect();
let w_output: Vec<i16> = (0..n_out * n_hid)
.map(|_| xorshift64_i16(&mut rng_state, range))
.collect();
let feedback: Vec<i16> = (0..n_hid * n_out)
.map(|_| xorshift64_i16(&mut rng_state, range))
.collect();
let readout: Vec<ReadoutNeuron> = (0..n_out)
.map(|_| ReadoutNeuron::new(config.kappa_out))
.collect();
let encoder = DeltaEncoderFixed::new(n_in, config.spike_threshold);
let astrocyte = if config.use_astrocyte {
Some(AstrocyteGate::with_mode(
n_hid,
config.astrocyte_tau,
config.astrocyte_mode,
))
} else {
None
};
Self {
n_input_encoded: n_enc,
membrane: vec![0; n_hid],
spikes: vec![0; n_hid],
prev_spikes: vec![0; n_hid],
pre_trace_in: vec![0; n_enc],
pre_trace_hid: vec![0; n_hid],
w_input,
w_recurrent,
w_output,
feedback,
elig_in: vec![0; n_hid * n_enc],
elig_rec: vec![0; n_hid * n_hid],
readout,
encoder,
spike_buf: vec![0; n_enc],
error_buf: vec![0; n_out],
astrocyte,
n_samples: 0,
config,
}
}
pub fn forward(&mut self, input_i16: &[i16]) {
let n_hid = self.config.n_hidden;
let n_enc = self.n_input_encoded;
self.encoder.encode(input_i16, &mut self.spike_buf);
self.prev_spikes.copy_from_slice(&self.spikes);
for j in 0..n_hid {
let mut current: i32 = 0;
let w_in_offset = j * n_enc;
for i in 0..n_enc {
if self.spike_buf[i] != 0 {
let w = match &self.astrocyte {
Some(astro) if astro.mode() == AstrocyteMode::WeightMod => {
astro.modulate_weight(j, self.w_input[w_in_offset + i])
}
_ => self.w_input[w_in_offset + i],
};
current += w as i32;
}
}
let w_rec_offset = j * n_hid;
for i in 0..n_hid {
if self.prev_spikes[i] != 0 {
current += self.w_recurrent[w_rec_offset + i] as i32;
}
}
let (v_new, spike) = lif_step(
self.membrane[j],
self.config.alpha,
current,
self.config.v_thr,
);
self.membrane[j] = v_new;
self.spikes[j] = spike as u8;
}
if let Some(ref mut astro) = self.astrocyte {
astro.update(&self.spikes);
}
let n_out = self.config.n_output;
for k in 0..n_out {
let w_out_offset = k * n_hid;
let mut weighted_input: i32 = 0;
for j in 0..n_hid {
if self.spikes[j] != 0 {
weighted_input += self.w_output[w_out_offset + j] as i32;
}
}
self.readout[k].step(weighted_input);
}
}
pub fn train_step(&mut self, input_i16: &[i16], target_i16: &[i16]) {
let n_hid = self.config.n_hidden;
let n_enc = self.n_input_encoded;
let n_out = self.config.n_output;
self.forward(input_i16);
for (k, &target_k) in target_i16.iter().enumerate().take(n_out) {
let readout_clamped = self.readout[k]
.output_i32()
.clamp(i16::MIN as i32, i16::MAX as i32) as i16;
self.error_buf[k] = target_k.saturating_sub(readout_clamped);
}
update_pre_trace_fixed(&mut self.pre_trace_in, &self.spike_buf, self.config.alpha);
update_pre_trace_fixed(&mut self.pre_trace_hid, &self.spikes, self.config.alpha);
for j in 0..n_hid {
let psi =
surrogate_gradient_pwl(self.membrane[j], self.config.v_thr, self.config.gamma);
let elig_in_start = j * n_enc;
let elig_in_end = elig_in_start + n_enc;
update_eligibility_fixed(
&mut self.elig_in[elig_in_start..elig_in_end],
psi,
&self.pre_trace_in,
self.config.kappa,
);
let elig_rec_start = j * n_hid;
let elig_rec_end = elig_rec_start + n_hid;
update_eligibility_fixed(
&mut self.elig_rec[elig_rec_start..elig_rec_end],
psi,
&self.pre_trace_hid,
self.config.kappa,
);
let fb_start = j * n_out;
let fb_end = fb_start + n_out;
let learning_signal = compute_learning_signal_fixed(
&self.feedback[fb_start..fb_end],
&self.error_buf[..n_out],
);
let eta_j = match &self.astrocyte {
Some(astro) if astro.mode() == AstrocyteMode::LearningRateGate => {
astro.effective_eta_q14(j, self.config.eta)
}
_ => self.config.eta,
};
let w_in_start = j * n_enc;
let w_in_end = w_in_start + n_enc;
update_weights_fixed(
&mut self.w_input[w_in_start..w_in_end],
&self.elig_in[elig_in_start..elig_in_end],
learning_signal,
eta_j,
);
let w_rec_start = j * n_hid;
let w_rec_end = w_rec_start + n_hid;
update_weights_fixed(
&mut self.w_recurrent[w_rec_start..w_rec_end],
&self.elig_rec[elig_rec_start..elig_rec_end],
learning_signal,
eta_j,
);
}
for k in 0..n_out {
let w_out_start = k * n_hid;
let w_out_end = w_out_start + n_hid;
update_output_weights_fixed(
&mut self.w_output[w_out_start..w_out_end],
self.error_buf[k],
&self.spikes,
self.config.eta,
);
}
self.n_samples += 1;
}
pub fn predict_raw(&self) -> Vec<i32> {
self.readout.iter().map(|r| r.output_i32()).collect()
}
pub fn predict_f64(&self, output_scale: f64) -> f64 {
if self.readout.is_empty() {
return 0.0;
}
self.readout[0].output_f64(output_scale)
}
pub fn predict_all_f64(&self, output_scale: f64) -> Vec<f64> {
self.readout
.iter()
.map(|r| r.output_f64(output_scale))
.collect()
}
pub fn n_samples_seen(&self) -> u64 {
self.n_samples
}
pub fn config(&self) -> &SpikeNetFixedConfig {
&self.config
}
pub fn n_hidden(&self) -> usize {
self.config.n_hidden
}
pub fn n_input_encoded(&self) -> usize {
self.n_input_encoded
}
pub fn hidden_spikes(&self) -> &[u8] {
&self.spikes
}
pub fn hidden_membrane(&self) -> &[i16] {
&self.membrane
}
pub fn memory_bytes(&self) -> usize {
let n_hid = self.config.n_hidden;
let n_enc = self.n_input_encoded;
let n_out = self.config.n_output;
let n_in = self.config.n_input;
let size_of_i16 = core::mem::size_of::<i16>();
let size_of_u8 = core::mem::size_of::<u8>();
let membrane = n_hid * size_of_i16;
let spikes = n_hid * size_of_u8;
let prev_spikes = n_hid * size_of_u8;
let pre_trace_in = n_enc * size_of_i16;
let pre_trace_hid = n_hid * size_of_i16;
let w_input = n_hid * n_enc * size_of_i16;
let w_recurrent = n_hid * n_hid * size_of_i16;
let w_output = n_out * n_hid * size_of_i16;
let feedback = n_hid * n_out * size_of_i16;
let elig_in = n_hid * n_enc * size_of_i16;
let elig_rec = n_hid * n_hid * size_of_i16;
let readout_size = n_out * core::mem::size_of::<ReadoutNeuron>();
let encoder_prev = n_in * size_of_i16;
let encoder_thr = n_in * size_of_i16;
let spike_buf = n_enc * size_of_u8;
let error_buf = n_out * size_of_i16;
let struct_overhead = core::mem::size_of::<Self>();
let vec_contents = membrane
+ spikes
+ prev_spikes
+ pre_trace_in
+ pre_trace_hid
+ w_input
+ w_recurrent
+ w_output
+ feedback
+ elig_in
+ elig_rec
+ readout_size
+ encoder_prev
+ encoder_thr
+ spike_buf
+ error_buf;
struct_overhead + vec_contents
}
pub fn reset(&mut self) {
for v in self.membrane.iter_mut() {
*v = 0;
}
for s in self.spikes.iter_mut() {
*s = 0;
}
for s in self.prev_spikes.iter_mut() {
*s = 0;
}
for t in self.pre_trace_in.iter_mut() {
*t = 0;
}
for t in self.pre_trace_hid.iter_mut() {
*t = 0;
}
for e in self.elig_in.iter_mut() {
*e = 0;
}
for e in self.elig_rec.iter_mut() {
*e = 0;
}
for r in self.readout.iter_mut() {
r.reset();
}
self.encoder.reset();
for s in self.spike_buf.iter_mut() {
*s = 0;
}
for e in self.error_buf.iter_mut() {
*e = 0;
}
if let Some(ref mut astro) = self.astrocyte {
astro.reset();
}
let mut rng_state = if self.config.seed == 0 {
1
} else {
self.config.seed
};
let range = self.config.weight_init_range;
for w in self.w_input.iter_mut() {
*w = xorshift64_i16(&mut rng_state, range);
}
for w in self.w_recurrent.iter_mut() {
*w = xorshift64_i16(&mut rng_state, range);
}
for w in self.w_output.iter_mut() {
*w = xorshift64_i16(&mut rng_state, range);
}
for w in self.feedback.iter_mut() {
*w = xorshift64_i16(&mut rng_state, range);
}
self.n_samples = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::snn::lif::{f64_to_q14, Q14_ONE};
fn default_small_config() -> SpikeNetFixedConfig {
SpikeNetFixedConfig {
n_input: 2,
n_hidden: 8,
n_output: 1,
alpha: f64_to_q14(0.95),
kappa: f64_to_q14(0.99),
kappa_out: f64_to_q14(0.9),
eta: f64_to_q14(0.01),
v_thr: f64_to_q14(0.5),
gamma: f64_to_q14(0.3),
spike_threshold: f64_to_q14(0.05),
seed: 42,
weight_init_range: f64_to_q14(0.1),
use_astrocyte: false,
astrocyte_tau: 1000.0,
astrocyte_mode: AstrocyteMode::WeightMod,
}
}
#[test]
fn construction_initializes_all_buffers() {
let config = default_small_config();
let net = SpikeNetFixed::new(config);
assert_eq!(net.membrane.len(), 8);
assert_eq!(net.spikes.len(), 8);
assert_eq!(net.n_input_encoded(), 4);
assert_eq!(net.w_input.len(), 8 * 4);
assert_eq!(net.w_recurrent.len(), 8 * 8);
assert_eq!(net.w_output.len(), 8);
assert_eq!(net.feedback.len(), 8);
assert_eq!(net.elig_in.len(), 8 * 4);
assert_eq!(net.elig_rec.len(), 8 * 8);
assert_eq!(net.readout.len(), 1);
assert_eq!(net.n_samples_seen(), 0);
}
#[test]
fn forward_does_not_crash() {
let config = default_small_config();
let mut net = SpikeNetFixed::new(config);
net.forward(&[f64_to_q14(0.5), f64_to_q14(-0.3)]);
net.forward(&[f64_to_q14(0.8), f64_to_q14(0.2)]);
let raw = net.predict_raw();
assert_eq!(raw.len(), 1, "should have one readout output");
}
#[test]
fn train_step_increments_counter() {
let config = default_small_config();
let mut net = SpikeNetFixed::new(config);
let input = [f64_to_q14(0.5), f64_to_q14(-0.3)];
let target = [f64_to_q14(0.7)];
net.train_step(&input, &target);
assert_eq!(net.n_samples_seen(), 1);
net.train_step(&input, &target);
assert_eq!(net.n_samples_seen(), 2);
}
#[test]
fn predictions_change_after_training() {
let config = SpikeNetFixedConfig {
n_input: 2,
n_hidden: 16,
n_output: 1,
alpha: f64_to_q14(0.9),
kappa: f64_to_q14(0.95),
kappa_out: f64_to_q14(0.85),
eta: f64_to_q14(0.05), v_thr: f64_to_q14(0.3), gamma: f64_to_q14(0.5),
spike_threshold: f64_to_q14(0.01), seed: 12345,
weight_init_range: f64_to_q14(0.2),
use_astrocyte: false,
astrocyte_tau: 1000.0,
astrocyte_mode: AstrocyteMode::WeightMod,
};
let mut net = SpikeNetFixed::new(config);
let scale = 1.0 / Q14_ONE as f64;
net.forward(&[0, 0]);
let pred_before = net.predict_f64(scale);
for step in 0..200 {
let x = if step % 2 == 0 {
[f64_to_q14(0.8), f64_to_q14(-0.5)]
} else {
[f64_to_q14(-0.3), f64_to_q14(0.6)]
};
let target = if step % 2 == 0 {
[f64_to_q14(1.0)]
} else {
[f64_to_q14(-1.0)]
};
net.train_step(&x, &target);
}
let pred_after = net.predict_f64(scale);
assert!(
(pred_after - pred_before).abs() > 1e-10,
"prediction should change after training: before={}, after={}",
pred_before,
pred_after
);
}
#[test]
fn reset_restores_initial_state() {
let config = default_small_config();
let mut net = SpikeNetFixed::new(config.clone());
let fresh = SpikeNetFixed::new(config);
net.train_step(&[1000, -500], &[2000]);
net.train_step(&[-1000, 500], &[-2000]);
assert!(net.n_samples_seen() > 0);
net.reset();
assert_eq!(net.n_samples_seen(), 0);
assert_eq!(net.membrane, fresh.membrane);
assert_eq!(net.spikes, fresh.spikes);
assert_eq!(
net.w_input, fresh.w_input,
"weights should be re-initialized from seed"
);
assert_eq!(net.w_recurrent, fresh.w_recurrent);
assert_eq!(net.w_output, fresh.w_output);
assert_eq!(net.feedback, fresh.feedback);
}
#[test]
fn memory_bytes_is_reasonable() {
let config = SpikeNetFixedConfig {
n_input: 10,
n_hidden: 64,
n_output: 1,
..SpikeNetFixedConfig::default()
};
let net = SpikeNetFixed::new(config);
let mem = net.memory_bytes();
assert!(
mem > 20_000,
"memory should be at least 20KB for 10-in/64-hid/1-out, got {}",
mem
);
assert!(
mem < 100_000,
"memory should be under 100KB for small network, got {}",
mem
);
}
#[test]
fn deterministic_with_same_seed() {
let config = default_small_config();
let mut net1 = SpikeNetFixed::new(config.clone());
let mut net2 = SpikeNetFixed::new(config);
let input = [f64_to_q14(0.3), f64_to_q14(-0.7)];
let target = [f64_to_q14(0.5)];
for _ in 0..10 {
net1.train_step(&input, &target);
net2.train_step(&input, &target);
}
let scale = 1.0 / Q14_ONE as f64;
let p1 = net1.predict_f64(scale);
let p2 = net2.predict_f64(scale);
assert_eq!(p1, p2, "same seed should produce identical predictions");
}
#[test]
fn multi_output_network() {
let config = SpikeNetFixedConfig {
n_input: 3,
n_hidden: 8,
n_output: 3,
..SpikeNetFixedConfig::default()
};
let mut net = SpikeNetFixed::new(config);
net.forward(&[1000, -500, 200]);
net.forward(&[1500, 0, -300]);
let raw = net.predict_raw();
assert_eq!(raw.len(), 3, "should have 3 readout outputs");
let scale = 1.0 / Q14_ONE as f64;
let all = net.predict_all_f64(scale);
assert_eq!(all.len(), 3);
}
#[test]
fn train_step_with_multi_output() {
let config = SpikeNetFixedConfig {
n_input: 2,
n_hidden: 8,
n_output: 2,
..SpikeNetFixedConfig::default()
};
let mut net = SpikeNetFixed::new(config);
net.train_step(&[1000, -500], &[2000, -1000]);
assert_eq!(net.n_samples_seen(), 1);
}
#[test]
fn network_with_astrocyte_runs() {
let config = SpikeNetFixedConfig {
use_astrocyte: true,
astrocyte_tau: 100.0,
..default_small_config()
};
let mut net = SpikeNetFixed::new(config);
for _ in 0..50 {
net.train_step(&[1000, -500], &[2000]);
}
assert_eq!(net.n_samples_seen(), 50);
let raw = net.predict_raw();
assert_eq!(raw.len(), 1);
}
#[test]
fn agmp_modulates_learning_rate_not_weights() {
use crate::snn::lif::f64_to_q14;
let config = SpikeNetFixedConfig {
use_astrocyte: true,
astrocyte_tau: 10.0, astrocyte_mode: AstrocyteMode::LearningRateGate,
n_input: 2,
n_hidden: 16,
n_output: 1,
..SpikeNetFixedConfig::default()
};
let mut net = SpikeNetFixed::new(config);
let input = [f64_to_q14(0.5), f64_to_q14(-0.3)];
let target = [f64_to_q14(1.0)];
for _ in 0..200 {
net.train_step(&input, &target);
}
let scale = 1.0 / Q14_ONE as f64;
let pred = net.predict_f64(scale);
assert!(
pred.is_finite(),
"LearningRateGate network should produce finite prediction after training, got {pred}"
);
assert_eq!(net.n_samples_seen(), 200);
}
#[test]
fn hidden_spikes_accessible() {
let config = default_small_config();
let mut net = SpikeNetFixed::new(config);
net.forward(&[0, 0]);
net.forward(&[Q14_ONE, -Q14_ONE]);
let spikes = net.hidden_spikes();
assert_eq!(spikes.len(), 8);
for &s in spikes {
assert!(s == 0 || s == 1, "spike should be 0 or 1, got {}", s);
}
}
#[test]
fn config_default_is_sensible() {
let config = SpikeNetFixedConfig::default();
assert!(config.alpha > 0, "alpha should be positive");
assert!(config.v_thr > 0, "v_thr should be positive");
assert!(config.eta > 0, "eta should be positive");
assert!(config.n_hidden > 0, "n_hidden should be positive");
}
#[test]
fn precision_default_is_float_in_std() {
let p = Precision::default();
assert_eq!(
p,
Precision::Float,
"Precision::default() must be Float on std targets, got {p:?}"
);
assert_ne!(
Precision::Float,
Precision::Fixed,
"Float and Fixed must be distinct"
);
}
}