#![allow(unused)]
#[derive(Debug, Clone, Copy)]
pub struct STDPConfig {
pub tau_pre: f32,
pub tau_post: f32,
pub a_plus: f32,
pub a_minus: f32,
}
impl Default for STDPConfig {
fn default() -> Self {
Self {
tau_pre: 20.0,
tau_post: 20.0,
a_plus: 0.01,
a_minus: 0.012,
}
}
}
#[inline]
pub fn compute_stdp_weight_change(dt: i32, config: &STDPConfig) -> f32 {
if dt > 0 {
let dt_f = dt as f32;
config.a_plus * (-dt_f / config.tau_pre.max(1e-6)).exp()
} else if dt < 0 {
let dt_f = (-dt) as f32;
-config.a_minus * (-dt_f / config.tau_post.max(1e-6)).exp()
} else {
config.a_plus
}
}
#[inline]
pub fn update_weight_stdp(current_weight: u8, dt: i32, config: &STDPConfig) -> u8 {
let delta_w = compute_stdp_weight_change(dt, config);
let new_weight = current_weight as f32 + delta_w * 255.0;
new_weight.clamp(0.0, 255.0) as u8
}
#[inline]
pub fn compute_stdp_batch(time_diffs: &[i32], config: &STDPConfig, weight_changes: &mut [f32]) {
let count = time_diffs.len();
debug_assert_eq!(weight_changes.len(), count);
for i in 0..count {
weight_changes[i] = compute_stdp_weight_change(time_diffs[i], config);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stdp_potentiation() {
let config = STDPConfig::default();
let delta_w = compute_stdp_weight_change(5, &config);
assert!(delta_w > 0.0);
assert!(delta_w <= config.a_plus);
}
#[test]
fn test_stdp_depression() {
let config = STDPConfig::default();
let delta_w = compute_stdp_weight_change(-5, &config);
assert!(delta_w < 0.0);
assert!(delta_w >= -config.a_minus);
}
#[test]
fn test_stdp_same_time() {
let config = STDPConfig::default();
let delta_w = compute_stdp_weight_change(0, &config);
assert_eq!(delta_w, config.a_plus);
}
#[test]
fn test_stdp_exponential_decay() {
let config = STDPConfig::default();
let delta_w1 = compute_stdp_weight_change(1, &config);
let delta_w2 = compute_stdp_weight_change(10, &config);
assert!(delta_w1 > delta_w2); }
#[test]
fn test_update_weight_potentiation() {
let config = STDPConfig::default();
let weight = 128;
let new_weight = update_weight_stdp(weight, 5, &config);
assert!(new_weight > weight);
}
#[test]
fn test_update_weight_depression() {
let config = STDPConfig::default();
let weight = 128;
let new_weight = update_weight_stdp(weight, -5, &config);
assert!(new_weight < weight);
}
#[test]
fn test_update_weight_clamp_high() {
let config = STDPConfig {
a_plus: 1.0, ..Default::default()
};
let weight = 250;
let new_weight = update_weight_stdp(weight, 1, &config);
assert_eq!(new_weight, 255); }
#[test]
fn test_update_weight_clamp_low() {
let config = STDPConfig {
a_minus: 1.0, ..Default::default()
};
let weight = 5;
let new_weight = update_weight_stdp(weight, -1, &config);
assert_eq!(new_weight, 0); }
#[test]
fn test_batch_computation() {
let config = STDPConfig::default();
let time_diffs = [5, -5, 0, 10];
let mut changes = [0.0; 4];
compute_stdp_batch(&time_diffs, &config, &mut changes);
assert!(changes[0] > 0.0); assert!(changes[1] < 0.0); assert_eq!(changes[2], config.a_plus); assert!(changes[3] > 0.0); assert!(changes[0] > changes[3]); }
}