entrenar/train/trainer/train_loop/
basic.rs1use crate::optim::clip_grad_norm;
4use crate::train::callback::CallbackAction;
5use crate::train::trainer::core::Trainer;
6use crate::train::trainer::result::TrainResult;
7use crate::train::Batch;
8use crate::Tensor;
9use std::time::Instant;
10
11pub(super) struct EpochStepResult {
13 pub total_loss: f32,
14 pub num_batches: usize,
15 pub stopped_early: bool,
16}
17
18impl Trainer {
19 pub fn train<F, B, I>(&mut self, max_epochs: usize, batch_fn: B, forward_fn: F) -> TrainResult
44 where
45 F: Fn(&Tensor) -> Tensor,
46 B: Fn() -> I,
47 I: IntoIterator<Item = Batch>,
48 {
49 self.start_time = Some(Instant::now());
50 self.best_loss = None;
51 let mut stopped_early = false;
52 let mut final_loss = 0.0;
53
54 let ctx = self.build_context(0, max_epochs, 0, 0, 0.0, None);
55 if self.callbacks.on_train_begin(&ctx) == CallbackAction::Stop {
56 return self.make_early_stop_result();
57 }
58
59 for epoch in 0..max_epochs {
60 let action = self.fire_epoch_begin(epoch, max_epochs, final_loss);
61 if action == CallbackAction::Stop {
62 stopped_early = true;
63 break;
64 }
65 if action == CallbackAction::SkipEpoch {
66 continue;
67 }
68
69 let batches: Vec<Batch> = batch_fn().into_iter().collect();
70 let steps_per_epoch = batches.len();
71
72 let step_result = self.run_epoch_steps(
73 batches,
74 steps_per_epoch,
75 epoch,
76 max_epochs,
77 final_loss,
78 &forward_fn,
79 );
80 stopped_early = step_result.stopped_early;
81 if stopped_early {
82 break;
83 }
84
85 let avg_loss = safe_avg(step_result.total_loss, step_result.num_batches);
86 final_loss = avg_loss;
87 self.update_best_loss(avg_loss);
88 self.metrics.record_epoch(avg_loss, self.lr());
89
90 if self.fire_epoch_end(epoch, max_epochs, steps_per_epoch, avg_loss, None) {
91 stopped_early = true;
92 break;
93 }
94 }
95
96 self.finalize_training(
97 max_epochs,
98 final_loss,
99 self.best_loss.unwrap_or(final_loss),
100 stopped_early,
101 )
102 }
103
104 pub(super) fn fire_epoch_begin(
108 &mut self,
109 epoch: usize,
110 max_epochs: usize,
111 current_loss: f32,
112 ) -> CallbackAction {
113 let ctx = self.build_context(epoch, max_epochs, 0, 0, current_loss, None);
114 self.callbacks.on_epoch_begin(&ctx)
115 }
116
117 pub(super) fn fire_epoch_end(
119 &mut self,
120 epoch: usize,
121 max_epochs: usize,
122 steps_per_epoch: usize,
123 loss: f32,
124 val_loss: Option<f32>,
125 ) -> bool {
126 let ctx =
127 self.build_context(epoch, max_epochs, steps_per_epoch, steps_per_epoch, loss, val_loss);
128 self.callbacks.on_epoch_end(&ctx) == CallbackAction::Stop
129 }
130
131 pub(super) fn run_epoch_steps<F>(
133 &mut self,
134 batches: Vec<Batch>,
135 steps_per_epoch: usize,
136 epoch: usize,
137 max_epochs: usize,
138 current_loss: f32,
139 forward_fn: &F,
140 ) -> EpochStepResult
141 where
142 F: Fn(&Tensor) -> Tensor,
143 {
144 let mut total_loss = 0.0;
145 let mut num_batches = 0;
146 let accum_steps = self.config.gradient_accumulation_steps.max(1);
147
148 for (step, batch) in batches.into_iter().enumerate() {
149 let ctx =
150 self.build_context(epoch, max_epochs, step, steps_per_epoch, current_loss, None);
151 if self.callbacks.on_step_begin(&ctx) == CallbackAction::Stop {
152 return EpochStepResult { total_loss, num_batches, stopped_early: true };
153 }
154
155 if step % accum_steps == 0 {
156 self.optimizer.zero_grad(&mut self.params);
157 }
158
159 let loss = self.accumulate_gradients(&batch, forward_fn);
160 total_loss += loss;
161 num_batches += 1;
162
163 self.maybe_clip_and_step(step, steps_per_epoch, accum_steps);
164 self.metrics.increment_step();
165
166 let ctx = self.build_context(epoch, max_epochs, step, steps_per_epoch, loss, None);
167 if self.callbacks.on_step_end(&ctx) == CallbackAction::Stop {
168 return EpochStepResult { total_loss, num_batches, stopped_early: true };
169 }
170 }
171
172 EpochStepResult { total_loss, num_batches, stopped_early: false }
173 }
174
175 fn maybe_clip_and_step(&mut self, step: usize, steps_per_epoch: usize, accum_steps: usize) {
177 let is_accum_boundary = (step + 1).is_multiple_of(accum_steps);
178 let is_last_batch = step + 1 == steps_per_epoch;
179 if is_accum_boundary || is_last_batch {
180 if let Some(max_norm) = self.config.max_grad_norm {
181 clip_grad_norm(&mut self.params, max_norm);
182 }
183 self.optimizer.step(&mut self.params);
184 }
185 }
186
187 pub(super) fn update_best_loss(&mut self, loss: f32) {
189 if self.best_loss.is_none_or(|bl| loss < bl) {
190 self.best_loss = Some(loss);
191 }
192 }
193
194 pub(super) fn make_early_stop_result(&self) -> TrainResult {
196 TrainResult {
197 final_epoch: 0,
198 final_loss: 0.0,
199 best_loss: 0.0,
200 stopped_early: true,
201 elapsed_secs: self.elapsed_secs(),
202 }
203 }
204
205 pub(super) fn finalize_training(
207 &mut self,
208 max_epochs: usize,
209 final_loss: f32,
210 best_loss: f32,
211 stopped_early: bool,
212 ) -> TrainResult {
213 let ctx = self.build_context(self.metrics.epoch, max_epochs, 0, 0, final_loss, None);
214 self.callbacks.on_train_end(&ctx);
215
216 TrainResult {
217 final_epoch: self.metrics.epoch,
218 final_loss,
219 best_loss,
220 stopped_early,
221 elapsed_secs: self.elapsed_secs(),
222 }
223 }
224
225 pub(super) fn elapsed_secs(&self) -> f64 {
227 self.start_time.map_or(0.0, |t| t.elapsed().as_secs_f64())
228 }
229}
230
231pub(super) fn safe_avg(total: f32, count: usize) -> f32 {
233 if count > 0 {
234 total / count as f32
235 } else {
236 0.0
237 }
238}