use super::{Compartment, Dendrite};
use crate::Result;
#[derive(Debug, Clone)]
pub struct DendriticTree {
branches: Vec<Dendrite>,
soma: Compartment,
synapses_per_branch: usize,
soma_threshold: f32,
}
impl DendriticTree {
pub fn new(num_branches: usize) -> Self {
Self::with_parameters(num_branches, 5, 20.0, 100)
}
pub fn with_parameters(
num_branches: usize,
nmda_threshold: u8,
coincidence_window_ms: f32,
synapses_per_branch: usize,
) -> Self {
let branches = (0..num_branches)
.map(|_| Dendrite::new(nmda_threshold, coincidence_window_ms))
.collect();
Self {
branches,
soma: Compartment::new(),
synapses_per_branch,
soma_threshold: 0.5,
}
}
pub fn receive_input(&mut self, branch: usize, synapse: usize, timestamp: u64) -> Result<()> {
if branch >= self.branches.len() {
return Err(crate::NervousSystemError::CompartmentOutOfBounds(branch));
}
if synapse >= self.synapses_per_branch {
return Err(crate::NervousSystemError::SynapseOutOfBounds(synapse));
}
self.branches[branch].receive_spike(synapse, timestamp);
Ok(())
}
pub fn step(&mut self, current_time: u64, dt: f32) -> f32 {
for branch in &mut self.branches {
branch.update(current_time, dt);
}
let mut branch_input = 0.0;
for branch in &self.branches {
branch_input += branch.plateau_amplitude() * 0.1;
branch_input += branch.membrane() * 0.01;
}
self.soma.step(branch_input, dt);
self.soma.membrane()
}
pub fn is_spiking(&self) -> bool {
self.soma.is_active(self.soma_threshold)
}
pub fn soma_membrane(&self) -> f32 {
self.soma.membrane()
}
pub fn num_branches(&self) -> usize {
self.branches.len()
}
pub fn branch(&self, index: usize) -> Option<&Dendrite> {
self.branches.get(index)
}
pub fn active_branch_count(&self) -> usize {
self.branches.iter().filter(|b| b.has_plateau()).count()
}
pub fn reset(&mut self) {
self.soma.reset();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tree_creation() {
let tree = DendriticTree::new(5);
assert_eq!(tree.num_branches(), 5);
assert_eq!(tree.soma_membrane(), 0.0);
}
#[test]
fn test_single_branch_input() {
let mut tree = DendriticTree::new(3);
for i in 0..6 {
tree.receive_input(0, i, 100).unwrap();
}
let soma_out = tree.step(100, 1.0);
assert!(soma_out > 0.0);
assert_eq!(tree.active_branch_count(), 1);
}
#[test]
fn test_multi_branch_integration() {
let mut tree = DendriticTree::new(3);
for branch in 0..3 {
for synapse in 0..6 {
tree.receive_input(branch, synapse, 100).unwrap();
}
}
tree.step(100, 1.0);
assert_eq!(tree.active_branch_count(), 3);
assert!(tree.soma_membrane() > 0.0);
}
#[test]
fn test_soma_spiking() {
let mut tree = DendriticTree::new(10);
for branch in 0..10 {
for synapse in 0..6 {
tree.receive_input(branch, synapse, 100).unwrap();
}
}
for t in 0..20 {
tree.step(100 + t * 10, 10.0);
}
assert!(tree.soma_membrane() > 0.3);
}
#[test]
fn test_invalid_branch_index() {
let mut tree = DendriticTree::new(3);
let result = tree.receive_input(5, 0, 100);
assert!(result.is_err());
}
#[test]
fn test_invalid_synapse_index() {
let tree_params = DendriticTree::with_parameters(3, 5, 20.0, 50);
let mut tree = tree_params;
let result = tree.receive_input(0, 100, 100);
assert!(result.is_err());
}
#[test]
fn test_branch_access() {
let tree = DendriticTree::new(5);
assert!(tree.branch(0).is_some());
assert!(tree.branch(4).is_some());
assert!(tree.branch(5).is_none());
}
#[test]
fn test_temporal_integration() {
let mut tree = DendriticTree::new(2);
for i in 0..6 {
tree.receive_input(0, i, 100).unwrap();
}
tree.step(100, 1.0);
for i in 0..6 {
tree.receive_input(1, i, 150).unwrap();
}
tree.step(150, 1.0);
let active = tree.active_branch_count();
assert!(active >= 1); }
#[test]
fn test_reset() {
let mut tree = DendriticTree::new(3);
for branch in 0..3 {
for synapse in 0..6 {
tree.receive_input(branch, synapse, 100).unwrap();
}
}
tree.step(100, 1.0);
tree.reset();
assert_eq!(tree.soma_membrane(), 0.0);
}
}