use super::{SimTime, Spike};
use crate::graph::{DynamicGraph, VertexId, Weight};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct STDPConfig {
pub a_plus: f64,
pub a_minus: f64,
pub tau_plus: f64,
pub tau_minus: f64,
pub w_min: f64,
pub w_max: f64,
pub learning_rate: f64,
pub tau_eligibility: f64,
}
impl Default for STDPConfig {
fn default() -> Self {
Self {
a_plus: 0.01,
a_minus: 0.012,
tau_plus: 20.0,
tau_minus: 20.0,
w_min: 0.0,
w_max: 1.0,
learning_rate: 1.0,
tau_eligibility: 1000.0,
}
}
}
#[derive(Debug, Clone)]
pub struct Synapse {
pub pre: usize,
pub post: usize,
pub weight: f64,
pub delay: f64,
pub eligibility: f64,
pub last_update: SimTime,
}
impl Synapse {
pub fn new(pre: usize, post: usize, weight: f64) -> Self {
Self {
pre,
post,
weight,
delay: 1.0,
eligibility: 0.0,
last_update: 0.0,
}
}
pub fn with_delay(pre: usize, post: usize, weight: f64, delay: f64) -> Self {
Self {
pre,
post,
weight,
delay,
eligibility: 0.0,
last_update: 0.0,
}
}
pub fn stdp_update(&mut self, t_pre: SimTime, t_post: SimTime, config: &STDPConfig) -> f64 {
let dt = t_post - t_pre;
let dw = if dt > 0.0 {
config.a_plus * (-dt / config.tau_plus).exp()
} else {
-config.a_minus * (dt / config.tau_minus).exp()
};
let delta = config.learning_rate * dw;
self.weight = (self.weight + delta).clamp(config.w_min, config.w_max);
self.eligibility += dw;
delta
}
pub fn decay_eligibility(&mut self, dt: f64, tau: f64) {
self.eligibility *= (-dt / tau).exp();
}
pub fn reward_modulated_update(&mut self, reward: f64, config: &STDPConfig) {
let delta = reward * self.eligibility * config.learning_rate;
self.weight = (self.weight + delta).clamp(config.w_min, config.w_max);
self.eligibility *= 0.5;
}
}
#[derive(Debug, Clone)]
pub struct SynapseMatrix {
pub n_pre: usize,
pub n_post: usize,
synapses: HashMap<(usize, usize), Synapse>,
pub config: STDPConfig,
pre_spike_times: Vec<SimTime>,
post_spike_times: Vec<SimTime>,
}
impl SynapseMatrix {
pub fn new(n_pre: usize, n_post: usize) -> Self {
Self {
n_pre,
n_post,
synapses: HashMap::new(),
config: STDPConfig::default(),
pre_spike_times: vec![f64::NEG_INFINITY; n_pre],
post_spike_times: vec![f64::NEG_INFINITY; n_post],
}
}
pub fn with_config(n_pre: usize, n_post: usize, config: STDPConfig) -> Self {
Self {
n_pre,
n_post,
synapses: HashMap::new(),
config,
pre_spike_times: vec![f64::NEG_INFINITY; n_pre],
post_spike_times: vec![f64::NEG_INFINITY; n_post],
}
}
pub fn add_synapse(&mut self, pre: usize, post: usize, weight: f64) {
if pre < self.n_pre && post < self.n_post {
self.synapses
.insert((pre, post), Synapse::new(pre, post, weight));
}
}
pub fn get_synapse(&self, pre: usize, post: usize) -> Option<&Synapse> {
self.synapses.get(&(pre, post))
}
pub fn get_synapse_mut(&mut self, pre: usize, post: usize) -> Option<&mut Synapse> {
self.synapses.get_mut(&(pre, post))
}
pub fn weight(&self, pre: usize, post: usize) -> f64 {
self.get_synapse(pre, post).map(|s| s.weight).unwrap_or(0.0)
}
#[inline]
pub fn compute_weighted_sums(&self, pre_activations: &[f64]) -> Vec<f64> {
let mut sums = vec![0.0; self.n_post];
for (&(pre, post), synapse) in &self.synapses {
if pre < pre_activations.len() {
sums[post] += synapse.weight * pre_activations[pre];
}
}
sums
}
#[inline]
pub fn weighted_sum_for_post(&self, post: usize, pre_activations: &[f64]) -> f64 {
let mut sum = 0.0;
for pre in 0..self.n_pre.min(pre_activations.len()) {
if let Some(synapse) = self.synapses.get(&(pre, post)) {
sum += synapse.weight * pre_activations[pre];
}
}
sum
}
pub fn set_weight(&mut self, pre: usize, post: usize, weight: f64) {
if let Some(synapse) = self.get_synapse_mut(pre, post) {
synapse.weight = weight;
} else {
self.add_synapse(pre, post, weight);
}
}
pub fn on_pre_spike(&mut self, pre: usize, time: SimTime) {
if pre >= self.n_pre {
return;
}
self.pre_spike_times[pre] = time;
for post in 0..self.n_post {
if let Some(synapse) = self.synapses.get_mut(&(pre, post)) {
let t_post = self.post_spike_times[post];
if t_post > f64::NEG_INFINITY {
synapse.stdp_update(time, t_post, &self.config);
}
}
}
}
pub fn on_post_spike(&mut self, post: usize, time: SimTime) {
if post >= self.n_post {
return;
}
self.post_spike_times[post] = time;
for pre in 0..self.n_pre {
if let Some(synapse) = self.synapses.get_mut(&(pre, post)) {
let t_pre = self.pre_spike_times[pre];
if t_pre > f64::NEG_INFINITY {
synapse.stdp_update(t_pre, time, &self.config);
}
}
}
}
pub fn process_spikes(&mut self, spikes: &[Spike]) {
for spike in spikes {
if spike.neuron_id < self.n_pre {
self.on_pre_spike(spike.neuron_id, spike.time);
}
if spike.neuron_id < self.n_post {
self.on_post_spike(spike.neuron_id, spike.time);
}
}
}
pub fn decay_eligibility(&mut self, dt: f64) {
for synapse in self.synapses.values_mut() {
synapse.decay_eligibility(dt, self.config.tau_eligibility);
}
}
pub fn apply_reward(&mut self, reward: f64) {
for synapse in self.synapses.values_mut() {
synapse.reward_modulated_update(reward, &self.config);
}
}
pub fn iter(&self) -> impl Iterator<Item = (&(usize, usize), &Synapse)> {
self.synapses.iter()
}
pub fn num_synapses(&self) -> usize {
self.synapses.len()
}
pub fn input_to(&self, post: usize, pre_activities: &[f64]) -> f64 {
let mut total = 0.0;
for pre in 0..self.n_pre.min(pre_activities.len()) {
total += self.weight(pre, post) * pre_activities[pre];
}
total
}
pub fn to_dense(&self) -> Vec<Vec<f64>> {
let mut matrix = vec![vec![0.0; self.n_post]; self.n_pre];
for ((pre, post), synapse) in &self.synapses {
matrix[*pre][*post] = synapse.weight;
}
matrix
}
pub fn from_dense(matrix: &[Vec<f64>]) -> Self {
let n_pre = matrix.len();
let n_post = matrix.first().map(|r| r.len()).unwrap_or(0);
let mut sm = Self::new(n_pre, n_post);
for (pre, row) in matrix.iter().enumerate() {
for (post, &weight) in row.iter().enumerate() {
if weight != 0.0 {
sm.add_synapse(pre, post, weight);
}
}
}
sm
}
pub fn sync_to_graph<F>(&self, graph: &mut DynamicGraph, neuron_to_vertex: F)
where
F: Fn(usize) -> VertexId,
{
for ((pre, post), synapse) in &self.synapses {
let u = neuron_to_vertex(*pre);
let v = neuron_to_vertex(*post);
if graph.has_edge(u, v) {
let _ = graph.update_edge_weight(u, v, synapse.weight);
}
}
}
pub fn sync_from_graph<F>(&mut self, graph: &DynamicGraph, vertex_to_neuron: F)
where
F: Fn(VertexId) -> usize,
{
for edge in graph.edges() {
let pre = vertex_to_neuron(edge.source);
let post = vertex_to_neuron(edge.target);
if pre < self.n_pre && post < self.n_post {
self.set_weight(pre, post, edge.weight);
}
}
}
pub fn high_correlation_pairs(&self, threshold: f64) -> Vec<(usize, usize)> {
self.synapses
.iter()
.filter(|(_, s)| s.weight >= threshold)
.map(|((pre, post), _)| (*pre, *post))
.collect()
}
}
#[derive(Debug, Clone)]
pub struct AsymmetricSTDP {
pub tau_forward: f64,
pub tau_backward: f64,
pub a_forward: f64,
pub a_backward: f64,
}
impl Default for AsymmetricSTDP {
fn default() -> Self {
Self {
tau_forward: 15.0,
tau_backward: 30.0, a_forward: 0.015, a_backward: 0.008, }
}
}
impl AsymmetricSTDP {
pub fn compute_dw(&self, dt: f64) -> f64 {
if dt > 0.0 {
self.a_forward * (-dt / self.tau_forward).exp()
} else {
-self.a_backward * (dt / self.tau_backward).exp()
}
}
pub fn update_weights(&self, matrix: &mut SynapseMatrix, neuron_id: usize, time: SimTime) {
let w_min = matrix.config.w_min;
let w_max = matrix.config.w_max;
let n_pre = matrix.n_pre;
let n_post = matrix.n_post;
let pre_times: Vec<_> = (0..n_pre)
.map(|pre| {
matrix
.pre_spike_times
.get(pre)
.copied()
.unwrap_or(f64::NEG_INFINITY)
})
.collect();
for pre in 0..n_pre {
let t_pre = pre_times[pre];
if t_pre > f64::NEG_INFINITY {
let dt = time - t_pre;
let dw = self.compute_dw(dt);
if let Some(synapse) = matrix.get_synapse_mut(pre, neuron_id) {
synapse.weight = (synapse.weight + dw).clamp(w_min, w_max);
}
}
}
let post_times: Vec<_> = (0..n_post)
.map(|post| {
matrix
.post_spike_times
.get(post)
.copied()
.unwrap_or(f64::NEG_INFINITY)
})
.collect();
for post in 0..n_post {
let t_post = post_times[post];
if t_post > f64::NEG_INFINITY {
let dt = t_post - time; let dw = self.compute_dw(dt);
if let Some(synapse) = matrix.get_synapse_mut(neuron_id, post) {
synapse.weight = (synapse.weight + dw).clamp(w_min, w_max);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_synapse_creation() {
let synapse = Synapse::new(0, 1, 0.5);
assert_eq!(synapse.pre, 0);
assert_eq!(synapse.post, 1);
assert_eq!(synapse.weight, 0.5);
}
#[test]
fn test_stdp_ltp() {
let mut synapse = Synapse::new(0, 1, 0.5);
let config = STDPConfig::default();
let dw = synapse.stdp_update(10.0, 15.0, &config);
assert!(dw > 0.0);
assert!(synapse.weight > 0.5);
}
#[test]
fn test_stdp_ltd() {
let mut synapse = Synapse::new(0, 1, 0.5);
let config = STDPConfig::default();
let dw = synapse.stdp_update(15.0, 10.0, &config);
assert!(dw < 0.0);
assert!(synapse.weight < 0.5);
}
#[test]
fn test_synapse_matrix() {
let mut matrix = SynapseMatrix::new(10, 10);
matrix.add_synapse(0, 1, 0.5);
matrix.add_synapse(1, 2, 0.3);
assert_eq!(matrix.num_synapses(), 2);
assert!((matrix.weight(0, 1) - 0.5).abs() < 0.001);
assert!((matrix.weight(1, 2) - 0.3).abs() < 0.001);
assert_eq!(matrix.weight(2, 3), 0.0);
}
#[test]
fn test_spike_processing() {
let mut matrix = SynapseMatrix::new(5, 5);
for i in 0..5 {
for j in 0..5 {
if i != j {
matrix.add_synapse(i, j, 0.5);
}
}
}
matrix.on_pre_spike(0, 10.0);
matrix.on_post_spike(1, 15.0);
assert!(matrix.weight(0, 1) > 0.5);
}
#[test]
fn test_asymmetric_stdp() {
let stdp = AsymmetricSTDP::default();
let dw_causal = stdp.compute_dw(5.0);
let dw_anticausal = stdp.compute_dw(-5.0);
assert!(dw_causal > 0.0);
assert!(dw_anticausal < 0.0);
assert!(dw_causal.abs() > dw_anticausal.abs());
}
#[test]
fn test_dense_conversion() {
let mut matrix = SynapseMatrix::new(3, 3);
matrix.add_synapse(0, 1, 0.5);
matrix.add_synapse(1, 2, 0.7);
let dense = matrix.to_dense();
assert_eq!(dense.len(), 3);
assert!((dense[0][1] - 0.5).abs() < 0.001);
assert!((dense[1][2] - 0.7).abs() < 0.001);
let recovered = SynapseMatrix::from_dense(&dense);
assert_eq!(recovered.num_synapses(), 2);
}
}