1use super::core::Trainer;
4use crate::train::Batch;
5use crate::Tensor;
6
7impl Trainer {
8 pub fn train_epoch<F, I>(&mut self, batches: I, forward_fn: F) -> f32
19 where
20 F: Fn(&Tensor) -> Tensor,
21 I: IntoIterator<Item = Batch>,
22 {
23 let mut total_loss = 0.0;
24 let mut num_batches = 0;
25
26 for (i, batch) in batches.into_iter().enumerate() {
27 let loss = self.train_step(&batch, &forward_fn);
28 total_loss += loss;
29 num_batches += 1;
30
31 if (i + 1) % self.config.log_interval == 0 {
33 let avg_loss = total_loss / num_batches as f32;
34 println!(
35 "Epoch {}, Step {}: loss={:.4}, lr={:.6}",
36 self.metrics.epoch,
37 i + 1,
38 avg_loss,
39 self.lr()
40 );
41 }
42 }
43
44 let avg_loss = if num_batches > 0 { total_loss / num_batches as f32 } else { 0.0 };
45
46 self.metrics.record_epoch(avg_loss, self.lr());
48
49 avg_loss
50 }
51
52 pub fn validate<F, I>(&mut self, batches: I, forward_fn: F) -> f32
74 where
75 F: Fn(&Tensor) -> Tensor,
76 I: IntoIterator<Item = Batch>,
77 {
78 assert!(self.loss_fn.is_some(), "Loss function must be set before validation");
79
80 let mut total_loss = 0.0;
81 let mut num_batches = 0;
82
83 for batch in batches {
84 let predictions = forward_fn(&batch.inputs);
86 let loss = self
87 .loss_fn
88 .as_ref()
89 .expect("loss function must be set before validation")
90 .forward(&predictions, &batch.targets);
91 total_loss += loss.data()[0];
92 num_batches += 1;
93 }
94
95 let avg_loss = if num_batches > 0 { total_loss / num_batches as f32 } else { 0.0 };
96
97 self.metrics.record_val_loss(avg_loss);
99
100 avg_loss
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use crate::optim::Adam;
107 use crate::train::{Batch, MSELoss, TrainConfig, Trainer};
108 use crate::Tensor;
109
110 #[test]
111 fn test_train_epoch() {
112 let params = vec![Tensor::from_vec(vec![1.0, 2.0], true)];
113 let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
114 let config = TrainConfig::new().with_log_interval(100); let mut trainer = Trainer::new(params, Box::new(optimizer), config);
117 trainer.set_loss(Box::new(MSELoss));
118
119 let batches = vec![
121 Batch::new(
122 Tensor::from_vec(vec![1.0, 2.0], false),
123 Tensor::from_vec(vec![2.0, 3.0], false),
124 ),
125 Batch::new(
126 Tensor::from_vec(vec![2.0, 3.0], false),
127 Tensor::from_vec(vec![3.0, 4.0], false),
128 ),
129 ];
130
131 let avg_loss = trainer.train_epoch(batches, std::clone::Clone::clone);
132
133 assert!(avg_loss > 0.0);
134 assert_eq!(trainer.metrics.epoch, 1);
135 assert_eq!(trainer.metrics.steps, 2);
136 }
137
138 #[test]
139 fn test_train_epoch_with_empty_batches() {
140 let params = vec![Tensor::from_vec(vec![1.0], true)];
141 let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
142 let config = TrainConfig::new().with_log_interval(100);
143
144 let mut trainer = Trainer::new(params, Box::new(optimizer), config);
145 trainer.set_loss(Box::new(MSELoss));
146
147 let batches: Vec<Batch> = vec![];
148 let avg_loss = trainer.train_epoch(batches, std::clone::Clone::clone);
149
150 assert_eq!(avg_loss, 0.0);
152 }
153
154 #[test]
155 fn test_validate() {
156 let params = vec![Tensor::from_vec(vec![1.0, 2.0], true)];
157 let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
158 let config = TrainConfig::default();
159
160 let mut trainer = Trainer::new(params, Box::new(optimizer), config);
161 trainer.set_loss(Box::new(MSELoss));
162
163 let val_batches = vec![
165 Batch::new(
166 Tensor::from_vec(vec![1.0, 2.0], false),
167 Tensor::from_vec(vec![2.0, 3.0], false),
168 ),
169 Batch::new(
170 Tensor::from_vec(vec![2.0, 3.0], false),
171 Tensor::from_vec(vec![3.0, 4.0], false),
172 ),
173 ];
174
175 let val_loss = trainer.validate(val_batches, std::clone::Clone::clone);
176
177 assert!(val_loss > 0.0);
178 assert!(val_loss.is_finite());
179 assert_eq!(trainer.metrics.val_losses.len(), 1);
180 assert_eq!(trainer.metrics.steps, 0);
182 }
183
184 #[test]
185 fn test_validate_does_not_update_params() {
186 let initial_params = vec![1.0, 2.0];
187 let params = vec![Tensor::from_vec(initial_params.clone(), true)];
188 let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
189 let config = TrainConfig::default();
190
191 let mut trainer = Trainer::new(params, Box::new(optimizer), config);
192 trainer.set_loss(Box::new(MSELoss));
193
194 let val_batches = vec![Batch::new(
195 Tensor::from_vec(vec![1.0, 2.0], false),
196 Tensor::from_vec(vec![5.0, 6.0], false), )];
198
199 trainer.validate(val_batches, std::clone::Clone::clone);
200
201 let params_after: Vec<f32> = trainer.params()[0].data().to_vec();
203 assert_eq!(params_after, initial_params);
204 }
205
206 #[test]
207 fn test_validate_with_empty_batches() {
208 let params = vec![Tensor::from_vec(vec![1.0], true)];
209 let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
210 let config = TrainConfig::default();
211
212 let mut trainer = Trainer::new(params, Box::new(optimizer), config);
213 trainer.set_loss(Box::new(MSELoss));
214
215 let batches: Vec<Batch> = vec![];
216 let val_loss = trainer.validate(batches, std::clone::Clone::clone);
217
218 assert_eq!(val_loss, 0.0);
220 }
221}