mlx-native 0.6.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

// Wave 5b.1 iter 2 — chunk_scaled_dot_kkt Metal kernel.
//
// Spec source:
// - FLA reference: `chunk_scaled_dot_kkt_fwd_kernel` at
//   /opt/vllm/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py:36-99
//
// No FLA / Triton / CUDA / Metal code is copied. The math here is a
// re-derivation from the FLA spec; the algorithmic structure (load /
// scale / bf16-cast / dot / gate / mask) follows the FLA Triton kernel
// pattern but is open-coded for Metal.
//
// # Algorithm (per (batch b, V-head i_h, chunk i_t))
//
//   kh        = i_h / (H / Hg)                        # GQA-mapped K-head
//   b_beta    = beta[b, t_chunk, i_h]                 # [BT] f32
//   b_g       = g[b, t_chunk, i_h]                    # [BT] f32 (cumsumed)
//   ba_acc    = zeros([BT, BT])                       # f32 in shared mem
//   for i_k in 0..(K // BK):
//       # Cooperative load of bk_stage[BT, BK] in bf16 (post-scale-cast):
//       for cells in (BT*BK) / TG_THREADS:
//           bk_stage[bt, bk] = bf16(b_k[bt, bk].float() * b_beta[bt])  # FLA :86
//       barrier
//       # Per-thread accumulate into ba_acc[BT, BT]:
//       for cells in (BT*BT) / TG_THREADS:
//           (i, j) = unflatten(cell)
//           dot = 0.0f
//           for kk in 0..BK:
//               dot += bk_stage[i, kk].float() * b_k_orig[j, kk].float()
//           ba_acc[i, j] += dot
//       barrier
//   # Apply gate exp(b_g[:, None] - b_g[None, :]):
//   for cells in (BT*BT) / TG_THREADS:
//       (i, j) = unflatten(cell)
//       ba_acc[i, j] *= exp(b_g[i] - b_g[j])
//   # Apply strict-lower mask (row > col only) and store to A:
//   for cells in (BT*BT) / TG_THREADS:
//       (i, j) = unflatten(cell)
//       v = (i > j) ? ba_acc[i, j] : 0.0f
//       A[b, t_chunk[i], i_h, j] = v
//
// # Numerical precision
//
//   Inputs k bf16, beta f32, g f32. Intermediate ba_acc in f32. The bf16
//   cast on b_kb (post-scale, pre-dot) follows FLA wy_fast.py:86; the dot
//   accumulator stays in f32 (bf16 × bf16 → f32 via PyTorch promote-to-f32
//   when reading the bf16 bytes).
//
// # bf16 round-trip placement
//
//   FLA line 85: b_kb = b_k * b_beta[:, None]    # bf16 * f32 → f32
//   FLA line 86: b_A += dot(b_kb.to(b_k.dtype), trans(b_k))
//                              ^^^^^^^^^^^^^^^^^
//                              cast to bf16 BEFORE dot, AFTER scale
//
//   We mirror that ordering exactly:
//     1. local_kb_f32 = b_k.float() * b_beta[bt]      (FLA :85, scale in f32)
//     2. bk_stage[bt, bk] = bfloat(local_kb_f32)      (FLA :86, post-scale cast)
//     3. ba_acc[i, j] += bk_stage[i, kk].float() * b_k[j, kk].float()  (f32 dot)
//
// # Threadgroup memory layout (24 KB total at BT=BK=64)
//
//   bk_stage  [[threadgroup(0)]]  : BT * BK * 2 bytes = 8  KB  (bf16)
//   ba_acc    [[threadgroup(1)]]  : BT * BT * 4 bytes = 16 KB  (f32)
//
//   Total: 24 KB < 32 KB M5 Max cap.
//
// # Memory layouts (innermost-first)
//
//   k:    [B, T, Hg, K]   bf16  — K innermost
//   beta: [B, T, H]       f32   — H innermost
//   g:    [B, T, H]       f32   — H innermost
//   A:    [B, T, H, BT]   f32   — BT innermost (row-major within each chunk's
//                                  [BT, BT] block, which is stored at rows
//                                  [bos+i_t*BT : bos+(i_t+1)*BT] of A)
//
// # Threading
//
//   Grid: (NT, H, B)
//   Threadgroup: TG_THREADS = 256 (8 simdgroups × 32 lanes), flat 1D.
//
// # Buffer bindings
//
//   buffer(0): k        bf16
//   buffer(1): beta     f32
//   buffer(2): g        f32
//   buffer(3): A        f32  (output)
//   buffer(4): params   uint[8] = [B, T, Hg, H, K, BT, NT, BK]

constant uint TG_THREADS = 256u;
constant uint BT_FIXED   = 64u;   // chunk size (iter-2 fixed)
constant uint BK_FIXED   = 64u;   // K-tile width (iter-2 fixed)

kernel void gated_delta_net_kkt_bf16(
    device const bfloat *k          [[buffer(0)]],
    device const float  *beta       [[buffer(1)]],
    device const float  *g          [[buffer(2)]],
    device float        *A          [[buffer(3)]],
    device const uint   *params     [[buffer(4)]],
    threadgroup bfloat  *bk_stage   [[threadgroup(0)]],   // [BT, BK] bf16
    threadgroup float   *ba_acc     [[threadgroup(1)]],   // [BT, BT] f32
    uint3 tid3 [[thread_position_in_threadgroup]],
    uint3 tgid [[threadgroup_position_in_grid]]
) {
    const uint B   = params[0];
    const uint T   = params[1];
    const uint Hg  = params[2];
    const uint H   = params[3];
    const uint K   = params[4];
    const uint BT  = params[5];   // = 64 in iter-2
    const uint NT  = params[6];
    const uint BK  = params[7];   // = 64 in iter-2

    const uint i_t = tgid.x;
    const uint i_h = tgid.y;
    const uint i_b = tgid.z;
    const uint tid = tid3.x;

    if (i_b >= B || i_h >= H || i_t >= NT) return;

    const uint group_ratio = H / Hg;
    const uint kh          = i_h / group_ratio;
    const uint t_start     = i_t * BT;

    // Strides (in elements).
    const uint k_t_stride   = Hg * K;
    const uint k_seq_stride = T * k_t_stride;
    const uint g_t_stride   = H;
    const uint g_seq_stride = T * g_t_stride;
    // A: [B, T, H, BT] — stride for (b, t, h, bt) is (T*H*BT, H*BT, BT, 1).
    const uint a_t_stride   = H * BT;
    const uint a_seq_stride = T * a_t_stride;

    // ===================================================================
    // 0. Initialize ba_acc to zero.
    // ===================================================================
    const uint ba_cells = BT * BT;   // = 4096 at BT=64
    for (uint cell = tid; cell < ba_cells; cell += TG_THREADS) {
        ba_acc[cell] = 0.0f;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Load b_beta into a local register slab (one f32 per row).
    // Each thread computes its own b_beta[bt] on demand to avoid a
    // dedicated shared buffer; with BT=64 the load is cheap and cached.
    // (We could place beta in shared mem, but the redundant per-thread
    // re-reads are coalesced by Apple's L1 — measurement first; optimize
    // later if profiling shows pressure.)

    const uint beta_base = i_b * g_seq_stride + i_h;   // beta same layout as g
    const uint g_base    = i_b * g_seq_stride + i_h;

    // ===================================================================
    // 1. K-tile loop: for each i_k in 0..K/BK, load bk_stage and accumulate.
    // ===================================================================
    const uint nbk = K / BK;          // = 2 at K=128, BK=64
    const uint bk_cells = BT * BK;    // = 4096

    for (uint i_k = 0; i_k < nbk; ++i_k) {
        const uint k_off = i_k * BK;

        // -----------------------------------------------------------
        // 1a. Cooperative load of bk_stage[bt, bk] = bf16(k[bt, bk] * beta[bt]).
        //     Each thread owns bk_cells/TG_THREADS = 16 cells.
        // -----------------------------------------------------------
        const uint k_chunk_base = i_b * k_seq_stride + t_start * k_t_stride + kh * K;
        for (uint cell = tid; cell < bk_cells; cell += TG_THREADS) {
            const uint bt_idx = cell / BK;
            const uint bk_idx = cell - bt_idx * BK;
            const float k_val   = float(k[k_chunk_base + bt_idx * k_t_stride + k_off + bk_idx]);
            const float beta_v  = beta[beta_base + (t_start + bt_idx) * g_t_stride];
            // FLA :85 scale in f32; FLA :86 cast to bf16.
            bk_stage[bt_idx * BK + bk_idx] = bfloat(k_val * beta_v);
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);

        // -----------------------------------------------------------
        // 1b. Per-thread accumulate into ba_acc[BT, BT].
        //     ba_acc[i, j] += sum_kk(bk_stage[i, kk] * k[j, kk])
        //     Each thread owns ba_cells/TG_THREADS = 16 cells.
        // -----------------------------------------------------------
        for (uint cell = tid; cell < ba_cells; cell += TG_THREADS) {
            const uint i = cell / BT;
            const uint j = cell - i * BT;
            // bk_stage row i (bf16, post-scale), original k row j (bf16).
            float dot_val = 0.0f;
            for (uint kk = 0; kk < BK; ++kk) {
                const float a = float(bk_stage[i * BK + kk]);
                const float b = float(k[k_chunk_base + j * k_t_stride + k_off + kk]);
                dot_val += a * b;
            }
            ba_acc[cell] += dot_val;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    // ===================================================================
    // 2. Apply gate: ba_acc[i, j] *= exp(g[i] - g[j]).
    //    Then mask: keep only strict-lower (i > j); zero the rest.
    //    Then store to A.
    // ===================================================================
    for (uint cell = tid; cell < ba_cells; cell += TG_THREADS) {
        const uint i = cell / BT;
        const uint j = cell - i * BT;

        const float g_i = g[g_base + (t_start + i) * g_t_stride];
        const float g_j = g[g_base + (t_start + j) * g_t_stride];

        float v = ba_acc[cell] * metal::exp(g_i - g_j);
        // FLA :94-95: strict-lower mask (i > j ONLY).
        v = (i > j) ? v : 0.0f;

        // Store to A[b, t_start+i, h, j].
        const uint a_off = i_b * a_seq_stride
                         + (t_start + i) * a_t_stride
                         + i_h * BT
                         + j;
        A[a_off] = v;
    }
}