Skip to main content

entrenar/train/trainer/
step.rs

1//! Training step operations
2
3use super::core::Trainer;
4use crate::optim::clip_grad_norm;
5use crate::train::Batch;
6use crate::Tensor;
7use provable_contracts_macros::ensures;
8
9impl Trainer {
10    /// Perform a single training step
11    ///
12    /// # Arguments
13    ///
14    /// * `batch` - Training batch with inputs and targets
15    /// * `forward_fn` - Closure that computes predictions from inputs
16    ///
17    /// # Returns
18    ///
19    /// Scalar loss value for this batch
20    ///
21    /// # Example
22    ///
23    /// ```no_run
24    /// # use entrenar::train::{Trainer, Batch};
25    /// # use entrenar::Tensor;
26    /// # let mut trainer: Trainer = todo!();
27    /// # let batch: Batch = todo!();
28    /// let loss = trainer.train_step(&batch, |inputs| {
29    ///     // Forward pass: compute predictions
30    ///     inputs.clone() // Simplified example
31    /// });
32    /// ```
33    #[ensures(ret.is_finite())]
34    pub fn train_step<F>(&mut self, batch: &Batch, forward_fn: F) -> f32
35    where
36        F: FnOnce(&Tensor) -> Tensor,
37    {
38        assert!(self.loss_fn.is_some(), "Loss function must be set before training");
39
40        // Zero gradients
41        self.optimizer.zero_grad(&mut self.params);
42
43        // Forward pass
44        let predictions = forward_fn(&batch.inputs);
45
46        // Compute loss
47        let loss = self
48            .loss_fn
49            .as_ref()
50            .expect("loss function must be set before training")
51            .forward(&predictions, &batch.targets);
52
53        let loss_val = loss.data()[0];
54
55        // Backward pass
56        if let Some(backward_op) = loss.backward_op() {
57            backward_op.backward();
58        }
59
60        // Gradient clipping
61        if let Some(max_norm) = self.config.max_grad_norm {
62            clip_grad_norm(&mut self.params, max_norm);
63        }
64
65        // Optimizer step
66        self.optimizer.step(&mut self.params);
67
68        // Update metrics
69        self.metrics.increment_step();
70
71        loss_val
72    }
73
74    /// Perform forward and backward pass without optimizer step (for gradient accumulation)
75    ///
76    /// This is used internally for gradient accumulation. Gradients accumulate
77    /// across calls until zero_grad is called.
78    pub(crate) fn accumulate_gradients<F>(&mut self, batch: &Batch, forward_fn: F) -> f32
79    where
80        F: FnOnce(&Tensor) -> Tensor,
81    {
82        assert!(self.loss_fn.is_some(), "Loss function must be set before training");
83
84        // Forward pass
85        let predictions = forward_fn(&batch.inputs);
86
87        // Compute loss
88        let loss = self
89            .loss_fn
90            .as_ref()
91            .expect("loss function must be set before training")
92            .forward(&predictions, &batch.targets);
93
94        let loss_val = loss.data()[0];
95
96        // Backward pass (gradients accumulate)
97        if let Some(backward_op) = loss.backward_op() {
98            backward_op.backward();
99        }
100
101        loss_val
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use crate::optim::Adam;
108    use crate::train::{Batch, MSELoss, TrainConfig, Trainer};
109    use crate::Tensor;
110
111    #[test]
112    fn test_train_step() {
113        let params = vec![Tensor::from_vec(vec![1.0, 2.0, 3.0], true)];
114        let optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
115        let config = TrainConfig::default();
116
117        let mut trainer = Trainer::new(params, Box::new(optimizer), config);
118        trainer.set_loss(Box::new(MSELoss));
119
120        // Create a simple batch
121        let inputs = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
122        let targets = Tensor::from_vec(vec![2.0, 3.0, 4.0], false);
123        let batch = Batch::new(inputs, targets);
124
125        // Train step (identity function)
126        let loss = trainer.train_step(&batch, std::clone::Clone::clone);
127
128        // Loss should be positive (predictions != targets)
129        assert!(loss > 0.0);
130        assert!(loss.is_finite());
131        assert_eq!(trainer.metrics.steps, 1);
132    }
133
134    #[test]
135    #[should_panic(expected = "Loss function must be set")]
136    fn test_train_step_without_loss() {
137        let params = vec![Tensor::zeros(10, true)];
138        let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
139        let config = TrainConfig::default();
140
141        let mut trainer = Trainer::new(params, Box::new(optimizer), config);
142
143        let batch = Batch::new(Tensor::zeros(10, false), Tensor::zeros(10, false));
144
145        trainer.train_step(&batch, std::clone::Clone::clone);
146    }
147}