use crate::error::Result;
use crate::layers::Layer;
use crate::models::History;
use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CallbackTiming {
BeforeTraining,
BeforeEpoch,
BeforeBatch,
AfterBatch,
AfterEpoch,
AfterTraining,
}
pub struct CallbackContext<'a, F: Float + Debug + ScalarOperand + NumAssign> {
pub epoch: usize,
pub total_epochs: usize,
pub batch: usize,
pub total_batches: usize,
pub batch_loss: Option<F>,
pub epoch_loss: Option<F>,
pub val_loss: Option<F>,
pub metrics: Vec<F>,
pub history: &'a History<F>,
pub stop_training: bool,
pub model: Option<&'a mut dyn Layer<F>>,
}
pub trait Callback<F: Float + Debug + ScalarOperand + NumAssign> {
fn on_event(&mut self, timing: CallbackTiming, context: &mut CallbackContext<F>) -> Result<()>;
}
mod callback_manager;
mod checkpoint;
mod early_stopping;
mod gradient_clipping;
mod learning_rate_scheduler;
mod learning_rate_scheduler_trait;
mod metrics;
mod model_checkpoint;
mod tensorboard;
mod visualization_callback;
pub struct FunctionCallback<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
func: Box<dyn Fn() -> Result<()> + Send + Sync>,
_phantom: std::marker::PhantomData<F>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> FunctionCallback<F> {
pub fn new(func: Box<dyn Fn() -> Result<()> + Send + Sync>) -> Self {
Self {
func,
_phantom: std::marker::PhantomData,
}
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Callback<F>
for FunctionCallback<F>
{
fn on_event(
&mut self,
_timing: CallbackTiming,
_context: &mut CallbackContext<F>,
) -> Result<()> {
(self.func)()
}
}
pub use callback_manager::CallbackManager;
pub use checkpoint::ModelCheckpoint;
pub use early_stopping::EarlyStopping;
pub use gradient_clipping::{GradientClipping, GradientClippingMethod};
pub use learning_rate_scheduler::{CosineAnnealingLR, ReduceOnPlateau, ScheduleMethod, StepDecay};
pub use learning_rate_scheduler_trait::LearningRateScheduler;
#[cfg(feature = "metrics_integration")]
pub use metrics::*;
pub use tensorboard::TensorBoardLogger;
pub use visualization_callback::VisualizationCallback;