1use crate::device::{GpuBuffer, GpuDevice};
7use anyhow::{ensure, Result};
8
9const SHADER_ADD: &str = "
10struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
11@group(0) @binding(0) var<uniform> params: Params;
12@group(0) @binding(1) var<storage, read> a: array<f32>;
13@group(0) @binding(2) var<storage, read> b: array<f32>;
14@group(0) @binding(3) var<storage, read_write> out: array<f32>;
15@compute @workgroup_size(256)
16fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
17 let idx = gid.x + gid.y * 65535u * 256u;
18 if idx >= params.n { return; }
19 out[idx] = a[idx] + b[idx];
20}
21";
22
23const SHADER_SUB: &str = "
24struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
25@group(0) @binding(0) var<uniform> params: Params;
26@group(0) @binding(1) var<storage, read> a: array<f32>;
27@group(0) @binding(2) var<storage, read> b: array<f32>;
28@group(0) @binding(3) var<storage, read_write> out: array<f32>;
29@compute @workgroup_size(256)
30fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
31 let idx = gid.x + gid.y * 65535u * 256u;
32 if idx >= params.n { return; }
33 out[idx] = a[idx] - b[idx];
34}
35";
36
37const SHADER_MUL: &str = "
38struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
39@group(0) @binding(0) var<uniform> params: Params;
40@group(0) @binding(1) var<storage, read> a: array<f32>;
41@group(0) @binding(2) var<storage, read> b: array<f32>;
42@group(0) @binding(3) var<storage, read_write> out: array<f32>;
43@compute @workgroup_size(256)
44fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
45 let idx = gid.x + gid.y * 65535u * 256u;
46 if idx >= params.n { return; }
47 out[idx] = a[idx] * b[idx];
48}
49";
50
51const SHADER_RELU: &str = "
52struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
53@group(0) @binding(0) var<uniform> params: Params;
54@group(0) @binding(1) var<storage, read> a: array<f32>;
55@group(0) @binding(2) var<storage, read_write> out: array<f32>;
56@compute @workgroup_size(256)
57fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
58 let idx = gid.x + gid.y * 65535u * 256u;
59 if idx >= params.n { return; }
60 out[idx] = max(a[idx], 0.0);
61}
62";
63
64const SHADER_SIGMOID: &str = "
65struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
66@group(0) @binding(0) var<uniform> params: Params;
67@group(0) @binding(1) var<storage, read> a: array<f32>;
68@group(0) @binding(2) var<storage, read_write> out: array<f32>;
69@compute @workgroup_size(256)
70fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
71 let idx = gid.x + gid.y * 65535u * 256u;
72 if idx >= params.n { return; }
73 out[idx] = 1.0 / (1.0 + exp(-a[idx]));
74}
75";
76
77const SHADER_SWISH: &str = "
78struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
79@group(0) @binding(0) var<uniform> params: Params;
80@group(0) @binding(1) var<storage, read> a: array<f32>;
81@group(0) @binding(2) var<storage, read_write> out: array<f32>;
82@compute @workgroup_size(256)
83fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
84 let idx = gid.x + gid.y * 65535u * 256u;
85 if idx >= params.n { return; }
86 let x = a[idx];
87 out[idx] = x / (1.0 + exp(-x));
88}
89";
90
91const SHADER_TANH: &str = "
92struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
93@group(0) @binding(0) var<uniform> params: Params;
94@group(0) @binding(1) var<storage, read> a: array<f32>;
95@group(0) @binding(2) var<storage, read_write> out: array<f32>;
96@compute @workgroup_size(256)
97fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
98 let idx = gid.x + gid.y * 65535u * 256u;
99 if idx >= params.n { return; }
100 out[idx] = tanh(a[idx]);
101}
102";
103
104#[repr(C)]
105#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
106struct ScaleParams {
107 n: u32,
108 scale: f32,
109 _pad: [u32; 2],
110}
111
112const SHADER_SCALE: &str = "
113struct Params { n: u32, scale: f32, _p0: u32, _p1: u32, }
114@group(0) @binding(0) var<uniform> params: Params;
115@group(0) @binding(1) var<storage, read> a: array<f32>;
116@group(0) @binding(2) var<storage, read_write> out: array<f32>;
117@compute @workgroup_size(256)
118fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
119 let idx = gid.x + gid.y * 65535u * 256u;
120 if idx >= params.n { return; }
121 out[idx] = a[idx] * params.scale;
122}
123";
124
125const SHADER_RELU_BACKWARD: &str = "
128struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
129@group(0) @binding(0) var<uniform> params: Params;
130@group(0) @binding(1) var<storage, read> grad_out: array<f32>;
131@group(0) @binding(2) var<storage, read> input: array<f32>;
132@group(0) @binding(3) var<storage, read_write> out: array<f32>;
133@compute @workgroup_size(256)
134fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
135 let idx = gid.x + gid.y * 65535u * 256u;
136 if idx >= params.n { return; }
137 out[idx] = select(0.0, grad_out[idx], input[idx] > 0.0);
138}
139";
140
141const SHADER_SIGMOID_BACKWARD: &str = "
142struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
143@group(0) @binding(0) var<uniform> params: Params;
144@group(0) @binding(1) var<storage, read> grad_out: array<f32>;
145@group(0) @binding(2) var<storage, read> sig_out: array<f32>;
146@group(0) @binding(3) var<storage, read_write> out: array<f32>;
147@compute @workgroup_size(256)
148fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
149 let idx = gid.x + gid.y * 65535u * 256u;
150 if idx >= params.n { return; }
151 let s = sig_out[idx];
152 out[idx] = grad_out[idx] * s * (1.0 - s);
153}
154";
155
156const SHADER_SWISH_BACKWARD: &str = "
157struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
158@group(0) @binding(0) var<uniform> params: Params;
159@group(0) @binding(1) var<storage, read> grad_out: array<f32>;
160@group(0) @binding(2) var<storage, read> input: array<f32>;
161@group(0) @binding(3) var<storage, read_write> out: array<f32>;
162@compute @workgroup_size(256)
163fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
164 let idx = gid.x + gid.y * 65535u * 256u;
165 if idx >= params.n { return; }
166 let x = input[idx];
167 let s = 1.0 / (1.0 + exp(-x));
168 out[idx] = grad_out[idx] * (s + x * s * (1.0 - s));
169}
170";
171
172const SHADER_TANH_BACKWARD: &str = "
173struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
174@group(0) @binding(0) var<uniform> params: Params;
175@group(0) @binding(1) var<storage, read> grad_out: array<f32>;
176@group(0) @binding(2) var<storage, read> tanh_out: array<f32>;
177@group(0) @binding(3) var<storage, read_write> out: array<f32>;
178@compute @workgroup_size(256)
179fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
180 let idx = gid.x + gid.y * 65535u * 256u;
181 if idx >= params.n { return; }
182 let t = tanh_out[idx];
183 out[idx] = grad_out[idx] * (1.0 - t * t);
184}
185";
186
187impl GpuDevice {
188 pub fn add(&self, a: &GpuBuffer, b: &GpuBuffer) -> Result<GpuBuffer> {
189 ensure!(a.len == b.len, "add: length mismatch ({} vs {})", a.len, b.len);
190 self.binary_op(SHADER_ADD, a, b)
191 }
192
193 pub fn sub(&self, a: &GpuBuffer, b: &GpuBuffer) -> Result<GpuBuffer> {
194 ensure!(a.len == b.len, "sub: length mismatch ({} vs {})", a.len, b.len);
195 self.binary_op(SHADER_SUB, a, b)
196 }
197
198 pub fn mul(&self, a: &GpuBuffer, b: &GpuBuffer) -> Result<GpuBuffer> {
199 ensure!(a.len == b.len, "mul: length mismatch ({} vs {})", a.len, b.len);
200 self.binary_op(SHADER_MUL, a, b)
201 }
202
203 pub fn relu(&self, a: &GpuBuffer) -> Result<GpuBuffer> {
204 self.unary_op(SHADER_RELU, a)
205 }
206
207 pub fn sigmoid(&self, a: &GpuBuffer) -> Result<GpuBuffer> {
208 self.unary_op(SHADER_SIGMOID, a)
209 }
210
211 pub fn swish(&self, a: &GpuBuffer) -> Result<GpuBuffer> {
212 self.unary_op(SHADER_SWISH, a)
213 }
214
215 pub fn tanh_act(&self, a: &GpuBuffer) -> Result<GpuBuffer> {
216 self.unary_op(SHADER_TANH, a)
217 }
218
219 pub fn scale(&self, a: &GpuBuffer, s: f32) -> Result<GpuBuffer> {
220 let out = self.alloc(a.len);
221 let params = ScaleParams { n: a.len as u32, scale: s, _pad: [0; 2] };
222 self.dispatch_shader(SHADER_SCALE, None, ¶ms, &[a], &out, super::dispatch_1d(a.len as u32));
223 Ok(out)
224 }
225
226 pub fn relu_backward(&self, grad_out: &GpuBuffer, input: &GpuBuffer) -> Result<GpuBuffer> {
230 ensure!(grad_out.len == input.len);
231 self.binary_op(SHADER_RELU_BACKWARD, grad_out, input)
232 }
233
234 pub fn sigmoid_backward(&self, grad_out: &GpuBuffer, output: &GpuBuffer) -> Result<GpuBuffer> {
236 ensure!(grad_out.len == output.len);
237 self.binary_op(SHADER_SIGMOID_BACKWARD, grad_out, output)
238 }
239
240 pub fn swish_backward(&self, grad_out: &GpuBuffer, input: &GpuBuffer) -> Result<GpuBuffer> {
242 ensure!(grad_out.len == input.len);
243 self.binary_op(SHADER_SWISH_BACKWARD, grad_out, input)
244 }
245
246 pub fn tanh_backward(&self, grad_out: &GpuBuffer, output: &GpuBuffer) -> Result<GpuBuffer> {
248 ensure!(grad_out.len == output.len);
249 self.binary_op(SHADER_TANH_BACKWARD, grad_out, output)
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use crate::ops::assert_approx;
257 fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
258
259 fn cpu_sigmoid(x: f32) -> f32 { 1.0 / (1.0 + (-x).exp()) }
261 fn cpu_swish(x: f32) -> f32 { x * cpu_sigmoid(x) }
262
263 #[test]
264 fn test_add() {
265 let a = dev().upload(&[1.0, 2.0, 3.0, 4.0]);
266 let b = dev().upload(&[10.0, 20.0, 30.0, 40.0]);
267 let result = dev().read(&dev().add(&a, &b).unwrap()).unwrap();
268 assert_eq!(result, vec![11.0, 22.0, 33.0, 44.0]);
269 }
270
271 #[test]
272 fn test_add_odd_size() {
273 let a_data: Vec<f32> = (0..13).map(|i| i as f32).collect();
275 let b_data: Vec<f32> = (0..13).map(|i| i as f32 * 10.0).collect();
276 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(a, b)| a + b).collect();
277 let result = dev().read(&dev().add(&dev().upload(&a_data), &dev().upload(&b_data)).unwrap()).unwrap();
278 assert_eq!(result, expected);
279 }
280
281 #[test]
282 fn test_add_single_element() {
283 let result = dev().read(&dev().add(&dev().upload(&[42.0]), &dev().upload(&[-42.0])).unwrap()).unwrap();
284 assert_eq!(result, vec![0.0]);
285 }
286
287 #[test]
288 fn test_sub() {
289 let a = dev().upload(&[10.0, 20.0, 30.0]);
290 let b = dev().upload(&[1.0, 2.0, 3.0]);
291 let result = dev().read(&dev().sub(&a, &b).unwrap()).unwrap();
292 assert_eq!(result, vec![9.0, 18.0, 27.0]);
293 }
294
295 #[test]
296 fn test_mul() {
297 let a = dev().upload(&[1.0, 2.0, 3.0, 4.0]);
298 let b = dev().upload(&[10.0, 20.0, 30.0, 40.0]);
299 let result = dev().read(&dev().mul(&a, &b).unwrap()).unwrap();
300 assert_eq!(result, vec![10.0, 40.0, 90.0, 160.0]);
301 }
302
303 #[test]
304 fn test_mul_zeros() {
305 let a = dev().upload(&[1.0, 2.0, 3.0]);
306 let b = dev().upload(&[0.0, 0.0, 0.0]);
307 let result = dev().read(&dev().mul(&a, &b).unwrap()).unwrap();
308 assert_eq!(result, vec![0.0, 0.0, 0.0]);
309 }
310
311 #[test]
312 fn test_relu() {
313 let a = dev().upload(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
314 let result = dev().read(&dev().relu(&a).unwrap()).unwrap();
315 assert_eq!(result, vec![0.0, 0.0, 0.0, 1.0, 2.0]);
316 }
317
318 #[test]
319 fn test_relu_all_negative() {
320 let result = dev().read(&dev().relu(&dev().upload(&[-100.0, -0.001, -1e-10])).unwrap()).unwrap();
321 assert_eq!(result, vec![0.0, 0.0, 0.0]);
322 }
323
324 #[test]
325 fn test_sigmoid_vs_cpu() {
326 let data: Vec<f32> = vec![-50.0, -10.0, -1.0, 0.0, 1.0, 10.0, 50.0];
327 let expected: Vec<f32> = data.iter().map(|&x| cpu_sigmoid(x)).collect();
328 let result = dev().read(&dev().sigmoid(&dev().upload(&data)).unwrap()).unwrap();
329 assert_approx(&result, &expected, 1e-4);
330 }
331
332 #[test]
333 fn test_swish_vs_cpu() {
334 let data: Vec<f32> = vec![-5.0, -2.0, -1.0, 0.0, 1.0, 2.0, 5.0];
335 let expected: Vec<f32> = data.iter().map(|&x| cpu_swish(x)).collect();
336 let result = dev().read(&dev().swish(&dev().upload(&data)).unwrap()).unwrap();
337 assert_approx(&result, &expected, 1e-4);
338 }
339
340 #[test]
341 fn test_tanh_vs_cpu() {
342 let data: Vec<f32> = vec![-10.0, -1.0, 0.0, 1.0, 10.0];
343 let expected: Vec<f32> = data.iter().map(|&x| x.tanh()).collect();
344 let result = dev().read(&dev().tanh_act(&dev().upload(&data)).unwrap()).unwrap();
345 assert_approx(&result, &expected, 1e-4);
346 }
347
348 #[test]
349 fn test_scale() {
350 let result = dev().read(&dev().scale(&dev().upload(&[1.0, 2.0, 3.0, 4.0]), 0.5).unwrap()).unwrap();
351 assert_eq!(result, vec![0.5, 1.0, 1.5, 2.0]);
352 }
353
354 #[test]
355 fn test_scale_zero() {
356 let result = dev().read(&dev().scale(&dev().upload(&[99.0, -99.0]), 0.0).unwrap()).unwrap();
357 assert_eq!(result, vec![0.0, 0.0]);
358 }
359
360 #[test]
361 fn test_scale_negative() {
362 let result = dev().read(&dev().scale(&dev().upload(&[1.0, -2.0, 3.0]), -2.0).unwrap()).unwrap();
363 assert_eq!(result, vec![-2.0, 4.0, -6.0]);
364 }
365
366 #[test]
369 fn test_add_length_mismatch() {
370 let a = dev().upload(&[1.0, 2.0]);
371 let b = dev().upload(&[1.0, 2.0, 3.0]);
372 assert!(dev().add(&a, &b).is_err());
373 }
374
375 #[test]
376 fn test_sub_length_mismatch() {
377 let a = dev().upload(&[1.0]);
378 let b = dev().upload(&[1.0, 2.0]);
379 assert!(dev().sub(&a, &b).is_err());
380 }
381
382 #[test]
383 fn test_mul_length_mismatch() {
384 let a = dev().upload(&[1.0, 2.0, 3.0]);
385 let b = dev().upload(&[1.0]);
386 assert!(dev().mul(&a, &b).is_err());
387 }
388
389 #[test]
392 fn test_add_vs_cpu() {
393 let a: Vec<f32> = (0..100).map(|i| (i as f32) * 0.3 - 15.0).collect();
394 let b: Vec<f32> = (0..100).map(|i| (i as f32) * -0.2 + 10.0).collect();
395 let expected: Vec<f32> = a.iter().zip(&b).map(|(x, y)| x + y).collect();
396 let result = dev().read(&dev().add(&dev().upload(&a), &dev().upload(&b)).unwrap()).unwrap();
397 assert_approx(&result, &expected, 1e-5);
398 }
399
400 #[test]
401 fn test_sub_vs_cpu() {
402 let a: Vec<f32> = (0..100).map(|i| (i as f32) * 0.7).collect();
403 let b: Vec<f32> = (0..100).map(|i| (i as f32) * 0.3).collect();
404 let expected: Vec<f32> = a.iter().zip(&b).map(|(x, y)| x - y).collect();
405 let result = dev().read(&dev().sub(&dev().upload(&a), &dev().upload(&b)).unwrap()).unwrap();
406 assert_approx(&result, &expected, 1e-5);
407 }
408
409 #[test]
410 fn test_mul_vs_cpu() {
411 let a: Vec<f32> = (0..100).map(|i| (i as f32) * 0.1 - 5.0).collect();
412 let b: Vec<f32> = (0..100).map(|i| (i as f32) * 0.05 + 0.5).collect();
413 let expected: Vec<f32> = a.iter().zip(&b).map(|(x, y)| x * y).collect();
414 let result = dev().read(&dev().mul(&dev().upload(&a), &dev().upload(&b)).unwrap()).unwrap();
415 assert_approx(&result, &expected, 1e-4);
416 }
417
418 #[test]
421 fn test_relu_backward_vs_cpu() {
422 let grad = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0]);
423 let input = dev().upload(&[-1.0, 0.5, 0.0, -0.1, 2.0]);
424 let result = dev().read(&dev().relu_backward(&grad, &input).unwrap()).unwrap();
425 assert_approx(&result, &[0.0, 2.0, 0.0, 0.0, 5.0], 1e-5);
427 }
428
429 #[test]
430 fn test_sigmoid_backward_vs_cpu() {
431 let sig_out = vec![0.5, 0.7311, 0.2689]; let grad = vec![1.0, 1.0, 1.0];
433 let expected: Vec<f32> = sig_out.iter().zip(&grad).map(|(s, g)| g * s * (1.0 - s)).collect();
434 let result = dev().read(&dev().sigmoid_backward(&dev().upload(&grad), &dev().upload(&sig_out)).unwrap()).unwrap();
435 assert_approx(&result, &expected, 1e-3);
436 }
437
438 #[test]
439 fn test_swish_backward_vs_cpu() {
440 let input = vec![0.0, 1.0, -1.0, 2.0];
441 let grad = vec![1.0, 1.0, 1.0, 1.0];
442 let expected: Vec<f32> = input.iter().map(|&x| {
443 let s = 1.0f32 / (1.0f32 + (-(x as f32)).exp());
444 s + x * s * (1.0 - s)
445 }).collect();
446 let result = dev().read(&dev().swish_backward(&dev().upload(&grad), &dev().upload(&input)).unwrap()).unwrap();
447 assert_approx(&result, &expected, 1e-3);
448 }
449
450 #[test]
451 fn test_tanh_backward_vs_cpu() {
452 let tanh_out = vec![0.0, 0.7616, -0.7616, 0.9951]; let grad = vec![1.0, 1.0, 1.0, 1.0];
454 let expected: Vec<f32> = tanh_out.iter().map(|&t| 1.0 - t * t).collect();
455 let result = dev().read(&dev().tanh_backward(&dev().upload(&grad), &dev().upload(&tanh_out)).unwrap()).unwrap();
456 assert_approx(&result, &expected, 1e-3);
457 }
458}