use crate::callbacks::core::Callback;
use crate::{TrainError, TrainResult, TrainingState};
use std::collections::HashMap;
pub struct ModelEMACallback {
decay: f64,
shadow_params: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
use_warmup: bool,
num_updates: usize,
initialized: bool,
}
impl ModelEMACallback {
pub fn new(decay: f64, use_warmup: bool) -> Self {
Self {
decay,
shadow_params: HashMap::new(),
use_warmup,
num_updates: 0,
initialized: false,
}
}
pub fn initialize(
&mut self,
parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
) {
self.shadow_params.clear();
for (name, param) in parameters {
self.shadow_params.insert(name.clone(), param.clone());
}
self.initialized = true;
}
pub fn update(
&mut self,
parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
) -> TrainResult<()> {
if !self.initialized {
return Err(TrainError::CallbackError(
"ModelEMA not initialized. Call initialize() first.".to_string(),
));
}
self.num_updates += 1;
let decay = if self.use_warmup {
let warmup_decay = (1.0 + self.num_updates as f64) / (10.0 + self.num_updates as f64);
warmup_decay.min(self.decay)
} else {
self.decay
};
for (name, param) in parameters {
if let Some(shadow) = self.shadow_params.get_mut(name) {
*shadow = &*shadow * decay + &(param * (1.0 - decay));
}
}
Ok(())
}
pub fn get_shadow_params(
&self,
) -> &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
&self.shadow_params
}
pub fn apply_shadow(
&self,
parameters: &mut HashMap<
String,
scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>,
>,
) {
for (name, shadow) in &self.shadow_params {
if let Some(param) = parameters.get_mut(name) {
*param = shadow.clone();
}
}
}
}
impl Callback for ModelEMACallback {
fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
Ok(())
}
fn on_batch_end(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
Ok(())
}
}
pub struct SWACallback {
start_epoch: usize,
update_frequency: usize,
swa_params: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
num_averaged: usize,
active: bool,
initialized: bool,
verbose: bool,
}
impl SWACallback {
pub fn new(start_epoch: usize, update_frequency: usize, verbose: bool) -> Self {
Self {
start_epoch,
update_frequency,
swa_params: HashMap::new(),
num_averaged: 0,
active: false,
initialized: false,
verbose,
}
}
pub fn update_average(
&mut self,
parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
) -> TrainResult<()> {
if !self.active {
return Ok(());
}
if !self.initialized {
for (name, param) in parameters {
self.swa_params.insert(name.clone(), param.clone());
}
self.initialized = true;
self.num_averaged = 1;
if self.verbose {
println!("SWA: Initialized with model parameters");
}
} else {
let n = self.num_averaged as f64;
for (name, param) in parameters {
if let Some(swa_param) = self.swa_params.get_mut(name) {
*swa_param = &(&*swa_param * n + param) / (n + 1.0);
}
}
self.num_averaged += 1;
if self.verbose {
println!("SWA: Updated average (n={})", self.num_averaged);
}
}
Ok(())
}
pub fn get_swa_params(
&self,
) -> &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
&self.swa_params
}
pub fn apply_swa(
&self,
parameters: &mut HashMap<
String,
scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>,
>,
) {
if self.initialized {
for (name, swa_param) in &self.swa_params {
if let Some(param) = parameters.get_mut(name) {
*param = swa_param.clone();
}
}
}
}
pub fn is_ready(&self) -> bool {
self.initialized && self.num_averaged > 0
}
}
impl Callback for SWACallback {
fn on_epoch_end(&mut self, epoch: usize, _state: &TrainingState) -> TrainResult<()> {
if epoch >= self.start_epoch && !self.active {
self.active = true;
if self.verbose {
println!("\nSWA: Activated at epoch {}", epoch + 1);
}
}
if self.active && epoch >= self.start_epoch {
let relative_epoch = epoch - self.start_epoch;
if relative_epoch.is_multiple_of(self.update_frequency) {
if self.verbose && self.initialized {
println!(
"SWA: Ready to update at epoch {} (call update_average with parameters)",
epoch + 1
);
}
}
}
Ok(())
}
fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
if self.verbose && self.initialized {
println!(
"\nSWA: Training complete. Averaged {} models.",
self.num_averaged
);
println!("SWA: Call apply_swa() to use averaged parameters.");
}
Ok(())
}
}