use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::tensor::Tensor;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EWCConfig {
pub learning_rate: f32,
pub lambda: f32,
pub fisher_method: FisherMethod,
pub fisher_samples: usize,
pub online: bool,
pub decay_factor: f32,
}
impl Default for EWCConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
lambda: 1000.0,
fisher_method: FisherMethod::Empirical,
fisher_samples: 1000,
online: false,
decay_factor: 0.9,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FisherMethod {
Empirical,
True,
Diagonal,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PackNetConfig {
pub learning_rate: f32,
pub sparsity_level: f32,
pub num_tasks: usize,
pub allocation_strategy: AllocationStrategy,
}
impl Default for PackNetConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
sparsity_level: 0.5,
num_tasks: 10,
allocation_strategy: AllocationStrategy::Sequential,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AllocationStrategy {
Sequential,
Random,
ImportanceBased,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct L2RegularizationConfig {
pub learning_rate: f32,
pub reg_strength: f32,
pub update_strategy: UpdateStrategy,
}
impl Default for L2RegularizationConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
reg_strength: 0.1,
update_strategy: UpdateStrategy::EMA,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum UpdateStrategy {
Fixed,
EMA,
TaskBoundary,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryReplayConfig {
pub learning_rate: f32,
pub memory_size: usize,
pub replay_frequency: usize,
pub replay_batch_size: usize,
pub selection_strategy: MemorySelectionStrategy,
}
impl Default for MemoryReplayConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
memory_size: 1000,
replay_frequency: 10,
replay_batch_size: 32,
selection_strategy: MemorySelectionStrategy::Random,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MemorySelectionStrategy {
Random,
GradientBased,
UncertaintyBased,
}
pub struct EWC {
config: EWCConfig,
parameters: Vec<Tensor>,
importance_weights: Vec<Tensor>,
anchor_parameters: Vec<Tensor>,
current_task: usize,
accumulated_importance: Vec<Tensor>,
}
impl EWC {
pub fn new(config: EWCConfig, initial_parameters: Vec<Tensor>) -> Result<Self> {
let param_count = initial_parameters.len();
let importance_weights: Result<Vec<Tensor>> = (0..param_count)
.map(|i| {
Tensor::zeros(&initial_parameters[i].shape())
.map_err(|e| anyhow::anyhow!("{:?}", e))
})
.collect();
let accumulated_importance: Result<Vec<Tensor>> = (0..param_count)
.map(|i| {
Tensor::zeros(&initial_parameters[i].shape())
.map_err(|e| anyhow::anyhow!("{:?}", e))
})
.collect();
Ok(Self {
config,
parameters: initial_parameters.clone(),
importance_weights: importance_weights?,
anchor_parameters: initial_parameters.clone(),
current_task: 0,
accumulated_importance: accumulated_importance?,
})
}
pub fn compute_fisher_information(&mut self, gradients_samples: &[Vec<Tensor>]) -> Result<()> {
let num_samples = gradients_samples.len();
if num_samples == 0 {
return Err(anyhow!("No gradient samples provided"));
}
for importance in self.importance_weights.iter_mut() {
*importance = Tensor::zeros(&importance.shape())?;
}
for gradient_sample in gradients_samples {
for (i, gradient) in gradient_sample.iter().enumerate() {
if i < self.importance_weights.len() {
let squared_grad = gradient.mul(gradient)?;
self.importance_weights[i] = self.importance_weights[i].add(&squared_grad)?;
}
}
}
for importance in self.importance_weights.iter_mut() {
*importance = importance.div_scalar(num_samples as f32)?;
}
if self.config.online {
for i in 0..self.accumulated_importance.len() {
let decayed =
self.accumulated_importance[i].mul_scalar(self.config.decay_factor)?;
self.accumulated_importance[i] = decayed.add(&self.importance_weights[i])?;
}
}
Ok(())
}
pub fn finish_task(&mut self) -> Result<()> {
self.anchor_parameters = self.parameters.clone();
self.current_task += 1;
Ok(())
}
pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
for (i, gradient) in gradients.iter().enumerate() {
if i < self.parameters.len() {
let param_diff = self.parameters[i].sub(&self.anchor_parameters[i])?;
let importance = if self.config.online {
&self.accumulated_importance[i]
} else {
&self.importance_weights[i]
};
let ewc_grad = param_diff.mul(importance)?.mul_scalar(self.config.lambda)?;
let total_grad = gradient.add(&ewc_grad)?;
let update = total_grad.mul_scalar(self.config.learning_rate)?;
self.parameters[i] = self.parameters[i].sub(&update)?;
}
}
Ok(())
}
pub fn get_parameters(&self) -> &[Tensor] {
&self.parameters
}
pub fn get_importance_weights(&self) -> &[Tensor] {
&self.importance_weights
}
}
pub struct PackNet {
config: PackNetConfig,
parameters: Vec<Tensor>,
#[allow(dead_code)]
parameter_masks: Vec<Tensor>,
task_allocations: HashMap<usize, Vec<Tensor>>,
current_task: usize,
available_capacity: Vec<f32>,
}
impl PackNet {
pub fn new(config: PackNetConfig, initial_parameters: Vec<Tensor>) -> Result<Self> {
let param_count = initial_parameters.len();
let parameter_masks: Result<Vec<Tensor>> = (0..param_count)
.map(|i| {
Tensor::ones(&initial_parameters[i].shape()).map_err(|e| anyhow::anyhow!("{:?}", e))
})
.collect();
Ok(Self {
config,
parameters: initial_parameters.clone(),
parameter_masks: parameter_masks?,
task_allocations: HashMap::new(),
current_task: 0,
available_capacity: vec![1.0; param_count],
})
}
pub fn allocate_task(&mut self, task_id: usize) -> Result<()> {
if self.available_capacity.iter().any(|&cap| cap < self.config.sparsity_level) {
return Err(anyhow!("Insufficient parameter capacity for new task"));
}
let mut task_masks = Vec::new();
for (i, param) in self.parameters.iter().enumerate() {
let shape = param.shape();
let total_params = shape.iter().product::<usize>();
let allocated_params = (total_params as f32 * self.config.sparsity_level) as usize;
let mut mask_data = vec![0.0; total_params];
match self.config.allocation_strategy {
AllocationStrategy::Sequential => {
let start_idx =
((1.0 - self.available_capacity[i]) * total_params as f32) as usize;
let end_idx = (start_idx + allocated_params).min(total_params);
for idx in start_idx..end_idx {
mask_data[idx] = 1.0;
}
},
AllocationStrategy::Random => {
use scirs2_core::random::*; let mut indices: Vec<usize> = (0..total_params).collect();
let mut rng = thread_rng();
indices.shuffle(rng.rng_mut());
for &idx in indices.iter().take(allocated_params) {
mask_data[idx] = 1.0;
}
},
AllocationStrategy::ImportanceBased => {
for idx in 0..allocated_params.min(total_params) {
mask_data[idx] = 1.0;
}
},
}
let task_mask = Tensor::new(mask_data)?;
task_masks.push(task_mask);
self.available_capacity[i] -= self.config.sparsity_level;
}
self.task_allocations.insert(task_id, task_masks);
self.current_task = task_id;
Ok(())
}
pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
let task_masks = self
.task_allocations
.get(&self.current_task)
.ok_or_else(|| anyhow!("No allocation for current task"))?;
for (i, gradient) in gradients.iter().enumerate() {
if i < self.parameters.len() && i < task_masks.len() {
let masked_grad = gradient.mul(&task_masks[i])?;
let update = masked_grad.mul_scalar(self.config.learning_rate)?;
self.parameters[i] = self.parameters[i].sub(&update)?;
}
}
Ok(())
}
pub fn get_parameters(&self) -> &[Tensor] {
&self.parameters
}
pub fn get_available_capacity(&self) -> &[f32] {
&self.available_capacity
}
}
pub struct L2Regularization {
config: L2RegularizationConfig,
parameters: Vec<Tensor>,
anchor_parameters: Vec<Tensor>,
ema_decay: f32,
}
impl L2Regularization {
pub fn new(config: L2RegularizationConfig, initial_parameters: Vec<Tensor>) -> Self {
Self {
config,
parameters: initial_parameters.clone(),
anchor_parameters: initial_parameters,
ema_decay: 0.999,
}
}
pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
for (i, gradient) in gradients.iter().enumerate() {
if i < self.parameters.len() {
let param_diff = self.parameters[i].sub(&self.anchor_parameters[i])?;
let reg_grad = param_diff.mul_scalar(self.config.reg_strength)?;
let total_grad = gradient.add(®_grad)?;
let update = total_grad.mul_scalar(self.config.learning_rate)?;
self.parameters[i] = self.parameters[i].sub(&update)?;
match self.config.update_strategy {
UpdateStrategy::Fixed => {
},
UpdateStrategy::EMA => {
let anchor_update = self.parameters[i].mul_scalar(1.0 - self.ema_decay)?;
let anchor_keep = self.anchor_parameters[i].mul_scalar(self.ema_decay)?;
self.anchor_parameters[i] = anchor_update.add(&anchor_keep)?;
},
UpdateStrategy::TaskBoundary => {
},
}
}
}
Ok(())
}
pub fn finish_task(&mut self) -> Result<()> {
if matches!(self.config.update_strategy, UpdateStrategy::TaskBoundary) {
self.anchor_parameters = self.parameters.clone();
}
Ok(())
}
pub fn get_parameters(&self) -> &[Tensor] {
&self.parameters
}
}
pub struct MemoryReplay {
config: MemoryReplayConfig,
parameters: Vec<Tensor>,
memory_buffer: Vec<Vec<Tensor>>, step_count: usize,
}
impl MemoryReplay {
pub fn new(config: MemoryReplayConfig, initial_parameters: Vec<Tensor>) -> Self {
Self {
config,
parameters: initial_parameters,
memory_buffer: Vec::new(),
step_count: 0,
}
}
pub fn store_gradient(&mut self, gradients: &[Tensor]) -> Result<()> {
if self.memory_buffer.len() >= self.config.memory_size {
match self.config.selection_strategy {
MemorySelectionStrategy::Random => {
use scirs2_core::random::*; let idx = thread_rng().random_range(0..self.memory_buffer.len());
self.memory_buffer.remove(idx);
},
_ => {
self.memory_buffer.remove(0); },
}
}
self.memory_buffer.push(gradients.to_vec());
Ok(())
}
pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
for (i, gradient) in gradients.iter().enumerate() {
if i < self.parameters.len() {
let update = gradient.mul_scalar(self.config.learning_rate)?;
self.parameters[i] = self.parameters[i].sub(&update)?;
}
}
self.store_gradient(gradients)?;
if self.step_count % self.config.replay_frequency == 0 && !self.memory_buffer.is_empty() {
self.replay_step()?;
}
self.step_count += 1;
Ok(())
}
fn replay_step(&mut self) -> Result<()> {
let batch_size = self.config.replay_batch_size.min(self.memory_buffer.len());
use scirs2_core::random::*; let mut indices: Vec<usize> = (0..self.memory_buffer.len()).collect();
let mut rng = thread_rng();
indices.shuffle(rng.rng_mut());
for &idx in indices.iter().take(batch_size) {
let replay_gradients = &self.memory_buffer[idx];
let replay_lr = self.config.learning_rate * 0.5;
for (i, gradient) in replay_gradients.iter().enumerate() {
if i < self.parameters.len() {
let update = gradient.mul_scalar(replay_lr)?;
self.parameters[i] = self.parameters[i].sub(&update)?;
}
}
}
Ok(())
}
pub fn get_parameters(&self) -> &[Tensor] {
&self.parameters
}
pub fn memory_size(&self) -> usize {
self.memory_buffer.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ewc_config() {
let config = EWCConfig::default();
assert_eq!(config.learning_rate, 1e-3);
assert_eq!(config.lambda, 1000.0);
assert!(!config.online);
}
#[test]
fn test_packnet_config() {
let config = PackNetConfig::default();
assert_eq!(config.sparsity_level, 0.5);
assert_eq!(config.num_tasks, 10);
}
#[test]
fn test_l2_regularization_config() {
let config = L2RegularizationConfig::default();
assert_eq!(config.reg_strength, 0.1);
assert!(matches!(config.update_strategy, UpdateStrategy::EMA));
}
#[test]
fn test_memory_replay_config() {
let config = MemoryReplayConfig::default();
assert_eq!(config.memory_size, 1000);
assert_eq!(config.replay_frequency, 10);
assert!(matches!(
config.selection_strategy,
MemorySelectionStrategy::Random
));
}
#[test]
fn test_fisher_methods() {
assert!(matches!(FisherMethod::Empirical, FisherMethod::Empirical));
assert!(matches!(FisherMethod::True, FisherMethod::True));
assert!(matches!(FisherMethod::Diagonal, FisherMethod::Diagonal));
}
}