1pub fn gemm_wgsl(tile_size: u32) -> String {
30 format!(
31 r#"
32struct GemmParams {{
33 m: u32,
34 n: u32,
35 k: u32,
36 alpha: f32,
37 beta: f32,
38 trans_a: u32,
39 trans_b: u32,
40 _pad: u32,
41}}
42
43@group(0) @binding(0) var<storage, read> a: array<f32>;
44@group(0) @binding(1) var<storage, read> b: array<f32>;
45@group(0) @binding(2) var<storage, read_write> c: array<f32>;
46@group(0) @binding(3) var<uniform> params: GemmParams;
47
48var<workgroup> tile_a: array<array<f32, {ts}>, {ts}>;
49var<workgroup> tile_b: array<array<f32, {ts}>, {ts}>;
50
51// op(A)[r, i] — logical m×k left operand.
52fn load_a(r: u32, i: u32) -> f32 {{
53 if (r >= params.m || i >= params.k) {{ return 0.0; }}
54 if (params.trans_a == 0u) {{
55 return a[r * params.k + i];
56 }}
57 return a[i * params.m + r];
58}}
59
60// op(B)[i, col] — logical k×n right operand.
61fn load_b(i: u32, col: u32) -> f32 {{
62 if (i >= params.k || col >= params.n) {{ return 0.0; }}
63 if (params.trans_b == 0u) {{
64 return b[i * params.n + col];
65 }}
66 return b[col * params.k + i];
67}}
68
69@compute @workgroup_size({ts}, {ts})
70fn main(
71 @builtin(global_invocation_id) gid: vec3<u32>,
72 @builtin(local_invocation_id) lid: vec3<u32>,
73) {{
74 let row = gid.y;
75 let col = gid.x;
76 let lr = lid.y;
77 let lc = lid.x;
78
79 var acc: f32 = 0.0;
80 let num_tiles = (params.k + {ts}u - 1u) / {ts}u;
81 for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {{
82 let a_col = t * {ts}u + lc;
83 let b_row = t * {ts}u + lr;
84 tile_a[lr][lc] = load_a(row, a_col);
85 tile_b[lr][lc] = load_b(b_row, col);
86 workgroupBarrier();
87
88 for (var e: u32 = 0u; e < {ts}u; e = e + 1u) {{
89 acc += tile_a[lr][e] * tile_b[e][lc];
90 }}
91 workgroupBarrier();
92 }}
93
94 if (row >= params.m || col >= params.n) {{ return; }}
95 let idx = row * params.n + col;
96 c[idx] = params.alpha * acc + params.beta * c[idx];
97}}
98"#,
99 ts = tile_size
100 )
101}
102
103pub fn batched_gemm_wgsl(tile_size: u32) -> String {
118 format!(
119 r#"
120struct BatchedGemmParams {{
121 m: u32,
122 n: u32,
123 k: u32,
124 alpha: f32,
125 beta: f32,
126 batch_count: u32,
127 stride_a: u32,
128 stride_b: u32,
129 stride_c: u32,
130 trans_a: u32,
131 trans_b: u32,
132}}
133
134@group(0) @binding(0) var<storage, read> a: array<f32>;
135@group(0) @binding(1) var<storage, read> b: array<f32>;
136@group(0) @binding(2) var<storage, read_write> c: array<f32>;
137@group(0) @binding(3) var<uniform> params: BatchedGemmParams;
138
139var<workgroup> tile_a: array<array<f32, {ts}>, {ts}>;
140var<workgroup> tile_b: array<array<f32, {ts}>, {ts}>;
141
142// op(A_b)[r, i] — logical m×k left operand for batch `a_offset`.
143fn load_a(a_offset: u32, r: u32, i: u32) -> f32 {{
144 if (r >= params.m || i >= params.k) {{ return 0.0; }}
145 if (params.trans_a == 0u) {{
146 return a[a_offset + r * params.k + i];
147 }}
148 return a[a_offset + i * params.m + r];
149}}
150
151// op(B_b)[i, col] — logical k×n right operand for batch `b_offset`.
152fn load_b(b_offset: u32, i: u32, col: u32) -> f32 {{
153 if (i >= params.k || col >= params.n) {{ return 0.0; }}
154 if (params.trans_b == 0u) {{
155 return b[b_offset + i * params.n + col];
156 }}
157 return b[b_offset + col * params.k + i];
158}}
159
160@compute @workgroup_size({ts}, {ts})
161fn main(
162 @builtin(global_invocation_id) gid: vec3<u32>,
163 @builtin(local_invocation_id) lid: vec3<u32>,
164) {{
165 let row = gid.y;
166 let col = gid.x;
167 let batch_index = gid.z;
168 let lr = lid.y;
169 let lc = lid.x;
170 if (batch_index >= params.batch_count) {{ return; }}
171
172 let a_offset = batch_index * params.stride_a;
173 let b_offset = batch_index * params.stride_b;
174 let c_offset = batch_index * params.stride_c;
175
176 var acc: f32 = 0.0;
177 let num_tiles = (params.k + {ts}u - 1u) / {ts}u;
178 for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {{
179 let a_col = t * {ts}u + lc;
180 let b_row = t * {ts}u + lr;
181 tile_a[lr][lc] = load_a(a_offset, row, a_col);
182 tile_b[lr][lc] = load_b(b_offset, b_row, col);
183 workgroupBarrier();
184
185 for (var e: u32 = 0u; e < {ts}u; e = e + 1u) {{
186 acc += tile_a[lr][e] * tile_b[e][lc];
187 }}
188 workgroupBarrier();
189 }}
190
191 if (row >= params.m || col >= params.n) {{ return; }}
192 let idx = c_offset + row * params.n + col;
193 c[idx] = params.alpha * acc + params.beta * c[idx];
194}}
195"#,
196 ts = tile_size
197 )
198}
199
200pub fn gemm_wgsl_f16(tile_size: u32) -> String {
209 format!(
210 r#"
211enable f16;
212
213struct GemmParams {{
214 m: u32,
215 n: u32,
216 k: u32,
217 alpha: f32,
218 beta: f32,
219}}
220
221@group(0) @binding(0) var<storage, read> a: array<f16>;
222@group(0) @binding(1) var<storage, read> b: array<f16>;
223@group(0) @binding(2) var<storage, read_write> c: array<f16>;
224@group(0) @binding(3) var<uniform> params: GemmParams;
225
226@compute @workgroup_size({ts}, {ts})
227fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
228 let row = gid.y;
229 let col = gid.x;
230 if (row >= params.m || col >= params.n) {{ return; }}
231
232 var acc: f32 = 0.0;
233 for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
234 acc += f32(a[row * params.k + i]) * f32(b[i * params.n + col]);
235 }}
236
237 let idx = row * params.n + col;
238 let prev = f32(c[idx]);
239 c[idx] = f16(params.alpha * acc + params.beta * prev);
240}}
241"#,
242 ts = tile_size
243 )
244}
245
246pub fn elementwise_wgsl(op: &str) -> String {
256 let op_expr = match op {
257 "relu" => "max(x, 0.0)",
258 "sigmoid" => "1.0 / (1.0 + exp(-x))",
259 "tanh" => "tanh(x)",
260 "exp" => "exp(x)",
261 "log" => "log(x)",
262 "sqrt" => "sqrt(x)",
263 "abs" => "abs(x)",
264 "neg" => "-x",
265 _ => "x",
266 };
267
268 format!(
269 r#"
270@group(0) @binding(0) var<storage, read> input: array<f32>;
271@group(0) @binding(1) var<storage, read_write> output: array<f32>;
272
273@compute @workgroup_size(256)
274fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
275 let i = gid.x;
276 if (i >= arrayLength(&input)) {{ return; }}
277 let x = input[i];
278 output[i] = {op};
279}}
280"#,
281 op = op_expr
282 )
283}
284
285pub fn binary_wgsl(op: &str) -> String {
295 let op_expr = match op {
296 "add" => "a + b",
297 "sub" => "a - b",
298 "mul" => "a * b",
299 "div" => "a / b",
300 "max" => "max(a, b)",
301 "min" => "min(a, b)",
302 "pow" => "pow(a, b)",
303 _ => "a",
304 };
305
306 format!(
307 r#"
308@group(0) @binding(0) var<storage, read> lhs: array<f32>;
309@group(0) @binding(1) var<storage, read> rhs: array<f32>;
310@group(0) @binding(2) var<storage, read_write> output: array<f32>;
311
312@compute @workgroup_size(256)
313fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
314 let i = gid.x;
315 if (i >= arrayLength(&lhs)) {{ return; }}
316 let a = lhs[i];
317 let b = rhs[i];
318 output[i] = {op};
319}}
320"#,
321 op = op_expr
322 )
323}
324
325pub fn reduction_wgsl(op: &str) -> String {
338 let (neutral, combine) = match op {
340 "max" => ("f32(-1e38)", "max(acc, val)"),
341 "min" => ("f32(1e38)", "min(acc, val)"),
342 _ => ("f32(0.0)", "acc + val"),
344 };
345
346 format!(
347 r#"
348// Reduction params: total element count.
349struct ReduceParams {{
350 n: u32,
351}}
352
353@group(0) @binding(0) var<storage, read> input: array<f32>;
354@group(0) @binding(1) var<storage, read_write> partial_sums: array<f32>;
355@group(0) @binding(2) var<uniform> params: ReduceParams;
356
357var<workgroup> shared_data: array<f32, 256>;
358
359@compute @workgroup_size(256)
360fn main(
361 @builtin(global_invocation_id) gid: vec3<u32>,
362 @builtin(local_invocation_id) lid: vec3<u32>,
363 @builtin(workgroup_id) wgid: vec3<u32>,
364) {{
365 let tid = lid.x;
366 let global_idx = gid.x;
367
368 // Load or use neutral element when out of range.
369 if (global_idx < params.n) {{
370 shared_data[tid] = input[global_idx];
371 }} else {{
372 shared_data[tid] = {neutral};
373 }}
374 workgroupBarrier();
375
376 // Parallel tree reduction within the workgroup.
377 var stride: u32 = 128u;
378 loop {{
379 if (stride == 0u) {{ break; }}
380 if (tid < stride) {{
381 let acc = shared_data[tid];
382 let val = shared_data[tid + stride];
383 shared_data[tid] = {combine};
384 }}
385 workgroupBarrier();
386 stride = stride >> 1u;
387 }}
388
389 // Thread 0 writes the workgroup result to the partial-sums buffer.
390 if (tid == 0u) {{
391 partial_sums[wgid.x] = shared_data[0];
392 }}
393}}
394"#,
395 neutral = neutral,
396 combine = combine,
397 )
398}
399
400#[allow(clippy::too_many_arguments)]
417pub fn conv2d_wgsl(
418 n: u32,
419 c_in: u32,
420 h_in: u32,
421 w_in: u32,
422 k_out: u32,
423 fh: u32,
424 fw: u32,
425 oh: u32,
426 ow: u32,
427 stride_h: u32,
428 stride_w: u32,
429 pad_h: u32,
430 pad_w: u32,
431) -> String {
432 format!(
433 r#"
434// Conv2D NCHW — generated by oxicuda-webgpu
435// input : [{n}, {c_in}, {h_in}, {w_in}]
436// filter: [{k_out}, {c_in}, {fh}, {fw}]
437// output: [{n}, {k_out}, {oh}, {ow}]
438
439@group(0) @binding(0) var<storage, read> input: array<f32>;
440@group(0) @binding(1) var<storage, read> filter: array<f32>;
441@group(0) @binding(2) var<storage, read_write> output: array<f32>;
442
443@compute @workgroup_size(8, 8)
444fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
445 // gid.x = output x (ox mapped across batches*k_out*oh)
446 // We flatten (batch, k, oy) into gid.y and ox into gid.x
447 let ox = gid.x;
448 let linear_y = gid.y;
449
450 let batch_k_oh = {n}u * {k_out}u * {oh}u;
451 if (ox >= {ow}u || linear_y >= batch_k_oh) {{ return; }}
452
453 let b = linear_y / ({k_out}u * {oh}u);
454 let rem = linear_y % ({k_out}u * {oh}u);
455 let kf = rem / {oh}u;
456 let oy = rem % {oh}u;
457
458 var acc: f32 = 0.0;
459 for (var ci: u32 = 0u; ci < {c_in}u; ci = ci + 1u) {{
460 for (var fy: u32 = 0u; fy < {fh}u; fy = fy + 1u) {{
461 for (var fx: u32 = 0u; fx < {fw}u; fx = fx + 1u) {{
462 let iy_raw = i32(oy * {stride_h}u + fy) - i32({pad_h}u);
463 let ix_raw = i32(ox * {stride_w}u + fx) - i32({pad_w}u);
464 if (iy_raw >= 0 && iy_raw < i32({h_in}u) && ix_raw >= 0 && ix_raw < i32({w_in}u)) {{
465 let iy = u32(iy_raw);
466 let ix = u32(ix_raw);
467 let in_idx = ((b * {c_in}u + ci) * {h_in}u + iy) * {w_in}u + ix;
468 let f_idx = ((kf * {c_in}u + ci) * {fh}u + fy) * {fw}u + fx;
469 acc += input[in_idx] * filter[f_idx];
470 }}
471 }}
472 }}
473 }}
474
475 let o_idx = ((b * {k_out}u + kf) * {oh}u + oy) * {ow}u + ox;
476 output[o_idx] = acc;
477}}
478"#,
479 n = n,
480 c_in = c_in,
481 h_in = h_in,
482 w_in = w_in,
483 k_out = k_out,
484 fh = fh,
485 fw = fw,
486 oh = oh,
487 ow = ow,
488 stride_h = stride_h,
489 stride_w = stride_w,
490 pad_h = pad_h,
491 pad_w = pad_w,
492 )
493}
494
495pub fn attention_wgsl(
511 batch_heads: u32,
512 seq_q: u32,
513 seq_kv: u32,
514 head_dim: u32,
515 scale: f32,
516 causal: bool,
517) -> String {
518 let causal_check = if causal {
519 "if (sk > sq) { score = f32(-1e38); } else {"
520 } else {
521 "{"
522 };
523
524 format!(
525 r#"
526// Scaled dot-product attention — generated by oxicuda-webgpu
527// Q, K, V : [{batch_heads}, seq, {head_dim}]
528// O : [{batch_heads}, {seq_q}, {head_dim}]
529// scale : {scale}
530// causal : {causal}
531
532@group(0) @binding(0) var<storage, read> q_buf: array<f32>;
533@group(0) @binding(1) var<storage, read> k_buf: array<f32>;
534@group(0) @binding(2) var<storage, read> v_buf: array<f32>;
535@group(0) @binding(3) var<storage, read_write> o_buf: array<f32>;
536
537@compute @workgroup_size(64)
538fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
539 let linear = gid.x;
540 let total = {batch_heads}u * {seq_q}u;
541 if (linear >= total) {{ return; }}
542
543 let bh = linear / {seq_q}u;
544 let sq = linear % {seq_q}u;
545
546 let q_base = (bh * {seq_q}u + sq) * {head_dim}u;
547
548 // Pass 1: find max score for numerical stability
549 var max_score: f32 = f32(-1e38);
550 for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
551 var score: f32 = 0.0;
552 {causal_check}
553 let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
554 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
555 score += q_buf[q_base + d] * k_buf[k_base + d];
556 }}
557 score *= f32({scale});
558 }}
559 if (score > max_score) {{ max_score = score; }}
560 }}
561
562 // Pass 2: compute exp(score - max), accumulate weighted V
563 var sum_exp: f32 = 0.0;
564 for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
565 var score: f32 = 0.0;
566 {causal_check}
567 let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
568 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
569 score += q_buf[q_base + d] * k_buf[k_base + d];
570 }}
571 score *= f32({scale});
572 }}
573 let w = exp(score - max_score);
574 sum_exp += w;
575 let v_base = (bh * {seq_kv}u + sk) * {head_dim}u;
576 let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
577 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
578 // Accumulate in-place (we normalise after the loop).
579 o_buf[o_base + d] += w * v_buf[v_base + d];
580 }}
581 }}
582
583 // Pass 3: normalise
584 if (sum_exp > 0.0) {{
585 let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
586 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
587 o_buf[o_base + d] /= sum_exp;
588 }}
589 }}
590}}
591"#,
592 batch_heads = batch_heads,
593 seq_q = seq_q,
594 seq_kv = seq_kv,
595 head_dim = head_dim,
596 scale = scale,
597 causal = causal,
598 causal_check = causal_check,
599 )
600}
601
602pub fn reduction_nd_wgsl(op: &str) -> String {
626 let (neutral, combine, combine_alias) = match op {
631 "max" => ("f32(-1e38)", "max(acc, val)", "max(acc2, val)"),
632 "min" => ("f32(1e38)", "min(acc, val)", "min(acc2, val)"),
633 _ => ("f32(0.0)", "acc + val", "acc2 + val"),
635 };
636
637 let final_expr = if op == "mean" {
639 "shared_data[0] / f32(params.dk)"
640 } else {
641 "shared_data[0]"
642 };
643
644 format!(
645 r#"
646struct ReduceNdParams {{
647 outer: u32,
648 dk: u32,
649 inner: u32,
650 outer_stride: u32,
651 dk_stride: u32,
652 inner_stride: u32,
653 grid_x: u32,
654 _pad: u32,
655}}
656
657@group(0) @binding(0) var<storage, read> input: array<f32>;
658@group(0) @binding(1) var<storage, read_write> output: array<f32>;
659@group(0) @binding(2) var<uniform> params: ReduceNdParams;
660
661var<workgroup> shared_data: array<f32, 256>;
662
663@compute @workgroup_size(256)
664fn main(
665 @builtin(local_invocation_id) lid: vec3<u32>,
666 @builtin(workgroup_id) wgid: vec3<u32>,
667) {{
668 let tid = lid.x;
669 let total = params.outer * params.inner;
670
671 // Decode 2-D workgroup id back to a linear output slot.
672 let slot = wgid.y * params.grid_x + wgid.x;
673 if (slot >= total) {{ return; }}
674
675 let o = slot / params.inner;
676 let j = slot % params.inner;
677 let base = o * params.outer_stride + j * params.inner_stride;
678
679 // Strided per-thread reduction across the dk axis.
680 var acc: f32 = {neutral};
681 var i: u32 = tid;
682 loop {{
683 if (i >= params.dk) {{ break; }}
684 let val = input[base + i * params.dk_stride];
685 acc = {combine};
686 i = i + 256u;
687 }}
688
689 shared_data[tid] = acc;
690 workgroupBarrier();
691
692 // Tree reduction within the workgroup.
693 var stride: u32 = 128u;
694 loop {{
695 if (stride == 0u) {{ break; }}
696 if (tid < stride) {{
697 let acc2 = shared_data[tid];
698 let val = shared_data[tid + stride];
699 shared_data[tid] = {combine_alias};
700 }}
701 workgroupBarrier();
702 stride = stride >> 1u;
703 }}
704
705 if (tid == 0u) {{
706 output[slot] = {final_expr};
707 }}
708}}
709"#,
710 neutral = neutral,
711 combine = combine,
712 combine_alias = combine_alias,
713 final_expr = final_expr,
714 )
715}
716
717pub fn reduction_final_wgsl(op: &str) -> String {
723 let (neutral, combine) = match op {
724 "max" => ("f32(-1e38)", "max(acc, val)"),
725 "min" => ("f32(1e38)", "min(acc, val)"),
726 _ => ("f32(0.0)", "acc + val"),
727 };
728
729 format!(
730 r#"
731struct FinalReduceParams {{
732 num_groups: u32,
733}}
734
735@group(0) @binding(0) var<storage, read> partial_sums: array<f32>;
736@group(0) @binding(1) var<storage, read_write> output: array<f32>;
737@group(0) @binding(2) var<uniform> params: FinalReduceParams;
738
739var<workgroup> shared_data: array<f32, 256>;
740
741@compute @workgroup_size(256)
742fn main(
743 @builtin(local_invocation_id) lid: vec3<u32>,
744) {{
745 let tid = lid.x;
746
747 if (tid < params.num_groups) {{
748 shared_data[tid] = partial_sums[tid];
749 }} else {{
750 shared_data[tid] = {neutral};
751 }}
752 workgroupBarrier();
753
754 var stride: u32 = 128u;
755 loop {{
756 if (stride == 0u) {{ break; }}
757 if (tid < stride) {{
758 let acc = shared_data[tid];
759 let val = shared_data[tid + stride];
760 shared_data[tid] = {combine};
761 }}
762 workgroupBarrier();
763 stride = stride >> 1u;
764 }}
765
766 if (tid == 0u) {{
767 output[0] = shared_data[0];
768 }}
769}}
770"#,
771 neutral = neutral,
772 combine = combine,
773 )
774}
775
776#[cfg(test)]
777mod tests {
778 use super::*;
779
780 #[test]
781 fn wgsl_gemm_contains_workgroup() {
782 let src = gemm_wgsl(16);
783 assert!(src.contains("@compute @workgroup_size(16, 16)"));
784 assert!(src.contains("GemmParams"));
785 assert!(src.contains("alpha"));
786 assert!(src.contains("beta"));
787 }
788
789 #[test]
790 fn wgsl_gemm_tile_size_embedded() {
791 let src8 = gemm_wgsl(8);
792 assert!(src8.contains("@workgroup_size(8, 8)"));
793 let src32 = gemm_wgsl(32);
794 assert!(src32.contains("@workgroup_size(32, 32)"));
795 }
796
797 #[test]
798 fn wgsl_gemm_has_transpose_flags() {
799 let src = gemm_wgsl(8);
800 assert!(src.contains("trans_a: u32"));
802 assert!(src.contains("trans_b: u32"));
803 assert!(src.contains("a[r * params.k + i]"));
805 assert!(src.contains("a[i * params.m + r]"));
806 assert!(src.contains("b[i * params.n + col]"));
807 assert!(src.contains("b[col * params.k + i]"));
808 }
809
810 #[test]
811 fn wgsl_gemm_uses_shared_memory_tiling() {
812 let src = gemm_wgsl(16);
813 assert!(src.contains("var<workgroup> tile_a"));
814 assert!(src.contains("var<workgroup> tile_b"));
815 assert!(src.contains("workgroupBarrier"));
816 assert!(src.contains("array<array<f32, 16>, 16>"));
818 }
819
820 #[test]
821 fn wgsl_elementwise_relu_contains_max() {
822 let src = elementwise_wgsl("relu");
823 assert!(src.contains("max(x, 0.0)"));
824 }
825
826 #[test]
827 fn wgsl_elementwise_all_ops() {
828 assert!(elementwise_wgsl("sigmoid").contains("exp(-x)"));
829 assert!(elementwise_wgsl("tanh").contains("tanh(x)"));
830 assert!(elementwise_wgsl("exp").contains("exp(x)"));
831 assert!(elementwise_wgsl("log").contains("log(x)"));
832 assert!(elementwise_wgsl("sqrt").contains("sqrt(x)"));
833 assert!(elementwise_wgsl("abs").contains("abs(x)"));
834 assert!(elementwise_wgsl("neg").contains("-x"));
835 assert!(elementwise_wgsl("identity_op").contains("output[i] = x;"));
837 }
838
839 #[test]
840 fn wgsl_reduction_sum_contains_addition() {
841 let src = reduction_wgsl("sum");
842 assert!(src.contains("acc + val"));
843 assert!(src.contains("workgroupBarrier"));
844 }
845
846 #[test]
847 fn wgsl_reduction_max_uses_max_fn() {
848 let src = reduction_wgsl("max");
849 assert!(src.contains("max(acc, val)"));
850 }
851
852 #[test]
853 fn wgsl_reduction_min_uses_min_fn() {
854 let src = reduction_wgsl("min");
855 assert!(src.contains("min(acc, val)"));
856 }
857
858 #[test]
859 fn wgsl_reduction_mean_same_as_sum() {
860 let sum_src = reduction_wgsl("sum");
862 let mean_src = reduction_wgsl("mean");
863 assert_eq!(sum_src, mean_src);
864 }
865
866 #[test]
867 fn wgsl_reduction_final_sum() {
868 let src = reduction_final_wgsl("sum");
869 assert!(src.contains("num_groups"));
870 assert!(src.contains("output[0]"));
871 }
872
873 #[test]
876 fn wgsl_reduction_nd_sum_contains_addition() {
877 let src = reduction_nd_wgsl("sum");
878 assert!(src.contains("acc + val"));
879 assert!(src.contains("acc2 + val"));
881 assert!(src.contains("workgroupBarrier"));
882 assert!(src.contains("ReduceNdParams"));
883 }
884
885 #[test]
886 fn wgsl_reduction_nd_max_uses_max_fn() {
887 let src = reduction_nd_wgsl("max");
888 assert!(src.contains("max(acc, val)"));
889 assert!(src.contains("max(acc2, val)"));
890 }
891
892 #[test]
893 fn wgsl_reduction_nd_min_uses_min_fn() {
894 let src = reduction_nd_wgsl("min");
895 assert!(src.contains("min(acc, val)"));
896 assert!(src.contains("min(acc2, val)"));
897 }
898
899 #[test]
900 fn wgsl_reduction_nd_mean_divides_by_dk() {
901 let src = reduction_nd_wgsl("mean");
902 assert!(src.contains("shared_data[0] / f32(params.dk)"));
903 assert!(src.contains("acc + val"));
904 }
905
906 #[test]
907 fn wgsl_reduction_nd_sum_does_not_divide() {
908 let src = reduction_nd_wgsl("sum");
909 assert!(!src.contains("/ f32(params.dk)"));
910 }
911
912 #[test]
913 fn wgsl_reduction_nd_decodes_2d_dispatch() {
914 let src = reduction_nd_wgsl("sum");
915 assert!(src.contains("wgid.y * params.grid_x + wgid.x"));
916 }
917
918 #[test]
919 fn wgsl_reduction_nd_uses_strided_loop() {
920 let src = reduction_nd_wgsl("sum");
921 assert!(src.contains("i = i + 256u"));
922 }
923
924 #[test]
927 fn wgsl_binary_add() {
928 let src = binary_wgsl("add");
929 assert!(src.contains("a + b"));
930 assert!(src.contains("lhs"));
931 assert!(src.contains("rhs"));
932 }
933
934 #[test]
935 fn wgsl_binary_all_ops() {
936 assert!(binary_wgsl("sub").contains("a - b"));
937 assert!(binary_wgsl("mul").contains("a * b"));
938 assert!(binary_wgsl("div").contains("a / b"));
939 assert!(binary_wgsl("max").contains("max(a, b)"));
940 assert!(binary_wgsl("min").contains("min(a, b)"));
941 assert!(binary_wgsl("pow").contains("pow(a, b)"));
942 assert!(binary_wgsl("unknown_op").contains("output[i] = a;"));
944 }
945
946 #[test]
947 fn wgsl_binary_workgroup_size() {
948 let src = binary_wgsl("add");
949 assert!(src.contains("@workgroup_size(256)"));
950 }
951
952 #[test]
955 fn wgsl_conv2d_contains_workgroup() {
956 let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
957 assert!(src.contains("@compute @workgroup_size(8, 8)"));
958 }
959
960 #[test]
961 fn wgsl_conv2d_contains_storage_bindings() {
962 let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
963 assert!(src.contains("var<storage, read> input:"));
964 assert!(src.contains("var<storage, read> filter:"));
965 assert!(src.contains("var<storage, read_write> output:"));
966 }
967
968 #[test]
969 fn wgsl_conv2d_embeds_dimensions() {
970 let src = conv2d_wgsl(2, 8, 64, 64, 32, 5, 5, 60, 60, 1, 1, 0, 0);
971 assert!(src.contains("8u")); assert!(src.contains("64u")); assert!(src.contains("32u")); assert!(src.contains("5u")); assert!(src.contains("60u")); }
978
979 #[test]
980 fn wgsl_conv2d_has_padding_check() {
981 let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 8, 8, 1, 1, 1, 1);
982 assert!(src.contains("iy_raw >= 0"));
984 assert!(src.contains("ix_raw >= 0"));
985 }
986
987 #[test]
988 fn wgsl_conv2d_has_stride() {
989 let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 3, 3, 2, 2, 0, 0);
990 assert!(src.contains("2u")); }
992
993 #[test]
996 fn wgsl_attention_contains_workgroup() {
997 let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
998 assert!(src.contains("@compute @workgroup_size(64)"));
999 }
1000
1001 #[test]
1002 fn wgsl_attention_contains_storage_bindings() {
1003 let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
1004 assert!(src.contains("var<storage, read> q_buf:"));
1005 assert!(src.contains("var<storage, read> k_buf:"));
1006 assert!(src.contains("var<storage, read> v_buf:"));
1007 assert!(src.contains("var<storage, read_write> o_buf:"));
1008 }
1009
1010 #[test]
1011 fn wgsl_attention_stable_softmax() {
1012 let src = attention_wgsl(1, 4, 4, 32, 0.25, false);
1013 assert!(src.contains("max_score"));
1014 assert!(src.contains("exp(score - max_score)"));
1015 assert!(src.contains("sum_exp"));
1016 }
1017
1018 #[test]
1019 fn wgsl_attention_causal_mask() {
1020 let src_causal = attention_wgsl(1, 4, 4, 32, 0.25, true);
1021 assert!(src_causal.contains("sk > sq"));
1022
1023 let src_non_causal = attention_wgsl(1, 4, 4, 32, 0.25, false);
1024 assert!(!src_non_causal.contains("sk > sq"));
1025 }
1026
1027 #[test]
1028 fn wgsl_attention_embeds_scale() {
1029 let src = attention_wgsl(2, 16, 16, 64, 0.125, false);
1030 assert!(src.contains("0.125"));
1031 }
1032
1033 #[test]
1036 fn wgsl_batched_gemm_contains_batch_params() {
1037 let src = batched_gemm_wgsl(16);
1038 assert!(src.contains("batch_count"));
1039 assert!(src.contains("stride_a"));
1040 assert!(src.contains("stride_b"));
1041 assert!(src.contains("stride_c"));
1042 }
1043
1044 #[test]
1045 fn wgsl_batched_gemm_contains_workgroup() {
1046 let src = batched_gemm_wgsl(16);
1047 assert!(src.contains("@compute @workgroup_size(16, 16)"));
1048 assert!(src.contains("BatchedGemmParams"));
1049 }
1050
1051 #[test]
1052 fn wgsl_batched_gemm_uses_batch_index() {
1053 let src = batched_gemm_wgsl(8);
1054 assert!(src.contains("batch_index"));
1055 assert!(src.contains("gid.z"));
1056 }
1057
1058 #[test]
1059 fn wgsl_batched_gemm_tile_size_embedded() {
1060 let src8 = batched_gemm_wgsl(8);
1061 assert!(src8.contains("@workgroup_size(8, 8)"));
1062 let src32 = batched_gemm_wgsl(32);
1063 assert!(src32.contains("@workgroup_size(32, 32)"));
1064 }
1065
1066 #[test]
1067 fn wgsl_batched_gemm_has_transpose_flags() {
1068 let src = batched_gemm_wgsl(8);
1069 assert!(src.contains("trans_a: u32"));
1070 assert!(src.contains("trans_b: u32"));
1071 assert!(src.contains("a[a_offset + r * params.k + i]"));
1073 assert!(src.contains("a[a_offset + i * params.m + r]"));
1074 assert!(src.contains("b[b_offset + i * params.n + col]"));
1075 assert!(src.contains("b[b_offset + col * params.k + i]"));
1076 }
1077
1078 #[test]
1079 fn wgsl_batched_gemm_uses_shared_memory_tiling() {
1080 let src = batched_gemm_wgsl(8);
1081 assert!(src.contains("var<workgroup> tile_a"));
1082 assert!(src.contains("var<workgroup> tile_b"));
1083 assert!(src.contains("workgroupBarrier"));
1084 assert!(src.contains("array<array<f32, 8>, 8>"));
1085 }
1086
1087 #[test]
1090 fn wgsl_gemm_f16_enables_extension() {
1091 let src = gemm_wgsl_f16(16);
1092 assert!(src.contains("enable f16;"));
1093 }
1094
1095 #[test]
1096 fn wgsl_gemm_f16_uses_f16_storage() {
1097 let src = gemm_wgsl_f16(16);
1098 assert!(src.contains("array<f16>"));
1099 }
1100
1101 #[test]
1102 fn wgsl_gemm_f16_accumulates_in_f32() {
1103 let src = gemm_wgsl_f16(16);
1104 assert!(src.contains("var acc: f32 = 0.0;"));
1105 assert!(src.contains("f32(a["));
1106 assert!(src.contains("f32(b["));
1107 }
1108
1109 #[test]
1110 fn wgsl_gemm_f16_contains_workgroup() {
1111 let src = gemm_wgsl_f16(8);
1112 assert!(src.contains("@compute @workgroup_size(8, 8)"));
1113 assert!(src.contains("GemmParams"));
1114 }
1115
1116 #[test]
1117 fn wgsl_attention_embeds_dimensions() {
1118 let src = attention_wgsl(8, 32, 32, 128, 0.088, true);
1119 assert!(src.contains("128u")); assert!(src.contains("32u")); assert!(src.contains("8u")); }
1123}