1pub fn gemm_wgsl(tile_size: u32) -> String {
14 format!(
15 r#"
16struct GemmParams {{
17 m: u32,
18 n: u32,
19 k: u32,
20 alpha: f32,
21 beta: f32,
22}}
23
24@group(0) @binding(0) var<storage, read> a: array<f32>;
25@group(0) @binding(1) var<storage, read> b: array<f32>;
26@group(0) @binding(2) var<storage, read_write> c: array<f32>;
27@group(0) @binding(3) var<uniform> params: GemmParams;
28
29@compute @workgroup_size({ts}, {ts})
30fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
31 let row = gid.y;
32 let col = gid.x;
33 if (row >= params.m || col >= params.n) {{ return; }}
34
35 var acc: f32 = 0.0;
36 for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
37 acc += a[row * params.k + i] * b[i * params.n + col];
38 }}
39
40 let idx = row * params.n + col;
41 c[idx] = params.alpha * acc + params.beta * c[idx];
42}}
43"#,
44 ts = tile_size
45 )
46}
47
48pub fn elementwise_wgsl(op: &str) -> String {
58 let op_expr = match op {
59 "relu" => "max(x, 0.0)",
60 "sigmoid" => "1.0 / (1.0 + exp(-x))",
61 "tanh" => "tanh(x)",
62 "exp" => "exp(x)",
63 "log" => "log(x)",
64 "sqrt" => "sqrt(x)",
65 "abs" => "abs(x)",
66 "neg" => "-x",
67 _ => "x",
68 };
69
70 format!(
71 r#"
72@group(0) @binding(0) var<storage, read> input: array<f32>;
73@group(0) @binding(1) var<storage, read_write> output: array<f32>;
74
75@compute @workgroup_size(256)
76fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
77 let i = gid.x;
78 if (i >= arrayLength(&input)) {{ return; }}
79 let x = input[i];
80 output[i] = {op};
81}}
82"#,
83 op = op_expr
84 )
85}
86
87pub fn binary_wgsl(op: &str) -> String {
97 let op_expr = match op {
98 "add" => "a + b",
99 "sub" => "a - b",
100 "mul" => "a * b",
101 "div" => "a / b",
102 "max" => "max(a, b)",
103 "min" => "min(a, b)",
104 "pow" => "pow(a, b)",
105 _ => "a",
106 };
107
108 format!(
109 r#"
110@group(0) @binding(0) var<storage, read> lhs: array<f32>;
111@group(0) @binding(1) var<storage, read> rhs: array<f32>;
112@group(0) @binding(2) var<storage, read_write> output: array<f32>;
113
114@compute @workgroup_size(256)
115fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
116 let i = gid.x;
117 if (i >= arrayLength(&lhs)) {{ return; }}
118 let a = lhs[i];
119 let b = rhs[i];
120 output[i] = {op};
121}}
122"#,
123 op = op_expr
124 )
125}
126
127pub fn reduction_wgsl(op: &str) -> String {
140 let (neutral, combine) = match op {
142 "max" => ("f32(-1e38)", "max(acc, val)"),
143 "min" => ("f32(1e38)", "min(acc, val)"),
144 _ => ("f32(0.0)", "acc + val"),
146 };
147
148 format!(
149 r#"
150// Reduction params: total element count.
151struct ReduceParams {{
152 n: u32,
153}}
154
155@group(0) @binding(0) var<storage, read> input: array<f32>;
156@group(0) @binding(1) var<storage, read_write> partial_sums: array<f32>;
157@group(0) @binding(2) var<uniform> params: ReduceParams;
158
159var<workgroup> shared_data: array<f32, 256>;
160
161@compute @workgroup_size(256)
162fn main(
163 @builtin(global_invocation_id) gid: vec3<u32>,
164 @builtin(local_invocation_id) lid: vec3<u32>,
165 @builtin(workgroup_id) wgid: vec3<u32>,
166) {{
167 let tid = lid.x;
168 let global_idx = gid.x;
169
170 // Load or use neutral element when out of range.
171 if (global_idx < params.n) {{
172 shared_data[tid] = input[global_idx];
173 }} else {{
174 shared_data[tid] = {neutral};
175 }}
176 workgroupBarrier();
177
178 // Parallel tree reduction within the workgroup.
179 var stride: u32 = 128u;
180 loop {{
181 if (stride == 0u) {{ break; }}
182 if (tid < stride) {{
183 let acc = shared_data[tid];
184 let val = shared_data[tid + stride];
185 shared_data[tid] = {combine};
186 }}
187 workgroupBarrier();
188 stride = stride >> 1u;
189 }}
190
191 // Thread 0 writes the workgroup result to the partial-sums buffer.
192 if (tid == 0u) {{
193 partial_sums[wgid.x] = shared_data[0];
194 }}
195}}
196"#,
197 neutral = neutral,
198 combine = combine,
199 )
200}
201
202#[allow(clippy::too_many_arguments)]
219pub fn conv2d_wgsl(
220 n: u32,
221 c_in: u32,
222 h_in: u32,
223 w_in: u32,
224 k_out: u32,
225 fh: u32,
226 fw: u32,
227 oh: u32,
228 ow: u32,
229 stride_h: u32,
230 stride_w: u32,
231 pad_h: u32,
232 pad_w: u32,
233) -> String {
234 format!(
235 r#"
236// Conv2D NCHW — generated by oxicuda-webgpu
237// input : [{n}, {c_in}, {h_in}, {w_in}]
238// filter: [{k_out}, {c_in}, {fh}, {fw}]
239// output: [{n}, {k_out}, {oh}, {ow}]
240
241@group(0) @binding(0) var<storage, read> input: array<f32>;
242@group(0) @binding(1) var<storage, read> filter: array<f32>;
243@group(0) @binding(2) var<storage, read_write> output: array<f32>;
244
245@compute @workgroup_size(8, 8)
246fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
247 // gid.x = output x (ox mapped across batches*k_out*oh)
248 // We flatten (batch, k, oy) into gid.y and ox into gid.x
249 let ox = gid.x;
250 let linear_y = gid.y;
251
252 let batch_k_oh = {n}u * {k_out}u * {oh}u;
253 if (ox >= {ow}u || linear_y >= batch_k_oh) {{ return; }}
254
255 let b = linear_y / ({k_out}u * {oh}u);
256 let rem = linear_y % ({k_out}u * {oh}u);
257 let kf = rem / {oh}u;
258 let oy = rem % {oh}u;
259
260 var acc: f32 = 0.0;
261 for (var ci: u32 = 0u; ci < {c_in}u; ci = ci + 1u) {{
262 for (var fy: u32 = 0u; fy < {fh}u; fy = fy + 1u) {{
263 for (var fx: u32 = 0u; fx < {fw}u; fx = fx + 1u) {{
264 let iy_raw = i32(oy * {stride_h}u + fy) - i32({pad_h}u);
265 let ix_raw = i32(ox * {stride_w}u + fx) - i32({pad_w}u);
266 if (iy_raw >= 0 && iy_raw < i32({h_in}u) && ix_raw >= 0 && ix_raw < i32({w_in}u)) {{
267 let iy = u32(iy_raw);
268 let ix = u32(ix_raw);
269 let in_idx = ((b * {c_in}u + ci) * {h_in}u + iy) * {w_in}u + ix;
270 let f_idx = ((kf * {c_in}u + ci) * {fh}u + fy) * {fw}u + fx;
271 acc += input[in_idx] * filter[f_idx];
272 }}
273 }}
274 }}
275 }}
276
277 let o_idx = ((b * {k_out}u + kf) * {oh}u + oy) * {ow}u + ox;
278 output[o_idx] = acc;
279}}
280"#,
281 n = n,
282 c_in = c_in,
283 h_in = h_in,
284 w_in = w_in,
285 k_out = k_out,
286 fh = fh,
287 fw = fw,
288 oh = oh,
289 ow = ow,
290 stride_h = stride_h,
291 stride_w = stride_w,
292 pad_h = pad_h,
293 pad_w = pad_w,
294 )
295}
296
297pub fn attention_wgsl(
313 batch_heads: u32,
314 seq_q: u32,
315 seq_kv: u32,
316 head_dim: u32,
317 scale: f32,
318 causal: bool,
319) -> String {
320 let causal_check = if causal {
321 "if (sk > sq) { score = f32(-1e38); } else {"
322 } else {
323 "{"
324 };
325
326 format!(
327 r#"
328// Scaled dot-product attention — generated by oxicuda-webgpu
329// Q, K, V : [{batch_heads}, seq, {head_dim}]
330// O : [{batch_heads}, {seq_q}, {head_dim}]
331// scale : {scale}
332// causal : {causal}
333
334@group(0) @binding(0) var<storage, read> q_buf: array<f32>;
335@group(0) @binding(1) var<storage, read> k_buf: array<f32>;
336@group(0) @binding(2) var<storage, read> v_buf: array<f32>;
337@group(0) @binding(3) var<storage, read_write> o_buf: array<f32>;
338
339@compute @workgroup_size(64)
340fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
341 let linear = gid.x;
342 let total = {batch_heads}u * {seq_q}u;
343 if (linear >= total) {{ return; }}
344
345 let bh = linear / {seq_q}u;
346 let sq = linear % {seq_q}u;
347
348 let q_base = (bh * {seq_q}u + sq) * {head_dim}u;
349
350 // Pass 1: find max score for numerical stability
351 var max_score: f32 = f32(-1e38);
352 for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
353 var score: f32 = 0.0;
354 {causal_check}
355 let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
356 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
357 score += q_buf[q_base + d] * k_buf[k_base + d];
358 }}
359 score *= f32({scale});
360 }}
361 if (score > max_score) {{ max_score = score; }}
362 }}
363
364 // Pass 2: compute exp(score - max), accumulate weighted V
365 var sum_exp: f32 = 0.0;
366 for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
367 var score: f32 = 0.0;
368 {causal_check}
369 let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
370 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
371 score += q_buf[q_base + d] * k_buf[k_base + d];
372 }}
373 score *= f32({scale});
374 }}
375 let w = exp(score - max_score);
376 sum_exp += w;
377 let v_base = (bh * {seq_kv}u + sk) * {head_dim}u;
378 let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
379 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
380 // Accumulate in-place (we normalise after the loop).
381 o_buf[o_base + d] += w * v_buf[v_base + d];
382 }}
383 }}
384
385 // Pass 3: normalise
386 if (sum_exp > 0.0) {{
387 let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
388 for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
389 o_buf[o_base + d] /= sum_exp;
390 }}
391 }}
392}}
393"#,
394 batch_heads = batch_heads,
395 seq_q = seq_q,
396 seq_kv = seq_kv,
397 head_dim = head_dim,
398 scale = scale,
399 causal = causal,
400 causal_check = causal_check,
401 )
402}
403
404pub fn reduction_final_wgsl(op: &str) -> String {
410 let (neutral, combine) = match op {
411 "max" => ("f32(-1e38)", "max(acc, val)"),
412 "min" => ("f32(1e38)", "min(acc, val)"),
413 _ => ("f32(0.0)", "acc + val"),
414 };
415
416 format!(
417 r#"
418struct FinalReduceParams {{
419 num_groups: u32,
420}}
421
422@group(0) @binding(0) var<storage, read> partial_sums: array<f32>;
423@group(0) @binding(1) var<storage, read_write> output: array<f32>;
424@group(0) @binding(2) var<uniform> params: FinalReduceParams;
425
426var<workgroup> shared_data: array<f32, 256>;
427
428@compute @workgroup_size(256)
429fn main(
430 @builtin(local_invocation_id) lid: vec3<u32>,
431) {{
432 let tid = lid.x;
433
434 if (tid < params.num_groups) {{
435 shared_data[tid] = partial_sums[tid];
436 }} else {{
437 shared_data[tid] = {neutral};
438 }}
439 workgroupBarrier();
440
441 var stride: u32 = 128u;
442 loop {{
443 if (stride == 0u) {{ break; }}
444 if (tid < stride) {{
445 let acc = shared_data[tid];
446 let val = shared_data[tid + stride];
447 shared_data[tid] = {combine};
448 }}
449 workgroupBarrier();
450 stride = stride >> 1u;
451 }}
452
453 if (tid == 0u) {{
454 output[0] = shared_data[0];
455 }}
456}}
457"#,
458 neutral = neutral,
459 combine = combine,
460 )
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[test]
468 fn wgsl_gemm_contains_workgroup() {
469 let src = gemm_wgsl(16);
470 assert!(src.contains("@compute @workgroup_size(16, 16)"));
471 assert!(src.contains("GemmParams"));
472 assert!(src.contains("alpha"));
473 assert!(src.contains("beta"));
474 }
475
476 #[test]
477 fn wgsl_gemm_tile_size_embedded() {
478 let src8 = gemm_wgsl(8);
479 assert!(src8.contains("@workgroup_size(8, 8)"));
480 let src32 = gemm_wgsl(32);
481 assert!(src32.contains("@workgroup_size(32, 32)"));
482 }
483
484 #[test]
485 fn wgsl_elementwise_relu_contains_max() {
486 let src = elementwise_wgsl("relu");
487 assert!(src.contains("max(x, 0.0)"));
488 }
489
490 #[test]
491 fn wgsl_elementwise_all_ops() {
492 assert!(elementwise_wgsl("sigmoid").contains("exp(-x)"));
493 assert!(elementwise_wgsl("tanh").contains("tanh(x)"));
494 assert!(elementwise_wgsl("exp").contains("exp(x)"));
495 assert!(elementwise_wgsl("log").contains("log(x)"));
496 assert!(elementwise_wgsl("sqrt").contains("sqrt(x)"));
497 assert!(elementwise_wgsl("abs").contains("abs(x)"));
498 assert!(elementwise_wgsl("neg").contains("-x"));
499 assert!(elementwise_wgsl("identity_op").contains("output[i] = x;"));
501 }
502
503 #[test]
504 fn wgsl_reduction_sum_contains_addition() {
505 let src = reduction_wgsl("sum");
506 assert!(src.contains("acc + val"));
507 assert!(src.contains("workgroupBarrier"));
508 }
509
510 #[test]
511 fn wgsl_reduction_max_uses_max_fn() {
512 let src = reduction_wgsl("max");
513 assert!(src.contains("max(acc, val)"));
514 }
515
516 #[test]
517 fn wgsl_reduction_min_uses_min_fn() {
518 let src = reduction_wgsl("min");
519 assert!(src.contains("min(acc, val)"));
520 }
521
522 #[test]
523 fn wgsl_reduction_mean_same_as_sum() {
524 let sum_src = reduction_wgsl("sum");
526 let mean_src = reduction_wgsl("mean");
527 assert_eq!(sum_src, mean_src);
528 }
529
530 #[test]
531 fn wgsl_reduction_final_sum() {
532 let src = reduction_final_wgsl("sum");
533 assert!(src.contains("num_groups"));
534 assert!(src.contains("output[0]"));
535 }
536
537 #[test]
540 fn wgsl_binary_add() {
541 let src = binary_wgsl("add");
542 assert!(src.contains("a + b"));
543 assert!(src.contains("lhs"));
544 assert!(src.contains("rhs"));
545 }
546
547 #[test]
548 fn wgsl_binary_all_ops() {
549 assert!(binary_wgsl("sub").contains("a - b"));
550 assert!(binary_wgsl("mul").contains("a * b"));
551 assert!(binary_wgsl("div").contains("a / b"));
552 assert!(binary_wgsl("max").contains("max(a, b)"));
553 assert!(binary_wgsl("min").contains("min(a, b)"));
554 assert!(binary_wgsl("pow").contains("pow(a, b)"));
555 assert!(binary_wgsl("unknown_op").contains("output[i] = a;"));
557 }
558
559 #[test]
560 fn wgsl_binary_workgroup_size() {
561 let src = binary_wgsl("add");
562 assert!(src.contains("@workgroup_size(256)"));
563 }
564
565 #[test]
568 fn wgsl_conv2d_contains_workgroup() {
569 let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
570 assert!(src.contains("@compute @workgroup_size(8, 8)"));
571 }
572
573 #[test]
574 fn wgsl_conv2d_contains_storage_bindings() {
575 let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
576 assert!(src.contains("var<storage, read> input:"));
577 assert!(src.contains("var<storage, read> filter:"));
578 assert!(src.contains("var<storage, read_write> output:"));
579 }
580
581 #[test]
582 fn wgsl_conv2d_embeds_dimensions() {
583 let src = conv2d_wgsl(2, 8, 64, 64, 32, 5, 5, 60, 60, 1, 1, 0, 0);
584 assert!(src.contains("8u")); assert!(src.contains("64u")); assert!(src.contains("32u")); assert!(src.contains("5u")); assert!(src.contains("60u")); }
591
592 #[test]
593 fn wgsl_conv2d_has_padding_check() {
594 let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 8, 8, 1, 1, 1, 1);
595 assert!(src.contains("iy_raw >= 0"));
597 assert!(src.contains("ix_raw >= 0"));
598 }
599
600 #[test]
601 fn wgsl_conv2d_has_stride() {
602 let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 3, 3, 2, 2, 0, 0);
603 assert!(src.contains("2u")); }
605
606 #[test]
609 fn wgsl_attention_contains_workgroup() {
610 let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
611 assert!(src.contains("@compute @workgroup_size(64)"));
612 }
613
614 #[test]
615 fn wgsl_attention_contains_storage_bindings() {
616 let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
617 assert!(src.contains("var<storage, read> q_buf:"));
618 assert!(src.contains("var<storage, read> k_buf:"));
619 assert!(src.contains("var<storage, read> v_buf:"));
620 assert!(src.contains("var<storage, read_write> o_buf:"));
621 }
622
623 #[test]
624 fn wgsl_attention_stable_softmax() {
625 let src = attention_wgsl(1, 4, 4, 32, 0.25, false);
626 assert!(src.contains("max_score"));
627 assert!(src.contains("exp(score - max_score)"));
628 assert!(src.contains("sum_exp"));
629 }
630
631 #[test]
632 fn wgsl_attention_causal_mask() {
633 let src_causal = attention_wgsl(1, 4, 4, 32, 0.25, true);
634 assert!(src_causal.contains("sk > sq"));
635
636 let src_non_causal = attention_wgsl(1, 4, 4, 32, 0.25, false);
637 assert!(!src_non_causal.contains("sk > sq"));
638 }
639
640 #[test]
641 fn wgsl_attention_embeds_scale() {
642 let src = attention_wgsl(2, 16, 16, 64, 0.125, false);
643 assert!(src.contains("0.125"));
644 }
645
646 #[test]
647 fn wgsl_attention_embeds_dimensions() {
648 let src = attention_wgsl(8, 32, 32, 128, 0.088, true);
649 assert!(src.contains("128u")); assert!(src.contains("32u")); assert!(src.contains("8u")); }
653}