use rand::Rng; use serde::{Deserialize, Serialize};
use super::lif::LifNeuron; use super::izhikevich::IzhikevichNeuron; use super::rm_stdp::*; use super::modulators::NeuroModulators;
const WEIGHT_BUDGET: f32 = 2.0;
#[derive(Default, Serialize, Deserialize)]
pub struct SpikingNetwork {
pub neurons: Vec<LifNeuron>,
pub iz_neurons: Vec<IzhikevichNeuron>,
pub modulators: NeuroModulators,
pub global_step: i64,
pub input_spike_times: Vec<i64>,
pub predictive_state: [f32; crate::NUM_INPUT_CHANNELS],
}
impl SpikingNetwork {
pub fn new() -> Self {
let mut neurons: Vec<LifNeuron> = (0..16).map(|_| {
let mut n = LifNeuron::new();
n.weights = vec![1.5; crate::NUM_INPUT_CHANNELS];
n.last_spike_time = -1;
n
}).collect();
let mut rng = rand::thread_rng();
for i in 0..14 {
let ch = i / 2;
let neuron = &mut neurons[i];
neuron.weights[ch] = 0.8 + (rng.r#gen::<f32>() * 0.4);
if i % 2 == 0 {
neuron.threshold = 0.10 + (rng.r#gen::<f32>() * 0.04);
} else {
neuron.threshold = 0.06 + (rng.r#gen::<f32>() * 0.04);
}
neuron.base_threshold = neuron.threshold;
}
neurons[14].threshold = 0.50;
neurons[14].base_threshold = 0.50;
neurons[15].threshold = 0.50;
neurons[15].base_threshold = 0.50;
Self {
neurons,
iz_neurons: vec![IzhikevichNeuron::new_regular_spiking(); 5],
modulators: NeuroModulators::default(),
global_step: 0,
input_spike_times: vec![-1; crate::NUM_INPUT_CHANNELS],
predictive_state: [0.0; crate::NUM_INPUT_CHANNELS],
}
}
pub fn step(&mut self, stimuli: &[f32; crate::NUM_INPUT_CHANNELS], modulators: &NeuroModulators) -> Vec<usize> {
self.global_step += 1;
self.modulators = *modulators;
let stress_multiplier = (1.0 - self.modulators.cortisol).max(0.1);
let learning_rate = 0.5 * self.modulators.dopamine;
for neuron in &mut self.neurons {
let target_decay = 0.15 - (0.05 * self.modulators.acetylcholine);
neuron.decay_rate = target_decay;
let global_target = 0.20 - (0.05 * self.modulators.dopamine) + (0.15 * self.modulators.cortisol);
let target_threshold = (global_target + if neuron.last_spike { 0.005 } else { -0.001 })
.clamp(0.05, 0.50);
neuron.threshold += (target_threshold - neuron.threshold) * learning_rate;
neuron.threshold = neuron.threshold.clamp(0.05, 0.50);
}
const PRED_ALPHA: f32 = 0.1;
const PRED_ERR_WEIGHT: f32 = 0.5;
let mut pred_errors = [0.0_f32; crate::NUM_INPUT_CHANNELS];
for ch in 0..crate::NUM_INPUT_CHANNELS {
let s = stimuli[ch].abs().clamp(0.0, 1.0);
pred_errors[ch] = (s - self.predictive_state[ch]).abs();
self.predictive_state[ch] = PRED_ALPHA * s + (1.0 - PRED_ALPHA) * self.predictive_state[ch];
}
let mut rng = rand::thread_rng();
for (ch, &s) in stimuli.iter().enumerate() {
let abs_s = s.abs().clamp(0.0, 1.0);
if abs_s > 0.01 && rng.r#gen::<f32>() < abs_s {
self.input_spike_times[ch] = self.global_step;
}
}
for (i, neuron) in self.neurons.iter_mut().enumerate() {
let ch = i % crate::NUM_INPUT_CHANNELS;
let is_bull = i % 2 == 1;
let delta = stimuli[ch];
let polarity_match = if is_bull { delta > 0.0 } else { delta < 0.0 };
let mut total_current = 0.0;
if polarity_match {
let stim = delta.abs().clamp(0.0, 1.0);
let surprise = PRED_ERR_WEIGHT * pred_errors[ch];
total_current = neuron.weights[ch] * (stim + surprise) * 0.45 * stress_multiplier;
}
neuron.integrate(total_current);
}
let mut spike_ids = Vec::new();
for (i, neuron) in self.neurons.iter_mut().enumerate() {
if let Some(_peak_v) = neuron.check_fire() {
neuron.last_spike = true;
neuron.last_spike_time = self.global_step;
spike_ids.push(i);
} else {
neuron.last_spike = false;
}
}
if !spike_ids.is_empty() {
const INHIBITION_STRENGTH: f32 = 0.05;
for (i, neuron) in self.neurons.iter_mut().enumerate() {
if !spike_ids.contains(&i) {
neuron.membrane_potential = (neuron.membrane_potential - INHIBITION_STRENGTH).max(0.0);
}
}
}
let dopamine_lr = learning_rate;
self.apply_stdp(stimuli, dopamine_lr);
for neuron in &mut self.neurons {
let total: f32 = neuron.weights.iter().sum();
if total > 1e-6 {
let scale = WEIGHT_BUDGET / total;
for w in &mut neuron.weights {
*w *= scale;
*w = w.clamp(RM_STDP_W_MIN, RM_STDP_W_MAX);
}
}
}
let lif_mean = if !self.neurons.is_empty() {
let sum: f32 = self.neurons.iter().map(|n| n.membrane_potential).sum();
sum / self.neurons.len() as f32
} else {
0.0
};
let iz_drive = (lif_mean * 20.0 + self.modulators.dopamine * 5.0)
.clamp(0.0, 15.0);
for iz in &mut self.iz_neurons {
iz.step(iz_drive);
}
spike_ids
}
fn apply_stdp(&mut self, _stimuli: &[f32; crate::NUM_INPUT_CHANNELS], dopamine_lr: f32) {
if dopamine_lr < 1e-6 {
return;
}
let input_times = self.input_spike_times.clone();
for neuron in &mut self.neurons {
if neuron.last_spike_time < 0 {
continue;
}
for (ch, &pre_time) in input_times.iter().enumerate() {
if ch >= neuron.weights.len() || pre_time < 0 {
continue;
}
let post_time = neuron.last_spike_time;
if post_time < 0 {
continue;
}
let delta_t = (post_time - pre_time) as f32;
let dw = if delta_t >= 0.0 {
RM_STDP_A_PLUS * (-delta_t / RM_STDP_TAU_PLUS).exp()
} else {
-RM_STDP_A_MINUS * (delta_t / RM_STDP_TAU_MINUS).exp()
};
neuron.weights[ch] = (neuron.weights[ch] + dw * dopamine_lr)
.clamp(RM_STDP_W_MIN, RM_STDP_W_MAX);
}
}
}
pub fn get_membrane_potentials(&self) -> Vec<f32> {
self.neurons.iter().map(|n| n.membrane_potential).collect()
}
pub fn get_thresholds(&self) -> Vec<f32> {
self.neurons.iter().map(|n| n.threshold).collect()
}
pub fn reset(&mut self) {
self.global_step = 0;
self.input_spike_times = vec![-1; crate::NUM_INPUT_CHANNELS];
self.predictive_state = [0.0; crate::NUM_INPUT_CHANNELS];
for neuron in &mut self.neurons {
neuron.membrane_potential = 0.0;
neuron.last_spike = false;
neuron.last_spike_time = -1;
}
self.modulators = NeuroModulators::default();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_network_creation() {
let network = SpikingNetwork::new();
assert_eq!(network.neurons.len(), 16);
assert_eq!(network.iz_neurons.len(), 5);
assert_eq!(network.global_step, 0);
}
#[test]
fn test_network_step() {
let mut network = SpikingNetwork::new();
let stimuli = [0.5; crate::NUM_INPUT_CHANNELS];
let modulators = NeuroModulators::default();
let spikes = network.step(&stimuli, &modulators);
assert_eq!(network.global_step, 1);
assert!(spikes.len() <= 16);
}
#[test]
fn test_membrane_potentials() {
let network = SpikingNetwork::new();
let potentials = network.get_membrane_potentials();
assert_eq!(potentials.len(), 16);
for &p in &potentials {
assert_eq!(p, 0.0);
}
}
}