1use crate::device::{GpuBuffer, GpuDevice};
7use anyhow::{ensure, Result};
8
9#[repr(C)]
10#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
11struct AdamWParams {
12 n: u32,
13 lr: f32,
14 beta1: f32,
15 beta2: f32,
16 eps: f32,
17 weight_decay: f32,
18 beta1_t: f32, beta2_t: f32, }
21
22const SHADER_ADAMW: &str = "
23struct P {
24 n: u32, lr: f32, beta1: f32, beta2: f32,
25 eps: f32, weight_decay: f32, beta1_t: f32, beta2_t: f32,
26}
27@group(0) @binding(0) var<uniform> p: P;
28@group(0) @binding(1) var<storage, read_write> param: array<f32>;
29@group(0) @binding(2) var<storage, read> grad: array<f32>;
30@group(0) @binding(3) var<storage, read_write> m: array<f32>;
31@group(0) @binding(4) var<storage, read_write> v: array<f32>;
32@compute @workgroup_size(256)
33fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
34 let idx = gid.x + gid.y * 65535u * 256u;
35 if idx >= p.n { return; }
36
37 let g = grad[idx];
38
39 // Update biased first moment
40 m[idx] = p.beta1 * m[idx] + (1.0 - p.beta1) * g;
41 // Update biased second moment
42 v[idx] = p.beta2 * v[idx] + (1.0 - p.beta2) * g * g;
43
44 // Bias correction
45 let m_hat = m[idx] / (1.0 - p.beta1_t);
46 let v_hat = v[idx] / (1.0 - p.beta2_t);
47
48 // Weight decay + Adam update
49 param[idx] = param[idx] * (1.0 - p.lr * p.weight_decay) - p.lr * m_hat / (sqrt(v_hat) + p.eps);
50}
51";
52
53pub struct AdamW {
55 pub lr: f32,
56 pub beta1: f32,
57 pub beta2: f32,
58 pub eps: f32,
59 pub weight_decay: f32,
60 step: u32,
61 states: Vec<(GpuBuffer, GpuBuffer)>,
63}
64
65impl AdamW {
66 pub fn new(lr: f32) -> Self {
67 Self {
68 lr,
69 beta1: 0.9,
70 beta2: 0.999,
71 eps: 1e-8,
72 weight_decay: 0.01,
73 step: 0,
74 states: Vec::new(),
75 }
76 }
77
78 pub fn step(&mut self, dev: &GpuDevice, params: &mut [GpuBuffer], grads: &[GpuBuffer]) -> Result<()> {
81 ensure!(params.len() == grads.len(), "params/grads length mismatch");
82
83 self.step += 1;
84 let beta1_t = self.beta1.powi(self.step as i32);
85 let beta2_t = self.beta2.powi(self.step as i32);
86
87 while self.states.len() < params.len() {
89 let n = params[self.states.len()].len;
90 let m = dev.upload(&vec![0.0f32; n]);
91 let v = dev.upload(&vec![0.0f32; n]);
92 self.states.push((m, v));
93 }
94
95 for (i, (param, grad)) in params.iter().zip(grads.iter()).enumerate() {
96 ensure!(param.len == grad.len, "param/grad size mismatch at index {i}");
97 let n = param.len as u32;
98 let (m, v) = &self.states[i];
99
100 let p = AdamWParams {
101 n,
102 lr: self.lr,
103 beta1: self.beta1,
104 beta2: self.beta2,
105 eps: self.eps,
106 weight_decay: self.weight_decay,
107 beta1_t,
108 beta2_t,
109 };
110
111 let params_buf = dev.upload_uniform(&p);
113 let pipeline = dev.pipeline(SHADER_ADAMW, Some("adamw"));
114 let bind_group = dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
115 label: None,
116 layout: &pipeline.get_bind_group_layout(0),
117 entries: &[
118 wgpu::BindGroupEntry { binding: 0, resource: params_buf.as_entire_binding() },
119 wgpu::BindGroupEntry { binding: 1, resource: param.buffer.as_entire_binding() },
120 wgpu::BindGroupEntry { binding: 2, resource: grad.buffer.as_entire_binding() },
121 wgpu::BindGroupEntry { binding: 3, resource: m.buffer.as_entire_binding() },
122 wgpu::BindGroupEntry { binding: 4, resource: v.buffer.as_entire_binding() },
123 ],
124 });
125 let (wx, wy, wz) = crate::ops::dispatch_1d(n);
126 let mut encoder = dev.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
127 {
128 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
129 label: Some("adamw"),
130 timestamp_writes: None,
131 });
132 pass.set_pipeline(&pipeline);
133 pass.set_bind_group(0, &bind_group, &[]);
134 pass.dispatch_workgroups(wx, wy, wz);
135 }
136 dev.queue.submit(Some(encoder.finish()));
137 }
138 Ok(())
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use crate::ops::assert_approx;
146
147 fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
148
149 #[test]
150 fn test_adamw_basic() {
151 let mut param = dev().upload(&[1.0, 2.0, 3.0]);
152 let grad = dev().upload(&[0.1, 0.2, 0.3]);
153
154 let mut opt = AdamW::new(0.01);
155 opt.weight_decay = 0.0; opt.step(dev(), std::slice::from_mut(&mut param), std::slice::from_ref(&grad)).unwrap();
157
158 let result = dev().read(¶m).unwrap();
164 assert!(result[0] < 1.0, "param[0] should decrease, got {}", result[0]);
166 assert!(result[1] < 2.0, "param[1] should decrease, got {}", result[1]);
167 assert!(result[2] < 3.0, "param[2] should decrease, got {}", result[2]);
168 }
169
170 #[test]
171 fn test_adamw_multiple_steps() {
172 let mut param = dev().upload(&[10.0]);
173 let grad = dev().upload(&[1.0]); let mut opt = AdamW::new(0.1);
176 opt.weight_decay = 0.0;
177
178 for _ in 0..10 {
179 opt.step(dev(), std::slice::from_mut(&mut param), std::slice::from_ref(&grad)).unwrap();
180 }
181
182 let result = dev().read(¶m).unwrap();
183 assert!(result[0] < 10.0, "after 10 steps param should decrease, got {}", result[0]);
185 }
186
187 #[test]
188 fn test_adamw_weight_decay() {
189 let mut param = dev().upload(&[10.0]);
190 let grad = dev().upload(&[0.0]); let mut opt = AdamW::new(0.01);
193 opt.weight_decay = 0.1;
194
195 opt.step(dev(), std::slice::from_mut(&mut param), std::slice::from_ref(&grad)).unwrap();
196
197 let result = dev().read(¶m).unwrap();
198 assert_approx(&result, &[9.99], 1e-3);
200 }
201
202 #[test]
203 fn test_adamw_params_grads_length_mismatch() {
204 let mut p1 = dev().upload(&[1.0]);
205 let mut p2 = dev().upload(&[2.0]);
206 let g1 = dev().upload(&[0.1]);
207 let mut opt = AdamW::new(0.01);
208 assert!(opt.step(dev(), &mut [p1, p2], &[g1]).is_err());
210 }
211
212 #[test]
213 fn test_adamw_param_grad_size_mismatch() {
214 let mut param = dev().upload(&[1.0, 2.0, 3.0]); let grad = dev().upload(&[0.1, 0.2]); let mut opt = AdamW::new(0.01);
217 assert!(opt.step(dev(), std::slice::from_mut(&mut param), std::slice::from_ref(&grad)).is_err());
218 }
219
220 #[test]
221 fn test_adamw_negative_gradient() {
222 let mut param = dev().upload(&[5.0]);
223 let grad = dev().upload(&[-1.0]); let mut opt = AdamW::new(0.01);
225 opt.weight_decay = 0.0;
226 opt.step(dev(), std::slice::from_mut(&mut param), std::slice::from_ref(&grad)).unwrap();
227 let result = dev().read(¶m).unwrap();
228 assert!(result[0] > 5.0, "negative grad should increase param, got {}", result[0]);
229 }
230
231 #[test]
232 fn test_adamw_lr_zero() {
233 let mut param = dev().upload(&[10.0]);
234 let grad = dev().upload(&[100.0]);
235 let mut opt = AdamW::new(0.0); opt.weight_decay = 0.0;
237 opt.step(dev(), std::slice::from_mut(&mut param), std::slice::from_ref(&grad)).unwrap();
238 let result = dev().read(¶m).unwrap();
239 assert_approx(&result, &[10.0], 1e-5); }
241
242 #[test]
243 fn test_adamw_defaults() {
244 let opt = AdamW::new(0.001);
245 assert_eq!(opt.lr, 0.001);
246 assert_eq!(opt.beta1, 0.9);
247 assert_eq!(opt.beta2, 0.999);
248 assert_eq!(opt.eps, 1e-8);
249 assert_eq!(opt.weight_decay, 0.01);
250 }
251}