Skip to main content

entrenar/train/trainer/
epoch.rs

1//! Epoch-level training and validation operations
2
3use super::core::Trainer;
4use crate::train::Batch;
5use crate::Tensor;
6
7impl Trainer {
8    /// Train for one epoch
9    ///
10    /// # Arguments
11    ///
12    /// * `batches` - Iterator over training batches
13    /// * `forward_fn` - Closure that computes predictions from inputs
14    ///
15    /// # Returns
16    ///
17    /// Average loss over the epoch
18    pub fn train_epoch<F, I>(&mut self, batches: I, forward_fn: F) -> f32
19    where
20        F: Fn(&Tensor) -> Tensor,
21        I: IntoIterator<Item = Batch>,
22    {
23        let mut total_loss = 0.0;
24        let mut num_batches = 0;
25
26        for (i, batch) in batches.into_iter().enumerate() {
27            let loss = self.train_step(&batch, &forward_fn);
28            total_loss += loss;
29            num_batches += 1;
30
31            // Log progress
32            if (i + 1) % self.config.log_interval == 0 {
33                let avg_loss = total_loss / num_batches as f32;
34                println!(
35                    "Epoch {}, Step {}: loss={:.4}, lr={:.6}",
36                    self.metrics.epoch,
37                    i + 1,
38                    avg_loss,
39                    self.lr()
40                );
41            }
42        }
43
44        let avg_loss = if num_batches > 0 { total_loss / num_batches as f32 } else { 0.0 };
45
46        // Record epoch metrics
47        self.metrics.record_epoch(avg_loss, self.lr());
48
49        avg_loss
50    }
51
52    /// Validate on a dataset without updating parameters
53    ///
54    /// # Arguments
55    ///
56    /// * `batches` - Iterator over validation batches
57    /// * `forward_fn` - Closure that computes predictions from inputs
58    ///
59    /// # Returns
60    ///
61    /// Average validation loss
62    ///
63    /// # Example
64    ///
65    /// ```no_run
66    /// # use entrenar::train::{Trainer, Batch};
67    /// # use entrenar::Tensor;
68    /// # let mut trainer: Trainer = todo!();
69    /// # let val_batches: Vec<Batch> = vec![];
70    /// let val_loss = trainer.validate(val_batches, |x| x.clone());
71    /// println!("Validation loss: {:.4}", val_loss);
72    /// ```
73    pub fn validate<F, I>(&mut self, batches: I, forward_fn: F) -> f32
74    where
75        F: Fn(&Tensor) -> Tensor,
76        I: IntoIterator<Item = Batch>,
77    {
78        assert!(self.loss_fn.is_some(), "Loss function must be set before validation");
79
80        let mut total_loss = 0.0;
81        let mut num_batches = 0;
82
83        for batch in batches {
84            // Forward pass only (no gradients, no optimizer step)
85            let predictions = forward_fn(&batch.inputs);
86            let loss = self
87                .loss_fn
88                .as_ref()
89                .expect("loss function must be set before validation")
90                .forward(&predictions, &batch.targets);
91            total_loss += loss.data()[0];
92            num_batches += 1;
93        }
94
95        let avg_loss = if num_batches > 0 { total_loss / num_batches as f32 } else { 0.0 };
96
97        // Record validation loss
98        self.metrics.record_val_loss(avg_loss);
99
100        avg_loss
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use crate::optim::Adam;
107    use crate::train::{Batch, MSELoss, TrainConfig, Trainer};
108    use crate::Tensor;
109
110    #[test]
111    fn test_train_epoch() {
112        let params = vec![Tensor::from_vec(vec![1.0, 2.0], true)];
113        let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
114        let config = TrainConfig::new().with_log_interval(100); // Disable logging
115
116        let mut trainer = Trainer::new(params, Box::new(optimizer), config);
117        trainer.set_loss(Box::new(MSELoss));
118
119        // Create multiple batches
120        let batches = vec![
121            Batch::new(
122                Tensor::from_vec(vec![1.0, 2.0], false),
123                Tensor::from_vec(vec![2.0, 3.0], false),
124            ),
125            Batch::new(
126                Tensor::from_vec(vec![2.0, 3.0], false),
127                Tensor::from_vec(vec![3.0, 4.0], false),
128            ),
129        ];
130
131        let avg_loss = trainer.train_epoch(batches, std::clone::Clone::clone);
132
133        assert!(avg_loss > 0.0);
134        assert_eq!(trainer.metrics.epoch, 1);
135        assert_eq!(trainer.metrics.steps, 2);
136    }
137
138    #[test]
139    fn test_train_epoch_with_empty_batches() {
140        let params = vec![Tensor::from_vec(vec![1.0], true)];
141        let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
142        let config = TrainConfig::new().with_log_interval(100);
143
144        let mut trainer = Trainer::new(params, Box::new(optimizer), config);
145        trainer.set_loss(Box::new(MSELoss));
146
147        let batches: Vec<Batch> = vec![];
148        let avg_loss = trainer.train_epoch(batches, std::clone::Clone::clone);
149
150        // With empty batches, loss is 0.0
151        assert_eq!(avg_loss, 0.0);
152    }
153
154    #[test]
155    fn test_validate() {
156        let params = vec![Tensor::from_vec(vec![1.0, 2.0], true)];
157        let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
158        let config = TrainConfig::default();
159
160        let mut trainer = Trainer::new(params, Box::new(optimizer), config);
161        trainer.set_loss(Box::new(MSELoss));
162
163        // Validation batches
164        let val_batches = vec![
165            Batch::new(
166                Tensor::from_vec(vec![1.0, 2.0], false),
167                Tensor::from_vec(vec![2.0, 3.0], false),
168            ),
169            Batch::new(
170                Tensor::from_vec(vec![2.0, 3.0], false),
171                Tensor::from_vec(vec![3.0, 4.0], false),
172            ),
173        ];
174
175        let val_loss = trainer.validate(val_batches, std::clone::Clone::clone);
176
177        assert!(val_loss > 0.0);
178        assert!(val_loss.is_finite());
179        assert_eq!(trainer.metrics.val_losses.len(), 1);
180        // Steps should not increase during validation
181        assert_eq!(trainer.metrics.steps, 0);
182    }
183
184    #[test]
185    fn test_validate_does_not_update_params() {
186        let initial_params = vec![1.0, 2.0];
187        let params = vec![Tensor::from_vec(initial_params.clone(), true)];
188        let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
189        let config = TrainConfig::default();
190
191        let mut trainer = Trainer::new(params, Box::new(optimizer), config);
192        trainer.set_loss(Box::new(MSELoss));
193
194        let val_batches = vec![Batch::new(
195            Tensor::from_vec(vec![1.0, 2.0], false),
196            Tensor::from_vec(vec![5.0, 6.0], false), // Different targets to create loss
197        )];
198
199        trainer.validate(val_batches, std::clone::Clone::clone);
200
201        // Parameters should remain unchanged after validation
202        let params_after: Vec<f32> = trainer.params()[0].data().to_vec();
203        assert_eq!(params_after, initial_params);
204    }
205
206    #[test]
207    fn test_validate_with_empty_batches() {
208        let params = vec![Tensor::from_vec(vec![1.0], true)];
209        let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
210        let config = TrainConfig::default();
211
212        let mut trainer = Trainer::new(params, Box::new(optimizer), config);
213        trainer.set_loss(Box::new(MSELoss));
214
215        let batches: Vec<Batch> = vec![];
216        let val_loss = trainer.validate(batches, std::clone::Clone::clone);
217
218        // With empty batches, loss is 0.0
219        assert_eq!(val_loss, 0.0);
220    }
221}