entrenar/train/trainer/
step.rs1use super::core::Trainer;
4use crate::optim::clip_grad_norm;
5use crate::train::Batch;
6use crate::Tensor;
7use provable_contracts_macros::ensures;
8
9impl Trainer {
10 #[ensures(ret.is_finite())]
34 pub fn train_step<F>(&mut self, batch: &Batch, forward_fn: F) -> f32
35 where
36 F: FnOnce(&Tensor) -> Tensor,
37 {
38 assert!(self.loss_fn.is_some(), "Loss function must be set before training");
39
40 self.optimizer.zero_grad(&mut self.params);
42
43 let predictions = forward_fn(&batch.inputs);
45
46 let loss = self
48 .loss_fn
49 .as_ref()
50 .expect("loss function must be set before training")
51 .forward(&predictions, &batch.targets);
52
53 let loss_val = loss.data()[0];
54
55 if let Some(backward_op) = loss.backward_op() {
57 backward_op.backward();
58 }
59
60 if let Some(max_norm) = self.config.max_grad_norm {
62 clip_grad_norm(&mut self.params, max_norm);
63 }
64
65 self.optimizer.step(&mut self.params);
67
68 self.metrics.increment_step();
70
71 loss_val
72 }
73
74 pub(crate) fn accumulate_gradients<F>(&mut self, batch: &Batch, forward_fn: F) -> f32
79 where
80 F: FnOnce(&Tensor) -> Tensor,
81 {
82 assert!(self.loss_fn.is_some(), "Loss function must be set before training");
83
84 let predictions = forward_fn(&batch.inputs);
86
87 let loss = self
89 .loss_fn
90 .as_ref()
91 .expect("loss function must be set before training")
92 .forward(&predictions, &batch.targets);
93
94 let loss_val = loss.data()[0];
95
96 if let Some(backward_op) = loss.backward_op() {
98 backward_op.backward();
99 }
100
101 loss_val
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use crate::optim::Adam;
108 use crate::train::{Batch, MSELoss, TrainConfig, Trainer};
109 use crate::Tensor;
110
111 #[test]
112 fn test_train_step() {
113 let params = vec![Tensor::from_vec(vec![1.0, 2.0, 3.0], true)];
114 let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
115 let config = TrainConfig::default();
116
117 let mut trainer = Trainer::new(params, Box::new(optimizer), config);
118 trainer.set_loss(Box::new(MSELoss));
119
120 let inputs = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
122 let targets = Tensor::from_vec(vec![2.0, 3.0, 4.0], false);
123 let batch = Batch::new(inputs, targets);
124
125 let loss = trainer.train_step(&batch, std::clone::Clone::clone);
127
128 assert!(loss > 0.0);
130 assert!(loss.is_finite());
131 assert_eq!(trainer.metrics.steps, 1);
132 }
133
134 #[test]
135 #[should_panic(expected = "Loss function must be set")]
136 fn test_train_step_without_loss() {
137 let params = vec![Tensor::zeros(10, true)];
138 let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
139 let config = TrainConfig::default();
140
141 let mut trainer = Trainer::new(params, Box::new(optimizer), config);
142
143 let batch = Batch::new(Tensor::zeros(10, false), Tensor::zeros(10, false));
144
145 trainer.train_step(&batch, std::clone::Clone::clone);
146 }
147}