synapse_models/
plasticity.rs

1//! Synaptic plasticity rules for learning and memory.
2//!
3//! This module implements various plasticity mechanisms including:
4//! - Spike-Timing Dependent Plasticity (STDP)
5//! - BCM (Bienenstock-Cooper-Munro) rule
6//! - Oja's rule
7//! - Hebbian and Anti-Hebbian learning
8//! - Homeostatic plasticity
9//! - Meta-plasticity
10
11use crate::error::{Result, SynapseError};
12
13/// Spike-Timing Dependent Plasticity (STDP) implementation.
14///
15/// STDP is a biological learning rule where synaptic strength changes depend
16/// on the relative timing of pre- and postsynaptic spikes.
17///
18/// Δw = A+ * exp(-Δt/τ+) for Δt > 0 (pre before post, potentiation)
19/// Δw = -A- * exp(Δt/τ-) for Δt < 0 (post before pre, depression)
20#[derive(Debug, Clone)]
21pub struct STDP {
22    /// Amplitude of potentiation.
23    pub a_plus: f64,
24
25    /// Amplitude of depression.
26    pub a_minus: f64,
27
28    /// Time constant for potentiation (ms).
29    pub tau_plus: f64,
30
31    /// Time constant for depression (ms).
32    pub tau_minus: f64,
33
34    /// Minimum synaptic weight.
35    pub w_min: f64,
36
37    /// Maximum synaptic weight.
38    pub w_max: f64,
39
40    /// Whether to use multiplicative updates (vs additive).
41    pub multiplicative: bool,
42
43    /// Last presynaptic spike time (ms).
44    last_pre_spike: Option<f64>,
45
46    /// Last postsynaptic spike time (ms).
47    last_post_spike: Option<f64>,
48
49    /// Accumulated weight change.
50    pub accumulated_dw: f64,
51}
52
53impl Default for STDP {
54    fn default() -> Self {
55        Self {
56            a_plus: 0.01,
57            a_minus: 0.01,
58            tau_plus: 20.0,
59            tau_minus: 20.0,
60            w_min: 0.0,
61            w_max: 1.0,
62            multiplicative: false,
63            last_pre_spike: None,
64            last_post_spike: None,
65            accumulated_dw: 0.0,
66        }
67    }
68}
69
70impl STDP {
71    /// Create new STDP with default parameters.
72    pub fn new() -> Self {
73        Self::default()
74    }
75
76    /// Create STDP with custom parameters.
77    pub fn with_params(a_plus: f64, a_minus: f64, tau_plus: f64, tau_minus: f64) -> Result<Self> {
78        if tau_plus <= 0.0 || tau_minus <= 0.0 {
79            return Err(SynapseError::InvalidTimeConstant(tau_plus.min(tau_minus)));
80        }
81
82        Ok(Self {
83            a_plus,
84            a_minus,
85            tau_plus,
86            tau_minus,
87            ..Self::default()
88        })
89    }
90
91    /// Create STDP with multiplicative updates (weight-dependent).
92    pub fn multiplicative(mut self) -> Self {
93        self.multiplicative = true;
94        self
95    }
96
97    /// Register presynaptic spike and calculate weight change.
98    ///
99    /// # Arguments
100    /// * `time` - Current time (ms)
101    /// * `current_weight` - Current synaptic weight
102    ///
103    /// # Returns
104    /// Weight change (Δw)
105    pub fn pre_spike(&mut self, time: f64, current_weight: f64) -> f64 {
106        let mut dw = 0.0;
107
108        // If there was a recent postsynaptic spike, apply depression
109        if let Some(post_time) = self.last_post_spike {
110            let dt = time - post_time;
111            if dt > 0.0 && dt < 5.0 * self.tau_minus {
112                dw = -self.a_minus * (-dt / self.tau_minus).exp();
113
114                // Multiplicative depression: Δw ∝ w
115                if self.multiplicative {
116                    dw *= current_weight;
117                }
118            }
119        }
120
121        self.last_pre_spike = Some(time);
122        self.accumulated_dw += dw;
123        dw
124    }
125
126    /// Register postsynaptic spike and calculate weight change.
127    ///
128    /// # Arguments
129    /// * `time` - Current time (ms)
130    /// * `current_weight` - Current synaptic weight
131    ///
132    /// # Returns
133    /// Weight change (Δw)
134    pub fn post_spike(&mut self, time: f64, current_weight: f64) -> f64 {
135        let mut dw = 0.0;
136
137        // If there was a recent presynaptic spike, apply potentiation
138        if let Some(pre_time) = self.last_pre_spike {
139            let dt = time - pre_time;
140            if dt > 0.0 && dt < 5.0 * self.tau_plus {
141                dw = self.a_plus * (-dt / self.tau_plus).exp();
142
143                // Multiplicative potentiation: Δw ∝ (w_max - w)
144                if self.multiplicative {
145                    dw *= self.w_max - current_weight;
146                }
147            }
148        }
149
150        self.last_post_spike = Some(time);
151        self.accumulated_dw += dw;
152        dw
153    }
154
155    /// Apply accumulated weight change to synaptic weight.
156    ///
157    /// # Arguments
158    /// * `weight` - Current synaptic weight
159    ///
160    /// # Returns
161    /// New synaptic weight
162    pub fn apply_update(&mut self, weight: f64) -> f64 {
163        let new_weight = (weight + self.accumulated_dw).clamp(self.w_min, self.w_max);
164        self.accumulated_dw = 0.0;
165        new_weight
166    }
167
168    /// Calculate STDP window function for a given time difference.
169    ///
170    /// # Arguments
171    /// * `dt` - Time difference (post - pre) in ms
172    pub fn window(&self, dt: f64) -> f64 {
173        if dt > 0.0 {
174            self.a_plus * (-dt / self.tau_plus).exp()
175        } else {
176            -self.a_minus * (dt / self.tau_minus).exp()
177        }
178    }
179
180    /// Reset STDP state.
181    pub fn reset(&mut self) {
182        self.last_pre_spike = None;
183        self.last_post_spike = None;
184        self.accumulated_dw = 0.0;
185    }
186}
187
188/// BCM (Bienenstock-Cooper-Munro) plasticity rule.
189///
190/// BCM theory proposes that synaptic modification depends on postsynaptic
191/// activity relative to a sliding threshold.
192///
193/// Δw = η * x * (y - θ) * y
194/// where x = presynaptic activity, y = postsynaptic activity, θ = threshold
195#[derive(Debug, Clone)]
196pub struct BCM {
197    /// Learning rate.
198    pub learning_rate: f64,
199
200    /// Modification threshold.
201    pub threshold: f64,
202
203    /// Time constant for threshold adaptation (ms).
204    pub tau_threshold: f64,
205
206    /// Average postsynaptic activity (for threshold update).
207    avg_post_activity: f64,
208
209    /// Minimum weight.
210    pub w_min: f64,
211
212    /// Maximum weight.
213    pub w_max: f64,
214}
215
216impl Default for BCM {
217    fn default() -> Self {
218        Self {
219            learning_rate: 0.001,
220            threshold: 0.5,
221            tau_threshold: 10000.0, // Slow adaptation
222            avg_post_activity: 0.0,
223            w_min: 0.0,
224            w_max: 1.0,
225        }
226    }
227}
228
229impl BCM {
230    /// Create new BCM with default parameters.
231    pub fn new() -> Self {
232        Self::default()
233    }
234
235    /// Update synaptic weight using BCM rule.
236    ///
237    /// # Arguments
238    /// * `pre_activity` - Presynaptic activity (firing rate or activation)
239    /// * `post_activity` - Postsynaptic activity
240    /// * `current_weight` - Current synaptic weight
241    /// * `dt` - Time step (ms)
242    pub fn update(&mut self, pre_activity: f64, post_activity: f64, current_weight: f64, dt: f64) -> f64 {
243        // BCM rule: Δw = η * x * (y - θ) * y
244        let dw = self.learning_rate * pre_activity * (post_activity - self.threshold) * post_activity * dt;
245
246        // Update threshold based on average postsynaptic activity
247        self.avg_post_activity += (post_activity - self.avg_post_activity) / self.tau_threshold * dt;
248        self.threshold = self.avg_post_activity * self.avg_post_activity;
249
250        (current_weight + dw).clamp(self.w_min, self.w_max)
251    }
252
253    /// Reset BCM state.
254    pub fn reset(&mut self) {
255        self.threshold = 0.5;
256        self.avg_post_activity = 0.0;
257    }
258}
259
260/// Oja's rule for normalized Hebbian learning.
261///
262/// Oja's rule prevents unbounded weight growth through normalization.
263///
264/// Δw = η * (y * x - y² * w)
265#[derive(Debug, Clone)]
266pub struct OjasRule {
267    /// Learning rate.
268    pub learning_rate: f64,
269
270    /// Minimum weight.
271    pub w_min: f64,
272
273    /// Maximum weight.
274    pub w_max: f64,
275}
276
277impl Default for OjasRule {
278    fn default() -> Self {
279        Self {
280            learning_rate: 0.001,
281            w_min: 0.0,
282            w_max: 1.0,
283        }
284    }
285}
286
287impl OjasRule {
288    /// Create new Oja's rule with default parameters.
289    pub fn new() -> Self {
290        Self::default()
291    }
292
293    /// Update synaptic weight using Oja's rule.
294    ///
295    /// # Arguments
296    /// * `pre_activity` - Presynaptic activity
297    /// * `post_activity` - Postsynaptic activity
298    /// * `current_weight` - Current synaptic weight
299    /// * `dt` - Time step (ms)
300    pub fn update(&mut self, pre_activity: f64, post_activity: f64, current_weight: f64, dt: f64) -> f64 {
301        // Oja's rule: Δw = η * (y * x - y² * w)
302        let dw = self.learning_rate * (post_activity * pre_activity - post_activity * post_activity * current_weight) * dt;
303
304        (current_weight + dw).clamp(self.w_min, self.w_max)
305    }
306}
307
308/// Hebbian learning rule.
309///
310/// Classic Hebbian learning: "Cells that fire together, wire together."
311///
312/// Δw = η * x * y
313#[derive(Debug, Clone)]
314pub struct HebbianRule {
315    /// Learning rate.
316    pub learning_rate: f64,
317
318    /// Whether to normalize weights.
319    pub normalize: bool,
320
321    /// Minimum weight.
322    pub w_min: f64,
323
324    /// Maximum weight.
325    pub w_max: f64,
326}
327
328impl Default for HebbianRule {
329    fn default() -> Self {
330        Self {
331            learning_rate: 0.001,
332            normalize: false,
333            w_min: 0.0,
334            w_max: 1.0,
335        }
336    }
337}
338
339impl HebbianRule {
340    /// Create new Hebbian rule with default parameters.
341    pub fn new() -> Self {
342        Self::default()
343    }
344
345    /// Create normalized Hebbian rule.
346    pub fn normalized(mut self) -> Self {
347        self.normalize = true;
348        self
349    }
350
351    /// Update synaptic weight using Hebbian rule.
352    ///
353    /// # Arguments
354    /// * `pre_activity` - Presynaptic activity
355    /// * `post_activity` - Postsynaptic activity
356    /// * `current_weight` - Current synaptic weight
357    /// * `dt` - Time step (ms)
358    pub fn update(&mut self, pre_activity: f64, post_activity: f64, current_weight: f64, dt: f64) -> f64 {
359        let dw = if self.normalize {
360            // Normalized: Δw = η * (x * y - w * y²)
361            self.learning_rate * (pre_activity * post_activity - current_weight * post_activity.powi(2)) * dt
362        } else {
363            // Standard: Δw = η * x * y
364            self.learning_rate * pre_activity * post_activity * dt
365        };
366
367        (current_weight + dw).clamp(self.w_min, self.w_max)
368    }
369}
370
371/// Anti-Hebbian learning rule.
372///
373/// Opposite of Hebbian learning, weakens connections between co-active neurons.
374///
375/// Δw = -η * x * y
376#[derive(Debug, Clone)]
377pub struct AntiHebbianRule {
378    /// Learning rate.
379    pub learning_rate: f64,
380
381    /// Minimum weight.
382    pub w_min: f64,
383
384    /// Maximum weight.
385    pub w_max: f64,
386}
387
388impl Default for AntiHebbianRule {
389    fn default() -> Self {
390        Self {
391            learning_rate: 0.001,
392            w_min: 0.0,
393            w_max: 1.0,
394        }
395    }
396}
397
398impl AntiHebbianRule {
399    /// Create new Anti-Hebbian rule with default parameters.
400    pub fn new() -> Self {
401        Self::default()
402    }
403
404    /// Update synaptic weight using Anti-Hebbian rule.
405    pub fn update(&mut self, pre_activity: f64, post_activity: f64, current_weight: f64, dt: f64) -> f64 {
406        let dw = -self.learning_rate * pre_activity * post_activity * dt;
407        (current_weight + dw).clamp(self.w_min, self.w_max)
408    }
409}
410
411/// Homeostatic plasticity for maintaining stable activity levels.
412///
413/// Scales synaptic weights to maintain target firing rate.
414#[derive(Debug, Clone)]
415pub struct HomeostaticPlasticity {
416    /// Target firing rate (Hz).
417    pub target_rate: f64,
418
419    /// Time constant for homeostatic adjustment (ms).
420    pub tau_homeostatic: f64,
421
422    /// Current average firing rate (Hz).
423    avg_rate: f64,
424
425    /// Scaling factor.
426    pub scaling_factor: f64,
427}
428
429impl Default for HomeostaticPlasticity {
430    fn default() -> Self {
431        Self {
432            target_rate: 5.0,           // 5 Hz target
433            tau_homeostatic: 1000000.0, // Very slow (hours)
434            avg_rate: 5.0,
435            scaling_factor: 1.0,
436        }
437    }
438}
439
440impl HomeostaticPlasticity {
441    /// Create new homeostatic plasticity with default parameters.
442    pub fn new() -> Self {
443        Self::default()
444    }
445
446    /// Update homeostatic scaling based on current activity.
447    ///
448    /// # Arguments
449    /// * `current_rate` - Current firing rate (Hz)
450    /// * `dt` - Time step (ms)
451    pub fn update(&mut self, current_rate: f64, dt: f64) {
452        // Update average rate
453        self.avg_rate += (current_rate - self.avg_rate) / self.tau_homeostatic * dt;
454
455        // Update scaling factor
456        // If rate too high, decrease weights; if too low, increase weights
457        let rate_error = self.target_rate - self.avg_rate;
458        self.scaling_factor += rate_error / self.target_rate / self.tau_homeostatic * dt;
459        self.scaling_factor = self.scaling_factor.max(0.1).min(10.0);
460    }
461
462    /// Apply homeostatic scaling to synaptic weight.
463    pub fn apply_scaling(&self, weight: f64) -> f64 {
464        weight * self.scaling_factor
465    }
466
467    /// Reset homeostatic state.
468    pub fn reset(&mut self) {
469        self.avg_rate = self.target_rate;
470        self.scaling_factor = 1.0;
471    }
472}
473
474/// Meta-plasticity: plasticity of plasticity.
475///
476/// Learning rate adapts based on recent synaptic activity.
477#[derive(Debug, Clone)]
478pub struct MetaPlasticity {
479    /// Base learning rate.
480    pub base_learning_rate: f64,
481
482    /// Current learning rate (modulated).
483    pub learning_rate: f64,
484
485    /// Time constant for meta-plasticity (ms).
486    pub tau_meta: f64,
487
488    /// Average synaptic activity.
489    avg_activity: f64,
490
491    /// Threshold for meta-plasticity.
492    pub activity_threshold: f64,
493}
494
495impl Default for MetaPlasticity {
496    fn default() -> Self {
497        Self {
498            base_learning_rate: 0.01,
499            learning_rate: 0.01,
500            tau_meta: 100000.0, // Slow time scale
501            avg_activity: 0.0,
502            activity_threshold: 0.5,
503        }
504    }
505}
506
507impl MetaPlasticity {
508    /// Create new meta-plasticity with default parameters.
509    pub fn new() -> Self {
510        Self::default()
511    }
512
513    /// Update meta-plasticity based on synaptic activity.
514    ///
515    /// # Arguments
516    /// * `activity` - Current synaptic activity level
517    /// * `dt` - Time step (ms)
518    pub fn update(&mut self, activity: f64, dt: f64) {
519        // Update average activity
520        self.avg_activity += (activity - self.avg_activity) / self.tau_meta * dt;
521
522        // Modulate learning rate based on activity
523        // High activity -> lower learning rate (homeostatic)
524        // Low activity -> higher learning rate
525        let modulation = if self.avg_activity > self.activity_threshold {
526            0.5 // Reduce learning rate
527        } else {
528            2.0 // Increase learning rate
529        };
530
531        self.learning_rate = self.base_learning_rate * modulation;
532    }
533
534    /// Get current learning rate.
535    pub fn get_learning_rate(&self) -> f64 {
536        self.learning_rate
537    }
538
539    /// Reset meta-plasticity state.
540    pub fn reset(&mut self) {
541        self.learning_rate = self.base_learning_rate;
542        self.avg_activity = 0.0;
543    }
544}
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549
550    #[test]
551    fn test_stdp_creation() {
552        let stdp = STDP::new();
553        assert_eq!(stdp.a_plus, 0.01);
554        assert_eq!(stdp.a_minus, 0.01);
555    }
556
557    #[test]
558    fn test_stdp_potentiation() {
559        let mut stdp = STDP::new();
560        let weight = 0.5;
561
562        // Pre spike at t=0
563        stdp.pre_spike(0.0, weight);
564
565        // Post spike at t=10 (pre before post -> potentiation)
566        let dw = stdp.post_spike(10.0, weight);
567
568        assert!(dw > 0.0); // Should potentiate
569    }
570
571    #[test]
572    fn test_stdp_depression() {
573        let mut stdp = STDP::new();
574        let weight = 0.5;
575
576        // Post spike at t=0
577        stdp.post_spike(0.0, weight);
578
579        // Pre spike at t=10 (post before pre -> depression)
580        let dw = stdp.pre_spike(10.0, weight);
581
582        assert!(dw < 0.0); // Should depress
583    }
584
585    #[test]
586    fn test_stdp_window() {
587        let stdp = STDP::new();
588
589        let pot = stdp.window(10.0);  // Potentiation
590        let dep = stdp.window(-10.0); // Depression
591
592        assert!(pot > 0.0);
593        assert!(dep < 0.0);
594    }
595
596    #[test]
597    fn test_bcm_rule() {
598        let mut bcm = BCM::new();
599        let weight = 0.5;
600
601        // Low postsynaptic activity -> depression
602        let w1 = bcm.update(1.0, 0.1, weight, 1.0);
603        assert!(w1 < weight);
604
605        // High postsynaptic activity -> potentiation
606        let w2 = bcm.update(1.0, 0.9, weight, 1.0);
607        assert!(w2 > weight);
608    }
609
610    #[test]
611    fn test_ojas_rule() {
612        let mut oja = OjasRule::new();
613        let weight = 0.5;
614
615        let new_weight = oja.update(1.0, 1.0, weight, 1.0);
616        assert!(new_weight >= 0.0 && new_weight <= 1.0);
617    }
618
619    #[test]
620    fn test_hebbian_rule() {
621        let mut hebb = HebbianRule::new();
622        let weight = 0.5;
623
624        // Both active -> strengthen
625        let new_weight = hebb.update(1.0, 1.0, weight, 1.0);
626        assert!(new_weight > weight);
627    }
628
629    #[test]
630    fn test_anti_hebbian_rule() {
631        let mut anti = AntiHebbianRule::new();
632        let weight = 0.5;
633
634        // Both active -> weaken
635        let new_weight = anti.update(1.0, 1.0, weight, 1.0);
636        assert!(new_weight < weight);
637    }
638
639    #[test]
640    fn test_homeostatic_plasticity() {
641        let mut homeo = HomeostaticPlasticity::new();
642
643        // High activity should reduce scaling
644        for _ in 0..100 {
645            homeo.update(10.0, 100.0); // 10 Hz, higher than target
646        }
647        assert!(homeo.scaling_factor < 1.0);
648
649        homeo.reset();
650
651        // Low activity should increase scaling
652        for _ in 0..100 {
653            homeo.update(1.0, 100.0); // 1 Hz, lower than target
654        }
655        assert!(homeo.scaling_factor > 1.0);
656    }
657
658    #[test]
659    fn test_meta_plasticity() {
660        let mut meta = MetaPlasticity::new();
661
662        // High activity should reduce learning rate (needs longer to accumulate)
663        for _ in 0..1000 {
664            meta.update(0.8, 100.0);
665        }
666        // After high activity, learning rate should be modulated down
667        assert!(meta.avg_activity > meta.activity_threshold);
668        assert!(meta.learning_rate < meta.base_learning_rate);
669
670        meta.reset();
671
672        // Low activity should increase learning rate
673        for _ in 0..1000 {
674            meta.update(0.2, 100.0);
675        }
676        assert!(meta.avg_activity < meta.activity_threshold);
677        assert!(meta.learning_rate > meta.base_learning_rate);
678    }
679}