use crate::error::{NeuralError, Result};
use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct ExponentialSynapse {
pub tau: f32,
pub weight: f32,
pub g: f32,
pub e_rev: f32,
}
impl ExponentialSynapse {
pub fn new(tau: f32, weight: f32, e_rev: f32) -> Result<Self> {
if tau <= 0.0 {
return Err(NeuralError::InvalidArgument(format!(
"tau must be > 0, got {tau}"
)));
}
Ok(Self {
tau,
weight,
g: 0.0,
e_rev,
})
}
pub fn ampa(weight: f32) -> Self {
Self {
tau: 5.0,
weight,
g: 0.0,
e_rev: 0.0,
}
}
pub fn gaba_a(weight: f32) -> Self {
Self {
tau: 10.0,
weight,
g: 0.0,
e_rev: -70.0,
}
}
pub fn update(&mut self, spike: bool, dt: f32) -> f32 {
self.g *= (-dt / self.tau).exp();
if spike {
self.g += self.weight;
}
self.g
}
pub fn current(&self, v_post: f32) -> f32 {
self.g * (v_post - self.e_rev)
}
}
#[derive(Debug, Clone)]
pub struct AlphaSynapse {
pub tau_rise: f32,
pub tau_decay: f32,
pub weight: f32,
pub e_rev: f32,
pub x: f32,
pub g: f32,
}
impl AlphaSynapse {
pub fn new(tau_rise: f32, tau_decay: f32, weight: f32, e_rev: f32) -> Result<Self> {
if tau_rise <= 0.0 {
return Err(NeuralError::InvalidArgument(format!(
"tau_rise must be > 0, got {tau_rise}"
)));
}
if tau_decay <= 0.0 {
return Err(NeuralError::InvalidArgument(format!(
"tau_decay must be > 0, got {tau_decay}"
)));
}
Ok(Self {
tau_rise,
tau_decay,
weight,
e_rev,
x: 0.0,
g: 0.0,
})
}
pub fn update(&mut self, spike: bool, dt: f32) -> f32 {
let dx = -self.x / self.tau_rise;
let dg = -self.g / self.tau_decay + self.x;
self.x += dt * dx;
self.g += dt * dg;
if spike {
self.x += self.weight;
}
self.g.max(0.0)
}
pub fn current(&self, v_post: f32) -> f32 {
self.g * (v_post - self.e_rev)
}
}
#[derive(Debug, Clone)]
pub struct STDPSynapse {
pub w: f32,
pub tau_plus: f32,
pub tau_minus: f32,
pub a_plus: f32,
pub a_minus: f32,
pub w_max: f32,
pub x: f32,
pub y: f32,
pub e_rev: f32,
pub g: f32,
pub tau_g: f32,
}
impl STDPSynapse {
pub fn new(
w: f32,
tau_plus: f32,
tau_minus: f32,
a_plus: f32,
a_minus: f32,
w_max: f32,
) -> Result<Self> {
if tau_plus <= 0.0 || tau_minus <= 0.0 {
return Err(NeuralError::InvalidArgument(
"STDP time constants must be > 0".into(),
));
}
if w_max <= 0.0 {
return Err(NeuralError::InvalidArgument(
"w_max must be > 0".into(),
));
}
Ok(Self {
w: w.clamp(0.0, w_max),
tau_plus,
tau_minus,
a_plus,
a_minus,
w_max,
x: 0.0,
y: 0.0,
e_rev: 0.0,
g: 0.0,
tau_g: 5.0,
})
}
pub fn update(&mut self, pre_spike: bool, post_spike: bool, dt: f32) -> f32 {
self.x *= (-dt / self.tau_plus).exp();
self.y *= (-dt / self.tau_minus).exp();
self.g *= (-dt / self.tau_g).exp();
if pre_spike {
self.w = (self.w + self.a_plus * self.y).clamp(0.0, self.w_max);
self.x += 1.0;
self.g += self.w;
}
if post_spike {
self.w = (self.w + self.a_minus * self.x).clamp(0.0, self.w_max);
self.y += 1.0;
}
self.g
}
pub fn current(&self, v_post: f32) -> f32 {
self.g * (v_post - self.e_rev)
}
}
#[derive(Debug, Clone)]
pub struct SynapticDelay {
buffer: VecDeque<f32>,
delay: usize,
}
impl SynapticDelay {
pub fn new(delay: usize) -> Result<Self> {
if delay == 0 {
return Err(NeuralError::InvalidArgument(
"Synaptic delay must be at least 1 time step".into(),
));
}
let mut buffer = VecDeque::with_capacity(delay);
for _ in 0..delay {
buffer.push_back(0.0);
}
Ok(Self { buffer, delay })
}
pub fn push_pop(&mut self, value: f32) -> f32 {
self.buffer.push_back(value);
self.buffer.pop_front().unwrap_or(0.0)
}
pub fn delay(&self) -> usize {
self.delay
}
pub fn reset(&mut self) {
for v in self.buffer.iter_mut() {
*v = 0.0;
}
}
}
#[derive(Debug, Clone)]
pub struct SpikeBoolDelay {
buffer: VecDeque<bool>,
delay: usize,
}
impl SpikeBoolDelay {
pub fn new(delay: usize) -> Result<Self> {
if delay == 0 {
return Err(NeuralError::InvalidArgument(
"Spike delay must be ≥ 1".into(),
));
}
let mut buffer = VecDeque::with_capacity(delay);
for _ in 0..delay {
buffer.push_back(false);
}
Ok(Self { buffer, delay })
}
pub fn push_pop(&mut self, spike: bool) -> bool {
self.buffer.push_back(spike);
self.buffer.pop_front().unwrap_or(false)
}
pub fn delay(&self) -> usize {
self.delay
}
pub fn reset(&mut self) {
for v in self.buffer.iter_mut() {
*v = false;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exponential_synapse_zero_without_spikes() {
let mut s = ExponentialSynapse::ampa(1.0);
for _ in 0..100 {
s.update(false, 0.1);
}
assert!(s.g.abs() < 1e-6);
}
#[test]
fn exponential_synapse_increases_on_spike() {
let mut s = ExponentialSynapse::ampa(1.0);
s.update(true, 0.1);
assert!(s.g > 0.9, "g should be ~1.0 right after spike");
}
#[test]
fn exponential_synapse_decays() {
let mut s = ExponentialSynapse::ampa(1.0);
s.update(true, 0.1); let g_after_spike = s.g;
for _ in 0..1000 {
s.update(false, 0.1);
}
assert!(s.g < g_after_spike * 0.01, "conductance should decay");
}
#[test]
fn alpha_synapse_rises_then_decays() {
let mut s = AlphaSynapse::new(1.0, 5.0, 1.0, 0.0).expect("operation should succeed");
s.update(true, 0.1);
let mut peak = 0.0_f32;
for _ in 0..500 {
let g = s.update(false, 0.1);
peak = peak.max(g);
}
assert!(peak > 0.01, "alpha synapse should produce non-zero response");
assert!(s.g < 1e-4, "conductance should decay to near zero");
}
#[test]
fn stdp_potentiation_when_post_before_pre() {
let mut s = STDPSynapse::new(0.5, 20.0, 20.0, 0.01, -0.01, 1.0).expect("operation should succeed");
s.update(false, true, 1.0);
let w_before = s.w;
s.update(true, false, 1.0);
assert!(s.w >= w_before, "LTP expected when pre fires after post");
}
#[test]
fn synaptic_delay_delays_by_correct_steps() {
let mut d = SynapticDelay::new(3).expect("operation should succeed");
let out0 = d.push_pop(1.0);
let out1 = d.push_pop(0.0);
let out2 = d.push_pop(0.0);
let out3 = d.push_pop(0.0);
assert_eq!(out0, 0.0);
assert_eq!(out1, 0.0);
assert_eq!(out2, 0.0);
assert!((out3 - 1.0).abs() < 1e-6, "signal should arrive at step 3");
}
#[test]
fn synaptic_delay_rejects_zero() {
assert!(SynapticDelay::new(0).is_err());
}
#[test]
fn spike_bool_delay_delays_spikes() {
let mut d = SpikeBoolDelay::new(2).expect("operation should succeed");
let o0 = d.push_pop(true);
let o1 = d.push_pop(false);
let o2 = d.push_pop(false);
assert!(!o0);
assert!(!o1);
assert!(o2, "spike should emerge after 2 steps");
}
}