metaltile-std 0.1.0

MetalTile kernel standard library — benchmark metadata and type definitions
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! GatedDeltaNet innovation-tape capture + replay — port of
//! `gated_delta_replay.metal` (spec 020 phase 2). Companion to
//! `gated_delta.rs`; the speculative-decode rollback path for
//! GDN-bearing models (Qwen 3.5 / 3.6).
//!
//! Two kernels:
//!   - `gated_delta_step_record` — the standard GatedDelta forward step
//!     that *also* writes each step's `delta_t` to a `delta_log` tape.
//!   - `state_replay` — re-folds the accepted prefix `[0, accepted)` of
//!     an innovation tape onto a pre-record state snapshot:
//!     `state ← select(do_step, state·g_t + k_t·delta_t, state)`,
//!     branchless via `select` (good SIMD occupancy when the timestep
//!     mask is non-uniform within a simdgroup).
//!
//! Tape layout: `delta_log` [B, T, Hv, Dv], `k_log` [B, T, Hv, Dk]
//! (GQA-expanded by the cache), `g_log` [B, T, Hv].
//!
//! ## DISPATCH INVARIANTS
//!
//! - **Grid3D**, `grid = [1, Dv, batch*Hv]`, `tg = [32, 1, 1]`.
//! - `Dk` a multiple of 32.
//!
//! Codegen-only; correctness pinned by
//! `tests/gated_delta_replay_gpu_correctness.rs`.

use metaltile::kernel;
use metaltile_core::ir::KernelMode;

use crate::{
    bench_types::DType,
    spec::{BenchDispatch, BenchSpec},
};

macro_rules! gdr_spec {
    ($name:ident, $subop:literal) => {
        inventory::submit! {
            BenchSpec {
                op: "gated_delta_replay",
                subop: $subop,
                kernel_name: stringify!($name),
                kernel_ir: $name::kernel_ir_for,
                dtypes: &[DType::F32, DType::F16, DType::BF16],
                tol: 1e-3,
                mlx_src: None,
                mlx_pattern: None,
                shapes: &[],
                dispatch: BenchDispatch::Generic,
                kernel_mode: Some(KernelMode::Grid3D),
            }
        }
    };
}

// ── Forward GatedDelta step with per-step delta-tape capture ────────────────
macro_rules! gated_delta_record {
    ($name:ident, $dk:literal, $dv:literal, $hk:literal, $hv:literal, $n_per_t:literal, $subop:literal) => {
        #[kernel]
        pub fn $name<T>(
            q: Tensor<T>,
            k: Tensor<T>,
            v: Tensor<T>,
            g: Tensor<T>,
            beta: Tensor<T>,
            state_in: Tensor<T>,
            mask: Tensor<u32>,
            mut y: Tensor<T>,
            mut state_out: Tensor<T>,
            mut delta_log: Tensor<T>,
            #[constexpr] t_val: u32,
            #[constexpr] has_mask: u32,
        ) {
            let lane = program_id::<0>();
            let dv_idx = program_id::<1>();
            let n = program_id::<2>();
            let b_idx = n / $hv;
            let hv_idx = n - b_idx * $hv;
            let hk_idx = hv_idx / ($hv / $hk);
            let i_state_base = (n * $dv + dv_idx) * $dk;

            stack_alloc("state", $n_per_t, "f32");
            for i in range(0u32, $n_per_t, 1u32) {
                let v = load(state_in[i_state_base + $n_per_t * lane + i]).cast::<f32>();
                stack_store("state", i, v);
            }

            for t in range(0u32, t_val, 1u32) {
                let m = select(has_mask == 0u32, 1u32, load(mask[b_idx * t_val + t]));
                if m > 0u32 {
                    let qk_base = (b_idx * t_val + t) * $hk * $dk + hk_idx * $dk;
                    let v_base = (b_idx * t_val + t) * $hv * $dv + hv_idx * $dv;
                    let gb_idx = (b_idx * t_val + t) * $hv + hv_idx;
                    let g_val = load(g[gb_idx]).cast::<f32>();
                    let beta_val = load(beta[gb_idx]).cast::<f32>();

                    let mut kv_mem = 0.0f32;
                    for i in range(0u32, $n_per_t, 1u32) {
                        let s_idx = $n_per_t * lane + i;
                        let st = stack_load("state", i) * g_val;
                        stack_store("state", i, st);
                        kv_mem = kv_mem + st * load(k[qk_base + s_idx]).cast::<f32>();
                    }
                    let kv = simd_sum(kv_mem);
                    let delta = (load(v[v_base + dv_idx]).cast::<f32>() - kv) * beta_val;

                    // Tape write: surface delta_t for the replay kernel.
                    if lane == 0u32 {
                        store(delta_log[v_base + dv_idx], delta.cast::<T>());
                    }

                    let mut out_acc = 0.0f32;
                    for i in range(0u32, $n_per_t, 1u32) {
                        let s_idx = $n_per_t * lane + i;
                        let st =
                            stack_load("state", i) + load(k[qk_base + s_idx]).cast::<f32>() * delta;
                        stack_store("state", i, st);
                        out_acc = out_acc + st * load(q[qk_base + s_idx]).cast::<f32>();
                    }
                    let out_red = simd_sum(out_acc);
                    if lane == 0u32 {
                        store(y[v_base + dv_idx], out_red.cast::<T>());
                    }
                }
            }

            for i in range(0u32, $n_per_t, 1u32) {
                let st = stack_load("state", i);
                store(state_out[i_state_base + $n_per_t * lane + i], st.cast::<T>());
            }
        }
        gdr_spec!($name, $subop);
    };
}

// ── Tape replay: re-fold the accepted prefix onto a snapshot ────────────────
macro_rules! state_replay {
    ($name:ident, $dk:literal, $dv:literal, $hv:literal, $n_per_t:literal, $subop:literal) => {
        #[kernel]
        pub fn $name<T>(
            delta_log: Tensor<T>,
            k_log: Tensor<T>,
            g_log: Tensor<T>,
            state_in: Tensor<T>,
            mask: Tensor<u32>,
            mut state_out: Tensor<T>,
            #[constexpr] t_log: u32,
            #[constexpr] accepted: u32,
            #[constexpr] has_mask: u32,
        ) {
            let lane = program_id::<0>();
            let dv_idx = program_id::<1>();
            let n = program_id::<2>();
            let b_idx = n / $hv;
            let hv_idx = n - b_idx * $hv;
            let i_state_base = (n * $dv + dv_idx) * $dk;

            stack_alloc("state", $n_per_t, "f32");
            for i in range(0u32, $n_per_t, 1u32) {
                let v = load(state_in[i_state_base + $n_per_t * lane + i]).cast::<f32>();
                stack_store("state", i, v);
            }

            for t in range(0u32, t_log, 1u32) {
                let mask_v = select(has_mask == 0u32, 1u32, load(mask[b_idx * t_log + t]));
                // do_step = (t < accepted) && mask_passes — branchless.
                let do_step = select(t < accepted, mask_v, 0u32);

                let delta_row = (b_idx * t_log + t) * $hv * $dv + hv_idx * $dv;
                let k_row = (b_idx * t_log + t) * $hv * $dk + hv_idx * $dk;
                let g_idx = (b_idx * t_log + t) * $hv + hv_idx;
                let g_val = load(g_log[g_idx]).cast::<f32>();
                let d_val = load(delta_log[delta_row + dv_idx]).cast::<f32>();

                for i in range(0u32, $n_per_t, 1u32) {
                    let s_idx = $n_per_t * lane + i;
                    let old = stack_load("state", i);
                    let new_val = old * g_val + load(k_log[k_row + s_idx]).cast::<f32>() * d_val;
                    stack_store("state", i, select(do_step > 0u32, new_val, old));
                }
            }

            for i in range(0u32, $n_per_t, 1u32) {
                let st = stack_load("state", i);
                store(state_out[i_state_base + $n_per_t * lane + i], st.cast::<T>());
            }
        }
        gdr_spec!($name, $subop);
    };
}

// Qwen 3.5/3.6 A3B: Dk=192, Dv=128, Hk=4, Hv=4.
gated_delta_record!(
    gated_delta_step_record_d192_128_4_4,
    192u32,
    128u32,
    4u32,
    4u32,
    6u32,
    "record_d192_128_4_4"
);
state_replay!(state_replay_d192_128_4_4, 192u32, 128u32, 4u32, 6u32, "replay_d192_128_4_4");
// Small unit-test cell: Dk=64, Dv=32, Hk=2, Hv=2.
gated_delta_record!(
    gated_delta_step_record_d64_32_2_2,
    64u32,
    32u32,
    2u32,
    2u32,
    2u32,
    "record_d64_32_2_2"
);
state_replay!(state_replay_d64_32_2_2, 64u32, 32u32, 2u32, 2u32, "replay_d64_32_2_2");