Skip to main content

oxicuda_webgpu/
shader.rs

1//! WGSL shader source generation for common compute kernels.
2//!
3//! Each function returns a complete, self-contained WGSL source string
4//! suitable for passing to `device.create_shader_module()`.
5
6/// Generate WGSL source for a tiled GEMM kernel: `C = alpha * A * B + beta * C`.
7///
8/// Uses `tile_size × tile_size` workgroups.  Both A and B are stored row-major.
9///
10/// # Arguments
11///
12/// * `tile_size` — workgroup tile dimension (e.g. 8, 16, 32).
13pub fn gemm_wgsl(tile_size: u32) -> String {
14    format!(
15        r#"
16struct GemmParams {{
17    m:     u32,
18    n:     u32,
19    k:     u32,
20    alpha: f32,
21    beta:  f32,
22}}
23
24@group(0) @binding(0) var<storage, read>       a:      array<f32>;
25@group(0) @binding(1) var<storage, read>       b:      array<f32>;
26@group(0) @binding(2) var<storage, read_write> c:      array<f32>;
27@group(0) @binding(3) var<uniform>             params: GemmParams;
28
29@compute @workgroup_size({ts}, {ts})
30fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
31    let row = gid.y;
32    let col = gid.x;
33    if (row >= params.m || col >= params.n) {{ return; }}
34
35    var acc: f32 = 0.0;
36    for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
37        acc += a[row * params.k + i] * b[i * params.n + col];
38    }}
39
40    let idx = row * params.n + col;
41    c[idx] = params.alpha * acc + params.beta * c[idx];
42}}
43"#,
44        ts = tile_size
45    )
46}
47
48/// Generate WGSL source for a batched (strided) GEMM kernel.
49///
50/// For each batch `b` in `0..batch_count`:
51///   `C_b = alpha * A_b * B_b + beta * C_b`
52/// where `A_b` starts at `a[b * stride_a]`, etc.
53///
54/// Uses `tile_size × tile_size` workgroups with Z = batch_count.
55///
56/// # Arguments
57///
58/// * `tile_size` — workgroup tile dimension (e.g. 8, 16, 32).
59pub fn batched_gemm_wgsl(tile_size: u32) -> String {
60    format!(
61        r#"
62struct BatchedGemmParams {{
63    m:        u32,
64    n:        u32,
65    k:        u32,
66    alpha:    f32,
67    beta:     f32,
68    batch_count: u32,
69    stride_a: u32,
70    stride_b: u32,
71    stride_c: u32,
72}}
73
74@group(0) @binding(0) var<storage, read>       a:      array<f32>;
75@group(0) @binding(1) var<storage, read>       b:      array<f32>;
76@group(0) @binding(2) var<storage, read_write> c:      array<f32>;
77@group(0) @binding(3) var<uniform>             params: BatchedGemmParams;
78
79@compute @workgroup_size({ts}, {ts})
80fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
81    let row = gid.y;
82    let col = gid.x;
83    let batch_index = gid.z;
84    if (batch_index >= params.batch_count || row >= params.m || col >= params.n) {{ return; }}
85
86    let a_offset = batch_index * params.stride_a;
87    let b_offset = batch_index * params.stride_b;
88    let c_offset = batch_index * params.stride_c;
89
90    var acc: f32 = 0.0;
91    for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
92        acc += a[a_offset + row * params.k + i] * b[b_offset + i * params.n + col];
93    }}
94
95    let idx = c_offset + row * params.n + col;
96    c[idx] = params.alpha * acc + params.beta * c[idx];
97}}
98"#,
99        ts = tile_size
100    )
101}
102
103/// Generate WGSL source for a tiled GEMM kernel using FP16 storage.
104///
105/// Uses `enable f16;` WGSL extension. Storage buffers use `array<f16>`,
106/// but accumulation is done in f32 for precision.
107///
108/// # Arguments
109///
110/// * `tile_size` — workgroup tile dimension (e.g. 8, 16, 32).
111pub fn gemm_wgsl_f16(tile_size: u32) -> String {
112    format!(
113        r#"
114enable f16;
115
116struct GemmParams {{
117    m:     u32,
118    n:     u32,
119    k:     u32,
120    alpha: f32,
121    beta:  f32,
122}}
123
124@group(0) @binding(0) var<storage, read>       a:      array<f16>;
125@group(0) @binding(1) var<storage, read>       b:      array<f16>;
126@group(0) @binding(2) var<storage, read_write> c:      array<f16>;
127@group(0) @binding(3) var<uniform>             params: GemmParams;
128
129@compute @workgroup_size({ts}, {ts})
130fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
131    let row = gid.y;
132    let col = gid.x;
133    if (row >= params.m || col >= params.n) {{ return; }}
134
135    var acc: f32 = 0.0;
136    for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
137        acc += f32(a[row * params.k + i]) * f32(b[i * params.n + col]);
138    }}
139
140    let idx = row * params.n + col;
141    let prev = f32(c[idx]);
142    c[idx] = f16(params.alpha * acc + params.beta * prev);
143}}
144"#,
145        ts = tile_size
146    )
147}
148
149/// Generate WGSL source for an element-wise unary operation.
150///
151/// The shader reads `n` elements from `input`, applies the operation, and
152/// writes the results to `output`.  Both buffers have `arrayLength` elements.
153///
154/// # Arguments
155///
156/// * `op` — one of: `"relu"`, `"sigmoid"`, `"tanh"`, `"exp"`, `"log"`,
157///   `"sqrt"`, `"abs"`, `"neg"`.  Unknown ops are treated as identity.
158pub fn elementwise_wgsl(op: &str) -> String {
159    let op_expr = match op {
160        "relu" => "max(x, 0.0)",
161        "sigmoid" => "1.0 / (1.0 + exp(-x))",
162        "tanh" => "tanh(x)",
163        "exp" => "exp(x)",
164        "log" => "log(x)",
165        "sqrt" => "sqrt(x)",
166        "abs" => "abs(x)",
167        "neg" => "-x",
168        _ => "x",
169    };
170
171    format!(
172        r#"
173@group(0) @binding(0) var<storage, read>       input:  array<f32>;
174@group(0) @binding(1) var<storage, read_write> output: array<f32>;
175
176@compute @workgroup_size(256)
177fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
178    let i = gid.x;
179    if (i >= arrayLength(&input)) {{ return; }}
180    let x = input[i];
181    output[i] = {op};
182}}
183"#,
184        op = op_expr
185    )
186}
187
188/// Generate WGSL source for an element-wise binary operation.
189///
190/// The shader reads `n` elements from two input buffers (`lhs` and `rhs`),
191/// applies the operation, and writes the results to `output`.
192///
193/// # Arguments
194///
195/// * `op` — one of: `"add"`, `"sub"`, `"mul"`, `"div"`, `"max"`, `"min"`,
196///   `"pow"`.  Unknown ops fall back to identity on `lhs`.
197pub fn binary_wgsl(op: &str) -> String {
198    let op_expr = match op {
199        "add" => "a + b",
200        "sub" => "a - b",
201        "mul" => "a * b",
202        "div" => "a / b",
203        "max" => "max(a, b)",
204        "min" => "min(a, b)",
205        "pow" => "pow(a, b)",
206        _ => "a",
207    };
208
209    format!(
210        r#"
211@group(0) @binding(0) var<storage, read>       lhs:    array<f32>;
212@group(0) @binding(1) var<storage, read>       rhs:    array<f32>;
213@group(0) @binding(2) var<storage, read_write> output: array<f32>;
214
215@compute @workgroup_size(256)
216fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
217    let i = gid.x;
218    if (i >= arrayLength(&lhs)) {{ return; }}
219    let a = lhs[i];
220    let b = rhs[i];
221    output[i] = {op};
222}}
223"#,
224        op = op_expr
225    )
226}
227
228/// Generate WGSL source for a parallel workgroup-level reduction.
229///
230/// Performs a two-pass approach: each workgroup of 256 threads reduces its
231/// tile to a single value in shared memory, then the results are written to
232/// a partial-sums buffer.  A second dispatch (with a single workgroup) then
233/// reduces the partial-sums to the final scalar.
234///
235/// # Arguments
236///
237/// * `op` — one of: `"sum"`, `"max"`, `"min"`, `"mean"`.  `"mean"` behaves
238///   like `"sum"` in the shader; the CPU is responsible for dividing by N.
239///   Unknown ops fall back to `"sum"`.
240pub fn reduction_wgsl(op: &str) -> String {
241    // Neutral elements and combine expressions for each operation.
242    let (neutral, combine) = match op {
243        "max" => ("f32(-1e38)", "max(acc, val)"),
244        "min" => ("f32(1e38)", "min(acc, val)"),
245        // "sum" and "mean" use the same reduction body.
246        _ => ("f32(0.0)", "acc + val"),
247    };
248
249    format!(
250        r#"
251// Reduction params: total element count.
252struct ReduceParams {{
253    n: u32,
254}}
255
256@group(0) @binding(0) var<storage, read>       input:        array<f32>;
257@group(0) @binding(1) var<storage, read_write> partial_sums: array<f32>;
258@group(0) @binding(2) var<uniform>             params:       ReduceParams;
259
260var<workgroup> shared_data: array<f32, 256>;
261
262@compute @workgroup_size(256)
263fn main(
264    @builtin(global_invocation_id) gid:  vec3<u32>,
265    @builtin(local_invocation_id)  lid:  vec3<u32>,
266    @builtin(workgroup_id)         wgid: vec3<u32>,
267) {{
268    let tid         = lid.x;
269    let global_idx  = gid.x;
270
271    // Load or use neutral element when out of range.
272    if (global_idx < params.n) {{
273        shared_data[tid] = input[global_idx];
274    }} else {{
275        shared_data[tid] = {neutral};
276    }}
277    workgroupBarrier();
278
279    // Parallel tree reduction within the workgroup.
280    var stride: u32 = 128u;
281    loop {{
282        if (stride == 0u) {{ break; }}
283        if (tid < stride) {{
284            let acc = shared_data[tid];
285            let val = shared_data[tid + stride];
286            shared_data[tid] = {combine};
287        }}
288        workgroupBarrier();
289        stride = stride >> 1u;
290    }}
291
292    // Thread 0 writes the workgroup result to the partial-sums buffer.
293    if (tid == 0u) {{
294        partial_sums[wgid.x] = shared_data[0];
295    }}
296}}
297"#,
298        neutral = neutral,
299        combine = combine,
300    )
301}
302
303/// Generate a WGSL compute shader for 2D convolution in NCHW format.
304///
305/// The shader reads from `input` (NCHW) and `filter` (K×C×FH×FW), writing
306/// the result to `output` (N×K×OH×OW).  Padding is handled via bounds
307/// checking — out-of-range input positions contribute zero.
308///
309/// # Arguments
310///
311/// * `n` — batch size
312/// * `c_in` — number of input channels
313/// * `h_in`, `w_in` — spatial input dimensions
314/// * `k_out` — number of output channels (filters)
315/// * `fh`, `fw` — filter height / width
316/// * `oh`, `ow` — output height / width
317/// * `stride_h`, `stride_w` — convolution strides
318/// * `pad_h`, `pad_w` — zero-padding applied to the input
319#[allow(clippy::too_many_arguments)]
320pub fn conv2d_wgsl(
321    n: u32,
322    c_in: u32,
323    h_in: u32,
324    w_in: u32,
325    k_out: u32,
326    fh: u32,
327    fw: u32,
328    oh: u32,
329    ow: u32,
330    stride_h: u32,
331    stride_w: u32,
332    pad_h: u32,
333    pad_w: u32,
334) -> String {
335    format!(
336        r#"
337// Conv2D NCHW — generated by oxicuda-webgpu
338// input : [{n}, {c_in}, {h_in}, {w_in}]
339// filter: [{k_out}, {c_in}, {fh}, {fw}]
340// output: [{n}, {k_out}, {oh}, {ow}]
341
342@group(0) @binding(0) var<storage, read>       input:  array<f32>;
343@group(0) @binding(1) var<storage, read>       filter: array<f32>;
344@group(0) @binding(2) var<storage, read_write> output: array<f32>;
345
346@compute @workgroup_size(8, 8)
347fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
348    // gid.x = output x (ox mapped across batches*k_out*oh)
349    // We flatten (batch, k, oy) into gid.y and ox into gid.x
350    let ox = gid.x;
351    let linear_y = gid.y;
352
353    let batch_k_oh = {n}u * {k_out}u * {oh}u;
354    if (ox >= {ow}u || linear_y >= batch_k_oh) {{ return; }}
355
356    let b  = linear_y / ({k_out}u * {oh}u);
357    let rem = linear_y % ({k_out}u * {oh}u);
358    let kf = rem / {oh}u;
359    let oy = rem % {oh}u;
360
361    var acc: f32 = 0.0;
362    for (var ci: u32 = 0u; ci < {c_in}u; ci = ci + 1u) {{
363        for (var fy: u32 = 0u; fy < {fh}u; fy = fy + 1u) {{
364            for (var fx: u32 = 0u; fx < {fw}u; fx = fx + 1u) {{
365                let iy_raw = i32(oy * {stride_h}u + fy) - i32({pad_h}u);
366                let ix_raw = i32(ox * {stride_w}u + fx) - i32({pad_w}u);
367                if (iy_raw >= 0 && iy_raw < i32({h_in}u) && ix_raw >= 0 && ix_raw < i32({w_in}u)) {{
368                    let iy = u32(iy_raw);
369                    let ix = u32(ix_raw);
370                    let in_idx = ((b * {c_in}u + ci) * {h_in}u + iy) * {w_in}u + ix;
371                    let f_idx  = ((kf * {c_in}u + ci) * {fh}u + fy) * {fw}u + fx;
372                    acc += input[in_idx] * filter[f_idx];
373                }}
374            }}
375        }}
376    }}
377
378    let o_idx = ((b * {k_out}u + kf) * {oh}u + oy) * {ow}u + ox;
379    output[o_idx] = acc;
380}}
381"#,
382        n = n,
383        c_in = c_in,
384        h_in = h_in,
385        w_in = w_in,
386        k_out = k_out,
387        fh = fh,
388        fw = fw,
389        oh = oh,
390        ow = ow,
391        stride_h = stride_h,
392        stride_w = stride_w,
393        pad_h = pad_h,
394        pad_w = pad_w,
395    )
396}
397
398/// Generate a WGSL compute shader for scaled dot-product attention.
399///
400/// Implements: `O = softmax(Q·K^T * scale [+ causal_mask]) · V`
401///
402/// The softmax is numerically stable (subtracts max before exp).
403/// When `causal` is true, positions where `sk > sq` are masked to −∞.
404///
405/// # Arguments
406///
407/// * `batch_heads` — combined batch × heads dimension
408/// * `seq_q` — query sequence length
409/// * `seq_kv` — key/value sequence length
410/// * `head_dim` — dimension of each head
411/// * `scale` — scaling factor (typically `1 / sqrt(head_dim)`)
412/// * `causal` — whether to apply a causal (upper-triangular) mask
413pub fn attention_wgsl(
414    batch_heads: u32,
415    seq_q: u32,
416    seq_kv: u32,
417    head_dim: u32,
418    scale: f32,
419    causal: bool,
420) -> String {
421    let causal_check = if causal {
422        "if (sk > sq) { score = f32(-1e38); } else {"
423    } else {
424        "{"
425    };
426
427    format!(
428        r#"
429// Scaled dot-product attention — generated by oxicuda-webgpu
430// Q, K, V : [{batch_heads}, seq, {head_dim}]
431// O       : [{batch_heads}, {seq_q}, {head_dim}]
432// scale   : {scale}
433// causal  : {causal}
434
435@group(0) @binding(0) var<storage, read>       q_buf: array<f32>;
436@group(0) @binding(1) var<storage, read>       k_buf: array<f32>;
437@group(0) @binding(2) var<storage, read>       v_buf: array<f32>;
438@group(0) @binding(3) var<storage, read_write> o_buf: array<f32>;
439
440@compute @workgroup_size(64)
441fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
442    let linear = gid.x;
443    let total = {batch_heads}u * {seq_q}u;
444    if (linear >= total) {{ return; }}
445
446    let bh = linear / {seq_q}u;
447    let sq = linear % {seq_q}u;
448
449    let q_base = (bh * {seq_q}u + sq) * {head_dim}u;
450
451    // Pass 1: find max score for numerical stability
452    var max_score: f32 = f32(-1e38);
453    for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
454        var score: f32 = 0.0;
455        {causal_check}
456            let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
457            for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
458                score += q_buf[q_base + d] * k_buf[k_base + d];
459            }}
460            score *= f32({scale});
461        }}
462        if (score > max_score) {{ max_score = score; }}
463    }}
464
465    // Pass 2: compute exp(score - max), accumulate weighted V
466    var sum_exp: f32 = 0.0;
467    for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
468        var score: f32 = 0.0;
469        {causal_check}
470            let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
471            for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
472                score += q_buf[q_base + d] * k_buf[k_base + d];
473            }}
474            score *= f32({scale});
475        }}
476        let w = exp(score - max_score);
477        sum_exp += w;
478        let v_base = (bh * {seq_kv}u + sk) * {head_dim}u;
479        let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
480        for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
481            // Accumulate in-place (we normalise after the loop).
482            o_buf[o_base + d] += w * v_buf[v_base + d];
483        }}
484    }}
485
486    // Pass 3: normalise
487    if (sum_exp > 0.0) {{
488        let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
489        for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
490            o_buf[o_base + d] /= sum_exp;
491        }}
492    }}
493}}
494"#,
495        batch_heads = batch_heads,
496        seq_q = seq_q,
497        seq_kv = seq_kv,
498        head_dim = head_dim,
499        scale = scale,
500        causal = causal,
501        causal_check = causal_check,
502    )
503}
504
505/// Generate WGSL source for an N-D reduction along a single axis.
506///
507/// The tensor is logically reshaped to `[outer, dk, inner]`, where the reduce
508/// axis spans `dk` elements, `outer` is the product of dimensions before the
509/// axis, and `inner` is the product of dimensions after the axis.
510///
511/// Output shape is `[outer, inner]` (flattened to a 1-D buffer of length
512/// `outer * inner`).  Each output slot is computed by a full workgroup of
513/// 256 threads, which cooperatively reduce the `dk` elements via a strided
514/// loop and a shared-memory tree reduction.
515///
516/// Dispatch must be 2-D: `(grid_x, ceil((outer * inner) / grid_x), 1)` to
517/// stay below WebGPU's per-axis limit of 65 535 workgroups.  The shader
518/// decodes its slot via `wgid.y * params.grid_x + wgid.x` and early-returns
519/// if `slot >= outer * inner`.
520///
521/// For `Mean`, the shader divides each output by `dk` directly (no host-side
522/// post-processing required).
523///
524/// # Arguments
525///
526/// * `op` — one of: `"sum"`, `"max"`, `"min"`, `"mean"`.  Unknown ops fall
527///   back to `"sum"`.
528pub fn reduction_nd_wgsl(op: &str) -> String {
529    // The first form combines `acc` with `val` (per-thread strided loop).
530    // The second form combines `acc2` with `val` (in-shared-memory tree
531    // reduction).  Listing both explicitly is more robust than string-
532    // substitution for future ops.
533    let (neutral, combine, combine_alias) = match op {
534        "max" => ("f32(-1e38)", "max(acc, val)", "max(acc2, val)"),
535        "min" => ("f32(1e38)", "min(acc, val)", "min(acc2, val)"),
536        // "sum" and "mean" use the same combine; "mean" divides at the end.
537        _ => ("f32(0.0)", "acc + val", "acc2 + val"),
538    };
539
540    // For "mean", divide the final reduced value by dk; otherwise pass-through.
541    let final_expr = if op == "mean" {
542        "shared_data[0] / f32(params.dk)"
543    } else {
544        "shared_data[0]"
545    };
546
547    format!(
548        r#"
549struct ReduceNdParams {{
550    outer:        u32,
551    dk:           u32,
552    inner:        u32,
553    outer_stride: u32,
554    dk_stride:    u32,
555    inner_stride: u32,
556    grid_x:       u32,
557    _pad:         u32,
558}}
559
560@group(0) @binding(0) var<storage, read>       input:  array<f32>;
561@group(0) @binding(1) var<storage, read_write> output: array<f32>;
562@group(0) @binding(2) var<uniform>             params: ReduceNdParams;
563
564var<workgroup> shared_data: array<f32, 256>;
565
566@compute @workgroup_size(256)
567fn main(
568    @builtin(local_invocation_id) lid:  vec3<u32>,
569    @builtin(workgroup_id)        wgid: vec3<u32>,
570) {{
571    let tid = lid.x;
572    let total = params.outer * params.inner;
573
574    // Decode 2-D workgroup id back to a linear output slot.
575    let slot = wgid.y * params.grid_x + wgid.x;
576    if (slot >= total) {{ return; }}
577
578    let o = slot / params.inner;
579    let j = slot % params.inner;
580    let base = o * params.outer_stride + j * params.inner_stride;
581
582    // Strided per-thread reduction across the dk axis.
583    var acc: f32 = {neutral};
584    var i: u32 = tid;
585    loop {{
586        if (i >= params.dk) {{ break; }}
587        let val = input[base + i * params.dk_stride];
588        acc = {combine};
589        i = i + 256u;
590    }}
591
592    shared_data[tid] = acc;
593    workgroupBarrier();
594
595    // Tree reduction within the workgroup.
596    var stride: u32 = 128u;
597    loop {{
598        if (stride == 0u) {{ break; }}
599        if (tid < stride) {{
600            let acc2 = shared_data[tid];
601            let val  = shared_data[tid + stride];
602            shared_data[tid] = {combine_alias};
603        }}
604        workgroupBarrier();
605        stride = stride >> 1u;
606    }}
607
608    if (tid == 0u) {{
609        output[slot] = {final_expr};
610    }}
611}}
612"#,
613        neutral = neutral,
614        combine = combine,
615        combine_alias = combine_alias,
616        final_expr = final_expr,
617    )
618}
619
620/// Generate WGSL for the final scalar reduction of partial sums.
621///
622/// Takes a `partial_sums` array of length `num_groups` and reduces it to a
623/// single value at `output[0]`.  Should be dispatched with a single workgroup
624/// of 256 threads.
625pub fn reduction_final_wgsl(op: &str) -> String {
626    let (neutral, combine) = match op {
627        "max" => ("f32(-1e38)", "max(acc, val)"),
628        "min" => ("f32(1e38)", "min(acc, val)"),
629        _ => ("f32(0.0)", "acc + val"),
630    };
631
632    format!(
633        r#"
634struct FinalReduceParams {{
635    num_groups: u32,
636}}
637
638@group(0) @binding(0) var<storage, read>       partial_sums: array<f32>;
639@group(0) @binding(1) var<storage, read_write> output:       array<f32>;
640@group(0) @binding(2) var<uniform>             params:       FinalReduceParams;
641
642var<workgroup> shared_data: array<f32, 256>;
643
644@compute @workgroup_size(256)
645fn main(
646    @builtin(local_invocation_id) lid: vec3<u32>,
647) {{
648    let tid = lid.x;
649
650    if (tid < params.num_groups) {{
651        shared_data[tid] = partial_sums[tid];
652    }} else {{
653        shared_data[tid] = {neutral};
654    }}
655    workgroupBarrier();
656
657    var stride: u32 = 128u;
658    loop {{
659        if (stride == 0u) {{ break; }}
660        if (tid < stride) {{
661            let acc = shared_data[tid];
662            let val = shared_data[tid + stride];
663            shared_data[tid] = {combine};
664        }}
665        workgroupBarrier();
666        stride = stride >> 1u;
667    }}
668
669    if (tid == 0u) {{
670        output[0] = shared_data[0];
671    }}
672}}
673"#,
674        neutral = neutral,
675        combine = combine,
676    )
677}
678
679#[cfg(test)]
680mod tests {
681    use super::*;
682
683    #[test]
684    fn wgsl_gemm_contains_workgroup() {
685        let src = gemm_wgsl(16);
686        assert!(src.contains("@compute @workgroup_size(16, 16)"));
687        assert!(src.contains("GemmParams"));
688        assert!(src.contains("alpha"));
689        assert!(src.contains("beta"));
690    }
691
692    #[test]
693    fn wgsl_gemm_tile_size_embedded() {
694        let src8 = gemm_wgsl(8);
695        assert!(src8.contains("@workgroup_size(8, 8)"));
696        let src32 = gemm_wgsl(32);
697        assert!(src32.contains("@workgroup_size(32, 32)"));
698    }
699
700    #[test]
701    fn wgsl_elementwise_relu_contains_max() {
702        let src = elementwise_wgsl("relu");
703        assert!(src.contains("max(x, 0.0)"));
704    }
705
706    #[test]
707    fn wgsl_elementwise_all_ops() {
708        assert!(elementwise_wgsl("sigmoid").contains("exp(-x)"));
709        assert!(elementwise_wgsl("tanh").contains("tanh(x)"));
710        assert!(elementwise_wgsl("exp").contains("exp(x)"));
711        assert!(elementwise_wgsl("log").contains("log(x)"));
712        assert!(elementwise_wgsl("sqrt").contains("sqrt(x)"));
713        assert!(elementwise_wgsl("abs").contains("abs(x)"));
714        assert!(elementwise_wgsl("neg").contains("-x"));
715        // Unknown op is identity.
716        assert!(elementwise_wgsl("identity_op").contains("output[i] = x;"));
717    }
718
719    #[test]
720    fn wgsl_reduction_sum_contains_addition() {
721        let src = reduction_wgsl("sum");
722        assert!(src.contains("acc + val"));
723        assert!(src.contains("workgroupBarrier"));
724    }
725
726    #[test]
727    fn wgsl_reduction_max_uses_max_fn() {
728        let src = reduction_wgsl("max");
729        assert!(src.contains("max(acc, val)"));
730    }
731
732    #[test]
733    fn wgsl_reduction_min_uses_min_fn() {
734        let src = reduction_wgsl("min");
735        assert!(src.contains("min(acc, val)"));
736    }
737
738    #[test]
739    fn wgsl_reduction_mean_same_as_sum() {
740        // "mean" divides on the CPU side; the shader is identical to sum.
741        let sum_src = reduction_wgsl("sum");
742        let mean_src = reduction_wgsl("mean");
743        assert_eq!(sum_src, mean_src);
744    }
745
746    #[test]
747    fn wgsl_reduction_final_sum() {
748        let src = reduction_final_wgsl("sum");
749        assert!(src.contains("num_groups"));
750        assert!(src.contains("output[0]"));
751    }
752
753    // ── reduction_nd_wgsl tests ───────────────────────────────────────────
754
755    #[test]
756    fn wgsl_reduction_nd_sum_contains_addition() {
757        let src = reduction_nd_wgsl("sum");
758        assert!(src.contains("acc + val"));
759        // Tree-step reuses the same combine with renamed lhs.
760        assert!(src.contains("acc2 + val"));
761        assert!(src.contains("workgroupBarrier"));
762        assert!(src.contains("ReduceNdParams"));
763    }
764
765    #[test]
766    fn wgsl_reduction_nd_max_uses_max_fn() {
767        let src = reduction_nd_wgsl("max");
768        assert!(src.contains("max(acc, val)"));
769        assert!(src.contains("max(acc2, val)"));
770    }
771
772    #[test]
773    fn wgsl_reduction_nd_min_uses_min_fn() {
774        let src = reduction_nd_wgsl("min");
775        assert!(src.contains("min(acc, val)"));
776        assert!(src.contains("min(acc2, val)"));
777    }
778
779    #[test]
780    fn wgsl_reduction_nd_mean_divides_by_dk() {
781        let src = reduction_nd_wgsl("mean");
782        assert!(src.contains("shared_data[0] / f32(params.dk)"));
783        assert!(src.contains("acc + val"));
784    }
785
786    #[test]
787    fn wgsl_reduction_nd_sum_does_not_divide() {
788        let src = reduction_nd_wgsl("sum");
789        assert!(!src.contains("/ f32(params.dk)"));
790    }
791
792    #[test]
793    fn wgsl_reduction_nd_decodes_2d_dispatch() {
794        let src = reduction_nd_wgsl("sum");
795        assert!(src.contains("wgid.y * params.grid_x + wgid.x"));
796    }
797
798    #[test]
799    fn wgsl_reduction_nd_uses_strided_loop() {
800        let src = reduction_nd_wgsl("sum");
801        assert!(src.contains("i = i + 256u"));
802    }
803
804    // ── binary_wgsl tests ─────────────────────────────────────────────────
805
806    #[test]
807    fn wgsl_binary_add() {
808        let src = binary_wgsl("add");
809        assert!(src.contains("a + b"));
810        assert!(src.contains("lhs"));
811        assert!(src.contains("rhs"));
812    }
813
814    #[test]
815    fn wgsl_binary_all_ops() {
816        assert!(binary_wgsl("sub").contains("a - b"));
817        assert!(binary_wgsl("mul").contains("a * b"));
818        assert!(binary_wgsl("div").contains("a / b"));
819        assert!(binary_wgsl("max").contains("max(a, b)"));
820        assert!(binary_wgsl("min").contains("min(a, b)"));
821        assert!(binary_wgsl("pow").contains("pow(a, b)"));
822        // Unknown op is identity on lhs.
823        assert!(binary_wgsl("unknown_op").contains("output[i] = a;"));
824    }
825
826    #[test]
827    fn wgsl_binary_workgroup_size() {
828        let src = binary_wgsl("add");
829        assert!(src.contains("@workgroup_size(256)"));
830    }
831
832    // ── conv2d_wgsl tests ─────────────────────────────────────────────────
833
834    #[test]
835    fn wgsl_conv2d_contains_workgroup() {
836        let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
837        assert!(src.contains("@compute @workgroup_size(8, 8)"));
838    }
839
840    #[test]
841    fn wgsl_conv2d_contains_storage_bindings() {
842        let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
843        assert!(src.contains("var<storage, read>       input:"));
844        assert!(src.contains("var<storage, read>       filter:"));
845        assert!(src.contains("var<storage, read_write> output:"));
846    }
847
848    #[test]
849    fn wgsl_conv2d_embeds_dimensions() {
850        let src = conv2d_wgsl(2, 8, 64, 64, 32, 5, 5, 60, 60, 1, 1, 0, 0);
851        // Check that the shape constants appear in the shader
852        assert!(src.contains("8u")); // c_in
853        assert!(src.contains("64u")); // h_in or w_in
854        assert!(src.contains("32u")); // k_out
855        assert!(src.contains("5u")); // fh or fw
856        assert!(src.contains("60u")); // oh or ow
857    }
858
859    #[test]
860    fn wgsl_conv2d_has_padding_check() {
861        let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 8, 8, 1, 1, 1, 1);
862        // Padding check with signed comparison
863        assert!(src.contains("iy_raw >= 0"));
864        assert!(src.contains("ix_raw >= 0"));
865    }
866
867    #[test]
868    fn wgsl_conv2d_has_stride() {
869        let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 3, 3, 2, 2, 0, 0);
870        assert!(src.contains("2u")); // stride
871    }
872
873    // ── attention_wgsl tests ──────────────────────────────────────────────
874
875    #[test]
876    fn wgsl_attention_contains_workgroup() {
877        let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
878        assert!(src.contains("@compute @workgroup_size(64)"));
879    }
880
881    #[test]
882    fn wgsl_attention_contains_storage_bindings() {
883        let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
884        assert!(src.contains("var<storage, read>       q_buf:"));
885        assert!(src.contains("var<storage, read>       k_buf:"));
886        assert!(src.contains("var<storage, read>       v_buf:"));
887        assert!(src.contains("var<storage, read_write> o_buf:"));
888    }
889
890    #[test]
891    fn wgsl_attention_stable_softmax() {
892        let src = attention_wgsl(1, 4, 4, 32, 0.25, false);
893        assert!(src.contains("max_score"));
894        assert!(src.contains("exp(score - max_score)"));
895        assert!(src.contains("sum_exp"));
896    }
897
898    #[test]
899    fn wgsl_attention_causal_mask() {
900        let src_causal = attention_wgsl(1, 4, 4, 32, 0.25, true);
901        assert!(src_causal.contains("sk > sq"));
902
903        let src_non_causal = attention_wgsl(1, 4, 4, 32, 0.25, false);
904        assert!(!src_non_causal.contains("sk > sq"));
905    }
906
907    #[test]
908    fn wgsl_attention_embeds_scale() {
909        let src = attention_wgsl(2, 16, 16, 64, 0.125, false);
910        assert!(src.contains("0.125"));
911    }
912
913    // ── batched_gemm_wgsl tests ────────────────────────────────────────────
914
915    #[test]
916    fn wgsl_batched_gemm_contains_batch_params() {
917        let src = batched_gemm_wgsl(16);
918        assert!(src.contains("batch_count"));
919        assert!(src.contains("stride_a"));
920        assert!(src.contains("stride_b"));
921        assert!(src.contains("stride_c"));
922    }
923
924    #[test]
925    fn wgsl_batched_gemm_contains_workgroup() {
926        let src = batched_gemm_wgsl(16);
927        assert!(src.contains("@compute @workgroup_size(16, 16)"));
928        assert!(src.contains("BatchedGemmParams"));
929    }
930
931    #[test]
932    fn wgsl_batched_gemm_uses_batch_index() {
933        let src = batched_gemm_wgsl(8);
934        assert!(src.contains("batch_index"));
935        assert!(src.contains("gid.z"));
936    }
937
938    #[test]
939    fn wgsl_batched_gemm_tile_size_embedded() {
940        let src8 = batched_gemm_wgsl(8);
941        assert!(src8.contains("@workgroup_size(8, 8)"));
942        let src32 = batched_gemm_wgsl(32);
943        assert!(src32.contains("@workgroup_size(32, 32)"));
944    }
945
946    // ── gemm_wgsl_f16 tests ─────────────────────────────────────────────
947
948    #[test]
949    fn wgsl_gemm_f16_enables_extension() {
950        let src = gemm_wgsl_f16(16);
951        assert!(src.contains("enable f16;"));
952    }
953
954    #[test]
955    fn wgsl_gemm_f16_uses_f16_storage() {
956        let src = gemm_wgsl_f16(16);
957        assert!(src.contains("array<f16>"));
958    }
959
960    #[test]
961    fn wgsl_gemm_f16_accumulates_in_f32() {
962        let src = gemm_wgsl_f16(16);
963        assert!(src.contains("var acc: f32 = 0.0;"));
964        assert!(src.contains("f32(a["));
965        assert!(src.contains("f32(b["));
966    }
967
968    #[test]
969    fn wgsl_gemm_f16_contains_workgroup() {
970        let src = gemm_wgsl_f16(8);
971        assert!(src.contains("@compute @workgroup_size(8, 8)"));
972        assert!(src.contains("GemmParams"));
973    }
974
975    #[test]
976    fn wgsl_attention_embeds_dimensions() {
977        let src = attention_wgsl(8, 32, 32, 128, 0.088, true);
978        assert!(src.contains("128u")); // head_dim
979        assert!(src.contains("32u")); // seq_q or seq_kv
980        assert!(src.contains("8u")); // batch_heads
981    }
982}