mlx-native 0.9.0

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
using namespace metal;

/// ADR-020 iter-15c — simdgroup-MMA variant of `qmm_affine_t_f32`.
///
/// Same I/O contract as iter-15 / iter-15b: computes
/// `Y[m, n] = Σ_k X[m, k] · (q_int[n, k]·scales[n, g(k)] + biases[n, g(k)])`
/// for affine-quantized weight tensors.  Uses Apple GPU's hardware
/// `simdgroup_matrix<float, 8, 8>` MMA instead of the scalar
/// per-thread inner loop in iter-15b.
///
/// ## Layout
///
///   * Each threadgroup: ONE simdgroup (32 threads).
///   * Each TG produces an 8×8 tile of `Y` (via `simdgroup_store`).
///   * Grid: `(ceil(M/8), ceil(N/8), 1)` threadgroups.
///
/// ## Algorithm
///
///   * Per K-step of `BK = 32` (= group_size, same constraint as 15b):
///     - Cooperatively load `X[8][32]` into `a_tile` (32 floats per thread,
///       8 floats × 4 sub-blocks of 8 K-elements each).
///     - Cooperatively dequantize `W[8][32]` into `b_tile` (one 8-element
///       row per thread; per-element dequant `q · scale + bias`).
///     - Inner unroll over `BK / 8 = 4` sub-tiles:
///         simdgroup_load `ma` from `a_tile`, `mb` from `b_tile`,
///         `simdgroup_multiply_accumulate(mc, ma, mb, mc)`.
///   * Final: `simdgroup_store(mc → Y[8][8])` (with bounds check).
///
/// ## Performance expectation
///
/// Apple GPU hardware MMA does an 8x8x8 multiply-accumulate per simdgroup
/// per cycle.  iter-15b's 256-thread × 32-iter scalar inner loop is
/// compute-bound at ~2 TFLOPS on M5 Max.  iter-15c's MMA path does the
/// same work in `(32 × 4) = 128` MMA cycles per output tile (vs `256 × 32 = 8192`
/// scalar cycles), an 8× algorithmic improvement that lands as 3-4× wall
/// once threadgroup-launch + load/store overhead amortizes.
///
/// ## Constraints
///
///   * `group_size == 32` (= BK; one (scales, biases) pair per K-tile).
///   * `K` divisible by `group_size` (host-validated; same as 15b).
///   * `M, N` not required to be multiples of 8 — bounds-check at
///     write-back; partial-tile output uses scalar fallback path.
///   * Single simdgroup per TG (32 threads); shared memory budget
///     `8*32*4 (a_tile) + 8*32*4 (b_tile) = 2048 bytes`.

constant constexpr uint BM = 8;
constant constexpr uint BN = 8;
constant constexpr uint BK = 32;

kernel void qmm_affine_t_f32_simd(
    device const float    *x          [[buffer(0)]],
    device const uchar    *q_int      [[buffer(1)]],
    device const float    *scales     [[buffer(2)]],
    device const float    *biases     [[buffer(3)]],
    device float          *y          [[buffer(4)]],
    device const uint     *meta       [[buffer(5)]],  // [M, N, K, group_size]
    threadgroup float     *shmem      [[threadgroup(0)]],
    uint   tid_in_simd                [[thread_index_in_simdgroup]],
    uint3  tg_id                      [[threadgroup_position_in_grid]]
) {
    const uint M = meta[0];
    const uint N = meta[1];
    const uint K = meta[2];
    // BK is host-validated to equal group_size = 32.

    const uint groups_per_row = K / BK;

    // Threadgroup-shared layout:
    //   shmem[0 .. BM*BK)         = a_tile (X)
    //   shmem[BM*BK .. 2*BM*BK)   = b_tile (dequantized W)
    threadgroup float *a_tile = shmem;
    threadgroup float *b_tile = shmem + BM * BK;

    // Output tile origin in (M, N).
    const uint m_origin = tg_id.x * BM;
    const uint n_origin = tg_id.y * BN;

    // MMA accumulator (8x8 floats — one full output tile).
    simdgroup_float8x8 mc = make_filled_simdgroup_matrix<float, 8>(0.f);

    // Sweep K in BK-sized chunks (= group_size).
    for (uint kt = 0; kt < groups_per_row; kt++) {
        const uint k_base = kt * BK;

        // ---- Cooperative load: X tile [BM=8, BK=32] = 256 floats ----
        // 32 threads × 8 floats per thread.  Thread `tid_in_simd` loads
        // floats `tid_in_simd, tid_in_simd+32, ..., tid_in_simd + 7*32`
        // (column-stride layout matches `simdgroup_load` row-stride 8).
        for (uint slot = 0; slot < 8; slot++) {
            const uint lin = slot * 32 + tid_in_simd;       // 0..255
            const uint r = lin / BK;                         // 0..7
            const uint c = lin % BK;                         // 0..31
            const uint mm = m_origin + r;
            float v = 0.0f;
            if (mm < M && (k_base + c) < K) {
                v = x[mm * K + k_base + c];
            }
            a_tile[lin] = v;
        }

        // ---- Cooperative dequant: W tile [BN=8, BK=32] = 256 floats ----
        // Each thread dequants 8 elements.  Per output row n:
        //   w_dq[n, c] = q_int[n, k_base + c] * scales[n, kt] + biases[n, kt]
        // Layout in b_tile: b_tile[r * BK + c]   (same row-major).
        for (uint slot = 0; slot < 8; slot++) {
            const uint lin = slot * 32 + tid_in_simd;
            const uint r = lin / BK;                         // 0..7
            const uint c = lin % BK;                         // 0..31
            const uint nn = n_origin + r;
            float v = 0.0f;
            if (nn < N && (k_base + c) < K) {
                const uint sb_idx = nn * groups_per_row + kt;
                const float s = scales[sb_idx];
                const float b = biases[sb_idx];
                const float q = float(q_int[nn * K + k_base + c]);
                v = q * s + b;
            }
            b_tile[lin] = v;
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        // ---- Inner MMA: 4 sub-tiles of 8x8 along the BK axis ----
        // a_tile is laid out as [BM=8 rows][BK=32 cols].
        // simdgroup_load reads an 8x8 tile from a contiguous stride-8 region.
        // For sub-tile `ik` (k = ik*8 .. ik*8+8):
        //   ma rows: a_tile[r][ik*8 .. ik*8+8]
        //   mb rows: b_tile[r][ik*8 .. ik*8+8]^T (transposed for k×n form)
        //
        // Since both tiles are row-major [8][32], the simdgroup-load source
        // pointer is `a_tile + ik*8` with row-stride BK (= 32) — but
        // simdgroup_load wants stride argument 8.  We handle this by
        // staging an explicit 8x8 sub-block to a contiguous register-width
        // pattern; here we use the in-place stride form directly: pass
        // `BK` as the matrix stride on a "row-major 8x8 view".
        for (uint ik = 0; ik < BK / 8; ik++) {
            simdgroup_float8x8 ma;
            simdgroup_float8x8 mb;
            simdgroup_load(ma, a_tile + ik * 8, BK, 0, false);
            // For mb we want the K axis as the OUTER (rows of the 8x8
            // matrix) so that ma * mb matches Y = X · W^T orientation.
            // b_tile is [n_row][k_col]; we want [k_col][n_row].  Pass
            // `transpose = true` to simdgroup_load to read with K axis on
            // the outer.
            simdgroup_load(mb, b_tile + ik * 8, BK, 0, true);
            simdgroup_multiply_accumulate(mc, ma, mb, mc);
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    // ---- Write-back ----
    // Fast path: full 8x8 tile inside (M, N) bounds.
    const bool full_tile_in_bounds =
        (m_origin + BM <= M) && (n_origin + BN <= N);
    if (full_tile_in_bounds) {
        device float *y_origin = y + m_origin * N + n_origin;
        simdgroup_store(mc, y_origin, N, 0, false);
    } else {
        // Partial-tile path: stage 8x8 tile to shared memory, copy out
        // element-wise with bounds check.
        threadgroup float *staging = a_tile;  // reuse a_tile for output
        threadgroup_barrier(mem_flags::mem_threadgroup);
        simdgroup_store(mc, staging, BN, 0, false);
        threadgroup_barrier(mem_flags::mem_threadgroup);
        // 32 threads cooperatively copy 8x8 = 64 elements (2 each).
        for (uint slot = 0; slot < 2; slot++) {
            const uint lin = slot * 32 + tid_in_simd;        // 0..63
            const uint r = lin / BN;                          // 0..7
            const uint c = lin % BN;                          // 0..7
            const uint mm = m_origin + r;
            const uint nn = n_origin + c;
            if (mm < M && nn < N) {
                y[mm * N + nn] = staging[r * BN + c];
            }
        }
    }
}