use serde::{Deserialize, Serialize};
use crate::error::{MastishkError, validate_dt};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IzhikevichNeuron {
pub v: f32,
pub u: f32,
pub a: f32,
pub b: f32,
pub c: f32,
pub d: f32,
}
impl IzhikevichNeuron {
#[must_use]
pub fn regular_spiking() -> Self {
Self {
v: -65.0,
u: -14.0,
a: 0.02,
b: 0.2,
c: -65.0,
d: 8.0,
}
}
#[must_use]
pub fn fast_spiking() -> Self {
Self {
v: -65.0,
u: -14.0,
a: 0.1,
b: 0.2,
c: -65.0,
d: 2.0,
}
}
#[must_use]
pub fn chattering() -> Self {
Self {
v: -65.0,
u: -14.0,
a: 0.02,
b: 0.2,
c: -50.0,
d: 2.0,
}
}
#[must_use]
pub fn intrinsically_bursting() -> Self {
Self {
v: -65.0,
u: -14.0,
a: 0.02,
b: 0.2,
c: -55.0,
d: 4.0,
}
}
#[inline]
pub fn tick(&mut self, input: f32, dt_ms: f32) -> bool {
let mut spiked = false;
let mut remaining = dt_ms;
while remaining > 0.0 {
let step = remaining.min(0.5);
self.v += step * (0.04 * self.v * self.v + 5.0 * self.v + 140.0 - self.u + input);
self.u += step * self.a * (self.b * self.v - self.u);
if self.v >= 30.0 {
self.v = self.c;
self.u += self.d;
spiked = true;
}
remaining -= step;
}
spiked
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LifNeuron {
pub v: f32,
pub v_rest: f32,
pub v_thresh: f32,
pub v_reset: f32,
pub tau_m: f32,
pub r_m: f32,
}
impl LifNeuron {
#[must_use]
pub fn default_params() -> Self {
Self {
v: -65.0,
v_rest: -65.0,
v_thresh: -55.0,
v_reset: -70.0,
tau_m: 15.0,
r_m: 10.0,
}
}
#[inline]
pub fn tick(&mut self, input: f32, dt_ms: f32) -> bool {
let tau = self.tau_m.max(f32::EPSILON);
let dv = (-(self.v - self.v_rest) + self.r_m * input) / tau;
self.v += dv * dt_ms;
if self.v >= self.v_thresh {
self.v = self.v_reset;
return true;
}
false
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum SpikingNeuron {
Izhikevich(IzhikevichNeuron),
Lif(LifNeuron),
}
impl SpikingNeuron {
#[inline]
pub fn tick(&mut self, input: f32, dt_ms: f32) -> bool {
match self {
Self::Izhikevich(n) => n.tick(input, dt_ms),
Self::Lif(n) => n.tick(input, dt_ms),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpikingSynapse {
pub from: usize,
pub to: usize,
pub weight: f32,
pub delay_ms: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StdpRule {
pub a_plus: f32,
pub a_minus: f32,
pub tau_plus: f32,
pub tau_minus: f32,
}
impl Default for StdpRule {
fn default() -> Self {
Self {
a_plus: 0.01,
a_minus: 0.012,
tau_plus: 20.0,
tau_minus: 20.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BcmRule {
pub theta_m: f32,
pub tau_theta: f32,
}
impl Default for BcmRule {
fn default() -> Self {
Self {
theta_m: 0.5,
tau_theta: 1000.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpikingNetwork {
pub neurons: Vec<SpikingNeuron>,
pub synapses: Vec<SpikingSynapse>,
spike_times: Vec<Option<f32>>,
last_spiked: Vec<usize>,
inputs: Vec<f32>,
pub time_ms: f32,
pub stdp: Option<StdpRule>,
}
impl Default for SpikingNetwork {
fn default() -> Self {
Self::new()
}
}
impl SpikingNetwork {
#[must_use]
pub fn new() -> Self {
Self {
neurons: Vec::new(),
synapses: Vec::new(),
spike_times: Vec::new(),
last_spiked: Vec::new(),
inputs: Vec::new(),
time_ms: 0.0,
stdp: None,
}
}
pub fn add_neuron(&mut self, neuron: SpikingNeuron) -> usize {
let idx = self.neurons.len();
self.neurons.push(neuron);
self.inputs.push(0.0);
self.spike_times.push(None);
idx
}
pub fn add_synapse(
&mut self,
from: usize,
to: usize,
weight: f32,
delay_ms: f32,
) -> Result<(), MastishkError> {
let len = self.neurons.len();
if from >= len || to >= len {
return Err(MastishkError::InvalidCircuit(format!(
"spiking synapse {from}->{to} out of bounds (neuron count: {len})"
)));
}
self.synapses.push(SpikingSynapse {
from,
to,
weight,
delay_ms,
});
Ok(())
}
pub fn tick(&mut self, dt_ms: f32) -> Result<(), MastishkError> {
validate_dt(dt_ms)?;
self.inputs.iter_mut().for_each(|x| *x = 0.0);
self.inputs.resize(self.neurons.len(), 0.0);
for syn in &self.synapses {
if syn.from < self.neurons.len()
&& syn.to < self.neurons.len()
&& let Some(t) = self.spike_times[syn.from]
{
let elapsed = self.time_ms - t;
if elapsed >= syn.delay_ms && elapsed < syn.delay_ms + dt_ms {
self.inputs[syn.to] += syn.weight;
}
}
}
self.last_spiked.clear();
for (i, neuron) in self.neurons.iter_mut().enumerate() {
if neuron.tick(self.inputs[i], dt_ms) {
self.spike_times[i] = Some(self.time_ms);
self.last_spiked.push(i);
}
}
if self.stdp.is_some() && !self.last_spiked.is_empty() {
self.apply_stdp();
}
self.time_ms += dt_ms;
tracing::trace!(
time_ms = self.time_ms,
spikes = self.last_spiked.len(),
"spiking network tick"
);
Ok(())
}
#[must_use]
pub fn last_spikes(&self) -> &[usize] {
&self.last_spiked
}
fn apply_stdp(&mut self) {
let rule = match &self.stdp {
Some(r) => r.clone(),
None => return,
};
for syn in &mut self.synapses {
let pre_time = self.spike_times[syn.from];
let post_time = self.spike_times[syn.to];
if let (Some(t_pre), Some(t_post)) = (pre_time, post_time) {
let delta_t = t_post - t_pre;
if delta_t > 0.0 && delta_t < rule.tau_plus * 3.0 {
syn.weight += rule.a_plus * (-delta_t / rule.tau_plus).exp();
} else if delta_t < 0.0 && -delta_t < rule.tau_minus * 3.0 {
syn.weight -= rule.a_minus * (delta_t / rule.tau_minus).exp();
}
syn.weight = syn.weight.clamp(-2.0, 2.0);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_izhikevich_rs_spikes() {
let mut n = IzhikevichNeuron::regular_spiking();
let mut spike_count = 0;
for _ in 0..1000 {
if n.tick(14.0, 0.5) {
spike_count += 1;
}
}
assert!(
spike_count > 5,
"RS neuron should spike with I=14, got {spike_count}"
);
}
#[test]
fn test_izhikevich_no_spike_without_input() {
let mut n = IzhikevichNeuron::regular_spiking();
let mut spiked = false;
for _ in 0..1000 {
if n.tick(0.0, 0.5) {
spiked = true;
}
}
assert!(!spiked, "RS neuron should not spike with I=0");
}
#[test]
fn test_izhikevich_presets() {
let rs = IzhikevichNeuron::regular_spiking();
let fs = IzhikevichNeuron::fast_spiking();
let ch = IzhikevichNeuron::chattering();
let ib = IzhikevichNeuron::intrinsically_bursting();
assert!((rs.a - 0.02).abs() < f32::EPSILON);
assert!((fs.a - 0.1).abs() < f32::EPSILON);
assert!((ch.c - (-50.0)).abs() < f32::EPSILON);
assert!((ib.d - 4.0).abs() < f32::EPSILON);
}
#[test]
fn test_lif_spikes() {
let mut n = LifNeuron::default_params();
let mut spike_count = 0;
for _ in 0..1000 {
if n.tick(2.0, 0.5) {
spike_count += 1;
}
}
assert!(spike_count > 0, "LIF should spike with I=2");
}
#[test]
fn test_lif_no_spike_without_input() {
let mut n = LifNeuron::default_params();
let mut spiked = false;
for _ in 0..1000 {
if n.tick(0.0, 0.5) {
spiked = true;
}
}
assert!(!spiked);
}
#[test]
fn test_spiking_neuron_enum() {
let mut n = SpikingNeuron::Izhikevich(IzhikevichNeuron::regular_spiking());
let mut spiked = false;
for _ in 0..1000 {
if n.tick(14.0, 0.5) {
spiked = true;
break;
}
}
assert!(spiked);
}
#[test]
fn test_spiking_network_propagates() {
let mut net = SpikingNetwork::new();
let a = net.add_neuron(SpikingNeuron::Izhikevich(
IzhikevichNeuron::regular_spiking(),
));
let b = net.add_neuron(SpikingNeuron::Lif(LifNeuron::default_params()));
net.add_synapse(a, b, 5.0, 1.0).unwrap();
net.neurons[a] = SpikingNeuron::Izhikevich(IzhikevichNeuron {
v: 29.0,
..IzhikevichNeuron::regular_spiking()
});
let mut b_spiked = false;
for _ in 0..100 {
net.tick(0.5).unwrap();
if net.last_spikes().contains(&b) {
b_spiked = true;
break;
}
}
assert!(
b_spiked || net.time_ms > 10.0,
"B should receive A's spikes"
);
}
#[test]
fn test_stdp_strengthens_causal() {
let mut net = SpikingNetwork::new();
let a = net.add_neuron(SpikingNeuron::Izhikevich(
IzhikevichNeuron::regular_spiking(),
));
let b = net.add_neuron(SpikingNeuron::Izhikevich(
IzhikevichNeuron::regular_spiking(),
));
net.add_synapse(a, b, 0.5, 0.5).unwrap();
net.stdp = Some(StdpRule::default());
let initial_weight = net.synapses[0].weight;
net.neurons[a] = SpikingNeuron::Izhikevich(IzhikevichNeuron {
v: 31.0,
..IzhikevichNeuron::regular_spiking()
});
net.tick(0.5).unwrap();
net.neurons[b] = SpikingNeuron::Izhikevich(IzhikevichNeuron {
v: 31.0,
..IzhikevichNeuron::regular_spiking()
});
net.tick(0.5).unwrap();
assert!(
net.synapses[0].weight >= initial_weight,
"STDP should strengthen causal synapse: {} vs {}",
net.synapses[0].weight,
initial_weight
);
}
#[test]
fn test_network_negative_dt_rejected() {
let mut net = SpikingNetwork::new();
assert!(net.tick(-1.0).is_err());
}
#[test]
fn test_network_invalid_synapse() {
let mut net = SpikingNetwork::new();
net.add_neuron(SpikingNeuron::Lif(LifNeuron::default_params()));
assert!(net.add_synapse(0, 99, 1.0, 1.0).is_err());
}
#[test]
fn test_serde_roundtrip() {
let mut net = SpikingNetwork::new();
net.add_neuron(SpikingNeuron::Izhikevich(
IzhikevichNeuron::regular_spiking(),
));
net.add_neuron(SpikingNeuron::Lif(LifNeuron::default_params()));
net.add_synapse(0, 1, 0.5, 1.0).unwrap();
let json = serde_json::to_string(&net).unwrap();
let net2: SpikingNetwork = serde_json::from_str(&json).unwrap();
assert_eq!(net2.neurons.len(), 2);
assert_eq!(net2.synapses.len(), 1);
}
}