use scirs2_core::error::CoreResult as Result;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::random::{Rng, RngExt};
use statrs::statistics::Statistics;
use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct AdvancedAdvancedSTDP {
pub pre_trace_fast: f64,
pub post_trace_fast: f64,
pub pre_trace_slow: f64,
pub post_trace_slow: f64,
pub pre_trace_triplet: f64,
pub post_trace_triplet: f64,
pub calcium_concentration: f64,
pub calcium_threshold_low: f64,
pub calcium_threshold_high: f64,
pub metaplasticity_factor: f64,
pub recent_activity: VecDeque<f64>,
pub sliding_threshold: f64,
pub target_firing_rate: f64,
pub current_firing_rate: f64,
pub scaling_factor: f64,
pub tau_plus_fast: f64,
pub tau_minus_fast: f64,
pub tau_plus_slow: f64,
pub tau_minus_slow: f64,
pub tau_calcium: f64,
pub tau_metaplasticity: f64,
pub eta_ltp: f64,
pub eta_ltd: f64,
pub eta_triplet: f64,
pub eta_homeostatic: f64,
pub theta_d: f64,
pub theta_p: f64,
pub spike_history_pre: VecDeque<f64>,
pub spike_history_post: VecDeque<f64>,
pub w_min: f64,
pub w_max: f64,
}
impl AdvancedAdvancedSTDP {
pub fn new(eta_ltp: f64, eta_ltd: f64, target_firing_rate: f64) -> Self {
Self {
pre_trace_fast: 0.0,
post_trace_fast: 0.0,
pre_trace_slow: 0.0,
post_trace_slow: 0.0,
pre_trace_triplet: 0.0,
post_trace_triplet: 0.0,
calcium_concentration: 0.0,
calcium_threshold_low: 0.2,
calcium_threshold_high: 0.6,
metaplasticity_factor: 1.0,
recent_activity: VecDeque::with_capacity(1000),
sliding_threshold: 0.5,
target_firing_rate,
current_firing_rate: 0.0,
scaling_factor: 1.0,
tau_plus_fast: 0.017, tau_minus_fast: 0.034, tau_plus_slow: 0.688, tau_minus_slow: 0.688, tau_calcium: 0.048, tau_metaplasticity: 100.0,
eta_ltp,
eta_ltd,
eta_triplet: eta_ltp * 0.1,
eta_homeostatic: eta_ltp * 0.01,
theta_d: 0.2,
theta_p: 0.8,
spike_history_pre: VecDeque::with_capacity(100),
spike_history_post: VecDeque::with_capacity(100),
w_min: -2.0,
w_max: 2.0,
}
}
pub fn update_weight_advanced(
&mut self,
current_weight: f64,
pre_spike: bool,
post_spike: bool,
dt: f64,
current_time: f64,
objective_improvement: f64,
) -> f64 {
self.update_calcium(pre_spike, post_spike, dt);
self.update_metaplasticity(current_time, objective_improvement);
self.update_homeostasis(pre_spike, post_spike, dt);
self.decay_traces(dt);
if pre_spike {
self.spike_history_pre.push_back(current_time);
if self.spike_history_pre.len() > 100 {
self.spike_history_pre.pop_front();
}
}
if post_spike {
self.spike_history_post.push_back(current_time);
if self.spike_history_post.len() > 100 {
self.spike_history_post.pop_front();
}
}
let mut total_weight_change = 0.0;
total_weight_change += self.compute_pairwise_stdp(pre_spike, post_spike);
total_weight_change += self.compute_triplet_stdp(pre_spike, post_spike);
total_weight_change += self.compute_calcium_plasticity(current_weight);
total_weight_change += self.compute_bcm_plasticity(pre_spike, post_spike, current_weight);
total_weight_change += self.compute_homeostatic_scaling(current_weight);
total_weight_change *= self.metaplasticity_factor;
total_weight_change *= self.scaling_factor;
let new_weight = current_weight + total_weight_change;
self.apply_weight_constraints(new_weight)
}
fn update_calcium(&mut self, pre_spike: bool, post_spike: bool, dt: f64) {
self.calcium_concentration *= (-dt / self.tau_calcium).exp();
if pre_spike {
self.calcium_concentration += 0.1;
}
if post_spike {
self.calcium_concentration += 0.2;
}
self.calcium_concentration = self.calcium_concentration.min(1.0);
}
fn update_metaplasticity(&mut self, current_time: f64, objective_improvement: f64) {
self.recent_activity.push_back(objective_improvement);
if self.recent_activity.len() > 1000 {
self.recent_activity.pop_front();
}
if self.recent_activity.len() > 10 {
let mean: f64 =
self.recent_activity.iter().sum::<f64>() / self.recent_activity.len() as f64;
let variance: f64 = self
.recent_activity
.iter()
.map(|&x| (x - mean).powi(2))
.sum::<f64>()
/ self.recent_activity.len() as f64;
self.metaplasticity_factor = 1.0 + variance.sqrt();
self.sliding_threshold = 0.9 * self.sliding_threshold + 0.1 * mean.abs();
}
}
fn update_homeostasis(&mut self, pre_spike: bool, post_spike: bool, dt: f64) {
let spike_rate = if post_spike { 1.0 / dt } else { 0.0 };
self.current_firing_rate = 0.999 * self.current_firing_rate + 0.001 * spike_rate;
let rate_ratio = self.current_firing_rate / self.target_firing_rate.max(0.1);
self.scaling_factor = (2.0 / (1.0 + rate_ratio)).min(2.0).max(0.5);
}
fn decay_traces(&mut self, dt: f64) {
self.pre_trace_fast *= (-dt / self.tau_plus_fast).exp();
self.post_trace_fast *= (-dt / self.tau_minus_fast).exp();
self.pre_trace_slow *= (-dt / self.tau_plus_slow).exp();
self.post_trace_slow *= (-dt / self.tau_minus_slow).exp();
self.pre_trace_triplet *= (-dt / (self.tau_plus_fast * 2.0)).exp();
self.post_trace_triplet *= (-dt / (self.tau_minus_fast * 2.0)).exp();
}
fn compute_pairwise_stdp(&mut self, pre_spike: bool, post_spike: bool) -> f64 {
let mut weight_change = 0.0;
if pre_spike {
self.pre_trace_fast += 1.0;
self.pre_trace_slow += 1.0;
weight_change -= self.eta_ltd * (self.post_trace_fast + 0.1 * self.post_trace_slow);
}
if post_spike {
self.post_trace_fast += 1.0;
self.post_trace_slow += 1.0;
weight_change += self.eta_ltp * (self.pre_trace_fast + 0.1 * self.pre_trace_slow);
}
weight_change
}
fn compute_triplet_stdp(&mut self, pre_spike: bool, post_spike: bool) -> f64 {
let mut weight_change = 0.0;
if pre_spike {
self.pre_trace_triplet += 1.0;
weight_change -= self.eta_triplet * self.post_trace_fast * self.post_trace_triplet;
}
if post_spike {
self.post_trace_triplet += 1.0;
weight_change += self.eta_triplet * self.pre_trace_fast * self.pre_trace_triplet;
}
weight_change
}
fn compute_calcium_plasticity(&self, current_weight: f64) -> f64 {
let ca = self.calcium_concentration;
if ca < self.calcium_threshold_low {
-self.eta_ltd * 0.1 * current_weight.abs()
} else if ca > self.calcium_threshold_high {
self.eta_ltp * 0.1 * (self.w_max - current_weight.abs())
} else {
let normalized_ca = (ca - self.calcium_threshold_low)
/ (self.calcium_threshold_high - self.calcium_threshold_low);
self.eta_ltp * 0.05 * (2.0 * normalized_ca - 1.0)
}
}
fn compute_bcm_plasticity(
&self,
pre_spike: bool,
post_spike: bool,
_current_weight: f64,
) -> f64 {
if !pre_spike && !post_spike {
return 0.0;
}
let post_activity = if post_spike { 1.0 } else { 0.0 };
let pre_activity = if pre_spike { 1.0 } else { 0.0 };
let theta = self.sliding_threshold;
pre_activity * post_activity * (post_activity - theta) * self.eta_ltp * 0.1
}
fn compute_homeostatic_scaling(&self, current_weight: f64) -> f64 {
let rate_error = self.target_firing_rate - self.current_firing_rate;
self.eta_homeostatic * rate_error * current_weight * 0.01
}
fn apply_weight_constraints(&self, weight: f64) -> f64 {
if weight > self.w_max {
self.w_max - (weight - self.w_max).exp().recip()
} else if weight < self.w_min {
self.w_min + (self.w_min - weight).exp().recip()
} else {
weight
}
}
pub fn get_plasticity_stats(&self) -> PlasticityStats {
PlasticityStats {
calcium_level: self.calcium_concentration,
metaplasticity_factor: self.metaplasticity_factor,
scaling_factor: self.scaling_factor,
firing_rate_error: self.target_firing_rate - self.current_firing_rate,
sliding_threshold: self.sliding_threshold,
trace_strength: (self.pre_trace_fast + self.post_trace_fast) / 2.0,
}
}
}
#[derive(Debug, Clone)]
pub struct PlasticityStats {
pub calcium_level: f64,
pub metaplasticity_factor: f64,
pub scaling_factor: f64,
pub firing_rate_error: f64,
pub sliding_threshold: f64,
pub trace_strength: f64,
}
#[derive(Debug, Clone)]
pub struct STDPLearningRule {
pub pre_trace: f64,
pub post_trace: f64,
pub learning_rate: f64,
pub tau_plus: f64,
pub tau_minus: f64,
}
impl STDPLearningRule {
pub fn new(learning_rate: f64) -> Self {
Self {
pre_trace: 0.0,
post_trace: 0.0,
learning_rate,
tau_plus: 0.020, tau_minus: 0.020, }
}
pub fn update_weight(
&mut self,
current_weight: f64,
pre_spike: bool,
post_spike: bool,
dt: f64,
) -> f64 {
self.pre_trace *= (-dt / self.tau_plus).exp();
self.post_trace *= (-dt / self.tau_minus).exp();
let mut weight_change = 0.0;
if pre_spike {
self.pre_trace += 1.0;
if self.post_trace > 0.0 {
weight_change -= self.learning_rate * self.post_trace;
}
}
if post_spike {
self.post_trace += 1.0;
if self.pre_trace > 0.0 {
weight_change += self.learning_rate * self.pre_trace;
}
}
(current_weight + weight_change).max(-1.0).min(1.0)
}
}
#[derive(Debug, Clone)]
pub struct AdvancedSTDPNetwork {
pub layers: Vec<STDPLayer>,
pub advanced_stdp_rules: Vec<Vec<AdvancedAdvancedSTDP>>,
pub current_params: Array1<f64>,
pub best_params: Array1<f64>,
pub best_objective: f64,
pub nit: usize,
pub network_stats: NetworkStats,
}
#[derive(Debug, Clone)]
pub struct STDPLayer {
pub size: usize,
pub potentials: Array1<f64>,
pub last_spike_times: Array1<Option<f64>>,
pub firing_rates: Array1<f64>,
}
#[derive(Debug, Clone)]
pub struct NetworkStats {
pub avg_plasticity: f64,
pub synchrony: f64,
pub energy_consumption: f64,
pub convergence: f64,
}
impl Default for NetworkStats {
fn default() -> Self {
Self {
avg_plasticity: 0.0,
synchrony: 0.0,
energy_consumption: 0.0,
convergence: 0.0,
}
}
}
impl AdvancedSTDPNetwork {
pub fn new(layer_sizes: Vec<usize>, target_firing_rate: f64, learning_rate: f64) -> Self {
let mut layers = Vec::new();
let mut advanced_stdp_rules = Vec::new();
for (layer_idx, &size) in layer_sizes.iter().enumerate() {
let layer = STDPLayer {
size,
potentials: Array1::zeros(size),
last_spike_times: Array1::from_vec(vec![None; size]),
firing_rates: Array1::zeros(size),
};
layers.push(layer);
if layer_idx > 0 {
let prev_size = layer_sizes[layer_idx - 1];
let mut layer_rules = Vec::new();
for _i in 0..size {
for _j in 0..prev_size {
layer_rules.push(AdvancedAdvancedSTDP::new(
learning_rate,
learning_rate * 0.5,
target_firing_rate,
));
}
}
advanced_stdp_rules.push(layer_rules);
}
}
let input_size = layer_sizes[0];
Self {
layers,
advanced_stdp_rules,
current_params: Array1::zeros(input_size),
best_params: Array1::zeros(input_size),
best_objective: f64::INFINITY,
nit: 0,
network_stats: NetworkStats::default(),
}
}
pub fn optimize<F>(
&mut self,
objective: F,
initial_params: &ArrayView1<f64>,
max_nit: usize,
dt: f64,
) -> Result<Array1<f64>>
where
F: Fn(&ArrayView1<f64>) -> f64,
{
self.current_params = initial_params.to_owned();
self.best_params = initial_params.to_owned();
self.best_objective = objective(initial_params);
let mut prev_objective = self.best_objective;
for iteration in 0..max_nit {
let current_time = iteration as f64 * dt;
let current_objective = objective(&self.current_params.view());
let objective_improvement = prev_objective - current_objective;
if current_objective < self.best_objective {
self.best_objective = current_objective;
self.best_params = self.current_params.clone();
}
let spike_patterns =
self.encode_parameters_to_spikes(&self.current_params, current_time);
let network_spikes =
self.simulate_network_dynamics(&spike_patterns, current_time, dt)?;
self.update_advanced_stdp_weights(
&network_spikes,
current_time,
dt,
objective_improvement,
)?;
let param_updates = self.decode_parameters_from_network(current_time);
let step_size = self.compute_adaptive_step_size(objective_improvement, iteration);
for (i, update) in param_updates.iter().enumerate() {
if i < self.current_params.len() {
self.current_params[i] += step_size * update;
}
}
self.update_network_statistics(current_time);
if objective_improvement.abs() < 1e-8 && iteration > 100 {
break;
}
prev_objective = current_objective;
self.nit = iteration + 1;
}
Ok(self.best_params.clone())
}
fn encode_parameters_to_spikes(
&self,
params: &Array1<f64>,
_current_time: f64,
) -> Vec<Vec<bool>> {
let mut spike_patterns = Vec::new();
for layer in &self.layers {
let mut layer_spikes = vec![false; layer.size];
for i in 0..layer.size.min(params.len()) {
let spike_prob = ((params[i] + 1.0) / 2.0).max(0.0).min(1.0);
layer_spikes[i] = scirs2_core::random::rng().random::<f64>() < spike_prob * 0.1;
}
spike_patterns.push(layer_spikes);
}
spike_patterns
}
fn simulate_network_dynamics(
&mut self,
input_spikes: &[Vec<bool>],
current_time: f64,
dt: f64,
) -> Result<Vec<Vec<bool>>> {
let mut all_spikes = input_spikes.to_vec();
for layer_idx in 1..self.layers.len() {
let mut layer_spikes = vec![false; self.layers[layer_idx].size];
for neuron_idx in 0..self.layers[layer_idx].size {
let mut input_current = 0.0;
for prev_neuron_idx in 0..self.layers[layer_idx - 1].size {
if all_spikes[layer_idx - 1][prev_neuron_idx] {
input_current += 0.1;
}
}
self.layers[layer_idx].potentials[neuron_idx] +=
dt * (-self.layers[layer_idx].potentials[neuron_idx] + input_current) / 0.02;
if self.layers[layer_idx].potentials[neuron_idx] > 1.0 {
self.layers[layer_idx].potentials[neuron_idx] = 0.0;
self.layers[layer_idx].last_spike_times[neuron_idx] = Some(current_time);
layer_spikes[neuron_idx] = true;
}
let spike_rate = if layer_spikes[neuron_idx] {
1.0 / dt
} else {
0.0
};
self.layers[layer_idx].firing_rates[neuron_idx] =
0.99 * self.layers[layer_idx].firing_rates[neuron_idx] + 0.01 * spike_rate;
}
all_spikes.push(layer_spikes);
}
Ok(all_spikes)
}
fn update_advanced_stdp_weights(
&mut self,
all_spikes: &[Vec<bool>],
current_time: f64,
dt: f64,
objective_improvement: f64,
) -> Result<()> {
for layer_idx in 0..self.advanced_stdp_rules.len() {
let input_spikes = &all_spikes[layer_idx];
let output_spikes = &all_spikes[layer_idx + 1];
for (connection_idx, rule) in self.advanced_stdp_rules[layer_idx].iter_mut().enumerate()
{
let _layer_size = self.layers[layer_idx + 1].size;
let prev_layer_size = self.layers[layer_idx].size;
let neuron_idx = connection_idx / prev_layer_size;
let input_idx = connection_idx % prev_layer_size;
let pre_spike = input_spikes.get(input_idx).copied().unwrap_or(false);
let post_spike = output_spikes.get(neuron_idx).copied().unwrap_or(false);
let _new_weight = rule.update_weight_advanced(
0.5, pre_spike,
post_spike,
dt,
current_time,
objective_improvement,
);
}
}
Ok(())
}
fn decode_parameters_from_network(&self, current_time: f64) -> Array1<f64> {
let mut updates = Array1::zeros(self.current_params.len());
if !self.layers.is_empty() {
for (i, &rate) in self.layers[0].firing_rates.iter().enumerate() {
if i < updates.len() {
updates[i] = (rate - 5.0) * 0.01; }
}
}
updates
}
fn compute_adaptive_step_size(&self, objective_improvement: f64, iteration: usize) -> f64 {
let base_step = 0.01;
let improvement_factor = if objective_improvement > 0.0 {
1.2
} else {
0.8
};
let decay_factor = 1.0 / (1.0 + iteration as f64 * 0.001);
base_step * improvement_factor * decay_factor
}
fn update_network_statistics(&mut self, current_time: f64) {
let mut total_plasticity = 0.0;
let mut count = 0;
for layer_rules in &self.advanced_stdp_rules {
for rule in layer_rules {
let stats = rule.get_plasticity_stats();
total_plasticity += stats.metaplasticity_factor;
count += 1;
}
}
if count > 0 {
self.network_stats.avg_plasticity = total_plasticity / count as f64;
}
let mut synchrony = 0.0;
for layer in &self.layers {
let rate_variance = layer.firing_rates.clone().variance();
synchrony += 1.0 / (1.0 + rate_variance);
}
self.network_stats.synchrony = synchrony / self.layers.len() as f64;
let total_spikes: f64 = self
.layers
.iter()
.map(|layer| layer.firing_rates.sum())
.sum();
self.network_stats.energy_consumption = total_spikes * 1e-12; }
pub fn get_network_stats(&self) -> &NetworkStats {
&self.network_stats
}
}
#[allow(dead_code)]
pub fn stdp_optimize<F>(
objective: F,
initial_params: &ArrayView1<f64>,
num_nit: usize,
) -> Result<Array1<f64>>
where
F: Fn(&ArrayView1<f64>) -> f64,
{
let mut params = initial_params.to_owned();
let mut stdp_rules: Vec<STDPLearningRule> = (0..params.len())
.map(|_| STDPLearningRule::new(0.01))
.collect();
let mut prev_obj = objective(¶ms.view());
for _iter in 0..num_nit {
let current_obj = objective(¶ms.view());
let improvement = prev_obj - current_obj;
for (i, rule) in stdp_rules.iter_mut().enumerate() {
let pre_spike =
scirs2_core::random::rng().random::<f64>() < (params[i].abs() * 0.1).min(0.5);
let post_spike = improvement > 0.0 && scirs2_core::random::rng().random::<f64>() < 0.2;
params[i] = rule.update_weight(params[i], pre_spike, post_spike, 0.001);
}
prev_obj = current_obj;
}
Ok(params)
}
#[allow(dead_code)]
pub fn advanced_stdp_optimize<F>(
objective: F,
initial_params: &ArrayView1<f64>,
max_nit: usize,
network_config: Option<(Vec<usize>, f64, f64)>, ) -> Result<Array1<f64>>
where
F: Fn(&ArrayView1<f64>) -> f64,
{
let (layer_sizes, target_rate, learning_rate) = network_config.unwrap_or_else(|| {
let input_size = initial_params.len();
(vec![input_size, input_size * 2, input_size], 5.0, 0.01)
});
let mut network = AdvancedSTDPNetwork::new(layer_sizes, target_rate, learning_rate);
network.optimize(objective, initial_params, max_nit, 0.001)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_advanced_stdp_creation() {
let stdp = AdvancedAdvancedSTDP::new(0.01, 0.005, 5.0);
assert_eq!(stdp.eta_ltp, 0.01);
assert_eq!(stdp.target_firing_rate, 5.0);
}
#[test]
fn test_advanced_stdp_weight_update() {
let mut stdp = AdvancedAdvancedSTDP::new(0.1, 0.05, 5.0);
let new_weight = stdp.update_weight_advanced(0.5, true, true, 0.001, 0.0, 0.1);
assert!(new_weight.is_finite());
assert!(new_weight >= stdp.w_min && new_weight <= stdp.w_max);
}
#[test]
fn test_advanced_stdp_network() {
let layer_sizes = vec![3, 5, 3];
let network = AdvancedSTDPNetwork::new(layer_sizes, 5.0, 0.01);
assert_eq!(network.layers.len(), 3);
assert_eq!(network.layers[0].size, 3);
assert_eq!(network.layers[1].size, 5);
assert_eq!(network.layers[2].size, 3);
}
#[test]
fn test_plasticity_stats() {
let stdp = AdvancedAdvancedSTDP::new(0.01, 0.005, 5.0);
let stats = stdp.get_plasticity_stats();
assert!(stats.calcium_level >= 0.0);
assert!(stats.metaplasticity_factor > 0.0);
assert!(stats.scaling_factor > 0.0);
}
#[test]
fn test_basic_stdp_optimization() {
let objective = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
let initial = Array1::from(vec![1.0, 1.0]);
let result = stdp_optimize(objective, &initial.view(), 100).expect("Operation failed");
let final_obj = objective(&result.view());
let initial_obj = objective(&initial.view());
assert!(final_obj <= initial_obj);
}
#[test]
fn test_advanced_stdp_optimization() {
let objective = |x: &ArrayView1<f64>| (x[0] - 1.0).powi(2) + (x[1] + 0.5).powi(2);
let initial = Array1::from(vec![0.0, 0.0]);
let result = advanced_stdp_optimize(
objective,
&initial.view(),
50,
Some((vec![2, 4, 2], 3.0, 0.05)),
)
.expect("Operation failed");
assert_eq!(result.len(), 2);
let final_obj = objective(&result.view());
let initial_obj = objective(&initial.view());
assert!(final_obj <= initial_obj * 2.0); }
}
#[allow(dead_code)]
pub fn placeholder() {
}