use serde::{Deserialize, Serialize};
use crate::ternary::TernaryWeight;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlasticityRule {
pub policy: UpdatePolicy,
pub novelty_threshold: f32,
pub merge_threshold: f32,
pub decay_rate: f32,
pub prune_threshold: f32,
pub learning_rate: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum UpdatePolicy {
STDP,
Replace,
EMA,
Bayesian,
WTA,
}
impl PlasticityRule {
pub fn stdp_like() -> Self {
Self {
policy: UpdatePolicy::STDP,
novelty_threshold: 0.6,
merge_threshold: 0.3,
decay_rate: 0.01, prune_threshold: 0.1, learning_rate: 0.1, }
}
pub fn conservative() -> Self {
Self {
policy: UpdatePolicy::EMA,
novelty_threshold: 0.8, merge_threshold: 0.4,
decay_rate: 0.005, prune_threshold: 0.05,
learning_rate: 0.05, }
}
pub fn aggressive() -> Self {
Self {
policy: UpdatePolicy::STDP,
novelty_threshold: 0.4, merge_threshold: 0.2,
decay_rate: 0.02, prune_threshold: 0.2,
learning_rate: 0.2, }
}
pub fn replace_only() -> Self {
Self {
policy: UpdatePolicy::Replace,
novelty_threshold: 1.0, merge_threshold: 0.0, decay_rate: 0.0, prune_threshold: 0.0, learning_rate: 1.0, }
}
pub fn bayesian() -> Self {
Self {
policy: UpdatePolicy::Bayesian,
novelty_threshold: 0.7,
merge_threshold: 0.3,
decay_rate: 0.01,
prune_threshold: 0.1,
learning_rate: 0.1,
}
}
pub fn apply_update(
&self,
current_strength: f32,
new_strength: f32,
time_delta_seconds: f64,
) -> f32 {
match self.policy {
UpdatePolicy::STDP => {
let decayed = current_strength * (1.0 - self.decay_rate * time_delta_seconds as f32 / 86400.0);
let updated = decayed + self.learning_rate * new_strength;
updated.clamp(0.0, 1.0)
}
UpdatePolicy::Replace => new_strength,
UpdatePolicy::EMA => {
let alpha = self.learning_rate;
alpha * new_strength + (1.0 - alpha) * current_strength
}
UpdatePolicy::Bayesian => {
let total_weight = current_strength + new_strength;
if total_weight > 0.0 {
(current_strength * current_strength + new_strength * new_strength) / total_weight
} else {
0.0
}
}
UpdatePolicy::WTA => {
current_strength.max(new_strength)
}
}
}
pub fn should_create_new(&self, novelty_score: f32) -> bool {
novelty_score > self.novelty_threshold
}
pub fn should_merge(&self, similarity: f32) -> bool {
similarity > (1.0 - self.merge_threshold)
}
pub fn should_prune(&self, strength: f32) -> bool {
strength < self.prune_threshold
}
pub fn apply_ternary_update(
&self,
current: TernaryWeight,
observed: TernaryWeight,
confidence: f32,
) -> TernaryWeight {
match self.policy {
UpdatePolicy::STDP => {
if current == observed {
if confidence > 0.7 {
current.strengthen()
} else {
current
}
} else if confidence > 0.5 {
self.move_toward(current, observed)
} else {
self.decay_toward_zero(current)
}
}
UpdatePolicy::Replace => {
if confidence > 0.5 {
observed
} else {
current
}
}
UpdatePolicy::EMA => {
if current == observed {
current
} else if confidence > self.learning_rate {
observed
} else {
current
}
}
UpdatePolicy::Bayesian => {
if confidence > 0.7 {
observed
} else if confidence > 0.3 {
self.decay_toward_zero(current)
} else {
current
}
}
UpdatePolicy::WTA => {
if confidence > 0.5 {
observed
} else {
current
}
}
}
}
fn move_toward(&self, current: TernaryWeight, target: TernaryWeight) -> TernaryWeight {
use TernaryWeight::*;
match (current, target) {
(Pos, Pos) | (Zero, Zero) | (Neg, Neg) => current,
(Zero, Pos) | (Neg, Pos) => current.strengthen(),
(Zero, Neg) | (Pos, Neg) => current.weaken(),
(Pos, Zero) => current.weaken(),
(Neg, Zero) => current.strengthen(),
}
}
fn decay_toward_zero(&self, current: TernaryWeight) -> TernaryWeight {
match current {
TernaryWeight::Pos => TernaryWeight::Zero,
TernaryWeight::Neg => TernaryWeight::Zero,
TernaryWeight::Zero => TernaryWeight::Zero,
}
}
pub fn decay_toward_zero_public(&self, current: TernaryWeight) -> TernaryWeight {
self.decay_toward_zero(current)
}
pub fn should_prune_ternary(&self, weight: TernaryWeight) -> bool {
weight == TernaryWeight::Zero
}
pub fn apply_ternary_stdp(
&self,
current: TernaryWeight,
pre_fired: bool,
post_fired: bool,
time_delta_ms: i64,
) -> TernaryWeight {
if pre_fired && post_fired {
if time_delta_ms > 0 {
current.strengthen()
} else if time_delta_ms < 0 {
current.weaken()
} else {
if current == TernaryWeight::Neg {
TernaryWeight::Zero
} else {
current
}
}
} else {
current
}
}
pub fn ternary_majority_vote(weights: &[TernaryWeight]) -> TernaryWeight {
let mut pos_count = 0i32;
let mut neg_count = 0i32;
for &w in weights {
match w {
TernaryWeight::Pos => pos_count += 1,
TernaryWeight::Neg => neg_count += 1,
TernaryWeight::Zero => {}
}
}
if pos_count > neg_count {
TernaryWeight::Pos
} else if neg_count > pos_count {
TernaryWeight::Neg
} else {
TernaryWeight::Zero
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stdp_update() {
let rule = PlasticityRule::stdp_like();
let current = 0.5;
let new = 0.8;
let time_delta = 0.0;
let updated = rule.apply_update(current, new, time_delta);
assert!(updated > current);
assert!(updated <= 1.0);
}
#[test]
fn test_novelty_threshold() {
let rule = PlasticityRule::stdp_like();
assert!(rule.should_create_new(0.7)); assert!(!rule.should_create_new(0.5)); }
#[test]
fn test_merge_decision() {
let rule = PlasticityRule::stdp_like();
assert!(rule.should_merge(0.8)); assert!(!rule.should_merge(0.6)); }
#[test]
fn test_prune_decision() {
let rule = PlasticityRule::stdp_like();
assert!(rule.should_prune(0.05)); assert!(!rule.should_prune(0.5)); }
#[test]
fn test_ema_update() {
let rule = PlasticityRule {
policy: UpdatePolicy::EMA,
learning_rate: 0.1,
..PlasticityRule::stdp_like()
};
let current = 0.5;
let new = 1.0;
let updated = rule.apply_update(current, new, 0.0);
let expected = 0.1 * 1.0 + 0.9 * 0.5;
assert!((updated - expected).abs() < 0.001);
}
}