Skip to main content

entrenar/train/trainer/train_loop/
validation.rs

1//! Training loop with validation support
2
3use super::basic::safe_avg;
4use crate::train::callback::CallbackAction;
5use crate::train::trainer::core::Trainer;
6use crate::train::trainer::result::TrainResult;
7use crate::train::Batch;
8use crate::Tensor;
9use std::time::Instant;
10
11impl Trainer {
12    /// Train for multiple epochs with validation after each epoch
13    ///
14    /// This method runs training and validation each epoch, passing validation
15    /// loss to callbacks for proper early stopping and checkpointing.
16    ///
17    /// # Arguments
18    ///
19    /// * `max_epochs` - Maximum number of epochs to train
20    /// * `train_fn` - Function that returns training batches for each epoch
21    /// * `val_fn` - Function that returns validation batches for each epoch
22    /// * `forward_fn` - Closure that computes predictions from inputs
23    ///
24    /// # Returns
25    ///
26    /// TrainResult with final metrics including best validation loss
27    ///
28    /// # Example
29    ///
30    /// ```no_run
31    /// # use entrenar::train::{Trainer, Batch, EarlyStopping};
32    /// # use entrenar::Tensor;
33    /// # let mut trainer: Trainer = todo!();
34    /// # let train_batches: Vec<Batch> = vec![];
35    /// # let val_batches: Vec<Batch> = vec![];
36    /// trainer.add_callback(EarlyStopping::new(5, 0.001).monitor_validation());
37    ///
38    /// let result = trainer.train_with_val(
39    ///     100,
40    ///     || train_batches.clone(),
41    ///     || val_batches.clone(),
42    ///     |x| x.clone()
43    /// );
44    /// println!("Best val loss: {:.4}", result.best_loss);
45    /// ```
46    pub fn train_with_val<F, BT, BV, IT, IV>(
47        &mut self,
48        max_epochs: usize,
49        train_fn: BT,
50        val_fn: BV,
51        forward_fn: F,
52    ) -> TrainResult
53    where
54        F: Fn(&Tensor) -> Tensor,
55        BT: Fn() -> IT,
56        BV: Fn() -> IV,
57        IT: IntoIterator<Item = Batch>,
58        IV: IntoIterator<Item = Batch>,
59    {
60        self.start_time = Some(Instant::now());
61        self.best_loss = None;
62        let mut stopped_early = false;
63        let mut final_loss = 0.0;
64        let mut best_val_loss: Option<f32> = None;
65
66        let ctx = self.build_context(0, max_epochs, 0, 0, 0.0, None);
67        if self.callbacks.on_train_begin(&ctx) == CallbackAction::Stop {
68            return self.make_early_stop_result();
69        }
70
71        for epoch in 0..max_epochs {
72            let action = self.fire_epoch_begin(epoch, max_epochs, final_loss);
73            if action == CallbackAction::Stop {
74                stopped_early = true;
75                break;
76            }
77            if action == CallbackAction::SkipEpoch {
78                continue;
79            }
80
81            // Training phase
82            let train_batches: Vec<Batch> = train_fn().into_iter().collect();
83            let steps_per_epoch = train_batches.len();
84
85            let step_result = self.run_epoch_steps(
86                train_batches,
87                steps_per_epoch,
88                epoch,
89                max_epochs,
90                final_loss,
91                &forward_fn,
92            );
93            stopped_early = step_result.stopped_early;
94            if stopped_early {
95                break;
96            }
97
98            let avg_train_loss = safe_avg(step_result.total_loss, step_result.num_batches);
99            final_loss = avg_train_loss;
100
101            // Validation phase
102            let val_loss = self.compute_validation_loss(&val_fn, &forward_fn);
103
104            // Update best losses
105            let monitored_loss = val_loss.unwrap_or(avg_train_loss);
106            update_tracked_best(&mut best_val_loss, monitored_loss);
107            self.update_best_loss(avg_train_loss);
108
109            self.metrics.record_epoch(avg_train_loss, self.lr());
110
111            if self.fire_epoch_end(epoch, max_epochs, steps_per_epoch, avg_train_loss, val_loss) {
112                stopped_early = true;
113                break;
114            }
115        }
116
117        let best = best_val_loss.unwrap_or(self.best_loss.unwrap_or(final_loss));
118        self.finalize_training(max_epochs, final_loss, best, stopped_early)
119    }
120
121    /// Run validation batches and return average validation loss
122    ///
123    /// Returns `None` if there are no validation batches or no loss function.
124    fn compute_validation_loss<F, BV, IV>(&mut self, val_fn: &BV, forward_fn: &F) -> Option<f32>
125    where
126        F: Fn(&Tensor) -> Tensor,
127        BV: Fn() -> IV,
128        IV: IntoIterator<Item = Batch>,
129    {
130        let val_batches: Vec<Batch> = val_fn().into_iter().collect();
131        if val_batches.is_empty() {
132            return None;
133        }
134
135        let mut val_total = 0.0;
136        let mut val_count = 0;
137        for batch in val_batches {
138            if let Some(loss_fn) = self.loss_fn.as_ref() {
139                let predictions = forward_fn(&batch.inputs);
140                let loss = loss_fn.forward(&predictions, &batch.targets);
141                val_total += loss.data()[0];
142                val_count += 1;
143            }
144        }
145
146        let val_avg = safe_avg(val_total, val_count);
147        self.metrics.record_val_loss(val_avg);
148        Some(val_avg)
149    }
150}
151
152/// Update a tracked best value if the new value is lower
153fn update_tracked_best(tracked: &mut Option<f32>, value: f32) {
154    if tracked.is_none_or(|best| value < best) {
155        *tracked = Some(value);
156    }
157}