use super::callbacks::Callback;
use super::metrics::{MetricsCollector, TrainingMetrics};
use super::state::{BatchState, EpochState, TrainingState};
use crate::autograd::Variable;
use crate::data::{DataLoader, Dataset};
use crate::nn::loss::Loss;
use crate::optim::Optimizer;
use crate::tensor::Tensor;
use num_traits::Float;
use std::fmt::Debug;
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct TrainerConfig {
pub epochs: usize,
pub log_frequency: usize,
pub validation_frequency: usize,
pub gradient_clip_value: Option<f32>,
pub device: String,
pub use_mixed_precision: bool,
pub accumulation_steps: usize,
}
impl Default for TrainerConfig {
fn default() -> Self {
Self {
epochs: 10,
log_frequency: 100,
validation_frequency: 1,
gradient_clip_value: None,
device: "cpu".to_string(),
use_mixed_precision: false,
accumulation_steps: 1,
}
}
}
pub trait TrainingDataLoader<T: Float> {
fn reset(&mut self);
fn next_batch(&mut self) -> Option<(Tensor<T>, Tensor<T>)>;
fn is_empty(&self) -> bool;
}
pub struct Trainer<T, O, L>
where
T: Float
+ 'static
+ Send
+ Sync
+ Debug
+ Clone
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
O: Optimizer + Clone,
L: Loss<T> + Clone,
{
config: TrainerConfig,
optimizer: O,
loss_fn: L,
metrics_collector: MetricsCollector<T>,
callbacks: Vec<Box<dyn Callback<T> + Send + Sync>>,
_phantom: std::marker::PhantomData<T>,
}
impl<T, O, L> Trainer<T, O, L>
where
T: Float
+ 'static
+ Send
+ Sync
+ Debug
+ Clone
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
O: Optimizer + Clone,
L: Loss<T> + Clone,
{
pub fn new(config: TrainerConfig, optimizer: O, loss_fn: L) -> Self {
Self {
config,
optimizer,
loss_fn,
metrics_collector: MetricsCollector::new(),
callbacks: Vec::new(),
_phantom: std::marker::PhantomData,
}
}
pub fn add_callback(&mut self, callback: Box<dyn Callback<T> + Send + Sync>) {
self.callbacks.push(callback);
}
pub fn train<M>(
&mut self,
model: &mut M,
train_loader: &mut dyn TrainingDataLoader<T>,
mut val_loader: Option<&mut dyn TrainingDataLoader<T>>,
) -> anyhow::Result<TrainingMetrics<T>>
where
M: TrainableModel<T>,
{
let start_time = Instant::now();
let mut state = TrainingState::new(self.config.epochs);
for callback in &mut self.callbacks {
callback.on_train_begin(&mut state)?;
}
for epoch in 0..self.config.epochs {
let epoch_start = Instant::now();
let mut epoch_state = EpochState::new(epoch);
for callback in &mut self.callbacks {
callback.on_epoch_begin(&mut state, &mut epoch_state)?;
}
model.train();
let train_metrics = self.train_epoch(model, train_loader, &mut state)?;
epoch_state.train_metrics = Some(train_metrics.clone());
if epoch % self.config.validation_frequency == 0 {
if let Some(ref mut val_loader) = val_loader {
model.eval();
let val_metrics = self.validate_epoch(model, &mut **val_loader, &mut state)?;
epoch_state.val_metrics = Some(val_metrics.clone());
}
}
epoch_state.duration = epoch_start.elapsed();
state.add_epoch(epoch_state.clone());
for callback in &mut self.callbacks {
if let Some(signal) = callback.on_epoch_end(&mut state, &epoch_state)? {
match signal {
CallbackSignal::Stop => {
println!("Training stopped by callback at epoch {}", epoch + 1);
break;
}
CallbackSignal::Continue => {}
}
}
}
if epoch % self.config.validation_frequency == 0 {
self.log_epoch_summary(epoch, &epoch_state);
}
}
state.total_duration = start_time.elapsed();
for callback in &mut self.callbacks {
callback.on_train_end(&mut state)?;
}
let final_metrics = self.metrics_collector.finalize(state);
Ok(final_metrics)
}
fn train_epoch<M>(
&mut self,
model: &mut M,
train_loader: &mut dyn TrainingDataLoader<T>,
state: &mut TrainingState<T>,
) -> anyhow::Result<EpochMetrics<T>>
where
M: TrainableModel<T>,
{
let mut epoch_metrics = EpochMetrics::new();
let mut batch_count = 0;
let mut accumulated_loss = T::zero();
train_loader.reset();
while let Some((inputs, targets)) = train_loader.next_batch() {
let batch_start = Instant::now();
let mut batch_state = BatchState::new(batch_count);
for callback in &mut self.callbacks {
callback.on_batch_begin(state, &mut batch_state)?;
}
let outputs = model.forward(&Variable::new(inputs, false));
let loss = self
.loss_fn
.forward(&outputs, &Variable::new(targets, false));
if batch_count % self.config.accumulation_steps == 0 {
self.optimizer.zero_grad();
}
loss.backward();
if (batch_count + 1) % self.config.accumulation_steps == 0 {
if let Some(clip_value) = self.config.gradient_clip_value {
self.clip_gradients(model, clip_value);
}
}
let loss_value = self.extract_scalar_value(&loss);
accumulated_loss = accumulated_loss + T::from(loss_value).unwrap();
epoch_metrics.total_loss = epoch_metrics.total_loss + T::from(loss_value).unwrap();
epoch_metrics.batch_count += 1;
batch_state.loss = Some(loss_value);
batch_state.duration = batch_start.elapsed();
for callback in &mut self.callbacks {
callback.on_batch_end(state, &batch_state)?;
}
if batch_count % self.config.log_frequency == 0 {
println!("Batch {}: Loss = {:.4}", batch_count, loss_value);
}
batch_count += 1;
}
epoch_metrics.avg_loss = if epoch_metrics.batch_count > 0 {
epoch_metrics.total_loss / T::from(epoch_metrics.batch_count).unwrap()
} else {
T::zero()
};
Ok(epoch_metrics)
}
fn validate_epoch<M>(
&mut self,
model: &mut M,
val_loader: &mut dyn TrainingDataLoader<T>,
_state: &mut TrainingState<T>,
) -> anyhow::Result<EpochMetrics<T>>
where
M: TrainableModel<T>,
{
let mut epoch_metrics = EpochMetrics::new();
val_loader.reset();
while let Some((inputs, targets)) = val_loader.next_batch() {
let outputs = model.forward(&Variable::new(inputs, false));
let loss = self
.loss_fn
.forward(&outputs, &Variable::new(targets, false));
let loss_value = self.extract_scalar_value(&loss);
epoch_metrics.total_loss = epoch_metrics.total_loss + T::from(loss_value).unwrap();
epoch_metrics.batch_count += 1;
}
epoch_metrics.avg_loss = if epoch_metrics.batch_count > 0 {
epoch_metrics.total_loss / T::from(epoch_metrics.batch_count).unwrap()
} else {
T::zero()
};
Ok(epoch_metrics)
}
fn clip_gradients<M>(&self, _model: &M, _clip_value: f32)
where
M: TrainableModel<T>,
{
}
fn extract_scalar_value(&self, _variable: &Variable<T>) -> f64 {
0.5
}
fn log_epoch_summary(&self, epoch: usize, epoch_state: &EpochState<T>) {
let mut summary = format!("Epoch {}/{}", epoch + 1, self.config.epochs);
if let Some(ref train_metrics) = epoch_state.train_metrics {
summary.push_str(&format!(
" | Train Loss: {:.4}",
train_metrics.avg_loss.to_f64().unwrap_or(0.0)
));
}
if let Some(ref val_metrics) = epoch_state.val_metrics {
summary.push_str(&format!(
" | Val Loss: {:.4}",
val_metrics.avg_loss.to_f64().unwrap_or(0.0)
));
}
summary.push_str(&format!(
" | Time: {:.2}s",
epoch_state.duration.as_secs_f64()
));
println!("{}", summary);
}
}
pub trait TrainableModel<
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>
{
fn forward(&self, input: &Variable<T>) -> Variable<T>;
fn train(&mut self);
fn eval(&mut self);
fn parameters(&self) -> Vec<&Variable<T>>;
fn parameters_mut(&mut self) -> Vec<&mut Variable<T>>;
}
#[derive(Debug, Clone)]
pub struct EpochMetrics<T: Float> {
pub total_loss: T,
pub avg_loss: T,
pub batch_count: usize,
}
impl<T: Float> EpochMetrics<T> {
pub fn new() -> Self {
Self {
total_loss: T::zero(),
avg_loss: T::zero(),
batch_count: 0,
}
}
}
impl<T: Float> Default for EpochMetrics<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
pub enum CallbackSignal {
Continue,
Stop,
}
pub struct TrainerBuilder<T, O, L>
where
T: Float
+ 'static
+ Send
+ Sync
+ Debug
+ Clone
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
O: Optimizer + Clone,
L: Loss<T> + Clone,
{
config: TrainerConfig,
optimizer: Option<O>,
loss_fn: Option<L>,
_phantom: std::marker::PhantomData<T>,
}
impl<T, O, L> TrainerBuilder<T, O, L>
where
T: Float
+ 'static
+ Send
+ Sync
+ Debug
+ Clone
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
O: Optimizer + Clone,
L: Loss<T> + Clone,
{
pub fn new() -> Self {
Self {
config: TrainerConfig::default(),
optimizer: None,
loss_fn: None,
_phantom: std::marker::PhantomData,
}
}
pub fn epochs(mut self, epochs: usize) -> Self {
self.config.epochs = epochs;
self
}
pub fn log_frequency(mut self, frequency: usize) -> Self {
self.config.log_frequency = frequency;
self
}
pub fn validation_frequency(mut self, frequency: usize) -> Self {
self.config.validation_frequency = frequency;
self
}
pub fn gradient_clip_value(mut self, value: f32) -> Self {
self.config.gradient_clip_value = Some(value);
self
}
pub fn device(mut self, device: String) -> Self {
self.config.device = device;
self
}
pub fn optimizer(mut self, optimizer: O) -> Self {
self.optimizer = Some(optimizer);
self
}
pub fn loss_fn(mut self, loss_fn: L) -> Self {
self.loss_fn = Some(loss_fn);
self
}
pub fn build(self) -> anyhow::Result<Trainer<T, O, L>> {
let optimizer = self
.optimizer
.ok_or_else(|| anyhow::anyhow!("Optimizer not provided"))?;
let loss_fn = self
.loss_fn
.ok_or_else(|| anyhow::anyhow!("Loss function not provided"))?;
Ok(Trainer::new(self.config, optimizer, loss_fn))
}
}
impl<T, O, L> Default for TrainerBuilder<T, O, L>
where
T: Float
+ 'static
+ Send
+ Sync
+ Debug
+ Clone
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
O: Optimizer + Clone,
L: Loss<T> + Clone,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trainer_config_default() {
let config = TrainerConfig::default();
assert_eq!(config.epochs, 10);
assert_eq!(config.log_frequency, 100);
assert_eq!(config.validation_frequency, 1);
assert_eq!(config.gradient_clip_value, None);
assert_eq!(config.device, "cpu");
assert!(!config.use_mixed_precision);
assert_eq!(config.accumulation_steps, 1);
}
#[test]
fn test_epoch_metrics_creation() {
let metrics: EpochMetrics<f32> = EpochMetrics::new();
assert_eq!(metrics.total_loss, 0.0);
assert_eq!(metrics.avg_loss, 0.0);
assert_eq!(metrics.batch_count, 0);
}
#[test]
fn test_callback_signal() {
let signal = CallbackSignal::Continue;
match signal {
CallbackSignal::Continue => {} CallbackSignal::Stop => unreachable!("Stop signal should be handled earlier"),
}
}
}
pub struct Phase5TrainingDataLoader<'a, D: Dataset<Vec<Tensor<T>>>, T: Float> {
dataset: &'a D,
batch_size: usize,
indices: Vec<usize>,
current_index: usize,
shuffle: bool,
_phantom: std::marker::PhantomData<T>,
}
impl<
'a,
D: Dataset<Vec<Tensor<T>>>,
T: Float + Clone + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
> Phase5TrainingDataLoader<'a, D, T>
{
pub fn new(dataset: &'a D, batch_size: usize, shuffle: bool) -> Self {
let mut indices: Vec<usize> = (0..dataset.len()).collect();
if shuffle {
use rand::seq::SliceRandom;
indices.shuffle(&mut rand::thread_rng());
}
Self {
dataset,
batch_size,
indices,
current_index: 0,
shuffle,
_phantom: std::marker::PhantomData,
}
}
}
impl<
'a,
D: Dataset<Vec<Tensor<T>>>,
T: Float + Clone + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
> TrainingDataLoader<T> for Phase5TrainingDataLoader<'a, D, T>
{
fn reset(&mut self) {
self.current_index = 0;
if self.shuffle {
use rand::seq::SliceRandom;
self.indices.shuffle(&mut rand::thread_rng());
}
}
fn next_batch(&mut self) -> Option<(Tensor<T>, Tensor<T>)> {
if self.is_empty() {
return None;
}
let end_index = std::cmp::min(self.current_index + self.batch_size, self.dataset.len());
let mut batch_features = Vec::new();
let mut batch_targets = Vec::new();
for i in self.current_index..end_index {
let index = self.indices[i];
if let Ok(tensors) = self.dataset.get_item(index) {
if tensors.len() >= 2 {
batch_features.push(tensors[0].clone());
batch_targets.push(tensors[1].clone());
}
}
}
self.current_index = end_index;
if !batch_features.is_empty() {
let feature_refs: Vec<&Tensor<T>> = batch_features.iter().collect();
let target_refs: Vec<&Tensor<T>> = batch_targets.iter().collect();
match (Tensor::stack(&feature_refs), Tensor::stack(&target_refs)) {
(Ok(stacked_features), Ok(stacked_targets)) => {
Some((stacked_features, stacked_targets))
}
_ => {
Some((batch_features[0].clone(), batch_targets[0].clone()))
}
}
} else {
None
}
}
fn is_empty(&self) -> bool {
self.current_index >= self.dataset.len()
}
}