Skip to main content

any_gpu/
optim.rs

1// Unlicense — cochranblock.org
2// Contributors: GotEmCoach, KOVA, Claude Opus 4.6
3//
4// AdamW optimizer. Single WGSL shader: weight update step.
5
6use crate::device::{GpuBuffer, GpuDevice};
7use anyhow::{ensure, Result};
8
9#[repr(C)]
10#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
11struct AdamWParams {
12    n: u32,
13    lr: f32,
14    beta1: f32,
15    beta2: f32,
16    eps: f32,
17    weight_decay: f32,
18    beta1_t: f32, // beta1^t (for bias correction)
19    beta2_t: f32, // beta2^t
20}
21
22const SHADER_ADAMW: &str = "
23struct P {
24    n: u32, lr: f32, beta1: f32, beta2: f32,
25    eps: f32, weight_decay: f32, beta1_t: f32, beta2_t: f32,
26}
27@group(0) @binding(0) var<uniform> p: P;
28@group(0) @binding(1) var<storage, read_write> param: array<f32>;
29@group(0) @binding(2) var<storage, read> grad: array<f32>;
30@group(0) @binding(3) var<storage, read_write> m: array<f32>;
31@group(0) @binding(4) var<storage, read_write> v: array<f32>;
32@compute @workgroup_size(256)
33fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
34    let idx = gid.x + gid.y * 65535u * 256u;
35    if idx >= p.n { return; }
36
37    let g = grad[idx];
38
39    // Update biased first moment
40    m[idx] = p.beta1 * m[idx] + (1.0 - p.beta1) * g;
41    // Update biased second moment
42    v[idx] = p.beta2 * v[idx] + (1.0 - p.beta2) * g * g;
43
44    // Bias correction
45    let m_hat = m[idx] / (1.0 - p.beta1_t);
46    let v_hat = v[idx] / (1.0 - p.beta2_t);
47
48    // Weight decay + Adam update
49    param[idx] = param[idx] * (1.0 - p.lr * p.weight_decay) - p.lr * m_hat / (sqrt(v_hat) + p.eps);
50}
51";
52
53/// AdamW optimizer state for a single parameter group.
54pub struct AdamW {
55    pub lr: f32,
56    pub beta1: f32,
57    pub beta2: f32,
58    pub eps: f32,
59    pub weight_decay: f32,
60    step: u32,
61    // Per-parameter state: (first moment, second moment)
62    states: Vec<(GpuBuffer, GpuBuffer)>,
63}
64
65impl AdamW {
66    pub fn new(lr: f32) -> Self {
67        Self {
68            lr,
69            beta1: 0.9,
70            beta2: 0.999,
71            eps: 1e-8,
72            weight_decay: 0.01,
73            step: 0,
74            states: Vec::new(),
75        }
76    }
77
78    /// Run one optimizer step. Updates params in-place using their gradients.
79    /// `params` and `grads` must be the same length, with matching buffer sizes.
80    pub fn step(&mut self, dev: &GpuDevice, params: &mut [GpuBuffer], grads: &[GpuBuffer]) -> Result<()> {
81        ensure!(params.len() == grads.len(), "params/grads length mismatch");
82
83        self.step += 1;
84        let beta1_t = self.beta1.powi(self.step as i32);
85        let beta2_t = self.beta2.powi(self.step as i32);
86
87        // Lazy init state buffers
88        while self.states.len() < params.len() {
89            let n = params[self.states.len()].len;
90            let m = dev.upload(&vec![0.0f32; n]);
91            let v = dev.upload(&vec![0.0f32; n]);
92            self.states.push((m, v));
93        }
94
95        for (i, (param, grad)) in params.iter().zip(grads.iter()).enumerate() {
96            ensure!(param.len == grad.len, "param/grad size mismatch at index {i}");
97            let n = param.len as u32;
98            let (m, v) = &self.states[i];
99
100            let p = AdamWParams {
101                n,
102                lr: self.lr,
103                beta1: self.beta1,
104                beta2: self.beta2,
105                eps: self.eps,
106                weight_decay: self.weight_decay,
107                beta1_t,
108                beta2_t,
109            };
110
111            // AdamW needs read_write on param, m, v. Use raw dispatch with cached pipeline.
112            let params_buf = dev.upload_uniform(&p);
113            let pipeline = dev.pipeline(SHADER_ADAMW, Some("adamw"));
114            let bind_group = dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
115                label: None,
116                layout: &pipeline.get_bind_group_layout(0),
117                entries: &[
118                    wgpu::BindGroupEntry { binding: 0, resource: params_buf.as_entire_binding() },
119                    wgpu::BindGroupEntry { binding: 1, resource: param.buffer.as_entire_binding() },
120                    wgpu::BindGroupEntry { binding: 2, resource: grad.buffer.as_entire_binding() },
121                    wgpu::BindGroupEntry { binding: 3, resource: m.buffer.as_entire_binding() },
122                    wgpu::BindGroupEntry { binding: 4, resource: v.buffer.as_entire_binding() },
123                ],
124            });
125            let (wx, wy, wz) = crate::ops::dispatch_1d(n);
126            let mut encoder = dev.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
127            {
128                let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
129                    label: Some("adamw"),
130                    timestamp_writes: None,
131                });
132                pass.set_pipeline(&pipeline);
133                pass.set_bind_group(0, &bind_group, &[]);
134                pass.dispatch_workgroups(wx, wy, wz);
135            }
136            dev.queue.submit(Some(encoder.finish()));
137        }
138        Ok(())
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::ops::assert_approx;
146
147    fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
148
149    #[test]
150    fn test_adamw_basic() {
151        let mut param = dev().upload(&[1.0, 2.0, 3.0]);
152        let grad = dev().upload(&[0.1, 0.2, 0.3]);
153
154        let mut opt = AdamW::new(0.01);
155        opt.weight_decay = 0.0; // disable weight decay for predictable test
156        opt.step(dev(), std::slice::from_mut(&mut param), std::slice::from_ref(&grad)).unwrap();
157
158        // After 1 step with no weight decay:
159        // m = 0.1 * grad, v = 0.001 * grad^2
160        // m_hat = m / (1 - 0.9) = grad, v_hat = v / (1 - 0.999) = grad^2
161        // update = lr * grad / (sqrt(grad^2) + eps) ≈ lr * sign(grad)
162        // For positive grads: param -= 0.01 (approximately)
163        let result = dev().read(&param).unwrap();
164        // params should decrease since gradients are positive
165        assert!(result[0] < 1.0, "param[0] should decrease, got {}", result[0]);
166        assert!(result[1] < 2.0, "param[1] should decrease, got {}", result[1]);
167        assert!(result[2] < 3.0, "param[2] should decrease, got {}", result[2]);
168    }
169
170    #[test]
171    fn test_adamw_multiple_steps() {
172        let mut param = dev().upload(&[10.0]);
173        let grad = dev().upload(&[1.0]); // constant gradient pushing param down
174
175        let mut opt = AdamW::new(0.1);
176        opt.weight_decay = 0.0;
177
178        for _ in 0..10 {
179            opt.step(dev(), std::slice::from_mut(&mut param), std::slice::from_ref(&grad)).unwrap();
180        }
181
182        let result = dev().read(&param).unwrap();
183        // After 10 steps with lr=0.1 and constant positive grad, param should decrease
184        assert!(result[0] < 10.0, "after 10 steps param should decrease, got {}", result[0]);
185    }
186
187    #[test]
188    fn test_adamw_weight_decay() {
189        let mut param = dev().upload(&[10.0]);
190        let grad = dev().upload(&[0.0]); // zero gradient, only weight decay
191
192        let mut opt = AdamW::new(0.01);
193        opt.weight_decay = 0.1;
194
195        opt.step(dev(), std::slice::from_mut(&mut param), std::slice::from_ref(&grad)).unwrap();
196
197        let result = dev().read(&param).unwrap();
198        // With zero grad, only weight decay: param *= (1 - lr * wd) = 10 * (1 - 0.001) = 9.99
199        assert_approx(&result, &[9.99], 1e-3);
200    }
201
202    #[test]
203    fn test_adamw_params_grads_length_mismatch() {
204        let mut p1 = dev().upload(&[1.0]);
205        let mut p2 = dev().upload(&[2.0]);
206        let g1 = dev().upload(&[0.1]);
207        let mut opt = AdamW::new(0.01);
208        // 2 params, 1 grad -> error
209        assert!(opt.step(dev(), &mut [p1, p2], &[g1]).is_err());
210    }
211
212    #[test]
213    fn test_adamw_param_grad_size_mismatch() {
214        let mut param = dev().upload(&[1.0, 2.0, 3.0]); // 3 elements
215        let grad = dev().upload(&[0.1, 0.2]); // 2 elements
216        let mut opt = AdamW::new(0.01);
217        assert!(opt.step(dev(), std::slice::from_mut(&mut param), std::slice::from_ref(&grad)).is_err());
218    }
219
220    #[test]
221    fn test_adamw_negative_gradient() {
222        let mut param = dev().upload(&[5.0]);
223        let grad = dev().upload(&[-1.0]); // negative grad -> param increases
224        let mut opt = AdamW::new(0.01);
225        opt.weight_decay = 0.0;
226        opt.step(dev(), std::slice::from_mut(&mut param), std::slice::from_ref(&grad)).unwrap();
227        let result = dev().read(&param).unwrap();
228        assert!(result[0] > 5.0, "negative grad should increase param, got {}", result[0]);
229    }
230
231    #[test]
232    fn test_adamw_lr_zero() {
233        let mut param = dev().upload(&[10.0]);
234        let grad = dev().upload(&[100.0]);
235        let mut opt = AdamW::new(0.0); // lr=0
236        opt.weight_decay = 0.0;
237        opt.step(dev(), std::slice::from_mut(&mut param), std::slice::from_ref(&grad)).unwrap();
238        let result = dev().read(&param).unwrap();
239        assert_approx(&result, &[10.0], 1e-5); // no update
240    }
241
242    #[test]
243    fn test_adamw_defaults() {
244        let opt = AdamW::new(0.001);
245        assert_eq!(opt.lr, 0.001);
246        assert_eq!(opt.beta1, 0.9);
247        assert_eq!(opt.beta2, 0.999);
248        assert_eq!(opt.eps, 1e-8);
249        assert_eq!(opt.weight_decay, 0.01);
250    }
251}