Skip to main content

irithyll_core/continual/
parameter_isolation.rs

1//! Drift-triggered parameter isolation.
2//!
3//! When drift is detected, computes an importance mask from accumulated
4//! gradient magnitudes and freezes the most important parameters (setting
5//! their gradients to zero). The remaining parameters are free to adapt.
6//!
7//! This is a streaming version of PackNet (Mallya & Lazebnik, 2018):
8//! instead of task boundaries, irithyll's drift detectors trigger mask updates.
9//!
10//! # How it works
11//!
12//! 1. Importance is tracked as an EWMA of `|gradient|` for each parameter.
13//! 2. When a `DriftSignal::Drift` arrives, the top `freeze_fraction` of
14//!    parameters (by importance) are frozen.
15//! 3. Frozen parameters have their gradients zeroed in every subsequent
16//!    `pre_update` call, effectively isolating them from further learning.
17//! 4. Multiple drift events accumulate: once frozen, a parameter stays frozen
18//!    until explicitly unfrozen or the strategy is reset.
19
20use alloc::vec::Vec;
21
22use super::ContinualStrategy;
23use crate::drift::DriftSignal;
24use crate::math;
25
26/// Drift-triggered parameter isolation.
27///
28/// When drift is detected, computes an importance mask from the accumulated
29/// gradient magnitudes and freezes the most important parameters (setting their
30/// gradients to zero). The remaining parameters are free to adapt.
31///
32/// This is a streaming version of PackNet (Mallya & Lazebnik, 2018):
33/// instead of task boundaries, irithyll's drift detectors trigger mask updates.
34///
35/// # Example
36///
37/// ```
38/// use irithyll_core::continual::{DriftMask, ContinualStrategy};
39/// use irithyll_core::drift::DriftSignal;
40///
41/// let mut mask = DriftMask::with_defaults(10);
42///
43/// // Train for a while, accumulating importance
44/// for _ in 0..100 {
45///     let params = vec![0.0; 10];
46///     let mut grads = vec![0.1; 10];
47///     mask.pre_update(&params, &mut grads);
48/// }
49///
50/// // Drift detected: freeze top 30% of params
51/// mask.on_drift(&[0.0; 10], DriftSignal::Drift);
52/// assert_eq!(mask.n_frozen(), 3); // 30% of 10
53/// ```
54pub struct DriftMask {
55    /// Accumulated importance per parameter (EWMA of |grad|).
56    importance: Vec<f64>,
57    /// Per-parameter frozen mask: `true` = frozen (gradient zeroed).
58    frozen: Vec<bool>,
59    /// Fraction of params to freeze on each drift event (default: 0.3).
60    freeze_fraction: f64,
61    /// EWMA decay for importance tracking (default: 0.99).
62    importance_alpha: f64,
63    /// Count of currently frozen parameters.
64    n_frozen: usize,
65}
66
67impl DriftMask {
68    /// Create a new `DriftMask` with explicit hyperparameters.
69    ///
70    /// # Arguments
71    ///
72    /// * `n_params` -- number of parameters to manage
73    /// * `freeze_fraction` -- fraction of *currently unfrozen* params to freeze on drift
74    /// * `importance_alpha` -- EWMA decay for importance; closer to 1.0 = longer memory
75    ///
76    /// # Panics
77    ///
78    /// Panics if `freeze_fraction` is not in `[0.0, 1.0]` or `importance_alpha`
79    /// is not in `[0.0, 1.0]`.
80    pub fn new(n_params: usize, freeze_fraction: f64, importance_alpha: f64) -> Self {
81        assert!(
82            (0.0..=1.0).contains(&freeze_fraction),
83            "freeze_fraction must be in [0.0, 1.0], got {freeze_fraction}"
84        );
85        assert!(
86            (0.0..=1.0).contains(&importance_alpha),
87            "importance_alpha must be in [0.0, 1.0], got {importance_alpha}"
88        );
89        Self {
90            importance: alloc::vec![0.0; n_params],
91            frozen: alloc::vec![false; n_params],
92            freeze_fraction,
93            importance_alpha,
94            n_frozen: 0,
95        }
96    }
97
98    /// Create a new `DriftMask` with default hyperparameters.
99    ///
100    /// Defaults: `freeze_fraction = 0.3`, `importance_alpha = 0.99`.
101    pub fn with_defaults(n_params: usize) -> Self {
102        Self::new(n_params, 0.3, 0.99)
103    }
104
105    /// Check if a specific parameter is frozen.
106    ///
107    /// # Panics
108    ///
109    /// Panics if `idx >= n_params()`.
110    #[inline]
111    pub fn is_frozen(&self, idx: usize) -> bool {
112        self.frozen[idx]
113    }
114
115    /// Number of currently frozen parameters.
116    #[inline]
117    pub fn n_frozen(&self) -> usize {
118        self.n_frozen
119    }
120
121    /// Fraction of parameters currently frozen.
122    #[inline]
123    pub fn frozen_fraction(&self) -> f64 {
124        if self.frozen.is_empty() {
125            return 0.0;
126        }
127        self.n_frozen as f64 / self.frozen.len() as f64
128    }
129
130    /// Current importance values.
131    #[inline]
132    pub fn importance(&self) -> &[f64] {
133        &self.importance
134    }
135
136    /// Unfreeze all parameters (fresh start for isolation).
137    pub fn unfreeze_all(&mut self) {
138        for f in &mut self.frozen {
139            *f = false;
140        }
141        self.n_frozen = 0;
142    }
143
144    /// Compute the freeze mask from current importance values.
145    ///
146    /// Freezes the top `freeze_fraction` of currently-unfrozen parameters
147    /// by importance. Already-frozen parameters remain frozen.
148    fn apply_freeze(&mut self) {
149        let n = self.importance.len();
150        if n == 0 {
151            return;
152        }
153
154        // Collect importance values of unfrozen parameters
155        let mut unfrozen_importance: Vec<(usize, f64)> = Vec::new();
156        for i in 0..n {
157            if !self.frozen[i] {
158                unfrozen_importance.push((i, self.importance[i]));
159            }
160        }
161
162        if unfrozen_importance.is_empty() {
163            return;
164        }
165
166        let n_unfrozen = unfrozen_importance.len();
167        // Number to freeze from the currently-unfrozen set
168        let n_to_freeze = math::round(self.freeze_fraction * n_unfrozen as f64) as usize;
169        if n_to_freeze == 0 {
170            return;
171        }
172
173        // Sort by importance descending to find top-k
174        // Manual sort: simple insertion sort (small param counts in streaming models)
175        for i in 1..unfrozen_importance.len() {
176            let mut j = i;
177            while j > 0 && unfrozen_importance[j].1 > unfrozen_importance[j - 1].1 {
178                unfrozen_importance.swap(j, j - 1);
179                j -= 1;
180            }
181        }
182
183        // Freeze the top n_to_freeze
184        for &(idx, _) in unfrozen_importance.iter().take(n_to_freeze) {
185            self.frozen[idx] = true;
186        }
187
188        // Recount
189        self.n_frozen = self.frozen.iter().filter(|&&f| f).count();
190    }
191}
192
193impl ContinualStrategy for DriftMask {
194    fn pre_update(&mut self, _params: &[f64], gradients: &mut [f64]) {
195        let n = self.importance.len();
196        debug_assert_eq!(gradients.len(), n);
197
198        let alpha = self.importance_alpha;
199        let one_minus_alpha = 1.0 - alpha;
200
201        for ((imp, grad), &is_frozen) in self
202            .importance
203            .iter_mut()
204            .zip(gradients.iter_mut())
205            .zip(self.frozen.iter())
206        {
207            // Update importance EWMA with |grad|
208            *imp = alpha * *imp + one_minus_alpha * math::abs(*grad);
209
210            // Zero gradients for frozen parameters
211            if is_frozen {
212                *grad = 0.0;
213            }
214        }
215    }
216
217    fn post_update(&mut self, _params: &[f64]) {
218        // No-op
219    }
220
221    fn on_drift(&mut self, _params: &[f64], signal: DriftSignal) {
222        match signal {
223            DriftSignal::Drift => {
224                self.apply_freeze();
225            }
226            DriftSignal::Warning | DriftSignal::Stable => {
227                // No action
228            }
229        }
230    }
231
232    #[inline]
233    fn n_params(&self) -> usize {
234        self.importance.len()
235    }
236
237    fn reset(&mut self) {
238        for v in &mut self.importance {
239            *v = 0.0;
240        }
241        for f in &mut self.frozen {
242            *f = false;
243        }
244        self.n_frozen = 0;
245    }
246}
247
248// ---------------------------------------------------------------------------
249// Tests
250// ---------------------------------------------------------------------------
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn initially_nothing_frozen() {
258        let mask = DriftMask::with_defaults(10);
259        assert_eq!(mask.n_frozen(), 0);
260        for i in 0..10 {
261            assert!(
262                !mask.is_frozen(i),
263                "param {i} should not be frozen initially"
264            );
265        }
266        assert!((mask.frozen_fraction() - 0.0).abs() < 1e-12);
267    }
268
269    #[test]
270    fn drift_freezes_top_fraction() {
271        let mut mask = DriftMask::new(10, 0.3, 0.0);
272        // With alpha=0.0, importance = |grad| directly (no memory of old)
273
274        let params = [0.0; 10];
275        // Give different importance to each param
276        let mut grads = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
277        mask.pre_update(&params, &mut grads);
278
279        // Importance is now [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
280        // Top 30% of 10 = 3 params: indices 7, 8, 9 (importance 8, 9, 10)
281        mask.on_drift(&params, DriftSignal::Drift);
282
283        assert_eq!(mask.n_frozen(), 3, "should freeze 30% = 3 params");
284        assert!(
285            mask.is_frozen(9),
286            "param 9 (importance 10) should be frozen"
287        );
288        assert!(mask.is_frozen(8), "param 8 (importance 9) should be frozen");
289        assert!(mask.is_frozen(7), "param 7 (importance 8) should be frozen");
290        assert!(!mask.is_frozen(6), "param 6 should remain unfrozen");
291        assert!(!mask.is_frozen(0), "param 0 should remain unfrozen");
292    }
293
294    #[test]
295    fn frozen_params_have_zero_gradient() {
296        let mut mask = DriftMask::new(4, 0.5, 0.0);
297
298        let params = [0.0; 4];
299        // Set importance: [1, 2, 3, 4]
300        let mut grads = [1.0, 2.0, 3.0, 4.0];
301        mask.pre_update(&params, &mut grads);
302
303        // Freeze top 50% = 2 params (indices 2 and 3)
304        mask.on_drift(&params, DriftSignal::Drift);
305        assert!(mask.is_frozen(2));
306        assert!(mask.is_frozen(3));
307
308        // Now call pre_update again with new gradients
309        let mut new_grads = [0.5, 0.5, 0.5, 0.5];
310        mask.pre_update(&params, &mut new_grads);
311
312        assert!(
313            new_grads[2].abs() < 1e-12,
314            "frozen param 2 gradient should be zero, got {}",
315            new_grads[2]
316        );
317        assert!(
318            new_grads[3].abs() < 1e-12,
319            "frozen param 3 gradient should be zero, got {}",
320            new_grads[3]
321        );
322    }
323
324    #[test]
325    fn unfrozen_params_pass_gradient_through() {
326        let mut mask = DriftMask::new(4, 0.5, 0.0);
327
328        let params = [0.0; 4];
329        let mut grads = [1.0, 2.0, 3.0, 4.0];
330        mask.pre_update(&params, &mut grads);
331
332        // Freeze top 50% (indices 2, 3)
333        mask.on_drift(&params, DriftSignal::Drift);
334
335        // Check unfrozen params still pass gradients through
336        let mut new_grads = [0.7, 0.8, 0.9, 1.0];
337        mask.pre_update(&params, &mut new_grads);
338
339        // Unfrozen params (0, 1) should have non-zero gradients
340        assert!(
341            new_grads[0].abs() > 1e-12,
342            "unfrozen param 0 should have non-zero gradient"
343        );
344        assert!(
345            new_grads[1].abs() > 1e-12,
346            "unfrozen param 1 should have non-zero gradient"
347        );
348        // They should still be approximately 0.7 and 0.8 (not zeroed)
349        assert!(
350            (new_grads[0] - 0.7).abs() < 1e-12,
351            "unfrozen param 0 gradient should pass through: got {}",
352            new_grads[0]
353        );
354        assert!(
355            (new_grads[1] - 0.8).abs() < 1e-12,
356            "unfrozen param 1 gradient should pass through: got {}",
357            new_grads[1]
358        );
359    }
360
361    #[test]
362    fn importance_tracks_gradient_magnitude() {
363        let mut mask = DriftMask::new(3, 0.3, 0.5);
364
365        let params = [0.0; 3];
366
367        // First update: importance = 0.5 * 0 + 0.5 * |grad|
368        let mut grads = [2.0, -4.0, 6.0];
369        mask.pre_update(&params, &mut grads);
370
371        let expected = [1.0, 2.0, 3.0]; // 0.5 * |grad|
372        for (i, &exp) in expected.iter().enumerate() {
373            assert!(
374                (mask.importance()[i] - exp).abs() < 1e-12,
375                "importance[{i}] = {}, expected {}",
376                mask.importance()[i],
377                exp
378            );
379        }
380
381        // Second update
382        let mut grads2 = [0.0, 0.0, 0.0];
383        mask.pre_update(&params, &mut grads2);
384        // importance = 0.5 * prev + 0.5 * 0 = prev / 2
385        let expected2 = [0.5, 1.0, 1.5];
386        for (i, &exp) in expected2.iter().enumerate() {
387            assert!(
388                (mask.importance()[i] - exp).abs() < 1e-12,
389                "importance[{i}] after 2nd = {}, expected {}",
390                mask.importance()[i],
391                exp
392            );
393        }
394    }
395
396    #[test]
397    fn unfreeze_all_resets_mask() {
398        let mut mask = DriftMask::new(5, 0.4, 0.0);
399
400        let params = [0.0; 5];
401        let mut grads = [1.0, 2.0, 3.0, 4.0, 5.0];
402        mask.pre_update(&params, &mut grads);
403
404        mask.on_drift(&params, DriftSignal::Drift);
405        assert!(mask.n_frozen() > 0, "should have frozen some params");
406
407        mask.unfreeze_all();
408        assert_eq!(mask.n_frozen(), 0, "all params should be unfrozen");
409        for i in 0..5 {
410            assert!(!mask.is_frozen(i), "param {i} should be unfrozen");
411        }
412    }
413
414    #[test]
415    fn multiple_drifts_accumulate_frozen() {
416        let mut mask = DriftMask::new(10, 0.3, 0.0);
417
418        let params = [0.0; 10];
419        let mut grads = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
420        mask.pre_update(&params, &mut grads);
421
422        // First drift: freeze top 30% of 10 unfrozen = 3 params (indices 7, 8, 9)
423        mask.on_drift(&params, DriftSignal::Drift);
424        let frozen_after_first = mask.n_frozen();
425        assert_eq!(frozen_after_first, 3);
426
427        // Update importance for unfrozen params with new grads
428        let mut grads2 = [10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 0.0, 0.0, 0.0];
429        mask.pre_update(&params, &mut grads2);
430
431        // Second drift: freeze top 30% of 7 unfrozen = round(0.3 * 7) = 2 params
432        mask.on_drift(&params, DriftSignal::Drift);
433        let frozen_after_second = mask.n_frozen();
434        assert!(
435            frozen_after_second > frozen_after_first,
436            "second drift should freeze more: first={}, second={}",
437            frozen_after_first,
438            frozen_after_second
439        );
440        // Previously frozen params should still be frozen
441        assert!(mask.is_frozen(9), "param 9 should still be frozen");
442        assert!(mask.is_frozen(8), "param 8 should still be frozen");
443        assert!(mask.is_frozen(7), "param 7 should still be frozen");
444    }
445
446    #[test]
447    fn reset_clears_everything() {
448        let mut mask = DriftMask::new(5, 0.4, 0.0);
449
450        let params = [0.0; 5];
451        let mut grads = [1.0, 2.0, 3.0, 4.0, 5.0];
452        mask.pre_update(&params, &mut grads);
453        mask.on_drift(&params, DriftSignal::Drift);
454
455        assert!(mask.n_frozen() > 0);
456        assert!(mask.importance().iter().any(|&v| v > 0.0));
457
458        mask.reset();
459
460        assert_eq!(
461            mask.n_frozen(),
462            0,
463            "frozen count should be zero after reset"
464        );
465        assert!(
466            mask.importance().iter().all(|&v| v == 0.0),
467            "importance should be zeroed after reset"
468        );
469        for i in 0..5 {
470            assert!(
471                !mask.is_frozen(i),
472                "param {i} should be unfrozen after reset"
473            );
474        }
475    }
476
477    #[test]
478    fn warning_and_stable_do_not_freeze() {
479        let mut mask = DriftMask::new(5, 0.5, 0.0);
480
481        let params = [0.0; 5];
482        let mut grads = [1.0, 2.0, 3.0, 4.0, 5.0];
483        mask.pre_update(&params, &mut grads);
484
485        mask.on_drift(&params, DriftSignal::Warning);
486        assert_eq!(mask.n_frozen(), 0, "Warning should not freeze anything");
487
488        mask.on_drift(&params, DriftSignal::Stable);
489        assert_eq!(mask.n_frozen(), 0, "Stable should not freeze anything");
490    }
491
492    #[test]
493    fn empty_mask_operations() {
494        let mut mask = DriftMask::with_defaults(0);
495        assert_eq!(mask.n_frozen(), 0);
496        assert!((mask.frozen_fraction() - 0.0).abs() < 1e-12);
497
498        let params: [f64; 0] = [];
499        let mut grads: [f64; 0] = [];
500        mask.pre_update(&params, &mut grads);
501        mask.on_drift(&params, DriftSignal::Drift);
502        mask.reset();
503        assert_eq!(mask.n_params(), 0);
504    }
505}