use crate::error::Result;
use crate::models::History;
use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::{Float, NumAssign};
use std::collections::HashMap;
use std::fmt::Debug;
use super::{Callback, CallbackContext, CallbackTiming};
pub struct CallbackManager<F: Float + Debug + ScalarOperand + NumAssign> {
callbacks: Vec<Box<dyn Callback<F> + Send + Sync>>,
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Debug for CallbackManager<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CallbackManager")
.field("callbacks_count", &self.callbacks.len())
.finish()
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> CallbackManager<F> {
pub fn new() -> Self {
Self {
callbacks: Vec::new(),
}
}
pub fn add_callback(&mut self, callback: Box<dyn Callback<F> + Send + Sync>) {
self.callbacks.push(callback);
}
pub fn on_train_begin(&mut self) -> Result<()> {
let mut context = CallbackContext {
epoch: 0,
total_epochs: 0,
batch: 0,
total_batches: 0,
batch_loss: None,
epoch_loss: None,
val_loss: None,
metrics: Vec::new(),
history: &History::default(),
stop_training: false,
model: None,
};
for callback in &mut self.callbacks {
callback.on_event(CallbackTiming::BeforeTraining, &mut context)?;
}
Ok(())
}
pub fn on_train_end(&mut self) -> Result<()> {
let mut context = CallbackContext {
epoch: 0,
total_epochs: 0,
batch: 0,
total_batches: 0,
batch_loss: None,
epoch_loss: None,
val_loss: None,
metrics: Vec::new(),
history: &History::default(),
stop_training: false,
model: None,
};
for callback in &mut self.callbacks {
callback.on_event(CallbackTiming::AfterTraining, &mut context)?;
}
Ok(())
}
pub fn on_epoch_begin(&mut self, epoch: usize) -> Result<()> {
let mut context = CallbackContext {
epoch,
total_epochs: 0,
batch: 0,
total_batches: 0,
batch_loss: None,
epoch_loss: None,
val_loss: None,
metrics: Vec::new(),
history: &History::default(),
stop_training: false,
model: None,
};
for callback in &mut self.callbacks {
callback.on_event(CallbackTiming::BeforeEpoch, &mut context)?;
}
Ok(())
}
pub fn on_epoch_end(&mut self, epoch: usize, metrics: &HashMap<String, F>) -> Result<bool> {
let metrics_vec = metrics.values().cloned().collect::<Vec<_>>();
let mut context = CallbackContext {
epoch,
total_epochs: 0,
batch: 0,
total_batches: 0,
batch_loss: metrics.get("loss").copied(),
epoch_loss: metrics.get("loss").copied(),
val_loss: metrics.get("val_loss").copied(),
metrics: metrics_vec,
history: &History::default(),
stop_training: false,
model: None,
};
for callback in &mut self.callbacks {
callback.on_event(CallbackTiming::AfterEpoch, &mut context)?;
}
let stop_training = context.stop_training;
Ok(stop_training)
}
pub fn on_batch_begin(&mut self, batch: usize) -> Result<()> {
let mut context = CallbackContext {
epoch: 0,
total_epochs: 0,
batch,
total_batches: 0,
batch_loss: None,
epoch_loss: None,
val_loss: None,
metrics: Vec::new(),
history: &History::default(),
stop_training: false,
model: None,
};
for callback in &mut self.callbacks {
callback.on_event(CallbackTiming::BeforeBatch, &mut context)?;
}
Ok(())
}
pub fn on_batch_end(&mut self, batch: usize, metrics: &HashMap<String, F>) -> Result<()> {
let metrics_vec = metrics.values().cloned().collect::<Vec<_>>();
let mut context = CallbackContext {
epoch: 0,
total_epochs: 0,
batch,
total_batches: 0,
batch_loss: metrics.get("loss").copied(),
epoch_loss: None,
val_loss: None,
metrics: metrics_vec,
history: &History::default(),
stop_training: false,
model: None,
};
for callback in &mut self.callbacks {
callback.on_event(CallbackTiming::AfterBatch, &mut context)?;
}
Ok(())
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Default for CallbackManager<F> {
fn default() -> Self {
Self::new()
}
}