use std::collections::HashMap;
use std::path::PathBuf;
use crate::error::Result;
pub struct CallbackContext {
pub epoch: usize,
pub n_epochs: usize,
pub batch: usize,
pub n_batches: usize,
pub lr: f64,
pub train_loss: Option<f32>,
pub valid_loss: Option<f32>,
pub metrics: HashMap<String, f32>,
pub stop_training: bool,
pub skip_batch: bool,
}
impl CallbackContext {
pub fn new(n_epochs: usize, n_batches: usize) -> Self {
Self {
epoch: 0,
n_epochs,
batch: 0,
n_batches,
lr: 0.0,
train_loss: None,
valid_loss: None,
metrics: HashMap::new(),
stop_training: false,
skip_batch: false,
}
}
pub fn progress(&self) -> f32 {
let total_batches = self.n_epochs * self.n_batches;
let current = self.epoch * self.n_batches + self.batch;
current as f32 / total_batches as f32
}
}
pub trait Callback: Send + Sync {
fn before_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn after_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn before_epoch(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn after_epoch(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn before_batch(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn after_batch(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn before_validate(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn after_validate(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
Ok(())
}
fn name(&self) -> &str {
std::any::type_name::<Self>()
}
}
#[derive(Default)]
pub struct CallbackList {
callbacks: Vec<Box<dyn Callback>>,
}
impl CallbackList {
pub fn new() -> Self {
Self {
callbacks: Vec::new(),
}
}
pub fn add<C: Callback + 'static>(&mut self, callback: C) {
self.callbacks.push(Box::new(callback));
}
pub fn before_fit(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.before_fit(ctx)?;
}
Ok(())
}
pub fn after_fit(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.after_fit(ctx)?;
}
Ok(())
}
pub fn before_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.before_epoch(ctx)?;
}
Ok(())
}
pub fn after_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.after_epoch(ctx)?;
}
Ok(())
}
pub fn before_batch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.before_batch(ctx)?;
}
Ok(())
}
pub fn after_batch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.after_batch(ctx)?;
}
Ok(())
}
pub fn before_validate(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.before_validate(ctx)?;
}
Ok(())
}
pub fn after_validate(&mut self, ctx: &mut CallbackContext) -> Result<()> {
for cb in &mut self.callbacks {
cb.after_validate(ctx)?;
}
Ok(())
}
}
pub struct ProgressCallback {
#[allow(dead_code)]
show_batch: bool,
}
impl ProgressCallback {
pub fn new(show_batch: bool) -> Self {
Self { show_batch }
}
}
impl Callback for ProgressCallback {
fn before_fit(&mut self, ctx: &mut CallbackContext) -> Result<()> {
tracing::info!("Starting training for {} epochs", ctx.n_epochs);
Ok(())
}
fn after_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
let train_loss = ctx.train_loss.map(|l| format!("{:.4}", l)).unwrap_or_default();
let valid_loss = ctx.valid_loss.map(|l| format!("{:.4}", l)).unwrap_or_default();
tracing::info!(
"Epoch {}/{}: train_loss={}, valid_loss={}, lr={:.6}",
ctx.epoch + 1,
ctx.n_epochs,
train_loss,
valid_loss,
ctx.lr
);
for (name, value) in &ctx.metrics {
tracing::info!(" {}: {:.4}", name, value);
}
Ok(())
}
fn after_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
tracing::info!("Training completed");
Ok(())
}
fn name(&self) -> &str {
"ProgressCallback"
}
}
pub struct EarlyStoppingCallback {
patience: usize,
min_delta: f32,
best_loss: f32,
counter: usize,
mode: EarlyStoppingMode,
}
pub enum EarlyStoppingMode {
Min,
Max,
}
impl EarlyStoppingCallback {
pub fn new(patience: usize, min_delta: f32, mode: EarlyStoppingMode) -> Self {
let best_loss = match mode {
EarlyStoppingMode::Min => f32::INFINITY,
EarlyStoppingMode::Max => f32::NEG_INFINITY,
};
Self {
patience,
min_delta,
best_loss,
counter: 0,
mode,
}
}
}
impl Callback for EarlyStoppingCallback {
fn after_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
let current = ctx.valid_loss.unwrap_or(f32::INFINITY);
let improved = match self.mode {
EarlyStoppingMode::Min => current < self.best_loss - self.min_delta,
EarlyStoppingMode::Max => current > self.best_loss + self.min_delta,
};
if improved {
self.best_loss = current;
self.counter = 0;
} else {
self.counter += 1;
if self.counter >= self.patience {
tracing::info!(
"Early stopping triggered after {} epochs without improvement",
self.patience
);
ctx.stop_training = true;
}
}
Ok(())
}
fn name(&self) -> &str {
"EarlyStoppingCallback"
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SaveModelMode {
Min,
Max,
Every,
}
pub struct SaveModelCallback {
save_dir: PathBuf,
mode: SaveModelMode,
best_value: f32,
metric_name: Option<String>,
save_best_only: bool,
filename_prefix: String,
best_epoch: usize,
}
impl SaveModelCallback {
pub fn new<P: Into<PathBuf>>(save_dir: P, mode: SaveModelMode) -> Self {
let best_value = match mode {
SaveModelMode::Min => f32::INFINITY,
SaveModelMode::Max => f32::NEG_INFINITY,
SaveModelMode::Every => 0.0,
};
Self {
save_dir: save_dir.into(),
mode,
best_value,
metric_name: None,
save_best_only: true,
filename_prefix: "checkpoint".to_string(),
best_epoch: 0,
}
}
#[must_use]
pub fn with_metric(mut self, name: &str) -> Self {
self.metric_name = Some(name.to_string());
self
}
#[must_use]
pub fn save_best_only(mut self, value: bool) -> Self {
self.save_best_only = value;
self
}
#[must_use]
pub fn with_prefix(mut self, prefix: &str) -> Self {
self.filename_prefix = prefix.to_string();
self
}
pub fn best_checkpoint_path(&self) -> PathBuf {
self.save_dir.join(format!("{}_best.json", self.filename_prefix))
}
pub fn epoch_checkpoint_path(&self, epoch: usize) -> PathBuf {
self.save_dir
.join(format!("{}_epoch_{}.json", self.filename_prefix, epoch))
}
pub fn best_epoch(&self) -> usize {
self.best_epoch
}
pub fn best_value(&self) -> f32 {
self.best_value
}
fn get_current_value(&self, ctx: &CallbackContext) -> Option<f32> {
if let Some(ref metric_name) = self.metric_name {
ctx.metrics.get(metric_name).copied()
} else {
ctx.valid_loss
}
}
fn should_save(&self, current: f32) -> bool {
match self.mode {
SaveModelMode::Min => current < self.best_value,
SaveModelMode::Max => current > self.best_value,
SaveModelMode::Every => true,
}
}
fn save_checkpoint(&self, ctx: &CallbackContext, is_best: bool) -> Result<()> {
std::fs::create_dir_all(&self.save_dir).map_err(|e| {
crate::error::TrainError::CheckpointError(format!(
"Failed to create checkpoint directory: {}",
e
))
})?;
let checkpoint = CheckpointMetadata {
epoch: ctx.epoch,
train_loss: ctx.train_loss,
valid_loss: ctx.valid_loss,
metrics: ctx.metrics.clone(),
is_best,
};
let epoch_path = self.epoch_checkpoint_path(ctx.epoch);
let json = serde_json::to_string_pretty(&checkpoint).map_err(|e| {
crate::error::TrainError::SerializationError(format!(
"Failed to serialize checkpoint: {}",
e
))
})?;
std::fs::write(&epoch_path, json).map_err(|e| {
crate::error::TrainError::CheckpointError(format!("Failed to write checkpoint: {}", e))
})?;
if is_best {
let best_path = self.best_checkpoint_path();
std::fs::copy(&epoch_path, &best_path).map_err(|e| {
crate::error::TrainError::CheckpointError(format!(
"Failed to copy best checkpoint: {}",
e
))
})?;
}
Ok(())
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CheckpointMetadata {
pub epoch: usize,
pub train_loss: Option<f32>,
pub valid_loss: Option<f32>,
pub metrics: HashMap<String, f32>,
pub is_best: bool,
}
impl Callback for SaveModelCallback {
fn before_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
std::fs::create_dir_all(&self.save_dir).map_err(|e| {
crate::error::TrainError::CheckpointError(format!(
"Failed to create checkpoint directory: {}",
e
))
})?;
tracing::info!("Checkpoints will be saved to: {:?}", self.save_dir);
Ok(())
}
fn after_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
let Some(current) = self.get_current_value(ctx) else {
return Ok(());
};
let is_best = self.should_save(current);
if is_best {
self.best_value = current;
self.best_epoch = ctx.epoch;
}
if !self.save_best_only || is_best {
self.save_checkpoint(ctx, is_best)?;
if is_best {
let metric_display = self
.metric_name
.as_deref()
.unwrap_or("valid_loss");
tracing::info!(
"Epoch {}: {} improved to {:.4}, saving checkpoint",
ctx.epoch + 1,
metric_display,
current
);
}
}
Ok(())
}
fn after_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
tracing::info!(
"Best model from epoch {} with value {:.4}",
self.best_epoch + 1,
self.best_value
);
Ok(())
}
fn name(&self) -> &str {
"SaveModelCallback"
}
}
#[derive(Debug, Clone, Copy)]
pub enum GradientClipMode {
Value(f32),
Norm(f32),
}
pub struct GradientClipCallback {
mode: GradientClipMode,
clip_count: usize,
total_batches: usize,
}
impl GradientClipCallback {
pub fn new(mode: GradientClipMode) -> Self {
Self {
mode,
clip_count: 0,
total_batches: 0,
}
}
pub fn by_norm(max_norm: f32) -> Self {
Self::new(GradientClipMode::Norm(max_norm))
}
pub fn by_value(max_value: f32) -> Self {
Self::new(GradientClipMode::Value(max_value))
}
pub fn mode(&self) -> GradientClipMode {
self.mode
}
pub fn clip_value(&self) -> f32 {
match self.mode {
GradientClipMode::Value(v) => v,
GradientClipMode::Norm(n) => n,
}
}
}
impl Callback for GradientClipCallback {
fn before_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
self.clip_count = 0;
self.total_batches = 0;
Ok(())
}
fn after_batch(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
self.total_batches += 1;
Ok(())
}
fn after_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
if self.total_batches > 0 {
let clip_rate = self.clip_count as f32 / self.total_batches as f32 * 100.0;
if clip_rate > 0.0 {
tracing::info!(
"Gradient clipping was applied in {:.1}% of batches",
clip_rate
);
}
}
Ok(())
}
fn name(&self) -> &str {
"GradientClipCallback"
}
}
#[derive(Default)]
pub struct HistoryCallback {
train_losses: Vec<f32>,
valid_losses: Vec<f32>,
learning_rates: Vec<f64>,
metrics_history: HashMap<String, Vec<f32>>,
}
impl HistoryCallback {
pub fn new() -> Self {
Self::default()
}
pub fn train_losses(&self) -> &[f32] {
&self.train_losses
}
pub fn valid_losses(&self) -> &[f32] {
&self.valid_losses
}
pub fn learning_rates(&self) -> &[f64] {
&self.learning_rates
}
pub fn metric_history(&self, name: &str) -> Option<&[f32]> {
self.metrics_history.get(name).map(|v| v.as_slice())
}
pub fn metric_names(&self) -> Vec<&str> {
self.metrics_history.keys().map(|s| s.as_str()).collect()
}
pub fn best_epoch(&self) -> Option<usize> {
self.valid_losses
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
}
}
impl Callback for HistoryCallback {
fn after_epoch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
if let Some(loss) = ctx.train_loss {
self.train_losses.push(loss);
}
if let Some(loss) = ctx.valid_loss {
self.valid_losses.push(loss);
}
self.learning_rates.push(ctx.lr);
for (name, &value) in &ctx.metrics {
self.metrics_history
.entry(name.clone())
.or_default()
.push(value);
}
Ok(())
}
fn name(&self) -> &str {
"HistoryCallback"
}
}
pub struct MixedPrecisionCallback {
initial_scale: f32,
current_scale: f32,
growth_factor: f32,
backoff_factor: f32,
growth_interval: usize,
batches_since_rescale: usize,
overflow_count: usize,
}
impl MixedPrecisionCallback {
pub fn new(initial_scale: f32) -> Self {
Self {
initial_scale,
current_scale: initial_scale,
growth_factor: 2.0,
backoff_factor: 0.5,
growth_interval: 2000,
batches_since_rescale: 0,
overflow_count: 0,
}
}
pub fn current_scale(&self) -> f32 {
self.current_scale
}
pub fn report_overflow(&mut self) {
self.overflow_count += 1;
self.current_scale *= self.backoff_factor;
self.batches_since_rescale = 0;
}
}
impl Callback for MixedPrecisionCallback {
fn before_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
self.current_scale = self.initial_scale;
self.batches_since_rescale = 0;
self.overflow_count = 0;
Ok(())
}
fn after_batch(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
self.batches_since_rescale += 1;
if self.batches_since_rescale >= self.growth_interval {
self.current_scale *= self.growth_factor;
self.batches_since_rescale = 0;
}
Ok(())
}
fn after_fit(&mut self, _ctx: &mut CallbackContext) -> Result<()> {
if self.overflow_count > 0 {
tracing::info!(
"Mixed precision: {} overflow events, final scale = {:.0}",
self.overflow_count,
self.current_scale
);
}
Ok(())
}
fn name(&self) -> &str {
"MixedPrecisionCallback"
}
}
pub struct TerminateOnNanCallback {
nan_count: usize,
}
impl TerminateOnNanCallback {
pub fn new() -> Self {
Self { nan_count: 0 }
}
}
impl Default for TerminateOnNanCallback {
fn default() -> Self {
Self::new()
}
}
impl Callback for TerminateOnNanCallback {
fn after_batch(&mut self, ctx: &mut CallbackContext) -> Result<()> {
if let Some(loss) = ctx.train_loss {
if loss.is_nan() || loss.is_infinite() {
self.nan_count += 1;
tracing::error!("NaN/Inf detected in training loss at batch {}", ctx.batch);
ctx.stop_training = true;
}
}
Ok(())
}
fn name(&self) -> &str {
"TerminateOnNanCallback"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_callback_context() {
let ctx = CallbackContext::new(10, 100);
assert_eq!(ctx.epoch, 0);
assert_eq!(ctx.n_epochs, 10);
assert_eq!(ctx.progress(), 0.0);
}
#[test]
fn test_callback_list() {
let mut list = CallbackList::new();
list.add(ProgressCallback::new(false));
}
#[test]
fn test_gradient_clip_callback() {
let clip = GradientClipCallback::by_norm(1.0);
assert_eq!(clip.clip_value(), 1.0);
let clip = GradientClipCallback::by_value(0.5);
assert_eq!(clip.clip_value(), 0.5);
}
#[test]
fn test_history_callback() {
let mut history = HistoryCallback::new();
let mut ctx = CallbackContext::new(10, 100);
ctx.train_loss = Some(0.5);
ctx.valid_loss = Some(0.4);
ctx.lr = 0.001;
history.after_epoch(&mut ctx).unwrap();
assert_eq!(history.train_losses(), &[0.5]);
assert_eq!(history.valid_losses(), &[0.4]);
assert_eq!(history.learning_rates(), &[0.001]);
}
}