use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct EwcConfig {
pub param_count: usize,
pub max_tasks: usize,
pub initial_lambda: f32,
pub min_lambda: f32,
pub max_lambda: f32,
pub fisher_ema_decay: f32,
pub boundary_threshold: f32,
pub gradient_history_size: usize,
}
impl Default for EwcConfig {
fn default() -> Self {
Self {
param_count: 1000,
max_tasks: 10,
initial_lambda: 2000.0, min_lambda: 100.0,
max_lambda: 15000.0, fisher_ema_decay: 0.999,
boundary_threshold: 2.0,
gradient_history_size: 100,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TaskFisher {
pub task_id: usize,
pub fisher: Vec<f32>,
pub optimal_weights: Vec<f32>,
pub importance: f32,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct EwcPlusPlus {
config: EwcConfig,
current_fisher: Vec<f32>,
current_weights: Vec<f32>,
task_memory: VecDeque<TaskFisher>,
current_task_id: usize,
lambda: f32,
gradient_history: VecDeque<Vec<f32>>,
gradient_mean: Vec<f32>,
gradient_var: Vec<f32>,
samples_seen: u64,
}
impl EwcPlusPlus {
pub fn new(config: EwcConfig) -> Self {
let param_count = config.param_count;
let initial_lambda = config.initial_lambda;
Self {
config: config.clone(),
current_fisher: vec![0.0; param_count],
current_weights: vec![0.0; param_count],
task_memory: VecDeque::with_capacity(config.max_tasks),
current_task_id: 0,
lambda: initial_lambda,
gradient_history: VecDeque::with_capacity(config.gradient_history_size),
gradient_mean: vec![0.0; param_count],
gradient_var: vec![1.0; param_count],
samples_seen: 0,
}
}
pub fn update_fisher(&mut self, gradients: &[f32]) {
if gradients.len() != self.config.param_count {
return;
}
let decay = self.config.fisher_ema_decay;
for (i, &g) in gradients.iter().enumerate() {
self.current_fisher[i] = decay * self.current_fisher[i] + (1.0 - decay) * g * g;
}
self.update_gradient_stats(gradients);
self.samples_seen += 1;
}
fn update_gradient_stats(&mut self, gradients: &[f32]) {
if self.gradient_history.len() >= self.config.gradient_history_size {
self.gradient_history.pop_front();
}
self.gradient_history.push_back(gradients.to_vec());
let n = self.samples_seen as f32 + 1.0;
for (i, &g) in gradients.iter().enumerate() {
let delta = g - self.gradient_mean[i];
self.gradient_mean[i] += delta / n;
let delta2 = g - self.gradient_mean[i];
self.gradient_var[i] += delta * delta2;
}
}
pub fn detect_task_boundary(&self, gradients: &[f32]) -> bool {
if self.samples_seen < 50 || gradients.len() != self.config.param_count {
return false;
}
let mut z_score_sum = 0.0f32;
let mut count = 0;
for (i, &g) in gradients.iter().enumerate() {
let var = self.gradient_var[i] / self.samples_seen as f32;
if var > 1e-8 {
let std = var.sqrt();
let z = (g - self.gradient_mean[i]).abs() / std;
z_score_sum += z;
count += 1;
}
}
if count == 0 {
return false;
}
let avg_z = z_score_sum / count as f32;
avg_z > self.config.boundary_threshold
}
pub fn start_new_task(&mut self) {
let task_fisher = TaskFisher {
task_id: self.current_task_id,
fisher: self.current_fisher.clone(),
optimal_weights: self.current_weights.clone(),
importance: 1.0,
};
if self.task_memory.len() >= self.config.max_tasks {
self.task_memory.pop_front();
}
self.task_memory.push_back(task_fisher);
self.current_task_id += 1;
self.current_fisher.fill(0.0);
self.gradient_history.clear();
self.gradient_mean.fill(0.0);
self.gradient_var.fill(1.0);
self.samples_seen = 0;
self.adapt_lambda();
}
fn adapt_lambda(&mut self) {
let task_count = self.task_memory.len();
if task_count == 0 {
return;
}
let scale = 1.0 + 0.1 * task_count as f32;
self.lambda = (self.config.initial_lambda * scale)
.clamp(self.config.min_lambda, self.config.max_lambda);
}
pub fn apply_constraints(&self, gradients: &[f32]) -> Vec<f32> {
if gradients.len() != self.config.param_count {
return gradients.to_vec();
}
let mut constrained = gradients.to_vec();
for task in &self.task_memory {
for (i, g) in constrained.iter_mut().enumerate() {
let importance = task.fisher[i] * task.importance;
if importance > 1e-8 {
let penalty_grad = self.lambda * importance;
*g *= 1.0 / (1.0 + penalty_grad);
}
}
}
for (i, g) in constrained.iter_mut().enumerate() {
if self.current_fisher[i] > 1e-8 {
let penalty_grad = self.lambda * self.current_fisher[i] * 0.1; *g *= 1.0 / (1.0 + penalty_grad);
}
}
constrained
}
pub fn regularization_loss(&self, current_weights: &[f32]) -> f32 {
if current_weights.len() != self.config.param_count {
return 0.0;
}
let mut loss = 0.0f32;
for task in &self.task_memory {
for ((&cw, &ow), &fi) in current_weights
.iter()
.zip(task.optimal_weights.iter())
.zip(task.fisher.iter())
.take(self.config.param_count)
{
let diff = cw - ow;
loss += fi * diff * diff * task.importance;
}
}
self.lambda * loss / 2.0
}
pub fn set_optimal_weights(&mut self, weights: &[f32]) {
if weights.len() == self.config.param_count {
self.current_weights.copy_from_slice(weights);
}
}
pub fn consolidate_all_tasks(&mut self) {
if self.task_memory.is_empty() {
return;
}
let mut consolidated_fisher = vec![0.0f32; self.config.param_count];
let mut total_importance = 0.0f32;
for task in &self.task_memory {
for (i, &f) in task.fisher.iter().enumerate() {
consolidated_fisher[i] += f * task.importance;
}
total_importance += task.importance;
}
if total_importance > 0.0 {
for f in &mut consolidated_fisher {
*f /= total_importance;
}
}
let consolidated = TaskFisher {
task_id: 0,
fisher: consolidated_fisher,
optimal_weights: self.current_weights.clone(),
importance: total_importance,
};
self.task_memory.clear();
self.task_memory.push_back(consolidated);
}
pub fn lambda(&self) -> f32 {
self.lambda
}
pub fn set_lambda(&mut self, lambda: f32) {
self.lambda = lambda.clamp(self.config.min_lambda, self.config.max_lambda);
}
pub fn task_count(&self) -> usize {
self.task_memory.len()
}
pub fn current_task_id(&self) -> usize {
self.current_task_id
}
pub fn samples_seen(&self) -> u64 {
self.samples_seen
}
pub fn importance_scores(&self) -> Vec<f32> {
let mut scores = self.current_fisher.clone();
for task in &self.task_memory {
for (i, &f) in task.fisher.iter().enumerate() {
scores[i] += f * task.importance;
}
}
scores
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ewc_creation() {
let config = EwcConfig {
param_count: 100,
..Default::default()
};
let ewc = EwcPlusPlus::new(config);
assert_eq!(ewc.task_count(), 0);
assert_eq!(ewc.current_task_id(), 0);
}
#[test]
fn test_fisher_update() {
let config = EwcConfig {
param_count: 10,
..Default::default()
};
let mut ewc = EwcPlusPlus::new(config);
let gradients = vec![0.5; 10];
ewc.update_fisher(&gradients);
assert!(ewc.samples_seen() > 0);
assert!(ewc.current_fisher.iter().any(|&f| f > 0.0));
}
#[test]
fn test_task_boundary() {
let config = EwcConfig {
param_count: 10,
gradient_history_size: 10,
boundary_threshold: 2.0,
..Default::default()
};
let mut ewc = EwcPlusPlus::new(config);
for _ in 0..60 {
let gradients = vec![0.1; 10];
ewc.update_fisher(&gradients);
}
let normal = vec![0.1; 10];
assert!(!ewc.detect_task_boundary(&normal));
let different = vec![10.0; 10];
}
#[test]
fn test_constraint_application() {
let config = EwcConfig {
param_count: 5,
..Default::default()
};
let mut ewc = EwcPlusPlus::new(config);
for _ in 0..10 {
ewc.update_fisher(&vec![1.0; 5]);
}
ewc.start_new_task();
let gradients = vec![1.0; 5];
let constrained = ewc.apply_constraints(&gradients);
let orig_mag: f32 = gradients.iter().map(|x| x.abs()).sum();
let const_mag: f32 = constrained.iter().map(|x| x.abs()).sum();
assert!(const_mag <= orig_mag);
}
#[test]
fn test_regularization_loss() {
let config = EwcConfig {
param_count: 5,
initial_lambda: 100.0,
..Default::default()
};
let mut ewc = EwcPlusPlus::new(config);
ewc.set_optimal_weights(&vec![0.0; 5]);
for _ in 0..10 {
ewc.update_fisher(&vec![1.0; 5]);
}
ewc.start_new_task();
let at_optimal = ewc.regularization_loss(&vec![0.0; 5]);
let deviated = ewc.regularization_loss(&vec![1.0; 5]);
assert!(deviated > at_optimal);
}
#[test]
fn test_task_consolidation() {
let config = EwcConfig {
param_count: 5,
max_tasks: 5,
..Default::default()
};
let mut ewc = EwcPlusPlus::new(config);
for _ in 0..3 {
for _ in 0..10 {
ewc.update_fisher(&vec![1.0; 5]);
}
ewc.start_new_task();
}
assert_eq!(ewc.task_count(), 3);
ewc.consolidate_all_tasks();
assert_eq!(ewc.task_count(), 1);
}
#[test]
fn test_lambda_adaptation() {
let config = EwcConfig {
param_count: 5,
initial_lambda: 1000.0,
..Default::default()
};
let mut ewc = EwcPlusPlus::new(config);
let initial_lambda = ewc.lambda();
for _ in 0..5 {
ewc.start_new_task();
}
assert!(ewc.lambda() >= initial_lambda);
}
}