Skip to main content

oxicuda_webgpu/
shader.rs

1//! WGSL shader source generation for common compute kernels.
2//!
3//! Each function returns a complete, self-contained WGSL source string
4//! suitable for passing to `device.create_shader_module()`.
5
6/// Generate WGSL source for a tiled GEMM kernel: `C = alpha * A * B + beta * C`.
7///
8/// Uses `tile_size × tile_size` workgroups.  Both A and B are stored row-major.
9///
10/// # Arguments
11///
12/// * `tile_size` — workgroup tile dimension (e.g. 8, 16, 32).
13pub fn gemm_wgsl(tile_size: u32) -> String {
14    format!(
15        r#"
16struct GemmParams {{
17    m:     u32,
18    n:     u32,
19    k:     u32,
20    alpha: f32,
21    beta:  f32,
22}}
23
24@group(0) @binding(0) var<storage, read>       a:      array<f32>;
25@group(0) @binding(1) var<storage, read>       b:      array<f32>;
26@group(0) @binding(2) var<storage, read_write> c:      array<f32>;
27@group(0) @binding(3) var<uniform>             params: GemmParams;
28
29@compute @workgroup_size({ts}, {ts})
30fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
31    let row = gid.y;
32    let col = gid.x;
33    if (row >= params.m || col >= params.n) {{ return; }}
34
35    var acc: f32 = 0.0;
36    for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
37        acc += a[row * params.k + i] * b[i * params.n + col];
38    }}
39
40    let idx = row * params.n + col;
41    c[idx] = params.alpha * acc + params.beta * c[idx];
42}}
43"#,
44        ts = tile_size
45    )
46}
47
48/// Generate WGSL source for an element-wise unary operation.
49///
50/// The shader reads `n` elements from `input`, applies the operation, and
51/// writes the results to `output`.  Both buffers have `arrayLength` elements.
52///
53/// # Arguments
54///
55/// * `op` — one of: `"relu"`, `"sigmoid"`, `"tanh"`, `"exp"`, `"log"`,
56///   `"sqrt"`, `"abs"`, `"neg"`.  Unknown ops are treated as identity.
57pub fn elementwise_wgsl(op: &str) -> String {
58    let op_expr = match op {
59        "relu" => "max(x, 0.0)",
60        "sigmoid" => "1.0 / (1.0 + exp(-x))",
61        "tanh" => "tanh(x)",
62        "exp" => "exp(x)",
63        "log" => "log(x)",
64        "sqrt" => "sqrt(x)",
65        "abs" => "abs(x)",
66        "neg" => "-x",
67        _ => "x",
68    };
69
70    format!(
71        r#"
72@group(0) @binding(0) var<storage, read>       input:  array<f32>;
73@group(0) @binding(1) var<storage, read_write> output: array<f32>;
74
75@compute @workgroup_size(256)
76fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
77    let i = gid.x;
78    if (i >= arrayLength(&input)) {{ return; }}
79    let x = input[i];
80    output[i] = {op};
81}}
82"#,
83        op = op_expr
84    )
85}
86
87/// Generate WGSL source for an element-wise binary operation.
88///
89/// The shader reads `n` elements from two input buffers (`lhs` and `rhs`),
90/// applies the operation, and writes the results to `output`.
91///
92/// # Arguments
93///
94/// * `op` — one of: `"add"`, `"sub"`, `"mul"`, `"div"`, `"max"`, `"min"`,
95///   `"pow"`.  Unknown ops fall back to identity on `lhs`.
96pub fn binary_wgsl(op: &str) -> String {
97    let op_expr = match op {
98        "add" => "a + b",
99        "sub" => "a - b",
100        "mul" => "a * b",
101        "div" => "a / b",
102        "max" => "max(a, b)",
103        "min" => "min(a, b)",
104        "pow" => "pow(a, b)",
105        _ => "a",
106    };
107
108    format!(
109        r#"
110@group(0) @binding(0) var<storage, read>       lhs:    array<f32>;
111@group(0) @binding(1) var<storage, read>       rhs:    array<f32>;
112@group(0) @binding(2) var<storage, read_write> output: array<f32>;
113
114@compute @workgroup_size(256)
115fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
116    let i = gid.x;
117    if (i >= arrayLength(&lhs)) {{ return; }}
118    let a = lhs[i];
119    let b = rhs[i];
120    output[i] = {op};
121}}
122"#,
123        op = op_expr
124    )
125}
126
127/// Generate WGSL source for a parallel workgroup-level reduction.
128///
129/// Performs a two-pass approach: each workgroup of 256 threads reduces its
130/// tile to a single value in shared memory, then the results are written to
131/// a partial-sums buffer.  A second dispatch (with a single workgroup) then
132/// reduces the partial-sums to the final scalar.
133///
134/// # Arguments
135///
136/// * `op` — one of: `"sum"`, `"max"`, `"min"`, `"mean"`.  `"mean"` behaves
137///   like `"sum"` in the shader; the CPU is responsible for dividing by N.
138///   Unknown ops fall back to `"sum"`.
139pub fn reduction_wgsl(op: &str) -> String {
140    // Neutral elements and combine expressions for each operation.
141    let (neutral, combine) = match op {
142        "max" => ("f32(-1e38)", "max(acc, val)"),
143        "min" => ("f32(1e38)", "min(acc, val)"),
144        // "sum" and "mean" use the same reduction body.
145        _ => ("f32(0.0)", "acc + val"),
146    };
147
148    format!(
149        r#"
150// Reduction params: total element count.
151struct ReduceParams {{
152    n: u32,
153}}
154
155@group(0) @binding(0) var<storage, read>       input:        array<f32>;
156@group(0) @binding(1) var<storage, read_write> partial_sums: array<f32>;
157@group(0) @binding(2) var<uniform>             params:       ReduceParams;
158
159var<workgroup> shared_data: array<f32, 256>;
160
161@compute @workgroup_size(256)
162fn main(
163    @builtin(global_invocation_id) gid:  vec3<u32>,
164    @builtin(local_invocation_id)  lid:  vec3<u32>,
165    @builtin(workgroup_id)         wgid: vec3<u32>,
166) {{
167    let tid         = lid.x;
168    let global_idx  = gid.x;
169
170    // Load or use neutral element when out of range.
171    if (global_idx < params.n) {{
172        shared_data[tid] = input[global_idx];
173    }} else {{
174        shared_data[tid] = {neutral};
175    }}
176    workgroupBarrier();
177
178    // Parallel tree reduction within the workgroup.
179    var stride: u32 = 128u;
180    loop {{
181        if (stride == 0u) {{ break; }}
182        if (tid < stride) {{
183            let acc = shared_data[tid];
184            let val = shared_data[tid + stride];
185            shared_data[tid] = {combine};
186        }}
187        workgroupBarrier();
188        stride = stride >> 1u;
189    }}
190
191    // Thread 0 writes the workgroup result to the partial-sums buffer.
192    if (tid == 0u) {{
193        partial_sums[wgid.x] = shared_data[0];
194    }}
195}}
196"#,
197        neutral = neutral,
198        combine = combine,
199    )
200}
201
202/// Generate a WGSL compute shader for 2D convolution in NCHW format.
203///
204/// The shader reads from `input` (NCHW) and `filter` (K×C×FH×FW), writing
205/// the result to `output` (N×K×OH×OW).  Padding is handled via bounds
206/// checking — out-of-range input positions contribute zero.
207///
208/// # Arguments
209///
210/// * `n` — batch size
211/// * `c_in` — number of input channels
212/// * `h_in`, `w_in` — spatial input dimensions
213/// * `k_out` — number of output channels (filters)
214/// * `fh`, `fw` — filter height / width
215/// * `oh`, `ow` — output height / width
216/// * `stride_h`, `stride_w` — convolution strides
217/// * `pad_h`, `pad_w` — zero-padding applied to the input
218#[allow(clippy::too_many_arguments)]
219pub fn conv2d_wgsl(
220    n: u32,
221    c_in: u32,
222    h_in: u32,
223    w_in: u32,
224    k_out: u32,
225    fh: u32,
226    fw: u32,
227    oh: u32,
228    ow: u32,
229    stride_h: u32,
230    stride_w: u32,
231    pad_h: u32,
232    pad_w: u32,
233) -> String {
234    format!(
235        r#"
236// Conv2D NCHW — generated by oxicuda-webgpu
237// input : [{n}, {c_in}, {h_in}, {w_in}]
238// filter: [{k_out}, {c_in}, {fh}, {fw}]
239// output: [{n}, {k_out}, {oh}, {ow}]
240
241@group(0) @binding(0) var<storage, read>       input:  array<f32>;
242@group(0) @binding(1) var<storage, read>       filter: array<f32>;
243@group(0) @binding(2) var<storage, read_write> output: array<f32>;
244
245@compute @workgroup_size(8, 8)
246fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
247    // gid.x = output x (ox mapped across batches*k_out*oh)
248    // We flatten (batch, k, oy) into gid.y and ox into gid.x
249    let ox = gid.x;
250    let linear_y = gid.y;
251
252    let batch_k_oh = {n}u * {k_out}u * {oh}u;
253    if (ox >= {ow}u || linear_y >= batch_k_oh) {{ return; }}
254
255    let b  = linear_y / ({k_out}u * {oh}u);
256    let rem = linear_y % ({k_out}u * {oh}u);
257    let kf = rem / {oh}u;
258    let oy = rem % {oh}u;
259
260    var acc: f32 = 0.0;
261    for (var ci: u32 = 0u; ci < {c_in}u; ci = ci + 1u) {{
262        for (var fy: u32 = 0u; fy < {fh}u; fy = fy + 1u) {{
263            for (var fx: u32 = 0u; fx < {fw}u; fx = fx + 1u) {{
264                let iy_raw = i32(oy * {stride_h}u + fy) - i32({pad_h}u);
265                let ix_raw = i32(ox * {stride_w}u + fx) - i32({pad_w}u);
266                if (iy_raw >= 0 && iy_raw < i32({h_in}u) && ix_raw >= 0 && ix_raw < i32({w_in}u)) {{
267                    let iy = u32(iy_raw);
268                    let ix = u32(ix_raw);
269                    let in_idx = ((b * {c_in}u + ci) * {h_in}u + iy) * {w_in}u + ix;
270                    let f_idx  = ((kf * {c_in}u + ci) * {fh}u + fy) * {fw}u + fx;
271                    acc += input[in_idx] * filter[f_idx];
272                }}
273            }}
274        }}
275    }}
276
277    let o_idx = ((b * {k_out}u + kf) * {oh}u + oy) * {ow}u + ox;
278    output[o_idx] = acc;
279}}
280"#,
281        n = n,
282        c_in = c_in,
283        h_in = h_in,
284        w_in = w_in,
285        k_out = k_out,
286        fh = fh,
287        fw = fw,
288        oh = oh,
289        ow = ow,
290        stride_h = stride_h,
291        stride_w = stride_w,
292        pad_h = pad_h,
293        pad_w = pad_w,
294    )
295}
296
297/// Generate a WGSL compute shader for scaled dot-product attention.
298///
299/// Implements: `O = softmax(Q·K^T * scale [+ causal_mask]) · V`
300///
301/// The softmax is numerically stable (subtracts max before exp).
302/// When `causal` is true, positions where `sk > sq` are masked to −∞.
303///
304/// # Arguments
305///
306/// * `batch_heads` — combined batch × heads dimension
307/// * `seq_q` — query sequence length
308/// * `seq_kv` — key/value sequence length
309/// * `head_dim` — dimension of each head
310/// * `scale` — scaling factor (typically `1 / sqrt(head_dim)`)
311/// * `causal` — whether to apply a causal (upper-triangular) mask
312pub fn attention_wgsl(
313    batch_heads: u32,
314    seq_q: u32,
315    seq_kv: u32,
316    head_dim: u32,
317    scale: f32,
318    causal: bool,
319) -> String {
320    let causal_check = if causal {
321        "if (sk > sq) { score = f32(-1e38); } else {"
322    } else {
323        "{"
324    };
325
326    format!(
327        r#"
328// Scaled dot-product attention — generated by oxicuda-webgpu
329// Q, K, V : [{batch_heads}, seq, {head_dim}]
330// O       : [{batch_heads}, {seq_q}, {head_dim}]
331// scale   : {scale}
332// causal  : {causal}
333
334@group(0) @binding(0) var<storage, read>       q_buf: array<f32>;
335@group(0) @binding(1) var<storage, read>       k_buf: array<f32>;
336@group(0) @binding(2) var<storage, read>       v_buf: array<f32>;
337@group(0) @binding(3) var<storage, read_write> o_buf: array<f32>;
338
339@compute @workgroup_size(64)
340fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
341    let linear = gid.x;
342    let total = {batch_heads}u * {seq_q}u;
343    if (linear >= total) {{ return; }}
344
345    let bh = linear / {seq_q}u;
346    let sq = linear % {seq_q}u;
347
348    let q_base = (bh * {seq_q}u + sq) * {head_dim}u;
349
350    // Pass 1: find max score for numerical stability
351    var max_score: f32 = f32(-1e38);
352    for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
353        var score: f32 = 0.0;
354        {causal_check}
355            let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
356            for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
357                score += q_buf[q_base + d] * k_buf[k_base + d];
358            }}
359            score *= f32({scale});
360        }}
361        if (score > max_score) {{ max_score = score; }}
362    }}
363
364    // Pass 2: compute exp(score - max), accumulate weighted V
365    var sum_exp: f32 = 0.0;
366    for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
367        var score: f32 = 0.0;
368        {causal_check}
369            let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
370            for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
371                score += q_buf[q_base + d] * k_buf[k_base + d];
372            }}
373            score *= f32({scale});
374        }}
375        let w = exp(score - max_score);
376        sum_exp += w;
377        let v_base = (bh * {seq_kv}u + sk) * {head_dim}u;
378        let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
379        for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
380            // Accumulate in-place (we normalise after the loop).
381            o_buf[o_base + d] += w * v_buf[v_base + d];
382        }}
383    }}
384
385    // Pass 3: normalise
386    if (sum_exp > 0.0) {{
387        let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
388        for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
389            o_buf[o_base + d] /= sum_exp;
390        }}
391    }}
392}}
393"#,
394        batch_heads = batch_heads,
395        seq_q = seq_q,
396        seq_kv = seq_kv,
397        head_dim = head_dim,
398        scale = scale,
399        causal = causal,
400        causal_check = causal_check,
401    )
402}
403
404/// Generate WGSL for the final scalar reduction of partial sums.
405///
406/// Takes a `partial_sums` array of length `num_groups` and reduces it to a
407/// single value at `output[0]`.  Should be dispatched with a single workgroup
408/// of 256 threads.
409pub fn reduction_final_wgsl(op: &str) -> String {
410    let (neutral, combine) = match op {
411        "max" => ("f32(-1e38)", "max(acc, val)"),
412        "min" => ("f32(1e38)", "min(acc, val)"),
413        _ => ("f32(0.0)", "acc + val"),
414    };
415
416    format!(
417        r#"
418struct FinalReduceParams {{
419    num_groups: u32,
420}}
421
422@group(0) @binding(0) var<storage, read>       partial_sums: array<f32>;
423@group(0) @binding(1) var<storage, read_write> output:       array<f32>;
424@group(0) @binding(2) var<uniform>             params:       FinalReduceParams;
425
426var<workgroup> shared_data: array<f32, 256>;
427
428@compute @workgroup_size(256)
429fn main(
430    @builtin(local_invocation_id) lid: vec3<u32>,
431) {{
432    let tid = lid.x;
433
434    if (tid < params.num_groups) {{
435        shared_data[tid] = partial_sums[tid];
436    }} else {{
437        shared_data[tid] = {neutral};
438    }}
439    workgroupBarrier();
440
441    var stride: u32 = 128u;
442    loop {{
443        if (stride == 0u) {{ break; }}
444        if (tid < stride) {{
445            let acc = shared_data[tid];
446            let val = shared_data[tid + stride];
447            shared_data[tid] = {combine};
448        }}
449        workgroupBarrier();
450        stride = stride >> 1u;
451    }}
452
453    if (tid == 0u) {{
454        output[0] = shared_data[0];
455    }}
456}}
457"#,
458        neutral = neutral,
459        combine = combine,
460    )
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    #[test]
468    fn wgsl_gemm_contains_workgroup() {
469        let src = gemm_wgsl(16);
470        assert!(src.contains("@compute @workgroup_size(16, 16)"));
471        assert!(src.contains("GemmParams"));
472        assert!(src.contains("alpha"));
473        assert!(src.contains("beta"));
474    }
475
476    #[test]
477    fn wgsl_gemm_tile_size_embedded() {
478        let src8 = gemm_wgsl(8);
479        assert!(src8.contains("@workgroup_size(8, 8)"));
480        let src32 = gemm_wgsl(32);
481        assert!(src32.contains("@workgroup_size(32, 32)"));
482    }
483
484    #[test]
485    fn wgsl_elementwise_relu_contains_max() {
486        let src = elementwise_wgsl("relu");
487        assert!(src.contains("max(x, 0.0)"));
488    }
489
490    #[test]
491    fn wgsl_elementwise_all_ops() {
492        assert!(elementwise_wgsl("sigmoid").contains("exp(-x)"));
493        assert!(elementwise_wgsl("tanh").contains("tanh(x)"));
494        assert!(elementwise_wgsl("exp").contains("exp(x)"));
495        assert!(elementwise_wgsl("log").contains("log(x)"));
496        assert!(elementwise_wgsl("sqrt").contains("sqrt(x)"));
497        assert!(elementwise_wgsl("abs").contains("abs(x)"));
498        assert!(elementwise_wgsl("neg").contains("-x"));
499        // Unknown op is identity.
500        assert!(elementwise_wgsl("identity_op").contains("output[i] = x;"));
501    }
502
503    #[test]
504    fn wgsl_reduction_sum_contains_addition() {
505        let src = reduction_wgsl("sum");
506        assert!(src.contains("acc + val"));
507        assert!(src.contains("workgroupBarrier"));
508    }
509
510    #[test]
511    fn wgsl_reduction_max_uses_max_fn() {
512        let src = reduction_wgsl("max");
513        assert!(src.contains("max(acc, val)"));
514    }
515
516    #[test]
517    fn wgsl_reduction_min_uses_min_fn() {
518        let src = reduction_wgsl("min");
519        assert!(src.contains("min(acc, val)"));
520    }
521
522    #[test]
523    fn wgsl_reduction_mean_same_as_sum() {
524        // "mean" divides on the CPU side; the shader is identical to sum.
525        let sum_src = reduction_wgsl("sum");
526        let mean_src = reduction_wgsl("mean");
527        assert_eq!(sum_src, mean_src);
528    }
529
530    #[test]
531    fn wgsl_reduction_final_sum() {
532        let src = reduction_final_wgsl("sum");
533        assert!(src.contains("num_groups"));
534        assert!(src.contains("output[0]"));
535    }
536
537    // ── binary_wgsl tests ─────────────────────────────────────────────────
538
539    #[test]
540    fn wgsl_binary_add() {
541        let src = binary_wgsl("add");
542        assert!(src.contains("a + b"));
543        assert!(src.contains("lhs"));
544        assert!(src.contains("rhs"));
545    }
546
547    #[test]
548    fn wgsl_binary_all_ops() {
549        assert!(binary_wgsl("sub").contains("a - b"));
550        assert!(binary_wgsl("mul").contains("a * b"));
551        assert!(binary_wgsl("div").contains("a / b"));
552        assert!(binary_wgsl("max").contains("max(a, b)"));
553        assert!(binary_wgsl("min").contains("min(a, b)"));
554        assert!(binary_wgsl("pow").contains("pow(a, b)"));
555        // Unknown op is identity on lhs.
556        assert!(binary_wgsl("unknown_op").contains("output[i] = a;"));
557    }
558
559    #[test]
560    fn wgsl_binary_workgroup_size() {
561        let src = binary_wgsl("add");
562        assert!(src.contains("@workgroup_size(256)"));
563    }
564
565    // ── conv2d_wgsl tests ─────────────────────────────────────────────────
566
567    #[test]
568    fn wgsl_conv2d_contains_workgroup() {
569        let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
570        assert!(src.contains("@compute @workgroup_size(8, 8)"));
571    }
572
573    #[test]
574    fn wgsl_conv2d_contains_storage_bindings() {
575        let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
576        assert!(src.contains("var<storage, read>       input:"));
577        assert!(src.contains("var<storage, read>       filter:"));
578        assert!(src.contains("var<storage, read_write> output:"));
579    }
580
581    #[test]
582    fn wgsl_conv2d_embeds_dimensions() {
583        let src = conv2d_wgsl(2, 8, 64, 64, 32, 5, 5, 60, 60, 1, 1, 0, 0);
584        // Check that the shape constants appear in the shader
585        assert!(src.contains("8u")); // c_in
586        assert!(src.contains("64u")); // h_in or w_in
587        assert!(src.contains("32u")); // k_out
588        assert!(src.contains("5u")); // fh or fw
589        assert!(src.contains("60u")); // oh or ow
590    }
591
592    #[test]
593    fn wgsl_conv2d_has_padding_check() {
594        let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 8, 8, 1, 1, 1, 1);
595        // Padding check with signed comparison
596        assert!(src.contains("iy_raw >= 0"));
597        assert!(src.contains("ix_raw >= 0"));
598    }
599
600    #[test]
601    fn wgsl_conv2d_has_stride() {
602        let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 3, 3, 2, 2, 0, 0);
603        assert!(src.contains("2u")); // stride
604    }
605
606    // ── attention_wgsl tests ──────────────────────────────────────────────
607
608    #[test]
609    fn wgsl_attention_contains_workgroup() {
610        let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
611        assert!(src.contains("@compute @workgroup_size(64)"));
612    }
613
614    #[test]
615    fn wgsl_attention_contains_storage_bindings() {
616        let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
617        assert!(src.contains("var<storage, read>       q_buf:"));
618        assert!(src.contains("var<storage, read>       k_buf:"));
619        assert!(src.contains("var<storage, read>       v_buf:"));
620        assert!(src.contains("var<storage, read_write> o_buf:"));
621    }
622
623    #[test]
624    fn wgsl_attention_stable_softmax() {
625        let src = attention_wgsl(1, 4, 4, 32, 0.25, false);
626        assert!(src.contains("max_score"));
627        assert!(src.contains("exp(score - max_score)"));
628        assert!(src.contains("sum_exp"));
629    }
630
631    #[test]
632    fn wgsl_attention_causal_mask() {
633        let src_causal = attention_wgsl(1, 4, 4, 32, 0.25, true);
634        assert!(src_causal.contains("sk > sq"));
635
636        let src_non_causal = attention_wgsl(1, 4, 4, 32, 0.25, false);
637        assert!(!src_non_causal.contains("sk > sq"));
638    }
639
640    #[test]
641    fn wgsl_attention_embeds_scale() {
642        let src = attention_wgsl(2, 16, 16, 64, 0.125, false);
643        assert!(src.contains("0.125"));
644    }
645
646    #[test]
647    fn wgsl_attention_embeds_dimensions() {
648        let src = attention_wgsl(8, 32, 32, 128, 0.088, true);
649        assert!(src.contains("128u")); // head_dim
650        assert!(src.contains("32u")); // seq_q or seq_kv
651        assert!(src.contains("8u")); // batch_heads
652    }
653}