1use crate::device::{GpuBuffer, GpuDevice};
7use anyhow::{ensure, Result};
8
9#[repr(C)]
12#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
13struct MatmulDims {
14 m: u32,
15 n: u32,
16 k: u32,
17 _pad: u32,
18}
19
20const SHADER_MATMUL: &str = "
24const TILE: u32 = 16u;
25struct Dims { m: u32, n: u32, k: u32, _pad: u32, }
26@group(0) @binding(0) var<uniform> dims: Dims;
27@group(0) @binding(1) var<storage, read> a: array<f32>;
28@group(0) @binding(2) var<storage, read> b: array<f32>;
29@group(0) @binding(3) var<storage, read_write> out: array<f32>;
30
31var<workgroup> tile_a: array<f32, 256>; // TILE * TILE
32var<workgroup> tile_b: array<f32, 256>;
33
34@compute @workgroup_size(16, 16)
35fn main(
36 @builtin(global_invocation_id) gid: vec3<u32>,
37 @builtin(local_invocation_id) lid: vec3<u32>,
38) {
39 let row = gid.x;
40 let col = gid.y;
41 let lr = lid.x;
42 let lc = lid.y;
43
44 var sum: f32 = 0.0;
45 let num_tiles = (dims.k + TILE - 1u) / TILE;
46
47 for (var t: u32 = 0u; t < num_tiles; t++) {
48 // Load tile of A: row from global row, col from tile offset
49 let a_col = t * TILE + lc;
50 if row < dims.m && a_col < dims.k {
51 tile_a[lr * TILE + lc] = a[row * dims.k + a_col];
52 } else {
53 tile_a[lr * TILE + lc] = 0.0;
54 }
55
56 // Load tile of B: row from tile offset, col from global col
57 let b_row = t * TILE + lr;
58 if b_row < dims.k && col < dims.n {
59 tile_b[lr * TILE + lc] = b[b_row * dims.n + col];
60 } else {
61 tile_b[lr * TILE + lc] = 0.0;
62 }
63
64 workgroupBarrier();
65
66 // Accumulate dot product from shared memory
67 for (var i: u32 = 0u; i < TILE; i++) {
68 sum += tile_a[lr * TILE + i] * tile_b[i * TILE + lc];
69 }
70
71 workgroupBarrier();
72 }
73
74 if row < dims.m && col < dims.n {
75 out[row * dims.n + col] = sum;
76 }
77}
78";
79
80#[repr(C)]
83#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
84struct BatchMatmulDims {
85 batch: u32,
86 m: u32,
87 n: u32,
88 k: u32,
89}
90
91const SHADER_BATCH_MATMUL: &str = "
92const TILE: u32 = 16u;
93struct Dims { batch: u32, m: u32, n: u32, k: u32, }
94@group(0) @binding(0) var<uniform> dims: Dims;
95@group(0) @binding(1) var<storage, read> a: array<f32>;
96@group(0) @binding(2) var<storage, read> b: array<f32>;
97@group(0) @binding(3) var<storage, read_write> out: array<f32>;
98
99var<workgroup> tile_a: array<f32, 256>;
100var<workgroup> tile_b: array<f32, 256>;
101
102@compute @workgroup_size(16, 16)
103fn main(
104 @builtin(global_invocation_id) gid: vec3<u32>,
105 @builtin(local_invocation_id) lid: vec3<u32>,
106) {
107 let row = gid.x;
108 let col = gid.y;
109 let bat = gid.z;
110 if bat >= dims.batch { return; }
111 let lr = lid.x;
112 let lc = lid.y;
113 let a_off = bat * dims.m * dims.k;
114 let b_off = bat * dims.k * dims.n;
115 let o_off = bat * dims.m * dims.n;
116
117 var sum: f32 = 0.0;
118 let num_tiles = (dims.k + TILE - 1u) / TILE;
119
120 for (var t: u32 = 0u; t < num_tiles; t++) {
121 let a_col = t * TILE + lc;
122 if row < dims.m && a_col < dims.k {
123 tile_a[lr * TILE + lc] = a[a_off + row * dims.k + a_col];
124 } else {
125 tile_a[lr * TILE + lc] = 0.0;
126 }
127
128 let b_row = t * TILE + lr;
129 if b_row < dims.k && col < dims.n {
130 tile_b[lr * TILE + lc] = b[b_off + b_row * dims.n + col];
131 } else {
132 tile_b[lr * TILE + lc] = 0.0;
133 }
134
135 workgroupBarrier();
136
137 for (var i: u32 = 0u; i < TILE; i++) {
138 sum += tile_a[lr * TILE + i] * tile_b[i * TILE + lc];
139 }
140
141 workgroupBarrier();
142 }
143
144 if row < dims.m && col < dims.n {
145 out[o_off + row * dims.n + col] = sum;
146 }
147}
148";
149
150#[repr(C)]
153#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
154struct Conv2dParams {
155 batch: u32,
156 in_c: u32,
157 out_c: u32,
158 in_h: u32,
159 in_w: u32,
160 out_h: u32,
161 out_w: u32,
162 kh: u32,
163 kw: u32,
164 stride_h: u32,
165 stride_w: u32,
166 pad_h: u32,
167 pad_w: u32,
168 dilation_h: u32,
169 dilation_w: u32,
170 groups: u32,
171}
172
173const SHADER_CONV2D: &str = "
174struct P {
175 batch: u32, in_c: u32, out_c: u32, in_h: u32,
176 in_w: u32, out_h: u32, out_w: u32, kh: u32,
177 kw: u32, stride_h: u32, stride_w: u32, pad_h: u32,
178 pad_w: u32, dilation_h: u32, dilation_w: u32, groups: u32,
179}
180@group(0) @binding(0) var<uniform> p: P;
181@group(0) @binding(1) var<storage, read> input: array<f32>;
182@group(0) @binding(2) var<storage, read> weight: array<f32>;
183@group(0) @binding(3) var<storage, read> bias: array<f32>;
184@group(0) @binding(4) var<storage, read_write> out: array<f32>;
185@compute @workgroup_size(256)
186fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
187 let idx = gid.x + gid.y * 65535u * 256u;
188 let total = p.batch * p.out_c * p.out_h * p.out_w;
189 if idx >= total { return; }
190
191 let ow = idx % p.out_w;
192 let oh = (idx / p.out_w) % p.out_h;
193 let oc = (idx / (p.out_w * p.out_h)) % p.out_c;
194 let n = idx / (p.out_w * p.out_h * p.out_c);
195
196 let group_in = p.in_c / p.groups;
197 let group_out = p.out_c / p.groups;
198 let g = oc / group_out;
199
200 var sum: f32 = bias[oc];
201 for (var ic: u32 = 0u; ic < group_in; ic++) {
202 for (var kh: u32 = 0u; kh < p.kh; kh++) {
203 for (var kw: u32 = 0u; kw < p.kw; kw++) {
204 let ih = oh * p.stride_h + kh * p.dilation_h - p.pad_h;
205 let iw = ow * p.stride_w + kw * p.dilation_w - p.pad_w;
206 if ih < p.in_h && iw < p.in_w {
207 let in_idx = n * (p.in_c * p.in_h * p.in_w)
208 + (g * group_in + ic) * (p.in_h * p.in_w)
209 + ih * p.in_w + iw;
210 let w_idx = oc * (group_in * p.kh * p.kw)
211 + ic * (p.kh * p.kw)
212 + kh * p.kw + kw;
213 sum += input[in_idx] * weight[w_idx];
214 }
215 }
216 }
217 }
218 out[idx] = sum;
219}
220";
221
222#[repr(C)]
225#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
226struct ConvTranspose2dParams {
227 batch: u32,
228 in_c: u32,
229 out_c: u32,
230 in_h: u32,
231 in_w: u32,
232 out_h: u32,
233 out_w: u32,
234 kh: u32,
235 kw: u32,
236 stride_h: u32,
237 stride_w: u32,
238 pad_h: u32,
239 pad_w: u32,
240 dilation_h: u32,
241 dilation_w: u32,
242 groups: u32,
243}
244
245const SHADER_CONV_TRANSPOSE2D: &str = "
246struct P {
247 batch: u32, in_c: u32, out_c: u32, in_h: u32,
248 in_w: u32, out_h: u32, out_w: u32, kh: u32,
249 kw: u32, stride_h: u32, stride_w: u32, pad_h: u32,
250 pad_w: u32, dilation_h: u32, dilation_w: u32, groups: u32,
251}
252@group(0) @binding(0) var<uniform> p: P;
253@group(0) @binding(1) var<storage, read> input: array<f32>;
254@group(0) @binding(2) var<storage, read> weight: array<f32>;
255@group(0) @binding(3) var<storage, read> bias: array<f32>;
256@group(0) @binding(4) var<storage, read_write> out: array<f32>;
257@compute @workgroup_size(256)
258fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
259 let idx = gid.x + gid.y * 65535u * 256u;
260 let total = p.batch * p.out_c * p.out_h * p.out_w;
261 if idx >= total { return; }
262
263 let ow = idx % p.out_w;
264 let oh = (idx / p.out_w) % p.out_h;
265 let oc = (idx / (p.out_w * p.out_h)) % p.out_c;
266 let n = idx / (p.out_w * p.out_h * p.out_c);
267
268 let group_in = p.in_c / p.groups;
269 let group_out = p.out_c / p.groups;
270 let g = oc / group_out;
271 let oc_local = oc % group_out;
272
273 var sum: f32 = bias[oc];
274 for (var ic: u32 = 0u; ic < group_in; ic++) {
275 for (var kh: u32 = 0u; kh < p.kh; kh++) {
276 for (var kw: u32 = 0u; kw < p.kw; kw++) {
277 // Transposed conv: output pixel (oh,ow) is affected by input pixel (ih,iw)
278 // where oh = ih * stride - pad + kh * dilation
279 // so ih = (oh + pad - kh * dilation) / stride, must be exact division
280 let oh_off = oh + p.pad_h - kh * p.dilation_h;
281 let ow_off = ow + p.pad_w - kw * p.dilation_w;
282 // Check exact divisibility and bounds (unsigned wraparound handles negatives)
283 if oh_off % p.stride_h == 0u && ow_off % p.stride_w == 0u {
284 let ih = oh_off / p.stride_h;
285 let iw = ow_off / p.stride_w;
286 if ih < p.in_h && iw < p.in_w {
287 let in_idx = n * (p.in_c * p.in_h * p.in_w)
288 + (g * group_in + ic) * (p.in_h * p.in_w)
289 + ih * p.in_w + iw;
290 // weight: [in_c, out_c/groups, kh, kw]
291 let w_idx = (g * group_in + ic) * (group_out * p.kh * p.kw)
292 + oc_local * (p.kh * p.kw)
293 + kh * p.kw + kw;
294 sum += input[in_idx] * weight[w_idx];
295 }
296 }
297 }
298 }
299 }
300 out[idx] = sum;
301}
302";
303
304#[repr(C)]
307#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
308struct Conv2dGradParams {
309 batch: u32,
310 in_c: u32,
311 out_c: u32,
312 in_h: u32,
313 in_w: u32,
314 out_h: u32,
315 out_w: u32,
316 kh: u32,
317 kw: u32,
318 stride_h: u32,
319 stride_w: u32,
320 pad_h: u32,
321 pad_w: u32,
322 dilation_h: u32,
323 dilation_w: u32,
324 groups: u32,
325}
326
327#[repr(C)]
328#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
329struct Conv2dGradBiasParams {
330 batch: u32,
331 out_c: u32,
332 out_h: u32,
333 out_w: u32,
334}
335
336const SHADER_CONV2D_GRAD_WEIGHT: &str = "
338struct P {
339 batch: u32, in_c: u32, out_c: u32, in_h: u32,
340 in_w: u32, out_h: u32, out_w: u32, kh: u32,
341 kw: u32, stride_h: u32, stride_w: u32, pad_h: u32,
342 pad_w: u32, dilation_h: u32, dilation_w: u32, groups: u32,
343}
344@group(0) @binding(0) var<uniform> p: P;
345@group(0) @binding(1) var<storage, read> input: array<f32>;
346@group(0) @binding(2) var<storage, read> grad_out: array<f32>;
347@group(0) @binding(3) var<storage, read_write> grad_w: array<f32>;
348@compute @workgroup_size(256)
349fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
350 let idx = gid.x + gid.y * 65535u * 256u;
351 let group_in = p.in_c / p.groups;
352 let group_out = p.out_c / p.groups;
353 let total = p.out_c * group_in * p.kh * p.kw;
354 if idx >= total { return; }
355
356 let kw_i = idx % p.kw;
357 let kh_i = (idx / p.kw) % p.kh;
358 let ic = (idx / (p.kw * p.kh)) % group_in;
359 let oc = idx / (p.kw * p.kh * group_in);
360
361 let g = oc / group_out;
362
363 var sum: f32 = 0.0;
364 for (var n: u32 = 0u; n < p.batch; n++) {
365 for (var oh: u32 = 0u; oh < p.out_h; oh++) {
366 for (var ow: u32 = 0u; ow < p.out_w; ow++) {
367 let ih = oh * p.stride_h + kh_i * p.dilation_h;
368 let iw = ow * p.stride_w + kw_i * p.dilation_w;
369 if ih >= p.pad_h && iw >= p.pad_w {
370 let ih2 = ih - p.pad_h;
371 let iw2 = iw - p.pad_w;
372 if ih2 < p.in_h && iw2 < p.in_w {
373 let in_idx = n * (p.in_c * p.in_h * p.in_w)
374 + (g * group_in + ic) * (p.in_h * p.in_w)
375 + ih2 * p.in_w + iw2;
376 let go_idx = n * (p.out_c * p.out_h * p.out_w)
377 + oc * (p.out_h * p.out_w)
378 + oh * p.out_w + ow;
379 sum += grad_out[go_idx] * input[in_idx];
380 }
381 }
382 }
383 }
384 }
385 grad_w[idx] = sum;
386}
387";
388
389const SHADER_CONV2D_GRAD_BIAS: &str = "
391struct P { batch: u32, out_c: u32, out_h: u32, out_w: u32, }
392@group(0) @binding(0) var<uniform> p: P;
393@group(0) @binding(1) var<storage, read> grad_out: array<f32>;
394@group(0) @binding(2) var<storage, read_write> grad_bias: array<f32>;
395@compute @workgroup_size(256)
396fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
397 let oc = gid.x + gid.y * 65535u * 256u;
398 if oc >= p.out_c { return; }
399 var sum: f32 = 0.0;
400 for (var n: u32 = 0u; n < p.batch; n++) {
401 for (var oh: u32 = 0u; oh < p.out_h; oh++) {
402 for (var ow: u32 = 0u; ow < p.out_w; ow++) {
403 let idx = n * (p.out_c * p.out_h * p.out_w)
404 + oc * (p.out_h * p.out_w)
405 + oh * p.out_w + ow;
406 sum += grad_out[idx];
407 }
408 }
409 }
410 grad_bias[oc] = sum;
411}
412";
413
414impl GpuDevice {
415 pub fn matmul(&self, a: &GpuBuffer, b: &GpuBuffer, m: u32, n: u32, k: u32) -> Result<GpuBuffer> {
417 ensure!(a.len == (m * k) as usize, "matmul: A has {} elems, expected {}", a.len, m * k);
418 ensure!(b.len == (k * n) as usize, "matmul: B has {} elems, expected {}", b.len, k * n);
419 let out = self.alloc((m * n) as usize);
420 let dims = MatmulDims { m, n, k, _pad: 0 };
421 self.dispatch_shader(SHADER_MATMUL, Some("matmul"), &dims, &[a, b], &out, (m.div_ceil(16), n.div_ceil(16), 1));
422 Ok(out)
423 }
424
425 pub fn batch_matmul(&self, a: &GpuBuffer, b: &GpuBuffer, batch: u32, m: u32, n: u32, k: u32) -> Result<GpuBuffer> {
427 ensure!(a.len == (batch * m * k) as usize);
428 ensure!(b.len == (batch * k * n) as usize);
429 let out = self.alloc((batch * m * n) as usize);
430 let dims = BatchMatmulDims { batch, m, n, k };
431 self.dispatch_shader(SHADER_BATCH_MATMUL, Some("batch_matmul"), &dims, &[a, b], &out, (m.div_ceil(16), n.div_ceil(16), batch));
432 Ok(out)
433 }
434
435 pub fn conv2d(
438 &self,
439 input: &GpuBuffer,
440 weight: &GpuBuffer,
441 bias: Option<&GpuBuffer>,
442 batch: u32, in_c: u32, in_h: u32, in_w: u32,
443 out_c: u32, kh: u32, kw: u32,
444 stride: (u32, u32), padding: (u32, u32),
445 dilation: (u32, u32), groups: u32,
446 ) -> Result<GpuBuffer> {
447 let out_h = (in_h + 2 * padding.0 - dilation.0 * (kh - 1) - 1) / stride.0 + 1;
448 let out_w = (in_w + 2 * padding.1 - dilation.1 * (kw - 1) - 1) / stride.1 + 1;
449 let total = batch * out_c * out_h * out_w;
450
451 ensure!(input.len == (batch * in_c * in_h * in_w) as usize);
452 ensure!(weight.len == (out_c * (in_c / groups) * kh * kw) as usize);
453
454 let zero_bias;
455 let bias_buf = match bias {
456 Some(b) => {
457 ensure!(b.len == out_c as usize);
458 b
459 }
460 None => {
461 zero_bias = self.upload(&vec![0.0f32; out_c as usize]);
462 &zero_bias
463 }
464 };
465
466 let out = self.alloc(total as usize);
467 let params = Conv2dParams {
468 batch, in_c, out_c, in_h, in_w, out_h, out_w,
469 kh, kw, stride_h: stride.0, stride_w: stride.1,
470 pad_h: padding.0, pad_w: padding.1,
471 dilation_h: dilation.0, dilation_w: dilation.1, groups,
472 };
473
474 self.dispatch_shader(
475 SHADER_CONV2D, Some("conv2d"),
476 ¶ms, &[input, weight, bias_buf], &out,
477 super::dispatch_1d(total),
478 );
479 Ok(out)
480 }
481
482 pub fn conv_transpose2d(
485 &self,
486 input: &GpuBuffer,
487 weight: &GpuBuffer,
488 bias: Option<&GpuBuffer>,
489 batch: u32, in_c: u32, in_h: u32, in_w: u32,
490 out_c: u32, kh: u32, kw: u32,
491 stride: (u32, u32), padding: (u32, u32),
492 output_padding: (u32, u32),
493 dilation: (u32, u32), groups: u32,
494 ) -> Result<GpuBuffer> {
495 let out_h = (in_h - 1) * stride.0 - 2 * padding.0 + dilation.0 * (kh - 1) + output_padding.0 + 1;
496 let out_w = (in_w - 1) * stride.1 - 2 * padding.1 + dilation.1 * (kw - 1) + output_padding.1 + 1;
497 let total = batch * out_c * out_h * out_w;
498
499 ensure!(input.len == (batch * in_c * in_h * in_w) as usize);
500 ensure!(weight.len == (in_c * (out_c / groups) * kh * kw) as usize);
501
502 let zero_bias;
503 let bias_buf = match bias {
504 Some(b) => {
505 ensure!(b.len == out_c as usize);
506 b
507 }
508 None => {
509 zero_bias = self.upload(&vec![0.0f32; out_c as usize]);
510 &zero_bias
511 }
512 };
513
514 let out = self.alloc(total as usize);
515 let params = ConvTranspose2dParams {
516 batch, in_c, out_c, in_h, in_w, out_h, out_w,
517 kh, kw, stride_h: stride.0, stride_w: stride.1,
518 pad_h: padding.0, pad_w: padding.1,
519 dilation_h: dilation.0, dilation_w: dilation.1, groups,
520 };
521
522 self.dispatch_shader(
523 SHADER_CONV_TRANSPOSE2D, Some("conv_transpose2d"),
524 ¶ms, &[input, weight, bias_buf], &out,
525 super::dispatch_1d(total),
526 );
527 Ok(out)
528 }
529 pub(crate) fn conv2d_grad_weight(
530 &self, input: &GpuBuffer, grad_out: &GpuBuffer,
531 batch: u32, in_c: u32, in_h: u32, in_w: u32,
532 out_c: u32, out_h: u32, out_w: u32, kh: u32, kw: u32,
533 stride_h: u32, stride_w: u32, pad_h: u32, pad_w: u32,
534 dil_h: u32, dil_w: u32, groups: u32,
535 ) -> Result<GpuBuffer> {
536 let group_in = in_c / groups;
537 let total = out_c * group_in * kh * kw;
538 let out = self.alloc(total as usize);
539 let params = Conv2dGradParams {
540 batch, in_c, out_c, in_h, in_w, out_h, out_w,
541 kh, kw, stride_h, stride_w, pad_h, pad_w,
542 dilation_h: dil_h, dilation_w: dil_w, groups,
543 };
544 self.dispatch_shader(
545 SHADER_CONV2D_GRAD_WEIGHT, Some("conv2d_grad_weight"),
546 ¶ms, &[input, grad_out], &out,
547 super::dispatch_1d(total),
548 );
549 Ok(out)
550 }
551
552 pub(crate) fn conv2d_grad_bias(
553 &self, grad_out: &GpuBuffer,
554 batch: u32, out_c: u32, out_h: u32, out_w: u32,
555 ) -> Result<GpuBuffer> {
556 let out = self.alloc(out_c as usize);
557 let params = Conv2dGradBiasParams { batch, out_c, out_h, out_w };
558 self.dispatch_shader(
559 SHADER_CONV2D_GRAD_BIAS, Some("conv2d_grad_bias"),
560 ¶ms, &[grad_out], &out,
561 super::dispatch_1d(out_c),
562 );
563 Ok(out)
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570 use crate::ops::assert_approx;
571
572 fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
573
574 fn cpu_matmul(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
576 let mut out = vec![0.0f32; m * n];
577 for row in 0..m {
578 for col in 0..n {
579 let mut sum = 0.0;
580 for i in 0..k { sum += a[row * k + i] * b[i * n + col]; }
581 out[row * n + col] = sum;
582 }
583 }
584 out
585 }
586
587 fn cpu_conv2d(
589 input: &[f32], weight: &[f32], bias: &[f32],
590 batch: usize, in_c: usize, in_h: usize, in_w: usize,
591 out_c: usize, kh: usize, kw: usize,
592 stride: (usize, usize), padding: (usize, usize), groups: usize,
593 ) -> Vec<f32> {
594 let out_h = (in_h + 2 * padding.0 - kh) / stride.0 + 1;
595 let out_w = (in_w + 2 * padding.1 - kw) / stride.1 + 1;
596 let group_in = in_c / groups;
597 let group_out = out_c / groups;
598 let mut out = vec![0.0f32; batch * out_c * out_h * out_w];
599 for n in 0..batch {
600 for oc in 0..out_c {
601 let g = oc / group_out;
602 for oh in 0..out_h {
603 for ow in 0..out_w {
604 let mut sum = bias[oc];
605 for ic in 0..group_in {
606 for kr in 0..kh {
607 for kc in 0..kw {
608 let ih = oh * stride.0 + kr;
609 let iw = ow * stride.1 + kc;
610 let ih = ih as isize - padding.0 as isize;
611 let iw = iw as isize - padding.1 as isize;
612 if ih >= 0 && ih < in_h as isize && iw >= 0 && iw < in_w as isize {
613 let in_idx = n * in_c * in_h * in_w
614 + (g * group_in + ic) * in_h * in_w
615 + ih as usize * in_w + iw as usize;
616 let w_idx = oc * group_in * kh * kw + ic * kh * kw + kr * kw + kc;
617 sum += input[in_idx] * weight[w_idx];
618 }
619 }
620 }
621 }
622 out[n * out_c * out_h * out_w + oc * out_h * out_w + oh * out_w + ow] = sum;
623 }
624 }
625 }
626 }
627 out
628 }
629
630 #[test]
633 fn test_matmul_2x2() {
634 let a = dev().upload(&[1.0, 2.0, 3.0, 4.0]);
635 let b = dev().upload(&[5.0, 6.0, 7.0, 8.0]);
636 let result = dev().read(&dev().matmul(&a, &b, 2, 2, 2).unwrap()).unwrap();
637 assert_eq!(result, vec![19.0, 22.0, 43.0, 50.0]);
638 }
639
640 #[test]
641 fn test_matmul_nonsquare_vs_cpu() {
642 let a: Vec<f32> = (1..=12).map(|x| x as f32).collect();
644 let b: Vec<f32> = (1..=8).map(|x| x as f32 * 0.1).collect();
645 let expected = cpu_matmul(&a, &b, 3, 2, 4);
646 let result = dev().read(&dev().matmul(&dev().upload(&a), &dev().upload(&b), 3, 2, 4).unwrap()).unwrap();
647 assert_approx(&result, &expected, 1e-4);
648 }
649
650 #[test]
651 fn test_matmul_1x1() {
652 let result = dev().read(&dev().matmul(&dev().upload(&[3.0]), &dev().upload(&[7.0]), 1, 1, 1).unwrap()).unwrap();
653 assert_eq!(result, vec![21.0]);
654 }
655
656 #[test]
657 fn test_matmul_17x13_vs_cpu() {
658 let m = 17; let n = 13; let k = 11;
660 let a: Vec<f32> = (0..m*k).map(|i| (i as f32 * 0.01) - 0.5).collect();
661 let b: Vec<f32> = (0..k*n).map(|i| (i as f32 * 0.01) - 0.3).collect();
662 let expected = cpu_matmul(&a, &b, m, n, k);
663 let result = dev().read(&dev().matmul(
664 &dev().upload(&a), &dev().upload(&b), m as u32, n as u32, k as u32
665 ).unwrap()).unwrap();
666 assert_approx(&result, &expected, 1e-3);
667 }
668
669 #[test]
672 fn test_batch_matmul() {
673 let a = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
674 let b = dev().upload(&[1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0]);
675 let result = dev().read(&dev().batch_matmul(&a, &b, 2, 2, 2, 2).unwrap()).unwrap();
676 assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0, 10.0, 12.0, 14.0, 16.0]);
677 }
678
679 #[test]
680 fn test_batch_matmul_nonsquare() {
681 let a = dev().upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
683 let b = dev().upload(&[1.0, 1.0, 1.0]);
684 let result = dev().read(&dev().batch_matmul(&a, &b, 1, 2, 1, 3).unwrap()).unwrap();
685 assert_eq!(result, vec![6.0, 15.0]);
686 }
687
688 #[test]
691 fn test_conv2d_3x3_vs_cpu() {
692 let input: Vec<f32> = (1..=16).map(|x| x as f32).collect();
694 let weight = vec![1.0, 0.0, -1.0, 1.0, 0.0, -1.0, 1.0, 0.0, -1.0];
695 let bias = vec![0.0];
696 let expected = cpu_conv2d(&input, &weight, &bias, 1, 1, 4, 4, 1, 3, 3, (1,1), (0,0), 1);
697 let result = dev().read(&dev().conv2d(
698 &dev().upload(&input), &dev().upload(&weight), Some(&dev().upload(&bias)),
699 1, 1, 4, 4, 1, 3, 3, (1,1), (0,0), (1,1), 1
700 ).unwrap()).unwrap();
701 assert_approx(&result, &expected, 1e-5);
702 }
703
704 #[test]
705 fn test_conv2d_1x1_kernel() {
706 let input = dev().upload(&[1.0, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0]);
709 let weight = dev().upload(&[1.0, 0.5, 0.0, 1.0, -1.0, 2.0]);
711 let bias = dev().upload(&[0.0, 0.0, 0.0]);
712 let result = dev().read(&dev().conv2d(&input, &weight, Some(&bias),
713 1, 2, 2, 2, 3, 1, 1, (1,1), (0,0), (1,1), 1).unwrap()).unwrap();
714 assert_approx(&result[0..4], &[6.0, 12.0, 18.0, 24.0], 1e-5);
716 assert_approx(&result[4..8], &[10.0, 20.0, 30.0, 40.0], 1e-5);
718 }
719
720 #[test]
721 fn test_conv2d_padding_vs_cpu() {
722 let input: Vec<f32> = (1..=9).map(|x| x as f32).collect();
723 let weight = vec![1.0; 9];
724 let bias = vec![0.0];
725 let expected = cpu_conv2d(&input, &weight, &bias, 1, 1, 3, 3, 1, 3, 3, (1,1), (1,1), 1);
726 let result = dev().read(&dev().conv2d(
727 &dev().upload(&input), &dev().upload(&weight), None,
728 1, 1, 3, 3, 1, 3, 3, (1,1), (1,1), (1,1), 1
729 ).unwrap()).unwrap();
730 assert_approx(&result, &expected, 1e-5);
731 }
732
733 #[test]
734 fn test_conv2d_stride2() {
735 let input: Vec<f32> = (1..=16).map(|x| x as f32).collect();
736 let result = dev().read(&dev().conv2d(
737 &dev().upload(&input), &dev().upload(&[1.0]), None,
738 1, 1, 4, 4, 1, 1, 1, (2,2), (0,0), (1,1), 1
739 ).unwrap()).unwrap();
740 assert_eq!(result, vec![1.0, 3.0, 9.0, 11.0]);
741 }
742
743 #[test]
744 fn test_conv2d_5x5_kernel_vs_cpu() {
745 let input: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1).collect();
747 let weight: Vec<f32> = (0..25).map(|i| if i == 12 { 1.0 } else { 0.0 }).collect(); let bias = vec![0.0];
749 let expected = cpu_conv2d(&input, &weight, &bias, 1, 1, 8, 8, 1, 5, 5, (1,1), (2,2), 1);
750 let result = dev().read(&dev().conv2d(
751 &dev().upload(&input), &dev().upload(&weight), None,
752 1, 1, 8, 8, 1, 5, 5, (1,1), (2,2), (1,1), 1
753 ).unwrap()).unwrap();
754 assert_approx(&result, &expected, 1e-4);
755 }
756
757 #[test]
758 fn test_conv2d_multichannel_vs_cpu() {
759 let batch = 2; let in_c = 3; let out_c = 2; let h = 4; let w = 4;
761 let input: Vec<f32> = (0..batch*in_c*h*w).map(|i| (i as f32) * 0.01 - 0.5).collect();
762 let weight: Vec<f32> = (0..out_c*in_c*3*3).map(|i| (i as f32) * 0.02 - 0.3).collect();
763 let bias = vec![0.1, -0.2];
764 let expected = cpu_conv2d(&input, &weight, &bias, batch, in_c, h, w, out_c, 3, 3, (1,1), (1,1), 1);
765 let result = dev().read(&dev().conv2d(
766 &dev().upload(&input), &dev().upload(&weight), Some(&dev().upload(&bias)),
767 batch as u32, in_c as u32, h as u32, w as u32, out_c as u32, 3, 3, (1,1), (1,1), (1,1), 1
768 ).unwrap()).unwrap();
769 assert_approx(&result, &expected, 1e-3);
770 }
771
772 #[test]
773 fn test_conv2d_with_bias() {
774 let input = dev().upload(&[0.0; 4]); let weight = dev().upload(&[0.0]); let bias = dev().upload(&[42.0]);
778 let result = dev().read(&dev().conv2d(&input, &weight, Some(&bias),
779 1, 1, 2, 2, 1, 1, 1, (1,1), (0,0), (1,1), 1).unwrap()).unwrap();
780 assert_eq!(result, vec![42.0, 42.0, 42.0, 42.0]);
781 }
782
783 #[test]
786 fn test_conv_transpose2d_stride2() {
787 let input = dev().upload(&[1.0, 2.0, 3.0, 4.0]);
788 let weight = dev().upload(&[1.0]);
789 let result = dev().read(&dev().conv_transpose2d(&input, &weight, None,
790 1, 1, 2, 2, 1, 1, 1, (2,2), (0,0), (0,0), (1,1), 1).unwrap()).unwrap();
791 assert_eq!(result.len(), 9);
792 assert_approx(&result, &[1.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0, 0.0, 4.0], 1e-5);
793 }
794
795 #[test]
796 fn test_conv_transpose2d_3x3_kernel() {
797 let input = dev().upload(&[5.0]);
798 let weight = dev().upload(&[1.0; 9]);
799 let result = dev().read(&dev().conv_transpose2d(&input, &weight, None,
800 1, 1, 1, 1, 1, 3, 3, (1,1), (0,0), (0,0), (1,1), 1).unwrap()).unwrap();
801 assert_eq!(result.len(), 9);
802 assert_approx(&result, &[5.0; 9], 1e-5);
803 }
804
805 #[test]
808 fn test_matmul_a_size_mismatch() {
809 let a = dev().upload(&[1.0, 2.0, 3.0]); let b = dev().upload(&[1.0, 2.0, 3.0, 4.0]); assert!(dev().matmul(&a, &b, 2, 2, 2).is_err()); }
813
814 #[test]
815 fn test_matmul_b_size_mismatch() {
816 let a = dev().upload(&[1.0, 2.0, 3.0, 4.0]);
817 let b = dev().upload(&[1.0, 2.0, 3.0]); assert!(dev().matmul(&a, &b, 2, 2, 2).is_err());
819 }
820
821 #[test]
822 fn test_conv2d_input_size_mismatch() {
823 let input = dev().upload(&[1.0; 10]); let weight = dev().upload(&[1.0; 9]);
825 assert!(dev().conv2d(&input, &weight, None, 1, 1, 4, 4, 1, 3, 3, (1,1), (0,0), (1,1), 1).is_err());
826 }
827
828 #[test]
829 fn test_conv2d_weight_size_mismatch() {
830 let input = dev().upload(&[1.0; 16]);
831 let weight = dev().upload(&[1.0; 5]); assert!(dev().conv2d(&input, &weight, None, 1, 1, 4, 4, 1, 3, 3, (1,1), (0,0), (1,1), 1).is_err());
833 }
834
835 #[test]
836 fn test_conv2d_bias_size_mismatch() {
837 let input = dev().upload(&[1.0; 16]);
838 let weight = dev().upload(&[1.0; 9]);
839 let bias = dev().upload(&[1.0, 2.0]); assert!(dev().conv2d(&input, &weight, Some(&bias), 1, 1, 4, 4, 1, 3, 3, (1,1), (0,0), (1,1), 1).is_err());
841 }
842
843 #[test]
846 fn test_batch_matmul_vs_cpu() {
847 let batch = 3; let m = 4; let n = 3; let k = 5;
848 let a: Vec<f32> = (0..batch*m*k).map(|i| (i as f32) * 0.01 - 0.3).collect();
849 let b: Vec<f32> = (0..batch*k*n).map(|i| (i as f32) * 0.02 - 0.1).collect();
850 let mut expected = vec![0.0f32; batch * m * n];
851 for bat in 0..batch {
852 for row in 0..m {
853 for col in 0..n {
854 let mut sum = 0.0;
855 for i in 0..k {
856 sum += a[bat*m*k + row*k + i] * b[bat*k*n + i*n + col];
857 }
858 expected[bat*m*n + row*n + col] = sum;
859 }
860 }
861 }
862 let result = dev().read(&dev().batch_matmul(
863 &dev().upload(&a), &dev().upload(&b), batch as u32, m as u32, n as u32, k as u32
864 ).unwrap()).unwrap();
865 assert_approx(&result, &expected, 1e-3);
866 }
867
868 #[test]
871 fn test_conv2d_grad_weight_vs_numeric() {
872 let input_data: Vec<f32> = (1..=16).map(|x| x as f32 * 0.1).collect();
875 let weight_data = vec![0.5f32, -0.5, 0.3, -0.3];
876 let eps = 1e-3f32;
877
878 let input_buf = dev().upload(&input_data);
880 let weight_buf = dev().upload(&weight_data);
881 let out = dev().conv2d(&input_buf, &weight_buf, None, 1, 1, 4, 4, 1, 2, 2, (1,1), (0,0), (1,1), 1).unwrap();
882 let grad_out_data = vec![1.0f32; 9]; let grad_out_buf = dev().upload(&grad_out_data);
885 let gw = dev().conv2d_grad_weight(&input_buf, &grad_out_buf, 1, 1, 4, 4, 1, 3, 3, 2, 2, 1, 1, 0, 0, 1, 1, 1).unwrap();
886 let gw_data = dev().read(&gw).unwrap();
887
888 for i in 0..4 {
890 let mut wp = weight_data.clone();
891 let mut wm = weight_data.clone();
892 wp[i] += eps;
893 wm[i] -= eps;
894 let wp_buf = dev().upload(&wp);
895 let wm_buf = dev().upload(&wm);
896 let outp = dev().read(&dev().conv2d(&input_buf, &wp_buf, None, 1, 1, 4, 4, 1, 2, 2, (1,1), (0,0), (1,1), 1).unwrap()).unwrap();
897 let outm = dev().read(&dev().conv2d(&input_buf, &wm_buf, None, 1, 1, 4, 4, 1, 2, 2, (1,1), (0,0), (1,1), 1).unwrap()).unwrap();
898 let numeric: f32 = (outp.iter().sum::<f32>() - outm.iter().sum::<f32>()) / (2.0 * eps);
900 assert!((gw_data[i] - numeric).abs() < 1e-2,
901 "grad_w[{i}]: analytical={}, numeric={}", gw_data[i], numeric);
902 }
903 let _ = out;
904 }
905
906 #[test]
907 fn test_conv2d_grad_bias_basic() {
908 let grad_out_data = vec![1.0f32, 2.0, 3.0, 4.0];
911 let grad_out_buf = dev().upload(&grad_out_data);
912 let gb = dev().conv2d_grad_bias(&grad_out_buf, 1, 1, 2, 2).unwrap();
913 let gb_data = dev().read(&gb).unwrap();
914 assert_approx(&gb_data, &[10.0], 1e-5);
915
916 let grad_out2 = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
919 let gb2 = dev().conv2d_grad_bias(&dev().upload(&grad_out2), 1, 2, 2, 2).unwrap();
920 let gb2_data = dev().read(&gb2).unwrap();
921 assert_approx(&gb2_data, &[10.0, 26.0], 1e-5);
922 }
923}