use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlasticityRule {
pub policy: UpdatePolicy,
pub novelty_threshold: f32,
pub merge_threshold: f32,
pub decay_rate: f32,
pub prune_threshold: f32,
pub learning_rate: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum UpdatePolicy {
STDP,
Replace,
EMA,
Bayesian,
WTA,
}
impl PlasticityRule {
pub fn stdp_like() -> Self {
Self {
policy: UpdatePolicy::STDP,
novelty_threshold: 0.6,
merge_threshold: 0.3,
decay_rate: 0.01, prune_threshold: 0.1, learning_rate: 0.1, }
}
pub fn conservative() -> Self {
Self {
policy: UpdatePolicy::EMA,
novelty_threshold: 0.8, merge_threshold: 0.4,
decay_rate: 0.005, prune_threshold: 0.05,
learning_rate: 0.05, }
}
pub fn aggressive() -> Self {
Self {
policy: UpdatePolicy::STDP,
novelty_threshold: 0.4, merge_threshold: 0.2,
decay_rate: 0.02, prune_threshold: 0.2,
learning_rate: 0.2, }
}
pub fn replace_only() -> Self {
Self {
policy: UpdatePolicy::Replace,
novelty_threshold: 1.0, merge_threshold: 0.0, decay_rate: 0.0, prune_threshold: 0.0, learning_rate: 1.0, }
}
pub fn bayesian() -> Self {
Self {
policy: UpdatePolicy::Bayesian,
novelty_threshold: 0.7,
merge_threshold: 0.3,
decay_rate: 0.01,
prune_threshold: 0.1,
learning_rate: 0.1,
}
}
pub fn apply_update(
&self,
current_strength: f32,
new_strength: f32,
time_delta_seconds: f64,
) -> f32 {
match self.policy {
UpdatePolicy::STDP => {
let decayed = current_strength * (1.0 - self.decay_rate * time_delta_seconds as f32 / 86400.0);
let updated = decayed + self.learning_rate * new_strength;
updated.clamp(0.0, 1.0)
}
UpdatePolicy::Replace => new_strength,
UpdatePolicy::EMA => {
let alpha = self.learning_rate;
alpha * new_strength + (1.0 - alpha) * current_strength
}
UpdatePolicy::Bayesian => {
let total_weight = current_strength + new_strength;
if total_weight > 0.0 {
(current_strength * current_strength + new_strength * new_strength) / total_weight
} else {
0.0
}
}
UpdatePolicy::WTA => {
current_strength.max(new_strength)
}
}
}
pub fn should_create_new(&self, novelty_score: f32) -> bool {
novelty_score > self.novelty_threshold
}
pub fn should_merge(&self, similarity: f32) -> bool {
similarity > (1.0 - self.merge_threshold)
}
pub fn should_prune(&self, strength: f32) -> bool {
strength < self.prune_threshold
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stdp_update() {
let rule = PlasticityRule::stdp_like();
let current = 0.5;
let new = 0.8;
let time_delta = 0.0;
let updated = rule.apply_update(current, new, time_delta);
assert!(updated > current);
assert!(updated <= 1.0);
}
#[test]
fn test_novelty_threshold() {
let rule = PlasticityRule::stdp_like();
assert!(rule.should_create_new(0.7)); assert!(!rule.should_create_new(0.5)); }
#[test]
fn test_merge_decision() {
let rule = PlasticityRule::stdp_like();
assert!(rule.should_merge(0.8)); assert!(!rule.should_merge(0.6)); }
#[test]
fn test_prune_decision() {
let rule = PlasticityRule::stdp_like();
assert!(rule.should_prune(0.05)); assert!(!rule.should_prune(0.5)); }
#[test]
fn test_ema_update() {
let rule = PlasticityRule {
policy: UpdatePolicy::EMA,
learning_rate: 0.1,
..PlasticityRule::stdp_like()
};
let current = 0.5;
let new = 1.0;
let updated = rule.apply_update(current, new, 0.0);
let expected = 0.1 * 1.0 + 0.9 * 0.5;
assert!((updated - expected).abs() < 0.001);
}
}