1use std::collections::HashMap;
14
15pub type WeightId = u64;
17
18#[derive(Debug, Clone)]
20pub struct FisherDiagonal {
21 pub values: Vec<f32>,
23 pub phi_weight: f32,
25 pub mode: PlasticityMode,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum PlasticityMode {
32 Instant,
34 Behavioral,
36 Eligibility,
38 Classic,
40}
41
42#[derive(Debug, Clone)]
44pub struct PlasticityDelta {
45 pub weight_id: WeightId,
46 pub delta: Vec<f32>,
47 pub mode: PlasticityMode,
48 pub ewc_penalty: f32,
49 pub phi_protection_applied: bool,
50}
51
52pub trait PlasticityBackend: Send + Sync {
54 fn name(&self) -> &'static str;
55 fn compute_delta(
56 &self,
57 weight_id: WeightId,
58 current: &[f32],
59 gradient: &[f32],
60 lr: f32,
61 ) -> PlasticityDelta;
62}
63
64pub struct EwcPlusPlusBackend {
67 fisher: HashMap<WeightId, FisherDiagonal>,
69 theta_star: HashMap<WeightId, Vec<f32>>,
71 pub lambda: f32,
73 pub phi_scale: f32,
75}
76
77impl EwcPlusPlusBackend {
78 pub fn new(lambda: f32) -> Self {
79 Self {
80 fisher: HashMap::new(),
81 theta_star: HashMap::new(),
82 lambda,
83 phi_scale: 1.0,
84 }
85 }
86
87 pub fn consolidate(&mut self, weight_id: WeightId, weights: Vec<f32>, phi: Option<f32>) {
90 let phi_weight = phi.unwrap_or(1.0).max(0.01);
91 let n = weights.len();
92 let fisher = FisherDiagonal {
94 values: vec![1.0; n],
95 phi_weight,
96 mode: PlasticityMode::Classic,
97 };
98 self.fisher.insert(weight_id, fisher);
99 self.theta_star.insert(weight_id, weights);
100 }
101
102 pub fn update_fisher(&mut self, weight_id: WeightId, gradient: &[f32]) {
104 if let Some(f) = self.fisher.get_mut(&weight_id) {
105 let alpha = 0.9f32;
107 for (fi, gi) in f.values.iter_mut().zip(gradient.iter()) {
108 *fi = alpha * *fi + (1.0 - alpha) * gi * gi;
109 }
110 }
111 }
112
113 fn ewc_penalty(&self, weight_id: WeightId, current: &[f32]) -> f32 {
115 match (self.fisher.get(&weight_id), self.theta_star.get(&weight_id)) {
116 (Some(f), Some(theta)) => {
117 let penalty: f32 = f
118 .values
119 .iter()
120 .zip(current.iter().zip(theta.iter()))
121 .map(|(fi, (ci, ti))| fi * (ci - ti).powi(2))
122 .sum::<f32>();
123 penalty * self.lambda * f.phi_weight * self.phi_scale
124 }
125 _ => 0.0,
126 }
127 }
128}
129
130impl PlasticityBackend for EwcPlusPlusBackend {
131 fn name(&self) -> &'static str {
132 "ewc++"
133 }
134
135 fn compute_delta(
136 &self,
137 weight_id: WeightId,
138 current: &[f32],
139 gradient: &[f32],
140 lr: f32,
141 ) -> PlasticityDelta {
142 let penalty = self.ewc_penalty(weight_id, current);
143 let phi_applied = self
144 .fisher
145 .get(&weight_id)
146 .map(|f| f.phi_weight > 1.0)
147 .unwrap_or(false);
148
149 let delta: Vec<f32> = gradient
151 .iter()
152 .enumerate()
153 .map(|(i, g)| {
154 let ewc_term = self
155 .fisher
156 .get(&weight_id)
157 .zip(self.theta_star.get(&weight_id))
158 .map(|(f, t)| {
159 let fi = f.values[i.min(f.values.len() - 1)];
160 let ci = current[i.min(current.len() - 1)];
161 let ti = t[i.min(t.len() - 1)];
162 self.lambda * fi * (ci - ti) * f.phi_weight
163 })
164 .unwrap_or(0.0);
165 -lr * (g + ewc_term)
166 })
167 .collect();
168
169 PlasticityDelta {
170 weight_id,
171 delta,
172 mode: PlasticityMode::Instant,
173 ewc_penalty: penalty,
174 phi_protection_applied: phi_applied,
175 }
176 }
177}
178
179pub struct BtspBackend {
182 pub window_ms: f32,
184 pub plateau_threshold: f32,
186 pub lr_btsp: f32,
188}
189
190impl BtspBackend {
191 pub fn new() -> Self {
192 Self {
193 window_ms: 2000.0,
194 plateau_threshold: 0.7,
195 lr_btsp: 0.3,
196 }
197 }
198}
199
200impl Default for BtspBackend {
201 fn default() -> Self {
202 Self::new()
203 }
204}
205
206impl PlasticityBackend for BtspBackend {
207 fn name(&self) -> &'static str {
208 "btsp"
209 }
210
211 fn compute_delta(
212 &self,
213 weight_id: WeightId,
214 _current: &[f32],
215 gradient: &[f32],
216 _lr: f32,
217 ) -> PlasticityDelta {
218 let n = gradient.len().max(1);
220 let plateau = gradient.iter().map(|g| g.abs()).sum::<f32>() / n as f32;
221 let btsp_lr = if plateau > self.plateau_threshold {
222 self.lr_btsp
223 } else {
224 self.lr_btsp * 0.1
225 };
226 let delta: Vec<f32> = gradient.iter().map(|g| -btsp_lr * g).collect();
227 PlasticityDelta {
228 weight_id,
229 delta,
230 mode: PlasticityMode::Behavioral,
231 ewc_penalty: 0.0,
232 phi_protection_applied: false,
233 }
234 }
235}
236
237pub struct PlasticityEngine {
239 pub ewc: EwcPlusPlusBackend,
241 pub btsp: Option<BtspBackend>,
243 pub default_mode: PlasticityMode,
245}
246
247impl PlasticityEngine {
248 pub fn new(lambda: f32) -> Self {
249 Self {
250 ewc: EwcPlusPlusBackend::new(lambda),
251 btsp: None,
252 default_mode: PlasticityMode::Instant,
253 }
254 }
255
256 pub fn with_btsp(mut self) -> Self {
257 self.btsp = Some(BtspBackend::new());
258 self
259 }
260
261 pub fn consolidate_with_phi(&mut self, weight_id: WeightId, weights: Vec<f32>, phi: f32) {
264 self.ewc.consolidate(weight_id, weights, Some(phi));
265 }
266
267 pub fn compute_delta(
269 &mut self,
270 weight_id: WeightId,
271 current: &[f32],
272 gradient: &[f32],
273 lr: f32,
274 mode: Option<PlasticityMode>,
275 ) -> PlasticityDelta {
276 self.ewc.update_fisher(weight_id, gradient);
278
279 let mode = mode.unwrap_or(self.default_mode);
280 match mode {
281 PlasticityMode::Instant | PlasticityMode::Classic => {
282 self.ewc.compute_delta(weight_id, current, gradient, lr)
283 }
284 PlasticityMode::Behavioral => self
285 .btsp
286 .as_ref()
287 .map(|b| b.compute_delta(weight_id, current, gradient, lr))
288 .unwrap_or_else(|| self.ewc.compute_delta(weight_id, current, gradient, lr)),
289 PlasticityMode::Eligibility =>
290 {
292 self.ewc
293 .compute_delta(weight_id, current, gradient, lr * 0.3)
294 }
295 }
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_ewc_prevents_catastrophic_forgetting() {
305 let mut engine = PlasticityEngine::new(10.0);
306 let weights = vec![1.0f32, 2.0, 3.0, 4.0];
307 engine.consolidate_with_phi(0, weights.clone(), 2.0); let current = vec![5.0f32, 6.0, 7.0, 8.0]; let gradient = vec![1.0f32; 4];
312 let delta = engine.compute_delta(0, ¤t, &gradient, 0.01, None);
313
314 assert!(delta.ewc_penalty > 0.0, "EWC penalty should be nonzero");
316 assert!(delta.phi_protection_applied);
318 }
319
320 #[test]
321 fn test_btsp_one_shot_large_update() {
322 let btsp = BtspBackend::new();
323 let gradient = vec![0.8f32; 10]; let delta = btsp.compute_delta(0, &vec![0.0; 10], &gradient, 0.01);
325 assert!(
327 delta.delta[0].abs() > 0.1,
328 "BTSP should produce large one-shot update"
329 );
330 }
331
332 #[test]
333 fn test_phi_weighted_protection() {
334 let mut engine = PlasticityEngine::new(1.0);
335 let weights = vec![0.0f32; 4];
336 engine.consolidate_with_phi(1, weights.clone(), 5.0); engine.consolidate_with_phi(2, weights.clone(), 0.1); let current = vec![1.0f32; 4];
340 let gradient = vec![0.1f32; 4];
341
342 let delta_high_phi = engine.compute_delta(1, ¤t, &gradient, 0.01, None);
343 let delta_low_phi = engine.compute_delta(2, ¤t, &gradient, 0.01, None);
344
345 assert!(
347 delta_high_phi.ewc_penalty > delta_low_phi.ewc_penalty,
348 "High Φ patterns should be protected more strongly"
349 );
350 }
351}