1use alloc::vec::Vec;
21
22use super::ContinualStrategy;
23use crate::drift::DriftSignal;
24use crate::math;
25
26pub struct DriftMask {
55 importance: Vec<f64>,
57 frozen: Vec<bool>,
59 freeze_fraction: f64,
61 importance_alpha: f64,
63 n_frozen: usize,
65}
66
67impl DriftMask {
68 pub fn new(n_params: usize, freeze_fraction: f64, importance_alpha: f64) -> Self {
81 assert!(
82 (0.0..=1.0).contains(&freeze_fraction),
83 "freeze_fraction must be in [0.0, 1.0], got {freeze_fraction}"
84 );
85 assert!(
86 (0.0..=1.0).contains(&importance_alpha),
87 "importance_alpha must be in [0.0, 1.0], got {importance_alpha}"
88 );
89 Self {
90 importance: alloc::vec![0.0; n_params],
91 frozen: alloc::vec![false; n_params],
92 freeze_fraction,
93 importance_alpha,
94 n_frozen: 0,
95 }
96 }
97
98 pub fn with_defaults(n_params: usize) -> Self {
102 Self::new(n_params, 0.3, 0.99)
103 }
104
105 #[inline]
111 pub fn is_frozen(&self, idx: usize) -> bool {
112 self.frozen[idx]
113 }
114
115 #[inline]
117 pub fn n_frozen(&self) -> usize {
118 self.n_frozen
119 }
120
121 #[inline]
123 pub fn frozen_fraction(&self) -> f64 {
124 if self.frozen.is_empty() {
125 return 0.0;
126 }
127 self.n_frozen as f64 / self.frozen.len() as f64
128 }
129
130 #[inline]
132 pub fn importance(&self) -> &[f64] {
133 &self.importance
134 }
135
136 pub fn unfreeze_all(&mut self) {
138 for f in &mut self.frozen {
139 *f = false;
140 }
141 self.n_frozen = 0;
142 }
143
144 fn apply_freeze(&mut self) {
149 let n = self.importance.len();
150 if n == 0 {
151 return;
152 }
153
154 let mut unfrozen_importance: Vec<(usize, f64)> = Vec::new();
156 for i in 0..n {
157 if !self.frozen[i] {
158 unfrozen_importance.push((i, self.importance[i]));
159 }
160 }
161
162 if unfrozen_importance.is_empty() {
163 return;
164 }
165
166 let n_unfrozen = unfrozen_importance.len();
167 let n_to_freeze = math::round(self.freeze_fraction * n_unfrozen as f64) as usize;
169 if n_to_freeze == 0 {
170 return;
171 }
172
173 for i in 1..unfrozen_importance.len() {
176 let mut j = i;
177 while j > 0 && unfrozen_importance[j].1 > unfrozen_importance[j - 1].1 {
178 unfrozen_importance.swap(j, j - 1);
179 j -= 1;
180 }
181 }
182
183 for &(idx, _) in unfrozen_importance.iter().take(n_to_freeze) {
185 self.frozen[idx] = true;
186 }
187
188 self.n_frozen = self.frozen.iter().filter(|&&f| f).count();
190 }
191}
192
193impl ContinualStrategy for DriftMask {
194 fn pre_update(&mut self, _params: &[f64], gradients: &mut [f64]) {
195 let n = self.importance.len();
196 debug_assert_eq!(gradients.len(), n);
197
198 let alpha = self.importance_alpha;
199 let one_minus_alpha = 1.0 - alpha;
200
201 for ((imp, grad), &is_frozen) in self
202 .importance
203 .iter_mut()
204 .zip(gradients.iter_mut())
205 .zip(self.frozen.iter())
206 {
207 *imp = alpha * *imp + one_minus_alpha * math::abs(*grad);
209
210 if is_frozen {
212 *grad = 0.0;
213 }
214 }
215 }
216
217 fn post_update(&mut self, _params: &[f64]) {
218 }
220
221 fn on_drift(&mut self, _params: &[f64], signal: DriftSignal) {
222 match signal {
223 DriftSignal::Drift => {
224 self.apply_freeze();
225 }
226 DriftSignal::Warning | DriftSignal::Stable => {
227 }
229 }
230 }
231
232 #[inline]
233 fn n_params(&self) -> usize {
234 self.importance.len()
235 }
236
237 fn reset(&mut self) {
238 for v in &mut self.importance {
239 *v = 0.0;
240 }
241 for f in &mut self.frozen {
242 *f = false;
243 }
244 self.n_frozen = 0;
245 }
246}
247
248#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn initially_nothing_frozen() {
258 let mask = DriftMask::with_defaults(10);
259 assert_eq!(mask.n_frozen(), 0);
260 for i in 0..10 {
261 assert!(
262 !mask.is_frozen(i),
263 "param {i} should not be frozen initially"
264 );
265 }
266 assert!((mask.frozen_fraction() - 0.0).abs() < 1e-12);
267 }
268
269 #[test]
270 fn drift_freezes_top_fraction() {
271 let mut mask = DriftMask::new(10, 0.3, 0.0);
272 let params = [0.0; 10];
275 let mut grads = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
277 mask.pre_update(¶ms, &mut grads);
278
279 mask.on_drift(¶ms, DriftSignal::Drift);
282
283 assert_eq!(mask.n_frozen(), 3, "should freeze 30% = 3 params");
284 assert!(
285 mask.is_frozen(9),
286 "param 9 (importance 10) should be frozen"
287 );
288 assert!(mask.is_frozen(8), "param 8 (importance 9) should be frozen");
289 assert!(mask.is_frozen(7), "param 7 (importance 8) should be frozen");
290 assert!(!mask.is_frozen(6), "param 6 should remain unfrozen");
291 assert!(!mask.is_frozen(0), "param 0 should remain unfrozen");
292 }
293
294 #[test]
295 fn frozen_params_have_zero_gradient() {
296 let mut mask = DriftMask::new(4, 0.5, 0.0);
297
298 let params = [0.0; 4];
299 let mut grads = [1.0, 2.0, 3.0, 4.0];
301 mask.pre_update(¶ms, &mut grads);
302
303 mask.on_drift(¶ms, DriftSignal::Drift);
305 assert!(mask.is_frozen(2));
306 assert!(mask.is_frozen(3));
307
308 let mut new_grads = [0.5, 0.5, 0.5, 0.5];
310 mask.pre_update(¶ms, &mut new_grads);
311
312 assert!(
313 new_grads[2].abs() < 1e-12,
314 "frozen param 2 gradient should be zero, got {}",
315 new_grads[2]
316 );
317 assert!(
318 new_grads[3].abs() < 1e-12,
319 "frozen param 3 gradient should be zero, got {}",
320 new_grads[3]
321 );
322 }
323
324 #[test]
325 fn unfrozen_params_pass_gradient_through() {
326 let mut mask = DriftMask::new(4, 0.5, 0.0);
327
328 let params = [0.0; 4];
329 let mut grads = [1.0, 2.0, 3.0, 4.0];
330 mask.pre_update(¶ms, &mut grads);
331
332 mask.on_drift(¶ms, DriftSignal::Drift);
334
335 let mut new_grads = [0.7, 0.8, 0.9, 1.0];
337 mask.pre_update(¶ms, &mut new_grads);
338
339 assert!(
341 new_grads[0].abs() > 1e-12,
342 "unfrozen param 0 should have non-zero gradient"
343 );
344 assert!(
345 new_grads[1].abs() > 1e-12,
346 "unfrozen param 1 should have non-zero gradient"
347 );
348 assert!(
350 (new_grads[0] - 0.7).abs() < 1e-12,
351 "unfrozen param 0 gradient should pass through: got {}",
352 new_grads[0]
353 );
354 assert!(
355 (new_grads[1] - 0.8).abs() < 1e-12,
356 "unfrozen param 1 gradient should pass through: got {}",
357 new_grads[1]
358 );
359 }
360
361 #[test]
362 fn importance_tracks_gradient_magnitude() {
363 let mut mask = DriftMask::new(3, 0.3, 0.5);
364
365 let params = [0.0; 3];
366
367 let mut grads = [2.0, -4.0, 6.0];
369 mask.pre_update(¶ms, &mut grads);
370
371 let expected = [1.0, 2.0, 3.0]; for (i, &exp) in expected.iter().enumerate() {
373 assert!(
374 (mask.importance()[i] - exp).abs() < 1e-12,
375 "importance[{i}] = {}, expected {}",
376 mask.importance()[i],
377 exp
378 );
379 }
380
381 let mut grads2 = [0.0, 0.0, 0.0];
383 mask.pre_update(¶ms, &mut grads2);
384 let expected2 = [0.5, 1.0, 1.5];
386 for (i, &exp) in expected2.iter().enumerate() {
387 assert!(
388 (mask.importance()[i] - exp).abs() < 1e-12,
389 "importance[{i}] after 2nd = {}, expected {}",
390 mask.importance()[i],
391 exp
392 );
393 }
394 }
395
396 #[test]
397 fn unfreeze_all_resets_mask() {
398 let mut mask = DriftMask::new(5, 0.4, 0.0);
399
400 let params = [0.0; 5];
401 let mut grads = [1.0, 2.0, 3.0, 4.0, 5.0];
402 mask.pre_update(¶ms, &mut grads);
403
404 mask.on_drift(¶ms, DriftSignal::Drift);
405 assert!(mask.n_frozen() > 0, "should have frozen some params");
406
407 mask.unfreeze_all();
408 assert_eq!(mask.n_frozen(), 0, "all params should be unfrozen");
409 for i in 0..5 {
410 assert!(!mask.is_frozen(i), "param {i} should be unfrozen");
411 }
412 }
413
414 #[test]
415 fn multiple_drifts_accumulate_frozen() {
416 let mut mask = DriftMask::new(10, 0.3, 0.0);
417
418 let params = [0.0; 10];
419 let mut grads = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
420 mask.pre_update(¶ms, &mut grads);
421
422 mask.on_drift(¶ms, DriftSignal::Drift);
424 let frozen_after_first = mask.n_frozen();
425 assert_eq!(frozen_after_first, 3);
426
427 let mut grads2 = [10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 0.0, 0.0, 0.0];
429 mask.pre_update(¶ms, &mut grads2);
430
431 mask.on_drift(¶ms, DriftSignal::Drift);
433 let frozen_after_second = mask.n_frozen();
434 assert!(
435 frozen_after_second > frozen_after_first,
436 "second drift should freeze more: first={}, second={}",
437 frozen_after_first,
438 frozen_after_second
439 );
440 assert!(mask.is_frozen(9), "param 9 should still be frozen");
442 assert!(mask.is_frozen(8), "param 8 should still be frozen");
443 assert!(mask.is_frozen(7), "param 7 should still be frozen");
444 }
445
446 #[test]
447 fn reset_clears_everything() {
448 let mut mask = DriftMask::new(5, 0.4, 0.0);
449
450 let params = [0.0; 5];
451 let mut grads = [1.0, 2.0, 3.0, 4.0, 5.0];
452 mask.pre_update(¶ms, &mut grads);
453 mask.on_drift(¶ms, DriftSignal::Drift);
454
455 assert!(mask.n_frozen() > 0);
456 assert!(mask.importance().iter().any(|&v| v > 0.0));
457
458 mask.reset();
459
460 assert_eq!(
461 mask.n_frozen(),
462 0,
463 "frozen count should be zero after reset"
464 );
465 assert!(
466 mask.importance().iter().all(|&v| v == 0.0),
467 "importance should be zeroed after reset"
468 );
469 for i in 0..5 {
470 assert!(
471 !mask.is_frozen(i),
472 "param {i} should be unfrozen after reset"
473 );
474 }
475 }
476
477 #[test]
478 fn warning_and_stable_do_not_freeze() {
479 let mut mask = DriftMask::new(5, 0.5, 0.0);
480
481 let params = [0.0; 5];
482 let mut grads = [1.0, 2.0, 3.0, 4.0, 5.0];
483 mask.pre_update(¶ms, &mut grads);
484
485 mask.on_drift(¶ms, DriftSignal::Warning);
486 assert_eq!(mask.n_frozen(), 0, "Warning should not freeze anything");
487
488 mask.on_drift(¶ms, DriftSignal::Stable);
489 assert_eq!(mask.n_frozen(), 0, "Stable should not freeze anything");
490 }
491
492 #[test]
493 fn empty_mask_operations() {
494 let mut mask = DriftMask::with_defaults(0);
495 assert_eq!(mask.n_frozen(), 0);
496 assert!((mask.frozen_fraction() - 0.0).abs() < 1e-12);
497
498 let params: [f64; 0] = [];
499 let mut grads: [f64; 0] = [];
500 mask.pre_update(¶ms, &mut grads);
501 mask.on_drift(¶ms, DriftSignal::Drift);
502 mask.reset();
503 assert_eq!(mask.n_params(), 0);
504 }
505}