mlx-native 0.6.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
// flash_attn_prefill_blk — pre-pass tile-skip classifier for the
// flash_attn_prefill family of kernels.
//
// Ported from llama.cpp's `kernel_flash_attn_ext_blk`
// (/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal:5666-5719).
//
// ## What it does
//
// Walks the additive attention mask in tile-sized chunks matching the main
// kernel's (BQ, BK) geometry and emits a single byte per (qtile, ktile) pair:
//
//   0 — "skip"          : the entire tile is masked to -inf.  Main kernel
//                         does `continue` for this KV tile (no K-load, no
//                         Q·K^T, no V-load, no mask-add, no softmax update).
//   1 — "mixed"         : tile has at least one finite mask value AND at
//                         least one attended cell.  Main kernel does the
//                         normal mask-load + mask-add path.
//   2 — "all_attended"  : every mask cell in the tile is exactly 0.0.  Main
//                         kernel computes Q·K^T and softmax normally but
//                         skips the mask-add (save one mask load per tile).
//
// ## Why not re-use llama.cpp's (8, 64) tile shape
//
// The blk byte is indexed by the main kernel as `blk[qt][kt]` where
// `(qt, kt)` is the main kernel's outer KV-tile loop position.  Our
// flash_attn_prefill main kernels use a DIFFERENT geometry than llama.cpp:
//
//   D=256  : BQ=32, BK=16   (candle-derived per-warp-Q-stacking template)
//   D=512  : BQ= 8, BK= 8   (llama.cpp-derived per-simdgroup-Q-distributed)
//
// The pre-pass MUST use the same (BQ, BK) as the main kernel it feeds,
// otherwise the blk index arithmetic is wrong.  See ADR-011 phase 2 §5.1
// for the full analysis.
//
// ## Sentinel convention (differs from llama.cpp)
//
// llama.cpp uses f16 masks with `-MAXHALF` as the "fully masked" sentinel
// (`ggml-metal.metal:5704`).  We use bf16 masks with a true `-INFINITY`
// encoding (`ADR-011-phase2-port-sentinel.md §2`, Wave 2A + Wave 2D).  The
// classification threshold is `mmax > bfloat(-1.0e30)` — conservative wide
// threshold that matches both true `-inf` and any "very negative" finite
// sentinel a future caller might pass.  The `mmin == 0 && mmax == 0` check
// is exact because the mask builder writes bit-exact `bfloat(0.0)` (bit
// pattern 0x0000) for every attended cell.
//
// ## Grid geometry
//
//   Threadgroups: (NK, NQ, 1)   — one threadgroup per (Q-tile, K-tile) pair.
//   Threads/TG  : (32, 1, 1)    — one simdgroup (matches llama.cpp:5666).
//
// Each lane in the simdgroup reads elements from the tile, performs a
// simdgroup-wide min/max reduction, and lane 0 writes the classification
// byte.
//
// For our D=256 geometry (BQ=32, BK=16) the tile is 32 rows × 16 cols = 512
// elements.  With NW=32 lanes each lane sees C=BK=16 elements of one row
// via the inner loop; only lanes [0, BK) participate on the col axis (half
// the simdgroup contributes identity values).  This is Option (b) in
// ADR §5.1 — acceptable for the initial port because the pre-pass is
// already cheap vs the main kernel.  Revisit only if profiling shows the
// pre-pass on the critical path.
//
// For our D=512 geometry (BQ=8, BK=8) the tile is 8 × 8 = 64 elements;
// lanes [0, 8) participate, lanes [8, 32) carry identity.  Same structure.
//
// ## Function constants
//
// | Index | Name   | Purpose |
// |-------|--------|---------|
// | 400   | BQ_blk | Q-rows per tile (matches main kernel's BQ).  32 or 8. |
// | 401   | BK_blk | K-cols per tile (matches main kernel's BK).  16 or 8. |
//
// A single pipeline is compiled per (BQ, BK) combination; the dispatcher
// decides which values to set based on which main kernel is being fed.
//
// SPDX-License-Identifier: MIT

#include <metal_stdlib>
using namespace metal;

#if defined(__HAVE_BFLOAT__)
typedef bfloat bfloat16_t;
#else
// bf16 is the only mask dtype we support here — mirrors the Wave-2D mask
// builder.  On an Apple Silicon target without bfloat support the
// flash_attn_prefill family would not compile either, so this fallback is
// never exercised; the typedef exists only so the preprocessor doesn't
// error on the bfloat symbol.
typedef half bfloat16_t;
#endif

// Function constants — the dispatcher specialises these at pipeline
// creation time.  Kernels compiled with different (BQ, BK) share the same
// entry point name; the cache key in KernelRegistry disambiguates.
constant int BQ_blk [[function_constant(400)]];
constant int BK_blk [[function_constant(401)]];

// Sentinel guards for function-constant presence.  The dispatcher always
// sets these, but Metal's function-constant machinery requires us to
// handle the "undefined" case defensively to avoid spurious compile
// warnings during preview compiles.
constant int BQ_def = is_function_constant_defined(BQ_blk) ? BQ_blk : 32;
constant int BK_def = is_function_constant_defined(BK_blk) ? BK_blk : 16;

// Shader-side parameter block.  Mirrors the Rust `BlkParamsGpu` struct in
// flash_attn_prefill_blk.rs byte-for-byte.  Total 16 bytes (4 × i32).
struct FlashAttnPrefillBlkParams {
    int seq_len_q;    // qL — number of Q rows in the mask
    int seq_len_k;    // kL — number of K cols in the mask
    int mask_row_stride;  // stride between consecutive Q rows of the mask, in ELEMENTS (bf16 units)
    int _pad;         // explicit padding to keep bytemuck::Pod layout obvious
};

// Single-simdgroup tile classifier.
//
// One threadgroup per (qtile, ktile) pair; within the threadgroup, the
// lane 0 of the simdgroup writes the output byte after a 32-wide reduction.
//
// The tile is read directly from device memory — we do NOT stage the mask
// through threadgroup memory, matching llama.cpp's design choice
// (`ggml-metal.metal:5685-5699`).  This is what makes the pre-pass
// asymptotically cheaper than inline classification in the main kernel:
// no K/V loads, no mask staging to shared memory, no cross-simdgroup
// synchronisation.
kernel void flash_attn_prefill_blk_bf16(
    device const bfloat16_t* mask                   [[buffer(0)]],
    device       char*       blk_out                [[buffer(1)]],
    constant FlashAttnPrefillBlkParams& params      [[buffer(2)]],
    uint3  tgpig                                    [[threadgroup_position_in_grid]],
    ushort tiisg                                    [[thread_index_in_simdgroup]]
) {
    const int BQ = BQ_def;          // Q-rows per tile
    const int BK = BK_def;          // K-cols per tile
    const int NW = 32;              // simd width (Apple GPU)

    const int qt = int(tgpig.y);    // Q-tile index
    const int kt = int(tgpig.x);    // K-tile index

    const int qL = params.seq_len_q;
    const int kL = params.seq_len_k;
    const int M_stride = params.mask_row_stride;  // elements between mask rows

    // Mirror llama.cpp ggml-metal.metal:5683 — partial trailing K-tiles
    // (tile straddles the kL right edge) default to `mixed` (1).  Classifying
    // a partial tile cleanly would require per-element bound checks inside
    // the main kernel's loop, which is exactly what the main kernel already
    // does for the last KV tile.  Giving it `1` lets the normal path handle
    // the remainder correctly.
    //
    // Ordering: we ALSO check whether the tile is entirely past kL (can
    // happen when kL < kt*BK); such a tile gets byte=0 so the main kernel
    // skips it.  Technically the main kernel never dispatches a qtile
    // beyond kL_aligned+1, but we emit a defensive 0 to avoid coupling
    // the two kernels' bound-check logic.
    const int tile_k_start = kt * BK;
    const int tile_k_end   = tile_k_start + BK;

    char res;
    if (tile_k_start >= kL) {
        // Entire tile past the mask's last valid K column — fully masked.
        // Main kernel's outer loop clamps to kL so it would never issue a
        // tile like this; emit 0 for completeness.
        res = 0;
    } else if (tile_k_end > kL) {
        // Partial right-edge tile — cannot be cleanly classified as skip
        // (edge rows contain valid mask values) or all-zero (trailing pad
        // bytes are undefined).  Fallback to mixed; the main kernel's
        // kL_rem branch handles per-element bound checks.
        res = 1;
    } else {
        // Fully-inside tile — classify by simdgroup reduction.
        res = 0;

        // Per-lane pointer to the tile start: mask row 0 of the Q-tile,
        // column (tile_k_start + tiisg).  When tiisg >= BK, the pointer is
        // out-of-tile — we gate reads on `tiisg < BK` so those lanes
        // contribute identity values (bfloat16_t has no sentinel better
        // than what we start mmin/mmax at, so skipping the read is the
        // correct behaviour).
        //
        // Wave 4 Phase B: widen the row-offset multiplication to int64_t.
        // With qt, BQ, M_stride all int the product `(qt * BQ) * M_stride`
        // overflows i32 at qt >= 1024 (1024 * 32 * 65536 = 2^31), wrapping
        // to a large negative pointer offset and reading garbage before
        // the mask buffer's base.  Mirrors the already-correct
        // flash_attn_prefill_d512.metal:411-413 ulong-cast idiom.  See
        // /tmp/cfa-cfa-20260427-adr005-wave4/phase-A-report.md §2.5.2.
        device const bfloat16_t* mask_src =
            mask + (int64_t)(qt * BQ) * (int64_t)M_stride + tile_k_start + tiisg;

        // Use f32 for reduction to avoid bf16 comparison subtleties — bf16
        // min/max are well-defined on Apple Silicon but f32 reductions are
        // universally safe and identical in semantics for finite values.
        // IEEE-754 min/max both propagate -inf correctly: min(x, -inf) = -inf,
        // max(x, -inf) = x for any finite x.
        float mmin =  INFINITY;
        float mmax = -INFINITY;

        // Compute the number of rows this Q-tile actually contains.  If the
        // Q-tile straddles the bottom edge (qt * BQ + BQ > qL) we only read
        // the rows that exist — reading past qL would fetch undefined bytes
        // (the mask buffer is exactly qL * kL * sizeof(bfloat) bytes).
        //
        // This is an internal defensive check — the main kernel's Q-tile
        // launcher already clamps dispatched qtiles to `ceil(qL/BQ)`, so in
        // practice `qt * BQ < qL` always holds.  But the pre-pass is
        // dispatched with the same ceil, so the LAST Q-tile can still have
        // `qt*BQ + BQ > qL`.  Without the clamp we'd read garbage for rows
        // >= qL, which would spuriously set mmin/mmax and flip the
        // classification of the last Q-tile.
        int q_rows = BQ;
        if (qt * BQ + BQ > qL) {
            q_rows = qL - qt * BQ;
            if (q_rows < 0) {
                q_rows = 0;
            }
        }

        // Walk the tile's rows.  Only lanes with tiisg < BK do useful work;
        // other lanes carry identity (mmin=+inf, mmax=-inf) through the
        // reduction which is the simd_min/simd_max identity.
        if (tiisg < BK) {
            for (int j = 0; j < q_rows; ++j) {
                float v = float(mask_src[j * M_stride]);
                mmin = min(mmin, v);
                mmax = max(mmax, v);
            }
        }

        // Simdgroup-wide reductions — fold the per-lane min/max into a
        // tile-wide value.  Metal `simd_min` / `simd_max` follow IEEE-754
        // min/max semantics: for finite operands they behave like std::min/
        // std::max; for -inf they return -inf (min) / finite (max).  The
        // max value of a tile where every cell is -inf is therefore -inf,
        // which we detect below via `!isfinite(mmax) && mmax < 0`.
        mmin = simd_min(mmin);
        mmax = simd_max(mmax);

        // Three-way classification.  See the llama.cpp reference at
        // ggml-metal.metal:5704-5710 for the equivalent f16 logic.
        //
        // Fully-masked: our mask builder writes bit-exact `-INFINITY` for
        // blocked cells.  `simd_max` of a tile of all -inf returns -inf.
        // The defensive `<= -1e30f` threshold also catches finite "very
        // negative" sentinels (e.g., if a future caller switches to
        // -FLT_MAX/2), without requiring exact-representation checks.
        //
        // All-attended: the mask builder writes bit-exact `bfloat(0.0)`
        // (bit pattern 0x0000) for attended cells.  A tile of all zeros
        // reduces to mmin=mmax=0.  Exact equality is safe — no rounding
        // paths produce subnormal zeros here.
        if (mmax <= -1.0e30f) {
            res = 0;  // fully masked
        } else if (mmin == 0.0f && mmax == 0.0f) {
            res = 2;  // all attended (mask is a no-op for this tile)
        } else {
            res = 1;  // mixed
        }
    }

    // Write the classification byte.  dst layout is [NQ, NK] row-major;
    // only lane 0 writes so we don't race 32 lanes against one byte.
    const int NK = (kL + BK - 1) / BK;
    if (tiisg == 0) {
        blk_out[qt * NK + kt] = res;
    }
}