use alloc::vec::Vec;
use super::ContinualStrategy;
use crate::drift::DriftSignal;
use crate::math;
pub struct DriftMask {
importance: Vec<f64>,
frozen: Vec<bool>,
freeze_fraction: f64,
importance_alpha: f64,
n_frozen: usize,
}
impl DriftMask {
pub fn new(n_params: usize, freeze_fraction: f64, importance_alpha: f64) -> Self {
assert!(
(0.0..=1.0).contains(&freeze_fraction),
"freeze_fraction must be in [0.0, 1.0], got {freeze_fraction}"
);
assert!(
(0.0..=1.0).contains(&importance_alpha),
"importance_alpha must be in [0.0, 1.0], got {importance_alpha}"
);
Self {
importance: alloc::vec![0.0; n_params],
frozen: alloc::vec![false; n_params],
freeze_fraction,
importance_alpha,
n_frozen: 0,
}
}
pub fn with_defaults(n_params: usize) -> Self {
Self::new(n_params, 0.3, 0.99)
}
#[inline]
pub fn is_frozen(&self, idx: usize) -> bool {
self.frozen[idx]
}
#[inline]
pub fn n_frozen(&self) -> usize {
self.n_frozen
}
#[inline]
pub fn frozen_fraction(&self) -> f64 {
if self.frozen.is_empty() {
return 0.0;
}
self.n_frozen as f64 / self.frozen.len() as f64
}
#[inline]
pub fn importance(&self) -> &[f64] {
&self.importance
}
pub fn unfreeze_all(&mut self) {
for f in &mut self.frozen {
*f = false;
}
self.n_frozen = 0;
}
fn apply_freeze(&mut self) {
let n = self.importance.len();
if n == 0 {
return;
}
let mut unfrozen_importance: Vec<(usize, f64)> = Vec::new();
for i in 0..n {
if !self.frozen[i] {
unfrozen_importance.push((i, self.importance[i]));
}
}
if unfrozen_importance.is_empty() {
return;
}
let n_unfrozen = unfrozen_importance.len();
let n_to_freeze = math::round(self.freeze_fraction * n_unfrozen as f64) as usize;
if n_to_freeze == 0 {
return;
}
for i in 1..unfrozen_importance.len() {
let mut j = i;
while j > 0 && unfrozen_importance[j].1 > unfrozen_importance[j - 1].1 {
unfrozen_importance.swap(j, j - 1);
j -= 1;
}
}
for &(idx, _) in unfrozen_importance.iter().take(n_to_freeze) {
self.frozen[idx] = true;
}
self.n_frozen = self.frozen.iter().filter(|&&f| f).count();
}
}
impl ContinualStrategy for DriftMask {
fn pre_update(&mut self, _params: &[f64], gradients: &mut [f64]) {
let n = self.importance.len();
debug_assert_eq!(gradients.len(), n);
let alpha = self.importance_alpha;
let one_minus_alpha = 1.0 - alpha;
for ((imp, grad), &is_frozen) in self
.importance
.iter_mut()
.zip(gradients.iter_mut())
.zip(self.frozen.iter())
{
*imp = alpha * *imp + one_minus_alpha * math::abs(*grad);
if is_frozen {
*grad = 0.0;
}
}
}
fn post_update(&mut self, _params: &[f64]) {
}
fn on_drift(&mut self, _params: &[f64], signal: DriftSignal) {
match signal {
DriftSignal::Drift => {
self.apply_freeze();
}
DriftSignal::Warning | DriftSignal::Stable => {
}
}
}
#[inline]
fn n_params(&self) -> usize {
self.importance.len()
}
fn reset(&mut self) {
for v in &mut self.importance {
*v = 0.0;
}
for f in &mut self.frozen {
*f = false;
}
self.n_frozen = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initially_nothing_frozen() {
let mask = DriftMask::with_defaults(10);
assert_eq!(mask.n_frozen(), 0);
for i in 0..10 {
assert!(
!mask.is_frozen(i),
"param {i} should not be frozen initially"
);
}
assert!((mask.frozen_fraction() - 0.0).abs() < 1e-12);
}
#[test]
fn drift_freezes_top_fraction() {
let mut mask = DriftMask::new(10, 0.3, 0.0);
let params = [0.0; 10];
let mut grads = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
mask.pre_update(¶ms, &mut grads);
mask.on_drift(¶ms, DriftSignal::Drift);
assert_eq!(mask.n_frozen(), 3, "should freeze 30% = 3 params");
assert!(
mask.is_frozen(9),
"param 9 (importance 10) should be frozen"
);
assert!(mask.is_frozen(8), "param 8 (importance 9) should be frozen");
assert!(mask.is_frozen(7), "param 7 (importance 8) should be frozen");
assert!(!mask.is_frozen(6), "param 6 should remain unfrozen");
assert!(!mask.is_frozen(0), "param 0 should remain unfrozen");
}
#[test]
fn frozen_params_have_zero_gradient() {
let mut mask = DriftMask::new(4, 0.5, 0.0);
let params = [0.0; 4];
let mut grads = [1.0, 2.0, 3.0, 4.0];
mask.pre_update(¶ms, &mut grads);
mask.on_drift(¶ms, DriftSignal::Drift);
assert!(mask.is_frozen(2));
assert!(mask.is_frozen(3));
let mut new_grads = [0.5, 0.5, 0.5, 0.5];
mask.pre_update(¶ms, &mut new_grads);
assert!(
new_grads[2].abs() < 1e-12,
"frozen param 2 gradient should be zero, got {}",
new_grads[2]
);
assert!(
new_grads[3].abs() < 1e-12,
"frozen param 3 gradient should be zero, got {}",
new_grads[3]
);
}
#[test]
fn unfrozen_params_pass_gradient_through() {
let mut mask = DriftMask::new(4, 0.5, 0.0);
let params = [0.0; 4];
let mut grads = [1.0, 2.0, 3.0, 4.0];
mask.pre_update(¶ms, &mut grads);
mask.on_drift(¶ms, DriftSignal::Drift);
let mut new_grads = [0.7, 0.8, 0.9, 1.0];
mask.pre_update(¶ms, &mut new_grads);
assert!(
new_grads[0].abs() > 1e-12,
"unfrozen param 0 should have non-zero gradient"
);
assert!(
new_grads[1].abs() > 1e-12,
"unfrozen param 1 should have non-zero gradient"
);
assert!(
(new_grads[0] - 0.7).abs() < 1e-12,
"unfrozen param 0 gradient should pass through: got {}",
new_grads[0]
);
assert!(
(new_grads[1] - 0.8).abs() < 1e-12,
"unfrozen param 1 gradient should pass through: got {}",
new_grads[1]
);
}
#[test]
fn importance_tracks_gradient_magnitude() {
let mut mask = DriftMask::new(3, 0.3, 0.5);
let params = [0.0; 3];
let mut grads = [2.0, -4.0, 6.0];
mask.pre_update(¶ms, &mut grads);
let expected = [1.0, 2.0, 3.0]; for (i, &exp) in expected.iter().enumerate() {
assert!(
(mask.importance()[i] - exp).abs() < 1e-12,
"importance[{i}] = {}, expected {}",
mask.importance()[i],
exp
);
}
let mut grads2 = [0.0, 0.0, 0.0];
mask.pre_update(¶ms, &mut grads2);
let expected2 = [0.5, 1.0, 1.5];
for (i, &exp) in expected2.iter().enumerate() {
assert!(
(mask.importance()[i] - exp).abs() < 1e-12,
"importance[{i}] after 2nd = {}, expected {}",
mask.importance()[i],
exp
);
}
}
#[test]
fn unfreeze_all_resets_mask() {
let mut mask = DriftMask::new(5, 0.4, 0.0);
let params = [0.0; 5];
let mut grads = [1.0, 2.0, 3.0, 4.0, 5.0];
mask.pre_update(¶ms, &mut grads);
mask.on_drift(¶ms, DriftSignal::Drift);
assert!(mask.n_frozen() > 0, "should have frozen some params");
mask.unfreeze_all();
assert_eq!(mask.n_frozen(), 0, "all params should be unfrozen");
for i in 0..5 {
assert!(!mask.is_frozen(i), "param {i} should be unfrozen");
}
}
#[test]
fn multiple_drifts_accumulate_frozen() {
let mut mask = DriftMask::new(10, 0.3, 0.0);
let params = [0.0; 10];
let mut grads = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
mask.pre_update(¶ms, &mut grads);
mask.on_drift(¶ms, DriftSignal::Drift);
let frozen_after_first = mask.n_frozen();
assert_eq!(frozen_after_first, 3);
let mut grads2 = [10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 0.0, 0.0, 0.0];
mask.pre_update(¶ms, &mut grads2);
mask.on_drift(¶ms, DriftSignal::Drift);
let frozen_after_second = mask.n_frozen();
assert!(
frozen_after_second > frozen_after_first,
"second drift should freeze more: first={}, second={}",
frozen_after_first,
frozen_after_second
);
assert!(mask.is_frozen(9), "param 9 should still be frozen");
assert!(mask.is_frozen(8), "param 8 should still be frozen");
assert!(mask.is_frozen(7), "param 7 should still be frozen");
}
#[test]
fn reset_clears_everything() {
let mut mask = DriftMask::new(5, 0.4, 0.0);
let params = [0.0; 5];
let mut grads = [1.0, 2.0, 3.0, 4.0, 5.0];
mask.pre_update(¶ms, &mut grads);
mask.on_drift(¶ms, DriftSignal::Drift);
assert!(mask.n_frozen() > 0);
assert!(mask.importance().iter().any(|&v| v > 0.0));
mask.reset();
assert_eq!(
mask.n_frozen(),
0,
"frozen count should be zero after reset"
);
assert!(
mask.importance().iter().all(|&v| v == 0.0),
"importance should be zeroed after reset"
);
for i in 0..5 {
assert!(
!mask.is_frozen(i),
"param {i} should be unfrozen after reset"
);
}
}
#[test]
fn warning_and_stable_do_not_freeze() {
let mut mask = DriftMask::new(5, 0.5, 0.0);
let params = [0.0; 5];
let mut grads = [1.0, 2.0, 3.0, 4.0, 5.0];
mask.pre_update(¶ms, &mut grads);
mask.on_drift(¶ms, DriftSignal::Warning);
assert_eq!(mask.n_frozen(), 0, "Warning should not freeze anything");
mask.on_drift(¶ms, DriftSignal::Stable);
assert_eq!(mask.n_frozen(), 0, "Stable should not freeze anything");
}
#[test]
fn empty_mask_operations() {
let mut mask = DriftMask::with_defaults(0);
assert_eq!(mask.n_frozen(), 0);
assert!((mask.frozen_fraction() - 0.0).abs() < 1e-12);
let params: [f64; 0] = [];
let mut grads: [f64; 0] = [];
mask.pre_update(¶ms, &mut grads);
mask.on_drift(¶ms, DriftSignal::Drift);
mask.reset();
assert_eq!(mask.n_params(), 0);
}
}