use super::plateau::PlateauPotential;
use std::collections::VecDeque;
#[derive(Debug, Clone, Copy)]
struct SpikeEvent {
synapse_id: usize,
timestamp: u64,
}
#[derive(Debug, Clone)]
pub struct Dendrite {
membrane: f32,
calcium: f32,
nmda_threshold: u8,
plateau: PlateauPotential,
active_synapses: VecDeque<SpikeEvent>,
coincidence_window_ms: f32,
max_synapses: usize,
}
impl Dendrite {
pub fn new(nmda_threshold: u8, coincidence_window_ms: f32) -> Self {
Self {
membrane: 0.0,
calcium: 0.0,
nmda_threshold,
plateau: PlateauPotential::new(200.0), active_synapses: VecDeque::new(),
coincidence_window_ms,
max_synapses: 1000,
}
}
pub fn with_plateau_duration(
nmda_threshold: u8,
coincidence_window_ms: f32,
plateau_duration_ms: f32,
) -> Self {
Self {
membrane: 0.0,
calcium: 0.0,
nmda_threshold,
plateau: PlateauPotential::new(plateau_duration_ms),
active_synapses: VecDeque::new(),
coincidence_window_ms,
max_synapses: 1000,
}
}
pub fn receive_spike(&mut self, synapse_id: usize, timestamp: u64) {
self.active_synapses.push_back(SpikeEvent {
synapse_id,
timestamp,
});
if self.active_synapses.len() > self.max_synapses {
self.active_synapses.pop_front();
}
self.membrane += 0.01;
self.membrane = self.membrane.min(1.0);
}
pub fn update(&mut self, current_time: u64, dt: f32) -> bool {
let window_start = current_time.saturating_sub(self.coincidence_window_ms as u64);
while let Some(spike) = self.active_synapses.front() {
if spike.timestamp < window_start {
self.active_synapses.pop_front();
} else {
break;
}
}
let mut unique_synapses = std::collections::HashSet::new();
for spike in &self.active_synapses {
unique_synapses.insert(spike.synapse_id);
}
let mut plateau_triggered = false;
if unique_synapses.len() >= self.nmda_threshold as usize {
if !self.plateau.is_active() {
self.plateau.trigger();
plateau_triggered = true;
}
}
self.plateau.update(dt);
self.membrane *= 0.95_f32.powf(dt / 10.0);
if self.plateau.is_active() {
self.calcium += 0.01 * dt;
self.calcium = self.calcium.min(1.0);
} else {
self.calcium *= 0.99_f32.powf(dt / 10.0);
}
plateau_triggered
}
pub fn has_plateau(&self) -> bool {
self.plateau.is_active()
}
pub fn membrane(&self) -> f32 {
self.membrane
}
pub fn calcium(&self) -> f32 {
self.calcium
}
pub fn active_synapse_count(&self) -> usize {
let mut unique = std::collections::HashSet::new();
for spike in &self.active_synapses {
unique.insert(spike.synapse_id);
}
unique.len()
}
pub fn plateau_amplitude(&self) -> f32 {
self.plateau.amplitude()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dendrite_creation() {
let dendrite = Dendrite::new(5, 20.0);
assert_eq!(dendrite.nmda_threshold, 5);
assert_eq!(dendrite.coincidence_window_ms, 20.0);
}
#[test]
fn test_single_spike_no_plateau() {
let mut dendrite = Dendrite::new(5, 20.0);
dendrite.receive_spike(0, 100);
let triggered = dendrite.update(100, 1.0);
assert!(!triggered);
assert!(!dendrite.has_plateau());
}
#[test]
fn test_coincidence_triggers_plateau() {
let mut dendrite = Dendrite::new(5, 20.0);
for i in 0..6 {
dendrite.receive_spike(i, 100);
}
let triggered = dendrite.update(100, 1.0);
assert!(triggered);
assert!(dendrite.has_plateau());
}
#[test]
fn test_coincidence_window() {
let mut dendrite = Dendrite::new(5, 20.0);
dendrite.receive_spike(0, 100);
dendrite.receive_spike(1, 110);
dendrite.receive_spike(2, 120);
dendrite.receive_spike(3, 130); dendrite.receive_spike(4, 135);
let triggered = dendrite.update(120, 1.0);
assert!(triggered);
}
#[test]
fn test_spikes_outside_window_ignored() {
let mut dendrite = Dendrite::new(5, 20.0);
dendrite.receive_spike(0, 100);
dendrite.receive_spike(1, 110);
dendrite.receive_spike(2, 150); dendrite.receive_spike(3, 160);
dendrite.receive_spike(4, 170);
let triggered = dendrite.update(170, 1.0);
assert!(!triggered);
}
#[test]
fn test_active_synapse_count() {
let mut dendrite = Dendrite::new(5, 20.0);
dendrite.receive_spike(0, 100);
dendrite.receive_spike(0, 101); dendrite.receive_spike(1, 102);
dendrite.receive_spike(2, 103);
dendrite.update(103, 1.0);
assert_eq!(dendrite.active_synapse_count(), 3);
}
#[test]
fn test_plateau_duration() {
let mut dendrite = Dendrite::with_plateau_duration(5, 20.0, 100.0);
for i in 0..6 {
dendrite.receive_spike(i, 100);
}
dendrite.update(100, 1.0);
assert!(dendrite.has_plateau());
dendrite.update(150, 50.0);
assert!(dendrite.has_plateau());
dendrite.update(210, 60.0);
assert!(!dendrite.has_plateau());
}
#[test]
fn test_calcium_during_plateau() {
let mut dendrite = Dendrite::new(5, 20.0);
for i in 0..6 {
dendrite.receive_spike(i, 100);
}
dendrite.update(100, 1.0);
let initial_calcium = dendrite.calcium();
dendrite.update(110, 10.0);
assert!(dendrite.calcium() > initial_calcium);
}
}