use metaltile::{bench_kernel, kernel};
use metaltile_core::ir::KernelMode;
use crate::{
bench_types::DType,
spec::{BenchDispatch, BenchSpec},
};
#[kernel]
pub fn mt_rms_inv_scalar(
partial_ssq: Tensor<f32>,
eps_buf: Tensor<f32>,
mut out: Tensor<f32>,
#[constexpr] n: u32,
) {
let v = load(partial_ssq[0u32]); let tg_ssq = reduce_sum(v);
let eps = load(eps_buf[0u32]);
store(out[0u32], rsqrt(tg_ssq / n + eps));
}
#[bench_kernel(
op="rms_norm",
subop="rms_norm",
class=RowNorm,
b=1024,
n=4096,
tpg=1024,
reads=2,
pre_weight=1.0,
post_eps=1e-5,
tol=1e-4,
mlx="rms{tn}",
metal_file="rms_norm.metal",
)]
#[kernel]
pub fn mt_rms_norm<T>(
x: Tensor<T>,
w: Tensor<T>,
out: Tensor<T>,
eps_buf: Tensor<f32>,
#[constexpr] n: u32,
) {
let row = program_id::<0>();
let rs = row * n;
let col = tid * 4u32;
let in_bounds = col + 3u32 < n;
let safe_col = select(in_bounds, col, 0u32);
let safe_base = rs + safe_col;
let base = rs + col; let x0 = load(x[safe_base]).cast::<f32>();
let x1 = load(x[safe_base + 1u32]).cast::<f32>();
let x2 = load(x[safe_base + 2u32]).cast::<f32>();
let x3 = load(x[safe_base + 3u32]).cast::<f32>();
let raw_ssq = x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
let partial_ssq = select(in_bounds, raw_ssq, 0.0f32);
let tg_ssq = reduce_sum(partial_ssq);
let eps = load(eps_buf[0]);
let rms = rsqrt(tg_ssq / n + eps);
if in_bounds {
store(out[base], (x0 * rms * load(w[col]).cast::<f32>()).cast::<T>());
store(out[base + 1u32], (x1 * rms * load(w[col + 1u32]).cast::<f32>()).cast::<T>());
store(out[base + 2u32], (x2 * rms * load(w[col + 2u32]).cast::<f32>()).cast::<T>());
store(out[base + 3u32], (x3 * rms * load(w[col + 3u32]).cast::<f32>()).cast::<T>());
}
}
#[bench_kernel(
op="rms_norm",
subop="rms_norm_small",
class=RowNorm,
// Per-head dispatch shape: head_dim=64 row count tuned so the bench
// walks a representative batched-prefill workload (4 batches × 16
// tokens × 16 q heads at head_dim=64 = 1024 rows). Same `n × b`
// total element count as the parent `mt_rms_norm` bench so the
// GB/s comparison is apples-to-apples.
b=1024,
n=64,
tpg=32,
reads=2,
pre_weight=1.0,
post_eps=1e-5,
tol=1e-4,
mlx="rms{tn}",
metal_file="rms_norm.metal",
)]
#[kernel]
pub fn mt_rms_norm_small<T>(
x: Tensor<T>,
w: Tensor<T>,
out: Tensor<T>,
eps_buf: Tensor<f32>,
#[constexpr] n: u32,
) {
let row = program_id::<0>();
let rs = row * n;
let base = rs + tid * 2u32;
let col = tid * 2u32;
let x0 = load(x[base]).cast::<f32>();
let x1 = load(x[base + 1u32]).cast::<f32>();
let partial_ssq = x0 * x0 + x1 * x1;
let tg_ssq = reduce_sum(partial_ssq);
let eps = load(eps_buf[0]);
let rms = rsqrt(tg_ssq / n + eps);
store(out[base], (x0 * rms * load(w[col]).cast::<f32>()).cast::<T>());
store(out[base + 1u32], (x1 * rms * load(w[col + 1u32]).cast::<f32>()).cast::<T>());
}
#[kernel]
pub fn mt_rms_norm_wide<T>(
x: Tensor<T>,
w: Tensor<T>,
out: Tensor<T>,
eps_buf: Tensor<f32>,
#[constexpr] n: u32,
) {
let row = program_id::<0>();
let rs = row * n;
let tpg = n_simd * 32u32;
let mut acc = 0.0f32;
for i in range(tid, n, tpg) {
let xi = load(x[rs + i]).cast::<f32>();
acc = acc + xi * xi;
}
let tg_ssq = reduce_sum(acc);
let eps = load(eps_buf[0]);
let rms = rsqrt(tg_ssq / n + eps);
for i in range(tid, n, tpg) {
let xi = load(x[rs + i]).cast::<f32>();
let wi = load(w[i]).cast::<f32>();
store(out[rs + i], (xi * rms * wi).cast::<T>());
}
}
#[kernel]
pub fn mt_gated_mixer_norm<T>(
y: Tensor<f32>,
z: Tensor<T>,
w: Tensor<T>,
out: Tensor<T>,
eps_buf: Tensor<f32>,
#[constexpr] n: u32,
) {
let row = program_id::<0>();
let rs = row * n;
let col = tid * 4u32;
let in_bounds = col + 3u32 < n;
let safe_col = select(in_bounds, col, 0u32);
let safe_base = rs + safe_col;
let base = rs + col;
let y0 = load(y[safe_base]).cast::<f32>();
let y1 = load(y[safe_base + 1u32]).cast::<f32>();
let y2 = load(y[safe_base + 2u32]).cast::<f32>();
let y3 = load(y[safe_base + 3u32]).cast::<f32>();
let raw_ssq = y0 * y0 + y1 * y1 + y2 * y2 + y3 * y3;
let partial_ssq = select(in_bounds, raw_ssq, 0.0f32);
let tg_ssq = reduce_sum(partial_ssq);
let eps = load(eps_buf[0]);
let rms = rsqrt(tg_ssq / n + eps);
if in_bounds {
let w0 = load(w[col]).cast::<f32>();
let w1 = load(w[col + 1u32]).cast::<f32>();
let w2 = load(w[col + 2u32]).cast::<f32>();
let w3 = load(w[col + 3u32]).cast::<f32>();
let z0 = load(z[base]).cast::<f32>();
let z1 = load(z[base + 1u32]).cast::<f32>();
let z2 = load(z[base + 2u32]).cast::<f32>();
let z3 = load(z[base + 3u32]).cast::<f32>();
let silu0 = z0 / (1.0f32 + exp(0.0f32 - z0));
let silu1 = z1 / (1.0f32 + exp(0.0f32 - z1));
let silu2 = z2 / (1.0f32 + exp(0.0f32 - z2));
let silu3 = z3 / (1.0f32 + exp(0.0f32 - z3));
store(out[base], ((y0 * rms * w0) * silu0).cast::<T>());
store(out[base + 1u32], ((y1 * rms * w1) * silu1).cast::<T>());
store(out[base + 2u32], ((y2 * rms * w2) * silu2).cast::<T>());
store(out[base + 3u32], ((y3 * rms * w3) * silu3).cast::<T>());
}
}
inventory::submit! {
BenchSpec {
op: "rms_norm",
subop: "rms_norm_wide",
kernel_name: "mt_rms_norm_wide",
kernel_ir: mt_rms_norm_wide::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 5e-4,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Reduction),
}
}
inventory::submit! {
BenchSpec {
op: "rms_norm",
subop: "gated_mixer_norm",
kernel_name: "mt_gated_mixer_norm",
kernel_ir: mt_gated_mixer_norm::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 1e-3,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Reduction),
}
}
#[cfg(test)]
mod wide_tests {
use metaltile_codegen::msl::MslGenerator;
use metaltile_core::ir::KernelMode;
use super::mt_rms_norm_wide;
use crate::bench_types::DType;
fn msl_for(dt: DType) -> String {
let mut k = mt_rms_norm_wide::kernel_ir_for(dt);
k.mode = KernelMode::Reduction;
MslGenerator::default().generate(&k).expect("mt_rms_norm_wide codegen succeeds")
}
#[test]
fn codegen_produces_nonempty_msl_for_all_float_dtypes() {
for dt in [DType::F32, DType::F16, DType::BF16] {
let src = msl_for(dt);
assert!(!src.trim().is_empty(), "MSL for {dt:?} should not be empty");
assert!(
src.contains("kernel void mt_rms_norm_wide"),
"MSL for {dt:?} should declare mt_rms_norm_wide:\n{src}",
);
}
}
}