entrenar/train/trainer/train_loop/
validation.rs1use super::basic::safe_avg;
4use crate::train::callback::CallbackAction;
5use crate::train::trainer::core::Trainer;
6use crate::train::trainer::result::TrainResult;
7use crate::train::Batch;
8use crate::Tensor;
9use std::time::Instant;
10
11impl Trainer {
12 pub fn train_with_val<F, BT, BV, IT, IV>(
47 &mut self,
48 max_epochs: usize,
49 train_fn: BT,
50 val_fn: BV,
51 forward_fn: F,
52 ) -> TrainResult
53 where
54 F: Fn(&Tensor) -> Tensor,
55 BT: Fn() -> IT,
56 BV: Fn() -> IV,
57 IT: IntoIterator<Item = Batch>,
58 IV: IntoIterator<Item = Batch>,
59 {
60 self.start_time = Some(Instant::now());
61 self.best_loss = None;
62 let mut stopped_early = false;
63 let mut final_loss = 0.0;
64 let mut best_val_loss: Option<f32> = None;
65
66 let ctx = self.build_context(0, max_epochs, 0, 0, 0.0, None);
67 if self.callbacks.on_train_begin(&ctx) == CallbackAction::Stop {
68 return self.make_early_stop_result();
69 }
70
71 for epoch in 0..max_epochs {
72 let action = self.fire_epoch_begin(epoch, max_epochs, final_loss);
73 if action == CallbackAction::Stop {
74 stopped_early = true;
75 break;
76 }
77 if action == CallbackAction::SkipEpoch {
78 continue;
79 }
80
81 let train_batches: Vec<Batch> = train_fn().into_iter().collect();
83 let steps_per_epoch = train_batches.len();
84
85 let step_result = self.run_epoch_steps(
86 train_batches,
87 steps_per_epoch,
88 epoch,
89 max_epochs,
90 final_loss,
91 &forward_fn,
92 );
93 stopped_early = step_result.stopped_early;
94 if stopped_early {
95 break;
96 }
97
98 let avg_train_loss = safe_avg(step_result.total_loss, step_result.num_batches);
99 final_loss = avg_train_loss;
100
101 let val_loss = self.compute_validation_loss(&val_fn, &forward_fn);
103
104 let monitored_loss = val_loss.unwrap_or(avg_train_loss);
106 update_tracked_best(&mut best_val_loss, monitored_loss);
107 self.update_best_loss(avg_train_loss);
108
109 self.metrics.record_epoch(avg_train_loss, self.lr());
110
111 if self.fire_epoch_end(epoch, max_epochs, steps_per_epoch, avg_train_loss, val_loss) {
112 stopped_early = true;
113 break;
114 }
115 }
116
117 let best = best_val_loss.unwrap_or(self.best_loss.unwrap_or(final_loss));
118 self.finalize_training(max_epochs, final_loss, best, stopped_early)
119 }
120
121 fn compute_validation_loss<F, BV, IV>(&mut self, val_fn: &BV, forward_fn: &F) -> Option<f32>
125 where
126 F: Fn(&Tensor) -> Tensor,
127 BV: Fn() -> IV,
128 IV: IntoIterator<Item = Batch>,
129 {
130 let val_batches: Vec<Batch> = val_fn().into_iter().collect();
131 if val_batches.is_empty() {
132 return None;
133 }
134
135 let mut val_total = 0.0;
136 let mut val_count = 0;
137 for batch in val_batches {
138 if let Some(loss_fn) = self.loss_fn.as_ref() {
139 let predictions = forward_fn(&batch.inputs);
140 let loss = loss_fn.forward(&predictions, &batch.targets);
141 val_total += loss.data()[0];
142 val_count += 1;
143 }
144 }
145
146 let val_avg = safe_avg(val_total, val_count);
147 self.metrics.record_val_loss(val_avg);
148 Some(val_avg)
149 }
150}
151
152fn update_tracked_best(tracked: &mut Option<f32>, value: f32) {
154 if tracked.is_none_or(|best| value < best) {
155 *tracked = Some(value);
156 }
157}