Skip to main content

any_gpu/
train.rs

1// Unlicense — cochranblock.org
2// Contributors: GotEmCoach, KOVA, Claude Opus 4.6
3//
4// Training loop: forward + backward + optimizer step.
5// One function call, not a framework.
6
7use crate::autograd::{Tape, TensorId};
8use crate::device::GpuDevice;
9use crate::optim::AdamW;
10use anyhow::Result;
11
12/// Training step result.
13pub struct StepResult {
14    pub loss: f32,
15    pub step: u32,
16}
17
18/// Train an MLP (or any differentiable graph) for one step.
19/// `forward_fn` builds the computation graph on the tape and returns (loss_id, param_ids).
20/// The training loop runs backward, extracts gradients, and updates params.
21pub fn train_step(
22    dev: &GpuDevice,
23    opt: &mut AdamW,
24    step_num: u32,
25    forward_fn: impl FnOnce(&mut Tape) -> Result<(TensorId, Vec<TensorId>)>,
26) -> Result<StepResult> {
27    let mut tape = Tape::new(dev);
28
29    // Forward: user builds the graph
30    let (loss_id, param_ids) = forward_fn(&mut tape)?;
31
32    // Read loss value
33    let loss_val = tape.read(loss_id)?[0];
34
35    // Backward
36    tape.backward(loss_id)?;
37
38    // Extract param buffers and grad buffers for optimizer
39    let mut params: Vec<_> = param_ids.iter().map(|id| {
40        // We need to extract the buffer from the tape.
41        // For now, read grad and param, re-upload for optimizer.
42        // This is inefficient (CPU roundtrip) but correct. Pipeline caching will fix later.
43        tape.read(*id).unwrap()
44    }).collect();
45
46    let grads: Vec<_> = param_ids.iter().map(|id| {
47        tape.read_grad(*id).unwrap().unwrap_or_else(|| vec![0.0; params[0].len()])
48    }).collect();
49
50    // Upload params as mutable GPU buffers and grads as read-only
51    let mut param_bufs: Vec<_> = params.iter().map(|p| dev.upload(p)).collect();
52    let grad_bufs: Vec<_> = grads.iter().map(|g| dev.upload(g)).collect();
53
54    // Optimizer step (in-place update on GPU)
55    opt.step(dev, &mut param_bufs, &grad_bufs)?;
56
57    // Read updated params back (caller can use these for next step)
58    for (i, buf) in param_bufs.iter().enumerate() {
59        params[i] = dev.read(buf)?;
60    }
61
62    Ok(StepResult { loss: loss_val, step: step_num })
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68    use crate::ops::assert_approx;
69
70    fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
71
72    #[test]
73    fn test_train_step_linear_regression() {
74        // Train y = 2x + 1 with MSE loss
75        // Input: x = [1, 2, 3], target: y = [3, 5, 7]
76        let x_data = vec![1.0, 2.0, 3.0];
77        let y_data = vec![3.0, 5.0, 7.0];
78
79        // Initial params: w=0.0, b=0.0
80        let mut w = vec![0.0f32];
81        let mut b = vec![0.0f32];
82
83        let mut opt = AdamW::new(0.1);
84        opt.weight_decay = 0.0;
85
86        let mut last_loss = f32::MAX;
87        for step in 0..50 {
88            let x = x_data.clone();
89            let y = y_data.clone();
90            let w_val = w.clone();
91            let b_val = b.clone();
92
93            let mut tape = Tape::new(dev());
94            let w_id = tape.leaf(&w_val);
95            let b_id = tape.leaf(&b_val);
96            let x_id = tape.leaf(&x);
97            let y_id = tape.leaf(&y);
98
99            // Forward: pred = x * w + b (broadcast w and b across elements)
100            // Since our ops are element-wise, we need w and b as 3-element vectors
101            let w3_id = tape.leaf(&[w_val[0], w_val[0], w_val[0]]);
102            let b3_id = tape.leaf(&[b_val[0], b_val[0], b_val[0]]);
103            let xw = tape.mul(x_id, w3_id).unwrap();
104            let pred = tape.add(xw, b3_id).unwrap();
105            let loss = tape.mse_loss(pred, y_id).unwrap();
106
107            let loss_val = tape.read(loss).unwrap()[0];
108            tape.backward(loss).unwrap();
109
110            // Get gradients for the broadcast params
111            let gw3 = tape.read_grad(w3_id).unwrap().unwrap();
112            let gb3 = tape.read_grad(b3_id).unwrap().unwrap();
113
114            // Sum gradients (since w3 and b3 are broadcast copies of w and b)
115            let gw_sum: f32 = gw3.iter().sum();
116            let gb_sum: f32 = gb3.iter().sum();
117
118            // Manual SGD for simplicity (AdamW tested separately)
119            w[0] -= 0.01 * gw_sum;
120            b[0] -= 0.01 * gb_sum;
121
122            if step % 10 == 0 {
123                assert!(loss_val < last_loss || step == 0, "loss should decrease: step {step} loss {loss_val} >= prev {last_loss}");
124            }
125            last_loss = loss_val;
126        }
127
128        // After 50 steps, w should approach 2.0 and b should approach 1.0
129        assert!((w[0] - 2.0).abs() < 0.5, "w should be near 2.0, got {}", w[0]);
130        assert!((b[0] - 1.0).abs() < 0.5, "b should be near 1.0, got {}", b[0]);
131    }
132}