mlx-native 0.8.1

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
// flash_attn_vec_reduce_tq_hb_undo.metal — ADR-028 §iter-485 H3 fusion prototype.
//
// Phase 7d Worker C: TQ-HB SDPA output + FWHT-sign-undo fusion.
//
// FUSION RATIONALE
// ================
// Current TQ-HB SDPA chain (per layer, decode):
//   1. fwht_sign_premult_f32 on Q                  (1 dispatch + barrier)
//   2. flash_attn_vec_tq_hb_dk256                  (1 dispatch + barrier)
//   3. flash_attn_vec_reduce_dk256                 (1 dispatch + barrier)  ← FUSE HERE
//   4. fwht_sign_undo_f32_d256 on output           (1 dispatch + barrier)  ← FOLD IN
//   5. o_proj
//
// H3 (this kernel): fuse 3 + 4 into a single kernel that does the cross-WG
// online-softmax reduce, divides by S_total, then in-place applies the
// inverse-FWHT-then-sign-undo on the final output row before the global store.
//
// Each threadgroup handles ONE output row (one query head). The reduce uses
// `32 * NWG` threads (one simdgroup of 32 with one thread per WG). After the
// reduce produces the final per-head output in registers/threadgroup-mem, we
// run a 32-thread simdgroup-butterfly FWHT (EPT=8 for D=256, EPT=16 for D=512)
// + sign-mask flip identical to fwht_sign_undo_fast.
//
// CHESTERTON'S FENCE
// ==================
// The original flash_attn_vec_reduce kernel + standalone fwht_sign_undo kernel
// are NOT modified. This is a new variant gated by env flag HF2Q_TQ_HB_OUT_FUSED=1.
// Default OFF until parity + delta clears the +3% gate.

#include <metal_stdlib>
using namespace metal;

// Reduce params (matches FlashAttnVecReduceParams in flash_attn_vec.metal).
struct FlashAttnVecReduceTqHbUndoParams {
    uint nrows;
};

// Sign-mask tables — MUST match TBQ_SIGNS_{256,512} in fwht_standalone.metal
// byte-for-byte. Copy-pasted here so the fused kernel is self-contained.
constant uint8_t TBQ_SIGNS_256_REDUCE[32] = {
    0xa7,0x3b,0x91,0xf4,0x6d,0xc2,0x58,0x0e,
    0xb3,0x7f,0x24,0xd6,0x89,0x45,0xea,0x1c,
    0x63,0xaf,0xd8,0x52,0x97,0x0b,0xe1,0x3d,
    0x76,0xc4,0x19,0xfe,0x4a,0x85,0x2c,0xdb,
};
constant uint8_t TBQ_SIGNS_512_REDUCE[64] = {
    0xa7,0x3b,0x91,0xf4,0x6d,0xc2,0x58,0x0e,
    0xb3,0x7f,0x24,0xd6,0x89,0x45,0xea,0x1c,
    0x63,0xaf,0xd8,0x52,0x97,0x0b,0xe1,0x3d,
    0x76,0xc4,0x19,0xfe,0x4a,0x85,0x2c,0xdb,
    0xd3,0x4e,0xa8,0x17,0x9c,0x5b,0xe6,0x31,
    0x72,0xb9,0x0d,0xf5,0x43,0x8a,0x6e,0xc7,
    0x58,0x2f,0x94,0xe1,0xb6,0x3d,0x0a,0x7c,
    0xc5,0x61,0xd8,0x4f,0xa3,0x97,0x1e,0x85,
};

// SIMD butterfly — identical to fwht_simd in fwht_standalone.metal.
inline void butterfly_local_undo(thread float &a, thread float &b) {
    float sum = a + b;
    float diff = a - b;
    a = sum;
    b = diff;
}

template<ushort EPT>
inline void fwht_simd_undo(thread float *elems, uint lane) {
    for (ushort h = 1; h < EPT; h <<= 1) {
        for (ushort i = 0; i < EPT; i++) {
            ushort partner = i ^ h;
            if (partner > i) {
                butterfly_local_undo(elems[i], elems[partner]);
            }
        }
    }
    for (ushort delta = 1; delta < 32; delta <<= 1) {
        for (ushort i = 0; i < EPT; i++) {
            float partner_val = simd_shuffle_xor(elems[i], delta);
            if (lane & delta) {
                elems[i] = partner_val - elems[i];
            } else {
                elems[i] = elems[i] + partner_val;
            }
        }
    }
}

// Fused reduce + FWHT-sign-undo.
//
// Grid: (nrows, 1, 1)   Threadgroup: (32 * NWG, 1, 1)
//
// Layout matches flash_attn_vec_reduce:
//   - htmp holds per-WG partial outputs followed by S/M scalars.
//   - One thread per WG (tiisg) loads its WG's S, M, then SIMD-reduces.
//
// Difference from the unfused reduce: instead of writing the per-D4 chunk
// directly to dst, we first accumulate the final output for THIS HEAD into
// threadgroup memory (`out_tg[DV]`), then run the 32-thread butterfly+sign
// pass on it, then write the inverse-rotated result to dst.
//
// Note: this kernel requires the threadgroup width to be EXACTLY 32 (one
// simdgroup) for the butterfly pass. NWG is the per-WG count from the SDPA
// kernel — at NWG=16 or 32 the threadgroup is 32*NWG wide, but we only use
// the first 32 threads (one simdgroup) for the FWHT. Other simdgroups
// participate in the load+reduce, then write their per-D4 partials into
// `out_tg` indexed by their D4 chunk, before the first simdgroup runs the
// butterfly.
template<short DV>
kernel void flash_attn_vec_reduce_tq_hb_undo(
    constant FlashAttnVecReduceTqHbUndoParams  &params [[buffer(0)]],
    device const float                          *htmp   [[buffer(1)]],
    device       float                          *dst    [[buffer(2)]],
    constant     uint                           &nwg_param [[buffer(3)]],
    threadgroup  float                          *out_tg [[threadgroup(0)]],
    uint   tgpig [[threadgroup_position_in_grid]],
    ushort tiisg [[thread_index_in_simdgroup]],
    ushort sgitg [[simdgroup_index_in_threadgroup]])
{
    constexpr short DV4 = DV / 4;
    constexpr ushort EPT = DV / 32;  // 8 for D=256, 16 for D=512

    const uint NWG = nwg_param;
    const uint64_t rid = tgpig;  // row index (one head per threadgroup)
    const ushort iwg = tiisg;    // each thread handles one workgroup

    // S and M values are stored after all DV data (same layout as
    // flash_attn_vec_reduce — htmp packing is dictated by the SDPA kernel).
    device const float *sm = htmp + (uint64_t)params.nrows * DV * NWG;

    // Load this workgroup's S and M.
    float S_wg = (iwg < NWG) ? sm[rid * (2 * NWG) + 2 * iwg + 0] : 0.0f;
    float M_wg = (iwg < NWG) ? sm[rid * (2 * NWG) + 2 * iwg + 1] : -FLT_MAX / 2;

    // Find global max across all workgroups (simdgroup reduce, lanes 0..NWG-1).
    const float M_global = simd_max(M_wg);
    const float ms = exp(M_wg - M_global);
    float S_total = simd_sum(S_wg * ms);
    float inv_S = (S_total == 0.0f) ? 0.0f : 1.0f / S_total;

    // Pointers to interleaved partial results.
    device const float4 *htmp4 = (device const float4 *)htmp + rid * DV4 * NWG;

    // Reduce: for each D4 chunk, sum the rescaled contributions from all WGs.
    // Each thread of simdgroup `sgitg` handles output chunks with stride `NWG`.
    // We write the reduced float4 (divided by S_total) into `out_tg`, which is
    // shared between all simdgroups in this threadgroup.
    threadgroup float4 *out_tg4 = (threadgroup float4 *)out_tg;
    for (short i = sgitg; i < DV4; i += NWG) {
        float4 val = (iwg < NWG) ? htmp4[i * NWG + iwg] * ms : float4(0.0f);
        float4 reduced = float4(simd_sum(val[0]),
                                simd_sum(val[1]),
                                simd_sum(val[2]),
                                simd_sum(val[3]));
        if (iwg == 0) {
            out_tg4[i] = reduced * inv_S;
        }
    }

    // Synchronize: all simdgroups must finish writing out_tg before
    // simdgroup 0 reads it for the FWHT.
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // --- FWHT (= IWHT for normalized H) + sign-undo, by simdgroup 0 ---
    //
    // Mirrors fwht_sign_undo_fast: load EPT consecutive elements per lane,
    // run the simd butterfly, normalize by 1/sqrt(d), apply sign mask, store.
    if (sgitg == 0) {
        const uint lane = tiisg;
        const uint base = lane * EPT;
        float elems[EPT];
        for (ushort i = 0; i < EPT; i++) {
            elems[i] = out_tg[base + i];
        }

        // FWHT + normalize.
        fwht_simd_undo<EPT>(elems, lane);
        const float inv_sqrt_d = rsqrt(float(DV));
        for (ushort i = 0; i < EPT; i++) {
            elems[i] *= inv_sqrt_d;
        }

        // Sign undo (same table as fwht_sign_premult; sign is self-inverse).
        for (ushort i = 0; i < EPT; i++) {
            ushort j = lane * EPT + i;
            uint8_t sign_byte = (DV == 256) ? TBQ_SIGNS_256_REDUCE[j >> 3]
                                            : TBQ_SIGNS_512_REDUCE[j >> 3];
            float sign_val = ((sign_byte >> (j & 7)) & 1u) ? -1.0f : 1.0f;
            elems[i] *= sign_val;
        }

        // Store final output for this row to global memory.
        device float *dst_row = dst + rid * DV;
        for (ushort i = 0; i < EPT; i++) {
            dst_row[base + i] = elems[i];
        }
    }
}

typedef decltype(flash_attn_vec_reduce_tq_hb_undo<256>) flash_attn_vec_reduce_tq_hb_undo_t;

template [[host_name("flash_attn_vec_reduce_tq_hb_undo_dk256")]]
kernel flash_attn_vec_reduce_tq_hb_undo_t flash_attn_vec_reduce_tq_hb_undo<256>;

template [[host_name("flash_attn_vec_reduce_tq_hb_undo_dk512")]]
kernel flash_attn_vec_reduce_tq_hb_undo_t flash_attn_vec_reduce_tq_hb_undo<512>;