mlx-native 0.6.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
// flash_attn_prefill_mask — GPU fill kernel for the bf16 additive attention
// mask consumed by the flash_attn_prefill family of kernels.
//
// Ported from: llama.cpp's llm_graph_input_attn_no_cache::set_input mask-fill
// algorithm at /opt/llama.cpp/src/llama-graph.cpp:380-444.
//
// llama.cpp fills the mask CPU-side then relies on implicit upload.  We fill
// it GPU-side because (a) Apple Silicon has unified memory so there is no
// meaningful "upload", (b) GPU fill parallelises trivially over (qL, kL) and
// stays on-device to match the rest of the mlx-native dispatcher, and
// (c) we avoid the cache-invalidation overhead of a large host→device transfer
// per prefill when we build both the global and sliding masks.  The mask
// values written by this kernel are byte-identical to llama.cpp's post-cast
// bf16 mask (see ADR-011 phase 2 §6.1).
//
// Reference: the canonical in-kernel attended predicate (simplified for the
// batch=1, single-sequence, no-ALiBi, causal_attn=true case, per ADR-011
// phase 2 §1.5) is:
//
//   attended iff (k_pos <= q_abs)              // causal
//            AND (!swa_standard ||
//                 q_abs - k_pos < n_swa)        // SWA window
//
// Otherwise the cell is written as bfloat16_t(-INFINITY), matching
// llama-graph.cpp:421,436 and the flash_attn_prefill.metal mask-sentinel
// contract (masked = bf16 -inf = bit pattern 0xFF80; attended = +0.0 =
// bit pattern 0x0000).
//
// Grid geometry:
//   Threadgroups: (ceil(qL / 32), qL_rows, 1) — one threadgroup per row,
//   32 threads each sweeping the K dimension in strides of 32.
//
//   Concretely the host dispatches threadgroups=(1, qL, 1) with tgsize=
//   (min(kL, 256), 1, 1) so each threadgroup fills one row and each
//   thread writes `ceil(kL / tgsize.x)` cells.  This mirrors softmax.metal's
//   one-threadgroup-per-row layout (see ops/softmax.rs:93-106).
//
// Params layout (inline bytes at buffer(1)):
//   struct FlashAttnPrefillMaskParams {
//     uint  seq_len_k;     // K dimension (mask stride between rows)
//     uint  q_abs_offset;  // absolute offset of the first query row (ql_off)
//     int   n_swa;         // sliding window size; -1 means "no window"
//     uint  causal;        // 1 = apply causal (k_pos > q_abs → masked)
//   };
//
// The `n_swa < 0` convention (rather than a separate boolean) is a host
// convenience: a global mask is built with n_swa=-1 and SWA is skipped.
// `causal` is a uint (not bool) so the param struct layout is trivial to
// serialize from Rust via bytemuck::Pod.

#include <metal_stdlib>
using namespace metal;

#if defined(__HAVE_BFLOAT__)
typedef bfloat bfloat16_t;
#else
// If bfloat is unavailable the flash_attn_prefill kernel family would not
// compile either, so this path is never exercised on the Apple Silicon
// targets we support.  Declared here for compilation completeness.
typedef half bfloat16_t;
#endif

struct FlashAttnPrefillMaskParams {
    uint seq_len_k;
    uint q_abs_offset;
    int  n_swa;
    uint causal;
};

// Single row-per-threadgroup kernel.  One threadgroup per q row, 32 or more
// threads sweeping K in stride loops.
//
// Correctness is exact (0.0 vs -inf are discrete bf16 values); the kernel
// writes one of two well-defined bit patterns per cell.  Both halves of the
// mask (global and sliding) can be built with different (n_swa, causal)
// settings on separate dispatches — the kernel is stateless.
//
// align_K handling: the inner loop covers the full seq_len_k with a stride
// loop, so unaligned kL trailing remainders write correctly.  No special
// pad handling is required because the output is exactly the bf16 mask
// with no trailing pad cells.
kernel void flash_attn_prefill_mask_fill_bf16(
    device bfloat16_t* mask                                  [[buffer(0)]],
    constant FlashAttnPrefillMaskParams& params              [[buffer(1)]],
    uint q_row                                               [[threadgroup_position_in_grid]],
    uint tid                                                 [[thread_index_in_threadgroup]],
    uint tg_size                                             [[threads_per_threadgroup]]
) {
    const uint seq_len_k  = params.seq_len_k;
    const int  q_abs      = int(q_row + params.q_abs_offset);
    const int  n_swa      = params.n_swa;           // -1 = disabled
    const bool causal     = params.causal != 0u;
    const uint row_offset = q_row * seq_len_k;

    // Stride loop across the K dimension: every thread writes cells at
    // indices tid, tid + tg_size, tid + 2*tg_size, ...
    for (uint k_pos = tid; k_pos < seq_len_k; k_pos += tg_size) {
        const int kp = int(k_pos);

        // Mirror llama_hparams::is_masked_swa (llama-hparams.h:316-328) +
        // causal: attended iff (kp <= q_abs) for causal, AND (q_abs - kp <
        // n_swa) for SWA_STANDARD.
        //
        // The llama.cpp loops ("if future ... continue" and "if masked_swa
        // ... continue") map to our boolean OR of "is_masked" gates below.
        bool is_masked = false;
        if (causal && kp > q_abs) {
            is_masked = true;
        }
        if (n_swa > 0 && (q_abs - kp) >= n_swa) {
            is_masked = true;
        }

        // bf16(-INFINITY) has bit pattern 0xFF80 (sign=1, exp=0xFF, mant=0);
        // bf16(0.0) has bit pattern 0x0000.  Both are exact.  The constructor
        // from `float(-INFINITY)` selects the preserved-infinity path in
        // _MLX_BFloat16 / bfloat (the non-NaN branch of float_to_bfloat_bits):
        // input bits 0xFF800000, round-to-nearest-even rounds the mantissa
        // half-even (zero + zero → zero), shift >> 16 → 0xFF80.
        const bfloat16_t masked_val   = bfloat16_t(-INFINITY);
        const bfloat16_t attended_val = bfloat16_t(0.0);

        mask[row_offset + k_pos] = is_masked ? masked_val : attended_val;
    }
}