Skip to main content

oxicuda_webgpu/
shader.rs

1//! WGSL shader source generation for common compute kernels.
2//!
3//! Each function returns a complete, self-contained WGSL source string
4//! suitable for passing to `device.create_shader_module()`.
5
6/// Generate WGSL source for a tiled GEMM kernel: `C = alpha * op(A) * op(B) + beta * C`.
7///
8/// Uses `tile_size × tile_size` workgroup tiles with shared-memory staging.
9///
10/// `op(A)` is the logical `m × k` left operand, `op(B)` the logical `k × n`
11/// right operand.  The physical layout of the stored buffers depends on the
12/// transpose flags carried in `GemmParams`:
13///
14/// * `trans_a == 0` — `a` is stored row-major as `m × k`; element `(r, i)` is
15///   at `a[r * k + i]`.
16/// * `trans_a != 0` — `a` is stored row-major as `k × m` (the transpose of the
17///   logical operand); element `(r, i)` of `op(A)` is at `a[i * m + r]`.
18/// * `trans_b == 0` — `b` is stored row-major as `k × n`; element `(i, c)` is
19///   at `b[i * n + c]`.
20/// * `trans_b != 0` — `b` is stored row-major as `n × k`; element `(i, c)` of
21///   `op(B)` is at `b[c * k + i]`.
22///
23/// The transpose flags are runtime uniforms, so a single shader module serves
24/// all four NN / NT / TN / TT combinations.
25///
26/// # Arguments
27///
28/// * `tile_size` — workgroup tile dimension (e.g. 8, 16, 32).
29pub fn gemm_wgsl(tile_size: u32) -> String {
30    format!(
31        r#"
32struct GemmParams {{
33    m:       u32,
34    n:       u32,
35    k:       u32,
36    alpha:   f32,
37    beta:    f32,
38    trans_a: u32,
39    trans_b: u32,
40    _pad:    u32,
41}}
42
43@group(0) @binding(0) var<storage, read>       a:      array<f32>;
44@group(0) @binding(1) var<storage, read>       b:      array<f32>;
45@group(0) @binding(2) var<storage, read_write> c:      array<f32>;
46@group(0) @binding(3) var<uniform>             params: GemmParams;
47
48var<workgroup> tile_a: array<array<f32, {ts}>, {ts}>;
49var<workgroup> tile_b: array<array<f32, {ts}>, {ts}>;
50
51// op(A)[r, i] — logical m×k left operand.
52fn load_a(r: u32, i: u32) -> f32 {{
53    if (r >= params.m || i >= params.k) {{ return 0.0; }}
54    if (params.trans_a == 0u) {{
55        return a[r * params.k + i];
56    }}
57    return a[i * params.m + r];
58}}
59
60// op(B)[i, col] — logical k×n right operand.
61fn load_b(i: u32, col: u32) -> f32 {{
62    if (i >= params.k || col >= params.n) {{ return 0.0; }}
63    if (params.trans_b == 0u) {{
64        return b[i * params.n + col];
65    }}
66    return b[col * params.k + i];
67}}
68
69@compute @workgroup_size({ts}, {ts})
70fn main(
71    @builtin(global_invocation_id) gid: vec3<u32>,
72    @builtin(local_invocation_id)  lid: vec3<u32>,
73) {{
74    let row = gid.y;
75    let col = gid.x;
76    let lr  = lid.y;
77    let lc  = lid.x;
78
79    var acc: f32 = 0.0;
80    let num_tiles = (params.k + {ts}u - 1u) / {ts}u;
81    for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {{
82        let a_col = t * {ts}u + lc;
83        let b_row = t * {ts}u + lr;
84        tile_a[lr][lc] = load_a(row, a_col);
85        tile_b[lr][lc] = load_b(b_row, col);
86        workgroupBarrier();
87
88        for (var e: u32 = 0u; e < {ts}u; e = e + 1u) {{
89            acc += tile_a[lr][e] * tile_b[e][lc];
90        }}
91        workgroupBarrier();
92    }}
93
94    if (row >= params.m || col >= params.n) {{ return; }}
95    let idx = row * params.n + col;
96    c[idx] = params.alpha * acc + params.beta * c[idx];
97}}
98"#,
99        ts = tile_size
100    )
101}
102
103/// Generate WGSL source for a batched (strided) GEMM kernel.
104///
105/// For each batch `b` in `0..batch_count`:
106///   `C_b = alpha * op(A_b) * op(B_b) + beta * C_b`
107/// where `A_b` starts at `a[b * stride_a]`, etc.
108///
109/// Uses `tile_size × tile_size` workgroup tiles with shared-memory staging and
110/// Z = batch_count.  Transpose handling matches [`gemm_wgsl`]: the `trans_a` /
111/// `trans_b` uniforms select a row-major (`m × k` / `k × n`) or column-major
112/// (`k × m` / `n × k`) physical layout for each per-batch operand.
113///
114/// # Arguments
115///
116/// * `tile_size` — workgroup tile dimension (e.g. 8, 16, 32).
117pub fn batched_gemm_wgsl(tile_size: u32) -> String {
118    format!(
119        r#"
120struct BatchedGemmParams {{
121    m:        u32,
122    n:        u32,
123    k:        u32,
124    alpha:    f32,
125    beta:     f32,
126    batch_count: u32,
127    stride_a: u32,
128    stride_b: u32,
129    stride_c: u32,
130    trans_a:  u32,
131    trans_b:  u32,
132}}
133
134@group(0) @binding(0) var<storage, read>       a:      array<f32>;
135@group(0) @binding(1) var<storage, read>       b:      array<f32>;
136@group(0) @binding(2) var<storage, read_write> c:      array<f32>;
137@group(0) @binding(3) var<uniform>             params: BatchedGemmParams;
138
139var<workgroup> tile_a: array<array<f32, {ts}>, {ts}>;
140var<workgroup> tile_b: array<array<f32, {ts}>, {ts}>;
141
142// op(A_b)[r, i] — logical m×k left operand for batch `a_offset`.
143fn load_a(a_offset: u32, r: u32, i: u32) -> f32 {{
144    if (r >= params.m || i >= params.k) {{ return 0.0; }}
145    if (params.trans_a == 0u) {{
146        return a[a_offset + r * params.k + i];
147    }}
148    return a[a_offset + i * params.m + r];
149}}
150
151// op(B_b)[i, col] — logical k×n right operand for batch `b_offset`.
152fn load_b(b_offset: u32, i: u32, col: u32) -> f32 {{
153    if (i >= params.k || col >= params.n) {{ return 0.0; }}
154    if (params.trans_b == 0u) {{
155        return b[b_offset + i * params.n + col];
156    }}
157    return b[b_offset + col * params.k + i];
158}}
159
160@compute @workgroup_size({ts}, {ts})
161fn main(
162    @builtin(global_invocation_id) gid: vec3<u32>,
163    @builtin(local_invocation_id)  lid: vec3<u32>,
164) {{
165    let row = gid.y;
166    let col = gid.x;
167    let batch_index = gid.z;
168    let lr  = lid.y;
169    let lc  = lid.x;
170    if (batch_index >= params.batch_count) {{ return; }}
171
172    let a_offset = batch_index * params.stride_a;
173    let b_offset = batch_index * params.stride_b;
174    let c_offset = batch_index * params.stride_c;
175
176    var acc: f32 = 0.0;
177    let num_tiles = (params.k + {ts}u - 1u) / {ts}u;
178    for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {{
179        let a_col = t * {ts}u + lc;
180        let b_row = t * {ts}u + lr;
181        tile_a[lr][lc] = load_a(a_offset, row, a_col);
182        tile_b[lr][lc] = load_b(b_offset, b_row, col);
183        workgroupBarrier();
184
185        for (var e: u32 = 0u; e < {ts}u; e = e + 1u) {{
186            acc += tile_a[lr][e] * tile_b[e][lc];
187        }}
188        workgroupBarrier();
189    }}
190
191    if (row >= params.m || col >= params.n) {{ return; }}
192    let idx = c_offset + row * params.n + col;
193    c[idx] = params.alpha * acc + params.beta * c[idx];
194}}
195"#,
196        ts = tile_size
197    )
198}
199
200/// Generate WGSL source for a tiled GEMM kernel using FP16 storage.
201///
202/// Uses `enable f16;` WGSL extension. Storage buffers use `array<f16>`,
203/// but accumulation is done in f32 for precision.
204///
205/// # Arguments
206///
207/// * `tile_size` — workgroup tile dimension (e.g. 8, 16, 32).
208pub fn gemm_wgsl_f16(tile_size: u32) -> String {
209    format!(
210        r#"
211enable f16;
212
213struct GemmParams {{
214    m:     u32,
215    n:     u32,
216    k:     u32,
217    alpha: f32,
218    beta:  f32,
219}}
220
221@group(0) @binding(0) var<storage, read>       a:      array<f16>;
222@group(0) @binding(1) var<storage, read>       b:      array<f16>;
223@group(0) @binding(2) var<storage, read_write> c:      array<f16>;
224@group(0) @binding(3) var<uniform>             params: GemmParams;
225
226@compute @workgroup_size({ts}, {ts})
227fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
228    let row = gid.y;
229    let col = gid.x;
230    if (row >= params.m || col >= params.n) {{ return; }}
231
232    var acc: f32 = 0.0;
233    for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
234        acc += f32(a[row * params.k + i]) * f32(b[i * params.n + col]);
235    }}
236
237    let idx = row * params.n + col;
238    let prev = f32(c[idx]);
239    c[idx] = f16(params.alpha * acc + params.beta * prev);
240}}
241"#,
242        ts = tile_size
243    )
244}
245
246/// Generate WGSL source for an element-wise unary operation.
247///
248/// The shader reads `n` elements from `input`, applies the operation, and
249/// writes the results to `output`.  Both buffers have `arrayLength` elements.
250///
251/// # Arguments
252///
253/// * `op` — one of: `"relu"`, `"sigmoid"`, `"tanh"`, `"exp"`, `"log"`,
254///   `"sqrt"`, `"abs"`, `"neg"`.  Unknown ops are treated as identity.
255pub fn elementwise_wgsl(op: &str) -> String {
256    let op_expr = match op {
257        "relu" => "max(x, 0.0)",
258        "sigmoid" => "1.0 / (1.0 + exp(-x))",
259        "tanh" => "tanh(x)",
260        "exp" => "exp(x)",
261        "log" => "log(x)",
262        "sqrt" => "sqrt(x)",
263        "abs" => "abs(x)",
264        "neg" => "-x",
265        _ => "x",
266    };
267
268    format!(
269        r#"
270@group(0) @binding(0) var<storage, read>       input:  array<f32>;
271@group(0) @binding(1) var<storage, read_write> output: array<f32>;
272
273@compute @workgroup_size(256)
274fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
275    let i = gid.x;
276    if (i >= arrayLength(&input)) {{ return; }}
277    let x = input[i];
278    output[i] = {op};
279}}
280"#,
281        op = op_expr
282    )
283}
284
285/// Generate WGSL source for an element-wise binary operation.
286///
287/// The shader reads `n` elements from two input buffers (`lhs` and `rhs`),
288/// applies the operation, and writes the results to `output`.
289///
290/// # Arguments
291///
292/// * `op` — one of: `"add"`, `"sub"`, `"mul"`, `"div"`, `"max"`, `"min"`,
293///   `"pow"`.  Unknown ops fall back to identity on `lhs`.
294pub fn binary_wgsl(op: &str) -> String {
295    let op_expr = match op {
296        "add" => "a + b",
297        "sub" => "a - b",
298        "mul" => "a * b",
299        "div" => "a / b",
300        "max" => "max(a, b)",
301        "min" => "min(a, b)",
302        "pow" => "pow(a, b)",
303        _ => "a",
304    };
305
306    format!(
307        r#"
308@group(0) @binding(0) var<storage, read>       lhs:    array<f32>;
309@group(0) @binding(1) var<storage, read>       rhs:    array<f32>;
310@group(0) @binding(2) var<storage, read_write> output: array<f32>;
311
312@compute @workgroup_size(256)
313fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
314    let i = gid.x;
315    if (i >= arrayLength(&lhs)) {{ return; }}
316    let a = lhs[i];
317    let b = rhs[i];
318    output[i] = {op};
319}}
320"#,
321        op = op_expr
322    )
323}
324
325/// Generate WGSL source for a parallel workgroup-level reduction.
326///
327/// Performs a two-pass approach: each workgroup of 256 threads reduces its
328/// tile to a single value in shared memory, then the results are written to
329/// a partial-sums buffer.  A second dispatch (with a single workgroup) then
330/// reduces the partial-sums to the final scalar.
331///
332/// # Arguments
333///
334/// * `op` — one of: `"sum"`, `"max"`, `"min"`, `"mean"`.  `"mean"` behaves
335///   like `"sum"` in the shader; the CPU is responsible for dividing by N.
336///   Unknown ops fall back to `"sum"`.
337pub fn reduction_wgsl(op: &str) -> String {
338    // Neutral elements and combine expressions for each operation.
339    let (neutral, combine) = match op {
340        "max" => ("f32(-1e38)", "max(acc, val)"),
341        "min" => ("f32(1e38)", "min(acc, val)"),
342        // "sum" and "mean" use the same reduction body.
343        _ => ("f32(0.0)", "acc + val"),
344    };
345
346    format!(
347        r#"
348// Reduction params: total element count.
349struct ReduceParams {{
350    n: u32,
351}}
352
353@group(0) @binding(0) var<storage, read>       input:        array<f32>;
354@group(0) @binding(1) var<storage, read_write> partial_sums: array<f32>;
355@group(0) @binding(2) var<uniform>             params:       ReduceParams;
356
357var<workgroup> shared_data: array<f32, 256>;
358
359@compute @workgroup_size(256)
360fn main(
361    @builtin(global_invocation_id) gid:  vec3<u32>,
362    @builtin(local_invocation_id)  lid:  vec3<u32>,
363    @builtin(workgroup_id)         wgid: vec3<u32>,
364) {{
365    let tid         = lid.x;
366    let global_idx  = gid.x;
367
368    // Load or use neutral element when out of range.
369    if (global_idx < params.n) {{
370        shared_data[tid] = input[global_idx];
371    }} else {{
372        shared_data[tid] = {neutral};
373    }}
374    workgroupBarrier();
375
376    // Parallel tree reduction within the workgroup.
377    var stride: u32 = 128u;
378    loop {{
379        if (stride == 0u) {{ break; }}
380        if (tid < stride) {{
381            let acc = shared_data[tid];
382            let val = shared_data[tid + stride];
383            shared_data[tid] = {combine};
384        }}
385        workgroupBarrier();
386        stride = stride >> 1u;
387    }}
388
389    // Thread 0 writes the workgroup result to the partial-sums buffer.
390    if (tid == 0u) {{
391        partial_sums[wgid.x] = shared_data[0];
392    }}
393}}
394"#,
395        neutral = neutral,
396        combine = combine,
397    )
398}
399
400/// Generate a WGSL compute shader for 2D convolution in NCHW format.
401///
402/// The shader reads from `input` (NCHW) and `filter` (K×C×FH×FW), writing
403/// the result to `output` (N×K×OH×OW).  Padding is handled via bounds
404/// checking — out-of-range input positions contribute zero.
405///
406/// # Arguments
407///
408/// * `n` — batch size
409/// * `c_in` — number of input channels
410/// * `h_in`, `w_in` — spatial input dimensions
411/// * `k_out` — number of output channels (filters)
412/// * `fh`, `fw` — filter height / width
413/// * `oh`, `ow` — output height / width
414/// * `stride_h`, `stride_w` — convolution strides
415/// * `pad_h`, `pad_w` — zero-padding applied to the input
416#[allow(clippy::too_many_arguments)]
417pub fn conv2d_wgsl(
418    n: u32,
419    c_in: u32,
420    h_in: u32,
421    w_in: u32,
422    k_out: u32,
423    fh: u32,
424    fw: u32,
425    oh: u32,
426    ow: u32,
427    stride_h: u32,
428    stride_w: u32,
429    pad_h: u32,
430    pad_w: u32,
431) -> String {
432    format!(
433        r#"
434// Conv2D NCHW — generated by oxicuda-webgpu
435// input : [{n}, {c_in}, {h_in}, {w_in}]
436// filter: [{k_out}, {c_in}, {fh}, {fw}]
437// output: [{n}, {k_out}, {oh}, {ow}]
438
439@group(0) @binding(0) var<storage, read>       input:  array<f32>;
440@group(0) @binding(1) var<storage, read>       filter: array<f32>;
441@group(0) @binding(2) var<storage, read_write> output: array<f32>;
442
443@compute @workgroup_size(8, 8)
444fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
445    // gid.x = output x (ox mapped across batches*k_out*oh)
446    // We flatten (batch, k, oy) into gid.y and ox into gid.x
447    let ox = gid.x;
448    let linear_y = gid.y;
449
450    let batch_k_oh = {n}u * {k_out}u * {oh}u;
451    if (ox >= {ow}u || linear_y >= batch_k_oh) {{ return; }}
452
453    let b  = linear_y / ({k_out}u * {oh}u);
454    let rem = linear_y % ({k_out}u * {oh}u);
455    let kf = rem / {oh}u;
456    let oy = rem % {oh}u;
457
458    var acc: f32 = 0.0;
459    for (var ci: u32 = 0u; ci < {c_in}u; ci = ci + 1u) {{
460        for (var fy: u32 = 0u; fy < {fh}u; fy = fy + 1u) {{
461            for (var fx: u32 = 0u; fx < {fw}u; fx = fx + 1u) {{
462                let iy_raw = i32(oy * {stride_h}u + fy) - i32({pad_h}u);
463                let ix_raw = i32(ox * {stride_w}u + fx) - i32({pad_w}u);
464                if (iy_raw >= 0 && iy_raw < i32({h_in}u) && ix_raw >= 0 && ix_raw < i32({w_in}u)) {{
465                    let iy = u32(iy_raw);
466                    let ix = u32(ix_raw);
467                    let in_idx = ((b * {c_in}u + ci) * {h_in}u + iy) * {w_in}u + ix;
468                    let f_idx  = ((kf * {c_in}u + ci) * {fh}u + fy) * {fw}u + fx;
469                    acc += input[in_idx] * filter[f_idx];
470                }}
471            }}
472        }}
473    }}
474
475    let o_idx = ((b * {k_out}u + kf) * {oh}u + oy) * {ow}u + ox;
476    output[o_idx] = acc;
477}}
478"#,
479        n = n,
480        c_in = c_in,
481        h_in = h_in,
482        w_in = w_in,
483        k_out = k_out,
484        fh = fh,
485        fw = fw,
486        oh = oh,
487        ow = ow,
488        stride_h = stride_h,
489        stride_w = stride_w,
490        pad_h = pad_h,
491        pad_w = pad_w,
492    )
493}
494
495/// Generate a WGSL compute shader for scaled dot-product attention.
496///
497/// Implements: `O = softmax(Q·K^T * scale [+ causal_mask]) · V`
498///
499/// The softmax is numerically stable (subtracts max before exp).
500/// When `causal` is true, positions where `sk > sq` are masked to −∞.
501///
502/// # Arguments
503///
504/// * `batch_heads` — combined batch × heads dimension
505/// * `seq_q` — query sequence length
506/// * `seq_kv` — key/value sequence length
507/// * `head_dim` — dimension of each head
508/// * `scale` — scaling factor (typically `1 / sqrt(head_dim)`)
509/// * `causal` — whether to apply a causal (upper-triangular) mask
510pub fn attention_wgsl(
511    batch_heads: u32,
512    seq_q: u32,
513    seq_kv: u32,
514    head_dim: u32,
515    scale: f32,
516    causal: bool,
517) -> String {
518    let causal_check = if causal {
519        "if (sk > sq) { score = f32(-1e38); } else {"
520    } else {
521        "{"
522    };
523
524    format!(
525        r#"
526// Scaled dot-product attention — generated by oxicuda-webgpu
527// Q, K, V : [{batch_heads}, seq, {head_dim}]
528// O       : [{batch_heads}, {seq_q}, {head_dim}]
529// scale   : {scale}
530// causal  : {causal}
531
532@group(0) @binding(0) var<storage, read>       q_buf: array<f32>;
533@group(0) @binding(1) var<storage, read>       k_buf: array<f32>;
534@group(0) @binding(2) var<storage, read>       v_buf: array<f32>;
535@group(0) @binding(3) var<storage, read_write> o_buf: array<f32>;
536
537@compute @workgroup_size(64)
538fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
539    let linear = gid.x;
540    let total = {batch_heads}u * {seq_q}u;
541    if (linear >= total) {{ return; }}
542
543    let bh = linear / {seq_q}u;
544    let sq = linear % {seq_q}u;
545
546    let q_base = (bh * {seq_q}u + sq) * {head_dim}u;
547
548    // Pass 1: find max score for numerical stability
549    var max_score: f32 = f32(-1e38);
550    for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
551        var score: f32 = 0.0;
552        {causal_check}
553            let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
554            for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
555                score += q_buf[q_base + d] * k_buf[k_base + d];
556            }}
557            score *= f32({scale});
558        }}
559        if (score > max_score) {{ max_score = score; }}
560    }}
561
562    // Pass 2: compute exp(score - max), accumulate weighted V
563    var sum_exp: f32 = 0.0;
564    for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
565        var score: f32 = 0.0;
566        {causal_check}
567            let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
568            for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
569                score += q_buf[q_base + d] * k_buf[k_base + d];
570            }}
571            score *= f32({scale});
572        }}
573        let w = exp(score - max_score);
574        sum_exp += w;
575        let v_base = (bh * {seq_kv}u + sk) * {head_dim}u;
576        let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
577        for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
578            // Accumulate in-place (we normalise after the loop).
579            o_buf[o_base + d] += w * v_buf[v_base + d];
580        }}
581    }}
582
583    // Pass 3: normalise
584    if (sum_exp > 0.0) {{
585        let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
586        for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
587            o_buf[o_base + d] /= sum_exp;
588        }}
589    }}
590}}
591"#,
592        batch_heads = batch_heads,
593        seq_q = seq_q,
594        seq_kv = seq_kv,
595        head_dim = head_dim,
596        scale = scale,
597        causal = causal,
598        causal_check = causal_check,
599    )
600}
601
602/// Generate WGSL source for an N-D reduction along a single axis.
603///
604/// The tensor is logically reshaped to `[outer, dk, inner]`, where the reduce
605/// axis spans `dk` elements, `outer` is the product of dimensions before the
606/// axis, and `inner` is the product of dimensions after the axis.
607///
608/// Output shape is `[outer, inner]` (flattened to a 1-D buffer of length
609/// `outer * inner`).  Each output slot is computed by a full workgroup of
610/// 256 threads, which cooperatively reduce the `dk` elements via a strided
611/// loop and a shared-memory tree reduction.
612///
613/// Dispatch must be 2-D: `(grid_x, ceil((outer * inner) / grid_x), 1)` to
614/// stay below WebGPU's per-axis limit of 65 535 workgroups.  The shader
615/// decodes its slot via `wgid.y * params.grid_x + wgid.x` and early-returns
616/// if `slot >= outer * inner`.
617///
618/// For `Mean`, the shader divides each output by `dk` directly (no host-side
619/// post-processing required).
620///
621/// # Arguments
622///
623/// * `op` — one of: `"sum"`, `"max"`, `"min"`, `"mean"`.  Unknown ops fall
624///   back to `"sum"`.
625pub fn reduction_nd_wgsl(op: &str) -> String {
626    // The first form combines `acc` with `val` (per-thread strided loop).
627    // The second form combines `acc2` with `val` (in-shared-memory tree
628    // reduction).  Listing both explicitly is more robust than string-
629    // substitution for future ops.
630    let (neutral, combine, combine_alias) = match op {
631        "max" => ("f32(-1e38)", "max(acc, val)", "max(acc2, val)"),
632        "min" => ("f32(1e38)", "min(acc, val)", "min(acc2, val)"),
633        // "sum" and "mean" use the same combine; "mean" divides at the end.
634        _ => ("f32(0.0)", "acc + val", "acc2 + val"),
635    };
636
637    // For "mean", divide the final reduced value by dk; otherwise pass-through.
638    let final_expr = if op == "mean" {
639        "shared_data[0] / f32(params.dk)"
640    } else {
641        "shared_data[0]"
642    };
643
644    format!(
645        r#"
646struct ReduceNdParams {{
647    outer:        u32,
648    dk:           u32,
649    inner:        u32,
650    outer_stride: u32,
651    dk_stride:    u32,
652    inner_stride: u32,
653    grid_x:       u32,
654    _pad:         u32,
655}}
656
657@group(0) @binding(0) var<storage, read>       input:  array<f32>;
658@group(0) @binding(1) var<storage, read_write> output: array<f32>;
659@group(0) @binding(2) var<uniform>             params: ReduceNdParams;
660
661var<workgroup> shared_data: array<f32, 256>;
662
663@compute @workgroup_size(256)
664fn main(
665    @builtin(local_invocation_id) lid:  vec3<u32>,
666    @builtin(workgroup_id)        wgid: vec3<u32>,
667) {{
668    let tid = lid.x;
669    let total = params.outer * params.inner;
670
671    // Decode 2-D workgroup id back to a linear output slot.
672    let slot = wgid.y * params.grid_x + wgid.x;
673    if (slot >= total) {{ return; }}
674
675    let o = slot / params.inner;
676    let j = slot % params.inner;
677    let base = o * params.outer_stride + j * params.inner_stride;
678
679    // Strided per-thread reduction across the dk axis.
680    var acc: f32 = {neutral};
681    var i: u32 = tid;
682    loop {{
683        if (i >= params.dk) {{ break; }}
684        let val = input[base + i * params.dk_stride];
685        acc = {combine};
686        i = i + 256u;
687    }}
688
689    shared_data[tid] = acc;
690    workgroupBarrier();
691
692    // Tree reduction within the workgroup.
693    var stride: u32 = 128u;
694    loop {{
695        if (stride == 0u) {{ break; }}
696        if (tid < stride) {{
697            let acc2 = shared_data[tid];
698            let val  = shared_data[tid + stride];
699            shared_data[tid] = {combine_alias};
700        }}
701        workgroupBarrier();
702        stride = stride >> 1u;
703    }}
704
705    if (tid == 0u) {{
706        output[slot] = {final_expr};
707    }}
708}}
709"#,
710        neutral = neutral,
711        combine = combine,
712        combine_alias = combine_alias,
713        final_expr = final_expr,
714    )
715}
716
717/// Generate WGSL for the final scalar reduction of partial sums.
718///
719/// Takes a `partial_sums` array of length `num_groups` and reduces it to a
720/// single value at `output[0]`.  Should be dispatched with a single workgroup
721/// of 256 threads.
722pub fn reduction_final_wgsl(op: &str) -> String {
723    let (neutral, combine) = match op {
724        "max" => ("f32(-1e38)", "max(acc, val)"),
725        "min" => ("f32(1e38)", "min(acc, val)"),
726        _ => ("f32(0.0)", "acc + val"),
727    };
728
729    format!(
730        r#"
731struct FinalReduceParams {{
732    num_groups: u32,
733}}
734
735@group(0) @binding(0) var<storage, read>       partial_sums: array<f32>;
736@group(0) @binding(1) var<storage, read_write> output:       array<f32>;
737@group(0) @binding(2) var<uniform>             params:       FinalReduceParams;
738
739var<workgroup> shared_data: array<f32, 256>;
740
741@compute @workgroup_size(256)
742fn main(
743    @builtin(local_invocation_id) lid: vec3<u32>,
744) {{
745    let tid = lid.x;
746
747    if (tid < params.num_groups) {{
748        shared_data[tid] = partial_sums[tid];
749    }} else {{
750        shared_data[tid] = {neutral};
751    }}
752    workgroupBarrier();
753
754    var stride: u32 = 128u;
755    loop {{
756        if (stride == 0u) {{ break; }}
757        if (tid < stride) {{
758            let acc = shared_data[tid];
759            let val = shared_data[tid + stride];
760            shared_data[tid] = {combine};
761        }}
762        workgroupBarrier();
763        stride = stride >> 1u;
764    }}
765
766    if (tid == 0u) {{
767        output[0] = shared_data[0];
768    }}
769}}
770"#,
771        neutral = neutral,
772        combine = combine,
773    )
774}
775
776#[cfg(test)]
777mod tests {
778    use super::*;
779
780    #[test]
781    fn wgsl_gemm_contains_workgroup() {
782        let src = gemm_wgsl(16);
783        assert!(src.contains("@compute @workgroup_size(16, 16)"));
784        assert!(src.contains("GemmParams"));
785        assert!(src.contains("alpha"));
786        assert!(src.contains("beta"));
787    }
788
789    #[test]
790    fn wgsl_gemm_tile_size_embedded() {
791        let src8 = gemm_wgsl(8);
792        assert!(src8.contains("@workgroup_size(8, 8)"));
793        let src32 = gemm_wgsl(32);
794        assert!(src32.contains("@workgroup_size(32, 32)"));
795    }
796
797    #[test]
798    fn wgsl_gemm_has_transpose_flags() {
799        let src = gemm_wgsl(8);
800        // Transpose flags live in the uniform struct.
801        assert!(src.contains("trans_a: u32"));
802        assert!(src.contains("trans_b: u32"));
803        // Both row-major and column-major index forms must be present.
804        assert!(src.contains("a[r * params.k + i]"));
805        assert!(src.contains("a[i * params.m + r]"));
806        assert!(src.contains("b[i * params.n + col]"));
807        assert!(src.contains("b[col * params.k + i]"));
808    }
809
810    #[test]
811    fn wgsl_gemm_uses_shared_memory_tiling() {
812        let src = gemm_wgsl(16);
813        assert!(src.contains("var<workgroup> tile_a"));
814        assert!(src.contains("var<workgroup> tile_b"));
815        assert!(src.contains("workgroupBarrier"));
816        // Tile dimension is embedded in the workgroup-array declaration.
817        assert!(src.contains("array<array<f32, 16>, 16>"));
818    }
819
820    #[test]
821    fn wgsl_elementwise_relu_contains_max() {
822        let src = elementwise_wgsl("relu");
823        assert!(src.contains("max(x, 0.0)"));
824    }
825
826    #[test]
827    fn wgsl_elementwise_all_ops() {
828        assert!(elementwise_wgsl("sigmoid").contains("exp(-x)"));
829        assert!(elementwise_wgsl("tanh").contains("tanh(x)"));
830        assert!(elementwise_wgsl("exp").contains("exp(x)"));
831        assert!(elementwise_wgsl("log").contains("log(x)"));
832        assert!(elementwise_wgsl("sqrt").contains("sqrt(x)"));
833        assert!(elementwise_wgsl("abs").contains("abs(x)"));
834        assert!(elementwise_wgsl("neg").contains("-x"));
835        // Unknown op is identity.
836        assert!(elementwise_wgsl("identity_op").contains("output[i] = x;"));
837    }
838
839    #[test]
840    fn wgsl_reduction_sum_contains_addition() {
841        let src = reduction_wgsl("sum");
842        assert!(src.contains("acc + val"));
843        assert!(src.contains("workgroupBarrier"));
844    }
845
846    #[test]
847    fn wgsl_reduction_max_uses_max_fn() {
848        let src = reduction_wgsl("max");
849        assert!(src.contains("max(acc, val)"));
850    }
851
852    #[test]
853    fn wgsl_reduction_min_uses_min_fn() {
854        let src = reduction_wgsl("min");
855        assert!(src.contains("min(acc, val)"));
856    }
857
858    #[test]
859    fn wgsl_reduction_mean_same_as_sum() {
860        // "mean" divides on the CPU side; the shader is identical to sum.
861        let sum_src = reduction_wgsl("sum");
862        let mean_src = reduction_wgsl("mean");
863        assert_eq!(sum_src, mean_src);
864    }
865
866    #[test]
867    fn wgsl_reduction_final_sum() {
868        let src = reduction_final_wgsl("sum");
869        assert!(src.contains("num_groups"));
870        assert!(src.contains("output[0]"));
871    }
872
873    // ── reduction_nd_wgsl tests ───────────────────────────────────────────
874
875    #[test]
876    fn wgsl_reduction_nd_sum_contains_addition() {
877        let src = reduction_nd_wgsl("sum");
878        assert!(src.contains("acc + val"));
879        // Tree-step reuses the same combine with renamed lhs.
880        assert!(src.contains("acc2 + val"));
881        assert!(src.contains("workgroupBarrier"));
882        assert!(src.contains("ReduceNdParams"));
883    }
884
885    #[test]
886    fn wgsl_reduction_nd_max_uses_max_fn() {
887        let src = reduction_nd_wgsl("max");
888        assert!(src.contains("max(acc, val)"));
889        assert!(src.contains("max(acc2, val)"));
890    }
891
892    #[test]
893    fn wgsl_reduction_nd_min_uses_min_fn() {
894        let src = reduction_nd_wgsl("min");
895        assert!(src.contains("min(acc, val)"));
896        assert!(src.contains("min(acc2, val)"));
897    }
898
899    #[test]
900    fn wgsl_reduction_nd_mean_divides_by_dk() {
901        let src = reduction_nd_wgsl("mean");
902        assert!(src.contains("shared_data[0] / f32(params.dk)"));
903        assert!(src.contains("acc + val"));
904    }
905
906    #[test]
907    fn wgsl_reduction_nd_sum_does_not_divide() {
908        let src = reduction_nd_wgsl("sum");
909        assert!(!src.contains("/ f32(params.dk)"));
910    }
911
912    #[test]
913    fn wgsl_reduction_nd_decodes_2d_dispatch() {
914        let src = reduction_nd_wgsl("sum");
915        assert!(src.contains("wgid.y * params.grid_x + wgid.x"));
916    }
917
918    #[test]
919    fn wgsl_reduction_nd_uses_strided_loop() {
920        let src = reduction_nd_wgsl("sum");
921        assert!(src.contains("i = i + 256u"));
922    }
923
924    // ── binary_wgsl tests ─────────────────────────────────────────────────
925
926    #[test]
927    fn wgsl_binary_add() {
928        let src = binary_wgsl("add");
929        assert!(src.contains("a + b"));
930        assert!(src.contains("lhs"));
931        assert!(src.contains("rhs"));
932    }
933
934    #[test]
935    fn wgsl_binary_all_ops() {
936        assert!(binary_wgsl("sub").contains("a - b"));
937        assert!(binary_wgsl("mul").contains("a * b"));
938        assert!(binary_wgsl("div").contains("a / b"));
939        assert!(binary_wgsl("max").contains("max(a, b)"));
940        assert!(binary_wgsl("min").contains("min(a, b)"));
941        assert!(binary_wgsl("pow").contains("pow(a, b)"));
942        // Unknown op is identity on lhs.
943        assert!(binary_wgsl("unknown_op").contains("output[i] = a;"));
944    }
945
946    #[test]
947    fn wgsl_binary_workgroup_size() {
948        let src = binary_wgsl("add");
949        assert!(src.contains("@workgroup_size(256)"));
950    }
951
952    // ── conv2d_wgsl tests ─────────────────────────────────────────────────
953
954    #[test]
955    fn wgsl_conv2d_contains_workgroup() {
956        let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
957        assert!(src.contains("@compute @workgroup_size(8, 8)"));
958    }
959
960    #[test]
961    fn wgsl_conv2d_contains_storage_bindings() {
962        let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
963        assert!(src.contains("var<storage, read>       input:"));
964        assert!(src.contains("var<storage, read>       filter:"));
965        assert!(src.contains("var<storage, read_write> output:"));
966    }
967
968    #[test]
969    fn wgsl_conv2d_embeds_dimensions() {
970        let src = conv2d_wgsl(2, 8, 64, 64, 32, 5, 5, 60, 60, 1, 1, 0, 0);
971        // Check that the shape constants appear in the shader
972        assert!(src.contains("8u")); // c_in
973        assert!(src.contains("64u")); // h_in or w_in
974        assert!(src.contains("32u")); // k_out
975        assert!(src.contains("5u")); // fh or fw
976        assert!(src.contains("60u")); // oh or ow
977    }
978
979    #[test]
980    fn wgsl_conv2d_has_padding_check() {
981        let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 8, 8, 1, 1, 1, 1);
982        // Padding check with signed comparison
983        assert!(src.contains("iy_raw >= 0"));
984        assert!(src.contains("ix_raw >= 0"));
985    }
986
987    #[test]
988    fn wgsl_conv2d_has_stride() {
989        let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 3, 3, 2, 2, 0, 0);
990        assert!(src.contains("2u")); // stride
991    }
992
993    // ── attention_wgsl tests ──────────────────────────────────────────────
994
995    #[test]
996    fn wgsl_attention_contains_workgroup() {
997        let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
998        assert!(src.contains("@compute @workgroup_size(64)"));
999    }
1000
1001    #[test]
1002    fn wgsl_attention_contains_storage_bindings() {
1003        let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
1004        assert!(src.contains("var<storage, read>       q_buf:"));
1005        assert!(src.contains("var<storage, read>       k_buf:"));
1006        assert!(src.contains("var<storage, read>       v_buf:"));
1007        assert!(src.contains("var<storage, read_write> o_buf:"));
1008    }
1009
1010    #[test]
1011    fn wgsl_attention_stable_softmax() {
1012        let src = attention_wgsl(1, 4, 4, 32, 0.25, false);
1013        assert!(src.contains("max_score"));
1014        assert!(src.contains("exp(score - max_score)"));
1015        assert!(src.contains("sum_exp"));
1016    }
1017
1018    #[test]
1019    fn wgsl_attention_causal_mask() {
1020        let src_causal = attention_wgsl(1, 4, 4, 32, 0.25, true);
1021        assert!(src_causal.contains("sk > sq"));
1022
1023        let src_non_causal = attention_wgsl(1, 4, 4, 32, 0.25, false);
1024        assert!(!src_non_causal.contains("sk > sq"));
1025    }
1026
1027    #[test]
1028    fn wgsl_attention_embeds_scale() {
1029        let src = attention_wgsl(2, 16, 16, 64, 0.125, false);
1030        assert!(src.contains("0.125"));
1031    }
1032
1033    // ── batched_gemm_wgsl tests ────────────────────────────────────────────
1034
1035    #[test]
1036    fn wgsl_batched_gemm_contains_batch_params() {
1037        let src = batched_gemm_wgsl(16);
1038        assert!(src.contains("batch_count"));
1039        assert!(src.contains("stride_a"));
1040        assert!(src.contains("stride_b"));
1041        assert!(src.contains("stride_c"));
1042    }
1043
1044    #[test]
1045    fn wgsl_batched_gemm_contains_workgroup() {
1046        let src = batched_gemm_wgsl(16);
1047        assert!(src.contains("@compute @workgroup_size(16, 16)"));
1048        assert!(src.contains("BatchedGemmParams"));
1049    }
1050
1051    #[test]
1052    fn wgsl_batched_gemm_uses_batch_index() {
1053        let src = batched_gemm_wgsl(8);
1054        assert!(src.contains("batch_index"));
1055        assert!(src.contains("gid.z"));
1056    }
1057
1058    #[test]
1059    fn wgsl_batched_gemm_tile_size_embedded() {
1060        let src8 = batched_gemm_wgsl(8);
1061        assert!(src8.contains("@workgroup_size(8, 8)"));
1062        let src32 = batched_gemm_wgsl(32);
1063        assert!(src32.contains("@workgroup_size(32, 32)"));
1064    }
1065
1066    #[test]
1067    fn wgsl_batched_gemm_has_transpose_flags() {
1068        let src = batched_gemm_wgsl(8);
1069        assert!(src.contains("trans_a:  u32"));
1070        assert!(src.contains("trans_b:  u32"));
1071        // Per-batch offset is applied to every index form.
1072        assert!(src.contains("a[a_offset + r * params.k + i]"));
1073        assert!(src.contains("a[a_offset + i * params.m + r]"));
1074        assert!(src.contains("b[b_offset + i * params.n + col]"));
1075        assert!(src.contains("b[b_offset + col * params.k + i]"));
1076    }
1077
1078    #[test]
1079    fn wgsl_batched_gemm_uses_shared_memory_tiling() {
1080        let src = batched_gemm_wgsl(8);
1081        assert!(src.contains("var<workgroup> tile_a"));
1082        assert!(src.contains("var<workgroup> tile_b"));
1083        assert!(src.contains("workgroupBarrier"));
1084        assert!(src.contains("array<array<f32, 8>, 8>"));
1085    }
1086
1087    // ── gemm_wgsl_f16 tests ─────────────────────────────────────────────
1088
1089    #[test]
1090    fn wgsl_gemm_f16_enables_extension() {
1091        let src = gemm_wgsl_f16(16);
1092        assert!(src.contains("enable f16;"));
1093    }
1094
1095    #[test]
1096    fn wgsl_gemm_f16_uses_f16_storage() {
1097        let src = gemm_wgsl_f16(16);
1098        assert!(src.contains("array<f16>"));
1099    }
1100
1101    #[test]
1102    fn wgsl_gemm_f16_accumulates_in_f32() {
1103        let src = gemm_wgsl_f16(16);
1104        assert!(src.contains("var acc: f32 = 0.0;"));
1105        assert!(src.contains("f32(a["));
1106        assert!(src.contains("f32(b["));
1107    }
1108
1109    #[test]
1110    fn wgsl_gemm_f16_contains_workgroup() {
1111        let src = gemm_wgsl_f16(8);
1112        assert!(src.contains("@compute @workgroup_size(8, 8)"));
1113        assert!(src.contains("GemmParams"));
1114    }
1115
1116    #[test]
1117    fn wgsl_attention_embeds_dimensions() {
1118        let src = attention_wgsl(8, 32, 32, 128, 0.088, true);
1119        assert!(src.contains("128u")); // head_dim
1120        assert!(src.contains("32u")); // seq_q or seq_kv
1121        assert!(src.contains("8u")); // batch_heads
1122    }
1123}