use alloc::vec::Vec;
use super::ContinualStrategy;
use crate::drift::DriftSignal;
pub struct StreamingEWC {
fisher_diag: Vec<f64>,
anchor_params: Vec<f64>,
fisher_alpha: f64,
ewc_lambda: f64,
n_updates: u64,
initialized: bool,
}
impl StreamingEWC {
pub fn new(n_params: usize, ewc_lambda: f64, fisher_alpha: f64) -> Self {
assert!(
(0.0..=1.0).contains(&fisher_alpha),
"fisher_alpha must be in [0.0, 1.0], got {fisher_alpha}"
);
Self {
fisher_diag: alloc::vec![0.0; n_params],
anchor_params: alloc::vec![0.0; n_params],
fisher_alpha,
ewc_lambda,
n_updates: 0,
initialized: false,
}
}
pub fn with_defaults(n_params: usize) -> Self {
Self::new(n_params, 1.0, 0.99)
}
#[inline]
pub fn fisher(&self) -> &[f64] {
&self.fisher_diag
}
#[inline]
pub fn anchor(&self) -> &[f64] {
&self.anchor_params
}
#[inline]
pub fn ewc_lambda(&self) -> f64 {
self.ewc_lambda
}
#[inline]
pub fn n_updates(&self) -> u64 {
self.n_updates
}
#[inline]
pub fn is_initialized(&self) -> bool {
self.initialized
}
pub fn set_anchor(&mut self, params: &[f64]) {
assert_eq!(
params.len(),
self.fisher_diag.len(),
"set_anchor: expected {} params, got {}",
self.fisher_diag.len(),
params.len()
);
self.anchor_params.copy_from_slice(params);
self.initialized = true;
}
pub fn penalty(&self, params: &[f64]) -> f64 {
if !self.initialized {
return 0.0;
}
let mut total = 0.0;
for ((&f, &a), &p) in self
.fisher_diag
.iter()
.zip(self.anchor_params.iter())
.zip(params.iter())
{
let diff = p - a;
total += f * diff * diff;
}
0.5 * self.ewc_lambda * total
}
}
impl ContinualStrategy for StreamingEWC {
fn pre_update(&mut self, params: &[f64], gradients: &mut [f64]) {
let n = self.fisher_diag.len();
debug_assert_eq!(params.len(), n);
debug_assert_eq!(gradients.len(), n);
let alpha = self.fisher_alpha;
let one_minus_alpha = 1.0 - alpha;
for i in 0..n {
self.fisher_diag[i] =
alpha * self.fisher_diag[i] + one_minus_alpha * gradients[i] * gradients[i];
if self.initialized {
let diff = params[i] - self.anchor_params[i];
gradients[i] += self.ewc_lambda * self.fisher_diag[i] * diff;
}
}
}
fn post_update(&mut self, _params: &[f64]) {
self.n_updates += 1;
}
fn on_drift(&mut self, params: &[f64], signal: DriftSignal) {
match signal {
DriftSignal::Drift => {
self.set_anchor(params);
}
DriftSignal::Warning | DriftSignal::Stable => {
}
}
}
#[inline]
fn n_params(&self) -> usize {
self.fisher_diag.len()
}
fn reset(&mut self) {
for v in &mut self.fisher_diag {
*v = 0.0;
}
for v in &mut self.anchor_params {
*v = 0.0;
}
self.n_updates = 0;
self.initialized = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ewc_gradient_penalty_pushes_toward_anchor() {
let mut ewc = StreamingEWC::new(3, 2.0, 0.5);
let anchor = [0.0, 0.0, 0.0];
ewc.set_anchor(&anchor);
let params = [1.0, -1.0, 0.5];
let mut grads = [0.1, 0.1, 0.1];
ewc.pre_update(¶ms, &mut grads);
assert!(
grads[0] > 0.1,
"gradient should be pushed away from anchor direction: got {}",
grads[0]
);
assert!(
grads[1] < 0.1,
"gradient should be pushed toward anchor: got {}",
grads[1]
);
}
#[test]
fn fisher_accumulates_squared_gradients() {
let mut ewc = StreamingEWC::new(2, 0.0, 0.5);
let params = [0.0, 0.0];
let mut grads = [2.0, 3.0];
ewc.pre_update(¶ms, &mut grads);
let expected_f0 = 0.5 * 0.0 + 0.5 * 4.0; let expected_f1 = 0.5 * 0.0 + 0.5 * 9.0; assert!(
(ewc.fisher()[0] - expected_f0).abs() < 1e-12,
"fisher[0] = {}, expected {}",
ewc.fisher()[0],
expected_f0
);
assert!(
(ewc.fisher()[1] - expected_f1).abs() < 1e-12,
"fisher[1] = {}, expected {}",
ewc.fisher()[1],
expected_f1
);
let mut grads2 = [1.0, 1.0];
ewc.pre_update(¶ms, &mut grads2);
let expected_f0_2 = 0.5 * expected_f0 + 0.5 * 1.0; let expected_f1_2 = 0.5 * expected_f1 + 0.5 * 1.0; assert!(
(ewc.fisher()[0] - expected_f0_2).abs() < 1e-12,
"fisher[0] after 2nd = {}, expected {}",
ewc.fisher()[0],
expected_f0_2
);
assert!(
(ewc.fisher()[1] - expected_f1_2).abs() < 1e-12,
"fisher[1] after 2nd = {}, expected {}",
ewc.fisher()[1],
expected_f1_2
);
}
#[test]
fn drift_signal_updates_anchor() {
let mut ewc = StreamingEWC::with_defaults(3);
let initial = [1.0, 2.0, 3.0];
ewc.set_anchor(&initial);
assert_eq!(ewc.anchor(), &[1.0, 2.0, 3.0]);
let new_params = [4.0, 5.0, 6.0];
ewc.on_drift(&new_params, DriftSignal::Drift);
assert_eq!(
ewc.anchor(),
&[4.0, 5.0, 6.0],
"anchor should be updated on Drift signal"
);
}
#[test]
fn warning_signal_no_anchor_change() {
let mut ewc = StreamingEWC::with_defaults(2);
let anchor = [1.0, 2.0];
ewc.set_anchor(&anchor);
let new_params = [10.0, 20.0];
ewc.on_drift(&new_params, DriftSignal::Warning);
assert_eq!(
ewc.anchor(),
&[1.0, 2.0],
"anchor should not change on Warning"
);
}
#[test]
fn stable_signal_no_effect() {
let mut ewc = StreamingEWC::with_defaults(2);
let anchor = [1.0, 2.0];
ewc.set_anchor(&anchor);
let new_params = [10.0, 20.0];
ewc.on_drift(&new_params, DriftSignal::Stable);
assert_eq!(
ewc.anchor(),
&[1.0, 2.0],
"anchor should not change on Stable"
);
}
#[test]
fn penalty_increases_with_distance_from_anchor() {
let mut ewc = StreamingEWC::new(2, 1.0, 0.5);
let anchor = [0.0, 0.0];
ewc.set_anchor(&anchor);
let params = [0.0, 0.0];
let mut grads = [1.0, 1.0];
ewc.pre_update(¶ms, &mut grads);
let close = [0.1, 0.1];
let far = [1.0, 1.0];
let penalty_close = ewc.penalty(&close);
let penalty_far = ewc.penalty(&far);
assert!(
penalty_far > penalty_close,
"penalty should increase with distance: close={}, far={}",
penalty_close,
penalty_far
);
assert!(
penalty_close > 0.0,
"penalty should be positive for non-zero distance"
);
}
#[test]
fn reset_clears_all_state() {
let mut ewc = StreamingEWC::with_defaults(3);
let params = [1.0, 2.0, 3.0];
ewc.set_anchor(¶ms);
let mut grads = [0.5, 0.5, 0.5];
ewc.pre_update(¶ms, &mut grads);
ewc.post_update(¶ms);
assert!(ewc.is_initialized());
assert!(ewc.n_updates() > 0);
assert!(ewc.fisher().iter().any(|&f| f > 0.0));
ewc.reset();
assert!(!ewc.is_initialized());
assert_eq!(ewc.n_updates(), 0);
assert!(
ewc.fisher().iter().all(|&f| f == 0.0),
"Fisher should be zeroed after reset"
);
assert!(
ewc.anchor().iter().all(|&a| a == 0.0),
"anchor should be zeroed after reset"
);
}
#[test]
fn zero_lambda_means_no_penalty() {
let mut ewc = StreamingEWC::new(3, 0.0, 0.99);
let anchor = [0.0, 0.0, 0.0];
ewc.set_anchor(&anchor);
let params = [0.0, 0.0, 0.0];
let mut grads_seed = [1.0, 1.0, 1.0];
ewc.pre_update(¶ms, &mut grads_seed);
let params_far = [10.0, 10.0, 10.0];
let original_grads = [0.5, -0.3, 0.7];
let mut grads = original_grads;
ewc.pre_update(¶ms_far, &mut grads);
for i in 0..3 {
assert!(
(grads[i] - original_grads[i]).abs() < 1e-12,
"gradient[{i}] should be unchanged with lambda=0: got {}, expected {}",
grads[i],
original_grads[i]
);
}
assert!(
ewc.penalty(¶ms_far).abs() < 1e-12,
"penalty should be zero with lambda=0"
);
}
#[test]
fn uninitialized_ewc_has_no_penalty() {
let ewc = StreamingEWC::with_defaults(3);
let params = [10.0, 20.0, 30.0];
assert!(
ewc.penalty(¶ms).abs() < 1e-12,
"penalty should be zero before anchor is set"
);
}
#[test]
fn post_update_increments_counter() {
let mut ewc = StreamingEWC::with_defaults(2);
assert_eq!(ewc.n_updates(), 0);
ewc.post_update(&[1.0, 2.0]);
assert_eq!(ewc.n_updates(), 1);
ewc.post_update(&[1.0, 2.0]);
assert_eq!(ewc.n_updates(), 2);
}
}