// flash_attn_prefill — mlx-native flash-attention prefill kernel.
//
// What this file is:
// The mlx-native Metal kernel for batched scaled-dot-product attention at
// prefill time, using online softmax + simdgroup matrix-multiply-accumulate
// (MMA) tiling. It is the prefill counterpart to flash_attn_vec.metal,
// which handles the seq_len=1 decode case where one-thread-per-q-position
// is bandwidth-optimal.
//
// Algorithm:
// The "flash-attention" formulation as described by Dao et al. (2022,
// arXiv:2205.14135), implemented for Apple GPU's 8x8x8 simdgroup MMA
// primitives. Q is pre-scaled by `scale * log2(e)` so the inner softmax
// is computed via `fast::exp2`; per-row max and sum are maintained as
// running registers across the K-tile sweep; output is normalized at the
// end with a single divide.
//
// Tile geometry (this kernel):
// D=256: BQ=32, BK=16, 4 simdgroups per threadgroup (128 threads).
// ~29 KB threadgroup memory at bf16/f16 I/O — fits Apple Silicon's
// 32 KB MTLDevice.maxThreadgroupMemoryLength budget with headroom.
// D=512: BQ=8, BK=8, 1 simdgroup per threadgroup (32 threads).
// ~25 KB threadgroup memory at bf16/f16. Smaller tiles because
// the Qs footprint scales with BQ * BD; preserving the larger
// BQ at D=512 would overflow the 32 KB budget.
// No f32 instantiation at either D: the Qs tile alone is 32 KB at f32
// D=256 (BQ=32 * BD=256 * 4 bytes), saturating the budget before KV_smem
// or scratch. f32 numerics are verified at the CPU reference layer in
// tests/test_flash_attn_prefill.rs. See ADR-011-phase1-port-source-decision.md
// §3 for the full threadgroup-memory analysis.
//
// Numerical guard (output normalisation — ONE guard, matches llama.cpp):
// This kernel follows llama.cpp's non-vec flash-attention design: the
// per-row running max `M` is initialised to a FINITE sentinel `-FLT_MAX/2`
// (~-1.7e38) rather than true `-infinity`. Masked positions in the input
// mask buffer still arrive as `-inf` (consistent with llama.cpp's CPU
// convention at `llama-graph.cpp:421,436,1572`), so scores `s2` CAN become
// `-inf` mid-flight. But because `M` is kept finite by the `simd_max`
// floor of `-FLT_MAX/2`, every `exp(score - M)` evaluates as
// `exp(-inf - finite) = exp(-inf) = +0.0` (IEEE-754 exact) rather than
// `exp(-inf - -inf) = exp(NaN) = NaN`. No intermediate guard needed.
//
// The ONE surviving guard is the final output normalisation: for a row
// where every K position was masked, `sum_score` stays at bit-exact 0
// across the K-sweep, and the final `output / sum_score = 0/0 = NaN`
// without a guard. `DivOp` returns 0 in that case, mirroring llama.cpp's
// `const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj];` at
// `ggml-metal.metal:6358`.
//
// Fully-masked-row output is exact 0.0 in every component under this
// regime — verified end-to-end by the
// `test_gpu_bf16_d256_fully_masked_nan_guard` integration test, and
// line-by-line traced in ADR-011-phase2-port-sentinel.md §3.3.
//
// References (algorithmic, not source dependencies):
// - Dao et al., "FlashAttention: Fast and Memory-Efficient Exact
// Attention with IO-Awareness" (2022).
// - MLX backend/metal/kernels/steel/attn — Apple Inc.'s reference Metal
// implementation; provided the simdgroup MMA tile structure we use.
// - llama.cpp ggml/src/ggml-metal — Apple-Silicon flash-attention. We
// port llama.cpp's numerical convention directly: M-init = -FLT_MAX/2
// (non-vec: `:5891`, vec: `:6725`), unguarded exp in the K-sweep
// (`:6155-6156`), and a single `S == 0 ? 0 : 1/S` guard at the output
// normalisation (`:6358`). See ADR-011-phase2-port-sentinel.md.
// - candle-metal-kernels/src/metal_src/scaled_dot_product_attention.metal
// — Hugging Face's Apache-2.0/MIT port of MLX with NaN guards; the MMA
// tile structure we retain is from candle's port, but the numerical
// design (finite M-init, single output-side guard) is llama.cpp's.
//
// SPDX-License-Identifier: MIT
#include <metal_stdlib>
#include <metal_simdgroup>
using namespace metal;
#define STEEL_CONST static constant constexpr const
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
#if defined(__HAVE_BFLOAT__)
typedef bfloat bfloat16_t;
typedef half float16_t;
#else
/////////////////////////////////////////////////////////////////////////////
// Helpers
/////////////////////////////////////////////////////////////////////////////
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
// Check for nan
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
_fp_encoding_traits<float>::inf_mask) {
return uint16_t(as_type<uint32_t>(0x7FC0));
}
// Take bits
uint32_t float_bits = as_type<uint32_t>(x);
// Round to nearest even
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
// Take upper 16 bits
return float_bits >> 16;
}
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
// Upper 16 bits are the data and lower 16 bits are 0s
return as_type<float>((uint32_t)x << 16);
}
struct _MLX_BFloat16;
template <typename T>
static constexpr constant bool can_convert_to_bfloat =
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
template <typename T>
static constexpr constant bool can_convert_from_bfloat =
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
/////////////////////////////////////////////////////////////////////////////
// Bfloat struct
/////////////////////////////////////////////////////////////////////////////
struct _MLX_BFloat16 {
/////////////////////////////////////////////////////////////////////////////
// Constructors
uint16_t bits_;
_MLX_BFloat16() thread = default;
_MLX_BFloat16() threadgroup = default;
_MLX_BFloat16() device = default;
_MLX_BFloat16() constant = default;
struct bits_to_bfloat_struct {};
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
return bits_to_bfloat_struct();
}
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
: bits_(bits) {}
/////////////////////////////////////////////////////////////////////////////
// Conversions to bfloat
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) device
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
/////////////////////////////////////////////////////////////////////////////
// Conversions from bfloat
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const thread {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const threadgroup {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const device {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const constant {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
};
/////////////////////////////////////////////////////////////////////////////
// Bfloat operators
/////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////
// Unary ops
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
return -static_cast<float>(x);
}
/////////////////////////////////////////////////////////////////////////////
// Binary operators
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
} \
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
/////////////////////////////////////////////////////////////////////////////
// Arithmetic Operators
#define bfloat_binop(_op_, _operator_) \
bfloat_binop_base( \
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
bfloat_binop_helper(_op_, _operator_, float, float, float); \
bfloat_binop_helper(_op_, _operator_, float, half, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
bfloat_binop(+, operator+);
bfloat_binop(-, operator-);
bfloat_binop(*, operator*);
bfloat_binop(/, operator/);
/////////////////////////////////////////////////////////////////////////////
// Comparison ops
#define bfloat_compop(__op__, __operator__) \
bfloat_binop_base( \
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
bfloat_compop(>, operator>);
bfloat_compop(<, operator<);
bfloat_compop(>=, operator>=);
bfloat_compop(<=, operator<=);
bfloat_compop(==, operator==);
bfloat_compop(!=, operator!=);
#undef bfloat_compop
#undef bfloat_binop_base
#undef bfloat_binop_helper
#undef bfloat_binop
/////////////////////////////////////////////////////////////////////////////
// Inplace Operators
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
addr_space _MLX_BFloat16& lhs, itype rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
} \
constexpr METAL_FUNC addr_space itype& __operator__( \
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
}
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
#define bfloat_inplace_op(itype) \
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
bfloat_inplace_op(float);
bfloat_inplace_op(half);
bfloat_inplace_op(int16_t);
bfloat_inplace_op(int32_t);
bfloat_inplace_op(int64_t);
bfloat_inplace_op(uint16_t);
bfloat_inplace_op(uint32_t);
bfloat_inplace_op(uint64_t);
#undef bfloat_inplace_op_helper
#undef bfloat_inplace_op_addr_space_helper
#undef bfloat_inplace_op
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
}
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
bfloat_inplace_op_helper(__op__, __operator__, device); \
bfloat_inplace_op_helper(__op__, __operator__, thread); \
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
bfloat_inplace_op_addr_space_helper(+, operator+=);
bfloat_inplace_op_addr_space_helper(-, operator-=);
bfloat_inplace_op_addr_space_helper(*, operator*=);
bfloat_inplace_op_addr_space_helper(/, operator/=);
#undef bfloat_inplace_op_helper
#undef bfloat_inplace_op_addr_space_helper
/////////////////////////////////////////////////////////////////////////////
// Bfloat typedef
/////////////////////////////////////////////////////////////////////////////
typedef struct _MLX_BFloat16 bfloat16_t;
#endif
// ============ "mlx/backend/metal/kernels/utils.h"
template <typename U>
struct Limits {
static const constant U max = metal::numeric_limits<U>::max();
static const constant U min = metal::numeric_limits<U>::min();
static const constant U finite_max = metal::numeric_limits<U>::max();
static const constant U finite_min = metal::numeric_limits<U>::min();
};
#define instantiate_default_limit(type) \
template <> \
struct Limits<type> { \
static constexpr constant type max = metal::numeric_limits<type>::max(); \
static constexpr constant type min = metal::numeric_limits<type>::min(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
metal::numeric_limits<type>::min(); \
};
instantiate_default_limit(uint8_t);
instantiate_default_limit(uint16_t);
instantiate_default_limit(uint32_t);
instantiate_default_limit(uint64_t);
instantiate_default_limit(int8_t);
instantiate_default_limit(int16_t);
instantiate_default_limit(int32_t);
instantiate_default_limit(int64_t);
#define instantiate_float_limit(type) \
template <> \
struct Limits<type> { \
static constexpr constant type max = \
metal::numeric_limits<type>::infinity(); \
static constexpr constant type min = \
-metal::numeric_limits<type>::infinity(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
-metal::numeric_limits<type>::max(); \
};
instantiate_float_limit(half);
instantiate_float_limit(float);
instantiate_float_limit(bfloat16_t);
// ============ "mlx/backend/metal/kernels/steel/attn/loader.h"
template <int R, int C>
struct CShape {
STEEL_CONST int kRows = R;
STEEL_CONST int kCols = C;
};
template <
typename T,
short BROWS,
short BCOLS,
short kDstStrRow,
short kDstStrCol,
short reduction_dim,
short tgp_size,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoaderT {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
/* Constructor */
METAL_FUNC BlockLoaderT(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * kDstStrRow + bj * kDstStrCol),
src(src_ + bi * src_ld + bj) {}
/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] =
op.apply(dst[i * kDstStrRow + j * kDstStrCol]);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j];
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
// ============ "mlx/backend/metal/kernels/steel/utils/type_traits.h"
template <typename... Ts>
struct make_void {
typedef void type;
};
template <typename... Ts>
using void_t = typename make_void<Ts...>::type;
template <typename T>
struct pointer_element {};
template <typename T>
struct pointer_element<thread T*> {
using type = remove_cv_t<T>;
};
template <typename T>
struct pointer_element<device T*> {
using type = remove_cv_t<T>;
};
template <typename T>
struct pointer_element<constant T*> {
using type = remove_cv_t<T>;
};
template <typename T>
struct pointer_element<threadgroup T*> {
using type = remove_cv_t<T>;
};
template <typename T>
using pointer_element_t = typename pointer_element<remove_cv_t<T>>::type;
// ============ "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
///////////////////////////////////////////////////////////////////////////////
// Integral constant with casting
///////////////////////////////////////////////////////////////////////////////
template <int val>
using Int = integral_constant<int, val>;
///////////////////////////////////////////////////////////////////////////////
// Binary Operators on Integral constants
///////////////////////////////////////////////////////////////////////////////
#define integral_const_binop(__op__, __operator__) \
template <typename T, T tv, typename U, U uv> \
METAL_FUNC constexpr auto __operator__( \
integral_constant<T, tv>, integral_constant<U, uv>) { \
constexpr auto res = tv __op__ uv; \
return integral_constant<decltype(res), res>{}; \
}
integral_const_binop(+, operator+);
integral_const_binop(-, operator-);
integral_const_binop(*, operator*);
integral_const_binop(/, operator/);
integral_const_binop(==, operator==);
integral_const_binop(!=, operator!=);
integral_const_binop(<, operator<);
integral_const_binop(>, operator>);
integral_const_binop(<=, operator<=);
integral_const_binop(>=, operator>=);
integral_const_binop(&&, operator&&);
integral_const_binop(||, operator||);
#undef integral_const_binop
// ============ "mlx/backend/metal/kernels/steel/attn/mma.h"
template <typename RInt, typename CInt>
struct Shape2D {
RInt r;
CInt c;
Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {}
};
template <typename Shape, typename Layout>
struct Layout2D {
Shape shape;
Layout layout;
};
template <typename T, int kFragRows_, int kFragCols_>
struct BaseMMAFrag {
static_assert(
kFragRows_ == 8,
"Only 8 x 8 fragment matrices are currently supported");
static_assert(
kFragCols_ == 8,
"Only 8 x 8 fragment matrices are currently supported");
};
template <typename T>
struct BaseMMAFrag<T, 8, 8> {
STEEL_CONST int kFragRows = 8;
STEEL_CONST int kFragCols = 8;
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
STEEL_CONST int kElemRows = 1;
STEEL_CONST int kElemCols = 2;
static_assert(
kElemRows * kElemCols == kElemsPerFrag,
"MMAFrag shape is not consistent with MMAFrag size");
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
typedef metal::vec<T, kElemsPerFrag> frag_type;
typedef metal::vec<T, kElemRows> row_frag_type;
typedef metal::vec<T, kElemCols> col_frag_type;
template <typename U>
using dtype_mat_t = typename metal::simdgroup_matrix<U, kFragRows, kFragCols>;
template <typename U>
using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
[[thread_index_in_simdgroup]]) {
const short qid = simd_lane_id / 4;
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
return short2{fn, fm};
}
template <typename SrcPtrType, typename StrX, typename StrY>
METAL_FUNC static constexpr void
load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * kElemCols + j] = static_cast<T>(src[i * str_x.value + j * str_y.value]);
}
}
}
template <
typename SrcPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void load_safe(
thread frag_type& dst,
SrcPtrType src,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[i * kElemCols + j] =
static_cast<T>(src[(off_x + i) * str_x + (off_y + j) * str_y.value]);
} else {
dst[i * kElemCols + j] = T(0);
}
}
}
}
template <typename DstPtrType, typename StrX, typename StrY>
METAL_FUNC static constexpr void
store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * str_x + j * str_y.value] = static_cast<U>(src[i * kElemCols + j]);
}
}
}
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void store_safe(
const thread frag_type& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[(off_x + i) * str_x + (off_y + j) * str_y.value] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
template <typename Atype, typename Btype, typename Ctype>
METAL_FUNC static constexpr void mma(
thread frag_type& D,
thread dtype_frag_t<Atype>& A,
thread dtype_frag_t<Btype>& B,
thread dtype_frag_t<Ctype>& C) {
mat_type D_mat;
dtype_mat_t<Atype> A_mat;
dtype_mat_t<Btype> B_mat;
dtype_mat_t<Ctype> C_mat;
reinterpret_cast<thread dtype_frag_t<Atype>&>(A_mat.thread_elements()) = A;
reinterpret_cast<thread dtype_frag_t<Btype>&>(B_mat.thread_elements()) = B;
reinterpret_cast<thread dtype_frag_t<Ctype>&>(C_mat.thread_elements()) = C;
mma(D_mat, A_mat, B_mat, C_mat);
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
}
template <typename Atype, typename Btype, typename Ctype>
METAL_FUNC static constexpr void mma(
thread mat_type& D,
thread dtype_mat_t<Atype>& A,
thread dtype_mat_t<Btype>& B,
thread dtype_mat_t<Ctype>& C) {
simdgroup_multiply_accumulate(D, A, B, C);
}
template <typename Op>
METAL_FUNC static constexpr void row_reduce(
thread const frag_type& inp_vals,
thread T* reduced_vals) {
T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
}
template <typename Op>
METAL_FUNC static constexpr void row_bin_op(
thread frag_type& inp_vals,
thread T* row_vals) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
inp_vals[i * kElemCols + j] =
Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
}
}
}
};
template <
typename T,
int kTileRows_,
int kTileCols_,
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
struct MMATile {
using MMAFrag_t = MMAFrag_;
using elem_type = T;
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
STEEL_CONST int kTileRows = kTileRows_;
STEEL_CONST int kTileCols = kTileCols_;
STEEL_CONST int kRows = kTileRows * kFragRows;
STEEL_CONST int kCols = kTileCols * kFragCols;
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;
STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;
typedef typename MMAFrag_t::mat_type mat_type;
typedef typename MMAFrag_t::frag_type frag_type;
frag_type val_frags[kNumFrags]; // = {frag_type(0)};
METAL_FUNC MMATile() thread {}
METAL_FUNC constexpr void clear() {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kNumFrags; ++i) {
val_frags[i] = frag_type(0);
}
}
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
return val_frags[i * kTileCols + j];
}
METAL_FUNC constexpr const thread frag_type& frag_at(
const short i,
const short j) const {
return val_frags[i * kTileCols + j];
}
METAL_FUNC mat_type mat_at(const short i, const short j) {
mat_type val_mat;
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
}
return val_mat;
}
METAL_FUNC thread elem_type* elems() {
return reinterpret_cast<thread elem_type*>(val_frags);
}
METAL_FUNC const thread elem_type* elems() const {
return reinterpret_cast<const thread elem_type*>(val_frags);
}
template <typename Op>
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::template row_reduce<Op>(
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
}
}
}
template <typename Op>
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::template row_bin_op<Op>(
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
}
}
}
template <typename U, int w_x, int w_y, int str_x, int str_y>
METAL_FUNC void load(const threadgroup U* src) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::load(
frag_at(i, j),
&(
src[(i * kFragRows) * w_x * str_x +
(j * kFragCols) * w_y * str_y]),
Int<str_x>{},
Int<str_y>{});
}
}
}
template <typename U, int w_x, int w_y, int str_x, int str_y>
METAL_FUNC void store(threadgroup U* dst) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::store(
frag_at(i, j),
&(
dst[(i * kFragRows) * w_x * str_x +
(j * kFragCols) * w_y * str_y]),
Int<str_x>{},
Int<str_y>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void load(const device U* src, const int ld) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::load(
frag_at(i, j),
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
ld,
Int<1>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void store(device U* dst, const int ld) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::store(
frag_at(i, j),
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
ld,
Int<1>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::load_safe(
frag_at(i, j),
src,
ld,
Int<1>{},
src_tile_dims.y,
src_tile_dims.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::store_safe(
frag_at(i, j),
dst,
ld,
Int<1>{},
dst_tile_dims.y,
dst_tile_dims.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
};
template <
typename Dtype,
typename Atype,
typename Btype,
typename Ctype,
int M,
int N,
int K,
class MMAFragD,
class MMAFragA,
class MMAFragB,
class MMAFragC>
METAL_FUNC void tile_matmad(
thread MMATile<Dtype, M, N, MMAFragD>& D,
thread MMATile<Atype, M, K, MMAFragA>& A,
thread MMATile<Btype, K, N, MMAFragB>& B,
thread MMATile<Ctype, M, N, MMAFragC>& C) {
STEEL_PRAGMA_UNROLL
for (short m = 0; m < M; ++m) {
STEEL_PRAGMA_UNROLL
for (short n = 0; n < N; ++n) {
short m_serp = m; //(n % 2) ? (M - 1 - m) : m;
short n_serp = (m % 2) ? (N - 1 - n) : n;
STEEL_PRAGMA_UNROLL
for (short k = 0; k < K; ++k) {
MMAFragD::mma(
D.frag_at(m_serp, n_serp),
A.frag_at(m_serp, k),
B.frag_at(k, n_serp),
C.frag_at(m_serp, n_serp));
}
}
}
}
// ──────────────────────────────────────────────────────────────────────────
// AttnParams — host/device ABI for the attention kernel.
//
// The Rust mirror of this struct (AttnParamsGpu) lives in
// /opt/mlx-native/src/ops/flash_attn_prefill.rs and must match this layout
// byte-for-byte. Total size: 160 bytes (verified at test time).
// ──────────────────────────────────────────────────────────────────────────
struct AttnParams {
int B; ///< Batch Size
int H; ///< Heads
int D; ///< Head Dim
int qL; ///< Query Sequence Length
int kL; ///< Key Sequence Length
int gqa_factor; ///< Group Query factor
float scale; ///< Attention scale
float softcapping; ///< Softcapping value (1.0 = disabled)
int NQ; ///< Number of query blocks
int NK; ///< Number of key/value blocks
int NQ_aligned; ///< Number of full query blocks
int NK_aligned; ///< Number of full key/value blocks
int qL_rem; ///< Remainder in last query block
int kL_rem; ///< Remainder in last key/value block
int qL_off; ///< Offset in query sequence start
int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
};
struct AttnMaskParams {
int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1)
};
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
constant bool align_Q [[function_constant(200)]];
constant bool align_K [[function_constant(201)]];
constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];
// Wave 2E tile-skip pre-pass: when true, the kernel reads a per-(qtile, ktile)
// classification byte from buffer(7) and uses it to skip fully-masked tiles
// and the mask-add on all-attended tiles. See
// /opt/mlx-native/src/shaders/flash_attn_prefill_blk.metal and
// /opt/hf2q/docs/ADR-011-phase2-port-tile-skip.md for the port spec. When
// has_blk is false the function_constant-gated buffer(7) is NOT bound and
// the Metal compiler dead-codes every blk reference below — so enabling the
// pre-pass is a strict zero-cost add for callers that don't need it.
constant bool has_blk [[function_constant(303)]];
template <typename T>
struct TransformScale {
T scale;
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
METAL_FUNC T apply(T x) const {
return scale * x;
}
};
struct MaxOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return metal::max(x, y);
}
};
struct SumOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x + y;
}
};
struct MulOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x * y;
}
};
struct SubOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x - y;
}
};
struct ExpSubOp {
// Unguarded under the finite-M regime: M is initialised to -FLT_MAX/2
// and floor-capped by simd_max, so y is ALWAYS finite. When a score
// x is -inf (from a masked position), exp2(-inf - finite) = exp2(-inf)
// = +0.0 (IEEE-754 exact), never NaN. Matches llama.cpp's
// `const float2 vs2 = exp(s2 - M[jj]);` at ggml-metal.metal:6156.
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return fast::exp2(x - y);
}
};
struct DivOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
// THE SOLE remaining numerical guard under the llama.cpp-derived
// finite-M regime. Mirrors llama.cpp's output-normalisation guard at
// `ggml-metal.metal:6358`:
// const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj];
//
// For a row where every KV position was masked to -inf, scores are
// -inf, exp2(-inf - -FLT_MAX/2) = 0 (IEEE-754 exact, NOT NaN), so the
// K-sweep accumulates sum_score = bit-exact 0 with no intermediate
// NaN. The final output normalisation is then `output / sum_score
// = 0/0 = NaN` without this guard. Returning 0 in that case
// preserves the "no valid keys attended → no contribution"
// semantics; it is a no-op for any non-degenerate row where
// sum_score > 0. See ADR-011-phase2-port-sentinel.md §2.3.
return (y == T(0)) ? T(0) : x / y;
}
};
// clang-format off
template <
typename T,
int BQ,
int BK,
int BD,
int WM,
int WN,
typename MaskType = float,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device T* O [[buffer(3)]],
const constant AttnParams* params [[buffer(4)]],
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
const device char* blk [[buffer(7), function_constant(has_blk)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
// Pacifying compiler
(void)lid;
// Move to correct block
ulong3 tidl{tid.x, tid.y, tid.z};
Q += tidl.z * params->Q_strides[0] + // Batch
tidl.y * params->Q_strides[1] + // Head
tidl.x * BQ * params->Q_strides[2]; // Seqeunce
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
K += tidl.z * params->K_strides[0] + // Batch
kv_head_idx * params->K_strides[1]; // Head
V += tidl.z * params->V_strides[0] + // Batch
kv_head_idx * params->V_strides[1]; // Head
O += tidl.z * params->O_strides[0] + // Batch
tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Seqeunce
if (has_mask) {
mask += tidl.z * mask_params->M_strides[0] + // Batch
tidl.y * mask_params->M_strides[1]; // Head
}
// Prepare threadgroup memory
constexpr short padQ = 16 / sizeof(T);
constexpr short padK = 16 / sizeof(T);
constexpr short padV = 16 / sizeof(T);
constexpr short LDQ_tgp = BD + padQ;
constexpr short LDK_tgp = BK + padK;
constexpr short LDV_tgp = BD + padV;
constexpr short tgp_mem_0 = (BK + padK) * (BD);
constexpr short tgp_mem_1 = BK * (BD + padV);
constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1;
threadgroup T Q_smem[BQ * (BD + padQ)];
threadgroup T KV_smem[tgp_mem_s];
threadgroup T* Qs = Q_smem;
threadgroup T* Ks = KV_smem;
threadgroup T* Vs = KV_smem;
// Prepare block loaders
using QBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BQ,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ LDQ_tgp,
/* short kDstStrCol = */ 1,
/* short reduction_dim = */ 1,
/* short tgp_size = */ WM * WN * 32>;
// K is loaded in transposed
using KBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BK,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ 1,
/* short kDstStrCol = */ LDK_tgp,
/* short reduction_dim = */ 0,
/* short tgp_size = */ WM * WN * 32>;
using VBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BK,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ LDV_tgp,
/* short kDstStrCol = */ 1,
/* short reduction_dim = */ 0,
/* short tgp_size = */ WM * WN * 32>;
QBlockLoader loader_q(
Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
KBlockLoader loader_k(
K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
VBlockLoader loader_v(
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089));
// Prepare MMA tiles
constexpr short kFragSize = 8; // MMAFrag size
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
constexpr int kNWarps = WM * WN;
static_assert(
BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
// Q seq frags per warp
constexpr int TQ = BQ / (kNWarps * kFragSize);
// KV sequence frags (all warps load the same frags)
constexpr int TK = BK / kFragSize;
// HeadDim frags (all warps load the same frags)
constexpr int TD = BD / kFragSize;
static_assert(TQ == 1, "Check TQ");
MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;
MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;
MMATile<AccumType, 1, 1, MMAFrag_acc_t> Vtile;
MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;
Otile.clear();
// Prepare mma tile offsets
const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
const short sm = simd_coord.y;
const short sn = simd_coord.x;
const short tm = kFragSize * TQ * simd_group_id;
const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
const short Ks_offset = sm * LDK_tgp + sn;
const short Vs_offset = sm * LDV_tgp + sn;
constexpr short Qs_tile_stride = kFragSize;
constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load Q blocks apply scale
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
loader_q.load_safe(short2(BD, params->qL_rem));
} else {
loader_q.load_unsafe();
}
loader_q.apply_inplace_op(ts);
// Init row reduction variables
constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
AccumType max_score[kRowsPT];
AccumType sum_score[kRowsPT] = {0};
// Init max_score to finite sentinel -FLT_MAX/2 per llama.cpp's convention
// at ggml-metal.metal:5891 (non-vec prefill) and :6725 (vec decode).
// A finite sentinel absorbs -inf scores (from masked positions) via
// simd_max without ever letting M become -inf itself, so exp(score - M)
// evaluates cleanly as exp(-inf) = 0 rather than exp(NaN) = NaN. This
// is the whole reason the three candle-style NaN guards are not needed.
// See ADR-011-phase2-port-sentinel.md §1.3.
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = -FLT_MAX / AccumType(2);
}
int kb_lim = params->NK;
if (do_causal) {
int q_max = (tid.x + 1) * BQ + params->qL_off;
kb_lim = (q_max + BK - 1) / BK;
}
// ── Wave 2E tile-skip pre-pass row base ─────────────────────────────────
//
// When has_blk is true the dispatcher has bound the per-(qtile, ktile)
// classification byte buffer at buffer(7). The mask (and therefore the
// blk) produced by Wave 2D is a single [qL, kL] plane that is broadcast
// across batch and heads via m_strides = (0, 0, kL). So the blk buffer
// is shape [NQ, NK] (no batch / head axis) and each main-kernel
// threadgroup reads its row at `blk + tid.x * NK`.
//
// Port of llama.cpp ggml-metal.metal:5841-5846, adapted to our 2D mask
// layout. See /opt/hf2q/docs/ADR-011-phase2-port-tile-skip.md §6.
const device char* blk_row = nullptr;
if (has_blk) {
const int NK_blk = (params->kL + BK - 1) / BK;
blk_row = blk + int(tid.x) * NK_blk;
}
// Loop over KV seq length
for (int kb = 0; kb < kb_lim; kb++) {
// ── Wave 2E tile-skip branch ─────────────────────────────────────────
//
// blk_cur:
// 0 → skip entire tile (port of ggml-metal.metal:5956-5962)
// 1 → standard mask-add + softmax (default; matches pre-Wave-2E behaviour)
// 2 → skip mask-add, compute Q·K^T + softmax normally (port of :6145)
//
// When has_blk is false blk_cur is forced to 1 below, and the compiler
// dead-codes both the byte load and the skip branch (blk_cur == 0 can
// never be true). The subsequent `blk_cur != 2` gate on mask-add also
// becomes a constant-true under `has_blk=false && has_mask=true`, so
// the mask-add code path is unchanged from pre-Wave-2E in that case.
char blk_cur = 1;
if (has_blk) {
blk_cur = blk_row[kb];
if (blk_cur == 0) {
// Fully-masked KV tile — skip the entire iteration. The running
// per-row (max, sum, O) accumulators are unchanged because this
// tile contributes no finite scores: equivalent to the standard
// path with mqk=-inf, which under the finite-M-sentinel regime
// yields exp2(-inf - finite) = 0 contribution and `factor = 1`
// rescale (M unchanged, S unchanged, O unchanged). Matches
// llama.cpp's `continue` at ggml-metal.metal:5961.
//
// IMPORTANT: K/V block loaders advance via loader_k.next() /
// loader_v.next() at the END of every iteration. A `continue`
// that skips the end-of-iter `next()` calls would leave the
// loaders pointing at the same KV tile on the NEXT iteration —
// the subsequent `load_unsafe()` would read the wrong data.
// Advance the loaders before `continue` so the next iteration
// sees the correct KV tile. llama.cpp handles this equivalently
// via its `pm2[jj] += NW` per-row mask pointer advance at
// ggml-metal.metal:5958-5960 (there the K/V stream is per-chunk
// and the advance happens implicitly via ic0++ in the for-head).
loader_k.next();
loader_v.next();
continue;
}
}
// Load K block and apply scale
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!align_K && kb == (params->NK_aligned)) {
loader_k.load_safe(short2(BD, params->kL_rem));
} else {
loader_k.load_unsafe();
}
// Do S = Q @ K.T
Stile.clear();
threadgroup_barrier(mem_flags::mem_threadgroup);
STEEL_PRAGMA_UNROLL
for (short dd = 0; dd < TD; dd++) {
simdgroup_barrier(mem_flags::mem_none);
Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
&Qs[Qs_offset + dd * Qs_tile_stride]);
Ktile.template load<T, 1, 1, LDK_tgp, 1>(
&Ks[Ks_offset + dd * Ks_tile_stride]);
simdgroup_barrier(mem_flags::mem_none);
tile_matmad(Stile, Qtile, Ktile, Stile);
}
// Mask out length sequence
if (!align_K && kb == (params->NK_aligned)) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < stile_t::kTileCols; j++) {
short col_pos = sn + (j * stile_t::kFragCols);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
if ((col_pos + jj) >= params->kL_rem) {
Stile.frag_at(i, j)[jj] = neg_inf;
}
}
}
}
}
// Mask out if causal
if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
const int row_pos =
tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows);
STEEL_PRAGMA_UNROLL
for (short j = 0; j < stile_t::kTileCols; j++) {
const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
if (row_pos < (col_pos + jj)) {
Stile.frag_at(i, j)[jj] = neg_inf;
}
}
}
}
}
// Other masking as needed.
//
// Wave 2E: when has_blk && blk_cur == 2 the entire mask tile is
// bit-exact 0.0, so adding it is a no-op — skip the load+add. Port
// of llama.cpp's `if (blk_cur != 2)` guard at
// ggml-metal.metal:6145. The gate is constant-false when has_blk is
// false (blk_cur always == 1 in that case), so the compiler treats
// this identically to pre-Wave-2E code.
if (has_mask && blk_cur != 2) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
constexpr bool is_bool = is_same_v<MaskType, bool>;
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;
using MMAFrag_mask_t = BaseMMAFrag<melem_t, kFragSize, kFragSize>;
using frag_t = typename MMAFrag_mask_t::frag_type;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows);
STEEL_PRAGMA_UNROLL
for (short j = 0; j < stile_t::kTileCols; j++) {
const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
frag_t mfrag;
// Wave 4 Phase B: pass M_strides[2] as int64_t (its native type
// in `AttnMaskParams`) instead of narrowing through int(). The
// load_safe template body computes `(off_x + i) * str_x`; with
// off_x = row_pos (int) and str_x narrowed to int, the product
// overflows at row_pos * kL >= i32::MAX (e.g. row_pos >= 32768
// when kL = 65536), wrapping to a large negative pointer offset
// and reading garbage before the mask buffer's base. Mirrors
// the already-correct flash_attn_prefill_d512.metal:411-413
// ulong-cast idiom. See
// /tmp/cfa-cfa-20260427-adr005-wave4/phase-A-report.md §2.5.1
// for the closed-form overflow argument.
MMAFrag_mask_t::load_safe(
mfrag,
mask,
mask_params->M_strides[2],
Int<1>{},
params->qL,
params->kL,
row_pos,
col_pos);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) {
if constexpr (is_bool) {
Stile.frag_at(i, j)[jj] =
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
} else {
Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
}
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load V blocks
if (!align_K && kb == (params->NK_aligned)) {
loader_v.load_safe(short2(BD, params->kL_rem));
} else {
loader_v.load_unsafe();
}
// Do softmax
// Temp variables
AccumType new_max[kRowsPT];
AccumType factor[kRowsPT];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
new_max[i] = max_score[i];
}
// Row max
Stile.template row_reduce<MaxOp>(new_max);
// exp(Si - rowmax(Si))
Stile.template row_bin_op<ExpSubOp>(new_max);
// Factor exp(rowmax(Si) - rowmax(Si-1))
// Unguarded under the finite-M regime: max_score is -FLT_MAX/2 initially
// and simd_max-floor-capped at -FLT_MAX/2 thereafter, so the difference
// max_score - new_max is ALWAYS finite (in [-FLT_MAX, 0]). On the first
// K-tile iteration of a fully-masked row the difference is exactly 0
// and factor = exp2(0) = 1, which is correct: sum_score starts at 0,
// stays at sum_score*1 + 0 = 0; Otile starts at 0, stays at 0*1 = 0.
// Matches llama.cpp's unguarded `const float ms = exp(m - M[jj]);`
// at ggml-metal.metal:6155.
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
factor[i] = fast::exp2(max_score[i] - new_max[i]);
}
// Save max for next iteration
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = new_max[i];
}
// Row Sum
AccumType sum_score_tmp[kRowsPT] = {0};
Stile.template row_reduce<SumOp>(sum_score_tmp);
// Update norm
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
}
// Update O
Otile.template row_bin_op<MulOp>(factor);
// Load V into registers
threadgroup_barrier(mem_flags::mem_threadgroup);
STEEL_PRAGMA_UNROLL
for (short iq = 0; iq < TQ; iq++) {
STEEL_PRAGMA_UNROLL
for (short id = 0; id < TD; id++) {
STEEL_PRAGMA_UNROLL
for (short ik = 0; ik < TK; ik++) {
if constexpr (BD == 128) {
simdgroup_barrier(mem_flags::mem_none);
}
const short kk = ik * kFragSize;
const short dd = id * kFragSize;
Vtile.template load<T, 1, 1, LDV_tgp, 1>(
&Vs[Vs_offset + kk * LDV_tgp + dd]);
if constexpr (BD == 128) {
simdgroup_barrier(mem_flags::mem_none);
}
MMAFrag_acc_t::mma(
Otile.frag_at(iq, id),
Stile.frag_at(iq, ik),
Vtile.frag_at(0, 0),
Otile.frag_at(iq, id));
}
}
}
// Prepare for next iteration
loader_k.next();
loader_v.next();
}
// Normalize output
Otile.template row_bin_op<DivOp>(sum_score);
threadgroup_barrier(mem_flags::mem_none);
// Store results
O += (tm + sm) * params->O_strides[2] + sn;
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
} else {
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
}
}
// clang-format off
// ──────────────────────────────────────────────────────────────────────────
// Kernel instantiations
// ──────────────────────────────────────────────────────────────────────────
//
// Eight host-visible entry points covering D ∈ {256, 512} × I/O dtype ∈
// {bf16, f16} × mask kind ∈ {additive (same dtype as I/O), bool}.
// f32 I/O is excluded at both head dims — the Qs threadgroup tile alone
// would exceed Apple Silicon's 32 KB MTLDevice.maxThreadgroupMemoryLength
// budget (BQ * BD * 4 = 32 KB at D=256). See preamble for the full
// threadgroup-memory analysis.
//
// The mask-kind suffix `_boolmask` selects an `is_attended` boolean mask
// where `false` masks the position; the unsuffixed form takes an additive
// mask in the same dtype as I/O (the standard log-domain mask, with
// `-inf` to mask).
#define instantiate_kernel(name, func, ...) \
template [[host_name( \
name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;
#define instantiate_flash_attn_prefill(name, io_dtype, bq, bk, bd, wm, wn, mask_dtype) \
instantiate_kernel(name, \
attention, io_dtype, bq, bk, bd, wm, wn, mask_dtype, float)
// D=256 — sliding-window-attention layers (Gemma 4 family).
// BQ=32, BK=16, WM=4, WN=1 → 128 threads / threadgroup, 4 simdgroups.
// Threadgroup memory ~29 KB at bf16/f16 — fits comfortably under 32 KB.
instantiate_flash_attn_prefill("flash_attn_prefill_bf16_d256", bfloat16_t, 32, 16, 256, 4, 1, bfloat16_t)
instantiate_flash_attn_prefill("flash_attn_prefill_bf16_d256_boolmask", bfloat16_t, 32, 16, 256, 4, 1, bool)
instantiate_flash_attn_prefill("flash_attn_prefill_f16_d256", half, 32, 16, 256, 4, 1, half)
instantiate_flash_attn_prefill("flash_attn_prefill_f16_d256_boolmask", half, 32, 16, 256, 4, 1, bool)
// D=512 — global-attention layers (Gemma 4 family).
// BQ=8, BK=8, WM=1, WN=1 → 32 threads / threadgroup, 1 simdgroup.
// Smaller tiles than D=256 because Qs (BQ * BD * sizeof(T)) scales with
// BQ * BD; the BQ=32 geometry would exceed 32 KB at BD=512.
instantiate_flash_attn_prefill("flash_attn_prefill_bf16_d512", bfloat16_t, 8, 8, 512, 1, 1, bfloat16_t)
instantiate_flash_attn_prefill("flash_attn_prefill_bf16_d512_boolmask", bfloat16_t, 8, 8, 512, 1, 1, bool)
instantiate_flash_attn_prefill("flash_attn_prefill_f16_d512", half, 8, 8, 512, 1, 1, half)
instantiate_flash_attn_prefill("flash_attn_prefill_f16_d512_boolmask", half, 8, 8, 512, 1, 1, bool)
// D=64 — small head dim (BERT family: nomic-bert/bge/mxbai/minilm and any
// other 64-dim attention-head model). Same 4-simdgroup geometry as D=256
// (BQ=32, BK=16, WM=4, WN=1 → 128 threads/threadgroup); threadgroup memory
// drops to ~5 KB at bf16/f16 because Qs (BQ × BD × sizeof(T)) scales with
// BD. Static-asserts pass: BQ=32 ≥ kNWarps×kFragSize = 4×8 = 32, divisible;
// TQ = 32/(4×8) = 1 ✓; TD = 64/8 = 8 ✓; TK = 16/8 = 2 ✓. f32 is excluded
// for ABI consistency with D=256/D=512 and because BERT linears land in f32
// then cast to bf16 before this kernel.
instantiate_flash_attn_prefill("flash_attn_prefill_bf16_d64", bfloat16_t, 32, 16, 64, 4, 1, bfloat16_t)
instantiate_flash_attn_prefill("flash_attn_prefill_bf16_d64_boolmask", bfloat16_t, 32, 16, 64, 4, 1, bool)
instantiate_flash_attn_prefill("flash_attn_prefill_f16_d64", half, 32, 16, 64, 4, 1, half)
instantiate_flash_attn_prefill("flash_attn_prefill_f16_d64_boolmask", half, 32, 16, 64, 4, 1, bool)
// clang-format on