mlx-native 0.8.1

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
// flash_attn_vec_peer_port_f16_reduce.metal — verbatim port of
// llama.cpp's kernel_flash_attn_ext_vec_reduce (ggml-metal.metal lines 7232-7275)
// for DV=256, NWG=32. Used after the NWG=32 vec kernel to combine partial
// results from each workgroup into the final dst.
//
// ADR-029 iter-134 (cfa/fa-peer-port-nwg32 follow-up).
//
// Hypothesis (refined from iter-127/132): peer's flash-attn-vec uses NWG=32 +
// reduce kernel by default (ggml-metal-ops.cpp:2944 — `if (false)` disables
// NWG=1 path). The original CFA port (iter-126) targeted the unused NWG=1 dead
// code path, which falsified at tg5000 (-25%). This file is part of the proper
// NWG=32 + reduce-kernel port.
//
// Surface adaptations only:
//   (a) args struct → FlashAttnVecPeerPortReduceParams (just nrows)
//   (b) DV/NWG baked as file-header #define matching peer's FC-bake
//   (c) Buffer slots: 0=params, 1=htmp (workgroup partials, float*), 2=dst (float*)
// Kernel body VERBATIM from peer 7235-7275.

#include <metal_stdlib>
using namespace metal;

// FC-bake constants (peer specializes via function constants; we bake symbolically per RULE-1 of cfa spec)
#define DV  256
#define NWG 32

// Params struct
struct FlashAttnVecPeerPortReduceParams {
    int32_t nrows;
};

kernel void flash_attn_vec_peer_port_f16_reduce_dv256_nwg32(
        constant FlashAttnVecPeerPortReduceParams & args [[buffer(0)]],
        device   const char *                       htmp [[buffer(1)]],
        device         char *                       dst  [[buffer(2)]],
        uint   tgpig[[threadgroup_position_in_grid]],
        ushort tiisg[[thread_index_in_simdgroup]],
        ushort sgitg[[simdgroup_index_in_threadgroup]]) {

    const uint64_t rid = tgpig;

    const short iwg = tiisg;

    device const float  * ss    = (device const float  *) htmp + (uint64_t)args.nrows*DV*NWG;

    float S = ss[rid*(2*NWG) + 2*iwg + 0];
    float M = ss[rid*(2*NWG) + 2*iwg + 1];

    const float m  = simd_max(M);
    const float ms = exp(M - m);

    S = simd_sum(S*ms);
    S = S == 0.0f ? 0.0f : 1.0f/S;

    const short DV4 = DV/4;

    device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*NWG;
    device       float4 * dst4  = (device       float4 *) dst  + rid*DV4;

    for (short i = sgitg; i < DV4; i += NWG) {
        const float4 v = simd_sum(htmp4[i*NWG + iwg]*ms);

        if (iwg == 0) {
            dst4[i] = v*S;
        }
    }
}