use super::{NeuromorphicConfig, SpikeEvent};
use scirs2_core::error::CoreResult as Result;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::random::{Rng, RngExt};
use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct SpikingNeuralNetwork {
pub config: NeuromorphicConfig,
pub neurons: Vec<SpikingNeuron>,
pub synapses: Vec<Vec<Synapse>>,
pub current_time: f64,
pub spike_history: VecDeque<SpikeEvent>,
pub population_activity: Array1<f64>,
}
#[derive(Debug, Clone)]
pub struct SpikingNeuron {
pub membrane_potential: f64,
pub resting_potential: f64,
pub threshold: f64,
pub tau_membrane: f64,
pub refractory_period: f64,
pub last_spike_time: Option<f64>,
pub input_current: f64,
pub adaptation_current: f64,
pub noise_amplitude: f64,
}
#[derive(Debug, Clone)]
pub struct Synapse {
pub source: usize,
pub target: usize,
pub weight: f64,
pub delay: f64,
pub facilitation: f64,
pub depression: f64,
pub pre_trace: f64,
pub post_trace: f64,
}
impl SpikingNeuron {
pub fn new(config: &NeuromorphicConfig) -> Self {
Self {
membrane_potential: 0.0,
resting_potential: 0.0,
threshold: config.spike_threshold,
tau_membrane: 0.020, refractory_period: config.refractory_period,
last_spike_time: None,
input_current: 0.0,
adaptation_current: 0.0,
noise_amplitude: config.noise_level,
}
}
pub fn update(&mut self, dt: f64, external_current: f64, current_time: f64) -> Option<f64> {
if let Some(last_spike) = self.last_spike_time {
if (current_time - last_spike) < self.refractory_period {
return None; }
}
let noise = if self.noise_amplitude > 0.0 {
let mut rng = scirs2_core::random::rng();
(rng.random::<f64>() - 0.5) * 2.0 * self.noise_amplitude
} else {
0.0
};
let total_current = external_current + self.input_current - self.adaptation_current + noise;
let dv_dt = (-(self.membrane_potential - self.resting_potential) + total_current)
/ self.tau_membrane;
self.membrane_potential += dv_dt * dt;
if self.membrane_potential >= self.threshold {
self.fire_spike();
Some(0.0) } else {
None
}
}
fn fire_spike(&mut self) {
self.membrane_potential = self.resting_potential;
self.last_spike_time = Some(0.0);
self.adaptation_current += 0.1; }
pub fn decay_adaptation(&mut self, dt: f64) {
let tau_adaptation = 0.1; self.adaptation_current *= (-dt / tau_adaptation).exp();
}
}
impl Synapse {
pub fn new(source: usize, target: usize, weight: f64, delay: f64) -> Self {
Self {
source,
target,
weight,
delay,
facilitation: 1.0,
depression: 1.0,
pre_trace: 0.0,
post_trace: 0.0,
}
}
pub fn compute_current(&self, pre_spike: bool) -> f64 {
if pre_spike {
self.weight * self.facilitation * self.depression
} else {
0.0
}
}
pub fn update_stp(&mut self, dt: f64, pre_spike: bool) {
let tau_facilitation = 0.050; let tau_depression = 0.100;
self.facilitation += (1.0 - self.facilitation) * dt / tau_facilitation;
self.depression += (1.0 - self.depression) * dt / tau_depression;
if pre_spike {
self.facilitation = (self.facilitation * 1.2).min(3.0); self.depression *= 0.8; }
}
pub fn update_stdp_traces(&mut self, dt: f64, pre_spike: bool, post_spike: bool) {
let tau_stdp = 0.020;
self.pre_trace *= (-dt / tau_stdp).exp();
self.post_trace *= (-dt / tau_stdp).exp();
if pre_spike {
self.pre_trace += 1.0;
}
if post_spike {
self.post_trace += 1.0;
}
}
pub fn apply_stdp(&mut self, learning_rate: f64, pre_spike: bool, post_spike: bool) {
let mut weight_change = 0.0;
if pre_spike && self.post_trace > 0.0 {
weight_change += learning_rate * self.post_trace;
}
if post_spike && self.pre_trace > 0.0 {
weight_change -= learning_rate * 0.5 * self.pre_trace;
}
self.weight += weight_change;
self.weight = self.weight.max(-1.0).min(1.0); }
}
impl SpikingNeuralNetwork {
pub fn new(config: NeuromorphicConfig, num_parameters: usize) -> Self {
let mut neurons = Vec::with_capacity(config.num_neurons);
for _ in 0..config.num_neurons {
neurons.push(SpikingNeuron::new(&config));
}
let mut synapses = vec![Vec::new(); config.num_neurons];
let connection_probability = 0.1; let mut rng = scirs2_core::random::rng();
for i in 0..config.num_neurons {
for j in 0..config.num_neurons {
if i != j && rng.random::<f64>() < connection_probability {
let weight = (rng.random::<f64>() - 0.5) * 0.2;
let delay = rng.random::<f64>() * 0.005; synapses[i].push(Synapse::new(i, j, weight, delay));
}
}
}
let num_neurons = config.num_neurons;
Self {
config,
neurons,
synapses,
current_time: 0.0,
spike_history: VecDeque::with_capacity(10000),
population_activity: Array1::zeros(num_neurons),
}
}
pub fn encode_parameters(&mut self, parameters: &ArrayView1<f64>) {
let neurons_per_param = self.config.num_neurons / parameters.len();
for (param_idx, ¶m_val) in parameters.iter().enumerate() {
let start_idx = param_idx * neurons_per_param;
let end_idx = ((param_idx + 1) * neurons_per_param).min(self.config.num_neurons);
let input_current = (param_val + 1.0) * 5.0;
for neuron_idx in start_idx..end_idx {
self.neurons[neuron_idx].input_current = input_current;
}
}
}
pub fn decode_parameters(&self, num_parameters: usize) -> Array1<f64> {
let mut decoded = Array1::zeros(num_parameters);
let neurons_per_param = self.config.num_neurons / num_parameters;
for param_idx in 0..num_parameters {
let start_idx = param_idx * neurons_per_param;
let end_idx = ((param_idx + 1) * neurons_per_param).min(self.config.num_neurons);
let mut activity_sum = 0.0;
for neuron_idx in start_idx..end_idx {
activity_sum += self.population_activity[neuron_idx];
}
if end_idx > start_idx {
decoded[param_idx] = (activity_sum / (end_idx - start_idx) as f64) - 1.0;
}
}
decoded
}
pub fn simulate_step(&mut self, objective_feedback: f64) -> Result<Vec<usize>> {
let mut spiked_neurons = Vec::new();
let inputs: Vec<(f64, f64)> = (0..self.neurons.len())
.map(|neuron_idx| {
let synaptic_input = self.compute_synaptic_input(neuron_idx);
let feedback_input = self.compute_feedback_input(neuron_idx, objective_feedback);
(synaptic_input, feedback_input)
})
.collect();
for (neuron_idx, neuron) in self.neurons.iter_mut().enumerate() {
let (synaptic_input, feedback_input) = inputs[neuron_idx];
let total_input = synaptic_input + feedback_input;
if let Some(_spike_time) = neuron.update(self.config.dt, total_input, self.current_time)
{
spiked_neurons.push(neuron_idx);
neuron.last_spike_time = Some(self.current_time);
self.spike_history.push_back(SpikeEvent {
time: self.current_time,
neuron_id: neuron_idx,
weight: 1.0,
});
self.population_activity[neuron_idx] = 1.0;
} else {
self.population_activity[neuron_idx] *= 0.95;
}
neuron.decay_adaptation(self.config.dt);
}
self.update_synapses(&spiked_neurons)?;
self.cleanup_spike_history();
self.current_time += self.config.dt;
Ok(spiked_neurons)
}
fn compute_synaptic_input(&self, target_neuron: usize) -> f64 {
let mut total_input = 0.0;
for source_neuron in 0..self.config.num_neurons {
for synapse in &self.synapses[source_neuron] {
if synapse.target == target_neuron {
if let Some(last_spike) = self.neurons[source_neuron].last_spike_time {
let time_since_spike = self.current_time - last_spike;
if time_since_spike >= synapse.delay
&& time_since_spike < synapse.delay + self.config.dt
{
total_input += synapse.compute_current(true);
}
}
}
}
}
total_input
}
fn compute_feedback_input(&self, neuron_idx: usize, objective_feedback: f64) -> f64 {
let feedback_strength = 1.0;
let normalized_feedback = -objective_feedback;
let phase = neuron_idx as f64 / self.config.num_neurons as f64 * 2.0 * std::f64::consts::PI;
feedback_strength * normalized_feedback * (phase.sin() + 1.0) * 0.5
}
fn update_synapses(&mut self, spiked_neurons: &[usize]) -> Result<()> {
for source_neuron in 0..self.config.num_neurons {
let source_spiked = spiked_neurons.contains(&source_neuron);
for synapse in &mut self.synapses[source_neuron] {
let target_spiked = spiked_neurons.contains(&synapse.target);
synapse.update_stp(self.config.dt, source_spiked);
synapse.update_stdp_traces(self.config.dt, source_spiked, target_spiked);
synapse.apply_stdp(self.config.learning_rate, source_spiked, target_spiked);
}
}
Ok(())
}
fn cleanup_spike_history(&mut self) {
let cutoff_time = self.current_time - 0.1; while let Some(spike) = self.spike_history.front() {
if spike.time < cutoff_time {
self.spike_history.pop_front();
} else {
break;
}
}
}
pub fn get_firing_rates(&self, window_duration: f64) -> Array1<f64> {
let mut rates = Array1::zeros(self.config.num_neurons);
let start_time = self.current_time - window_duration;
for spike in &self.spike_history {
if spike.time >= start_time {
rates[spike.neuron_id] += 1.0;
}
}
rates /= window_duration;
rates
}
pub fn reset(&mut self) {
self.current_time = 0.0;
self.spike_history.clear();
self.population_activity.fill(0.0);
for neuron in &mut self.neurons {
neuron.membrane_potential = neuron.resting_potential;
neuron.last_spike_time = None;
neuron.input_current = 0.0;
neuron.adaptation_current = 0.0;
}
for synapse_group in &mut self.synapses {
for synapse in synapse_group {
synapse.facilitation = 1.0;
synapse.depression = 1.0;
synapse.pre_trace = 0.0;
synapse.post_trace = 0.0;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_spiking_neuron_creation() {
let config = NeuromorphicConfig::default();
let neuron = SpikingNeuron::new(&config);
assert_eq!(neuron.membrane_potential, 0.0);
assert_eq!(neuron.threshold, config.spike_threshold);
assert!(neuron.last_spike_time.is_none());
}
#[test]
fn test_neuron_spike() {
let config = NeuromorphicConfig::default();
let mut neuron = SpikingNeuron::new(&config);
let spike_time = neuron.update(0.001, 50.0, 0.0);
assert!(spike_time.is_some());
assert_eq!(neuron.membrane_potential, neuron.resting_potential);
}
#[test]
fn test_synapse_creation() {
let synapse = Synapse::new(0, 1, 0.5, 0.002);
assert_eq!(synapse.source, 0);
assert_eq!(synapse.target, 1);
assert_eq!(synapse.weight, 0.5);
assert_eq!(synapse.delay, 0.002);
}
#[test]
fn test_synapse_current() {
let mut synapse = Synapse::new(0, 1, 0.5, 0.001);
assert_eq!(synapse.compute_current(false), 0.0);
let current = synapse.compute_current(true);
assert!(current > 0.0);
synapse.update_stp(0.001, true);
let current_after_stp = synapse.compute_current(true);
assert!(current_after_stp != current); }
#[test]
fn test_spiking_network_creation() {
let config = NeuromorphicConfig::default();
let network = SpikingNeuralNetwork::new(config, 3);
assert_eq!(network.neurons.len(), 100); assert_eq!(network.synapses.len(), 100);
assert_eq!(network.current_time, 0.0);
}
#[test]
fn test_parameter_encoding() {
let config = NeuromorphicConfig::default();
let mut network = SpikingNeuralNetwork::new(config, 2);
let params = Array1::from(vec![0.5, -0.3]);
network.encode_parameters(¶ms.view());
assert!(network.neurons.iter().any(|n| n.input_current != 0.0));
}
#[test]
fn test_network_simulation() {
let config = NeuromorphicConfig {
num_neurons: 10,
..Default::default()
};
let mut network = SpikingNeuralNetwork::new(config, 2);
for _ in 0..10 {
let _spiked = network.simulate_step(1.0).expect("Operation failed");
}
assert!(network.current_time > 0.0);
}
#[test]
fn test_firing_rates() {
let config = NeuromorphicConfig {
num_neurons: 5,
..Default::default()
};
let mut network = SpikingNeuralNetwork::new(config, 1);
for neuron in &mut network.neurons {
neuron.input_current = 20.0;
}
for _ in 0..100 {
network.simulate_step(0.0).expect("Operation failed");
}
let rates = network.get_firing_rates(0.1);
assert!(rates.iter().any(|&r| r > 0.0)); }
#[test]
fn test_network_reset() {
let config = NeuromorphicConfig::default();
let mut network = SpikingNeuralNetwork::new(config, 2);
for _ in 0..10 {
network.simulate_step(1.0).expect("Operation failed");
}
let _time_before_reset = network.current_time;
network.reset();
assert_eq!(network.current_time, 0.0);
assert!(network.spike_history.is_empty());
assert!(network.population_activity.iter().all(|&x| x == 0.0));
}
}