use crate::error::{Result, SynapseError};
#[derive(Debug, Clone)]
pub struct STDP {
pub a_plus: f64,
pub a_minus: f64,
pub tau_plus: f64,
pub tau_minus: f64,
pub w_min: f64,
pub w_max: f64,
pub multiplicative: bool,
last_pre_spike: Option<f64>,
last_post_spike: Option<f64>,
pub accumulated_dw: f64,
}
impl Default for STDP {
fn default() -> Self {
Self {
a_plus: 0.01,
a_minus: 0.01,
tau_plus: 20.0,
tau_minus: 20.0,
w_min: 0.0,
w_max: 1.0,
multiplicative: false,
last_pre_spike: None,
last_post_spike: None,
accumulated_dw: 0.0,
}
}
}
impl STDP {
pub fn new() -> Self {
Self::default()
}
pub fn with_params(a_plus: f64, a_minus: f64, tau_plus: f64, tau_minus: f64) -> Result<Self> {
if tau_plus <= 0.0 || tau_minus <= 0.0 {
return Err(SynapseError::InvalidTimeConstant(tau_plus.min(tau_minus)));
}
Ok(Self {
a_plus,
a_minus,
tau_plus,
tau_minus,
..Self::default()
})
}
pub fn multiplicative(mut self) -> Self {
self.multiplicative = true;
self
}
pub fn pre_spike(&mut self, time: f64, current_weight: f64) -> f64 {
let mut dw = 0.0;
if let Some(post_time) = self.last_post_spike {
let dt = time - post_time;
if dt > 0.0 && dt < 5.0 * self.tau_minus {
dw = -self.a_minus * (-dt / self.tau_minus).exp();
if self.multiplicative {
dw *= current_weight;
}
}
}
self.last_pre_spike = Some(time);
self.accumulated_dw += dw;
dw
}
pub fn post_spike(&mut self, time: f64, current_weight: f64) -> f64 {
let mut dw = 0.0;
if let Some(pre_time) = self.last_pre_spike {
let dt = time - pre_time;
if dt > 0.0 && dt < 5.0 * self.tau_plus {
dw = self.a_plus * (-dt / self.tau_plus).exp();
if self.multiplicative {
dw *= self.w_max - current_weight;
}
}
}
self.last_post_spike = Some(time);
self.accumulated_dw += dw;
dw
}
pub fn apply_update(&mut self, weight: f64) -> f64 {
let new_weight = (weight + self.accumulated_dw).clamp(self.w_min, self.w_max);
self.accumulated_dw = 0.0;
new_weight
}
pub fn window(&self, dt: f64) -> f64 {
if dt > 0.0 {
self.a_plus * (-dt / self.tau_plus).exp()
} else {
-self.a_minus * (dt / self.tau_minus).exp()
}
}
pub fn reset(&mut self) {
self.last_pre_spike = None;
self.last_post_spike = None;
self.accumulated_dw = 0.0;
}
}
#[derive(Debug, Clone)]
pub struct BCM {
pub learning_rate: f64,
pub threshold: f64,
pub tau_threshold: f64,
avg_post_activity: f64,
pub w_min: f64,
pub w_max: f64,
}
impl Default for BCM {
fn default() -> Self {
Self {
learning_rate: 0.001,
threshold: 0.5,
tau_threshold: 10000.0, avg_post_activity: 0.0,
w_min: 0.0,
w_max: 1.0,
}
}
}
impl BCM {
pub fn new() -> Self {
Self::default()
}
pub fn update(&mut self, pre_activity: f64, post_activity: f64, current_weight: f64, dt: f64) -> f64 {
let dw = self.learning_rate * pre_activity * (post_activity - self.threshold) * post_activity * dt;
self.avg_post_activity += (post_activity - self.avg_post_activity) / self.tau_threshold * dt;
self.threshold = self.avg_post_activity * self.avg_post_activity;
(current_weight + dw).clamp(self.w_min, self.w_max)
}
pub fn reset(&mut self) {
self.threshold = 0.5;
self.avg_post_activity = 0.0;
}
}
#[derive(Debug, Clone)]
pub struct OjasRule {
pub learning_rate: f64,
pub w_min: f64,
pub w_max: f64,
}
impl Default for OjasRule {
fn default() -> Self {
Self {
learning_rate: 0.001,
w_min: 0.0,
w_max: 1.0,
}
}
}
impl OjasRule {
pub fn new() -> Self {
Self::default()
}
pub fn update(&mut self, pre_activity: f64, post_activity: f64, current_weight: f64, dt: f64) -> f64 {
let dw = self.learning_rate * (post_activity * pre_activity - post_activity * post_activity * current_weight) * dt;
(current_weight + dw).clamp(self.w_min, self.w_max)
}
}
#[derive(Debug, Clone)]
pub struct HebbianRule {
pub learning_rate: f64,
pub normalize: bool,
pub w_min: f64,
pub w_max: f64,
}
impl Default for HebbianRule {
fn default() -> Self {
Self {
learning_rate: 0.001,
normalize: false,
w_min: 0.0,
w_max: 1.0,
}
}
}
impl HebbianRule {
pub fn new() -> Self {
Self::default()
}
pub fn normalized(mut self) -> Self {
self.normalize = true;
self
}
pub fn update(&mut self, pre_activity: f64, post_activity: f64, current_weight: f64, dt: f64) -> f64 {
let dw = if self.normalize {
self.learning_rate * (pre_activity * post_activity - current_weight * post_activity.powi(2)) * dt
} else {
self.learning_rate * pre_activity * post_activity * dt
};
(current_weight + dw).clamp(self.w_min, self.w_max)
}
}
#[derive(Debug, Clone)]
pub struct AntiHebbianRule {
pub learning_rate: f64,
pub w_min: f64,
pub w_max: f64,
}
impl Default for AntiHebbianRule {
fn default() -> Self {
Self {
learning_rate: 0.001,
w_min: 0.0,
w_max: 1.0,
}
}
}
impl AntiHebbianRule {
pub fn new() -> Self {
Self::default()
}
pub fn update(&mut self, pre_activity: f64, post_activity: f64, current_weight: f64, dt: f64) -> f64 {
let dw = -self.learning_rate * pre_activity * post_activity * dt;
(current_weight + dw).clamp(self.w_min, self.w_max)
}
}
#[derive(Debug, Clone)]
pub struct HomeostaticPlasticity {
pub target_rate: f64,
pub tau_homeostatic: f64,
avg_rate: f64,
pub scaling_factor: f64,
}
impl Default for HomeostaticPlasticity {
fn default() -> Self {
Self {
target_rate: 5.0, tau_homeostatic: 1000000.0, avg_rate: 5.0,
scaling_factor: 1.0,
}
}
}
impl HomeostaticPlasticity {
pub fn new() -> Self {
Self::default()
}
pub fn update(&mut self, current_rate: f64, dt: f64) {
self.avg_rate += (current_rate - self.avg_rate) / self.tau_homeostatic * dt;
let rate_error = self.target_rate - self.avg_rate;
self.scaling_factor += rate_error / self.target_rate / self.tau_homeostatic * dt;
self.scaling_factor = self.scaling_factor.max(0.1).min(10.0);
}
pub fn apply_scaling(&self, weight: f64) -> f64 {
weight * self.scaling_factor
}
pub fn reset(&mut self) {
self.avg_rate = self.target_rate;
self.scaling_factor = 1.0;
}
}
#[derive(Debug, Clone)]
pub struct MetaPlasticity {
pub base_learning_rate: f64,
pub learning_rate: f64,
pub tau_meta: f64,
avg_activity: f64,
pub activity_threshold: f64,
}
impl Default for MetaPlasticity {
fn default() -> Self {
Self {
base_learning_rate: 0.01,
learning_rate: 0.01,
tau_meta: 100000.0, avg_activity: 0.0,
activity_threshold: 0.5,
}
}
}
impl MetaPlasticity {
pub fn new() -> Self {
Self::default()
}
pub fn update(&mut self, activity: f64, dt: f64) {
self.avg_activity += (activity - self.avg_activity) / self.tau_meta * dt;
let modulation = if self.avg_activity > self.activity_threshold {
0.5 } else {
2.0 };
self.learning_rate = self.base_learning_rate * modulation;
}
pub fn get_learning_rate(&self) -> f64 {
self.learning_rate
}
pub fn reset(&mut self) {
self.learning_rate = self.base_learning_rate;
self.avg_activity = 0.0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stdp_creation() {
let stdp = STDP::new();
assert_eq!(stdp.a_plus, 0.01);
assert_eq!(stdp.a_minus, 0.01);
}
#[test]
fn test_stdp_potentiation() {
let mut stdp = STDP::new();
let weight = 0.5;
stdp.pre_spike(0.0, weight);
let dw = stdp.post_spike(10.0, weight);
assert!(dw > 0.0); }
#[test]
fn test_stdp_depression() {
let mut stdp = STDP::new();
let weight = 0.5;
stdp.post_spike(0.0, weight);
let dw = stdp.pre_spike(10.0, weight);
assert!(dw < 0.0); }
#[test]
fn test_stdp_window() {
let stdp = STDP::new();
let pot = stdp.window(10.0); let dep = stdp.window(-10.0);
assert!(pot > 0.0);
assert!(dep < 0.0);
}
#[test]
fn test_bcm_rule() {
let mut bcm = BCM::new();
let weight = 0.5;
let w1 = bcm.update(1.0, 0.1, weight, 1.0);
assert!(w1 < weight);
let w2 = bcm.update(1.0, 0.9, weight, 1.0);
assert!(w2 > weight);
}
#[test]
fn test_ojas_rule() {
let mut oja = OjasRule::new();
let weight = 0.5;
let new_weight = oja.update(1.0, 1.0, weight, 1.0);
assert!(new_weight >= 0.0 && new_weight <= 1.0);
}
#[test]
fn test_hebbian_rule() {
let mut hebb = HebbianRule::new();
let weight = 0.5;
let new_weight = hebb.update(1.0, 1.0, weight, 1.0);
assert!(new_weight > weight);
}
#[test]
fn test_anti_hebbian_rule() {
let mut anti = AntiHebbianRule::new();
let weight = 0.5;
let new_weight = anti.update(1.0, 1.0, weight, 1.0);
assert!(new_weight < weight);
}
#[test]
fn test_homeostatic_plasticity() {
let mut homeo = HomeostaticPlasticity::new();
for _ in 0..100 {
homeo.update(10.0, 100.0); }
assert!(homeo.scaling_factor < 1.0);
homeo.reset();
for _ in 0..100 {
homeo.update(1.0, 100.0); }
assert!(homeo.scaling_factor > 1.0);
}
#[test]
fn test_meta_plasticity() {
let mut meta = MetaPlasticity::new();
for _ in 0..1000 {
meta.update(0.8, 100.0);
}
assert!(meta.avg_activity > meta.activity_threshold);
assert!(meta.learning_rate < meta.base_learning_rate);
meta.reset();
for _ in 0..1000 {
meta.update(0.2, 100.0);
}
assert!(meta.avg_activity < meta.activity_threshold);
assert!(meta.learning_rate > meta.base_learning_rate);
}
}