1use crate::error::{Result, SynapseError};
12
13#[derive(Debug, Clone)]
21pub struct STDP {
22 pub a_plus: f64,
24
25 pub a_minus: f64,
27
28 pub tau_plus: f64,
30
31 pub tau_minus: f64,
33
34 pub w_min: f64,
36
37 pub w_max: f64,
39
40 pub multiplicative: bool,
42
43 last_pre_spike: Option<f64>,
45
46 last_post_spike: Option<f64>,
48
49 pub accumulated_dw: f64,
51}
52
53impl Default for STDP {
54 fn default() -> Self {
55 Self {
56 a_plus: 0.01,
57 a_minus: 0.01,
58 tau_plus: 20.0,
59 tau_minus: 20.0,
60 w_min: 0.0,
61 w_max: 1.0,
62 multiplicative: false,
63 last_pre_spike: None,
64 last_post_spike: None,
65 accumulated_dw: 0.0,
66 }
67 }
68}
69
70impl STDP {
71 pub fn new() -> Self {
73 Self::default()
74 }
75
76 pub fn with_params(a_plus: f64, a_minus: f64, tau_plus: f64, tau_minus: f64) -> Result<Self> {
78 if tau_plus <= 0.0 || tau_minus <= 0.0 {
79 return Err(SynapseError::InvalidTimeConstant(tau_plus.min(tau_minus)));
80 }
81
82 Ok(Self {
83 a_plus,
84 a_minus,
85 tau_plus,
86 tau_minus,
87 ..Self::default()
88 })
89 }
90
91 pub fn multiplicative(mut self) -> Self {
93 self.multiplicative = true;
94 self
95 }
96
97 pub fn pre_spike(&mut self, time: f64, current_weight: f64) -> f64 {
106 let mut dw = 0.0;
107
108 if let Some(post_time) = self.last_post_spike {
110 let dt = time - post_time;
111 if dt > 0.0 && dt < 5.0 * self.tau_minus {
112 dw = -self.a_minus * (-dt / self.tau_minus).exp();
113
114 if self.multiplicative {
116 dw *= current_weight;
117 }
118 }
119 }
120
121 self.last_pre_spike = Some(time);
122 self.accumulated_dw += dw;
123 dw
124 }
125
126 pub fn post_spike(&mut self, time: f64, current_weight: f64) -> f64 {
135 let mut dw = 0.0;
136
137 if let Some(pre_time) = self.last_pre_spike {
139 let dt = time - pre_time;
140 if dt > 0.0 && dt < 5.0 * self.tau_plus {
141 dw = self.a_plus * (-dt / self.tau_plus).exp();
142
143 if self.multiplicative {
145 dw *= self.w_max - current_weight;
146 }
147 }
148 }
149
150 self.last_post_spike = Some(time);
151 self.accumulated_dw += dw;
152 dw
153 }
154
155 pub fn apply_update(&mut self, weight: f64) -> f64 {
163 let new_weight = (weight + self.accumulated_dw).clamp(self.w_min, self.w_max);
164 self.accumulated_dw = 0.0;
165 new_weight
166 }
167
168 pub fn window(&self, dt: f64) -> f64 {
173 if dt > 0.0 {
174 self.a_plus * (-dt / self.tau_plus).exp()
175 } else {
176 -self.a_minus * (dt / self.tau_minus).exp()
177 }
178 }
179
180 pub fn reset(&mut self) {
182 self.last_pre_spike = None;
183 self.last_post_spike = None;
184 self.accumulated_dw = 0.0;
185 }
186}
187
188#[derive(Debug, Clone)]
196pub struct BCM {
197 pub learning_rate: f64,
199
200 pub threshold: f64,
202
203 pub tau_threshold: f64,
205
206 avg_post_activity: f64,
208
209 pub w_min: f64,
211
212 pub w_max: f64,
214}
215
216impl Default for BCM {
217 fn default() -> Self {
218 Self {
219 learning_rate: 0.001,
220 threshold: 0.5,
221 tau_threshold: 10000.0, avg_post_activity: 0.0,
223 w_min: 0.0,
224 w_max: 1.0,
225 }
226 }
227}
228
229impl BCM {
230 pub fn new() -> Self {
232 Self::default()
233 }
234
235 pub fn update(&mut self, pre_activity: f64, post_activity: f64, current_weight: f64, dt: f64) -> f64 {
243 let dw = self.learning_rate * pre_activity * (post_activity - self.threshold) * post_activity * dt;
245
246 self.avg_post_activity += (post_activity - self.avg_post_activity) / self.tau_threshold * dt;
248 self.threshold = self.avg_post_activity * self.avg_post_activity;
249
250 (current_weight + dw).clamp(self.w_min, self.w_max)
251 }
252
253 pub fn reset(&mut self) {
255 self.threshold = 0.5;
256 self.avg_post_activity = 0.0;
257 }
258}
259
260#[derive(Debug, Clone)]
266pub struct OjasRule {
267 pub learning_rate: f64,
269
270 pub w_min: f64,
272
273 pub w_max: f64,
275}
276
277impl Default for OjasRule {
278 fn default() -> Self {
279 Self {
280 learning_rate: 0.001,
281 w_min: 0.0,
282 w_max: 1.0,
283 }
284 }
285}
286
287impl OjasRule {
288 pub fn new() -> Self {
290 Self::default()
291 }
292
293 pub fn update(&mut self, pre_activity: f64, post_activity: f64, current_weight: f64, dt: f64) -> f64 {
301 let dw = self.learning_rate * (post_activity * pre_activity - post_activity * post_activity * current_weight) * dt;
303
304 (current_weight + dw).clamp(self.w_min, self.w_max)
305 }
306}
307
308#[derive(Debug, Clone)]
314pub struct HebbianRule {
315 pub learning_rate: f64,
317
318 pub normalize: bool,
320
321 pub w_min: f64,
323
324 pub w_max: f64,
326}
327
328impl Default for HebbianRule {
329 fn default() -> Self {
330 Self {
331 learning_rate: 0.001,
332 normalize: false,
333 w_min: 0.0,
334 w_max: 1.0,
335 }
336 }
337}
338
339impl HebbianRule {
340 pub fn new() -> Self {
342 Self::default()
343 }
344
345 pub fn normalized(mut self) -> Self {
347 self.normalize = true;
348 self
349 }
350
351 pub fn update(&mut self, pre_activity: f64, post_activity: f64, current_weight: f64, dt: f64) -> f64 {
359 let dw = if self.normalize {
360 self.learning_rate * (pre_activity * post_activity - current_weight * post_activity.powi(2)) * dt
362 } else {
363 self.learning_rate * pre_activity * post_activity * dt
365 };
366
367 (current_weight + dw).clamp(self.w_min, self.w_max)
368 }
369}
370
371#[derive(Debug, Clone)]
377pub struct AntiHebbianRule {
378 pub learning_rate: f64,
380
381 pub w_min: f64,
383
384 pub w_max: f64,
386}
387
388impl Default for AntiHebbianRule {
389 fn default() -> Self {
390 Self {
391 learning_rate: 0.001,
392 w_min: 0.0,
393 w_max: 1.0,
394 }
395 }
396}
397
398impl AntiHebbianRule {
399 pub fn new() -> Self {
401 Self::default()
402 }
403
404 pub fn update(&mut self, pre_activity: f64, post_activity: f64, current_weight: f64, dt: f64) -> f64 {
406 let dw = -self.learning_rate * pre_activity * post_activity * dt;
407 (current_weight + dw).clamp(self.w_min, self.w_max)
408 }
409}
410
411#[derive(Debug, Clone)]
415pub struct HomeostaticPlasticity {
416 pub target_rate: f64,
418
419 pub tau_homeostatic: f64,
421
422 avg_rate: f64,
424
425 pub scaling_factor: f64,
427}
428
429impl Default for HomeostaticPlasticity {
430 fn default() -> Self {
431 Self {
432 target_rate: 5.0, tau_homeostatic: 1000000.0, avg_rate: 5.0,
435 scaling_factor: 1.0,
436 }
437 }
438}
439
440impl HomeostaticPlasticity {
441 pub fn new() -> Self {
443 Self::default()
444 }
445
446 pub fn update(&mut self, current_rate: f64, dt: f64) {
452 self.avg_rate += (current_rate - self.avg_rate) / self.tau_homeostatic * dt;
454
455 let rate_error = self.target_rate - self.avg_rate;
458 self.scaling_factor += rate_error / self.target_rate / self.tau_homeostatic * dt;
459 self.scaling_factor = self.scaling_factor.max(0.1).min(10.0);
460 }
461
462 pub fn apply_scaling(&self, weight: f64) -> f64 {
464 weight * self.scaling_factor
465 }
466
467 pub fn reset(&mut self) {
469 self.avg_rate = self.target_rate;
470 self.scaling_factor = 1.0;
471 }
472}
473
474#[derive(Debug, Clone)]
478pub struct MetaPlasticity {
479 pub base_learning_rate: f64,
481
482 pub learning_rate: f64,
484
485 pub tau_meta: f64,
487
488 avg_activity: f64,
490
491 pub activity_threshold: f64,
493}
494
495impl Default for MetaPlasticity {
496 fn default() -> Self {
497 Self {
498 base_learning_rate: 0.01,
499 learning_rate: 0.01,
500 tau_meta: 100000.0, avg_activity: 0.0,
502 activity_threshold: 0.5,
503 }
504 }
505}
506
507impl MetaPlasticity {
508 pub fn new() -> Self {
510 Self::default()
511 }
512
513 pub fn update(&mut self, activity: f64, dt: f64) {
519 self.avg_activity += (activity - self.avg_activity) / self.tau_meta * dt;
521
522 let modulation = if self.avg_activity > self.activity_threshold {
526 0.5 } else {
528 2.0 };
530
531 self.learning_rate = self.base_learning_rate * modulation;
532 }
533
534 pub fn get_learning_rate(&self) -> f64 {
536 self.learning_rate
537 }
538
539 pub fn reset(&mut self) {
541 self.learning_rate = self.base_learning_rate;
542 self.avg_activity = 0.0;
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549
550 #[test]
551 fn test_stdp_creation() {
552 let stdp = STDP::new();
553 assert_eq!(stdp.a_plus, 0.01);
554 assert_eq!(stdp.a_minus, 0.01);
555 }
556
557 #[test]
558 fn test_stdp_potentiation() {
559 let mut stdp = STDP::new();
560 let weight = 0.5;
561
562 stdp.pre_spike(0.0, weight);
564
565 let dw = stdp.post_spike(10.0, weight);
567
568 assert!(dw > 0.0); }
570
571 #[test]
572 fn test_stdp_depression() {
573 let mut stdp = STDP::new();
574 let weight = 0.5;
575
576 stdp.post_spike(0.0, weight);
578
579 let dw = stdp.pre_spike(10.0, weight);
581
582 assert!(dw < 0.0); }
584
585 #[test]
586 fn test_stdp_window() {
587 let stdp = STDP::new();
588
589 let pot = stdp.window(10.0); let dep = stdp.window(-10.0); assert!(pot > 0.0);
593 assert!(dep < 0.0);
594 }
595
596 #[test]
597 fn test_bcm_rule() {
598 let mut bcm = BCM::new();
599 let weight = 0.5;
600
601 let w1 = bcm.update(1.0, 0.1, weight, 1.0);
603 assert!(w1 < weight);
604
605 let w2 = bcm.update(1.0, 0.9, weight, 1.0);
607 assert!(w2 > weight);
608 }
609
610 #[test]
611 fn test_ojas_rule() {
612 let mut oja = OjasRule::new();
613 let weight = 0.5;
614
615 let new_weight = oja.update(1.0, 1.0, weight, 1.0);
616 assert!(new_weight >= 0.0 && new_weight <= 1.0);
617 }
618
619 #[test]
620 fn test_hebbian_rule() {
621 let mut hebb = HebbianRule::new();
622 let weight = 0.5;
623
624 let new_weight = hebb.update(1.0, 1.0, weight, 1.0);
626 assert!(new_weight > weight);
627 }
628
629 #[test]
630 fn test_anti_hebbian_rule() {
631 let mut anti = AntiHebbianRule::new();
632 let weight = 0.5;
633
634 let new_weight = anti.update(1.0, 1.0, weight, 1.0);
636 assert!(new_weight < weight);
637 }
638
639 #[test]
640 fn test_homeostatic_plasticity() {
641 let mut homeo = HomeostaticPlasticity::new();
642
643 for _ in 0..100 {
645 homeo.update(10.0, 100.0); }
647 assert!(homeo.scaling_factor < 1.0);
648
649 homeo.reset();
650
651 for _ in 0..100 {
653 homeo.update(1.0, 100.0); }
655 assert!(homeo.scaling_factor > 1.0);
656 }
657
658 #[test]
659 fn test_meta_plasticity() {
660 let mut meta = MetaPlasticity::new();
661
662 for _ in 0..1000 {
664 meta.update(0.8, 100.0);
665 }
666 assert!(meta.avg_activity > meta.activity_threshold);
668 assert!(meta.learning_rate < meta.base_learning_rate);
669
670 meta.reset();
671
672 for _ in 0..1000 {
674 meta.update(0.2, 100.0);
675 }
676 assert!(meta.avg_activity < meta.activity_threshold);
677 assert!(meta.learning_rate > meta.base_learning_rate);
678 }
679}