use crate::error::{ModelError, ModelResult};
use crate::{AutoregressiveModel, ModelType};
use kizzasi_core::{CoreResult, HiddenState, SignalPredictor};
use scirs2_core::ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
#[allow(unused_imports)]
use tracing::{debug, instrument, trace};
struct SeededRng {
state: u64,
}
impl SeededRng {
fn new(seed: u64) -> Self {
Self { state: seed.max(1) }
}
fn next_f32(&mut self) -> f32 {
self.state ^= self.state << 13;
self.state ^= self.state >> 7;
self.state ^= self.state << 17;
(self.state as f64 / u64::MAX as f64 * 2.0 - 1.0) as f32
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ResetMode {
HardReset,
SoftReset,
SubThreshold,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpikingConfig {
pub input_dim: usize,
pub hidden_dim: usize,
pub output_dim: usize,
pub num_layers: usize,
pub threshold: f32,
pub leak_factor: f32,
pub refractory_period: usize,
pub reset_mode: ResetMode,
pub dt: f32,
}
impl SpikingConfig {
pub fn new(input_dim: usize, hidden_dim: usize, output_dim: usize, num_layers: usize) -> Self {
Self {
input_dim,
hidden_dim,
output_dim,
num_layers,
threshold: 1.0,
leak_factor: 0.9,
refractory_period: 2,
reset_mode: ResetMode::SoftReset,
dt: 1.0,
}
}
pub fn validate(&self) -> ModelResult<()> {
if self.input_dim == 0 {
return Err(ModelError::invalid_config("input_dim must be > 0"));
}
if self.hidden_dim == 0 {
return Err(ModelError::invalid_config("hidden_dim must be > 0"));
}
if self.output_dim == 0 {
return Err(ModelError::invalid_config("output_dim must be > 0"));
}
if self.num_layers == 0 {
return Err(ModelError::invalid_config("num_layers must be > 0"));
}
if !self.threshold.is_finite() || self.threshold <= 0.0 {
return Err(ModelError::invalid_config(
"threshold must be positive and finite",
));
}
if !(0.0..=1.0).contains(&self.leak_factor) {
return Err(ModelError::invalid_config(
"leak_factor must be in [0.0, 1.0]",
));
}
if self.dt <= 0.0 {
return Err(ModelError::invalid_config("dt must be > 0"));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MembranePotential {
pub voltages: Array1<f32>,
pub refractory_countdown: Array1<u32>,
pub last_spike_trace: Array1<f32>,
}
impl MembranePotential {
pub fn new(n: usize) -> Self {
Self {
voltages: Array1::zeros(n),
refractory_countdown: Array1::zeros(n),
last_spike_trace: Array1::zeros(n),
}
}
pub fn reset(&mut self) {
self.voltages.fill(0.0);
self.refractory_countdown.fill(0);
self.last_spike_trace.fill(0.0);
}
}
pub struct LifLayer {
config: SpikingConfig,
weights: Array2<f32>,
bias: Array1<f32>,
output_neurons: usize,
}
impl LifLayer {
pub fn new(input_dim: usize, output_dim: usize, config: &SpikingConfig) -> ModelResult<Self> {
if input_dim == 0 || output_dim == 0 {
return Err(ModelError::invalid_config(
"LIF layer dimensions must be > 0",
));
}
let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
let mut rng =
SeededRng::new(((input_dim + output_dim) as u64).wrapping_mul(6364136223846793005));
let weights = Array2::from_shape_fn((output_dim, input_dim), |_| rng.next_f32() * scale);
let bias = Array1::from_shape_fn(output_dim, |_| rng.next_f32() * 0.01);
Ok(Self {
config: config.clone(),
weights,
bias,
output_neurons: output_dim,
})
}
#[instrument(skip(self, input, state), fields(neurons = self.output_neurons))]
pub fn step(
&self,
input: &Array1<f32>,
state: &mut MembranePotential,
) -> ModelResult<Array1<f32>> {
if input.len() != self.weights.ncols() {
return Err(ModelError::dimension_mismatch(
"LIF input",
self.weights.ncols(),
input.len(),
));
}
if state.voltages.len() != self.output_neurons {
return Err(ModelError::dimension_mismatch(
"LIF state",
self.output_neurons,
state.voltages.len(),
));
}
let synaptic_current = self.weights.dot(input) + &self.bias;
let new_voltages = &state.voltages * self.config.leak_factor + &synaptic_current;
let mut spikes = Array1::<f32>::zeros(self.output_neurons);
let mut updated_voltages = new_voltages.clone();
let mut updated_refractory = state.refractory_countdown.clone();
let mut updated_traces = state.last_spike_trace.clone();
let tau_trace = 20.0_f32;
let trace_decay = (-self.config.dt / tau_trace).exp();
updated_traces.mapv_inplace(|t| t * trace_decay);
for i in 0..self.output_neurons {
if updated_refractory[i] > 0 {
updated_refractory[i] -= 1;
updated_voltages[i] = match self.config.reset_mode {
ResetMode::HardReset => 0.0,
_ => state.voltages[i] * self.config.leak_factor,
};
} else if new_voltages[i] >= self.config.threshold {
spikes[i] = 1.0;
updated_refractory[i] = self.config.refractory_period as u32;
updated_traces[i] += 1.0;
updated_voltages[i] = match self.config.reset_mode {
ResetMode::HardReset => 0.0,
ResetMode::SoftReset => new_voltages[i] - self.config.threshold,
ResetMode::SubThreshold => new_voltages[i] * self.config.leak_factor,
};
}
}
state.voltages = updated_voltages;
state.refractory_countdown = updated_refractory;
state.last_spike_trace = updated_traces;
Ok(spikes)
}
pub fn init_state(&self) -> MembranePotential {
MembranePotential::new(self.output_neurons)
}
pub fn output_neurons(&self) -> usize {
self.output_neurons
}
pub fn weights_mut(&mut self) -> &mut Array2<f32> {
&mut self.weights
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StdpConfig {
pub a_plus: f32,
pub a_minus: f32,
pub tau_plus: f32,
pub tau_minus: f32,
}
impl Default for StdpConfig {
fn default() -> Self {
Self {
a_plus: 0.01,
a_minus: 0.012,
tau_plus: 20.0,
tau_minus: 20.0,
}
}
}
pub struct StdpUpdater {
config: StdpConfig,
}
impl StdpUpdater {
pub fn new(config: StdpConfig) -> Self {
Self { config }
}
#[instrument(skip(self, weights, pre_traces, post_spikes, post_traces))]
pub fn update_weights(
&self,
weights: &mut Array2<f32>,
pre_traces: &Array1<f32>,
post_spikes: &Array1<f32>,
post_traces: &Array1<f32>,
) -> ModelResult<()> {
let (n_out, n_in) = weights.dim();
if pre_traces.len() != n_in {
return Err(ModelError::dimension_mismatch(
"STDP pre_traces",
n_in,
pre_traces.len(),
));
}
if post_spikes.len() != n_out {
return Err(ModelError::dimension_mismatch(
"STDP post_spikes",
n_out,
post_spikes.len(),
));
}
if post_traces.len() != n_out {
return Err(ModelError::dimension_mismatch(
"STDP post_traces",
n_out,
post_traces.len(),
));
}
for i in 0..n_out {
if post_spikes[i] > 0.0 {
for j in 0..n_in {
weights[[i, j]] += self.config.a_plus * pre_traces[j];
}
}
}
for i in 0..n_out {
for j in 0..n_in {
weights[[i, j]] -= self.config.a_minus * post_traces[i] * pre_traces[j];
}
}
weights.mapv_inplace(|w| w.clamp(-10.0, 10.0));
Ok(())
}
}
pub struct SpikingNeuralNetwork {
pub config: SpikingConfig,
layers: Vec<LifLayer>,
output_proj: Array2<f32>,
output_bias: Array1<f32>,
layer_states: Vec<MembranePotential>,
stdp: Option<StdpUpdater>,
step_count: usize,
spike_accumulator: Vec<Array1<f32>>,
}
impl SpikingNeuralNetwork {
#[instrument(skip(config), fields(layers = config.num_layers, hidden = config.hidden_dim))]
pub fn new(config: SpikingConfig) -> ModelResult<Self> {
config.validate()?;
Self::build(config, None)
}
pub fn new_with_stdp(config: SpikingConfig, stdp: StdpConfig) -> ModelResult<Self> {
config.validate()?;
Self::build(config, Some(StdpUpdater::new(stdp)))
}
pub fn small() -> ModelResult<Self> {
let config = SpikingConfig::new(1, 64, 1, 2);
Self::new(config)
}
fn build(config: SpikingConfig, stdp: Option<StdpUpdater>) -> ModelResult<Self> {
debug!(
"Building SNN: layers={}, hidden={}",
config.num_layers, config.hidden_dim
);
let mut layers = Vec::with_capacity(config.num_layers);
layers.push(LifLayer::new(config.input_dim, config.hidden_dim, &config)?);
for _ in 1..config.num_layers {
layers.push(LifLayer::new(
config.hidden_dim,
config.hidden_dim,
&config,
)?);
}
let scale = (2.0 / (config.hidden_dim + config.output_dim) as f32).sqrt();
let mut rng = SeededRng::new(
((config.hidden_dim * 1000 + config.output_dim) as u64)
.wrapping_mul(2862933555777941757),
);
let output_proj = Array2::from_shape_fn((config.output_dim, config.hidden_dim), |_| {
rng.next_f32() * scale
});
let output_bias = Array1::from_shape_fn(config.output_dim, |_| rng.next_f32() * 0.01);
let layer_states: Vec<MembranePotential> = layers.iter().map(|l| l.init_state()).collect();
let spike_accumulator: Vec<Array1<f32>> = layers
.iter()
.map(|l| Array1::zeros(l.output_neurons()))
.collect();
debug!("SNN built with {} layers", layers.len());
Ok(Self {
config,
layers,
output_proj,
output_bias,
layer_states,
stdp,
step_count: 0,
spike_accumulator,
})
}
pub fn init_layer_states(&self) -> Vec<MembranePotential> {
self.layers.iter().map(|l| l.init_state()).collect()
}
pub fn average_firing_rate(&self) -> Vec<f32> {
if self.step_count == 0 {
return vec![0.0; self.layers.len()];
}
self.spike_accumulator
.iter()
.map(|acc| acc.mean().unwrap_or(0.0) / self.step_count as f32)
.collect()
}
fn forward_step(&mut self, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
let mut current = input.clone();
for (layer_idx, layer) in self.layers.iter().enumerate() {
let state = self.layer_states.get_mut(layer_idx).ok_or_else(|| {
ModelError::not_initialized(format!("layer state {} missing", layer_idx))
})?;
let spikes = layer.step(¤t, state)?;
if let Some(acc) = self.spike_accumulator.get_mut(layer_idx) {
*acc += &spikes;
}
current = spikes;
}
if let Some(stdp_updater) = &self.stdp {
if self.layers.len() >= 2 {
let pre_traces = self
.layer_states
.get(self.layers.len() - 2)
.map(|s| s.last_spike_trace.clone())
.unwrap_or_else(|| Array1::zeros(self.config.hidden_dim));
let post_spikes = current.clone();
let post_traces = self
.layer_states
.last()
.map(|s| s.last_spike_trace.clone())
.unwrap_or_else(|| Array1::zeros(self.config.hidden_dim));
let last_idx = self.layers.len() - 1;
let weights = self.layers[last_idx].weights_mut();
stdp_updater.update_weights(weights, &pre_traces, &post_spikes, &post_traces)?;
}
}
let output = self.output_proj.dot(¤t) + &self.output_bias;
if output.iter().any(|v| !v.is_finite()) {
return Err(ModelError::numerical_instability(
"SNN output projection",
"NaN or Inf detected",
));
}
self.step_count += 1;
Ok(output)
}
}
impl SignalPredictor for SpikingNeuralNetwork {
#[instrument(skip(self, input))]
fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
self.forward_step(input)
.map_err(|e| kizzasi_core::CoreError::Generic(e.to_string()))
}
#[instrument(skip(self))]
fn reset(&mut self) {
debug!("Resetting SNN membrane states");
for state in &mut self.layer_states {
state.reset();
}
for acc in &mut self.spike_accumulator {
acc.fill(0.0);
}
self.step_count = 0;
}
fn context_window(&self) -> usize {
usize::MAX
}
}
impl AutoregressiveModel for SpikingNeuralNetwork {
fn hidden_dim(&self) -> usize {
self.config.hidden_dim
}
fn state_dim(&self) -> usize {
self.config.hidden_dim * self.config.num_layers
}
fn num_layers(&self) -> usize {
self.config.num_layers
}
fn model_type(&self) -> ModelType {
ModelType::Snn
}
fn get_states(&self) -> Vec<HiddenState> {
self.layer_states
.iter()
.map(|state| {
let dim = state.voltages.len();
let state_2d = state
.voltages
.clone()
.insert_axis(scirs2_core::ndarray::Axis(0));
let mut hidden = HiddenState::new(dim, 1);
hidden.update(state_2d);
hidden
})
.collect()
}
fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
if states.len() != self.config.num_layers {
return Err(ModelError::state_count_mismatch(
"SNN",
self.config.num_layers,
states.len(),
));
}
for (layer_state, hidden) in self.layer_states.iter_mut().zip(states.iter()) {
let state_2d = hidden.state();
if state_2d.nrows() > 0 && state_2d.ncols() > 0 {
layer_state.voltages = state_2d.row(0).to_owned();
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn small_config() -> SpikingConfig {
SpikingConfig::new(4, 8, 4, 2)
}
#[test]
fn test_lif_layer_creation() {
let config = small_config();
let layer = LifLayer::new(4, 8, &config);
assert!(layer.is_ok());
let layer = layer.expect("LIF layer creation failed");
assert_eq!(layer.output_neurons(), 8);
}
#[test]
fn test_lif_spike_generation() {
let config = SpikingConfig {
threshold: 0.1, leak_factor: 1.0, refractory_period: 0,
..SpikingConfig::new(4, 8, 4, 1)
};
let layer = LifLayer::new(4, 8, &config).expect("LIF layer creation failed");
let mut state = layer.init_state();
let input = Array1::from_vec(vec![10.0_f32; 4]);
let spikes = layer.step(&input, &mut state).expect("step failed");
let total_spikes: f32 = spikes.sum();
assert!(
total_spikes > 0.0,
"expected spikes for large input, got {total_spikes}"
);
}
#[test]
fn test_lif_refractory_period() {
let config = SpikingConfig {
threshold: 0.1,
leak_factor: 1.0,
refractory_period: 5, reset_mode: ResetMode::SoftReset,
..SpikingConfig::new(4, 8, 4, 1)
};
let layer = LifLayer::new(4, 8, &config).expect("LIF layer creation failed");
let mut state = layer.init_state();
let input = Array1::from_vec(vec![10.0_f32; 4]);
let spikes1 = layer.step(&input, &mut state).expect("step 1 failed");
let total_spikes1: f32 = spikes1.sum();
let spikes2 = layer.step(&input, &mut state).expect("step 2 failed");
let total_spikes2: f32 = spikes2.sum();
if total_spikes1 > 0.0 {
assert!(
total_spikes2 <= total_spikes1 || total_spikes2 == 0.0,
"refractory period not respected: step1={total_spikes1}, step2={total_spikes2}"
);
}
}
#[test]
fn test_lif_membrane_decay() {
let config = SpikingConfig {
threshold: 1000.0, leak_factor: 0.5,
refractory_period: 0,
reset_mode: ResetMode::HardReset,
..SpikingConfig::new(4, 8, 4, 1)
};
let layer = LifLayer::new(4, 8, &config).expect("LIF layer creation failed");
let mut state = layer.init_state();
let input = Array1::from_vec(vec![0.1_f32; 4]);
let _ = layer.step(&input, &mut state).expect("step 1 failed");
let v1 = state.voltages.clone();
let zero_input = Array1::zeros(4);
let _ = layer.step(&zero_input, &mut state).expect("step 2 failed");
let v2 = state.voltages.clone();
let v1_norm: f32 = v1.iter().map(|x| x.abs()).sum();
let v2_norm: f32 = v2.iter().map(|x| x.abs()).sum();
if v1_norm > 1e-6 {
assert!(
v2_norm < v1_norm,
"membrane voltages should decay, got v1={v1_norm}, v2={v2_norm}"
);
}
}
#[test]
fn test_snn_forward() {
let config = small_config();
let output_dim = config.output_dim;
let mut model = SpikingNeuralNetwork::new(config).expect("SNN creation failed");
let input = Array1::from_vec(vec![0.5_f32; 4]);
let output = model.forward_step(&input).expect("forward pass failed");
assert_eq!(output.len(), output_dim);
assert!(
output.iter().all(|v| v.is_finite()),
"output must be finite"
);
}
#[test]
fn test_snn_signal_predictor() {
let config = small_config();
let output_dim = config.output_dim;
let mut model = SpikingNeuralNetwork::new(config).expect("SNN creation failed");
let input = Array1::from_vec(vec![0.1_f32; 4]);
let output = model.step(&input).expect("SignalPredictor::step failed");
assert_eq!(output.len(), output_dim);
assert!(output.iter().all(|v| v.is_finite()));
}
#[test]
fn test_snn_reset() {
let config = small_config();
let mut model = SpikingNeuralNetwork::new(config).expect("SNN creation failed");
let input = Array1::from_vec(vec![1.0_f32; 4]);
for _ in 0..10 {
let _ = model.step(&input).expect("step failed");
}
assert!(model.step_count > 0, "step_count should be > 0");
model.reset();
assert_eq!(model.step_count, 0, "step_count should be 0 after reset");
for state in &model.layer_states {
assert!(
state.voltages.iter().all(|&v| v == 0.0),
"voltages should be zero after reset"
);
}
}
#[test]
fn test_stdp_update() {
let stdp_config = StdpConfig::default();
let updater = StdpUpdater::new(stdp_config);
let mut weights = Array2::zeros((4, 4));
let pre_traces = Array1::from_vec(vec![1.0_f32; 4]);
let post_spikes = Array1::from_vec(vec![1.0_f32, 0.0, 1.0, 0.0]);
let post_traces = Array1::from_vec(vec![0.5_f32; 4]);
let result = updater.update_weights(&mut weights, &pre_traces, &post_spikes, &post_traces);
assert!(result.is_ok(), "STDP update failed: {:?}", result);
let w00 = weights[[0, 0]];
assert!(
w00 > -1e-6,
"weight for spiking post-neuron should not strongly decrease: got {w00}"
);
}
}