use crate::error::{NeuralError, Result};
use crate::models::NeuralModel;
use crate::training::{
data_loader::{DataLoader, TimeSeriesDataset},
optimizer::{LRScheduler, Optimizer, OptimizerConfig, SchedulerMode},
TrainingConfig, TrainingMetrics,
};
#[cfg(feature = "candle")]
use candle_core::{Device, Tensor};
#[cfg(feature = "candle")]
use candle_nn::VarMap;
use std::path::{Path, PathBuf};
use std::time::Instant;
use tracing::{info, warn};
pub struct Trainer {
config: TrainingConfig,
device: Device,
varmap: VarMap,
best_val_loss: Option<f64>,
epochs_without_improvement: usize,
checkpoint_dir: Option<PathBuf>,
}
impl Trainer {
pub fn new(config: TrainingConfig, device: Device) -> Self {
Self {
config,
device,
varmap: VarMap::new(),
best_val_loss: None,
epochs_without_improvement: 0,
checkpoint_dir: None,
}
}
pub fn with_checkpointing(mut self, dir: impl AsRef<Path>) -> Self {
self.checkpoint_dir = Some(dir.as_ref().to_path_buf());
self
}
pub fn varmap(&self) -> &VarMap {
&self.varmap
}
pub fn device(&self) -> &Device {
&self.device
}
pub async fn train<M: NeuralModel>(
&mut self,
mut model: M,
mut train_loader: DataLoader,
mut val_loader: Option<DataLoader>,
optimizer_config: OptimizerConfig,
) -> Result<(M, Vec<TrainingMetrics>)> {
let mut optimizer = Optimizer::new(optimizer_config.clone(), &self.varmap)?;
let mut scheduler = LRScheduler::reduce_on_plateau(
optimizer_config.learning_rate,
self.config.early_stopping_patience / 2,
0.5,
);
let mut metrics_history = Vec::new();
info!(
"Starting training for {} epochs (batch_size={}, lr={})",
self.config.num_epochs, self.config.batch_size, optimizer_config.learning_rate
);
for epoch in 0..self.config.num_epochs {
let epoch_start = Instant::now();
train_loader.reset();
let train_loss = self.train_epoch(&model, &mut train_loader, &mut optimizer)?;
let val_loss = if let Some(ref mut val_loader) = val_loader {
val_loader.reset();
Some(self.validate_epoch(&model, val_loader)?)
} else {
None
};
let current_lr = scheduler.step(val_loss, epoch);
optimizer.set_learning_rate(current_lr)?;
let epoch_time = epoch_start.elapsed().as_secs_f64();
let metrics = TrainingMetrics {
epoch,
train_loss,
val_loss,
learning_rate: current_lr,
epoch_time_seconds: epoch_time,
};
info!(
"Epoch {}/{}: train_loss={:.6}, val_loss={:?}, lr={:.2e}, time={:.2}s",
epoch + 1,
self.config.num_epochs,
train_loss,
val_loss.map(|v| format!("{:.6}", v)).unwrap_or_else(|| "N/A".to_string()),
current_lr,
epoch_time
);
metrics_history.push(metrics.clone());
if let Some(val_loss) = val_loss {
let should_stop = self.check_early_stopping(val_loss);
if Some(val_loss) == self.best_val_loss {
if let Some(ref checkpoint_dir) = self.checkpoint_dir {
self.save_checkpoint(&model, epoch, val_loss, checkpoint_dir)?;
}
}
if should_stop {
info!("Early stopping triggered after {} epochs", epoch + 1);
break;
}
}
if let Some(ref checkpoint_dir) = self.checkpoint_dir {
if (epoch + 1) % 10 == 0 {
let checkpoint_path = checkpoint_dir.join(format!("checkpoint_epoch_{}.safetensors", epoch + 1));
model.save_weights(&checkpoint_path.to_string_lossy())?;
}
}
}
if let Some(ref checkpoint_dir) = self.checkpoint_dir {
let best_path = checkpoint_dir.join("best_model.safetensors");
if best_path.exists() {
info!("Loading best model from checkpoint");
model.load_weights(&best_path.to_string_lossy())?;
}
}
Ok((model, metrics_history))
}
fn train_epoch<M: NeuralModel>(
&self,
model: &M,
loader: &mut DataLoader,
optimizer: &mut Optimizer,
) -> Result<f64> {
let mut total_loss = 0.0;
let mut batch_count = 0;
while let Some((inputs, targets)) = loader.next_batch(&self.device)? {
let predictions = model.forward(&inputs)?;
let loss = self.mse_loss(&predictions, &targets)?;
optimizer.zero_grad()?;
loss.backward()?;
if let Some(max_norm) = self.config.gradient_clip {
self.clip_gradients(&self.varmap, max_norm)?;
}
optimizer.step()?;
total_loss += loss.to_scalar::<f64>()?;
batch_count += 1;
}
Ok(total_loss / batch_count as f64)
}
fn validate_epoch<M: NeuralModel>(
&self,
model: &M,
loader: &mut DataLoader,
) -> Result<f64> {
let mut total_loss = 0.0;
let mut batch_count = 0;
while let Some((inputs, targets)) = loader.next_batch(&self.device)? {
let predictions = model.forward(&inputs)?;
let loss = self.mse_loss(&predictions, &targets)?;
total_loss += loss.to_scalar::<f64>()?;
batch_count += 1;
}
Ok(total_loss / batch_count as f64)
}
fn mse_loss(&self, predictions: &Tensor, targets: &Tensor) -> Result<Tensor> {
let diff = predictions.sub(targets)?;
let squared = diff.sqr()?;
let mean = squared.mean_all()?;
Ok(mean)
}
fn clip_gradients(&self, varmap: &VarMap, max_norm: f64) -> Result<()> {
let vars = varmap.all_vars();
let mut total_norm = 0.0;
for var in &vars {
if let Some(grad) = var.grad() {
let grad_norm = grad.as_ref().sqr()?.sum_all()?.to_scalar::<f64>()?;
total_norm += grad_norm;
}
}
total_norm = total_norm.sqrt();
if total_norm > max_norm {
let clip_coef = max_norm / total_norm;
for var in vars {
if let Some(grad) = var.grad() {
let clipped = grad.as_ref().mul(&clip_coef)?;
var.set_grad(clipped)?;
}
}
}
Ok(())
}
fn check_early_stopping(&mut self, val_loss: f64) -> bool {
match self.best_val_loss {
None => {
self.best_val_loss = Some(val_loss);
self.epochs_without_improvement = 0;
false
}
Some(best) => {
if val_loss < best {
self.best_val_loss = Some(val_loss);
self.epochs_without_improvement = 0;
false
} else {
self.epochs_without_improvement += 1;
self.epochs_without_improvement >= self.config.early_stopping_patience
}
}
}
}
fn save_checkpoint<M: NeuralModel>(
&self,
model: &M,
epoch: usize,
val_loss: f64,
checkpoint_dir: &Path,
) -> Result<()> {
std::fs::create_dir_all(checkpoint_dir)?;
let weights_path = checkpoint_dir.join("best_model.safetensors");
model.save_weights(&weights_path.to_string_lossy())?;
let metadata = CheckpointMetadata {
epoch,
val_loss,
timestamp: chrono::Utc::now(),
config: self.config.clone(),
};
let metadata_path = checkpoint_dir.join("checkpoint_metadata.json");
let json = serde_json::to_string_pretty(&metadata)?;
std::fs::write(metadata_path, json)?;
info!("Saved checkpoint at epoch {} (val_loss={:.6})", epoch + 1, val_loss);
Ok(())
}
pub fn load_checkpoint<M: NeuralModel>(
checkpoint_dir: impl AsRef<Path>,
mut model: M,
) -> Result<(M, CheckpointMetadata)> {
let checkpoint_dir = checkpoint_dir.as_ref();
let metadata_path = checkpoint_dir.join("checkpoint_metadata.json");
let json = std::fs::read_to_string(metadata_path)?;
let metadata: CheckpointMetadata = serde_json::from_str(&json)?;
let weights_path = checkpoint_dir.join("best_model.safetensors");
model.load_weights(&weights_path.to_string_lossy())?;
Ok((model, metadata))
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CheckpointMetadata {
pub epoch: usize,
pub val_loss: f64,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub config: TrainingConfig,
}
pub fn quantile_loss(predictions: &Tensor, targets: &Tensor, quantile: f64) -> Result<Tensor> {
let diff = targets.sub(predictions)?;
let positive_part = diff.maximum(&Tensor::zeros_like(&diff)?)?;
let negative_part = diff.minimum(&Tensor::zeros_like(&diff)?)?;
let loss = positive_part
.mul(&quantile)?
.add(&negative_part.mul(&(quantile - 1.0))?)?;
Ok(loss.mean_all()?)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trainer_creation() {
let config = TrainingConfig::default();
let device = Device::Cpu;
let trainer = Trainer::new(config.clone(), device);
assert_eq!(trainer.best_val_loss, None);
assert_eq!(trainer.epochs_without_improvement, 0);
}
#[test]
fn test_early_stopping() {
let config = TrainingConfig {
early_stopping_patience: 3,
..Default::default()
};
let device = Device::Cpu;
let mut trainer = Trainer::new(config, device);
assert!(!trainer.check_early_stopping(1.0));
assert!(!trainer.check_early_stopping(0.8));
assert!(!trainer.check_early_stopping(0.9));
assert!(!trainer.check_early_stopping(0.9));
assert!(trainer.check_early_stopping(0.9));
}
}