1pub fn gemm_wgsl(tile_size: u32) -> String {
14 format!(
15 r#"
16struct GemmParams {{
17 m: u32,
18 n: u32,
19 k: u32,
20 alpha: f32,
21 beta: f32,
22}}
23
24@group(0) @binding(0) var<storage, read> a: array<f32>;
25@group(0) @binding(1) var<storage, read> b: array<f32>;
26@group(0) @binding(2) var<storage, read_write> c: array<f32>;
27@group(0) @binding(3) var<uniform> params: GemmParams;
28
29@compute @workgroup_size({ts}, {ts})
30fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
31 let row = gid.y;
32 let col = gid.x;
33 if (row >= params.m || col >= params.n) {{ return; }}
34
35 var acc: f32 = 0.0;
36 for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
37 acc += a[row * params.k + i] * b[i * params.n + col];
38 }}
39
40 let idx = row * params.n + col;
41 c[idx] = params.alpha * acc + params.beta * c[idx];
42}}
43"#,
44 ts = tile_size
45 )
46}
47
48pub fn batched_gemm_wgsl(tile_size: u32) -> String {
60 format!(
61 r#"
62struct BatchedGemmParams {{
63 m: u32,
64 n: u32,
65 k: u32,
66 alpha: f32,
67 beta: f32,
68 batch_count: u32,
69 stride_a: u32,
70 stride_b: u32,
71 stride_c: u32,
72}}
73
74@group(0) @binding(0) var<storage, read> a: array<f32>;
75@group(0) @binding(1) var<storage, read> b: array<f32>;
76@group(0) @binding(2) var<storage, read_write> c: array<f32>;
77@group(0) @binding(3) var<uniform> params: BatchedGemmParams;
78
79@compute @workgroup_size({ts}, {ts})
80fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
81 let row = gid.y;
82 let col = gid.x;
83 let batch_index = gid.z;
84 if (batch_index >= params.batch_count || row >= params.m || col >= params.n) {{ return; }}
85
86 let a_offset = batch_index * params.stride_a;
87 let b_offset = batch_index * params.stride_b;
88 let c_offset = batch_index * params.stride_c;
89
90 var acc: f32 = 0.0;
91 for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
92 acc += a[a_offset + row * params.k + i] * b[b_offset + i * params.n + col];
93 }}
94
95 let idx = c_offset + row * params.n + col;
96 c[idx] = params.alpha * acc + params.beta * c[idx];
97}}
98"#,
99 ts = tile_size
100 )
101}
102
103pub fn gemm_wgsl_f16(tile_size: u32) -> String {
112 format!(
113 r#"
114enable f16;
115
116struct GemmParams {{
117 m: u32,
118 n: u32,
119 k: u32,
120 alpha: f32,
121 beta: f32,
122}}
123
124@group(0) @binding(0) var<storage, read> a: array<f16>;
125@group(0) @binding(1) var<storage, read> b: array<f16>;
126@group(0) @binding(2) var<storage, read_write> c: array<f16>;
127@group(0) @binding(3) var<uniform> params: GemmParams;
128
129@compute @workgroup_size({ts}, {ts})
130fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
131 let row = gid.y;
132 let col = gid.x;
133 if (row >= params.m || col >= params.n) {{ return; }}
134
135 var acc: f32 = 0.0;
136 for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
137 acc += f32(a[row * params.k + i]) * f32(b[i * params.n + col]);
138 }}
139
140 let idx = row * params.n + col;
141 let prev = f32(c[idx]);
142 c[idx] = f16(params.alpha * acc + params.beta * prev);
143}}
144"#,
145 ts = tile_size
146 )
147}
148
149pub fn elementwise_wgsl(op: &str) -> String {
159 let op_expr = match op {
160 "relu" => "max(x, 0.0)",
161 "sigmoid" => "1.0 / (1.0 + exp(-x))",
162 "tanh" => "tanh(x)",
163 "exp" => "exp(x)",
164 "log" => "log(x)",
165 "sqrt" => "sqrt(x)",
166 "abs" => "abs(x)",
167 "neg" => "-x",
168 _ => "x",
169 };
170
171 format!(
172 r#"
173@group(0) @binding(0) var<storage, read> input: array<f32>;
174@group(0) @binding(1) var<storage, read_write> output: array<f32>;
175
176@compute @workgroup_size(256)
177fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
178 let i = gid.x;
179 if (i >= arrayLength(&input)) {{ return; }}
180 let x = input[i];
181 output[i] = {op};
182}}
183"#,
184 op = op_expr
185 )
186}
187
188pub fn binary_wgsl(op: &str) -> String {
198 let op_expr = match op {
199 "add" => "a + b",
200 "sub" => "a - b",
201 "mul" => "a * b",
202 "div" => "a / b",
203 "max" => "max(a, b)",
204 "min" => "min(a, b)",
205 "pow" => "pow(a, b)",
206 _ => "a",
207 };
208
209 format!(
210 r#"
211@group(0) @binding(0) var<storage, read> lhs: array<f32>;
212@group(0) @binding(1) var<storage, read> rhs: array<f32>;
213@group(0) @binding(2) var<storage, read_write> output: array<f32>;
214
215@compute @workgroup_size(256)
216fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
217 let i = gid.x;
218 if (i >= arrayLength(&lhs)) {{ return; }}
219 let a = lhs[i];
220 let b = rhs[i];
221 output[i] = {op};
222}}
223"#,
224 op = op_expr
225 )
226}
227
228pub fn reduction_wgsl(op: &str) -> String {
241 let (neutral, combine) = match op {
243 "max" => ("f32(-1e38)", "max(acc, val)"),
244 "min" => ("f32(1e38)", "min(acc, val)"),
245 _ => ("f32(0.0)", "acc + val"),
247 };
248
249 format!(
250 r#"
251// Reduction params: total element count.
252struct ReduceParams {{
253 n: u32,
254}}
255
256@group(0) @binding(0) var<storage, read> input: array<f32>;
257@group(0) @binding(1) var<storage, read_write> partial_sums: array<f32>;
258@group(0) @binding(2) var<uniform> params: ReduceParams;
259
260var<workgroup> shared_data: array<f32, 256>;
261
262@compute @workgroup_size(256)
263fn main(
264 @builtin(global_invocation_id) gid: vec3<u32>,
265 @builtin(local_invocation_id) lid: vec3<u32>,
266 @builtin(workgroup_id) wgid: vec3<u32>,
267) {{
268 let tid = lid.x;
269 let global_idx = gid.x;
270
271 // Load or use neutral element when out of range.
272 if (global_idx < params.n) {{
273 shared_data[tid] = input[global_idx];
274 }} else {{
275 shared_data[tid] = {neutral};
276 }}
277 workgroupBarrier();
278
279 // Parallel tree reduction within the workgroup.
280 var stride: u32 = 128u;
281 loop {{
282 if (stride == 0u) {{ break; }}
283 if (tid < stride) {{
284 let acc = shared_data[tid];
285 let val = shared_data[tid + stride];
286 shared_data[tid] = {combine};
287 }}
288 workgroupBarrier();
289 stride = stride >> 1u;
290 }}
291
292 // Thread 0 writes the workgroup result to the partial-sums buffer.
293 if (tid == 0u) {{
294 partial_sums[wgid.x] = shared_data[0];
295 }}
296}}
297"#,
298 neutral = neutral,
299 combine = combine,
300 )
301}
302
303#[allow(clippy::too_many_arguments)]
320pub fn conv2d_wgsl(
321 n: u32,
322 c_in: u32,
323 h_in: u32,
324 w_in: u32,
325 k_out: u32,
326 fh: u32,
327 fw: u32,
328 oh: u32,
329 ow: u32,
330 stride_h: u32,
331 stride_w: u32,
332 pad_h: u32,
333 pad_w: u32,
334) -> String {
335 format!(
336 r#"
337// Conv2D NCHW — generated by oxicuda-webgpu
338// input : [{n}, {c_in}, {h_in}, {w_in}]
339// filter: [{k_out}, {c_in}, {fh}, {fw}]
340// output: [{n}, {k_out}, {oh}, {ow}]
341
342@group(0) @binding(0) var<storage, read> input: array<f32>;
343@group(0) @binding(1) var<storage, read> filter: array<f32>;
344@group(0) @binding(2) var<storage, read_write> output: array<f32>;
345
346@compute @workgroup_size(8, 8)
347fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
348 // gid.x = output x (ox mapped across batches*k_out*oh)
349 // We flatten (batch, k, oy) into gid.y and ox into gid.x
350 let ox = gid.x;
351 let linear_y = gid.y;
352
353 let batch_k_oh = {n}u * {k_out}u * {oh}u;
354 if (ox >= {ow}u || linear_y >= batch_k_oh) {{ return; }}
355
356 let b = linear_y / ({k_out}u * {oh}u);
357 let rem = linear_y % ({k_out}u * {oh}u);
358 let kf = rem / {oh}u;
359 let oy = rem % {oh}u;
360
361 var acc: f32 = 0.0;
362 for (var ci: u32 = 0u; ci < {c_in}u; ci = ci + 1u) {{
363 for (var fy: u32 = 0u; fy < {fh}u; fy = fy + 1u) {{
364 for (var fx: u32 = 0u; fx < {fw}u; fx = fx + 1u) {{
365 let iy_raw = i32(oy * {stride_h}u + fy) - i32({pad_h}u);
366 let ix_raw = i32(ox * {stride_w}u + fx) - i32({pad_w}u);
367 if (iy_raw >= 0 && iy_raw < i32({h_in}u) && ix_raw >= 0 && ix_raw < i32({w_in}u)) {{
368 let iy = u32(iy_raw);
369 let ix = u32(ix_raw);
370 let in_idx = ((b * {c_in}u + ci) * {h_in}u + iy) * {w_in}u + ix;
371 let f_idx = ((kf * {c_in}u + ci) * {fh}u + fy) * {fw}u + fx;
372 acc += input[in_idx] * filter[f_idx];
373 }}
374 }}
375 }}
376 }}
377
378 let o_idx = ((b * {k_out}u + kf) * {oh}u + oy) * {ow}u + ox;
379 output[o_idx] = acc;
380}}
381"#,
382 n = n,
383 c_in = c_in,
384 h_in = h_in,
385 w_in = w_in,
386 k_out = k_out,
387 fh = fh,
388 fw = fw,
389 oh = oh,
390 ow = ow,
391 stride_h = stride_h,
392 stride_w = stride_w,
393 pad_h = pad_h,
394 pad_w = pad_w,
395 )
396}
397
398pub fn attention_wgsl(
414 batch_heads: u32,
415 seq_q: u32,
416 seq_kv: u32,
417 head_dim: u32,
418 scale: f32,
419 causal: bool,
420) -> String {
421 let causal_check = if causal {
422 "if (sk > sq) { score = f32(-1e38); } else {"
423 } else {
424 "{"
425 };
426
427 format!(
428 r#"
429// Scaled dot-product attention — generated by oxicuda-webgpu
430// Q, K, V : [{batch_heads}, seq, {head_dim}]
431// O : [{batch_heads}, {seq_q}, {head_dim}]
432// scale : {scale}
433// causal : {causal}
434
435@group(0) @binding(0) var<storage, read> q_buf: array<f32>;
436@group(0) @binding(1) var<storage, read> k_buf: array<f32>;
437@group(0) @binding(2) var<storage, read> v_buf: array<f32>;
438@group(0) @binding(3) var<storage, read_write> o_buf: array<f32>;
439
440@compute @workgroup_size(64)
441fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
442 let linear = gid.x;
443 let total = {batch_heads}u * {seq_q}u;
444 if (linear >= total) {{ return; }}
445
446 let bh = linear / {seq_q}u;
447 let sq = linear % {seq_q}u;
448
449 let q_base = (bh * {seq_q}u + sq) * {head_dim}u;
450
451 // Pass 1: find max score for numerical stability
452 var max_score: f32 = f32(-1e38);
453 for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
454 var score: f32 = 0.0;
455 {causal_check}
456 let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
457 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
458 score += q_buf[q_base + d] * k_buf[k_base + d];
459 }}
460 score *= f32({scale});
461 }}
462 if (score > max_score) {{ max_score = score; }}
463 }}
464
465 // Pass 2: compute exp(score - max), accumulate weighted V
466 var sum_exp: f32 = 0.0;
467 for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
468 var score: f32 = 0.0;
469 {causal_check}
470 let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
471 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
472 score += q_buf[q_base + d] * k_buf[k_base + d];
473 }}
474 score *= f32({scale});
475 }}
476 let w = exp(score - max_score);
477 sum_exp += w;
478 let v_base = (bh * {seq_kv}u + sk) * {head_dim}u;
479 let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
480 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
481 // Accumulate in-place (we normalise after the loop).
482 o_buf[o_base + d] += w * v_buf[v_base + d];
483 }}
484 }}
485
486 // Pass 3: normalise
487 if (sum_exp > 0.0) {{
488 let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
489 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
490 o_buf[o_base + d] /= sum_exp;
491 }}
492 }}
493}}
494"#,
495 batch_heads = batch_heads,
496 seq_q = seq_q,
497 seq_kv = seq_kv,
498 head_dim = head_dim,
499 scale = scale,
500 causal = causal,
501 causal_check = causal_check,
502 )
503}
504
505pub fn reduction_nd_wgsl(op: &str) -> String {
529 let (neutral, combine, combine_alias) = match op {
534 "max" => ("f32(-1e38)", "max(acc, val)", "max(acc2, val)"),
535 "min" => ("f32(1e38)", "min(acc, val)", "min(acc2, val)"),
536 _ => ("f32(0.0)", "acc + val", "acc2 + val"),
538 };
539
540 let final_expr = if op == "mean" {
542 "shared_data[0] / f32(params.dk)"
543 } else {
544 "shared_data[0]"
545 };
546
547 format!(
548 r#"
549struct ReduceNdParams {{
550 outer: u32,
551 dk: u32,
552 inner: u32,
553 outer_stride: u32,
554 dk_stride: u32,
555 inner_stride: u32,
556 grid_x: u32,
557 _pad: u32,
558}}
559
560@group(0) @binding(0) var<storage, read> input: array<f32>;
561@group(0) @binding(1) var<storage, read_write> output: array<f32>;
562@group(0) @binding(2) var<uniform> params: ReduceNdParams;
563
564var<workgroup> shared_data: array<f32, 256>;
565
566@compute @workgroup_size(256)
567fn main(
568 @builtin(local_invocation_id) lid: vec3<u32>,
569 @builtin(workgroup_id) wgid: vec3<u32>,
570) {{
571 let tid = lid.x;
572 let total = params.outer * params.inner;
573
574 // Decode 2-D workgroup id back to a linear output slot.
575 let slot = wgid.y * params.grid_x + wgid.x;
576 if (slot >= total) {{ return; }}
577
578 let o = slot / params.inner;
579 let j = slot % params.inner;
580 let base = o * params.outer_stride + j * params.inner_stride;
581
582 // Strided per-thread reduction across the dk axis.
583 var acc: f32 = {neutral};
584 var i: u32 = tid;
585 loop {{
586 if (i >= params.dk) {{ break; }}
587 let val = input[base + i * params.dk_stride];
588 acc = {combine};
589 i = i + 256u;
590 }}
591
592 shared_data[tid] = acc;
593 workgroupBarrier();
594
595 // Tree reduction within the workgroup.
596 var stride: u32 = 128u;
597 loop {{
598 if (stride == 0u) {{ break; }}
599 if (tid < stride) {{
600 let acc2 = shared_data[tid];
601 let val = shared_data[tid + stride];
602 shared_data[tid] = {combine_alias};
603 }}
604 workgroupBarrier();
605 stride = stride >> 1u;
606 }}
607
608 if (tid == 0u) {{
609 output[slot] = {final_expr};
610 }}
611}}
612"#,
613 neutral = neutral,
614 combine = combine,
615 combine_alias = combine_alias,
616 final_expr = final_expr,
617 )
618}
619
620pub fn reduction_final_wgsl(op: &str) -> String {
626 let (neutral, combine) = match op {
627 "max" => ("f32(-1e38)", "max(acc, val)"),
628 "min" => ("f32(1e38)", "min(acc, val)"),
629 _ => ("f32(0.0)", "acc + val"),
630 };
631
632 format!(
633 r#"
634struct FinalReduceParams {{
635 num_groups: u32,
636}}
637
638@group(0) @binding(0) var<storage, read> partial_sums: array<f32>;
639@group(0) @binding(1) var<storage, read_write> output: array<f32>;
640@group(0) @binding(2) var<uniform> params: FinalReduceParams;
641
642var<workgroup> shared_data: array<f32, 256>;
643
644@compute @workgroup_size(256)
645fn main(
646 @builtin(local_invocation_id) lid: vec3<u32>,
647) {{
648 let tid = lid.x;
649
650 if (tid < params.num_groups) {{
651 shared_data[tid] = partial_sums[tid];
652 }} else {{
653 shared_data[tid] = {neutral};
654 }}
655 workgroupBarrier();
656
657 var stride: u32 = 128u;
658 loop {{
659 if (stride == 0u) {{ break; }}
660 if (tid < stride) {{
661 let acc = shared_data[tid];
662 let val = shared_data[tid + stride];
663 shared_data[tid] = {combine};
664 }}
665 workgroupBarrier();
666 stride = stride >> 1u;
667 }}
668
669 if (tid == 0u) {{
670 output[0] = shared_data[0];
671 }}
672}}
673"#,
674 neutral = neutral,
675 combine = combine,
676 )
677}
678
679#[cfg(test)]
680mod tests {
681 use super::*;
682
683 #[test]
684 fn wgsl_gemm_contains_workgroup() {
685 let src = gemm_wgsl(16);
686 assert!(src.contains("@compute @workgroup_size(16, 16)"));
687 assert!(src.contains("GemmParams"));
688 assert!(src.contains("alpha"));
689 assert!(src.contains("beta"));
690 }
691
692 #[test]
693 fn wgsl_gemm_tile_size_embedded() {
694 let src8 = gemm_wgsl(8);
695 assert!(src8.contains("@workgroup_size(8, 8)"));
696 let src32 = gemm_wgsl(32);
697 assert!(src32.contains("@workgroup_size(32, 32)"));
698 }
699
700 #[test]
701 fn wgsl_elementwise_relu_contains_max() {
702 let src = elementwise_wgsl("relu");
703 assert!(src.contains("max(x, 0.0)"));
704 }
705
706 #[test]
707 fn wgsl_elementwise_all_ops() {
708 assert!(elementwise_wgsl("sigmoid").contains("exp(-x)"));
709 assert!(elementwise_wgsl("tanh").contains("tanh(x)"));
710 assert!(elementwise_wgsl("exp").contains("exp(x)"));
711 assert!(elementwise_wgsl("log").contains("log(x)"));
712 assert!(elementwise_wgsl("sqrt").contains("sqrt(x)"));
713 assert!(elementwise_wgsl("abs").contains("abs(x)"));
714 assert!(elementwise_wgsl("neg").contains("-x"));
715 assert!(elementwise_wgsl("identity_op").contains("output[i] = x;"));
717 }
718
719 #[test]
720 fn wgsl_reduction_sum_contains_addition() {
721 let src = reduction_wgsl("sum");
722 assert!(src.contains("acc + val"));
723 assert!(src.contains("workgroupBarrier"));
724 }
725
726 #[test]
727 fn wgsl_reduction_max_uses_max_fn() {
728 let src = reduction_wgsl("max");
729 assert!(src.contains("max(acc, val)"));
730 }
731
732 #[test]
733 fn wgsl_reduction_min_uses_min_fn() {
734 let src = reduction_wgsl("min");
735 assert!(src.contains("min(acc, val)"));
736 }
737
738 #[test]
739 fn wgsl_reduction_mean_same_as_sum() {
740 let sum_src = reduction_wgsl("sum");
742 let mean_src = reduction_wgsl("mean");
743 assert_eq!(sum_src, mean_src);
744 }
745
746 #[test]
747 fn wgsl_reduction_final_sum() {
748 let src = reduction_final_wgsl("sum");
749 assert!(src.contains("num_groups"));
750 assert!(src.contains("output[0]"));
751 }
752
753 #[test]
756 fn wgsl_reduction_nd_sum_contains_addition() {
757 let src = reduction_nd_wgsl("sum");
758 assert!(src.contains("acc + val"));
759 assert!(src.contains("acc2 + val"));
761 assert!(src.contains("workgroupBarrier"));
762 assert!(src.contains("ReduceNdParams"));
763 }
764
765 #[test]
766 fn wgsl_reduction_nd_max_uses_max_fn() {
767 let src = reduction_nd_wgsl("max");
768 assert!(src.contains("max(acc, val)"));
769 assert!(src.contains("max(acc2, val)"));
770 }
771
772 #[test]
773 fn wgsl_reduction_nd_min_uses_min_fn() {
774 let src = reduction_nd_wgsl("min");
775 assert!(src.contains("min(acc, val)"));
776 assert!(src.contains("min(acc2, val)"));
777 }
778
779 #[test]
780 fn wgsl_reduction_nd_mean_divides_by_dk() {
781 let src = reduction_nd_wgsl("mean");
782 assert!(src.contains("shared_data[0] / f32(params.dk)"));
783 assert!(src.contains("acc + val"));
784 }
785
786 #[test]
787 fn wgsl_reduction_nd_sum_does_not_divide() {
788 let src = reduction_nd_wgsl("sum");
789 assert!(!src.contains("/ f32(params.dk)"));
790 }
791
792 #[test]
793 fn wgsl_reduction_nd_decodes_2d_dispatch() {
794 let src = reduction_nd_wgsl("sum");
795 assert!(src.contains("wgid.y * params.grid_x + wgid.x"));
796 }
797
798 #[test]
799 fn wgsl_reduction_nd_uses_strided_loop() {
800 let src = reduction_nd_wgsl("sum");
801 assert!(src.contains("i = i + 256u"));
802 }
803
804 #[test]
807 fn wgsl_binary_add() {
808 let src = binary_wgsl("add");
809 assert!(src.contains("a + b"));
810 assert!(src.contains("lhs"));
811 assert!(src.contains("rhs"));
812 }
813
814 #[test]
815 fn wgsl_binary_all_ops() {
816 assert!(binary_wgsl("sub").contains("a - b"));
817 assert!(binary_wgsl("mul").contains("a * b"));
818 assert!(binary_wgsl("div").contains("a / b"));
819 assert!(binary_wgsl("max").contains("max(a, b)"));
820 assert!(binary_wgsl("min").contains("min(a, b)"));
821 assert!(binary_wgsl("pow").contains("pow(a, b)"));
822 assert!(binary_wgsl("unknown_op").contains("output[i] = a;"));
824 }
825
826 #[test]
827 fn wgsl_binary_workgroup_size() {
828 let src = binary_wgsl("add");
829 assert!(src.contains("@workgroup_size(256)"));
830 }
831
832 #[test]
835 fn wgsl_conv2d_contains_workgroup() {
836 let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
837 assert!(src.contains("@compute @workgroup_size(8, 8)"));
838 }
839
840 #[test]
841 fn wgsl_conv2d_contains_storage_bindings() {
842 let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
843 assert!(src.contains("var<storage, read> input:"));
844 assert!(src.contains("var<storage, read> filter:"));
845 assert!(src.contains("var<storage, read_write> output:"));
846 }
847
848 #[test]
849 fn wgsl_conv2d_embeds_dimensions() {
850 let src = conv2d_wgsl(2, 8, 64, 64, 32, 5, 5, 60, 60, 1, 1, 0, 0);
851 assert!(src.contains("8u")); assert!(src.contains("64u")); assert!(src.contains("32u")); assert!(src.contains("5u")); assert!(src.contains("60u")); }
858
859 #[test]
860 fn wgsl_conv2d_has_padding_check() {
861 let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 8, 8, 1, 1, 1, 1);
862 assert!(src.contains("iy_raw >= 0"));
864 assert!(src.contains("ix_raw >= 0"));
865 }
866
867 #[test]
868 fn wgsl_conv2d_has_stride() {
869 let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 3, 3, 2, 2, 0, 0);
870 assert!(src.contains("2u")); }
872
873 #[test]
876 fn wgsl_attention_contains_workgroup() {
877 let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
878 assert!(src.contains("@compute @workgroup_size(64)"));
879 }
880
881 #[test]
882 fn wgsl_attention_contains_storage_bindings() {
883 let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
884 assert!(src.contains("var<storage, read> q_buf:"));
885 assert!(src.contains("var<storage, read> k_buf:"));
886 assert!(src.contains("var<storage, read> v_buf:"));
887 assert!(src.contains("var<storage, read_write> o_buf:"));
888 }
889
890 #[test]
891 fn wgsl_attention_stable_softmax() {
892 let src = attention_wgsl(1, 4, 4, 32, 0.25, false);
893 assert!(src.contains("max_score"));
894 assert!(src.contains("exp(score - max_score)"));
895 assert!(src.contains("sum_exp"));
896 }
897
898 #[test]
899 fn wgsl_attention_causal_mask() {
900 let src_causal = attention_wgsl(1, 4, 4, 32, 0.25, true);
901 assert!(src_causal.contains("sk > sq"));
902
903 let src_non_causal = attention_wgsl(1, 4, 4, 32, 0.25, false);
904 assert!(!src_non_causal.contains("sk > sq"));
905 }
906
907 #[test]
908 fn wgsl_attention_embeds_scale() {
909 let src = attention_wgsl(2, 16, 16, 64, 0.125, false);
910 assert!(src.contains("0.125"));
911 }
912
913 #[test]
916 fn wgsl_batched_gemm_contains_batch_params() {
917 let src = batched_gemm_wgsl(16);
918 assert!(src.contains("batch_count"));
919 assert!(src.contains("stride_a"));
920 assert!(src.contains("stride_b"));
921 assert!(src.contains("stride_c"));
922 }
923
924 #[test]
925 fn wgsl_batched_gemm_contains_workgroup() {
926 let src = batched_gemm_wgsl(16);
927 assert!(src.contains("@compute @workgroup_size(16, 16)"));
928 assert!(src.contains("BatchedGemmParams"));
929 }
930
931 #[test]
932 fn wgsl_batched_gemm_uses_batch_index() {
933 let src = batched_gemm_wgsl(8);
934 assert!(src.contains("batch_index"));
935 assert!(src.contains("gid.z"));
936 }
937
938 #[test]
939 fn wgsl_batched_gemm_tile_size_embedded() {
940 let src8 = batched_gemm_wgsl(8);
941 assert!(src8.contains("@workgroup_size(8, 8)"));
942 let src32 = batched_gemm_wgsl(32);
943 assert!(src32.contains("@workgroup_size(32, 32)"));
944 }
945
946 #[test]
949 fn wgsl_gemm_f16_enables_extension() {
950 let src = gemm_wgsl_f16(16);
951 assert!(src.contains("enable f16;"));
952 }
953
954 #[test]
955 fn wgsl_gemm_f16_uses_f16_storage() {
956 let src = gemm_wgsl_f16(16);
957 assert!(src.contains("array<f16>"));
958 }
959
960 #[test]
961 fn wgsl_gemm_f16_accumulates_in_f32() {
962 let src = gemm_wgsl_f16(16);
963 assert!(src.contains("var acc: f32 = 0.0;"));
964 assert!(src.contains("f32(a["));
965 assert!(src.contains("f32(b["));
966 }
967
968 #[test]
969 fn wgsl_gemm_f16_contains_workgroup() {
970 let src = gemm_wgsl_f16(8);
971 assert!(src.contains("@compute @workgroup_size(8, 8)"));
972 assert!(src.contains("GemmParams"));
973 }
974
975 #[test]
976 fn wgsl_attention_embeds_dimensions() {
977 let src = attention_wgsl(8, 32, 32, 128, 0.088, true);
978 assert!(src.contains("128u")); assert!(src.contains("32u")); assert!(src.contains("8u")); }
982}