use super::{SimTime, Spike};
use rayon::prelude::*;
use std::collections::VecDeque;
const PARALLEL_THRESHOLD: usize = 2000;
#[derive(Debug, Clone)]
pub struct NeuronConfig {
pub tau_membrane: f64,
pub v_rest: f64,
pub v_reset: f64,
pub threshold: f64,
pub t_refrac: f64,
pub resistance: f64,
pub threshold_adapt: f64,
pub tau_threshold: f64,
pub homeostatic: bool,
pub target_rate: f64,
pub tau_homeostatic: f64,
}
impl Default for NeuronConfig {
fn default() -> Self {
Self {
tau_membrane: 20.0,
v_rest: 0.0,
v_reset: 0.0,
threshold: 1.0,
t_refrac: 2.0,
resistance: 1.0,
threshold_adapt: 0.1,
tau_threshold: 100.0,
homeostatic: true,
target_rate: 0.01,
tau_homeostatic: 1000.0,
}
}
}
#[derive(Debug, Clone)]
pub struct NeuronState {
pub v: f64,
pub threshold: f64,
pub refrac_remaining: f64,
pub last_spike_time: f64,
pub spike_rate: f64,
}
impl Default for NeuronState {
fn default() -> Self {
Self {
v: 0.0,
threshold: 1.0,
refrac_remaining: 0.0,
last_spike_time: f64::NEG_INFINITY,
spike_rate: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct LIFNeuron {
pub id: usize,
pub config: NeuronConfig,
pub state: NeuronState,
}
impl LIFNeuron {
pub fn new(id: usize) -> Self {
Self {
id,
config: NeuronConfig::default(),
state: NeuronState::default(),
}
}
pub fn with_config(id: usize, config: NeuronConfig) -> Self {
let mut state = NeuronState::default();
state.threshold = config.threshold;
Self { id, config, state }
}
pub fn reset(&mut self) {
self.state = NeuronState {
threshold: self.config.threshold,
..NeuronState::default()
};
}
pub fn step(&mut self, current: f64, dt: f64, time: SimTime) -> bool {
if self.state.refrac_remaining > 0.0 {
self.state.refrac_remaining -= dt;
return false;
}
let dv = (-self.state.v + self.config.v_rest + self.config.resistance * current)
/ self.config.tau_membrane
* dt;
self.state.v += dv;
if self.state.threshold > self.config.threshold {
let d_thresh =
-(self.state.threshold - self.config.threshold) / self.config.tau_threshold * dt;
self.state.threshold += d_thresh;
}
if self.state.v >= self.state.threshold {
self.state.v = self.config.v_reset;
self.state.refrac_remaining = self.config.t_refrac;
self.state.last_spike_time = time;
self.state.threshold += self.config.threshold_adapt;
let alpha = (dt / self.config.tau_homeostatic).min(1.0);
self.state.spike_rate = self.state.spike_rate * (1.0 - alpha) + alpha;
return true;
}
self.state.spike_rate *= 1.0 - dt / self.config.tau_homeostatic;
if self.config.homeostatic {
let rate_error = self.state.spike_rate - self.config.target_rate;
let d_base_thresh = rate_error * dt / self.config.tau_homeostatic;
}
false
}
pub fn inject_spike(&mut self, time: SimTime) {
self.state.last_spike_time = time;
let alpha = (1.0 / self.config.tau_homeostatic).min(1.0);
self.state.spike_rate = self.state.spike_rate * (1.0 - alpha) + alpha;
}
pub fn time_since_spike(&self, current_time: SimTime) -> f64 {
current_time - self.state.last_spike_time
}
pub fn is_refractory(&self) -> bool {
self.state.refrac_remaining > 0.0
}
pub fn membrane_potential(&self) -> f64 {
self.state.v
}
pub fn set_membrane_potential(&mut self, v: f64) {
self.state.v = v;
}
pub fn threshold(&self) -> f64 {
self.state.threshold
}
}
#[derive(Debug, Clone)]
pub struct SpikeTrain {
pub neuron_id: usize,
pub spike_times: Vec<SimTime>,
pub max_window: f64,
}
impl SpikeTrain {
pub fn new(neuron_id: usize) -> Self {
Self {
neuron_id,
spike_times: Vec::new(),
max_window: 1000.0, }
}
pub fn with_window(neuron_id: usize, max_window: f64) -> Self {
Self {
neuron_id,
spike_times: Vec::new(),
max_window,
}
}
pub fn record_spike(&mut self, time: SimTime) {
self.spike_times.push(time);
let cutoff = time - self.max_window;
self.spike_times.retain(|&t| t >= cutoff);
}
pub fn clear(&mut self) {
self.spike_times.clear();
}
pub fn count(&self) -> usize {
self.spike_times.len()
}
pub fn spike_rate(&self, window: f64) -> f64 {
if self.spike_times.is_empty() {
return 0.0;
}
let latest = self.spike_times.last().copied().unwrap_or(0.0);
let count = self
.spike_times
.iter()
.filter(|&&t| t >= latest - window)
.count();
count as f64 / window
}
pub fn mean_isi(&self) -> Option<f64> {
if self.spike_times.len() < 2 {
return None;
}
let mut total_isi = 0.0;
for i in 1..self.spike_times.len() {
total_isi += self.spike_times[i] - self.spike_times[i - 1];
}
Some(total_isi / (self.spike_times.len() - 1) as f64)
}
pub fn cv_isi(&self) -> Option<f64> {
let mean = self.mean_isi()?;
if mean == 0.0 {
return None;
}
let mut variance = 0.0;
for i in 1..self.spike_times.len() {
let isi = self.spike_times[i] - self.spike_times[i - 1];
variance += (isi - mean).powi(2);
}
variance /= (self.spike_times.len() - 1) as f64;
Some(variance.sqrt() / mean)
}
pub fn to_pattern(&self, start: SimTime, bin_size: f64, num_bins: usize) -> Vec<bool> {
let mut pattern = vec![false; num_bins];
if bin_size <= 0.0 || num_bins == 0 {
return pattern;
}
let end_time = start + bin_size * num_bins as f64;
for &spike_time in &self.spike_times {
if spike_time >= start && spike_time < end_time {
let offset = spike_time - start;
let bin_f64 = offset / bin_size;
if bin_f64 >= 0.0 && bin_f64 < num_bins as f64 {
let bin = bin_f64 as usize;
if bin < num_bins {
pattern[bin] = true;
}
}
}
}
pattern
}
#[inline]
fn is_sorted(times: &[f64]) -> bool {
times.windows(2).all(|w| w[0] <= w[1])
}
pub fn cross_correlation(&self, other: &SpikeTrain, max_lag: f64, bin_size: f64) -> Vec<f64> {
if bin_size <= 0.0 || max_lag <= 0.0 {
return vec![0.0];
}
let num_bins_f64 = 2.0 * max_lag / bin_size + 1.0;
let num_bins = if num_bins_f64 > 0.0 && num_bins_f64 < usize::MAX as f64 {
(num_bins_f64 as usize).min(100_000) } else {
return vec![0.0];
};
let mut correlation = vec![0.0; num_bins];
if self.spike_times.is_empty() || other.spike_times.is_empty() {
return correlation;
}
let t1_owned: Vec<f64>;
let t2_owned: Vec<f64>;
let t1: &[f64] = if Self::is_sorted(&self.spike_times) {
&self.spike_times
} else {
t1_owned = {
let mut v = self.spike_times.clone();
v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
v
};
&t1_owned
};
let t2: &[f64] = if Self::is_sorted(&other.spike_times) {
&other.spike_times
} else {
t2_owned = {
let mut v = other.spike_times.clone();
v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
v
};
&t2_owned
};
let first_lower = t1[0] - max_lag;
let mut window_start = t2.partition_point(|&x| x < first_lower);
for &t1_spike in t1 {
let lower_bound = t1_spike - max_lag;
let upper_bound = t1_spike + max_lag;
while window_start < t2.len() && t2[window_start] < lower_bound {
window_start += 1;
}
let mut j = window_start;
while j < t2.len() && t2[j] <= upper_bound {
let lag = t1_spike - t2[j];
let bin = ((lag + max_lag) / bin_size) as usize;
if bin < num_bins {
correlation[bin] += 1.0;
}
j += 1;
}
}
let norm = ((self.count() * other.count()) as f64).sqrt();
if norm > 0.0 {
let inv_norm = 1.0 / norm;
for c in &mut correlation {
*c *= inv_norm;
}
}
correlation
}
}
#[derive(Debug, Clone)]
pub struct NeuronPopulation {
pub neurons: Vec<LIFNeuron>,
pub spike_trains: Vec<SpikeTrain>,
pub time: SimTime,
}
impl NeuronPopulation {
pub fn new(n: usize) -> Self {
let neurons: Vec<_> = (0..n).map(|i| LIFNeuron::new(i)).collect();
let spike_trains: Vec<_> = (0..n).map(|i| SpikeTrain::new(i)).collect();
Self {
neurons,
spike_trains,
time: 0.0,
}
}
pub fn with_config(n: usize, config: NeuronConfig) -> Self {
let neurons: Vec<_> = (0..n)
.map(|i| LIFNeuron::with_config(i, config.clone()))
.collect();
let spike_trains: Vec<_> = (0..n).map(|i| SpikeTrain::new(i)).collect();
Self {
neurons,
spike_trains,
time: 0.0,
}
}
pub fn size(&self) -> usize {
self.neurons.len()
}
pub fn step(&mut self, currents: &[f64], dt: f64) -> Vec<Spike> {
self.time += dt;
let time = self.time;
if self.neurons.len() >= PARALLEL_THRESHOLD {
let spike_flags: Vec<bool> = self
.neurons
.par_iter_mut()
.enumerate()
.map(|(i, neuron)| {
let current = currents.get(i).copied().unwrap_or(0.0);
neuron.step(current, dt, time)
})
.collect();
let mut spikes = Vec::new();
for (i, &spiked) in spike_flags.iter().enumerate() {
if spiked {
spikes.push(Spike { neuron_id: i, time });
self.spike_trains[i].record_spike(time);
}
}
spikes
} else {
let mut spikes = Vec::new();
for (i, neuron) in self.neurons.iter_mut().enumerate() {
let current = currents.get(i).copied().unwrap_or(0.0);
if neuron.step(current, dt, time) {
spikes.push(Spike { neuron_id: i, time });
self.spike_trains[i].record_spike(time);
}
}
spikes
}
}
pub fn reset(&mut self) {
self.time = 0.0;
for neuron in &mut self.neurons {
neuron.reset();
}
for train in &mut self.spike_trains {
train.clear();
}
}
pub fn population_rate(&self, window: f64) -> f64 {
let total: f64 = self.spike_trains.iter().map(|t| t.spike_rate(window)).sum();
total / self.neurons.len() as f64
}
pub fn synchrony(&self, window: f64) -> f64 {
let mut all_spikes = Vec::new();
let cutoff = self.time - window;
for train in &self.spike_trains {
for &t in &train.spike_times {
if t >= cutoff {
all_spikes.push(Spike {
neuron_id: train.neuron_id,
time: t,
});
}
}
}
super::compute_synchrony(&all_spikes, window / 10.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lif_neuron_creation() {
let neuron = LIFNeuron::new(0);
assert_eq!(neuron.id, 0);
assert_eq!(neuron.state.v, 0.0);
}
#[test]
fn test_lif_neuron_spike() {
let mut neuron = LIFNeuron::new(0);
let mut spiked = false;
for i in 0..100 {
if neuron.step(2.0, 1.0, i as f64) {
spiked = true;
break;
}
}
assert!(spiked);
assert!(neuron.is_refractory());
}
#[test]
fn test_spike_train() {
let mut train = SpikeTrain::new(0);
train.record_spike(10.0);
train.record_spike(20.0);
train.record_spike(30.0);
assert_eq!(train.count(), 3);
let mean_isi = train.mean_isi().unwrap();
assert!((mean_isi - 10.0).abs() < 0.001);
}
#[test]
fn test_neuron_population() {
let mut pop = NeuronPopulation::new(100);
let currents = vec![1.5; 100];
let mut total_spikes = 0;
for _ in 0..100 {
let spikes = pop.step(¤ts, 1.0);
total_spikes += spikes.len();
}
assert!(total_spikes > 0);
}
#[test]
fn test_spike_train_pattern() {
let mut train = SpikeTrain::new(0);
train.record_spike(1.0);
train.record_spike(3.0);
train.record_spike(7.0);
let pattern = train.to_pattern(0.0, 1.0, 10);
assert_eq!(pattern.len(), 10);
assert!(pattern[1]); assert!(pattern[3]); assert!(pattern[7]); assert!(!pattern[0]); }
}