1use alloc::vec::Vec;
32
33use super::ContinualStrategy;
34use crate::drift::DriftSignal;
35
36pub struct StreamingEWC {
58 fisher_diag: Vec<f64>,
60 anchor_params: Vec<f64>,
62 fisher_alpha: f64,
64 ewc_lambda: f64,
66 n_updates: u64,
68 initialized: bool,
70}
71
72impl StreamingEWC {
73 pub fn new(n_params: usize, ewc_lambda: f64, fisher_alpha: f64) -> Self {
85 assert!(
86 (0.0..=1.0).contains(&fisher_alpha),
87 "fisher_alpha must be in [0.0, 1.0], got {fisher_alpha}"
88 );
89 Self {
90 fisher_diag: alloc::vec![0.0; n_params],
91 anchor_params: alloc::vec![0.0; n_params],
92 fisher_alpha,
93 ewc_lambda,
94 n_updates: 0,
95 initialized: false,
96 }
97 }
98
99 pub fn with_defaults(n_params: usize) -> Self {
103 Self::new(n_params, 1.0, 0.99)
104 }
105
106 #[inline]
108 pub fn fisher(&self) -> &[f64] {
109 &self.fisher_diag
110 }
111
112 #[inline]
114 pub fn anchor(&self) -> &[f64] {
115 &self.anchor_params
116 }
117
118 #[inline]
120 pub fn ewc_lambda(&self) -> f64 {
121 self.ewc_lambda
122 }
123
124 #[inline]
126 pub fn n_updates(&self) -> u64 {
127 self.n_updates
128 }
129
130 #[inline]
132 pub fn is_initialized(&self) -> bool {
133 self.initialized
134 }
135
136 pub fn set_anchor(&mut self, params: &[f64]) {
142 assert_eq!(
143 params.len(),
144 self.fisher_diag.len(),
145 "set_anchor: expected {} params, got {}",
146 self.fisher_diag.len(),
147 params.len()
148 );
149 self.anchor_params.copy_from_slice(params);
150 self.initialized = true;
151 }
152
153 pub fn penalty(&self, params: &[f64]) -> f64 {
159 if !self.initialized {
160 return 0.0;
161 }
162 let mut total = 0.0;
163 for ((&f, &a), &p) in self
164 .fisher_diag
165 .iter()
166 .zip(self.anchor_params.iter())
167 .zip(params.iter())
168 {
169 let diff = p - a;
170 total += f * diff * diff;
171 }
172 0.5 * self.ewc_lambda * total
173 }
174}
175
176impl ContinualStrategy for StreamingEWC {
177 fn pre_update(&mut self, params: &[f64], gradients: &mut [f64]) {
178 let n = self.fisher_diag.len();
179 debug_assert_eq!(params.len(), n);
180 debug_assert_eq!(gradients.len(), n);
181
182 let alpha = self.fisher_alpha;
183 let one_minus_alpha = 1.0 - alpha;
184
185 for i in 0..n {
186 self.fisher_diag[i] =
188 alpha * self.fisher_diag[i] + one_minus_alpha * gradients[i] * gradients[i];
189
190 if self.initialized {
192 let diff = params[i] - self.anchor_params[i];
193 gradients[i] += self.ewc_lambda * self.fisher_diag[i] * diff;
194 }
195 }
196 }
197
198 fn post_update(&mut self, _params: &[f64]) {
199 self.n_updates += 1;
200 }
201
202 fn on_drift(&mut self, params: &[f64], signal: DriftSignal) {
203 match signal {
204 DriftSignal::Drift => {
205 self.set_anchor(params);
207 }
208 DriftSignal::Warning | DriftSignal::Stable => {
209 }
211 }
212 }
213
214 #[inline]
215 fn n_params(&self) -> usize {
216 self.fisher_diag.len()
217 }
218
219 fn reset(&mut self) {
220 for v in &mut self.fisher_diag {
221 *v = 0.0;
222 }
223 for v in &mut self.anchor_params {
224 *v = 0.0;
225 }
226 self.n_updates = 0;
227 self.initialized = false;
228 }
229}
230
231#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn ewc_gradient_penalty_pushes_toward_anchor() {
241 let mut ewc = StreamingEWC::new(3, 2.0, 0.5);
242 let anchor = [0.0, 0.0, 0.0];
243 ewc.set_anchor(&anchor);
244
245 let params = [1.0, -1.0, 0.5];
247 let mut grads = [0.1, 0.1, 0.1];
248 ewc.pre_update(¶ms, &mut grads);
249
250 assert!(
258 grads[0] > 0.1,
259 "gradient should be pushed away from anchor direction: got {}",
260 grads[0]
261 );
262 assert!(
264 grads[1] < 0.1,
265 "gradient should be pushed toward anchor: got {}",
266 grads[1]
267 );
268 }
269
270 #[test]
271 fn fisher_accumulates_squared_gradients() {
272 let mut ewc = StreamingEWC::new(2, 0.0, 0.5);
273 let params = [0.0, 0.0];
276 let mut grads = [2.0, 3.0];
277 ewc.pre_update(¶ms, &mut grads);
278
279 let expected_f0 = 0.5 * 0.0 + 0.5 * 4.0; let expected_f1 = 0.5 * 0.0 + 0.5 * 9.0; assert!(
283 (ewc.fisher()[0] - expected_f0).abs() < 1e-12,
284 "fisher[0] = {}, expected {}",
285 ewc.fisher()[0],
286 expected_f0
287 );
288 assert!(
289 (ewc.fisher()[1] - expected_f1).abs() < 1e-12,
290 "fisher[1] = {}, expected {}",
291 ewc.fisher()[1],
292 expected_f1
293 );
294
295 let mut grads2 = [1.0, 1.0];
297 ewc.pre_update(¶ms, &mut grads2);
298 let expected_f0_2 = 0.5 * expected_f0 + 0.5 * 1.0; let expected_f1_2 = 0.5 * expected_f1 + 0.5 * 1.0; assert!(
302 (ewc.fisher()[0] - expected_f0_2).abs() < 1e-12,
303 "fisher[0] after 2nd = {}, expected {}",
304 ewc.fisher()[0],
305 expected_f0_2
306 );
307 assert!(
308 (ewc.fisher()[1] - expected_f1_2).abs() < 1e-12,
309 "fisher[1] after 2nd = {}, expected {}",
310 ewc.fisher()[1],
311 expected_f1_2
312 );
313 }
314
315 #[test]
316 fn drift_signal_updates_anchor() {
317 let mut ewc = StreamingEWC::with_defaults(3);
318 let initial = [1.0, 2.0, 3.0];
319 ewc.set_anchor(&initial);
320 assert_eq!(ewc.anchor(), &[1.0, 2.0, 3.0]);
321
322 let new_params = [4.0, 5.0, 6.0];
323 ewc.on_drift(&new_params, DriftSignal::Drift);
324 assert_eq!(
325 ewc.anchor(),
326 &[4.0, 5.0, 6.0],
327 "anchor should be updated on Drift signal"
328 );
329 }
330
331 #[test]
332 fn warning_signal_no_anchor_change() {
333 let mut ewc = StreamingEWC::with_defaults(2);
334 let anchor = [1.0, 2.0];
335 ewc.set_anchor(&anchor);
336
337 let new_params = [10.0, 20.0];
338 ewc.on_drift(&new_params, DriftSignal::Warning);
339 assert_eq!(
340 ewc.anchor(),
341 &[1.0, 2.0],
342 "anchor should not change on Warning"
343 );
344 }
345
346 #[test]
347 fn stable_signal_no_effect() {
348 let mut ewc = StreamingEWC::with_defaults(2);
349 let anchor = [1.0, 2.0];
350 ewc.set_anchor(&anchor);
351
352 let new_params = [10.0, 20.0];
353 ewc.on_drift(&new_params, DriftSignal::Stable);
354 assert_eq!(
355 ewc.anchor(),
356 &[1.0, 2.0],
357 "anchor should not change on Stable"
358 );
359 }
360
361 #[test]
362 fn penalty_increases_with_distance_from_anchor() {
363 let mut ewc = StreamingEWC::new(2, 1.0, 0.5);
364 let anchor = [0.0, 0.0];
365 ewc.set_anchor(&anchor);
366
367 let params = [0.0, 0.0];
369 let mut grads = [1.0, 1.0];
370 ewc.pre_update(¶ms, &mut grads);
371 let close = [0.1, 0.1];
374 let far = [1.0, 1.0];
375 let penalty_close = ewc.penalty(&close);
376 let penalty_far = ewc.penalty(&far);
377
378 assert!(
379 penalty_far > penalty_close,
380 "penalty should increase with distance: close={}, far={}",
381 penalty_close,
382 penalty_far
383 );
384 assert!(
385 penalty_close > 0.0,
386 "penalty should be positive for non-zero distance"
387 );
388 }
389
390 #[test]
391 fn reset_clears_all_state() {
392 let mut ewc = StreamingEWC::with_defaults(3);
393 let params = [1.0, 2.0, 3.0];
394 ewc.set_anchor(¶ms);
395
396 let mut grads = [0.5, 0.5, 0.5];
397 ewc.pre_update(¶ms, &mut grads);
398 ewc.post_update(¶ms);
399
400 assert!(ewc.is_initialized());
401 assert!(ewc.n_updates() > 0);
402 assert!(ewc.fisher().iter().any(|&f| f > 0.0));
403
404 ewc.reset();
405
406 assert!(!ewc.is_initialized());
407 assert_eq!(ewc.n_updates(), 0);
408 assert!(
409 ewc.fisher().iter().all(|&f| f == 0.0),
410 "Fisher should be zeroed after reset"
411 );
412 assert!(
413 ewc.anchor().iter().all(|&a| a == 0.0),
414 "anchor should be zeroed after reset"
415 );
416 }
417
418 #[test]
419 fn zero_lambda_means_no_penalty() {
420 let mut ewc = StreamingEWC::new(3, 0.0, 0.99);
421 let anchor = [0.0, 0.0, 0.0];
422 ewc.set_anchor(&anchor);
423
424 let params = [0.0, 0.0, 0.0];
426 let mut grads_seed = [1.0, 1.0, 1.0];
427 ewc.pre_update(¶ms, &mut grads_seed);
428
429 let params_far = [10.0, 10.0, 10.0];
431 let original_grads = [0.5, -0.3, 0.7];
432 let mut grads = original_grads;
433 ewc.pre_update(¶ms_far, &mut grads);
434
435 for i in 0..3 {
438 assert!(
439 (grads[i] - original_grads[i]).abs() < 1e-12,
440 "gradient[{i}] should be unchanged with lambda=0: got {}, expected {}",
441 grads[i],
442 original_grads[i]
443 );
444 }
445
446 assert!(
447 ewc.penalty(¶ms_far).abs() < 1e-12,
448 "penalty should be zero with lambda=0"
449 );
450 }
451
452 #[test]
453 fn uninitialized_ewc_has_no_penalty() {
454 let ewc = StreamingEWC::with_defaults(3);
455 let params = [10.0, 20.0, 30.0];
456 assert!(
457 ewc.penalty(¶ms).abs() < 1e-12,
458 "penalty should be zero before anchor is set"
459 );
460 }
461
462 #[test]
463 fn post_update_increments_counter() {
464 let mut ewc = StreamingEWC::with_defaults(2);
465 assert_eq!(ewc.n_updates(), 0);
466
467 ewc.post_update(&[1.0, 2.0]);
468 assert_eq!(ewc.n_updates(), 1);
469
470 ewc.post_update(&[1.0, 2.0]);
471 assert_eq!(ewc.n_updates(), 2);
472 }
473}