use super::core::Trainer;
use crate::train::Batch;
use crate::Tensor;
impl Trainer {
pub fn train_epoch<F, I>(&mut self, batches: I, forward_fn: F) -> f32
where
F: Fn(&Tensor) -> Tensor,
I: IntoIterator<Item = Batch>,
{
let mut total_loss = 0.0;
let mut num_batches = 0;
for (i, batch) in batches.into_iter().enumerate() {
let loss = self.train_step(&batch, &forward_fn);
total_loss += loss;
num_batches += 1;
if (i + 1) % self.config.log_interval == 0 {
let avg_loss = total_loss / num_batches as f32;
println!(
"Epoch {}, Step {}: loss={:.4}, lr={:.6}",
self.metrics.epoch,
i + 1,
avg_loss,
self.lr()
);
}
}
let avg_loss = if num_batches > 0 { total_loss / num_batches as f32 } else { 0.0 };
self.metrics.record_epoch(avg_loss, self.lr());
avg_loss
}
pub fn validate<F, I>(&mut self, batches: I, forward_fn: F) -> f32
where
F: Fn(&Tensor) -> Tensor,
I: IntoIterator<Item = Batch>,
{
assert!(self.loss_fn.is_some(), "Loss function must be set before validation");
let mut total_loss = 0.0;
let mut num_batches = 0;
for batch in batches {
let predictions = forward_fn(&batch.inputs);
let loss = self
.loss_fn
.as_ref()
.expect("loss function must be set before validation")
.forward(&predictions, &batch.targets);
total_loss += loss.data()[0];
num_batches += 1;
}
let avg_loss = if num_batches > 0 { total_loss / num_batches as f32 } else { 0.0 };
self.metrics.record_val_loss(avg_loss);
avg_loss
}
}
#[cfg(test)]
mod tests {
use crate::optim::Adam;
use crate::train::{Batch, MSELoss, TrainConfig, Trainer};
use crate::Tensor;
#[test]
fn test_train_epoch() {
let params = vec![Tensor::from_vec(vec![1.0, 2.0], true)];
let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
let config = TrainConfig::new().with_log_interval(100);
let mut trainer = Trainer::new(params, Box::new(optimizer), config);
trainer.set_loss(Box::new(MSELoss));
let batches = vec![
Batch::new(
Tensor::from_vec(vec![1.0, 2.0], false),
Tensor::from_vec(vec![2.0, 3.0], false),
),
Batch::new(
Tensor::from_vec(vec![2.0, 3.0], false),
Tensor::from_vec(vec![3.0, 4.0], false),
),
];
let avg_loss = trainer.train_epoch(batches, std::clone::Clone::clone);
assert!(avg_loss > 0.0);
assert_eq!(trainer.metrics.epoch, 1);
assert_eq!(trainer.metrics.steps, 2);
}
#[test]
fn test_train_epoch_with_empty_batches() {
let params = vec![Tensor::from_vec(vec![1.0], true)];
let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
let config = TrainConfig::new().with_log_interval(100);
let mut trainer = Trainer::new(params, Box::new(optimizer), config);
trainer.set_loss(Box::new(MSELoss));
let batches: Vec<Batch> = vec![];
let avg_loss = trainer.train_epoch(batches, std::clone::Clone::clone);
assert_eq!(avg_loss, 0.0);
}
#[test]
fn test_validate() {
let params = vec![Tensor::from_vec(vec![1.0, 2.0], true)];
let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
let config = TrainConfig::default();
let mut trainer = Trainer::new(params, Box::new(optimizer), config);
trainer.set_loss(Box::new(MSELoss));
let val_batches = vec![
Batch::new(
Tensor::from_vec(vec![1.0, 2.0], false),
Tensor::from_vec(vec![2.0, 3.0], false),
),
Batch::new(
Tensor::from_vec(vec![2.0, 3.0], false),
Tensor::from_vec(vec![3.0, 4.0], false),
),
];
let val_loss = trainer.validate(val_batches, std::clone::Clone::clone);
assert!(val_loss > 0.0);
assert!(val_loss.is_finite());
assert_eq!(trainer.metrics.val_losses.len(), 1);
assert_eq!(trainer.metrics.steps, 0);
}
#[test]
fn test_validate_does_not_update_params() {
let initial_params = vec![1.0, 2.0];
let params = vec![Tensor::from_vec(initial_params.clone(), true)];
let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
let config = TrainConfig::default();
let mut trainer = Trainer::new(params, Box::new(optimizer), config);
trainer.set_loss(Box::new(MSELoss));
let val_batches = vec![Batch::new(
Tensor::from_vec(vec![1.0, 2.0], false),
Tensor::from_vec(vec![5.0, 6.0], false), )];
trainer.validate(val_batches, std::clone::Clone::clone);
let params_after: Vec<f32> = trainer.params()[0].data().to_vec();
assert_eq!(params_after, initial_params);
}
#[test]
fn test_validate_with_empty_batches() {
let params = vec![Tensor::from_vec(vec![1.0], true)];
let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
let config = TrainConfig::default();
let mut trainer = Trainer::new(params, Box::new(optimizer), config);
trainer.set_loss(Box::new(MSELoss));
let batches: Vec<Batch> = vec![];
let val_loss = trainer.validate(batches, std::clone::Clone::clone);
assert_eq!(val_loss, 0.0);
}
}