use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::density::DensityField;
use crate::neuron::NeuronArrays;
use crate::synapse::SynapseStore;
#[derive(Clone, Copy, Debug)]
pub struct SpikeArrival {
pub target: u32,
pub current: i16,
pub arrival_time: u64,
pub source: u32,
}
impl PartialEq for SpikeArrival {
fn eq(&self, other: &Self) -> bool {
self.arrival_time == other.arrival_time
}
}
impl Eq for SpikeArrival {}
impl PartialOrd for SpikeArrival {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SpikeArrival {
fn cmp(&self, other: &Self) -> Ordering {
other.arrival_time.cmp(&self.arrival_time)
}
}
#[derive(Clone, Debug)]
pub struct CascadeConfig {
pub resting_potential: i16,
pub reset_potential: i16,
pub spike_threshold: i16,
pub propagation_speed: f32,
pub leak_tau: u64,
pub refractory_us: u64,
pub weight_scale: i16,
pub gray_threshold: f32,
}
impl Default for CascadeConfig {
fn default() -> Self {
Self {
resting_potential: -17920, reset_potential: -16640, spike_threshold: -14080, propagation_speed: 10.0, leak_tau: 20_000, refractory_us: 2_000, weight_scale: 64,
gray_threshold: 0.5,
}
}
}
pub struct CascadePool {
pub name: String,
pub neurons: NeuronArrays,
pub synapses: SynapseStore,
pub density: DensityField,
pending: BinaryHeap<SpikeArrival>,
last_update: Vec<u64>,
refractory_until: Vec<u64>,
pub sim_time: u64,
pub total_spikes: u64,
pub total_events: u64,
pub config: CascadeConfig,
}
impl CascadePool {
pub fn new(
name: impl Into<String>,
neurons: NeuronArrays,
synapses: SynapseStore,
bounds: [f32; 3],
config: CascadeConfig,
) -> Self {
let n = neurons.len();
let density = DensityField::new([8, 8, 8], bounds);
Self {
name: name.into(),
neurons,
synapses,
density,
pending: BinaryHeap::new(),
last_update: vec![0; n],
refractory_until: vec![0; n],
sim_time: 0,
total_spikes: 0,
total_events: 0,
config,
}
}
pub fn from_pool(pool: crate::pool::NeuronPool, config: CascadeConfig) -> Self {
let bounds = pool.spatial_bounds.unwrap_or([10.0, 10.0, 10.0]);
Self::new(pool.name, pool.neurons, pool.synapses, bounds, config)
}
#[inline]
pub fn n_neurons(&self) -> usize {
self.neurons.len()
}
#[inline]
pub fn pending_count(&self) -> usize {
self.pending.len()
}
pub fn inject(&mut self, neuron: u32, current: i16, time: u64) {
self.pending.push(SpikeArrival {
target: neuron,
current,
arrival_time: time,
source: u32::MAX, });
}
pub fn run_until(&mut self, until_time: u64) -> usize {
let mut processed = 0;
while let Some(arrival) = self.pending.peek() {
if arrival.arrival_time > until_time {
break;
}
let arrival = self.pending.pop().unwrap();
self.sim_time = arrival.arrival_time;
self.process_arrival(arrival);
processed += 1;
}
self.total_events += processed as u64;
processed
}
pub fn step(&mut self) -> Option<SpikeArrival> {
let arrival = self.pending.pop()?;
self.sim_time = arrival.arrival_time;
self.process_arrival(arrival);
self.total_events += 1;
Some(arrival)
}
fn process_arrival(&mut self, arrival: SpikeArrival) {
let i = arrival.target as usize;
if i >= self.neurons.len() {
return;
}
if self.refractory_until[i] > self.sim_time {
return; }
let dt = self.sim_time.saturating_sub(self.last_update[i]);
self.apply_leak(i, dt);
self.last_update[i] = self.sim_time;
self.neurons.membrane[i] = self.neurons.membrane[i].saturating_add(arrival.current);
if self.neurons.membrane[i] >= self.config.spike_threshold {
self.fire(i);
}
}
fn apply_leak(&mut self, neuron: usize, dt: u64) {
if dt == 0 || self.config.leak_tau == 0 {
return;
}
let membrane = self.neurons.membrane[neuron] as i32;
let resting = self.config.resting_potential as i32;
let diff = membrane - resting;
let effective_dt = dt.min(self.config.leak_tau);
let decay = (diff * effective_dt as i32) / self.config.leak_tau as i32;
self.neurons.membrane[neuron] = (membrane - decay).clamp(-32768, 32767) as i16;
}
fn fire(&mut self, neuron: usize) {
self.neurons.membrane[neuron] = self.config.reset_potential;
self.refractory_until[neuron] = self.sim_time + self.config.refractory_us;
self.neurons.trace[neuron] = self.neurons.trace[neuron].saturating_add(30);
for syn in self.synapses.outgoing(neuron as u32) {
let tgt = syn.target as usize;
if tgt >= self.neurons.len() {
continue;
}
let delay = self.axon_delay(neuron, tgt);
let base_current = syn.weight as i16 * self.config.weight_scale;
let current = self.attenuate_current(base_current, neuron, tgt);
self.pending.push(SpikeArrival {
target: syn.target as u32,
current,
arrival_time: self.sim_time + delay,
source: neuron as u32,
});
}
self.total_spikes += 1;
}
pub fn axon_delay(&self, src: usize, tgt: usize) -> u64 {
let src_pos = self.neurons.axon_terminal[src];
let tgt_pos = self.neurons.soma_position[tgt];
let dx = tgt_pos[0] - src_pos[0];
let dy = tgt_pos[1] - src_pos[1];
let dz = tgt_pos[2] - src_pos[2];
let distance = (dx * dx + dy * dy + dz * dz).sqrt();
let mid = [
(src_pos[0] + tgt_pos[0]) / 2.0,
(src_pos[1] + tgt_pos[1]) / 2.0,
(src_pos[2] + tgt_pos[2]) / 2.0,
];
let cond = self.density.conductivity_at(mid);
let delay = distance * self.config.propagation_speed * cond.delay_factor;
(delay as u64).max(1)
}
pub fn attenuate_current(&self, current: i16, src: usize, tgt: usize) -> i16 {
let src_pos = self.neurons.axon_terminal[src];
let tgt_pos = self.neurons.soma_position[tgt];
let mid = [
(src_pos[0] + tgt_pos[0]) / 2.0,
(src_pos[1] + tgt_pos[1]) / 2.0,
(src_pos[2] + tgt_pos[2]) / 2.0,
];
let cond = self.density.conductivity_at(mid);
let scale = 1.0 - cond.attenuation;
((current as f32) * scale) as i16
}
pub fn update_density(&mut self) {
self.density.update_from_positions(&self.neurons.soma_position);
}
pub fn migrate(&mut self, neuron: usize, rate: f32) {
if neuron >= self.neurons.len() {
return;
}
let outgoing = self.synapses.outgoing(neuron as u32);
if outgoing.is_empty() {
return;
}
let mut attraction = [0.0f32; 3];
let mut count = 0.0;
for syn in outgoing {
let tgt = syn.target as usize;
if tgt < self.neurons.len() {
attraction[0] += self.neurons.soma_position[tgt][0];
attraction[1] += self.neurons.soma_position[tgt][1];
attraction[2] += self.neurons.soma_position[tgt][2];
count += 1.0;
}
}
if count > 0.0 {
attraction[0] /= count;
attraction[1] /= count;
attraction[2] /= count;
let pos = self.neurons.soma_position[neuron];
let dx = (attraction[0] - pos[0]) * rate;
let dy = (attraction[1] - pos[1]) * rate;
let dz = (attraction[2] - pos[2]) * rate;
self.neurons.soma_position[neuron][0] += dx;
self.neurons.soma_position[neuron][1] += dy;
self.neurons.soma_position[neuron][2] += dz;
self.neurons.axon_terminal[neuron][0] += dx * 0.5;
self.neurons.axon_terminal[neuron][1] += dy * 0.5;
self.neurons.axon_terminal[neuron][2] += dz * 0.5;
}
}
pub fn reset(&mut self) {
self.pending.clear();
self.last_update.fill(0);
self.refractory_until.fill(0);
self.sim_time = 0;
self.total_spikes = 0;
self.total_events = 0;
for m in &mut self.neurons.membrane {
*m = self.config.resting_potential;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::neuron::NeuronArrays;
use crate::synapse::{Synapse, SynapseStore};
fn make_test_pool() -> CascadePool {
let neurons = NeuronArrays::new(3, 3, -17920, -14080);
let synapses = SynapseStore::empty(3);
CascadePool::new("test", neurons, synapses, [10.0, 10.0, 10.0], CascadeConfig::default())
}
#[test]
fn test_cascade_creation() {
let pool = make_test_pool();
assert_eq!(pool.n_neurons(), 3);
assert_eq!(pool.pending_count(), 0);
assert_eq!(pool.sim_time, 0);
}
#[test]
fn test_inject_and_process() {
let mut pool = make_test_pool();
pool.inject(0, 5000, 100);
assert_eq!(pool.pending_count(), 1);
let processed = pool.run_until(100);
assert_eq!(processed, 1);
assert_eq!(pool.sim_time, 100);
}
#[test]
fn test_spike_propagation() {
let mut neurons = NeuronArrays::new(2, 2, -17920, -14080);
neurons.soma_position[0] = [0.0, 0.0, 0.0];
neurons.soma_position[1] = [1.0, 0.0, 0.0];
neurons.axon_terminal[0] = [0.5, 0.0, 0.0];
neurons.axon_terminal[1] = [1.0, 0.0, 0.0];
let exc_flags = crate::neuron::flags::encode(false, crate::neuron::NeuronProfile::RegularSpiking);
let edges = vec![
(0, Synapse::new(1, 100, 1, exc_flags)),
];
let synapses = SynapseStore::from_edges(2, edges);
let mut pool = CascadePool::new("test", neurons, synapses, [10.0, 10.0, 10.0], CascadeConfig::default());
pool.inject(0, 5000, 0);
pool.run_until(0);
assert!(pool.pending_count() > 0 || pool.total_spikes > 0);
}
#[test]
fn test_spike_arrival_ordering() {
let a1 = SpikeArrival { target: 0, current: 100, arrival_time: 50, source: 1 };
let a2 = SpikeArrival { target: 1, current: 100, arrival_time: 100, source: 1 };
assert!(a1 > a2);
}
#[test]
fn test_leak_decay() {
let mut pool = make_test_pool();
pool.neurons.membrane[0] = -10000;
pool.last_update[0] = 0;
pool.apply_leak(0, pool.config.leak_tau / 2);
assert!(pool.neurons.membrane[0] < -10000);
assert!(pool.neurons.membrane[0] > pool.config.resting_potential);
}
}