1use crate::autograd::{Tape, TensorId};
8use crate::device::GpuDevice;
9use crate::optim::AdamW;
10use anyhow::Result;
11
12pub struct StepResult {
14 pub loss: f32,
15 pub step: u32,
16}
17
18pub fn train_step(
22 dev: &GpuDevice,
23 opt: &mut AdamW,
24 step_num: u32,
25 forward_fn: impl FnOnce(&mut Tape) -> Result<(TensorId, Vec<TensorId>)>,
26) -> Result<StepResult> {
27 let mut tape = Tape::new(dev);
28
29 let (loss_id, param_ids) = forward_fn(&mut tape)?;
31
32 let loss_val = tape.read(loss_id)?[0];
34
35 tape.backward(loss_id)?;
37
38 let mut params: Vec<_> = param_ids.iter().map(|id| {
40 tape.read(*id).unwrap()
44 }).collect();
45
46 let grads: Vec<_> = param_ids.iter().map(|id| {
47 tape.read_grad(*id).unwrap().unwrap_or_else(|| vec![0.0; params[0].len()])
48 }).collect();
49
50 let mut param_bufs: Vec<_> = params.iter().map(|p| dev.upload(p)).collect();
52 let grad_bufs: Vec<_> = grads.iter().map(|g| dev.upload(g)).collect();
53
54 opt.step(dev, &mut param_bufs, &grad_bufs)?;
56
57 for (i, buf) in param_bufs.iter().enumerate() {
59 params[i] = dev.read(buf)?;
60 }
61
62 Ok(StepResult { loss: loss_val, step: step_num })
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68 use crate::ops::assert_approx;
69
70 fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
71
72 #[test]
73 fn test_train_step_linear_regression() {
74 let x_data = vec![1.0, 2.0, 3.0];
77 let y_data = vec![3.0, 5.0, 7.0];
78
79 let mut w = vec![0.0f32];
81 let mut b = vec![0.0f32];
82
83 let mut opt = AdamW::new(0.1);
84 opt.weight_decay = 0.0;
85
86 let mut last_loss = f32::MAX;
87 for step in 0..50 {
88 let x = x_data.clone();
89 let y = y_data.clone();
90 let w_val = w.clone();
91 let b_val = b.clone();
92
93 let mut tape = Tape::new(dev());
94 let w_id = tape.leaf(&w_val);
95 let b_id = tape.leaf(&b_val);
96 let x_id = tape.leaf(&x);
97 let y_id = tape.leaf(&y);
98
99 let w3_id = tape.leaf(&[w_val[0], w_val[0], w_val[0]]);
102 let b3_id = tape.leaf(&[b_val[0], b_val[0], b_val[0]]);
103 let xw = tape.mul(x_id, w3_id).unwrap();
104 let pred = tape.add(xw, b3_id).unwrap();
105 let loss = tape.mse_loss(pred, y_id).unwrap();
106
107 let loss_val = tape.read(loss).unwrap()[0];
108 tape.backward(loss).unwrap();
109
110 let gw3 = tape.read_grad(w3_id).unwrap().unwrap();
112 let gb3 = tape.read_grad(b3_id).unwrap().unwrap();
113
114 let gw_sum: f32 = gw3.iter().sum();
116 let gb_sum: f32 = gb3.iter().sum();
117
118 w[0] -= 0.01 * gw_sum;
120 b[0] -= 0.01 * gb_sum;
121
122 if step % 10 == 0 {
123 assert!(loss_val < last_loss || step == 0, "loss should decrease: step {step} loss {loss_val} >= prev {last_loss}");
124 }
125 last_loss = loss_val;
126 }
127
128 assert!((w[0] - 2.0).abs() < 0.5, "w should be near 2.0, got {}", w[0]);
130 assert!((b[0] - 1.0).abs() < 0.5, "b should be near 1.0, got {}", b[0]);
131 }
132}