Skip to main content

any_gpu/ops/
conv.rs

1// Unlicense — cochranblock.org
2// Contributors: GotEmCoach, KOVA, Claude Opus 4.6
3//
4// Matmul, batched matmul, conv2d, transpose_conv2d.
5
6use crate::device::{GpuBuffer, GpuDevice};
7use anyhow::{ensure, Result};
8
9// --- Matmul ---
10
11#[repr(C)]
12#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
13struct MatmulDims {
14    m: u32,
15    n: u32,
16    k: u32,
17    _pad: u32,
18}
19
20// Tiled matmul: 16x16 tiles in workgroup shared memory.
21// Each tile of A and B is loaded from global memory once per tile iteration,
22// not once per output element. Reduces global memory reads by ~16x.
23const SHADER_MATMUL: &str = "
24const TILE: u32 = 16u;
25struct Dims { m: u32, n: u32, k: u32, _pad: u32, }
26@group(0) @binding(0) var<uniform> dims: Dims;
27@group(0) @binding(1) var<storage, read> a: array<f32>;
28@group(0) @binding(2) var<storage, read> b: array<f32>;
29@group(0) @binding(3) var<storage, read_write> out: array<f32>;
30
31var<workgroup> tile_a: array<f32, 256>;  // TILE * TILE
32var<workgroup> tile_b: array<f32, 256>;
33
34@compute @workgroup_size(16, 16)
35fn main(
36    @builtin(global_invocation_id) gid: vec3<u32>,
37    @builtin(local_invocation_id) lid: vec3<u32>,
38) {
39    let row = gid.x;
40    let col = gid.y;
41    let lr = lid.x;
42    let lc = lid.y;
43
44    var sum: f32 = 0.0;
45    let num_tiles = (dims.k + TILE - 1u) / TILE;
46
47    for (var t: u32 = 0u; t < num_tiles; t++) {
48        // Load tile of A: row from global row, col from tile offset
49        let a_col = t * TILE + lc;
50        if row < dims.m && a_col < dims.k {
51            tile_a[lr * TILE + lc] = a[row * dims.k + a_col];
52        } else {
53            tile_a[lr * TILE + lc] = 0.0;
54        }
55
56        // Load tile of B: row from tile offset, col from global col
57        let b_row = t * TILE + lr;
58        if b_row < dims.k && col < dims.n {
59            tile_b[lr * TILE + lc] = b[b_row * dims.n + col];
60        } else {
61            tile_b[lr * TILE + lc] = 0.0;
62        }
63
64        workgroupBarrier();
65
66        // Accumulate dot product from shared memory
67        for (var i: u32 = 0u; i < TILE; i++) {
68            sum += tile_a[lr * TILE + i] * tile_b[i * TILE + lc];
69        }
70
71        workgroupBarrier();
72    }
73
74    if row < dims.m && col < dims.n {
75        out[row * dims.n + col] = sum;
76    }
77}
78";
79
80// --- Batched Matmul ---
81
82#[repr(C)]
83#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
84struct BatchMatmulDims {
85    batch: u32,
86    m: u32,
87    n: u32,
88    k: u32,
89}
90
91const SHADER_BATCH_MATMUL: &str = "
92const TILE: u32 = 16u;
93struct Dims { batch: u32, m: u32, n: u32, k: u32, }
94@group(0) @binding(0) var<uniform> dims: Dims;
95@group(0) @binding(1) var<storage, read> a: array<f32>;
96@group(0) @binding(2) var<storage, read> b: array<f32>;
97@group(0) @binding(3) var<storage, read_write> out: array<f32>;
98
99var<workgroup> tile_a: array<f32, 256>;
100var<workgroup> tile_b: array<f32, 256>;
101
102@compute @workgroup_size(16, 16)
103fn main(
104    @builtin(global_invocation_id) gid: vec3<u32>,
105    @builtin(local_invocation_id) lid: vec3<u32>,
106) {
107    let row = gid.x;
108    let col = gid.y;
109    let bat = gid.z;
110    if bat >= dims.batch { return; }
111    let lr = lid.x;
112    let lc = lid.y;
113    let a_off = bat * dims.m * dims.k;
114    let b_off = bat * dims.k * dims.n;
115    let o_off = bat * dims.m * dims.n;
116
117    var sum: f32 = 0.0;
118    let num_tiles = (dims.k + TILE - 1u) / TILE;
119
120    for (var t: u32 = 0u; t < num_tiles; t++) {
121        let a_col = t * TILE + lc;
122        if row < dims.m && a_col < dims.k {
123            tile_a[lr * TILE + lc] = a[a_off + row * dims.k + a_col];
124        } else {
125            tile_a[lr * TILE + lc] = 0.0;
126        }
127
128        let b_row = t * TILE + lr;
129        if b_row < dims.k && col < dims.n {
130            tile_b[lr * TILE + lc] = b[b_off + b_row * dims.n + col];
131        } else {
132            tile_b[lr * TILE + lc] = 0.0;
133        }
134
135        workgroupBarrier();
136
137        for (var i: u32 = 0u; i < TILE; i++) {
138            sum += tile_a[lr * TILE + i] * tile_b[i * TILE + lc];
139        }
140
141        workgroupBarrier();
142    }
143
144    if row < dims.m && col < dims.n {
145        out[o_off + row * dims.n + col] = sum;
146    }
147}
148";
149
150// --- Conv2d ---
151
152#[repr(C)]
153#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
154struct Conv2dParams {
155    batch: u32,
156    in_c: u32,
157    out_c: u32,
158    in_h: u32,
159    in_w: u32,
160    out_h: u32,
161    out_w: u32,
162    kh: u32,
163    kw: u32,
164    stride_h: u32,
165    stride_w: u32,
166    pad_h: u32,
167    pad_w: u32,
168    dilation_h: u32,
169    dilation_w: u32,
170    groups: u32,
171}
172
173const SHADER_CONV2D: &str = "
174struct P {
175    batch: u32, in_c: u32, out_c: u32, in_h: u32,
176    in_w: u32, out_h: u32, out_w: u32, kh: u32,
177    kw: u32, stride_h: u32, stride_w: u32, pad_h: u32,
178    pad_w: u32, dilation_h: u32, dilation_w: u32, groups: u32,
179}
180@group(0) @binding(0) var<uniform> p: P;
181@group(0) @binding(1) var<storage, read> input: array<f32>;
182@group(0) @binding(2) var<storage, read> weight: array<f32>;
183@group(0) @binding(3) var<storage, read> bias: array<f32>;
184@group(0) @binding(4) var<storage, read_write> out: array<f32>;
185@compute @workgroup_size(256)
186fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
187    let idx = gid.x + gid.y * 65535u * 256u;
188    let total = p.batch * p.out_c * p.out_h * p.out_w;
189    if idx >= total { return; }
190
191    let ow = idx % p.out_w;
192    let oh = (idx / p.out_w) % p.out_h;
193    let oc = (idx / (p.out_w * p.out_h)) % p.out_c;
194    let n  = idx / (p.out_w * p.out_h * p.out_c);
195
196    let group_in = p.in_c / p.groups;
197    let group_out = p.out_c / p.groups;
198    let g = oc / group_out;
199
200    var sum: f32 = bias[oc];
201    for (var ic: u32 = 0u; ic < group_in; ic++) {
202        for (var kh: u32 = 0u; kh < p.kh; kh++) {
203            for (var kw: u32 = 0u; kw < p.kw; kw++) {
204                let ih = oh * p.stride_h + kh * p.dilation_h - p.pad_h;
205                let iw = ow * p.stride_w + kw * p.dilation_w - p.pad_w;
206                if ih < p.in_h && iw < p.in_w {
207                    let in_idx = n * (p.in_c * p.in_h * p.in_w)
208                               + (g * group_in + ic) * (p.in_h * p.in_w)
209                               + ih * p.in_w + iw;
210                    let w_idx = oc * (group_in * p.kh * p.kw)
211                              + ic * (p.kh * p.kw)
212                              + kh * p.kw + kw;
213                    sum += input[in_idx] * weight[w_idx];
214                }
215            }
216        }
217    }
218    out[idx] = sum;
219}
220";
221
222// --- TransposeConv2d ---
223
224#[repr(C)]
225#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
226struct ConvTranspose2dParams {
227    batch: u32,
228    in_c: u32,
229    out_c: u32,
230    in_h: u32,
231    in_w: u32,
232    out_h: u32,
233    out_w: u32,
234    kh: u32,
235    kw: u32,
236    stride_h: u32,
237    stride_w: u32,
238    pad_h: u32,
239    pad_w: u32,
240    dilation_h: u32,
241    dilation_w: u32,
242    groups: u32,
243}
244
245const SHADER_CONV_TRANSPOSE2D: &str = "
246struct P {
247    batch: u32, in_c: u32, out_c: u32, in_h: u32,
248    in_w: u32, out_h: u32, out_w: u32, kh: u32,
249    kw: u32, stride_h: u32, stride_w: u32, pad_h: u32,
250    pad_w: u32, dilation_h: u32, dilation_w: u32, groups: u32,
251}
252@group(0) @binding(0) var<uniform> p: P;
253@group(0) @binding(1) var<storage, read> input: array<f32>;
254@group(0) @binding(2) var<storage, read> weight: array<f32>;
255@group(0) @binding(3) var<storage, read> bias: array<f32>;
256@group(0) @binding(4) var<storage, read_write> out: array<f32>;
257@compute @workgroup_size(256)
258fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
259    let idx = gid.x + gid.y * 65535u * 256u;
260    let total = p.batch * p.out_c * p.out_h * p.out_w;
261    if idx >= total { return; }
262
263    let ow = idx % p.out_w;
264    let oh = (idx / p.out_w) % p.out_h;
265    let oc = (idx / (p.out_w * p.out_h)) % p.out_c;
266    let n  = idx / (p.out_w * p.out_h * p.out_c);
267
268    let group_in = p.in_c / p.groups;
269    let group_out = p.out_c / p.groups;
270    let g = oc / group_out;
271    let oc_local = oc % group_out;
272
273    var sum: f32 = bias[oc];
274    for (var ic: u32 = 0u; ic < group_in; ic++) {
275        for (var kh: u32 = 0u; kh < p.kh; kh++) {
276            for (var kw: u32 = 0u; kw < p.kw; kw++) {
277                // Transposed conv: output pixel (oh,ow) is affected by input pixel (ih,iw)
278                // where oh = ih * stride - pad + kh * dilation
279                // so ih = (oh + pad - kh * dilation) / stride, must be exact division
280                let oh_off = oh + p.pad_h - kh * p.dilation_h;
281                let ow_off = ow + p.pad_w - kw * p.dilation_w;
282                // Check exact divisibility and bounds (unsigned wraparound handles negatives)
283                if oh_off % p.stride_h == 0u && ow_off % p.stride_w == 0u {
284                    let ih = oh_off / p.stride_h;
285                    let iw = ow_off / p.stride_w;
286                    if ih < p.in_h && iw < p.in_w {
287                        let in_idx = n * (p.in_c * p.in_h * p.in_w)
288                                   + (g * group_in + ic) * (p.in_h * p.in_w)
289                                   + ih * p.in_w + iw;
290                        // weight: [in_c, out_c/groups, kh, kw]
291                        let w_idx = (g * group_in + ic) * (group_out * p.kh * p.kw)
292                                  + oc_local * (p.kh * p.kw)
293                                  + kh * p.kw + kw;
294                        sum += input[in_idx] * weight[w_idx];
295                    }
296                }
297            }
298        }
299    }
300    out[idx] = sum;
301}
302";
303
304// --- Conv2d gradient shaders ---
305
306#[repr(C)]
307#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
308struct Conv2dGradParams {
309    batch: u32,
310    in_c: u32,
311    out_c: u32,
312    in_h: u32,
313    in_w: u32,
314    out_h: u32,
315    out_w: u32,
316    kh: u32,
317    kw: u32,
318    stride_h: u32,
319    stride_w: u32,
320    pad_h: u32,
321    pad_w: u32,
322    dilation_h: u32,
323    dilation_w: u32,
324    groups: u32,
325}
326
327#[repr(C)]
328#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
329struct Conv2dGradBiasParams {
330    batch: u32,
331    out_c: u32,
332    out_h: u32,
333    out_w: u32,
334}
335
336// grad_weight[oc, ic_local, kh, kw] = sum_{n,oh,ow} grad_out[n,oc,oh,ow] * input[n, g*group_in+ic_local, oh*sh+kh*dh-ph, ow*sw+kw*dw-pw]
337const SHADER_CONV2D_GRAD_WEIGHT: &str = "
338struct P {
339    batch: u32, in_c: u32, out_c: u32, in_h: u32,
340    in_w: u32, out_h: u32, out_w: u32, kh: u32,
341    kw: u32, stride_h: u32, stride_w: u32, pad_h: u32,
342    pad_w: u32, dilation_h: u32, dilation_w: u32, groups: u32,
343}
344@group(0) @binding(0) var<uniform> p: P;
345@group(0) @binding(1) var<storage, read> input: array<f32>;
346@group(0) @binding(2) var<storage, read> grad_out: array<f32>;
347@group(0) @binding(3) var<storage, read_write> grad_w: array<f32>;
348@compute @workgroup_size(256)
349fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
350    let idx = gid.x + gid.y * 65535u * 256u;
351    let group_in = p.in_c / p.groups;
352    let group_out = p.out_c / p.groups;
353    let total = p.out_c * group_in * p.kh * p.kw;
354    if idx >= total { return; }
355
356    let kw_i = idx % p.kw;
357    let kh_i = (idx / p.kw) % p.kh;
358    let ic = (idx / (p.kw * p.kh)) % group_in;
359    let oc = idx / (p.kw * p.kh * group_in);
360
361    let g = oc / group_out;
362
363    var sum: f32 = 0.0;
364    for (var n: u32 = 0u; n < p.batch; n++) {
365        for (var oh: u32 = 0u; oh < p.out_h; oh++) {
366            for (var ow: u32 = 0u; ow < p.out_w; ow++) {
367                let ih = oh * p.stride_h + kh_i * p.dilation_h;
368                let iw = ow * p.stride_w + kw_i * p.dilation_w;
369                if ih >= p.pad_h && iw >= p.pad_w {
370                    let ih2 = ih - p.pad_h;
371                    let iw2 = iw - p.pad_w;
372                    if ih2 < p.in_h && iw2 < p.in_w {
373                        let in_idx = n * (p.in_c * p.in_h * p.in_w)
374                                   + (g * group_in + ic) * (p.in_h * p.in_w)
375                                   + ih2 * p.in_w + iw2;
376                        let go_idx = n * (p.out_c * p.out_h * p.out_w)
377                                   + oc * (p.out_h * p.out_w)
378                                   + oh * p.out_w + ow;
379                        sum += grad_out[go_idx] * input[in_idx];
380                    }
381                }
382            }
383        }
384    }
385    grad_w[idx] = sum;
386}
387";
388
389// grad_bias[oc] = sum_{n,oh,ow} grad_out[n,oc,oh,ow]
390const SHADER_CONV2D_GRAD_BIAS: &str = "
391struct P { batch: u32, out_c: u32, out_h: u32, out_w: u32, }
392@group(0) @binding(0) var<uniform> p: P;
393@group(0) @binding(1) var<storage, read> grad_out: array<f32>;
394@group(0) @binding(2) var<storage, read_write> grad_bias: array<f32>;
395@compute @workgroup_size(256)
396fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
397    let oc = gid.x + gid.y * 65535u * 256u;
398    if oc >= p.out_c { return; }
399    var sum: f32 = 0.0;
400    for (var n: u32 = 0u; n < p.batch; n++) {
401        for (var oh: u32 = 0u; oh < p.out_h; oh++) {
402            for (var ow: u32 = 0u; ow < p.out_w; ow++) {
403                let idx = n * (p.out_c * p.out_h * p.out_w)
404                        + oc * (p.out_h * p.out_w)
405                        + oh * p.out_w + ow;
406                sum += grad_out[idx];
407            }
408        }
409    }
410    grad_bias[oc] = sum;
411}
412";
413
414impl GpuDevice {
415    /// Matrix multiply: A(m,k) x B(k,n) = C(m,n). Row-major layout.
416    pub fn matmul(&self, a: &GpuBuffer, b: &GpuBuffer, m: u32, n: u32, k: u32) -> Result<GpuBuffer> {
417        ensure!(a.len == (m * k) as usize, "matmul: A has {} elems, expected {}", a.len, m * k);
418        ensure!(b.len == (k * n) as usize, "matmul: B has {} elems, expected {}", b.len, k * n);
419        let out = self.alloc((m * n) as usize);
420        let dims = MatmulDims { m, n, k, _pad: 0 };
421        self.dispatch_shader(SHADER_MATMUL, Some("matmul"), &dims, &[a, b], &out, (m.div_ceil(16), n.div_ceil(16), 1));
422        Ok(out)
423    }
424
425    /// Batched matmul: A[batch,m,k] x B[batch,k,n] = C[batch,m,n].
426    pub fn batch_matmul(&self, a: &GpuBuffer, b: &GpuBuffer, batch: u32, m: u32, n: u32, k: u32) -> Result<GpuBuffer> {
427        ensure!(a.len == (batch * m * k) as usize);
428        ensure!(b.len == (batch * k * n) as usize);
429        let out = self.alloc((batch * m * n) as usize);
430        let dims = BatchMatmulDims { batch, m, n, k };
431        self.dispatch_shader(SHADER_BATCH_MATMUL, Some("batch_matmul"), &dims, &[a, b], &out, (m.div_ceil(16), n.div_ceil(16), batch));
432        Ok(out)
433    }
434
435    /// Conv2d: input[N,C_in,H,W] * weight[C_out,C_in/groups,kH,kW] + bias[C_out].
436    /// NCHW layout. Returns output[N,C_out,out_H,out_W].
437    pub fn conv2d(
438        &self,
439        input: &GpuBuffer,
440        weight: &GpuBuffer,
441        bias: Option<&GpuBuffer>,
442        batch: u32, in_c: u32, in_h: u32, in_w: u32,
443        out_c: u32, kh: u32, kw: u32,
444        stride: (u32, u32), padding: (u32, u32),
445        dilation: (u32, u32), groups: u32,
446    ) -> Result<GpuBuffer> {
447        let out_h = (in_h + 2 * padding.0 - dilation.0 * (kh - 1) - 1) / stride.0 + 1;
448        let out_w = (in_w + 2 * padding.1 - dilation.1 * (kw - 1) - 1) / stride.1 + 1;
449        let total = batch * out_c * out_h * out_w;
450
451        ensure!(input.len == (batch * in_c * in_h * in_w) as usize);
452        ensure!(weight.len == (out_c * (in_c / groups) * kh * kw) as usize);
453
454        let zero_bias;
455        let bias_buf = match bias {
456            Some(b) => {
457                ensure!(b.len == out_c as usize);
458                b
459            }
460            None => {
461                zero_bias = self.upload(&vec![0.0f32; out_c as usize]);
462                &zero_bias
463            }
464        };
465
466        let out = self.alloc(total as usize);
467        let params = Conv2dParams {
468            batch, in_c, out_c, in_h, in_w, out_h, out_w,
469            kh, kw, stride_h: stride.0, stride_w: stride.1,
470            pad_h: padding.0, pad_w: padding.1,
471            dilation_h: dilation.0, dilation_w: dilation.1, groups,
472        };
473
474        self.dispatch_shader(
475            SHADER_CONV2D, Some("conv2d"),
476            &params, &[input, weight, bias_buf], &out,
477            super::dispatch_1d(total),
478        );
479        Ok(out)
480    }
481
482    /// Transposed conv2d (deconvolution): input[N,C_in,H,W] -> output[N,C_out,out_H,out_W].
483    /// Weight layout: [C_in, C_out/groups, kH, kW].
484    pub fn conv_transpose2d(
485        &self,
486        input: &GpuBuffer,
487        weight: &GpuBuffer,
488        bias: Option<&GpuBuffer>,
489        batch: u32, in_c: u32, in_h: u32, in_w: u32,
490        out_c: u32, kh: u32, kw: u32,
491        stride: (u32, u32), padding: (u32, u32),
492        output_padding: (u32, u32),
493        dilation: (u32, u32), groups: u32,
494    ) -> Result<GpuBuffer> {
495        let out_h = (in_h - 1) * stride.0 - 2 * padding.0 + dilation.0 * (kh - 1) + output_padding.0 + 1;
496        let out_w = (in_w - 1) * stride.1 - 2 * padding.1 + dilation.1 * (kw - 1) + output_padding.1 + 1;
497        let total = batch * out_c * out_h * out_w;
498
499        ensure!(input.len == (batch * in_c * in_h * in_w) as usize);
500        ensure!(weight.len == (in_c * (out_c / groups) * kh * kw) as usize);
501
502        let zero_bias;
503        let bias_buf = match bias {
504            Some(b) => {
505                ensure!(b.len == out_c as usize);
506                b
507            }
508            None => {
509                zero_bias = self.upload(&vec![0.0f32; out_c as usize]);
510                &zero_bias
511            }
512        };
513
514        let out = self.alloc(total as usize);
515        let params = ConvTranspose2dParams {
516            batch, in_c, out_c, in_h, in_w, out_h, out_w,
517            kh, kw, stride_h: stride.0, stride_w: stride.1,
518            pad_h: padding.0, pad_w: padding.1,
519            dilation_h: dilation.0, dilation_w: dilation.1, groups,
520        };
521
522        self.dispatch_shader(
523            SHADER_CONV_TRANSPOSE2D, Some("conv_transpose2d"),
524            &params, &[input, weight, bias_buf], &out,
525            super::dispatch_1d(total),
526        );
527        Ok(out)
528    }
529    pub(crate) fn conv2d_grad_weight(
530        &self, input: &GpuBuffer, grad_out: &GpuBuffer,
531        batch: u32, in_c: u32, in_h: u32, in_w: u32,
532        out_c: u32, out_h: u32, out_w: u32, kh: u32, kw: u32,
533        stride_h: u32, stride_w: u32, pad_h: u32, pad_w: u32,
534        dil_h: u32, dil_w: u32, groups: u32,
535    ) -> Result<GpuBuffer> {
536        let group_in = in_c / groups;
537        let total = out_c * group_in * kh * kw;
538        let out = self.alloc(total as usize);
539        let params = Conv2dGradParams {
540            batch, in_c, out_c, in_h, in_w, out_h, out_w,
541            kh, kw, stride_h, stride_w, pad_h, pad_w,
542            dilation_h: dil_h, dilation_w: dil_w, groups,
543        };
544        self.dispatch_shader(
545            SHADER_CONV2D_GRAD_WEIGHT, Some("conv2d_grad_weight"),
546            &params, &[input, grad_out], &out,
547            super::dispatch_1d(total),
548        );
549        Ok(out)
550    }
551
552    pub(crate) fn conv2d_grad_bias(
553        &self, grad_out: &GpuBuffer,
554        batch: u32, out_c: u32, out_h: u32, out_w: u32,
555    ) -> Result<GpuBuffer> {
556        let out = self.alloc(out_c as usize);
557        let params = Conv2dGradBiasParams { batch, out_c, out_h, out_w };
558        self.dispatch_shader(
559            SHADER_CONV2D_GRAD_BIAS, Some("conv2d_grad_bias"),
560            &params, &[grad_out], &out,
561            super::dispatch_1d(out_c),
562        );
563        Ok(out)
564    }
565}
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570    use crate::ops::assert_approx;
571
572    fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
573
574    // CPU reference matmul for cross-validation
575    fn cpu_matmul(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
576        let mut out = vec![0.0f32; m * n];
577        for row in 0..m {
578            for col in 0..n {
579                let mut sum = 0.0;
580                for i in 0..k { sum += a[row * k + i] * b[i * n + col]; }
581                out[row * n + col] = sum;
582            }
583        }
584        out
585    }
586
587    // CPU reference conv2d for cross-validation
588    fn cpu_conv2d(
589        input: &[f32], weight: &[f32], bias: &[f32],
590        batch: usize, in_c: usize, in_h: usize, in_w: usize,
591        out_c: usize, kh: usize, kw: usize,
592        stride: (usize, usize), padding: (usize, usize), groups: usize,
593    ) -> Vec<f32> {
594        let out_h = (in_h + 2 * padding.0 - kh) / stride.0 + 1;
595        let out_w = (in_w + 2 * padding.1 - kw) / stride.1 + 1;
596        let group_in = in_c / groups;
597        let group_out = out_c / groups;
598        let mut out = vec![0.0f32; batch * out_c * out_h * out_w];
599        for n in 0..batch {
600            for oc in 0..out_c {
601                let g = oc / group_out;
602                for oh in 0..out_h {
603                    for ow in 0..out_w {
604                        let mut sum = bias[oc];
605                        for ic in 0..group_in {
606                            for kr in 0..kh {
607                                for kc in 0..kw {
608                                    let ih = oh * stride.0 + kr;
609                                    let iw = ow * stride.1 + kc;
610                                    let ih = ih as isize - padding.0 as isize;
611                                    let iw = iw as isize - padding.1 as isize;
612                                    if ih >= 0 && ih < in_h as isize && iw >= 0 && iw < in_w as isize {
613                                        let in_idx = n * in_c * in_h * in_w
614                                            + (g * group_in + ic) * in_h * in_w
615                                            + ih as usize * in_w + iw as usize;
616                                        let w_idx = oc * group_in * kh * kw + ic * kh * kw + kr * kw + kc;
617                                        sum += input[in_idx] * weight[w_idx];
618                                    }
619                                }
620                            }
621                        }
622                        out[n * out_c * out_h * out_w + oc * out_h * out_w + oh * out_w + ow] = sum;
623                    }
624                }
625            }
626        }
627        out
628    }
629
630    // --- Matmul tests ---
631
632    #[test]
633    fn test_matmul_2x2() {
634        let a = dev().upload(&[1.0, 2.0, 3.0, 4.0]);
635        let b = dev().upload(&[5.0, 6.0, 7.0, 8.0]);
636        let result = dev().read(&dev().matmul(&a, &b, 2, 2, 2).unwrap()).unwrap();
637        assert_eq!(result, vec![19.0, 22.0, 43.0, 50.0]);
638    }
639
640    #[test]
641    fn test_matmul_nonsquare_vs_cpu() {
642        // 3x4 @ 4x2 = 3x2
643        let a: Vec<f32> = (1..=12).map(|x| x as f32).collect();
644        let b: Vec<f32> = (1..=8).map(|x| x as f32 * 0.1).collect();
645        let expected = cpu_matmul(&a, &b, 3, 2, 4);
646        let result = dev().read(&dev().matmul(&dev().upload(&a), &dev().upload(&b), 3, 2, 4).unwrap()).unwrap();
647        assert_approx(&result, &expected, 1e-4);
648    }
649
650    #[test]
651    fn test_matmul_1x1() {
652        let result = dev().read(&dev().matmul(&dev().upload(&[3.0]), &dev().upload(&[7.0]), 1, 1, 1).unwrap()).unwrap();
653        assert_eq!(result, vec![21.0]);
654    }
655
656    #[test]
657    fn test_matmul_17x13_vs_cpu() {
658        // Odd dims that don't align to 16x16 workgroup
659        let m = 17; let n = 13; let k = 11;
660        let a: Vec<f32> = (0..m*k).map(|i| (i as f32 * 0.01) - 0.5).collect();
661        let b: Vec<f32> = (0..k*n).map(|i| (i as f32 * 0.01) - 0.3).collect();
662        let expected = cpu_matmul(&a, &b, m, n, k);
663        let result = dev().read(&dev().matmul(
664            &dev().upload(&a), &dev().upload(&b), m as u32, n as u32, k as u32
665        ).unwrap()).unwrap();
666        assert_approx(&result, &expected, 1e-3);
667    }
668
669    // --- Batch matmul tests ---
670
671    #[test]
672    fn test_batch_matmul() {
673        let a = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
674        let b = dev().upload(&[1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0]);
675        let result = dev().read(&dev().batch_matmul(&a, &b, 2, 2, 2, 2).unwrap()).unwrap();
676        assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 10.0, 12.0, 14.0, 16.0]);
677    }
678
679    #[test]
680    fn test_batch_matmul_nonsquare() {
681        // batch=1, 2x3 @ 3x1 = 2x1
682        let a = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
683        let b = dev().upload(&[1.0, 1.0, 1.0]);
684        let result = dev().read(&dev().batch_matmul(&a, &b, 1, 2, 1, 3).unwrap()).unwrap();
685        assert_eq!(result, vec![6.0, 15.0]);
686    }
687
688    // --- Conv2d tests ---
689
690    #[test]
691    fn test_conv2d_3x3_vs_cpu() {
692        // Verify GPU conv2d matches CPU reference
693        let input: Vec<f32> = (1..=16).map(|x| x as f32).collect();
694        let weight = vec![1.0, 0.0, -1.0, 1.0, 0.0, -1.0, 1.0, 0.0, -1.0];
695        let bias = vec![0.0];
696        let expected = cpu_conv2d(&input, &weight, &bias, 1, 1, 4, 4, 1, 3, 3, (1,1), (0,0), 1);
697        let result = dev().read(&dev().conv2d(
698            &dev().upload(&input), &dev().upload(&weight), Some(&dev().upload(&bias)),
699            1, 1, 4, 4, 1, 3, 3, (1,1), (0,0), (1,1), 1
700        ).unwrap()).unwrap();
701        assert_approx(&result, &expected, 1e-5);
702    }
703
704    #[test]
705    fn test_conv2d_1x1_kernel() {
706        // 1x1 conv = per-pixel channel mixing
707        // 2 in channels, 3 out channels, 2x2 spatial
708        let input = dev().upload(&[1.0, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0]);
709        // weight[out_c, in_c, 1, 1]: 3 output channels, 2 input channels
710        let weight = dev().upload(&[1.0, 0.5, 0.0, 1.0, -1.0, 2.0]);
711        let bias = dev().upload(&[0.0, 0.0, 0.0]);
712        let result = dev().read(&dev().conv2d(&input, &weight, Some(&bias),
713            1, 2, 2, 2, 3, 1, 1, (1,1), (0,0), (1,1), 1).unwrap()).unwrap();
714        // out_c=0: 1.0*in0 + 0.5*in1
715        assert_approx(&result[0..4], &[6.0, 12.0, 18.0, 24.0], 1e-5);
716        // out_c=1: 0.0*in0 + 1.0*in1
717        assert_approx(&result[4..8], &[10.0, 20.0, 30.0, 40.0], 1e-5);
718    }
719
720    #[test]
721    fn test_conv2d_padding_vs_cpu() {
722        let input: Vec<f32> = (1..=9).map(|x| x as f32).collect();
723        let weight = vec![1.0; 9];
724        let bias = vec![0.0];
725        let expected = cpu_conv2d(&input, &weight, &bias, 1, 1, 3, 3, 1, 3, 3, (1,1), (1,1), 1);
726        let result = dev().read(&dev().conv2d(
727            &dev().upload(&input), &dev().upload(&weight), None,
728            1, 1, 3, 3, 1, 3, 3, (1,1), (1,1), (1,1), 1
729        ).unwrap()).unwrap();
730        assert_approx(&result, &expected, 1e-5);
731    }
732
733    #[test]
734    fn test_conv2d_stride2() {
735        let input: Vec<f32> = (1..=16).map(|x| x as f32).collect();
736        let result = dev().read(&dev().conv2d(
737            &dev().upload(&input), &dev().upload(&[1.0]), None,
738            1, 1, 4, 4, 1, 1, 1, (2,2), (0,0), (1,1), 1
739        ).unwrap()).unwrap();
740        assert_eq!(result, vec![1.0, 3.0, 9.0, 11.0]);
741    }
742
743    #[test]
744    fn test_conv2d_5x5_kernel_vs_cpu() {
745        // 1x1x8x8 input, 1x1x5x5 kernel, padding=2 -> 8x8 output
746        let input: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1).collect();
747        let weight: Vec<f32> = (0..25).map(|i| if i == 12 { 1.0 } else { 0.0 }).collect(); // center=1, rest=0 -> identity
748        let bias = vec![0.0];
749        let expected = cpu_conv2d(&input, &weight, &bias, 1, 1, 8, 8, 1, 5, 5, (1,1), (2,2), 1);
750        let result = dev().read(&dev().conv2d(
751            &dev().upload(&input), &dev().upload(&weight), None,
752            1, 1, 8, 8, 1, 5, 5, (1,1), (2,2), (1,1), 1
753        ).unwrap()).unwrap();
754        assert_approx(&result, &expected, 1e-4);
755    }
756
757    #[test]
758    fn test_conv2d_multichannel_vs_cpu() {
759        // batch=2, in_c=3, out_c=2, 4x4, 3x3 kernel, padding=1
760        let batch = 2; let in_c = 3; let out_c = 2; let h = 4; let w = 4;
761        let input: Vec<f32> = (0..batch*in_c*h*w).map(|i| (i as f32) * 0.01 - 0.5).collect();
762        let weight: Vec<f32> = (0..out_c*in_c*3*3).map(|i| (i as f32) * 0.02 - 0.3).collect();
763        let bias = vec![0.1, -0.2];
764        let expected = cpu_conv2d(&input, &weight, &bias, batch, in_c, h, w, out_c, 3, 3, (1,1), (1,1), 1);
765        let result = dev().read(&dev().conv2d(
766            &dev().upload(&input), &dev().upload(&weight), Some(&dev().upload(&bias)),
767            batch as u32, in_c as u32, h as u32, w as u32, out_c as u32, 3, 3, (1,1), (1,1), (1,1), 1
768        ).unwrap()).unwrap();
769        assert_approx(&result, &expected, 1e-3);
770    }
771
772    #[test]
773    fn test_conv2d_with_bias() {
774        // Verify bias is added correctly
775        let input = dev().upload(&[0.0; 4]); // 1x1x2x2 zeros
776        let weight = dev().upload(&[0.0]); // 1x1x1x1 zero kernel
777        let bias = dev().upload(&[42.0]);
778        let result = dev().read(&dev().conv2d(&input, &weight, Some(&bias),
779            1, 1, 2, 2, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap()).unwrap();
780        assert_eq!(result, vec![42.0, 42.0, 42.0, 42.0]);
781    }
782
783    // --- Transpose conv2d tests ---
784
785    #[test]
786    fn test_conv_transpose2d_stride2() {
787        let input = dev().upload(&[1.0, 2.0, 3.0, 4.0]);
788        let weight = dev().upload(&[1.0]);
789        let result = dev().read(&dev().conv_transpose2d(&input, &weight, None,
790            1, 1, 2, 2, 1, 1, 1, (2,2), (0,0), (0,0), (1,1), 1).unwrap()).unwrap();
791        assert_eq!(result.len(), 9);
792        assert_approx(&result, &[1.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0, 0.0, 4.0], 1e-5);
793    }
794
795    #[test]
796    fn test_conv_transpose2d_3x3_kernel() {
797        let input = dev().upload(&[5.0]);
798        let weight = dev().upload(&[1.0; 9]);
799        let result = dev().read(&dev().conv_transpose2d(&input, &weight, None,
800            1, 1, 1, 1, 1, 3, 3, (1,1), (0,0), (0,0), (1,1), 1).unwrap()).unwrap();
801        assert_eq!(result.len(), 9);
802        assert_approx(&result, &[5.0; 9], 1e-5);
803    }
804
805    // --- Error path tests ---
806
807    #[test]
808    fn test_matmul_a_size_mismatch() {
809        let a = dev().upload(&[1.0, 2.0, 3.0]); // 3 elements
810        let b = dev().upload(&[1.0, 2.0, 3.0, 4.0]); // 4 elements
811        assert!(dev().matmul(&a, &b, 2, 2, 2).is_err()); // expects a=4 elements
812    }
813
814    #[test]
815    fn test_matmul_b_size_mismatch() {
816        let a = dev().upload(&[1.0, 2.0, 3.0, 4.0]);
817        let b = dev().upload(&[1.0, 2.0, 3.0]); // expects 4
818        assert!(dev().matmul(&a, &b, 2, 2, 2).is_err());
819    }
820
821    #[test]
822    fn test_conv2d_input_size_mismatch() {
823        let input = dev().upload(&[1.0; 10]); // wrong size for 1x1x4x4=16
824        let weight = dev().upload(&[1.0; 9]);
825        assert!(dev().conv2d(&input, &weight, None, 1, 1, 4, 4, 1, 3, 3, (1,1), (0,0), (1,1), 1).is_err());
826    }
827
828    #[test]
829    fn test_conv2d_weight_size_mismatch() {
830        let input = dev().upload(&[1.0; 16]);
831        let weight = dev().upload(&[1.0; 5]); // wrong size for 1*1*3*3=9
832        assert!(dev().conv2d(&input, &weight, None, 1, 1, 4, 4, 1, 3, 3, (1,1), (0,0), (1,1), 1).is_err());
833    }
834
835    #[test]
836    fn test_conv2d_bias_size_mismatch() {
837        let input = dev().upload(&[1.0; 16]);
838        let weight = dev().upload(&[1.0; 9]);
839        let bias = dev().upload(&[1.0, 2.0]); // expects 1 (out_c=1)
840        assert!(dev().conv2d(&input, &weight, Some(&bias), 1, 1, 4, 4, 1, 3, 3, (1,1), (0,0), (1,1), 1).is_err());
841    }
842
843    // --- CPU cross-validation for batch_matmul ---
844
845    #[test]
846    fn test_batch_matmul_vs_cpu() {
847        let batch = 3; let m = 4; let n = 3; let k = 5;
848        let a: Vec<f32> = (0..batch*m*k).map(|i| (i as f32) * 0.01 - 0.3).collect();
849        let b: Vec<f32> = (0..batch*k*n).map(|i| (i as f32) * 0.02 - 0.1).collect();
850        let mut expected = vec![0.0f32; batch * m * n];
851        for bat in 0..batch {
852            for row in 0..m {
853                for col in 0..n {
854                    let mut sum = 0.0;
855                    for i in 0..k {
856                        sum += a[bat*m*k + row*k + i] * b[bat*k*n + i*n + col];
857                    }
858                    expected[bat*m*n + row*n + col] = sum;
859                }
860            }
861        }
862        let result = dev().read(&dev().batch_matmul(
863            &dev().upload(&a), &dev().upload(&b), batch as u32, m as u32, n as u32, k as u32
864        ).unwrap()).unwrap();
865        assert_approx(&result, &expected, 1e-3);
866    }
867
868    // --- Conv2d grad weight numeric check ---
869
870    #[test]
871    fn test_conv2d_grad_weight_vs_numeric() {
872        // 1x1x4x4 input, 1x1x2x2 kernel, stride 1, pad 0
873        // out_h = out_w = 3
874        let input_data: Vec<f32> = (1..=16).map(|x| x as f32 * 0.1).collect();
875        let weight_data = vec![0.5f32, -0.5, 0.3, -0.3];
876        let eps = 1e-3f32;
877
878        // Compute analytical grad_weight
879        let input_buf = dev().upload(&input_data);
880        let weight_buf = dev().upload(&weight_data);
881        let out = dev().conv2d(&input_buf, &weight_buf, None, 1, 1, 4, 4, 1, 2, 2, (1,1), (0,0), (1,1), 1).unwrap();
882        // Use grad_out = all ones to make it a simple sum
883        let grad_out_data = vec![1.0f32; 9]; // 1x1x3x3
884        let grad_out_buf = dev().upload(&grad_out_data);
885        let gw = dev().conv2d_grad_weight(&input_buf, &grad_out_buf, 1, 1, 4, 4, 1, 3, 3, 2, 2, 1, 1, 0, 0, 1, 1, 1).unwrap();
886        let gw_data = dev().read(&gw).unwrap();
887
888        // Numeric gradient check
889        for i in 0..4 {
890            let mut wp = weight_data.clone();
891            let mut wm = weight_data.clone();
892            wp[i] += eps;
893            wm[i] -= eps;
894            let wp_buf = dev().upload(&wp);
895            let wm_buf = dev().upload(&wm);
896            let outp = dev().read(&dev().conv2d(&input_buf, &wp_buf, None, 1, 1, 4, 4, 1, 2, 2, (1,1), (0,0), (1,1), 1).unwrap()).unwrap();
897            let outm = dev().read(&dev().conv2d(&input_buf, &wm_buf, None, 1, 1, 4, 4, 1, 2, 2, (1,1), (0,0), (1,1), 1).unwrap()).unwrap();
898            // loss = sum(out) => numeric grad = (sum(outp)-sum(outm)) / (2*eps)
899            let numeric: f32 = (outp.iter().sum::<f32>() - outm.iter().sum::<f32>()) / (2.0 * eps);
900            assert!((gw_data[i] - numeric).abs() < 1e-2,
901                "grad_w[{i}]: analytical={}, numeric={}", gw_data[i], numeric);
902        }
903        let _ = out;
904    }
905
906    #[test]
907    fn test_conv2d_grad_bias_basic() {
908        // grad_out = [[1,2],[3,4]], batch=1, out_c=1, out_h=2, out_w=2
909        // grad_bias[0] should = 1+2+3+4 = 10
910        let grad_out_data = vec![1.0f32, 2.0, 3.0, 4.0];
911        let grad_out_buf = dev().upload(&grad_out_data);
912        let gb = dev().conv2d_grad_bias(&grad_out_buf, 1, 1, 2, 2).unwrap();
913        let gb_data = dev().read(&gb).unwrap();
914        assert_approx(&gb_data, &[10.0], 1e-5);
915
916        // 2 output channels, batch=1, out_h=2, out_w=2
917        // grad_out[0] = [1,2,3,4], grad_out[1] = [5,6,7,8]
918        let grad_out2 = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
919        let gb2 = dev().conv2d_grad_bias(&dev().upload(&grad_out2), 1, 2, 2, 2).unwrap();
920        let gb2_data = dev().read(&gb2).unwrap();
921        assert_approx(&gb2_data, &[10.0, 26.0], 1e-5);
922    }
923}