1pub fn gemm_wgsl(tile_size: u32) -> String {
14 format!(
15 r#"
16struct GemmParams {{
17 m: u32,
18 n: u32,
19 k: u32,
20 alpha: f32,
21 beta: f32,
22}}
23
24@group(0) @binding(0) var<storage, read> a: array<f32>;
25@group(0) @binding(1) var<storage, read> b: array<f32>;
26@group(0) @binding(2) var<storage, read_write> c: array<f32>;
27@group(0) @binding(3) var<uniform> params: GemmParams;
28
29@compute @workgroup_size({ts}, {ts})
30fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
31 let row = gid.y;
32 let col = gid.x;
33 if (row >= params.m || col >= params.n) {{ return; }}
34
35 var acc: f32 = 0.0;
36 for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
37 acc += a[row * params.k + i] * b[i * params.n + col];
38 }}
39
40 let idx = row * params.n + col;
41 c[idx] = params.alpha * acc + params.beta * c[idx];
42}}
43"#,
44 ts = tile_size
45 )
46}
47
48pub fn batched_gemm_wgsl(tile_size: u32) -> String {
60 format!(
61 r#"
62struct BatchedGemmParams {{
63 m: u32,
64 n: u32,
65 k: u32,
66 alpha: f32,
67 beta: f32,
68 batch_count: u32,
69 stride_a: u32,
70 stride_b: u32,
71 stride_c: u32,
72}}
73
74@group(0) @binding(0) var<storage, read> a: array<f32>;
75@group(0) @binding(1) var<storage, read> b: array<f32>;
76@group(0) @binding(2) var<storage, read_write> c: array<f32>;
77@group(0) @binding(3) var<uniform> params: BatchedGemmParams;
78
79@compute @workgroup_size({ts}, {ts})
80fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
81 let row = gid.y;
82 let col = gid.x;
83 let batch_index = gid.z;
84 if (batch_index >= params.batch_count || row >= params.m || col >= params.n) {{ return; }}
85
86 let a_offset = batch_index * params.stride_a;
87 let b_offset = batch_index * params.stride_b;
88 let c_offset = batch_index * params.stride_c;
89
90 var acc: f32 = 0.0;
91 for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
92 acc += a[a_offset + row * params.k + i] * b[b_offset + i * params.n + col];
93 }}
94
95 let idx = c_offset + row * params.n + col;
96 c[idx] = params.alpha * acc + params.beta * c[idx];
97}}
98"#,
99 ts = tile_size
100 )
101}
102
103pub fn gemm_wgsl_f16(tile_size: u32) -> String {
112 format!(
113 r#"
114enable f16;
115
116struct GemmParams {{
117 m: u32,
118 n: u32,
119 k: u32,
120 alpha: f32,
121 beta: f32,
122}}
123
124@group(0) @binding(0) var<storage, read> a: array<f16>;
125@group(0) @binding(1) var<storage, read> b: array<f16>;
126@group(0) @binding(2) var<storage, read_write> c: array<f16>;
127@group(0) @binding(3) var<uniform> params: GemmParams;
128
129@compute @workgroup_size({ts}, {ts})
130fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
131 let row = gid.y;
132 let col = gid.x;
133 if (row >= params.m || col >= params.n) {{ return; }}
134
135 var acc: f32 = 0.0;
136 for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
137 acc += f32(a[row * params.k + i]) * f32(b[i * params.n + col]);
138 }}
139
140 let idx = row * params.n + col;
141 let prev = f32(c[idx]);
142 c[idx] = f16(params.alpha * acc + params.beta * prev);
143}}
144"#,
145 ts = tile_size
146 )
147}
148
149pub fn elementwise_wgsl(op: &str) -> String {
159 let op_expr = match op {
160 "relu" => "max(x, 0.0)",
161 "sigmoid" => "1.0 / (1.0 + exp(-x))",
162 "tanh" => "tanh(x)",
163 "exp" => "exp(x)",
164 "log" => "log(x)",
165 "sqrt" => "sqrt(x)",
166 "abs" => "abs(x)",
167 "neg" => "-x",
168 _ => "x",
169 };
170
171 format!(
172 r#"
173@group(0) @binding(0) var<storage, read> input: array<f32>;
174@group(0) @binding(1) var<storage, read_write> output: array<f32>;
175
176@compute @workgroup_size(256)
177fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
178 let i = gid.x;
179 if (i >= arrayLength(&input)) {{ return; }}
180 let x = input[i];
181 output[i] = {op};
182}}
183"#,
184 op = op_expr
185 )
186}
187
188pub fn binary_wgsl(op: &str) -> String {
198 let op_expr = match op {
199 "add" => "a + b",
200 "sub" => "a - b",
201 "mul" => "a * b",
202 "div" => "a / b",
203 "max" => "max(a, b)",
204 "min" => "min(a, b)",
205 "pow" => "pow(a, b)",
206 _ => "a",
207 };
208
209 format!(
210 r#"
211@group(0) @binding(0) var<storage, read> lhs: array<f32>;
212@group(0) @binding(1) var<storage, read> rhs: array<f32>;
213@group(0) @binding(2) var<storage, read_write> output: array<f32>;
214
215@compute @workgroup_size(256)
216fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
217 let i = gid.x;
218 if (i >= arrayLength(&lhs)) {{ return; }}
219 let a = lhs[i];
220 let b = rhs[i];
221 output[i] = {op};
222}}
223"#,
224 op = op_expr
225 )
226}
227
228pub fn reduction_wgsl(op: &str) -> String {
241 let (neutral, combine) = match op {
243 "max" => ("f32(-1e38)", "max(acc, val)"),
244 "min" => ("f32(1e38)", "min(acc, val)"),
245 _ => ("f32(0.0)", "acc + val"),
247 };
248
249 format!(
250 r#"
251// Reduction params: total element count.
252struct ReduceParams {{
253 n: u32,
254}}
255
256@group(0) @binding(0) var<storage, read> input: array<f32>;
257@group(0) @binding(1) var<storage, read_write> partial_sums: array<f32>;
258@group(0) @binding(2) var<uniform> params: ReduceParams;
259
260var<workgroup> shared_data: array<f32, 256>;
261
262@compute @workgroup_size(256)
263fn main(
264 @builtin(global_invocation_id) gid: vec3<u32>,
265 @builtin(local_invocation_id) lid: vec3<u32>,
266 @builtin(workgroup_id) wgid: vec3<u32>,
267) {{
268 let tid = lid.x;
269 let global_idx = gid.x;
270
271 // Load or use neutral element when out of range.
272 if (global_idx < params.n) {{
273 shared_data[tid] = input[global_idx];
274 }} else {{
275 shared_data[tid] = {neutral};
276 }}
277 workgroupBarrier();
278
279 // Parallel tree reduction within the workgroup.
280 var stride: u32 = 128u;
281 loop {{
282 if (stride == 0u) {{ break; }}
283 if (tid < stride) {{
284 let acc = shared_data[tid];
285 let val = shared_data[tid + stride];
286 shared_data[tid] = {combine};
287 }}
288 workgroupBarrier();
289 stride = stride >> 1u;
290 }}
291
292 // Thread 0 writes the workgroup result to the partial-sums buffer.
293 if (tid == 0u) {{
294 partial_sums[wgid.x] = shared_data[0];
295 }}
296}}
297"#,
298 neutral = neutral,
299 combine = combine,
300 )
301}
302
303#[allow(clippy::too_many_arguments)]
320pub fn conv2d_wgsl(
321 n: u32,
322 c_in: u32,
323 h_in: u32,
324 w_in: u32,
325 k_out: u32,
326 fh: u32,
327 fw: u32,
328 oh: u32,
329 ow: u32,
330 stride_h: u32,
331 stride_w: u32,
332 pad_h: u32,
333 pad_w: u32,
334) -> String {
335 format!(
336 r#"
337// Conv2D NCHW — generated by oxicuda-webgpu
338// input : [{n}, {c_in}, {h_in}, {w_in}]
339// filter: [{k_out}, {c_in}, {fh}, {fw}]
340// output: [{n}, {k_out}, {oh}, {ow}]
341
342@group(0) @binding(0) var<storage, read> input: array<f32>;
343@group(0) @binding(1) var<storage, read> filter: array<f32>;
344@group(0) @binding(2) var<storage, read_write> output: array<f32>;
345
346@compute @workgroup_size(8, 8)
347fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
348 // gid.x = output x (ox mapped across batches*k_out*oh)
349 // We flatten (batch, k, oy) into gid.y and ox into gid.x
350 let ox = gid.x;
351 let linear_y = gid.y;
352
353 let batch_k_oh = {n}u * {k_out}u * {oh}u;
354 if (ox >= {ow}u || linear_y >= batch_k_oh) {{ return; }}
355
356 let b = linear_y / ({k_out}u * {oh}u);
357 let rem = linear_y % ({k_out}u * {oh}u);
358 let kf = rem / {oh}u;
359 let oy = rem % {oh}u;
360
361 var acc: f32 = 0.0;
362 for (var ci: u32 = 0u; ci < {c_in}u; ci = ci + 1u) {{
363 for (var fy: u32 = 0u; fy < {fh}u; fy = fy + 1u) {{
364 for (var fx: u32 = 0u; fx < {fw}u; fx = fx + 1u) {{
365 let iy_raw = i32(oy * {stride_h}u + fy) - i32({pad_h}u);
366 let ix_raw = i32(ox * {stride_w}u + fx) - i32({pad_w}u);
367 if (iy_raw >= 0 && iy_raw < i32({h_in}u) && ix_raw >= 0 && ix_raw < i32({w_in}u)) {{
368 let iy = u32(iy_raw);
369 let ix = u32(ix_raw);
370 let in_idx = ((b * {c_in}u + ci) * {h_in}u + iy) * {w_in}u + ix;
371 let f_idx = ((kf * {c_in}u + ci) * {fh}u + fy) * {fw}u + fx;
372 acc += input[in_idx] * filter[f_idx];
373 }}
374 }}
375 }}
376 }}
377
378 let o_idx = ((b * {k_out}u + kf) * {oh}u + oy) * {ow}u + ox;
379 output[o_idx] = acc;
380}}
381"#,
382 n = n,
383 c_in = c_in,
384 h_in = h_in,
385 w_in = w_in,
386 k_out = k_out,
387 fh = fh,
388 fw = fw,
389 oh = oh,
390 ow = ow,
391 stride_h = stride_h,
392 stride_w = stride_w,
393 pad_h = pad_h,
394 pad_w = pad_w,
395 )
396}
397
398pub fn attention_wgsl(
414 batch_heads: u32,
415 seq_q: u32,
416 seq_kv: u32,
417 head_dim: u32,
418 scale: f32,
419 causal: bool,
420) -> String {
421 let causal_check = if causal {
422 "if (sk > sq) { score = f32(-1e38); } else {"
423 } else {
424 "{"
425 };
426
427 format!(
428 r#"
429// Scaled dot-product attention — generated by oxicuda-webgpu
430// Q, K, V : [{batch_heads}, seq, {head_dim}]
431// O : [{batch_heads}, {seq_q}, {head_dim}]
432// scale : {scale}
433// causal : {causal}
434
435@group(0) @binding(0) var<storage, read> q_buf: array<f32>;
436@group(0) @binding(1) var<storage, read> k_buf: array<f32>;
437@group(0) @binding(2) var<storage, read> v_buf: array<f32>;
438@group(0) @binding(3) var<storage, read_write> o_buf: array<f32>;
439
440@compute @workgroup_size(64)
441fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
442 let linear = gid.x;
443 let total = {batch_heads}u * {seq_q}u;
444 if (linear >= total) {{ return; }}
445
446 let bh = linear / {seq_q}u;
447 let sq = linear % {seq_q}u;
448
449 let q_base = (bh * {seq_q}u + sq) * {head_dim}u;
450
451 // Pass 1: find max score for numerical stability
452 var max_score: f32 = f32(-1e38);
453 for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
454 var score: f32 = 0.0;
455 {causal_check}
456 let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
457 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
458 score += q_buf[q_base + d] * k_buf[k_base + d];
459 }}
460 score *= f32({scale});
461 }}
462 if (score > max_score) {{ max_score = score; }}
463 }}
464
465 // Pass 2: compute exp(score - max), accumulate weighted V
466 var sum_exp: f32 = 0.0;
467 for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
468 var score: f32 = 0.0;
469 {causal_check}
470 let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
471 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
472 score += q_buf[q_base + d] * k_buf[k_base + d];
473 }}
474 score *= f32({scale});
475 }}
476 let w = exp(score - max_score);
477 sum_exp += w;
478 let v_base = (bh * {seq_kv}u + sk) * {head_dim}u;
479 let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
480 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
481 // Accumulate in-place (we normalise after the loop).
482 o_buf[o_base + d] += w * v_buf[v_base + d];
483 }}
484 }}
485
486 // Pass 3: normalise
487 if (sum_exp > 0.0) {{
488 let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
489 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
490 o_buf[o_base + d] /= sum_exp;
491 }}
492 }}
493}}
494"#,
495 batch_heads = batch_heads,
496 seq_q = seq_q,
497 seq_kv = seq_kv,
498 head_dim = head_dim,
499 scale = scale,
500 causal = causal,
501 causal_check = causal_check,
502 )
503}
504
505pub fn reduction_final_wgsl(op: &str) -> String {
511 let (neutral, combine) = match op {
512 "max" => ("f32(-1e38)", "max(acc, val)"),
513 "min" => ("f32(1e38)", "min(acc, val)"),
514 _ => ("f32(0.0)", "acc + val"),
515 };
516
517 format!(
518 r#"
519struct FinalReduceParams {{
520 num_groups: u32,
521}}
522
523@group(0) @binding(0) var<storage, read> partial_sums: array<f32>;
524@group(0) @binding(1) var<storage, read_write> output: array<f32>;
525@group(0) @binding(2) var<uniform> params: FinalReduceParams;
526
527var<workgroup> shared_data: array<f32, 256>;
528
529@compute @workgroup_size(256)
530fn main(
531 @builtin(local_invocation_id) lid: vec3<u32>,
532) {{
533 let tid = lid.x;
534
535 if (tid < params.num_groups) {{
536 shared_data[tid] = partial_sums[tid];
537 }} else {{
538 shared_data[tid] = {neutral};
539 }}
540 workgroupBarrier();
541
542 var stride: u32 = 128u;
543 loop {{
544 if (stride == 0u) {{ break; }}
545 if (tid < stride) {{
546 let acc = shared_data[tid];
547 let val = shared_data[tid + stride];
548 shared_data[tid] = {combine};
549 }}
550 workgroupBarrier();
551 stride = stride >> 1u;
552 }}
553
554 if (tid == 0u) {{
555 output[0] = shared_data[0];
556 }}
557}}
558"#,
559 neutral = neutral,
560 combine = combine,
561 )
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567
568 #[test]
569 fn wgsl_gemm_contains_workgroup() {
570 let src = gemm_wgsl(16);
571 assert!(src.contains("@compute @workgroup_size(16, 16)"));
572 assert!(src.contains("GemmParams"));
573 assert!(src.contains("alpha"));
574 assert!(src.contains("beta"));
575 }
576
577 #[test]
578 fn wgsl_gemm_tile_size_embedded() {
579 let src8 = gemm_wgsl(8);
580 assert!(src8.contains("@workgroup_size(8, 8)"));
581 let src32 = gemm_wgsl(32);
582 assert!(src32.contains("@workgroup_size(32, 32)"));
583 }
584
585 #[test]
586 fn wgsl_elementwise_relu_contains_max() {
587 let src = elementwise_wgsl("relu");
588 assert!(src.contains("max(x, 0.0)"));
589 }
590
591 #[test]
592 fn wgsl_elementwise_all_ops() {
593 assert!(elementwise_wgsl("sigmoid").contains("exp(-x)"));
594 assert!(elementwise_wgsl("tanh").contains("tanh(x)"));
595 assert!(elementwise_wgsl("exp").contains("exp(x)"));
596 assert!(elementwise_wgsl("log").contains("log(x)"));
597 assert!(elementwise_wgsl("sqrt").contains("sqrt(x)"));
598 assert!(elementwise_wgsl("abs").contains("abs(x)"));
599 assert!(elementwise_wgsl("neg").contains("-x"));
600 assert!(elementwise_wgsl("identity_op").contains("output[i] = x;"));
602 }
603
604 #[test]
605 fn wgsl_reduction_sum_contains_addition() {
606 let src = reduction_wgsl("sum");
607 assert!(src.contains("acc + val"));
608 assert!(src.contains("workgroupBarrier"));
609 }
610
611 #[test]
612 fn wgsl_reduction_max_uses_max_fn() {
613 let src = reduction_wgsl("max");
614 assert!(src.contains("max(acc, val)"));
615 }
616
617 #[test]
618 fn wgsl_reduction_min_uses_min_fn() {
619 let src = reduction_wgsl("min");
620 assert!(src.contains("min(acc, val)"));
621 }
622
623 #[test]
624 fn wgsl_reduction_mean_same_as_sum() {
625 let sum_src = reduction_wgsl("sum");
627 let mean_src = reduction_wgsl("mean");
628 assert_eq!(sum_src, mean_src);
629 }
630
631 #[test]
632 fn wgsl_reduction_final_sum() {
633 let src = reduction_final_wgsl("sum");
634 assert!(src.contains("num_groups"));
635 assert!(src.contains("output[0]"));
636 }
637
638 #[test]
641 fn wgsl_binary_add() {
642 let src = binary_wgsl("add");
643 assert!(src.contains("a + b"));
644 assert!(src.contains("lhs"));
645 assert!(src.contains("rhs"));
646 }
647
648 #[test]
649 fn wgsl_binary_all_ops() {
650 assert!(binary_wgsl("sub").contains("a - b"));
651 assert!(binary_wgsl("mul").contains("a * b"));
652 assert!(binary_wgsl("div").contains("a / b"));
653 assert!(binary_wgsl("max").contains("max(a, b)"));
654 assert!(binary_wgsl("min").contains("min(a, b)"));
655 assert!(binary_wgsl("pow").contains("pow(a, b)"));
656 assert!(binary_wgsl("unknown_op").contains("output[i] = a;"));
658 }
659
660 #[test]
661 fn wgsl_binary_workgroup_size() {
662 let src = binary_wgsl("add");
663 assert!(src.contains("@workgroup_size(256)"));
664 }
665
666 #[test]
669 fn wgsl_conv2d_contains_workgroup() {
670 let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
671 assert!(src.contains("@compute @workgroup_size(8, 8)"));
672 }
673
674 #[test]
675 fn wgsl_conv2d_contains_storage_bindings() {
676 let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
677 assert!(src.contains("var<storage, read> input:"));
678 assert!(src.contains("var<storage, read> filter:"));
679 assert!(src.contains("var<storage, read_write> output:"));
680 }
681
682 #[test]
683 fn wgsl_conv2d_embeds_dimensions() {
684 let src = conv2d_wgsl(2, 8, 64, 64, 32, 5, 5, 60, 60, 1, 1, 0, 0);
685 assert!(src.contains("8u")); assert!(src.contains("64u")); assert!(src.contains("32u")); assert!(src.contains("5u")); assert!(src.contains("60u")); }
692
693 #[test]
694 fn wgsl_conv2d_has_padding_check() {
695 let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 8, 8, 1, 1, 1, 1);
696 assert!(src.contains("iy_raw >= 0"));
698 assert!(src.contains("ix_raw >= 0"));
699 }
700
701 #[test]
702 fn wgsl_conv2d_has_stride() {
703 let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 3, 3, 2, 2, 0, 0);
704 assert!(src.contains("2u")); }
706
707 #[test]
710 fn wgsl_attention_contains_workgroup() {
711 let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
712 assert!(src.contains("@compute @workgroup_size(64)"));
713 }
714
715 #[test]
716 fn wgsl_attention_contains_storage_bindings() {
717 let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
718 assert!(src.contains("var<storage, read> q_buf:"));
719 assert!(src.contains("var<storage, read> k_buf:"));
720 assert!(src.contains("var<storage, read> v_buf:"));
721 assert!(src.contains("var<storage, read_write> o_buf:"));
722 }
723
724 #[test]
725 fn wgsl_attention_stable_softmax() {
726 let src = attention_wgsl(1, 4, 4, 32, 0.25, false);
727 assert!(src.contains("max_score"));
728 assert!(src.contains("exp(score - max_score)"));
729 assert!(src.contains("sum_exp"));
730 }
731
732 #[test]
733 fn wgsl_attention_causal_mask() {
734 let src_causal = attention_wgsl(1, 4, 4, 32, 0.25, true);
735 assert!(src_causal.contains("sk > sq"));
736
737 let src_non_causal = attention_wgsl(1, 4, 4, 32, 0.25, false);
738 assert!(!src_non_causal.contains("sk > sq"));
739 }
740
741 #[test]
742 fn wgsl_attention_embeds_scale() {
743 let src = attention_wgsl(2, 16, 16, 64, 0.125, false);
744 assert!(src.contains("0.125"));
745 }
746
747 #[test]
750 fn wgsl_batched_gemm_contains_batch_params() {
751 let src = batched_gemm_wgsl(16);
752 assert!(src.contains("batch_count"));
753 assert!(src.contains("stride_a"));
754 assert!(src.contains("stride_b"));
755 assert!(src.contains("stride_c"));
756 }
757
758 #[test]
759 fn wgsl_batched_gemm_contains_workgroup() {
760 let src = batched_gemm_wgsl(16);
761 assert!(src.contains("@compute @workgroup_size(16, 16)"));
762 assert!(src.contains("BatchedGemmParams"));
763 }
764
765 #[test]
766 fn wgsl_batched_gemm_uses_batch_index() {
767 let src = batched_gemm_wgsl(8);
768 assert!(src.contains("batch_index"));
769 assert!(src.contains("gid.z"));
770 }
771
772 #[test]
773 fn wgsl_batched_gemm_tile_size_embedded() {
774 let src8 = batched_gemm_wgsl(8);
775 assert!(src8.contains("@workgroup_size(8, 8)"));
776 let src32 = batched_gemm_wgsl(32);
777 assert!(src32.contains("@workgroup_size(32, 32)"));
778 }
779
780 #[test]
783 fn wgsl_gemm_f16_enables_extension() {
784 let src = gemm_wgsl_f16(16);
785 assert!(src.contains("enable f16;"));
786 }
787
788 #[test]
789 fn wgsl_gemm_f16_uses_f16_storage() {
790 let src = gemm_wgsl_f16(16);
791 assert!(src.contains("array<f16>"));
792 }
793
794 #[test]
795 fn wgsl_gemm_f16_accumulates_in_f32() {
796 let src = gemm_wgsl_f16(16);
797 assert!(src.contains("var acc: f32 = 0.0;"));
798 assert!(src.contains("f32(a["));
799 assert!(src.contains("f32(b["));
800 }
801
802 #[test]
803 fn wgsl_gemm_f16_contains_workgroup() {
804 let src = gemm_wgsl_f16(8);
805 assert!(src.contains("@compute @workgroup_size(8, 8)"));
806 assert!(src.contains("GemmParams"));
807 }
808
809 #[test]
810 fn wgsl_attention_embeds_dimensions() {
811 let src = attention_wgsl(8, 32, 32, 128, 0.088, true);
812 assert!(src.contains("128u")); assert!(src.contains("32u")); assert!(src.contains("8u")); }
816}