Skip to main content

irithyll_core/continual/
ewc.rs

1//! Streaming Elastic Weight Consolidation.
2//!
3//! Maintains a diagonal Fisher Information Matrix as an exponential moving
4//! average of squared gradients. When drift is detected, the anchor point
5//! (parameter snapshot) shifts, allowing adaptation to the new regime while
6//! preserving knowledge of the old.
7//!
8//! # The Math
9//!
10//! The EWC penalty added to the loss is:
11//! ```text
12//! L_ewc = (ewc_lambda / 2) * sum_i F_i * (theta_i - theta_i*)^2
13//! ```
14//! where F_i is the diagonal Fisher, theta_i* is the anchor, theta_i is current.
15//!
16//! The gradient modification is:
17//! ```text
18//! grad_i = grad_i + ewc_lambda * F_i * (theta_i - theta_i*)
19//! ```
20//!
21//! The Fisher is maintained as EWMA:
22//! ```text
23//! F_i = fisher_alpha * F_{i-1} + (1 - fisher_alpha) * grad_i^2
24//! ```
25//!
26//! # References
27//!
28//! - Kirkpatrick et al., "Overcoming catastrophic forgetting in neural networks",
29//!   PNAS 114(13), 2017.
30
31use alloc::vec::Vec;
32
33use super::ContinualStrategy;
34use crate::drift::DriftSignal;
35
36/// Streaming EWC: prevents catastrophic forgetting via Fisher regularization.
37///
38/// Maintains a diagonal Fisher Information Matrix as an exponential moving
39/// average of squared gradients. When drift is detected, the anchor point
40/// (parameter snapshot) shifts, allowing adaptation to the new regime while
41/// preserving knowledge of the old.
42///
43/// # Example
44///
45/// ```
46/// use irithyll_core::continual::{StreamingEWC, ContinualStrategy};
47/// use irithyll_core::drift::DriftSignal;
48///
49/// let mut ewc = StreamingEWC::with_defaults(4);
50/// let params = vec![1.0, 2.0, 3.0, 4.0];
51/// ewc.set_anchor(&params);
52///
53/// let mut grads = vec![0.1, -0.2, 0.3, -0.4];
54/// ewc.pre_update(&params, &mut grads);
55/// // Gradients are modified by EWC penalty toward anchor
56/// ```
57pub struct StreamingEWC {
58    /// Diagonal Fisher information (EWMA of grad^2).
59    fisher_diag: Vec<f64>,
60    /// Parameter snapshot from last drift/init.
61    anchor_params: Vec<f64>,
62    /// EWMA decay for Fisher (default: 0.99).
63    fisher_alpha: f64,
64    /// Regularization strength (default: 1.0).
65    ewc_lambda: f64,
66    /// Total updates seen.
67    n_updates: u64,
68    /// Whether anchor has been set.
69    initialized: bool,
70}
71
72impl StreamingEWC {
73    /// Create a new `StreamingEWC` with explicit hyperparameters.
74    ///
75    /// # Arguments
76    ///
77    /// * `n_params` -- number of parameters to protect
78    /// * `ewc_lambda` -- regularization strength; higher = stronger memory
79    /// * `fisher_alpha` -- EWMA decay for Fisher; closer to 1.0 = longer memory
80    ///
81    /// # Panics
82    ///
83    /// Panics if `fisher_alpha` is not in `[0.0, 1.0]`.
84    pub fn new(n_params: usize, ewc_lambda: f64, fisher_alpha: f64) -> Self {
85        assert!(
86            (0.0..=1.0).contains(&fisher_alpha),
87            "fisher_alpha must be in [0.0, 1.0], got {fisher_alpha}"
88        );
89        Self {
90            fisher_diag: alloc::vec![0.0; n_params],
91            anchor_params: alloc::vec![0.0; n_params],
92            fisher_alpha,
93            ewc_lambda,
94            n_updates: 0,
95            initialized: false,
96        }
97    }
98
99    /// Create a new `StreamingEWC` with default hyperparameters.
100    ///
101    /// Defaults: `ewc_lambda = 1.0`, `fisher_alpha = 0.99`.
102    pub fn with_defaults(n_params: usize) -> Self {
103        Self::new(n_params, 1.0, 0.99)
104    }
105
106    /// Current diagonal Fisher information.
107    #[inline]
108    pub fn fisher(&self) -> &[f64] {
109        &self.fisher_diag
110    }
111
112    /// Current anchor parameters.
113    #[inline]
114    pub fn anchor(&self) -> &[f64] {
115        &self.anchor_params
116    }
117
118    /// Current EWC regularization strength.
119    #[inline]
120    pub fn ewc_lambda(&self) -> f64 {
121        self.ewc_lambda
122    }
123
124    /// Number of updates processed.
125    #[inline]
126    pub fn n_updates(&self) -> u64 {
127        self.n_updates
128    }
129
130    /// Whether the anchor has been initialized.
131    #[inline]
132    pub fn is_initialized(&self) -> bool {
133        self.initialized
134    }
135
136    /// Manually set the anchor to a parameter snapshot.
137    ///
138    /// # Panics
139    ///
140    /// Panics if `params.len() != n_params()`.
141    pub fn set_anchor(&mut self, params: &[f64]) {
142        assert_eq!(
143            params.len(),
144            self.fisher_diag.len(),
145            "set_anchor: expected {} params, got {}",
146            self.fisher_diag.len(),
147            params.len()
148        );
149        self.anchor_params.copy_from_slice(params);
150        self.initialized = true;
151    }
152
153    /// Compute the EWC penalty value for given parameters.
154    ///
155    /// Returns `(ewc_lambda / 2) * sum_i F_i * (theta_i - anchor_i)^2`.
156    ///
157    /// Returns `0.0` if the anchor has not been set.
158    pub fn penalty(&self, params: &[f64]) -> f64 {
159        if !self.initialized {
160            return 0.0;
161        }
162        let mut total = 0.0;
163        for ((&f, &a), &p) in self
164            .fisher_diag
165            .iter()
166            .zip(self.anchor_params.iter())
167            .zip(params.iter())
168        {
169            let diff = p - a;
170            total += f * diff * diff;
171        }
172        0.5 * self.ewc_lambda * total
173    }
174}
175
176impl ContinualStrategy for StreamingEWC {
177    fn pre_update(&mut self, params: &[f64], gradients: &mut [f64]) {
178        let n = self.fisher_diag.len();
179        debug_assert_eq!(params.len(), n);
180        debug_assert_eq!(gradients.len(), n);
181
182        let alpha = self.fisher_alpha;
183        let one_minus_alpha = 1.0 - alpha;
184
185        for i in 0..n {
186            // Update Fisher EWMA with squared gradient
187            self.fisher_diag[i] =
188                alpha * self.fisher_diag[i] + one_minus_alpha * gradients[i] * gradients[i];
189
190            // Add EWC gradient penalty if anchor is set
191            if self.initialized {
192                let diff = params[i] - self.anchor_params[i];
193                gradients[i] += self.ewc_lambda * self.fisher_diag[i] * diff;
194            }
195        }
196    }
197
198    fn post_update(&mut self, _params: &[f64]) {
199        self.n_updates += 1;
200    }
201
202    fn on_drift(&mut self, params: &[f64], signal: DriftSignal) {
203        match signal {
204            DriftSignal::Drift => {
205                // Confirmed drift: snapshot current params as new anchor
206                self.set_anchor(params);
207            }
208            DriftSignal::Warning | DriftSignal::Stable => {
209                // No action on Warning or Stable
210            }
211        }
212    }
213
214    #[inline]
215    fn n_params(&self) -> usize {
216        self.fisher_diag.len()
217    }
218
219    fn reset(&mut self) {
220        for v in &mut self.fisher_diag {
221            *v = 0.0;
222        }
223        for v in &mut self.anchor_params {
224            *v = 0.0;
225        }
226        self.n_updates = 0;
227        self.initialized = false;
228    }
229}
230
231// ---------------------------------------------------------------------------
232// Tests
233// ---------------------------------------------------------------------------
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn ewc_gradient_penalty_pushes_toward_anchor() {
241        let mut ewc = StreamingEWC::new(3, 2.0, 0.5);
242        let anchor = [0.0, 0.0, 0.0];
243        ewc.set_anchor(&anchor);
244
245        // Seed Fisher with a first call so it's nonzero
246        let params = [1.0, -1.0, 0.5];
247        let mut grads = [0.1, 0.1, 0.1];
248        ewc.pre_update(&params, &mut grads);
249
250        // After one step, Fisher has accumulated some value.
251        // The penalty gradient is ewc_lambda * F_i * (param_i - anchor_i).
252        // For param=1.0, anchor=0.0 => penalty pushes grad positive (increasing loss
253        // for moving away from anchor), so grad should be > original 0.1.
254        // For param=-1.0, anchor=0.0 => penalty pushes grad negative.
255        // Verify the direction:
256        // grad[0] was 0.1, param[0]-anchor[0] = 1.0 > 0 => penalty added is positive => grad[0] > 0.1
257        assert!(
258            grads[0] > 0.1,
259            "gradient should be pushed away from anchor direction: got {}",
260            grads[0]
261        );
262        // grad[1] was 0.1, param[1]-anchor[1] = -1.0 < 0 => penalty is negative => grad[1] < 0.1
263        assert!(
264            grads[1] < 0.1,
265            "gradient should be pushed toward anchor: got {}",
266            grads[1]
267        );
268    }
269
270    #[test]
271    fn fisher_accumulates_squared_gradients() {
272        let mut ewc = StreamingEWC::new(2, 0.0, 0.5);
273        // lambda=0 so no penalty contamination
274
275        let params = [0.0, 0.0];
276        let mut grads = [2.0, 3.0];
277        ewc.pre_update(&params, &mut grads);
278
279        // After one update: F_i = 0.5 * 0.0 + 0.5 * grad_i^2
280        let expected_f0 = 0.5 * 0.0 + 0.5 * 4.0; // 2.0
281        let expected_f1 = 0.5 * 0.0 + 0.5 * 9.0; // 4.5
282        assert!(
283            (ewc.fisher()[0] - expected_f0).abs() < 1e-12,
284            "fisher[0] = {}, expected {}",
285            ewc.fisher()[0],
286            expected_f0
287        );
288        assert!(
289            (ewc.fisher()[1] - expected_f1).abs() < 1e-12,
290            "fisher[1] = {}, expected {}",
291            ewc.fisher()[1],
292            expected_f1
293        );
294
295        // Second update with different grads
296        let mut grads2 = [1.0, 1.0];
297        ewc.pre_update(&params, &mut grads2);
298        // F_i = 0.5 * prev + 0.5 * grad_i^2
299        let expected_f0_2 = 0.5 * expected_f0 + 0.5 * 1.0; // 1.5
300        let expected_f1_2 = 0.5 * expected_f1 + 0.5 * 1.0; // 2.75
301        assert!(
302            (ewc.fisher()[0] - expected_f0_2).abs() < 1e-12,
303            "fisher[0] after 2nd = {}, expected {}",
304            ewc.fisher()[0],
305            expected_f0_2
306        );
307        assert!(
308            (ewc.fisher()[1] - expected_f1_2).abs() < 1e-12,
309            "fisher[1] after 2nd = {}, expected {}",
310            ewc.fisher()[1],
311            expected_f1_2
312        );
313    }
314
315    #[test]
316    fn drift_signal_updates_anchor() {
317        let mut ewc = StreamingEWC::with_defaults(3);
318        let initial = [1.0, 2.0, 3.0];
319        ewc.set_anchor(&initial);
320        assert_eq!(ewc.anchor(), &[1.0, 2.0, 3.0]);
321
322        let new_params = [4.0, 5.0, 6.0];
323        ewc.on_drift(&new_params, DriftSignal::Drift);
324        assert_eq!(
325            ewc.anchor(),
326            &[4.0, 5.0, 6.0],
327            "anchor should be updated on Drift signal"
328        );
329    }
330
331    #[test]
332    fn warning_signal_no_anchor_change() {
333        let mut ewc = StreamingEWC::with_defaults(2);
334        let anchor = [1.0, 2.0];
335        ewc.set_anchor(&anchor);
336
337        let new_params = [10.0, 20.0];
338        ewc.on_drift(&new_params, DriftSignal::Warning);
339        assert_eq!(
340            ewc.anchor(),
341            &[1.0, 2.0],
342            "anchor should not change on Warning"
343        );
344    }
345
346    #[test]
347    fn stable_signal_no_effect() {
348        let mut ewc = StreamingEWC::with_defaults(2);
349        let anchor = [1.0, 2.0];
350        ewc.set_anchor(&anchor);
351
352        let new_params = [10.0, 20.0];
353        ewc.on_drift(&new_params, DriftSignal::Stable);
354        assert_eq!(
355            ewc.anchor(),
356            &[1.0, 2.0],
357            "anchor should not change on Stable"
358        );
359    }
360
361    #[test]
362    fn penalty_increases_with_distance_from_anchor() {
363        let mut ewc = StreamingEWC::new(2, 1.0, 0.5);
364        let anchor = [0.0, 0.0];
365        ewc.set_anchor(&anchor);
366
367        // Seed Fisher with known values
368        let params = [0.0, 0.0];
369        let mut grads = [1.0, 1.0];
370        ewc.pre_update(&params, &mut grads);
371        // Fisher is now [0.5, 0.5]
372
373        let close = [0.1, 0.1];
374        let far = [1.0, 1.0];
375        let penalty_close = ewc.penalty(&close);
376        let penalty_far = ewc.penalty(&far);
377
378        assert!(
379            penalty_far > penalty_close,
380            "penalty should increase with distance: close={}, far={}",
381            penalty_close,
382            penalty_far
383        );
384        assert!(
385            penalty_close > 0.0,
386            "penalty should be positive for non-zero distance"
387        );
388    }
389
390    #[test]
391    fn reset_clears_all_state() {
392        let mut ewc = StreamingEWC::with_defaults(3);
393        let params = [1.0, 2.0, 3.0];
394        ewc.set_anchor(&params);
395
396        let mut grads = [0.5, 0.5, 0.5];
397        ewc.pre_update(&params, &mut grads);
398        ewc.post_update(&params);
399
400        assert!(ewc.is_initialized());
401        assert!(ewc.n_updates() > 0);
402        assert!(ewc.fisher().iter().any(|&f| f > 0.0));
403
404        ewc.reset();
405
406        assert!(!ewc.is_initialized());
407        assert_eq!(ewc.n_updates(), 0);
408        assert!(
409            ewc.fisher().iter().all(|&f| f == 0.0),
410            "Fisher should be zeroed after reset"
411        );
412        assert!(
413            ewc.anchor().iter().all(|&a| a == 0.0),
414            "anchor should be zeroed after reset"
415        );
416    }
417
418    #[test]
419    fn zero_lambda_means_no_penalty() {
420        let mut ewc = StreamingEWC::new(3, 0.0, 0.99);
421        let anchor = [0.0, 0.0, 0.0];
422        ewc.set_anchor(&anchor);
423
424        // Seed Fisher
425        let params = [0.0, 0.0, 0.0];
426        let mut grads_seed = [1.0, 1.0, 1.0];
427        ewc.pre_update(&params, &mut grads_seed);
428
429        // Now with params far from anchor
430        let params_far = [10.0, 10.0, 10.0];
431        let original_grads = [0.5, -0.3, 0.7];
432        let mut grads = original_grads;
433        ewc.pre_update(&params_far, &mut grads);
434
435        // With lambda=0, the EWC penalty term is zero, so gradients should
436        // only differ due to Fisher update (which doesn't affect them with lambda=0)
437        for i in 0..3 {
438            assert!(
439                (grads[i] - original_grads[i]).abs() < 1e-12,
440                "gradient[{i}] should be unchanged with lambda=0: got {}, expected {}",
441                grads[i],
442                original_grads[i]
443            );
444        }
445
446        assert!(
447            ewc.penalty(&params_far).abs() < 1e-12,
448            "penalty should be zero with lambda=0"
449        );
450    }
451
452    #[test]
453    fn uninitialized_ewc_has_no_penalty() {
454        let ewc = StreamingEWC::with_defaults(3);
455        let params = [10.0, 20.0, 30.0];
456        assert!(
457            ewc.penalty(&params).abs() < 1e-12,
458            "penalty should be zero before anchor is set"
459        );
460    }
461
462    #[test]
463    fn post_update_increments_counter() {
464        let mut ewc = StreamingEWC::with_defaults(2);
465        assert_eq!(ewc.n_updates(), 0);
466
467        ewc.post_update(&[1.0, 2.0]);
468        assert_eq!(ewc.n_updates(), 1);
469
470        ewc.post_update(&[1.0, 2.0]);
471        assert_eq!(ewc.n_updates(), 2);
472    }
473}