Skip to main content

any_gpu/ops/
elementwise.rs

1// Unlicense — cochranblock.org
2// Contributors: GotEmCoach, KOVA, Claude Opus 4.6
3//
4// Element-wise ops: add, mul, sub, scale, relu, sigmoid, swish, tanh.
5
6use crate::device::{GpuBuffer, GpuDevice};
7use anyhow::{ensure, Result};
8
9const SHADER_ADD: &str = "
10struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
11@group(0) @binding(0) var<uniform> params: Params;
12@group(0) @binding(1) var<storage, read> a: array<f32>;
13@group(0) @binding(2) var<storage, read> b: array<f32>;
14@group(0) @binding(3) var<storage, read_write> out: array<f32>;
15@compute @workgroup_size(256)
16fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
17    let idx = gid.x + gid.y * 65535u * 256u;
18    if idx >= params.n { return; }
19    out[idx] = a[idx] + b[idx];
20}
21";
22
23const SHADER_SUB: &str = "
24struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
25@group(0) @binding(0) var<uniform> params: Params;
26@group(0) @binding(1) var<storage, read> a: array<f32>;
27@group(0) @binding(2) var<storage, read> b: array<f32>;
28@group(0) @binding(3) var<storage, read_write> out: array<f32>;
29@compute @workgroup_size(256)
30fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
31    let idx = gid.x + gid.y * 65535u * 256u;
32    if idx >= params.n { return; }
33    out[idx] = a[idx] - b[idx];
34}
35";
36
37const SHADER_MUL: &str = "
38struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
39@group(0) @binding(0) var<uniform> params: Params;
40@group(0) @binding(1) var<storage, read> a: array<f32>;
41@group(0) @binding(2) var<storage, read> b: array<f32>;
42@group(0) @binding(3) var<storage, read_write> out: array<f32>;
43@compute @workgroup_size(256)
44fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
45    let idx = gid.x + gid.y * 65535u * 256u;
46    if idx >= params.n { return; }
47    out[idx] = a[idx] * b[idx];
48}
49";
50
51const SHADER_RELU: &str = "
52struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
53@group(0) @binding(0) var<uniform> params: Params;
54@group(0) @binding(1) var<storage, read> a: array<f32>;
55@group(0) @binding(2) var<storage, read_write> out: array<f32>;
56@compute @workgroup_size(256)
57fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
58    let idx = gid.x + gid.y * 65535u * 256u;
59    if idx >= params.n { return; }
60    out[idx] = max(a[idx], 0.0);
61}
62";
63
64const SHADER_SIGMOID: &str = "
65struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
66@group(0) @binding(0) var<uniform> params: Params;
67@group(0) @binding(1) var<storage, read> a: array<f32>;
68@group(0) @binding(2) var<storage, read_write> out: array<f32>;
69@compute @workgroup_size(256)
70fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
71    let idx = gid.x + gid.y * 65535u * 256u;
72    if idx >= params.n { return; }
73    out[idx] = 1.0 / (1.0 + exp(-a[idx]));
74}
75";
76
77const SHADER_SWISH: &str = "
78struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
79@group(0) @binding(0) var<uniform> params: Params;
80@group(0) @binding(1) var<storage, read> a: array<f32>;
81@group(0) @binding(2) var<storage, read_write> out: array<f32>;
82@compute @workgroup_size(256)
83fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
84    let idx = gid.x + gid.y * 65535u * 256u;
85    if idx >= params.n { return; }
86    let x = a[idx];
87    out[idx] = x / (1.0 + exp(-x));
88}
89";
90
91const SHADER_TANH: &str = "
92struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
93@group(0) @binding(0) var<uniform> params: Params;
94@group(0) @binding(1) var<storage, read> a: array<f32>;
95@group(0) @binding(2) var<storage, read_write> out: array<f32>;
96@compute @workgroup_size(256)
97fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
98    let idx = gid.x + gid.y * 65535u * 256u;
99    if idx >= params.n { return; }
100    out[idx] = tanh(a[idx]);
101}
102";
103
104#[repr(C)]
105#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
106struct ScaleParams {
107    n: u32,
108    scale: f32,
109    _pad: [u32; 2],
110}
111
112const SHADER_SCALE: &str = "
113struct Params { n: u32, scale: f32, _p0: u32, _p1: u32, }
114@group(0) @binding(0) var<uniform> params: Params;
115@group(0) @binding(1) var<storage, read> a: array<f32>;
116@group(0) @binding(2) var<storage, read_write> out: array<f32>;
117@compute @workgroup_size(256)
118fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
119    let idx = gid.x + gid.y * 65535u * 256u;
120    if idx >= params.n { return; }
121    out[idx] = a[idx] * params.scale;
122}
123";
124
125// --- Backward shaders ---
126
127const SHADER_RELU_BACKWARD: &str = "
128struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
129@group(0) @binding(0) var<uniform> params: Params;
130@group(0) @binding(1) var<storage, read> grad_out: array<f32>;
131@group(0) @binding(2) var<storage, read> input: array<f32>;
132@group(0) @binding(3) var<storage, read_write> out: array<f32>;
133@compute @workgroup_size(256)
134fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
135    let idx = gid.x + gid.y * 65535u * 256u;
136    if idx >= params.n { return; }
137    out[idx] = select(0.0, grad_out[idx], input[idx] > 0.0);
138}
139";
140
141const SHADER_SIGMOID_BACKWARD: &str = "
142struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
143@group(0) @binding(0) var<uniform> params: Params;
144@group(0) @binding(1) var<storage, read> grad_out: array<f32>;
145@group(0) @binding(2) var<storage, read> sig_out: array<f32>;
146@group(0) @binding(3) var<storage, read_write> out: array<f32>;
147@compute @workgroup_size(256)
148fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
149    let idx = gid.x + gid.y * 65535u * 256u;
150    if idx >= params.n { return; }
151    let s = sig_out[idx];
152    out[idx] = grad_out[idx] * s * (1.0 - s);
153}
154";
155
156const SHADER_SWISH_BACKWARD: &str = "
157struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
158@group(0) @binding(0) var<uniform> params: Params;
159@group(0) @binding(1) var<storage, read> grad_out: array<f32>;
160@group(0) @binding(2) var<storage, read> input: array<f32>;
161@group(0) @binding(3) var<storage, read_write> out: array<f32>;
162@compute @workgroup_size(256)
163fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
164    let idx = gid.x + gid.y * 65535u * 256u;
165    if idx >= params.n { return; }
166    let x = input[idx];
167    let s = 1.0 / (1.0 + exp(-x));
168    out[idx] = grad_out[idx] * (s + x * s * (1.0 - s));
169}
170";
171
172const SHADER_TANH_BACKWARD: &str = "
173struct Params { n: u32, _p0: u32, _p1: u32, _p2: u32, }
174@group(0) @binding(0) var<uniform> params: Params;
175@group(0) @binding(1) var<storage, read> grad_out: array<f32>;
176@group(0) @binding(2) var<storage, read> tanh_out: array<f32>;
177@group(0) @binding(3) var<storage, read_write> out: array<f32>;
178@compute @workgroup_size(256)
179fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
180    let idx = gid.x + gid.y * 65535u * 256u;
181    if idx >= params.n { return; }
182    let t = tanh_out[idx];
183    out[idx] = grad_out[idx] * (1.0 - t * t);
184}
185";
186
187impl GpuDevice {
188    pub fn add(&self, a: &GpuBuffer, b: &GpuBuffer) -> Result<GpuBuffer> {
189        ensure!(a.len == b.len, "add: length mismatch ({} vs {})", a.len, b.len);
190        self.binary_op(SHADER_ADD, a, b)
191    }
192
193    pub fn sub(&self, a: &GpuBuffer, b: &GpuBuffer) -> Result<GpuBuffer> {
194        ensure!(a.len == b.len, "sub: length mismatch ({} vs {})", a.len, b.len);
195        self.binary_op(SHADER_SUB, a, b)
196    }
197
198    pub fn mul(&self, a: &GpuBuffer, b: &GpuBuffer) -> Result<GpuBuffer> {
199        ensure!(a.len == b.len, "mul: length mismatch ({} vs {})", a.len, b.len);
200        self.binary_op(SHADER_MUL, a, b)
201    }
202
203    pub fn relu(&self, a: &GpuBuffer) -> Result<GpuBuffer> {
204        self.unary_op(SHADER_RELU, a)
205    }
206
207    pub fn sigmoid(&self, a: &GpuBuffer) -> Result<GpuBuffer> {
208        self.unary_op(SHADER_SIGMOID, a)
209    }
210
211    pub fn swish(&self, a: &GpuBuffer) -> Result<GpuBuffer> {
212        self.unary_op(SHADER_SWISH, a)
213    }
214
215    pub fn tanh_act(&self, a: &GpuBuffer) -> Result<GpuBuffer> {
216        self.unary_op(SHADER_TANH, a)
217    }
218
219    pub fn scale(&self, a: &GpuBuffer, s: f32) -> Result<GpuBuffer> {
220        let out = self.alloc(a.len);
221        let params = ScaleParams { n: a.len as u32, scale: s, _pad: [0; 2] };
222        self.dispatch_shader(SHADER_SCALE, None, &params, &[a], &out, super::dispatch_1d(a.len as u32));
223        Ok(out)
224    }
225
226    // --- Backward shaders for autograd ---
227
228    /// ReLU backward: grad_a = grad_out * (input > 0)
229    pub fn relu_backward(&self, grad_out: &GpuBuffer, input: &GpuBuffer) -> Result<GpuBuffer> {
230        ensure!(grad_out.len == input.len);
231        self.binary_op(SHADER_RELU_BACKWARD, grad_out, input)
232    }
233
234    /// Sigmoid backward: grad_a = grad_out * output * (1 - output)
235    pub fn sigmoid_backward(&self, grad_out: &GpuBuffer, output: &GpuBuffer) -> Result<GpuBuffer> {
236        ensure!(grad_out.len == output.len);
237        self.binary_op(SHADER_SIGMOID_BACKWARD, grad_out, output)
238    }
239
240    /// Swish backward: grad_a = grad_out * (sig(x) + x * sig(x) * (1 - sig(x)))
241    pub fn swish_backward(&self, grad_out: &GpuBuffer, input: &GpuBuffer) -> Result<GpuBuffer> {
242        ensure!(grad_out.len == input.len);
243        self.binary_op(SHADER_SWISH_BACKWARD, grad_out, input)
244    }
245
246    /// Tanh backward: grad_a = grad_out * (1 - output^2)
247    pub fn tanh_backward(&self, grad_out: &GpuBuffer, output: &GpuBuffer) -> Result<GpuBuffer> {
248        ensure!(grad_out.len == output.len);
249        self.binary_op(SHADER_TANH_BACKWARD, grad_out, output)
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use crate::ops::assert_approx;
257    fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
258
259    // CPU references for cross-validation
260    fn cpu_sigmoid(x: f32) -> f32 { 1.0 / (1.0 + (-x).exp()) }
261    fn cpu_swish(x: f32) -> f32 { x * cpu_sigmoid(x) }
262
263    #[test]
264    fn test_add() {
265        let a = dev().upload(&[1.0, 2.0, 3.0, 4.0]);
266        let b = dev().upload(&[10.0, 20.0, 30.0, 40.0]);
267        let result = dev().read(&dev().add(&a, &b).unwrap()).unwrap();
268        assert_eq!(result, vec![11.0, 22.0, 33.0, 44.0]);
269    }
270
271    #[test]
272    fn test_add_odd_size() {
273        // 13 elements — not aligned to workgroup size 256
274        let a_data: Vec<f32> = (0..13).map(|i| i as f32).collect();
275        let b_data: Vec<f32> = (0..13).map(|i| i as f32 * 10.0).collect();
276        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(a, b)| a + b).collect();
277        let result = dev().read(&dev().add(&dev().upload(&a_data), &dev().upload(&b_data)).unwrap()).unwrap();
278        assert_eq!(result, expected);
279    }
280
281    #[test]
282    fn test_add_single_element() {
283        let result = dev().read(&dev().add(&dev().upload(&[42.0]), &dev().upload(&[-42.0])).unwrap()).unwrap();
284        assert_eq!(result, vec![0.0]);
285    }
286
287    #[test]
288    fn test_sub() {
289        let a = dev().upload(&[10.0, 20.0, 30.0]);
290        let b = dev().upload(&[1.0, 2.0, 3.0]);
291        let result = dev().read(&dev().sub(&a, &b).unwrap()).unwrap();
292        assert_eq!(result, vec![9.0, 18.0, 27.0]);
293    }
294
295    #[test]
296    fn test_mul() {
297        let a = dev().upload(&[1.0, 2.0, 3.0, 4.0]);
298        let b = dev().upload(&[10.0, 20.0, 30.0, 40.0]);
299        let result = dev().read(&dev().mul(&a, &b).unwrap()).unwrap();
300        assert_eq!(result, vec![10.0, 40.0, 90.0, 160.0]);
301    }
302
303    #[test]
304    fn test_mul_zeros() {
305        let a = dev().upload(&[1.0, 2.0, 3.0]);
306        let b = dev().upload(&[0.0, 0.0, 0.0]);
307        let result = dev().read(&dev().mul(&a, &b).unwrap()).unwrap();
308        assert_eq!(result, vec![0.0, 0.0, 0.0]);
309    }
310
311    #[test]
312    fn test_relu() {
313        let a = dev().upload(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
314        let result = dev().read(&dev().relu(&a).unwrap()).unwrap();
315        assert_eq!(result, vec![0.0, 0.0, 0.0, 1.0, 2.0]);
316    }
317
318    #[test]
319    fn test_relu_all_negative() {
320        let result = dev().read(&dev().relu(&dev().upload(&[-100.0, -0.001, -1e-10])).unwrap()).unwrap();
321        assert_eq!(result, vec![0.0, 0.0, 0.0]);
322    }
323
324    #[test]
325    fn test_sigmoid_vs_cpu() {
326        let data: Vec<f32> = vec![-50.0, -10.0, -1.0, 0.0, 1.0, 10.0, 50.0];
327        let expected: Vec<f32> = data.iter().map(|&x| cpu_sigmoid(x)).collect();
328        let result = dev().read(&dev().sigmoid(&dev().upload(&data)).unwrap()).unwrap();
329        assert_approx(&result, &expected, 1e-4);
330    }
331
332    #[test]
333    fn test_swish_vs_cpu() {
334        let data: Vec<f32> = vec![-5.0, -2.0, -1.0, 0.0, 1.0, 2.0, 5.0];
335        let expected: Vec<f32> = data.iter().map(|&x| cpu_swish(x)).collect();
336        let result = dev().read(&dev().swish(&dev().upload(&data)).unwrap()).unwrap();
337        assert_approx(&result, &expected, 1e-4);
338    }
339
340    #[test]
341    fn test_tanh_vs_cpu() {
342        let data: Vec<f32> = vec![-10.0, -1.0, 0.0, 1.0, 10.0];
343        let expected: Vec<f32> = data.iter().map(|&x| x.tanh()).collect();
344        let result = dev().read(&dev().tanh_act(&dev().upload(&data)).unwrap()).unwrap();
345        assert_approx(&result, &expected, 1e-4);
346    }
347
348    #[test]
349    fn test_scale() {
350        let result = dev().read(&dev().scale(&dev().upload(&[1.0, 2.0, 3.0, 4.0]), 0.5).unwrap()).unwrap();
351        assert_eq!(result, vec![0.5, 1.0, 1.5, 2.0]);
352    }
353
354    #[test]
355    fn test_scale_zero() {
356        let result = dev().read(&dev().scale(&dev().upload(&[99.0, -99.0]), 0.0).unwrap()).unwrap();
357        assert_eq!(result, vec![0.0, 0.0]);
358    }
359
360    #[test]
361    fn test_scale_negative() {
362        let result = dev().read(&dev().scale(&dev().upload(&[1.0, -2.0, 3.0]), -2.0).unwrap()).unwrap();
363        assert_eq!(result, vec![-2.0, 4.0, -6.0]);
364    }
365
366    // --- Error path tests ---
367
368    #[test]
369    fn test_add_length_mismatch() {
370        let a = dev().upload(&[1.0, 2.0]);
371        let b = dev().upload(&[1.0, 2.0, 3.0]);
372        assert!(dev().add(&a, &b).is_err());
373    }
374
375    #[test]
376    fn test_sub_length_mismatch() {
377        let a = dev().upload(&[1.0]);
378        let b = dev().upload(&[1.0, 2.0]);
379        assert!(dev().sub(&a, &b).is_err());
380    }
381
382    #[test]
383    fn test_mul_length_mismatch() {
384        let a = dev().upload(&[1.0, 2.0, 3.0]);
385        let b = dev().upload(&[1.0]);
386        assert!(dev().mul(&a, &b).is_err());
387    }
388
389    // --- CPU cross-validation for add/sub/mul ---
390
391    #[test]
392    fn test_add_vs_cpu() {
393        let a: Vec<f32> = (0..100).map(|i| (i as f32) * 0.3 - 15.0).collect();
394        let b: Vec<f32> = (0..100).map(|i| (i as f32) * -0.2 + 10.0).collect();
395        let expected: Vec<f32> = a.iter().zip(&b).map(|(x, y)| x + y).collect();
396        let result = dev().read(&dev().add(&dev().upload(&a), &dev().upload(&b)).unwrap()).unwrap();
397        assert_approx(&result, &expected, 1e-5);
398    }
399
400    #[test]
401    fn test_sub_vs_cpu() {
402        let a: Vec<f32> = (0..100).map(|i| (i as f32) * 0.7).collect();
403        let b: Vec<f32> = (0..100).map(|i| (i as f32) * 0.3).collect();
404        let expected: Vec<f32> = a.iter().zip(&b).map(|(x, y)| x - y).collect();
405        let result = dev().read(&dev().sub(&dev().upload(&a), &dev().upload(&b)).unwrap()).unwrap();
406        assert_approx(&result, &expected, 1e-5);
407    }
408
409    #[test]
410    fn test_mul_vs_cpu() {
411        let a: Vec<f32> = (0..100).map(|i| (i as f32) * 0.1 - 5.0).collect();
412        let b: Vec<f32> = (0..100).map(|i| (i as f32) * 0.05 + 0.5).collect();
413        let expected: Vec<f32> = a.iter().zip(&b).map(|(x, y)| x * y).collect();
414        let result = dev().read(&dev().mul(&dev().upload(&a), &dev().upload(&b)).unwrap()).unwrap();
415        assert_approx(&result, &expected, 1e-4);
416    }
417
418    // --- Backward shader direct tests ---
419
420    #[test]
421    fn test_relu_backward_vs_cpu() {
422        let grad = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0]);
423        let input = dev().upload(&[-1.0, 0.5, 0.0, -0.1, 2.0]);
424        let result = dev().read(&dev().relu_backward(&grad, &input).unwrap()).unwrap();
425        // relu_backward: grad * (input > 0)
426        assert_approx(&result, &[0.0, 2.0, 0.0, 0.0, 5.0], 1e-5);
427    }
428
429    #[test]
430    fn test_sigmoid_backward_vs_cpu() {
431        let sig_out = vec![0.5, 0.7311, 0.2689]; // sigmoid outputs
432        let grad = vec![1.0, 1.0, 1.0];
433        let expected: Vec<f32> = sig_out.iter().zip(&grad).map(|(s, g)| g * s * (1.0 - s)).collect();
434        let result = dev().read(&dev().sigmoid_backward(&dev().upload(&grad), &dev().upload(&sig_out)).unwrap()).unwrap();
435        assert_approx(&result, &expected, 1e-3);
436    }
437
438    #[test]
439    fn test_swish_backward_vs_cpu() {
440        let input = vec![0.0, 1.0, -1.0, 2.0];
441        let grad = vec![1.0, 1.0, 1.0, 1.0];
442        let expected: Vec<f32> = input.iter().map(|&x| {
443            let s = 1.0f32 / (1.0f32 + (-(x as f32)).exp());
444            s + x * s * (1.0 - s)
445        }).collect();
446        let result = dev().read(&dev().swish_backward(&dev().upload(&grad), &dev().upload(&input)).unwrap()).unwrap();
447        assert_approx(&result, &expected, 1e-3);
448    }
449
450    #[test]
451    fn test_tanh_backward_vs_cpu() {
452        let tanh_out = vec![0.0, 0.7616, -0.7616, 0.9951]; // tanh outputs
453        let grad = vec![1.0, 1.0, 1.0, 1.0];
454        let expected: Vec<f32> = tanh_out.iter().map(|&t| 1.0 - t * t).collect();
455        let result = dev().read(&dev().tanh_backward(&dev().upload(&grad), &dev().upload(&tanh_out)).unwrap()).unwrap();
456        assert_approx(&result, &expected, 1e-3);
457    }
458}