metaltile-std 0.1.0

MetalTile kernel standard library — benchmark metadata and type definitions
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! RMS normalization benchmark — #[kernel] DSL vs MLX metal/rms_norm.metal
//!
//! The kernel is generic over `N = tpg * 4` — each thread owns 4
//! consecutive elements, the partial sum-of-squares reduces across
//! the threadgroup. The bench wires `n=4096, tpg=1024` for the
//! hidden-axis case. For per-head normalisation (Qwen3-style q_norm
//! / k_norm pre-RoPE), the same kernel is dispatched as one
//! threadgroup per `(batch*token*n_heads)` row at `tpg = head_dim/4`
//! with the per-head_dim weight broadcast across all rows. The
//! per-head contract is pinned by
//! `tests/rms_norm_per_head_gpu.rs`.
//!
//! Models with head_dim < 128 (older 7B-class, head_dim=64) dispatch
//! [`mt_rms_norm_small`] instead, which uses a 2-elements-per-thread
//! layout so head_dim=64 still hits the tpg=32 minimum.
//!
//! ## DISPATCH INVARIANTS
//!
//! This kernel is reduction-mode and has STRICT threadgroup-geometry
//! requirements. Violating any of these silently miscomputes the
//! output (best case) or pins the GPU in an infinite loop (worst
//! case — see FFAI post-mortem 2026-05-19). Consumers MUST encode
//! these as preconditions in their wrappers.
//!
//! - **`N = TPG * 4`.** Each thread owns exactly 4 consecutive
//!   elements of the row, loaded unconditionally at offsets
//!   `tid*4 + {0..3}`. The wrapper computes `TPG = n / 4`.
//! - **`TPG` must be a multiple of 32** (one full Apple simdgroup).
//!   The cross-simdgroup combine reads `n_simd = TPG / 32` slots
//!   from threadgroup memory; with `TPG < 32` the combine reads
//!   zero everywhere and `tg_ssq` silently collapses to 0.
//! - **`TPG ≤ 1024`** (Apple's max-threads-per-threadgroup cap on
//!   M-series). Combined with `N = TPG*4`, this means `N ≤ 4096`;
//!   larger rows need the multi-row dispatch variant + chunking.
//! - **Combined**: `n` must be a multiple of 128 and `n ≤ 4096`.
//! - **Grid: 1 threadgroup per row.** Multi-row dispatch uses
//!   `grid = (nRows * TPG, 1, 1)`, `tg = (TPG, 1, 1)`; Metal slices
//!   that into `nRows` threadgroups of `TPG` threads each.

use metaltile::{bench_kernel, kernel};
use metaltile_core::ir::KernelMode;

use crate::{
    bench_types::DType,
    spec::{BenchDispatch, BenchSpec},
};

/// Cross-kernel callee: threadgroup-wide RMS inverse.
///
/// Given each thread's pre-computed `partial_ssq` (sum of squares for its
/// slice of the row), reduces across the threadgroup and returns:
///
/// ```text
///   rsqrt(reduce_sum(partial_ssq) / n + eps)
/// ```
///
/// This kernel exists **only** as a cross-kernel callee. Kernels that fuse
/// RMSNorm with a second operation (residual add, RoPE, quantized GEMV) call
/// it via the DSL cross-kernel syntax so that the reduction + rsqrt body is
/// expressed once and inlined by `KernelInlinePass` rather than copy-pasted.
///
/// ## Calling convention
///
/// ```rust
/// // In the caller kernel body (after computing per-thread partial_ssq):
/// let inv_rms = mt_rms_inv_scalar(partial_ssq, eps_buf, n);
/// ```
///
/// - `partial_ssq` → `KernelCallArg::Value`: the callee's param-load is
///   replaced by the caller's pre-computed scalar. No memory round-trip.
/// - `eps_buf`, `n` → `KernelCallArg::Tensor`: the callee's loads are kept
///   but renamed to the caller's buffer/constexpr names, so the inlined code
///   reads the correct per-kernel eps and row length.
/// - The output param `out` receives no arg; its store is skipped and the
///   stored `inv_rms` value is returned as the call result.
///
/// ## Standalone vs inlined semantics
///
/// `mt_rms_inv_scalar` is a **valid standalone kernel**: `partial_ssq` is a
/// real 1-element `Tensor<f32>` and `load(partial_ssq[0u32])` is a legal
/// memory access. It can be dispatched directly (e.g. in tests) by passing a
/// 1-element buffer containing the pre-summed partial sum.
///
/// When called via the cross-kernel DSL (`let inv = mt_rms_inv_scalar(g, ...)`)
/// the caller passes `g` as a `KernelCallArg::Value` — a pre-computed scalar
/// already in registers. `KernelInlinePass` detects the `Value` arg, skips the
/// load, and substitutes `g` directly, eliminating the memory round-trip.
/// This is load-forwarding: the callee is correct both ways.
#[kernel]
pub fn mt_rms_inv_scalar(
    partial_ssq: Tensor<f32>,
    eps_buf: Tensor<f32>,
    mut out: Tensor<f32>,
    #[constexpr] n: u32,
) {
    let v = load(partial_ssq[0u32]); // replaced by Value arg at inline time
    let tg_ssq = reduce_sum(v);
    let eps = load(eps_buf[0u32]);
    store(out[0u32], rsqrt(tg_ssq / n + eps));
}

#[bench_kernel(
    op="rms_norm",
    subop="rms_norm",
    class=RowNorm,
    b=1024,
    n=4096,
    tpg=1024,
    reads=2,
    pre_weight=1.0,
    post_eps=1e-5,
    tol=1e-4,
    mlx="rms{tn}",
    metal_file="rms_norm.metal",
)]
#[kernel]
pub fn mt_rms_norm<T>(
    x: Tensor<T>,
    w: Tensor<T>,
    out: Tensor<T>,
    eps_buf: Tensor<f32>,
    #[constexpr] n: u32,
) {
    let row = program_id::<0>();
    let rs = row * n;
    // Each thread owns exactly 4 consecutive elements (N = TPG * 4).
    // The wrapper enforces this — but as belt-and-braces (the original
    // 2026-05-19 freeze came from a wrong-TPG dispatch in a sibling
    // kernel), clamp the load base for OOB threads and mask their SSQ
    // contribution + skip their stores. Threads with `col >= n` re-read
    // row[0..3] (benign, since `partial_ssq` for them is forced to 0),
    // participate in `reduce_sum` (required — Apple simdgroup
    // primitives need all lanes active), and skip their stores so
    // they don't trample a neighbouring row.
    let col = tid * 4u32;
    let in_bounds = col + 3u32 < n;
    let safe_col = select(in_bounds, col, 0u32);
    let safe_base = rs + safe_col;
    let base = rs + col; // only used inside the in_bounds-guarded store block.
    // Read x once, cache in registers, reuse for both ssq and output — 3 reads total.
    let x0 = load(x[safe_base]).cast::<f32>();
    let x1 = load(x[safe_base + 1u32]).cast::<f32>();
    let x2 = load(x[safe_base + 2u32]).cast::<f32>();
    let x3 = load(x[safe_base + 3u32]).cast::<f32>();
    let raw_ssq = x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
    // Mask OOB lanes to 0 contribution so `mean(x²) = tg_ssq / n` stays
    // correct: in-bounds lanes contribute their real x² values, the
    // sum/n divisor is unchanged. Only valid when the wrapper has
    // ensured the in-bounds lanes cover the full row exactly once;
    // duplicate / missing coverage is a wrapper bug we can't repair here.
    let partial_ssq = select(in_bounds, raw_ssq, 0.0f32);
    let tg_ssq = reduce_sum(partial_ssq);
    let eps = load(eps_buf[0]);
    let rms = rsqrt(tg_ssq / n + eps);
    if in_bounds {
        store(out[base], (x0 * rms * load(w[col]).cast::<f32>()).cast::<T>());
        store(out[base + 1u32], (x1 * rms * load(w[col + 1u32]).cast::<f32>()).cast::<T>());
        store(out[base + 2u32], (x2 * rms * load(w[col + 2u32]).cast::<f32>()).cast::<T>());
        store(out[base + 3u32], (x3 * rms * load(w[col + 3u32]).cast::<f32>()).cast::<T>());
    }
}

/// Small-head RMSNorm — 2 consecutive elements per thread, so
/// `N = tpg * 2`. Covers per-head dispatch at head_dim ∈ {64, 128,
/// 192, 256} (head_dim=64 → tpg=32 hits the single-simdgroup
/// minimum that the 4-element variant misses). At head_dim ≥ 128
/// the 4-element [`mt_rms_norm`] has better ILP per lane and is
/// preferred; this variant exists to cover the small-head_dim
/// regime (older 7B-class architectures) without a dispatch-time
/// fallback.
///
/// Algorithm-identical to `mt_rms_norm`: f32 accumulator for the
/// sum-of-squares, threadgroup-wide `reduce_sum`, `rsqrt(ssq/n + eps)`
/// scaling, per-element output store rounded through `T`.
#[bench_kernel(
    op="rms_norm",
    subop="rms_norm_small",
    class=RowNorm,
    // Per-head dispatch shape: head_dim=64 row count tuned so the bench
    // walks a representative batched-prefill workload (4 batches × 16
    // tokens × 16 q heads at head_dim=64 = 1024 rows). Same `n × b`
    // total element count as the parent `mt_rms_norm` bench so the
    // GB/s comparison is apples-to-apples.
    b=1024,
    n=64,
    tpg=32,
    reads=2,
    pre_weight=1.0,
    post_eps=1e-5,
    tol=1e-4,
    mlx="rms{tn}",
    metal_file="rms_norm.metal",
)]
#[kernel]
pub fn mt_rms_norm_small<T>(
    x: Tensor<T>,
    w: Tensor<T>,
    out: Tensor<T>,
    eps_buf: Tensor<f32>,
    #[constexpr] n: u32,
) {
    let row = program_id::<0>();
    let rs = row * n;
    // 2 elements per thread → tpg = n / 2. The minimum supported is
    // tpg = 32 (one full simdgroup) → n ≥ 64.
    let base = rs + tid * 2u32;
    let col = tid * 2u32;
    let x0 = load(x[base]).cast::<f32>();
    let x1 = load(x[base + 1u32]).cast::<f32>();
    let partial_ssq = x0 * x0 + x1 * x1;
    let tg_ssq = reduce_sum(partial_ssq);
    let eps = load(eps_buf[0]);
    let rms = rsqrt(tg_ssq / n + eps);
    store(out[base], (x0 * rms * load(w[col]).cast::<f32>()).cast::<T>());
    store(out[base + 1u32], (x1 * rms * load(w[col + 1u32]).cast::<f32>()).cast::<T>());
}

/// Wide-row RMSNorm — handles rows wider than the 4096-element cap of
/// [`mt_rms_norm`]. Where `mt_rms_norm` fixes `N = TPG * 4` (so a
/// 1024-thread group tops out at 4096), this kernel has each thread
/// *stride* over the row in steps of one full threadgroup, so any `n`
/// is covered regardless of the threadgroup size. Needed for
/// large-hidden models (e.g. Gemma 4 31B, hidden 5376).
///
/// Two passes over device memory: pass 1 accumulates the strided
/// sum-of-squares and reduces it threadgroup-wide; pass 2 re-reads `x`
/// and writes the scaled output. The per-thread element count is
/// `ceil(n / TPG)` and varies with `n`, so the `x` values cannot be
/// held in registers across the reduction the way `mt_rms_norm` does
/// — hence the re-read. RMSNorm is memory-bound; the extra `x` read is
/// the price of unbounded `n`.
///
/// ## DISPATCH INVARIANTS
///
/// - **TPG a multiple of 32** (one full Apple simdgroup) so the
///   `reduce_sum` cross-simdgroup combine is well-defined. The wrapper
///   uses TPG = 1024. The stride is derived as `n_simd * 32`, so the
///   kernel is correct for any such TPG.
/// - **Grid: 1 threadgroup per row.** Multi-row dispatch uses
///   `grid = (nRows * TPG, 1, 1)`, `tg = (TPG, 1, 1)`.
/// - **`n` may be any positive value.** The strided loops bound on
///   `n`, so no `N = TPG * k` relationship is required; threads whose
///   stride walks past `n` simply stop. Unlike `mt_rms_norm` there is
///   no 128-alignment or `n ≤ 4096` requirement.
#[kernel]
pub fn mt_rms_norm_wide<T>(
    x: Tensor<T>,
    w: Tensor<T>,
    out: Tensor<T>,
    eps_buf: Tensor<f32>,
    #[constexpr] n: u32,
) {
    let row = program_id::<0>();
    let rs = row * n;
    // One full threadgroup of threads; every thread strides by this.
    let tpg = n_simd * 32u32;
    // Pass 1: strided sum-of-squares. A thread with `tid >= n` runs
    // zero iterations and contributes 0 — still required to reach
    // `reduce_sum` (Apple simdgroup reductions need all lanes active).
    let mut acc = 0.0f32;
    for i in range(tid, n, tpg) {
        let xi = load(x[rs + i]).cast::<f32>();
        acc = acc + xi * xi;
    }
    let tg_ssq = reduce_sum(acc);
    let eps = load(eps_buf[0]);
    let rms = rsqrt(tg_ssq / n + eps);
    // Pass 2: strided scaled store. `x` is re-read from device memory
    // (see the doc note above).
    for i in range(tid, n, tpg) {
        let xi = load(x[rs + i]).cast::<f32>();
        let wi = load(w[i]).cast::<f32>();
        store(out[rs + i], (xi * rms * wi).cast::<T>());
    }
}

/// Fused gated-mixer-norm: `out = rms_norm(y, w) · silu(z)`. Per-row
/// across `[Hv, Dv]` — one row per threadgroup. Used by the FFAI
/// Qwen3.5 / Qwen3.6 GDN mixer's phase-2 step (`y` is the recurrence
/// output in fp32; `z` is the gate from `in_proj_z` in the model
/// dtype; `w` is `mixer.norm.weight`). Folding RMSNorm + weight +
/// `silu(z)` into one dispatch kills the host round-trip the legacy
/// path needed to compute this on the CPU between phases — 30 host
/// commit+waits per Qwen3.6-A3B decode token recovered.
///
/// Math (one row):
///   rms = rsqrt(mean(y²) + eps)
///   y_normed[i] = y[i] * rms * w[i]
///   silu(z)[i]  = z[i] / (1 + exp(-z[i]))
///   out[i] = y_normed[i] * silu(z)[i]
///
/// Same `N = TPG * 4` invariant as `mt_rms_norm` — Dv is multiple of
/// 4 on every shipped Qwen3 hybrid (128 / 256 / 512). One thread owns
/// 4 consecutive `Dv`-axis elements; the OOB clamp + mask copies the
/// `mt_rms_norm` template so a wrong-TPG dispatch fails loudly rather
/// than silently miscomputing.
#[kernel]
pub fn mt_gated_mixer_norm<T>(
    y: Tensor<f32>,
    z: Tensor<T>,
    w: Tensor<T>,
    out: Tensor<T>,
    eps_buf: Tensor<f32>,
    #[constexpr] n: u32,
) {
    let row = program_id::<0>();
    let rs = row * n;
    let col = tid * 4u32;
    let in_bounds = col + 3u32 < n;
    let safe_col = select(in_bounds, col, 0u32);
    let safe_base = rs + safe_col;
    let base = rs + col;
    // y is already fp32, but mirror the mt_rms_norm load pattern
    // (`.cast::<f32>()` after each load) — the vectorize pass on this
    // codegen reads the cast as the consumer hook for the float4
    // load+extract emit. Removing the cast leaves the vectorize pass
    // half-finished (load merges into a float4, scalar y_n references
    // never get rewritten into VectorExtract — see emit + bug-report
    // in metaltile codegen `vectorize.rs`).
    let y0 = load(y[safe_base]).cast::<f32>();
    let y1 = load(y[safe_base + 1u32]).cast::<f32>();
    let y2 = load(y[safe_base + 2u32]).cast::<f32>();
    let y3 = load(y[safe_base + 3u32]).cast::<f32>();
    let raw_ssq = y0 * y0 + y1 * y1 + y2 * y2 + y3 * y3;
    let partial_ssq = select(in_bounds, raw_ssq, 0.0f32);
    let tg_ssq = reduce_sum(partial_ssq);
    let eps = load(eps_buf[0]);
    let rms = rsqrt(tg_ssq / n + eps);
    if in_bounds {
        let w0 = load(w[col]).cast::<f32>();
        let w1 = load(w[col + 1u32]).cast::<f32>();
        let w2 = load(w[col + 2u32]).cast::<f32>();
        let w3 = load(w[col + 3u32]).cast::<f32>();
        let z0 = load(z[base]).cast::<f32>();
        let z1 = load(z[base + 1u32]).cast::<f32>();
        let z2 = load(z[base + 2u32]).cast::<f32>();
        let z3 = load(z[base + 3u32]).cast::<f32>();
        // silu(z) = z / (1 + exp(-z)). Inlined per the `mt_sigmoid`
        // precedent — Activation::Sigmoid folds into FusedElementwise
        // and the per-kernel feature analyzer would miss it, so the
        // emitted MSL stays self-contained without an `mt_sigmoid`
        // helper. Same as `mt_gated_delta_prep_step`'s `beta` path.
        let silu0 = z0 / (1.0f32 + exp(0.0f32 - z0));
        let silu1 = z1 / (1.0f32 + exp(0.0f32 - z1));
        let silu2 = z2 / (1.0f32 + exp(0.0f32 - z2));
        let silu3 = z3 / (1.0f32 + exp(0.0f32 - z3));
        store(out[base], ((y0 * rms * w0) * silu0).cast::<T>());
        store(out[base + 1u32], ((y1 * rms * w1) * silu1).cast::<T>());
        store(out[base + 2u32], ((y2 * rms * w2) * silu2).cast::<T>());
        store(out[base + 3u32], ((y3 * rms * w3) * silu3).cast::<T>());
    }
}

inventory::submit! {
    BenchSpec {
        op: "rms_norm",
        subop: "rms_norm_wide",
        kernel_name: "mt_rms_norm_wide",
        kernel_ir: mt_rms_norm_wide::kernel_ir_for,
        dtypes: &[DType::F32, DType::F16, DType::BF16],
        tol: 5e-4,
        mlx_src: None,
        mlx_pattern: None,
        shapes: &[],
        dispatch: BenchDispatch::Generic,
        kernel_mode: Some(KernelMode::Reduction),
    }
}

inventory::submit! {
    BenchSpec {
        op: "rms_norm",
        subop: "gated_mixer_norm",
        kernel_name: "mt_gated_mixer_norm",
        kernel_ir: mt_gated_mixer_norm::kernel_ir_for,
        dtypes: &[DType::F32, DType::F16, DType::BF16],
        tol: 1e-3,
        mlx_src: None,
        mlx_pattern: None,
        shapes: &[],
        dispatch: BenchDispatch::Generic,
        kernel_mode: Some(KernelMode::Reduction),
    }
}

#[cfg(test)]
mod wide_tests {
    use metaltile_codegen::msl::MslGenerator;
    use metaltile_core::ir::KernelMode;

    use super::mt_rms_norm_wide;
    use crate::bench_types::DType;

    fn msl_for(dt: DType) -> String {
        let mut k = mt_rms_norm_wide::kernel_ir_for(dt);
        k.mode = KernelMode::Reduction;
        MslGenerator::default().generate(&k).expect("mt_rms_norm_wide codegen succeeds")
    }

    #[test]
    fn codegen_produces_nonempty_msl_for_all_float_dtypes() {
        for dt in [DType::F32, DType::F16, DType::BF16] {
            let src = msl_for(dt);
            assert!(!src.trim().is_empty(), "MSL for {dt:?} should not be empty");
            assert!(
                src.contains("kernel void mt_rms_norm_wide"),
                "MSL for {dt:?} should declare mt_rms_norm_wide:\n{src}",
            );
        }
    }
}