use serde::{Deserialize, Serialize};
use ternary_signal::Signal;
use crate::ternary::TernaryWeight;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlasticityRule {
pub policy: UpdatePolicy,
pub novelty_threshold: Signal,
pub merge_threshold: Signal,
pub decay_rate: Signal,
pub prune_threshold: Signal,
pub learning_rate: Signal,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum UpdatePolicy {
STDP,
Replace,
EMA,
Bayesian,
WTA,
}
impl PlasticityRule {
pub fn stdp_like() -> Self {
Self {
policy: UpdatePolicy::STDP,
novelty_threshold: Signal::positive(153), merge_threshold: Signal::positive(77), decay_rate: Signal::positive(3), prune_threshold: Signal::positive(26), learning_rate: Signal::positive(26), }
}
pub fn conservative() -> Self {
Self {
policy: UpdatePolicy::EMA,
novelty_threshold: Signal::positive(204), merge_threshold: Signal::positive(102), decay_rate: Signal::positive(1), prune_threshold: Signal::positive(13), learning_rate: Signal::positive(13), }
}
pub fn aggressive() -> Self {
Self {
policy: UpdatePolicy::STDP,
novelty_threshold: Signal::positive(102), merge_threshold: Signal::positive(51), decay_rate: Signal::positive(5), prune_threshold: Signal::positive(51), learning_rate: Signal::positive(51), }
}
pub fn replace_only() -> Self {
Self {
policy: UpdatePolicy::Replace,
novelty_threshold: Signal::positive(255), merge_threshold: Signal::positive(0), decay_rate: Signal::positive(0), prune_threshold: Signal::positive(0), learning_rate: Signal::positive(255), }
}
pub fn bayesian() -> Self {
Self {
policy: UpdatePolicy::Bayesian,
novelty_threshold: Signal::positive(179), merge_threshold: Signal::positive(77), decay_rate: Signal::positive(3), prune_threshold: Signal::positive(26), learning_rate: Signal::positive(26), }
}
pub fn apply_update(
&self,
current_strength: Signal,
new_strength: Signal,
time_delta_seconds: f64,
) -> Signal {
let cur_f = current_strength.magnitude_f32();
let new_f = new_strength.magnitude_f32();
let decay_f = self.decay_rate.magnitude_f32();
let lr_f = self.learning_rate.magnitude_f32();
let result_f = match self.policy {
UpdatePolicy::STDP => {
let decayed = cur_f * (1.0 - decay_f * time_delta_seconds as f32 / 86400.0);
let updated = decayed + lr_f * new_f;
updated.clamp(0.0, 1.0)
}
UpdatePolicy::Replace => new_f,
UpdatePolicy::EMA => {
lr_f * new_f + (1.0 - lr_f) * cur_f
}
UpdatePolicy::Bayesian => {
let total_weight = cur_f + new_f;
if total_weight > 0.0 {
(cur_f * cur_f + new_f * new_f) / total_weight
} else {
0.0
}
}
UpdatePolicy::WTA => {
cur_f.max(new_f)
}
};
let polarity = if new_f >= cur_f {
new_strength.polarity
} else {
current_strength.polarity
};
Signal::new(polarity, (result_f * 255.0) as u8)
}
pub fn should_create_new(&self, novelty_score: Signal) -> bool {
novelty_score.magnitude > self.novelty_threshold.magnitude
}
pub fn should_merge(&self, similarity: Signal) -> bool {
let anti_threshold = 255 - self.merge_threshold.magnitude;
similarity.magnitude > anti_threshold
}
pub fn should_prune_signal(&self, strength: Signal) -> bool {
strength.magnitude < self.prune_threshold.magnitude
}
pub fn apply_ternary_update(
&self,
current: TernaryWeight,
observed: TernaryWeight,
confidence: f32,
) -> TernaryWeight {
match self.policy {
UpdatePolicy::STDP => {
if current == observed {
if confidence > 0.7 {
current.strengthen()
} else {
current
}
} else if confidence > 0.5 {
self.move_toward(current, observed)
} else {
self.decay_toward_zero(current)
}
}
UpdatePolicy::Replace => {
if confidence > 0.5 {
observed
} else {
current
}
}
UpdatePolicy::EMA => {
if current == observed {
current
} else if confidence > self.learning_rate.magnitude_f32() {
observed
} else {
current
}
}
UpdatePolicy::Bayesian => {
if confidence > 0.7 {
observed
} else if confidence > 0.3 {
self.decay_toward_zero(current)
} else {
current
}
}
UpdatePolicy::WTA => {
if confidence > 0.5 {
observed
} else {
current
}
}
}
}
fn move_toward(&self, current: TernaryWeight, target: TernaryWeight) -> TernaryWeight {
use TernaryWeight::*;
match (current, target) {
(Pos, Pos) | (Zero, Zero) | (Neg, Neg) => current,
(Zero, Pos) | (Neg, Pos) => current.strengthen(),
(Zero, Neg) | (Pos, Neg) => current.weaken(),
(Pos, Zero) => current.weaken(),
(Neg, Zero) => current.strengthen(),
}
}
fn decay_toward_zero(&self, current: TernaryWeight) -> TernaryWeight {
match current {
TernaryWeight::Pos => TernaryWeight::Zero,
TernaryWeight::Neg => TernaryWeight::Zero,
TernaryWeight::Zero => TernaryWeight::Zero,
}
}
pub fn decay_toward_zero_public(&self, current: TernaryWeight) -> TernaryWeight {
self.decay_toward_zero(current)
}
pub fn should_prune_ternary(&self, weight: TernaryWeight) -> bool {
weight == TernaryWeight::Zero
}
pub fn apply_ternary_stdp(
&self,
current: TernaryWeight,
pre_fired: bool,
post_fired: bool,
time_delta_ms: i64,
) -> TernaryWeight {
if pre_fired && post_fired {
if time_delta_ms > 0 {
current.strengthen()
} else if time_delta_ms < 0 {
current.weaken()
} else {
if current == TernaryWeight::Neg {
TernaryWeight::Zero
} else {
current
}
}
} else {
current
}
}
pub fn ternary_majority_vote(weights: &[TernaryWeight]) -> TernaryWeight {
let mut pos_count = 0i32;
let mut neg_count = 0i32;
for &w in weights {
match w {
TernaryWeight::Pos => pos_count += 1,
TernaryWeight::Neg => neg_count += 1,
TernaryWeight::Zero => {}
}
}
if pos_count > neg_count {
TernaryWeight::Pos
} else if neg_count > pos_count {
TernaryWeight::Neg
} else {
TernaryWeight::Zero
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stdp_update() {
let rule = PlasticityRule::stdp_like();
let current = Signal::positive(128); let new = Signal::positive(204); let time_delta = 0.0;
let updated = rule.apply_update(current, new, time_delta);
assert!(updated.magnitude > current.magnitude);
assert!(updated.magnitude <= 255);
}
#[test]
fn test_novelty_threshold() {
let rule = PlasticityRule::stdp_like();
assert!(rule.should_create_new(Signal::positive(179))); assert!(!rule.should_create_new(Signal::positive(128))); }
#[test]
fn test_merge_decision() {
let rule = PlasticityRule::stdp_like();
assert!(rule.should_merge(Signal::positive(204))); assert!(!rule.should_merge(Signal::positive(153))); }
#[test]
fn test_prune_decision() {
let rule = PlasticityRule::stdp_like();
assert!(rule.should_prune_signal(Signal::positive(13))); assert!(!rule.should_prune_signal(Signal::positive(128))); }
#[test]
fn test_ema_update() {
let rule = PlasticityRule {
policy: UpdatePolicy::EMA,
learning_rate: Signal::positive(26), ..PlasticityRule::stdp_like()
};
let current = Signal::positive(128); let new = Signal::positive(255);
let updated = rule.apply_update(current, new, 0.0);
let expected_f = 0.1 * 1.0 + 0.9 * (128.0 / 255.0);
let expected_mag = (expected_f * 255.0) as u8;
assert!((updated.magnitude as i16 - expected_mag as i16).unsigned_abs() <= 3);
}
}