mlx-native 0.6.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
// dense_gemv_bf16.metal — Dense bf16 × f32 → f32 GEMV (matrix-vector multiply).
//
// Port of llama.cpp's `kernel_mul_mv_t_t_4` template instantiation for
// `bfloat, bfloat4, float, float4` (i.e. `kernel_mul_mv_bf16_f32_4`).
// Reference: ggml/src/ggml-metal/ggml-metal.metal.
//
// Used when M == 1 (single-token decode) for linear projections.
// For M > 1, use the GEMM tensor-core kernel (dense_mm_bf16_tensor.metal).
//
// Layout contract (identical to llama.cpp):
//   src0  [src0_batch, N, K]  bfloat, row-major  (weight matrix, transposed convention)
//   src1  [src1_batch, M, K]  float,  row-major  (input vectors)
//   dst   [src1_batch, M, N]  float,  row-major  (output vectors)
//
// Reduction strategy:
//   - Grid: (ceil(N/NR0), M, src1_batch).
//   - Each threadgroup handles NR0=2 output elements (weight rows) for one
//     input row.
//   - Each threadgroup has 32 × NSG threads (one simdgroup per "lane block").
//     NSG = min(4, (K + 127) / 128) — empirically chosen by llama.cpp.
//   - Each simdgroup computes a partial dot product of its K-slice and reduces
//     via simd_sum, then stores to threadgroup memory for the final cross-group
//     reduction.
//
// Derived from llama.cpp (https://github.com/ggml-org/llama.cpp), MIT licensed.
// Copyright the llama.cpp Authors.  See LICENSE-MIT-llamacpp.

#include <metal_stdlib>
using namespace metal;

// ---- Host-facing params struct ---------------------------------------------
//
// Field layout is identical to `ggml_metal_kargs_mul_mv` in ggml-metal-impl.h
// (bytes 0-111).  Unused fields (nb00, ne10) are present for layout
// compatibility but ignored by this kernel.
struct DenseGemvBf16Params {
    int32_t  ne00;   // K  — contract dim
    int32_t  ne01;   // N  — number of weight rows (output dim)
    int32_t  ne02;   // src0_batch
    uint64_t nb00;   // src0 element stride (bytes) — unused (assumed 2)
    uint64_t nb01;   // src0 row stride (bytes)     = K * sizeof(bfloat)
    uint64_t nb02;   // src0 batch stride (bytes)   = N * K * sizeof(bfloat)
    uint64_t nb03;   // src0 super-batch stride      — unused
    int32_t  ne10;   // ne10 — unused (= K)
    int32_t  ne11;   // M  — number of input rows
    int32_t  ne12;   // src1_batch
    uint64_t nb10;   // src1 element stride (bytes) — unused (assumed 4)
    uint64_t nb11;   // src1 row stride (bytes)     = K * sizeof(float)
    uint64_t nb12;   // src1 batch stride (bytes)   = M * K * sizeof(float)
    uint64_t nb13;   // src1 super-batch stride      — unused
    int32_t  ne0;    // N (output cols, = ne01)
    int32_t  ne1;    // M (output rows, = ne11)
    int32_t  nr0;    // NR0 — weight rows per threadgroup (always 2)
    int16_t  r2;     // src1_batch / src0_batch (GQA broadcast factor)
    int16_t  r3;     // super-batch broadcast — unused (always 1)
};

// ---- Kernel ----------------------------------------------------------------
//
// Template parameters:
//   NR0  = weight rows per threadgroup (2, matching llama.cpp default).
//   NSG  = simdgroups per threadgroup;  baked-in as a constant (4).
//          Caller chooses 1..4 based on K; we compile NSG=4 and let the
//          host limit the grid accordingly.  Unused simdgroups produce zero
//          contributions that are harmlessly summed.
//
// llama.cpp uses a function_constant for NSG; we hard-code NSG=4 since we
// only need to cover the M5 Max's large K dimensions (K ≥ 2048 always for
// our model, so NSG = min(4, (K+127)/128) = 4 in all cases).

kernel void hf2q_dense_gemv_bf16_f32_4(
        constant DenseGemvBf16Params & args,
        device const char * src0,           // bfloat  [ne02, N, K]
        device const char * src1,           // float   [ne12, M, K]
        device       char * dst,            // float   [ne12, M, N]
        threadgroup  char * shmem [[threadgroup(0)]],
        uint3  tgpig [[threadgroup_position_in_grid]],
        ushort tiisg [[thread_index_in_simdgroup]],
        ushort sgitg [[simdgroup_index_in_threadgroup]]) {

    // NSG = number of simdgroups per threadgroup (hard-coded 4).
    // NW  = simdgroup width on Apple GPU = 32.
    constexpr short NSG = 4;
    constexpr short NW  = 32;
    // NB  = elements per "inner block" (32 bfloat4 = 128 scalars).
    // NF  = bfloat4 vector elements loaded per thread-iteration (4×4=16).
    constexpr short NB  = 32;   // inner blocks of NF=16 are strided by NW=32
    constexpr short NF  = 16;
    constexpr short NF4 = NF/4; // 4 — bfloat4 vectors per thread-iteration

    const int nb = args.ne00 / NB;  // number of full NB-wide inner blocks in K

    // Threadgroup position:
    //   tgpig.x = output-row tile index  (covers NR0 = 2 output rows)
    //   tgpig.y = input-row  index       (one threadgroup per M row)
    //   tgpig.z = batch index
    const int r0 = (int)tgpig.x * 2;  // NR0 = 2
    const int r1 = (int)tgpig.y;
    const int im = (int)tgpig.z;

    const uint i12 = (uint)im % (uint)args.ne12;
    const uint i13 = (uint)im / (uint)args.ne12;

    // Input vector (src1) pointer for this input row and batch.
    const uint64_t offset1 = (uint64_t)r1   * args.nb11
                           + (uint64_t)i12  * args.nb12
                           + (uint64_t)i13  * args.nb13;

    device const float4 * y4 = (device const float4 *)(src1 + offset1);

    // Weight row pointers for the 2 output rows this threadgroup handles.
    device const bfloat4 * ax4[2];
    for (short row = 0; row < 2; ++row) {
        const uint64_t offset0 = (uint64_t)(r0 + row) * args.nb01
                               + (uint64_t)(i12 / (uint)args.r2) * args.nb02
                               + (uint64_t)(i13 / (uint)args.r3) * args.nb03;
        ax4[row] = (device const bfloat4 *)((device const char *)src0 + offset0);
    }

    // Partial dot products per row.
    float sumf[2] = { 0.f, 0.f };

    // Each simdgroup handles a contiguous slice of inner blocks.
    // ix = which thread within the simdgroup's NF-wide sub-slice (0..NW/NF-1 = 0..1).
    // il = which NF4-aligned bfloat4 sub-slice (0..NW/NF-1 = 0..1).
    const short ix = tiisg / (NW / NF);   // 0..1
    const short il = tiisg % (NW / NF);   // 0..1

    // Starting inner block for this simdgroup + thread.
    const int ib0 = (int)sgitg * NF + ix;

    // bfloat4 vector cache for the current y slice.
    float4 yl4[NF4];

    // Pointer to the starting position in y4 for this thread.
    device const float4 * yb4 = y4 + (ib0 * NB + il * NF) / 4;

    // Main loop: stride by NSG*NF inner blocks across K.
    for (int ib = ib0; ib < nb; ib += NSG * NF) {
        // Load NF4 float4 vectors from the input.
        for (short i = 0; i < NF4; ++i) {
            yl4[i] = yb4[i];
        }

        // Accumulate dot product for each weight row.
        for (short row = 0; row < 2; ++row) {
            device const bfloat4 * xb4 = ax4[row] + (ib * NB + il * NF) / 4;
            float sumq = 0.f;
            for (short i = 0; i < NF4; ++i) {
                sumq += dot(float4(xb4[i]), yl4[i]);
            }
            sumf[row] += sumq;
        }

        yb4 += NSG * NF * NW / 4;
    }

    // Tail loop for any remaining scalars past the last full NB block.
    // Use scalar float/bfloat access.
    device const float  * y_scalar  = (device const float  *)(src1 + offset1);
    for (int i = nb * NB + (int)sgitg * NW + (int)tiisg; i < args.ne00; i += NW * NSG) {
        for (short row = 0; row < 2; ++row) {
            device const bfloat * ax_scalar = (device const bfloat *)((device const char *)src0
                + (uint64_t)(r0 + row) * args.nb01
                + (uint64_t)(i12 / (uint)args.r2) * args.nb02
                + (uint64_t)(i13 / (uint)args.r3) * args.nb03);
            sumf[row] += (float)ax_scalar[i] * y_scalar[i];
        }
    }

    // ---- Threadgroup reduction (identical to llama.cpp helper_mv_reduce_and_write) ----
    //
    // Layout of threadgroup memory: [NR0][NW] floats = [2][32] floats = 256 bytes.
    threadgroup float * shmem_f32 = (threadgroup float *)shmem;

    // Phase 1: simd_sum within each simdgroup, store to shmem[row][sgitg].
    for (short row = 0; row < 2; ++row) {
        if (sgitg == 0) {
            shmem_f32[row * NW + tiisg] = 0.f;
        }
        sumf[row] = simd_sum(sumf[row]);
    }

    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (short row = 0; row < 2; ++row) {
        if (tiisg == 0) {
            shmem_f32[row * NW + sgitg] = sumf[row];
        }
    }

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Phase 2: simd_sum across the NSG partial sums stored in shmem.
    // Output pointer for this batch + input-row.
    device float * dst_f32 = (device float *)dst
        + (uint64_t)im * (uint64_t)args.ne0 * (uint64_t)args.ne1
        + (uint64_t)r1 * (uint64_t)args.ne0;

    for (short row = 0; row < 2 && r0 + row < args.ne01; ++row) {
        float tot = simd_sum(shmem_f32[row * NW + tiisg]);
        if (tiisg == 0 && sgitg == 0) {
            dst_f32[r0 + row] = tot;
        }
    }
}