Skip to main content

exo_core/
plasticity_engine.rs

1//! PlasticityEngine — ADR-029 canonical plasticity system.
2//!
3//! Unifies four previously-independent EWC implementations:
4//! - SONA EWC++ (production, <1ms, ReasoningBank)
5//! - ruvector-nervous-system BTSP (behavioral timescale, 1-3s windows)
6//! - ruvector-nervous-system E-prop (eligibility propagation, 1000ms)
7//! - ruvector-gnn EWC (deprecated; this replaces it)
8//!
9//! Key property: EWC Fisher Information weights are scaled by IIT Φ score
10//! of the pattern being protected — high-consciousness patterns are protected
11//! more strongly from catastrophic forgetting.
12
13use std::collections::HashMap;
14
15/// A weight vector (parameter) in the model being protected.
16pub type WeightId = u64;
17
18/// Fisher Information diagonal approximation for EWC.
19#[derive(Debug, Clone)]
20pub struct FisherDiagonal {
21    /// Fisher Information for each weight dimension
22    pub values: Vec<f32>,
23    /// Φ-weighted importance multiplier (1.0 = neutral, >1.0 = protect more)
24    pub phi_weight: f32,
25    /// Which plasticity mode computed this
26    pub mode: PlasticityMode,
27}
28
29/// Plasticity learning modes.
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum PlasticityMode {
32    /// SONA MicroLoRA: <1ms instant adaptation, EWC++ regularization
33    Instant,
34    /// BTSP: behavioral timescale, 1–3 second windows, one-shot
35    Behavioral,
36    /// E-prop: eligibility propagation, 1000ms credit assignment
37    Eligibility,
38    /// EWC: classic Fisher Information regularization
39    Classic,
40}
41
42/// Δ-parameter update from plasticity engine.
43#[derive(Debug, Clone)]
44pub struct PlasticityDelta {
45    pub weight_id: WeightId,
46    pub delta: Vec<f32>,
47    pub mode: PlasticityMode,
48    pub ewc_penalty: f32,
49    pub phi_protection_applied: bool,
50}
51
52/// Trait for plasticity backend implementations.
53pub trait PlasticityBackend: Send + Sync {
54    fn name(&self) -> &'static str;
55    fn compute_delta(
56        &self,
57        weight_id: WeightId,
58        current: &[f32],
59        gradient: &[f32],
60        lr: f32,
61    ) -> PlasticityDelta;
62}
63
64/// EWC++ implementation — the canonical production backend.
65/// Bidirectional plasticity: strengthens important weights, prunes irrelevant ones.
66pub struct EwcPlusPlusBackend {
67    /// Fisher diagonal per weight
68    fisher: HashMap<WeightId, FisherDiagonal>,
69    /// Optimal weights (consolidation point)
70    theta_star: HashMap<WeightId, Vec<f32>>,
71    /// EWC regularization strength λ
72    pub lambda: f32,
73    /// Φ-weighting scale (0.0 = ignore Φ, 1.0 = full Φ-weighting)
74    pub phi_scale: f32,
75}
76
77impl EwcPlusPlusBackend {
78    pub fn new(lambda: f32) -> Self {
79        Self {
80            fisher: HashMap::new(),
81            theta_star: HashMap::new(),
82            lambda,
83            phi_scale: 1.0,
84        }
85    }
86
87    /// Consolidate current weights as the new optimal point.
88    /// Called after learning a task to protect it from future forgetting.
89    pub fn consolidate(&mut self, weight_id: WeightId, weights: Vec<f32>, phi: Option<f32>) {
90        let phi_weight = phi.unwrap_or(1.0).max(0.01);
91        let n = weights.len();
92        // Initialize Fisher diagonal to 1.0 (uniform importance baseline)
93        let fisher = FisherDiagonal {
94            values: vec![1.0; n],
95            phi_weight,
96            mode: PlasticityMode::Classic,
97        };
98        self.fisher.insert(weight_id, fisher);
99        self.theta_star.insert(weight_id, weights);
100    }
101
102    /// Update Fisher diagonal from gradient samples (online estimation).
103    pub fn update_fisher(&mut self, weight_id: WeightId, gradient: &[f32]) {
104        if let Some(f) = self.fisher.get_mut(&weight_id) {
105            // F_i ← α·F_i + (1-α)·g_i² (running average)
106            let alpha = 0.9f32;
107            for (fi, gi) in f.values.iter_mut().zip(gradient.iter()) {
108                *fi = alpha * *fi + (1.0 - alpha) * gi * gi;
109            }
110        }
111    }
112
113    /// Compute EWC++ penalty term for a weight update.
114    fn ewc_penalty(&self, weight_id: WeightId, current: &[f32]) -> f32 {
115        match (self.fisher.get(&weight_id), self.theta_star.get(&weight_id)) {
116            (Some(f), Some(theta)) => {
117                let penalty: f32 = f
118                    .values
119                    .iter()
120                    .zip(current.iter().zip(theta.iter()))
121                    .map(|(fi, (ci, ti))| fi * (ci - ti).powi(2))
122                    .sum::<f32>();
123                penalty * self.lambda * f.phi_weight * self.phi_scale
124            }
125            _ => 0.0,
126        }
127    }
128}
129
130impl PlasticityBackend for EwcPlusPlusBackend {
131    fn name(&self) -> &'static str {
132        "ewc++"
133    }
134
135    fn compute_delta(
136        &self,
137        weight_id: WeightId,
138        current: &[f32],
139        gradient: &[f32],
140        lr: f32,
141    ) -> PlasticityDelta {
142        let penalty = self.ewc_penalty(weight_id, current);
143        let phi_applied = self
144            .fisher
145            .get(&weight_id)
146            .map(|f| f.phi_weight > 1.0)
147            .unwrap_or(false);
148
149        // EWC++ update: θ ← θ - lr·(∇L + λ·F·(θ - θ*))
150        let delta: Vec<f32> = gradient
151            .iter()
152            .enumerate()
153            .map(|(i, g)| {
154                let ewc_term = self
155                    .fisher
156                    .get(&weight_id)
157                    .zip(self.theta_star.get(&weight_id))
158                    .map(|(f, t)| {
159                        let fi = f.values[i.min(f.values.len() - 1)];
160                        let ci = current[i.min(current.len() - 1)];
161                        let ti = t[i.min(t.len() - 1)];
162                        self.lambda * fi * (ci - ti) * f.phi_weight
163                    })
164                    .unwrap_or(0.0);
165                -lr * (g + ewc_term)
166            })
167            .collect();
168
169        PlasticityDelta {
170            weight_id,
171            delta,
172            mode: PlasticityMode::Instant,
173            ewc_penalty: penalty,
174            phi_protection_applied: phi_applied,
175        }
176    }
177}
178
179/// BTSP (Behavioral Timescale Synaptic Plasticity) backend.
180/// One-shot learning within 1–3 second behavioral windows.
181pub struct BtspBackend {
182    /// Window duration in milliseconds
183    pub window_ms: f32,
184    /// Plateau potential threshold (triggers one-shot learning)
185    pub plateau_threshold: f32,
186    /// BTSP learning rate (typically large — one-shot)
187    pub lr_btsp: f32,
188}
189
190impl BtspBackend {
191    pub fn new() -> Self {
192        Self {
193            window_ms: 2000.0,
194            plateau_threshold: 0.7,
195            lr_btsp: 0.3,
196        }
197    }
198}
199
200impl Default for BtspBackend {
201    fn default() -> Self {
202        Self::new()
203    }
204}
205
206impl PlasticityBackend for BtspBackend {
207    fn name(&self) -> &'static str {
208        "btsp"
209    }
210
211    fn compute_delta(
212        &self,
213        weight_id: WeightId,
214        _current: &[f32],
215        gradient: &[f32],
216        _lr: f32,
217    ) -> PlasticityDelta {
218        // BTSP: large update if plateau potential exceeds threshold
219        let n = gradient.len().max(1);
220        let plateau = gradient.iter().map(|g| g.abs()).sum::<f32>() / n as f32;
221        let btsp_lr = if plateau > self.plateau_threshold {
222            self.lr_btsp
223        } else {
224            self.lr_btsp * 0.1
225        };
226        let delta: Vec<f32> = gradient.iter().map(|g| -btsp_lr * g).collect();
227        PlasticityDelta {
228            weight_id,
229            delta,
230            mode: PlasticityMode::Behavioral,
231            ewc_penalty: 0.0,
232            phi_protection_applied: false,
233        }
234    }
235}
236
237/// The unified plasticity engine.
238pub struct PlasticityEngine {
239    /// EWC++ is always present (canonical production backend)
240    pub ewc: EwcPlusPlusBackend,
241    /// Optional BTSP for biological one-shot plasticity
242    pub btsp: Option<BtspBackend>,
243    /// Default mode for new weight updates
244    pub default_mode: PlasticityMode,
245}
246
247impl PlasticityEngine {
248    pub fn new(lambda: f32) -> Self {
249        Self {
250            ewc: EwcPlusPlusBackend::new(lambda),
251            btsp: None,
252            default_mode: PlasticityMode::Instant,
253        }
254    }
255
256    pub fn with_btsp(mut self) -> Self {
257        self.btsp = Some(BtspBackend::new());
258        self
259    }
260
261    /// Set Φ-based protection weight for a consolidated pattern.
262    /// phi > 1.0 protects the pattern more strongly from forgetting.
263    pub fn consolidate_with_phi(&mut self, weight_id: WeightId, weights: Vec<f32>, phi: f32) {
264        self.ewc.consolidate(weight_id, weights, Some(phi));
265    }
266
267    /// Compute update delta for a weight, routing to appropriate backend.
268    pub fn compute_delta(
269        &mut self,
270        weight_id: WeightId,
271        current: &[f32],
272        gradient: &[f32],
273        lr: f32,
274        mode: Option<PlasticityMode>,
275    ) -> PlasticityDelta {
276        // Update Fisher diagonal online
277        self.ewc.update_fisher(weight_id, gradient);
278
279        let mode = mode.unwrap_or(self.default_mode);
280        match mode {
281            PlasticityMode::Instant | PlasticityMode::Classic => {
282                self.ewc.compute_delta(weight_id, current, gradient, lr)
283            }
284            PlasticityMode::Behavioral => self
285                .btsp
286                .as_ref()
287                .map(|b| b.compute_delta(weight_id, current, gradient, lr))
288                .unwrap_or_else(|| self.ewc.compute_delta(weight_id, current, gradient, lr)),
289            PlasticityMode::Eligibility =>
290            // E-prop: use EWC with reduced learning rate (credit assignment delay)
291            {
292                self.ewc
293                    .compute_delta(weight_id, current, gradient, lr * 0.3)
294            }
295        }
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_ewc_prevents_catastrophic_forgetting() {
305        let mut engine = PlasticityEngine::new(10.0);
306        let weights = vec![1.0f32, 2.0, 3.0, 4.0];
307        engine.consolidate_with_phi(0, weights.clone(), 2.0); // High Φ = protect more
308
309        // Simulate gradient pushing weights far from consolidation point
310        let current = vec![5.0f32, 6.0, 7.0, 8.0]; // Drifted far
311        let gradient = vec![1.0f32; 4];
312        let delta = engine.compute_delta(0, &current, &gradient, 0.01, None);
313
314        // EWC penalty should be large (current far from theta_star)
315        assert!(delta.ewc_penalty > 0.0, "EWC penalty should be nonzero");
316        // Phi protection should be applied
317        assert!(delta.phi_protection_applied);
318    }
319
320    #[test]
321    fn test_btsp_one_shot_large_update() {
322        let btsp = BtspBackend::new();
323        let gradient = vec![0.8f32; 10]; // Above plateau threshold
324        let delta = btsp.compute_delta(0, &vec![0.0; 10], &gradient, 0.01);
325        // BTSP lr (0.3) should dominate over standard lr (0.01)
326        assert!(
327            delta.delta[0].abs() > 0.1,
328            "BTSP should produce large one-shot update"
329        );
330    }
331
332    #[test]
333    fn test_phi_weighted_protection() {
334        let mut engine = PlasticityEngine::new(1.0);
335        let weights = vec![0.0f32; 4];
336        engine.consolidate_with_phi(1, weights.clone(), 5.0); // Very high Φ
337        engine.consolidate_with_phi(2, weights.clone(), 0.1); // Very low Φ
338
339        let current = vec![1.0f32; 4];
340        let gradient = vec![0.1f32; 4];
341
342        let delta_high_phi = engine.compute_delta(1, &current, &gradient, 0.01, None);
343        let delta_low_phi = engine.compute_delta(2, &current, &gradient, 0.01, None);
344
345        // High Φ pattern should have larger EWC penalty (more protection)
346        assert!(
347            delta_high_phi.ewc_penalty > delta_low_phi.ewc_penalty,
348            "High Φ patterns should be protected more strongly"
349        );
350    }
351}