use super::{
EventPriority, MembraneDynamicsConfig, NeuromorphicEvent, NeuromorphicMetrics, PlasticityModel,
STDPConfig, Spike, SpikeTrain,
};
use scirs2_neural::activations_minimal::Activation;
use scirs2_neural::layers::Layer;
use scirs2_stats::distributions;
use crate::error::Result;
use crate::optimizers::Optimizer;
use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, DataMut, Dimension};
use scirs2_core::numeric::Float;
use scirs2_core::random::{thread_rng, Rng};
use std::collections::{HashMap, VecDeque};
use std::fmt::Debug;
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct SpikingConfig<T: Float + Debug + Send + Sync + 'static> {
pub time_step: T,
pub simulation_time: T,
pub encoding_method: SpikeEncodingMethod,
pub decoding_method: SpikeDecodingMethod,
pub spike_learning_rate: T,
pub temporal_window: T,
pub lateral_inhibition: bool,
pub homeostatic_config: HomeostaticConfig<T>,
pub noise_config: SpikeNoiseConfig<T>,
}
#[derive(Debug, Clone, Copy)]
pub enum SpikeEncodingMethod {
RateCoding,
TemporalCoding,
PopulationVectorCoding,
SparseCoding,
PhaseCoding,
BurstCoding,
RankOrderCoding,
}
#[derive(Debug, Clone, Copy)]
pub enum SpikeDecodingMethod {
RateDecoding,
TemporalDecoding,
PopulationVectorDecoding,
WeightedSpikeCount,
MovingAverageFilter,
ExponentialDecayFilter,
}
#[derive(Debug, Clone)]
pub struct HomeostaticConfig<T: Float + Debug + Send + Sync + 'static> {
pub enable_homeostatic_scaling: bool,
pub target_firing_rate: T,
pub scaling_time_constant: T,
pub scaling_factor: T,
pub enable_intrinsic_plasticity: bool,
pub threshold_adaptation_rate: T,
}
#[derive(Debug, Clone)]
pub struct SpikeNoiseConfig<T: Float + Debug + Send + Sync + 'static> {
pub background_rate: T,
pub jitter_std: T,
pub poisson_noise: bool,
pub noise_amplitude: T,
pub correlation_noise: T,
}
impl<T: Float + Debug + Send + Sync + 'static> Default for SpikingConfig<T> {
fn default() -> Self {
Self {
time_step: T::from(0.1).unwrap_or_else(|| T::zero()),
simulation_time: T::from(1000.0).unwrap_or_else(|| T::zero()),
encoding_method: SpikeEncodingMethod::RateCoding,
decoding_method: SpikeDecodingMethod::RateDecoding,
spike_learning_rate: T::from(0.01).unwrap_or_else(|| T::zero()),
temporal_window: T::from(20.0).unwrap_or_else(|| T::zero()),
lateral_inhibition: false,
homeostatic_config: HomeostaticConfig::default(),
noise_config: SpikeNoiseConfig::default(),
}
}
}
impl<T: Float + Debug + Send + Sync + 'static> Default for HomeostaticConfig<T> {
fn default() -> Self {
Self {
enable_homeostatic_scaling: false,
target_firing_rate: T::from(10.0).unwrap_or_else(|| T::zero()),
scaling_time_constant: T::from(1000.0).unwrap_or_else(|| T::zero()),
scaling_factor: T::from(0.01).unwrap_or_else(|| T::zero()),
enable_intrinsic_plasticity: false,
threshold_adaptation_rate: T::from(0.001).unwrap_or_else(|| T::zero()),
}
}
}
impl<T: Float + Debug + Send + Sync + 'static> Default for SpikeNoiseConfig<T> {
fn default() -> Self {
Self {
background_rate: T::from(1.0).unwrap_or_else(|| T::zero()),
jitter_std: T::from(0.5).unwrap_or_else(|| T::zero()),
poisson_noise: false,
noise_amplitude: T::from(0.1).unwrap_or_else(|| T::zero()),
correlation_noise: T::zero(),
}
}
}
pub struct SpikingOptimizer<
T: Float + Debug + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
> {
config: SpikingConfig<T>,
stdp_config: STDPConfig<T>,
membrane_config: MembraneDynamicsConfig<T>,
current_time: T,
spike_trains: HashMap<usize, SpikeTrain<T>>,
membrane_potentials: Array1<T>,
synaptic_weights: Array2<T>,
last_spike_times: Array1<T>,
refractory_until: Array1<T>,
homeostatic_scales: Array1<T>,
spike_buffer: VecDeque<Spike<T>>,
metrics: NeuromorphicMetrics<T>,
plasticity_model: PlasticityModel,
}
impl<
T: Float
+ Debug
+ Send
+ Sync
+ scirs2_core::ndarray::ScalarOperand
+ 'static
+ std::iter::Sum,
> SpikingOptimizer<T>
{
pub fn new(
config: SpikingConfig<T>,
stdp_config: STDPConfig<T>,
membrane_config: MembraneDynamicsConfig<T>,
num_neurons: usize,
) -> Self {
let resting_potential = membrane_config.resting_potential;
Self {
config,
stdp_config,
membrane_config,
current_time: T::zero(),
spike_trains: HashMap::new(),
membrane_potentials: Array1::from_elem(num_neurons, resting_potential),
synaptic_weights: Array2::ones((num_neurons, num_neurons))
* T::from(0.1).unwrap_or_else(|| T::zero()),
last_spike_times: Array1::from_elem(
num_neurons,
T::from(-1000.0).unwrap_or_else(|| T::zero()),
),
refractory_until: Array1::zeros(num_neurons),
homeostatic_scales: Array1::ones(num_neurons),
spike_buffer: VecDeque::new(),
metrics: NeuromorphicMetrics::default(),
plasticity_model: PlasticityModel::STDP,
}
}
pub fn encode_input(&self, input: &Array1<T>) -> Result<Vec<SpikeTrain<T>>> {
let mut spike_trains = Vec::new();
for (neuron_id, &value) in input.iter().enumerate() {
let spike_train = match self.config.encoding_method {
SpikeEncodingMethod::RateCoding => self.rate_encode(neuron_id, value)?,
SpikeEncodingMethod::TemporalCoding => self.temporal_encode(neuron_id, value)?,
SpikeEncodingMethod::PopulationVectorCoding => {
self.population_vector_encode(neuron_id, value)?
}
SpikeEncodingMethod::SparseCoding => self.sparse_encode(neuron_id, value)?,
_ => {
self.rate_encode(neuron_id, value)?
}
};
spike_trains.push(spike_train);
}
Ok(spike_trains)
}
fn rate_encode(&self, neuron_id: usize, value: T) -> Result<SpikeTrain<T>> {
let max_rate = T::from(100.0).unwrap_or_else(|| T::zero()); let firing_rate = value.abs() * max_rate;
let mut spike_times = Vec::new();
let dt = self.config.time_step;
let total_time = self.config.simulation_time;
let mut time = T::zero();
while time < total_time {
let spike_prob = firing_rate * dt / T::from(1000.0).unwrap_or_else(|| T::zero());
if thread_rng().random::<f64>() < spike_prob.to_f64().unwrap_or(0.0) {
spike_times.push(time);
}
time = time + dt;
}
Ok(SpikeTrain::new(neuron_id, spike_times))
}
fn temporal_encode(&self, neuron_id: usize, value: T) -> Result<SpikeTrain<T>> {
let max_delay = T::from(20.0).unwrap_or_else(|| T::zero()); let spike_time = if value > T::zero() {
max_delay * (T::one() - value.min(T::one()))
} else {
max_delay };
let spike_times = if spike_time < max_delay {
vec![spike_time]
} else {
Vec::new()
};
Ok(SpikeTrain::new(neuron_id, spike_times))
}
fn population_vector_encode(&self, neuron_id: usize, value: T) -> Result<SpikeTrain<T>> {
self.rate_encode(neuron_id, value)
}
fn sparse_encode(&self, neuron_id: usize, value: T) -> Result<SpikeTrain<T>> {
let threshold = T::from(0.5).unwrap_or_else(|| T::zero());
if value.abs() > threshold {
self.rate_encode(neuron_id, value)
} else {
Ok(SpikeTrain::new(neuron_id, Vec::new()))
}
}
pub fn decode_output(&self, spike_trains: &[SpikeTrain<T>]) -> Result<Array1<T>> {
let mut output = Array1::zeros(spike_trains.len());
for (i, spike_train) in spike_trains.iter().enumerate() {
output[i] = match self.config.decoding_method {
SpikeDecodingMethod::RateDecoding => self.rate_decode(spike_train)?,
SpikeDecodingMethod::TemporalDecoding => self.temporal_decode(spike_train)?,
SpikeDecodingMethod::WeightedSpikeCount => {
self.weighted_spike_count_decode(spike_train)?
}
_ => {
self.rate_decode(spike_train)?
}
};
}
Ok(output)
}
fn rate_decode(&self, spike_train: &SpikeTrain<T>) -> Result<T> {
let window_duration = self.config.temporal_window;
let spike_count = T::from(spike_train.spike_count).unwrap_or_else(|| T::zero());
let rate = spike_count / (window_duration / T::from(1000.0).unwrap_or_else(|| T::zero()));
Ok(rate / T::from(100.0).unwrap_or_else(|| T::zero())) }
fn temporal_decode(&self, spike_train: &SpikeTrain<T>) -> Result<T> {
if spike_train.spike_times.is_empty() {
Ok(T::zero())
} else {
let first_spike = spike_train.spike_times[0];
let max_delay = T::from(20.0).unwrap_or_else(|| T::zero());
Ok(T::one() - (first_spike / max_delay).min(T::one()))
}
}
fn weighted_spike_count_decode(&self, spike_train: &SpikeTrain<T>) -> Result<T> {
if spike_train.spike_times.is_empty() {
return Ok(T::zero());
}
let mut weighted_sum = T::zero();
let current_time = self.current_time;
for &spike_time in &spike_train.spike_times {
let time_diff = current_time - spike_time;
let weight = (-time_diff / T::from(10.0).unwrap_or_else(|| T::zero())).exp(); weighted_sum = weighted_sum + weight;
}
Ok(weighted_sum)
}
pub fn simulate_step(&mut self, input_spikes: &[Spike<T>]) -> Result<Vec<Spike<T>>> {
let mut output_spikes = Vec::new();
let dt = self.config.time_step;
for spike in input_spikes {
self.process_input_spike(spike)?;
}
for neuron_id in 0..self.membrane_potentials.len() {
if self.current_time >= self.refractory_until[neuron_id] {
self.update_membrane_potential(neuron_id, dt)?;
if self.membrane_potentials[neuron_id] >= self.membrane_config.threshold_potential {
let spike = self.generate_spike(neuron_id)?;
output_spikes.push(spike);
}
}
}
self.update_plasticity(&output_spikes)?;
if self.config.homeostatic_config.enable_homeostatic_scaling {
self.update_homeostatic_scaling()?;
}
self.current_time = self.current_time + dt;
Ok(output_spikes)
}
fn process_input_spike(&mut self, spike: &Spike<T>) -> Result<()> {
let target_neuron = spike.postsynaptic_id.unwrap_or(spike.neuron_id);
if target_neuron < self.membrane_potentials.len() {
let synaptic_current = spike.weight * spike.amplitude;
self.membrane_potentials[target_neuron] =
self.membrane_potentials[target_neuron] + synaptic_current;
}
Ok(())
}
fn update_membrane_potential(&mut self, neuron_id: usize, dt: T) -> Result<()> {
let v = self.membrane_potentials[neuron_id];
let v_rest = self.membrane_config.resting_potential;
let tau = self.membrane_config.tau_membrane;
let dv_dt = (v_rest - v) / tau;
let new_v = v + dv_dt * dt;
self.membrane_potentials[neuron_id] = new_v;
Ok(())
}
fn generate_spike(&mut self, neuron_id: usize) -> Result<Spike<T>> {
self.membrane_potentials[neuron_id] = self.membrane_config.reset_potential;
self.refractory_until[neuron_id] =
self.current_time + self.membrane_config.refractory_period;
self.last_spike_times[neuron_id] = self.current_time;
let spike = Spike {
neuron_id,
time: self.current_time,
amplitude: T::from(1.0).unwrap_or_else(|| T::zero()),
width: Some(T::from(1.0).unwrap_or_else(|| T::zero())),
weight: T::one(),
presynaptic_id: None,
postsynaptic_id: None,
};
if let Some(spike_train) = self.spike_trains.get_mut(&neuron_id) {
spike_train.spike_times.push(self.current_time);
spike_train.spike_count += 1;
} else {
let spike_train = SpikeTrain::new(neuron_id, vec![self.current_time]);
self.spike_trains.insert(neuron_id, spike_train);
}
self.metrics.total_spikes += 1;
Ok(spike)
}
fn update_plasticity(&mut self, output_spikes: &[Spike<T>]) -> Result<()> {
match self.plasticity_model {
PlasticityModel::STDP => {
self.update_stdp(output_spikes)?;
}
PlasticityModel::Hebbian => {
self.update_hebbian(output_spikes)?;
}
_ => {
self.update_stdp(output_spikes)?;
}
}
Ok(())
}
fn update_stdp(&mut self, output_spikes: &[Spike<T>]) -> Result<()> {
for spike in output_spikes {
let post_id = spike.neuron_id;
let post_time = spike.time;
for pre_id in 0..self.last_spike_times.len() {
if pre_id != post_id {
let pre_time = self.last_spike_times[pre_id];
if pre_time > T::from(-1000.0).unwrap_or_else(|| T::zero()) {
let dt = post_time - pre_time;
let weight_change = self.compute_stdp_update(dt);
self.synaptic_weights[[pre_id, post_id]] =
(self.synaptic_weights[[pre_id, post_id]] + weight_change)
.max(self.stdp_config.weight_min)
.min(self.stdp_config.weight_max);
}
}
}
}
Ok(())
}
fn compute_stdp_update(&self, dt: T) -> T {
if dt > T::zero() {
let exp_arg = -dt / self.stdp_config.tau_pot;
self.stdp_config.learning_rate_pot * exp_arg.exp()
} else {
let exp_arg = dt / self.stdp_config.tau_dep;
-self.stdp_config.learning_rate_dep * exp_arg.exp()
}
}
fn update_hebbian(&mut self, output_spikes: &[Spike<T>]) -> Result<()> {
for spike in output_spikes {
let post_id = spike.neuron_id;
for pre_id in 0..self.membrane_potentials.len() {
if pre_id != post_id {
let pre_activity =
self.membrane_potentials[pre_id] / self.membrane_config.threshold_potential;
let weight_change = self.stdp_config.learning_rate_pot * pre_activity;
self.synaptic_weights[[pre_id, post_id]] =
(self.synaptic_weights[[pre_id, post_id]] + weight_change)
.max(self.stdp_config.weight_min)
.min(self.stdp_config.weight_max);
}
}
}
Ok(())
}
fn update_homeostatic_scaling(&mut self) -> Result<()> {
let target_rate = self.config.homeostatic_config.target_firing_rate;
let time_constant = self.config.homeostatic_config.scaling_time_constant;
let dt = self.config.time_step;
for neuron_id in 0..self.homeostatic_scales.len() {
if let Some(spike_train) = self.spike_trains.get(&neuron_id) {
let current_rate = spike_train.firing_rate;
let rate_error = target_rate - current_rate;
let scale_change = rate_error * dt / time_constant;
self.homeostatic_scales[neuron_id] =
self.homeostatic_scales[neuron_id] + scale_change;
for pre_id in 0..self.synaptic_weights.nrows() {
self.synaptic_weights[[pre_id, neuron_id]] = self.synaptic_weights
[[pre_id, neuron_id]]
* self.homeostatic_scales[neuron_id];
}
}
}
Ok(())
}
pub fn get_metrics(&self) -> &NeuromorphicMetrics<T> {
&self.metrics
}
pub fn reset(&mut self) {
self.current_time = T::zero();
self.membrane_potentials
.fill(self.membrane_config.resting_potential);
self.last_spike_times
.fill(T::from(-1000.0).unwrap_or_else(|| T::zero()));
self.refractory_until.fill(T::zero());
self.spike_trains.clear();
self.spike_buffer.clear();
self.metrics = NeuromorphicMetrics::default();
}
}
pub struct SpikeTrainOptimizer<
T: Float + Debug + scirs2_core::ndarray::ScalarOperand + std::fmt::Debug + Send + Sync,
> {
config: SpikingConfig<T>,
pattern_templates: Vec<SpikePattern<T>>,
matching_threshold: T,
pattern_learning_rate: T,
temporal_kernel: TemporalKernel<T>,
}
#[derive(Debug, Clone)]
pub struct SpikePattern<T: Float + Debug + Send + Sync + 'static> {
pub pattern_id: usize,
pub relative_spike_times: Vec<T>,
pub duration: T,
pub weight: T,
pub observation_count: usize,
}
#[derive(Debug, Clone)]
pub struct TemporalKernel<T: Float + Debug + Send + Sync + 'static> {
pub kernel_type: TemporalKernelType,
pub width: T,
pub parameters: Vec<T>,
}
#[derive(Debug, Clone, Copy)]
pub enum TemporalKernelType {
Gaussian,
Exponential,
Alpha,
Rectangular,
}
impl<T: Float + Debug + Send + Sync + scirs2_core::ndarray::ScalarOperand + std::fmt::Debug>
SpikeTrainOptimizer<T>
{
pub fn new(config: SpikingConfig<T>) -> Self {
Self {
config,
pattern_templates: Vec::new(),
matching_threshold: T::from(0.8).unwrap_or_else(|| T::zero()),
pattern_learning_rate: T::from(0.1).unwrap_or_else(|| T::zero()),
temporal_kernel: TemporalKernel {
kernel_type: TemporalKernelType::Gaussian,
width: T::from(5.0).unwrap_or_else(|| T::zero()),
parameters: vec![T::one()],
},
}
}
pub fn learn_patterns(&mut self, spike_trains: &[SpikeTrain<T>]) -> Result<()> {
for spike_train in spike_trains {
self.extract_and_learn_patterns(spike_train)?;
}
Ok(())
}
fn extract_and_learn_patterns(&mut self, spike_train: &SpikeTrain<T>) -> Result<()> {
let window_size = T::from(50.0).unwrap_or_else(|| T::zero()); let step_size = T::from(10.0).unwrap_or_else(|| T::zero());
let mut window_start = T::zero();
while window_start < spike_train.duration {
let window_end = window_start + window_size;
let window_spikes: Vec<T> = spike_train
.spike_times
.iter()
.filter(|&&t| t >= window_start && t < window_end)
.map(|&t| t - window_start) .collect();
if !window_spikes.is_empty() {
let pattern = SpikePattern {
pattern_id: self.pattern_templates.len(),
relative_spike_times: window_spikes,
duration: window_size,
weight: T::one(),
observation_count: 1,
};
if let Some(similar_pattern_id) = self.find_similar_pattern(&pattern) {
self.update_pattern(similar_pattern_id, &pattern)?;
} else {
self.pattern_templates.push(pattern);
}
}
window_start = window_start + step_size;
}
Ok(())
}
fn find_similar_pattern(&self, new_pattern: &SpikePattern<T>) -> Option<usize> {
for (i, existing_pattern) in self.pattern_templates.iter().enumerate() {
let similarity = self.compute_pattern_similarity(new_pattern, existing_pattern);
if similarity > self.matching_threshold {
return Some(i);
}
}
None
}
fn compute_pattern_similarity(
&self,
pattern1: &SpikePattern<T>,
pattern2: &SpikePattern<T>,
) -> T {
let max_spikes = pattern1
.relative_spike_times
.len()
.max(pattern2.relative_spike_times.len());
if max_spikes == 0 {
return T::one();
}
let count_diff = (pattern1.relative_spike_times.len() as i32
- pattern2.relative_spike_times.len() as i32)
.abs() as f64;
let count_similarity =
T::one() - T::from(count_diff / max_spikes as f64).unwrap_or_else(|| T::zero());
if !pattern1.relative_spike_times.is_empty() && !pattern2.relative_spike_times.is_empty() {
let temporal_similarity = self.compute_temporal_similarity(
&pattern1.relative_spike_times,
&pattern2.relative_spike_times,
);
(count_similarity + temporal_similarity) / T::from(2.0).unwrap_or_else(|| T::zero())
} else {
count_similarity
}
}
fn compute_temporal_similarity(&self, spikes1: &[T], spikes2: &[T]) -> T {
let mut max_correlation = T::zero();
let max_shift = T::from(10.0).unwrap_or_else(|| T::zero()); let shift_step = T::from(1.0).unwrap_or_else(|| T::zero());
let mut shift = -max_shift;
while shift <= max_shift {
let correlation = self.compute_spike_correlation(spikes1, spikes2, shift);
max_correlation = max_correlation.max(correlation);
shift = shift + shift_step;
}
max_correlation
}
fn compute_spike_correlation(&self, spikes1: &[T], spikes2: &[T], shift: T) -> T {
let mut correlation = T::zero();
let kernel_width = self.temporal_kernel.width;
for &t1 in spikes1 {
for &t2 in spikes2 {
let dt = (t1 - (t2 + shift)).abs();
let kernel_value = (-dt * dt
/ (T::from(2.0).unwrap_or_else(|| T::zero()) * kernel_width * kernel_width))
.exp();
correlation = correlation + kernel_value;
}
}
if !spikes1.is_empty() && !spikes2.is_empty() {
correlation / T::from(spikes1.len() * spikes2.len()).expect("unwrap failed")
} else {
T::zero()
}
}
fn update_pattern(&mut self, pattern_id: usize, new_pattern: &SpikePattern<T>) -> Result<()> {
if let Some(existing_pattern) = self.pattern_templates.get_mut(pattern_id) {
let alpha = self.pattern_learning_rate;
if existing_pattern.relative_spike_times.len() == new_pattern.relative_spike_times.len()
{
for (existing_time, &new_time) in existing_pattern
.relative_spike_times
.iter_mut()
.zip(new_pattern.relative_spike_times.iter())
{
*existing_time = *existing_time * (T::one() - alpha) + new_time * alpha;
}
}
existing_pattern.observation_count += 1;
existing_pattern.weight =
existing_pattern.weight * (T::one() - alpha) + new_pattern.weight * alpha;
}
Ok(())
}
pub fn recognize_patterns(&self, spike_train: &SpikeTrain<T>) -> Result<Vec<(usize, T, T)>> {
let mut recognized_patterns = Vec::new();
let window_size = T::from(50.0).unwrap_or_else(|| T::zero());
let step_size = T::from(5.0).unwrap_or_else(|| T::zero());
let mut window_start = T::zero();
while window_start < spike_train.duration {
let window_end = window_start + window_size;
let window_spikes: Vec<T> = spike_train
.spike_times
.iter()
.filter(|&&t| t >= window_start && t < window_end)
.map(|&t| t - window_start)
.collect();
if !window_spikes.is_empty() {
let test_pattern = SpikePattern {
pattern_id: 0,
relative_spike_times: window_spikes,
duration: window_size,
weight: T::one(),
observation_count: 1,
};
let mut best_match = (0, T::zero());
for (i, template) in self.pattern_templates.iter().enumerate() {
let similarity = self.compute_pattern_similarity(&test_pattern, template);
if similarity > best_match.1 {
best_match = (i, similarity);
}
}
if best_match.1 > self.matching_threshold {
recognized_patterns.push((best_match.0, window_start, best_match.1));
}
}
window_start = window_start + step_size;
}
Ok(recognized_patterns)
}
pub fn get_patterns(&self) -> &[SpikePattern<T>] {
&self.pattern_templates
}
}