Skip to main content

entrenar/train/trainer/train_loop/
basic.rs

1//! Basic training loop without validation
2
3use crate::optim::clip_grad_norm;
4use crate::train::callback::CallbackAction;
5use crate::train::trainer::core::Trainer;
6use crate::train::trainer::result::TrainResult;
7use crate::train::Batch;
8use crate::Tensor;
9use std::time::Instant;
10
11/// Result of running the inner step loop for one epoch
12pub(super) struct EpochStepResult {
13    pub total_loss: f32,
14    pub num_batches: usize,
15    pub stopped_early: bool,
16}
17
18impl Trainer {
19    /// Train for multiple epochs with full callback support
20    ///
21    /// # Arguments
22    ///
23    /// * `max_epochs` - Maximum number of epochs to train
24    /// * `batch_fn` - Function that returns batches for each epoch
25    /// * `forward_fn` - Closure that computes predictions from inputs
26    ///
27    /// # Returns
28    ///
29    /// TrainResult with final metrics
30    ///
31    /// # Example
32    ///
33    /// ```no_run
34    /// # use entrenar::train::{Trainer, Batch, EarlyStopping};
35    /// # use entrenar::Tensor;
36    /// # let mut trainer: Trainer = todo!();
37    /// # let batches: Vec<Batch> = vec![];
38    /// trainer.add_callback(EarlyStopping::new(5, 0.001));
39    ///
40    /// let result = trainer.train(100, || batches.clone(), |x| x.clone());
41    /// println!("Trained {} epochs, final loss: {:.4}", result.final_epoch, result.final_loss);
42    /// ```
43    pub fn train<F, B, I>(&mut self, max_epochs: usize, batch_fn: B, forward_fn: F) -> TrainResult
44    where
45        F: Fn(&Tensor) -> Tensor,
46        B: Fn() -> I,
47        I: IntoIterator<Item = Batch>,
48    {
49        self.start_time = Some(Instant::now());
50        self.best_loss = None;
51        let mut stopped_early = false;
52        let mut final_loss = 0.0;
53
54        let ctx = self.build_context(0, max_epochs, 0, 0, 0.0, None);
55        if self.callbacks.on_train_begin(&ctx) == CallbackAction::Stop {
56            return self.make_early_stop_result();
57        }
58
59        for epoch in 0..max_epochs {
60            let action = self.fire_epoch_begin(epoch, max_epochs, final_loss);
61            if action == CallbackAction::Stop {
62                stopped_early = true;
63                break;
64            }
65            if action == CallbackAction::SkipEpoch {
66                continue;
67            }
68
69            let batches: Vec<Batch> = batch_fn().into_iter().collect();
70            let steps_per_epoch = batches.len();
71
72            let step_result = self.run_epoch_steps(
73                batches,
74                steps_per_epoch,
75                epoch,
76                max_epochs,
77                final_loss,
78                &forward_fn,
79            );
80            stopped_early = step_result.stopped_early;
81            if stopped_early {
82                break;
83            }
84
85            let avg_loss = safe_avg(step_result.total_loss, step_result.num_batches);
86            final_loss = avg_loss;
87            self.update_best_loss(avg_loss);
88            self.metrics.record_epoch(avg_loss, self.lr());
89
90            if self.fire_epoch_end(epoch, max_epochs, steps_per_epoch, avg_loss, None) {
91                stopped_early = true;
92                break;
93            }
94        }
95
96        self.finalize_training(
97            max_epochs,
98            final_loss,
99            self.best_loss.unwrap_or(final_loss),
100            stopped_early,
101        )
102    }
103
104    // -- Shared helpers used by both basic.rs and validation.rs --
105
106    /// Fire the epoch_begin callback and return the requested action
107    pub(super) fn fire_epoch_begin(
108        &mut self,
109        epoch: usize,
110        max_epochs: usize,
111        current_loss: f32,
112    ) -> CallbackAction {
113        let ctx = self.build_context(epoch, max_epochs, 0, 0, current_loss, None);
114        self.callbacks.on_epoch_begin(&ctx)
115    }
116
117    /// Fire the epoch_end callback; returns true if training should stop
118    pub(super) fn fire_epoch_end(
119        &mut self,
120        epoch: usize,
121        max_epochs: usize,
122        steps_per_epoch: usize,
123        loss: f32,
124        val_loss: Option<f32>,
125    ) -> bool {
126        let ctx =
127            self.build_context(epoch, max_epochs, steps_per_epoch, steps_per_epoch, loss, val_loss);
128        self.callbacks.on_epoch_end(&ctx) == CallbackAction::Stop
129    }
130
131    /// Run the inner step loop for one epoch
132    pub(super) fn run_epoch_steps<F>(
133        &mut self,
134        batches: Vec<Batch>,
135        steps_per_epoch: usize,
136        epoch: usize,
137        max_epochs: usize,
138        current_loss: f32,
139        forward_fn: &F,
140    ) -> EpochStepResult
141    where
142        F: Fn(&Tensor) -> Tensor,
143    {
144        let mut total_loss = 0.0;
145        let mut num_batches = 0;
146        let accum_steps = self.config.gradient_accumulation_steps.max(1);
147
148        for (step, batch) in batches.into_iter().enumerate() {
149            let ctx =
150                self.build_context(epoch, max_epochs, step, steps_per_epoch, current_loss, None);
151            if self.callbacks.on_step_begin(&ctx) == CallbackAction::Stop {
152                return EpochStepResult { total_loss, num_batches, stopped_early: true };
153            }
154
155            if step % accum_steps == 0 {
156                self.optimizer.zero_grad(&mut self.params);
157            }
158
159            let loss = self.accumulate_gradients(&batch, forward_fn);
160            total_loss += loss;
161            num_batches += 1;
162
163            self.maybe_clip_and_step(step, steps_per_epoch, accum_steps);
164            self.metrics.increment_step();
165
166            let ctx = self.build_context(epoch, max_epochs, step, steps_per_epoch, loss, None);
167            if self.callbacks.on_step_end(&ctx) == CallbackAction::Stop {
168                return EpochStepResult { total_loss, num_batches, stopped_early: true };
169            }
170        }
171
172        EpochStepResult { total_loss, num_batches, stopped_early: false }
173    }
174
175    /// Clip gradients and run optimizer step at accumulation boundaries
176    fn maybe_clip_and_step(&mut self, step: usize, steps_per_epoch: usize, accum_steps: usize) {
177        let is_accum_boundary = (step + 1).is_multiple_of(accum_steps);
178        let is_last_batch = step + 1 == steps_per_epoch;
179        if is_accum_boundary || is_last_batch {
180            if let Some(max_norm) = self.config.max_grad_norm {
181                clip_grad_norm(&mut self.params, max_norm);
182            }
183            self.optimizer.step(&mut self.params);
184        }
185    }
186
187    /// Update best_loss if the new loss is lower
188    pub(super) fn update_best_loss(&mut self, loss: f32) {
189        if self.best_loss.is_none_or(|bl| loss < bl) {
190            self.best_loss = Some(loss);
191        }
192    }
193
194    /// Create a TrainResult for immediate early stop at train_begin
195    pub(super) fn make_early_stop_result(&self) -> TrainResult {
196        TrainResult {
197            final_epoch: 0,
198            final_loss: 0.0,
199            best_loss: 0.0,
200            stopped_early: true,
201            elapsed_secs: self.elapsed_secs(),
202        }
203    }
204
205    /// Fire train_end and build the final TrainResult
206    pub(super) fn finalize_training(
207        &mut self,
208        max_epochs: usize,
209        final_loss: f32,
210        best_loss: f32,
211        stopped_early: bool,
212    ) -> TrainResult {
213        let ctx = self.build_context(self.metrics.epoch, max_epochs, 0, 0, final_loss, None);
214        self.callbacks.on_train_end(&ctx);
215
216        TrainResult {
217            final_epoch: self.metrics.epoch,
218            final_loss,
219            best_loss,
220            stopped_early,
221            elapsed_secs: self.elapsed_secs(),
222        }
223    }
224
225    /// Compute elapsed seconds from start_time
226    pub(super) fn elapsed_secs(&self) -> f64 {
227        self.start_time.map_or(0.0, |t| t.elapsed().as_secs_f64())
228    }
229}
230
231/// Safely compute average, returning 0.0 for empty sets
232pub(super) fn safe_avg(total: f32, count: usize) -> f32 {
233    if count > 0 {
234        total / count as f32
235    } else {
236        0.0
237    }
238}