mlx-native 0.9.0

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// ADR-020 iter-15b — tiled variant of `qmm_affine_t_f32`.
///
/// Same math + I/O contract as iter-15's per-element kernel, but
/// shifts the inner reduction onto threadgroup-shared tiles so reads
/// of X and (scales, biases) are amortized across the 16x16 thread
/// block.  Per-element kernel reads X[m, k] independently for each
/// (m, n) pair; tiled variant loads X[BM, BK] once per K-tile and
/// reuses it across BN output columns of the tile (register-level
/// reuse via the inner loop).
///
/// Tile geometry (compile-time constants, must agree with the host
/// dispatcher's threadgroup size and shared-memory allocation):
///   BM = 16   output rows per threadgroup
///   BN = 16   output cols per threadgroup
///   BK = 32   K-tile depth = group_size (REQUIRED — see below)
///
/// **Constraint**: this kernel REQUIRES `group_size == 32`.  This
/// makes BK align exactly with one (scales, biases) pair per
/// K-tile per output row, eliminating the inner per-group fetch.
/// The host validates this; for `group_size != 32` callers fall back
/// to the iter-15 per-element kernel.
///
/// Threadgroup-shared budget:
///   x_tile : f32[BM][BK] = 16 * 32 * 4 =  2048 bytes
///   q_tile :  u8[BN][BK] = 16 * 32 * 1 =   512 bytes
///   s_tile : f32[BN]      =      16 * 4 =    64 bytes
///   b_tile : f32[BN]      =      16 * 4 =    64 bytes
///   total                                =  2688 bytes
///
/// Threadgroup: 16x16 = 256 threads.  Each thread computes one
/// output element y[m, n].
///
/// Performance vs iter-15 per-element kernel: ~2-5× depending on
/// matrix size (smaller speedup at small M/N where TG-overhead
/// dominates).  The simdgroup-MMA optimization (mlx::steel::BlockMMA
/// equivalent) is queued for iter-15c.

constant constexpr uint BM = 16;
constant constexpr uint BN = 16;
constant constexpr uint BK = 32;

kernel void qmm_affine_t_f32_tiled(
    device const float    *x          [[buffer(0)]],
    device const uchar    *q_int      [[buffer(1)]],
    device const float    *scales     [[buffer(2)]],
    device const float    *biases     [[buffer(3)]],
    device float          *y          [[buffer(4)]],
    device const uint     *meta       [[buffer(5)]],  // [M, N, K, group_size]
    threadgroup float     *shmem      [[threadgroup(0)]],
    uint2 tid                         [[thread_position_in_threadgroup]],
    uint2 bid                         [[threadgroup_position_in_grid]]
) {
    const uint M          = meta[0];
    const uint N          = meta[1];
    const uint K          = meta[2];
    // group_size is host-validated to be 32 for this kernel — the
    // shader skips the meta[3] dynamic check and uses BK directly.

    // Threadgroup-shared layout:
    //   shmem[0 .. BM*BK)         = x_tile  (f32, BM*BK floats)
    //   shmem[BM*BK .. BM*BK+BN]  = s_tile  (f32, BN floats)
    //   shmem[BM*BK+BN .. ...+BN] = b_tile  (f32, BN floats)
    //   tail: q_tile (u8, BN*BK bytes) — packed after the float regions
    threadgroup float *x_tile = shmem;
    threadgroup float *s_tile = shmem + BM * BK;
    threadgroup float *b_tile = s_tile + BN;
    threadgroup uchar *q_tile = (threadgroup uchar *)(b_tile + BN);

    const uint m = bid.x * BM + tid.x;
    const uint n = bid.y * BN + tid.y;
    const bool m_in_bounds = (m < M);
    const bool n_in_bounds = (n < N);

    const uint groups_per_row = K / BK;     // = K / group_size

    // Linear thread index in the 16x16 TG (0..255).
    const uint tlin = tid.x * BN + tid.y;

    float acc = 0.0f;

    // Sweep K in BK-sized tiles.
    for (uint kt = 0; kt < groups_per_row; kt++) {
        const uint k_base = kt * BK;

        // ---- Cooperative loads ----
        // X tile: [BM][BK] = 512 floats; 256 threads → 2 floats each.
        //   thread tlin loads positions tlin and tlin + 256 in row-major
        //   linear index over the tile.
        for (uint slot = 0; slot < 2; slot++) {
            const uint lin_idx = slot * 256 + tlin;
            const uint row = lin_idx / BK;       // 0..15
            const uint col = lin_idx % BK;       // 0..31
            const uint mm = bid.x * BM + row;
            const float v = (mm < M && (k_base + col) < K)
                ? x[mm * K + k_base + col]
                : 0.0f;
            x_tile[lin_idx] = v;
        }

        // q_int tile: [BN][BK] = 512 bytes; 256 threads → 2 bytes each.
        for (uint slot = 0; slot < 2; slot++) {
            const uint lin_idx = slot * 256 + tlin;
            const uint row = lin_idx / BK;       // 0..15
            const uint col = lin_idx % BK;       // 0..31
            const uint nn = bid.y * BN + row;
            const uchar q = (nn < N && (k_base + col) < K)
                ? q_int[nn * K + k_base + col]
                : (uchar)0;
            q_tile[lin_idx] = q;
        }

        // s/b tiles: BN floats each; first 16 threads in tlin order load.
        if (tlin < BN) {
            const uint nn = bid.y * BN + tlin;
            const uint sb_idx = nn * groups_per_row + kt;
            s_tile[tlin] = (nn < N) ? scales[sb_idx] : 0.0f;
            b_tile[tlin] = (nn < N) ? biases[sb_idx] : 0.0f;
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        // ---- Inner reduction across BK ----
        if (m_in_bounds && n_in_bounds) {
            const float s = s_tile[tid.y];
            const float b = b_tile[tid.y];
            // Shared X row for this thread (depends only on tid.x).
            threadgroup float *x_row = x_tile + tid.x * BK;
            // Shared q row for this thread (depends only on tid.y).
            threadgroup uchar *q_row = q_tile + tid.y * BK;

            for (uint k = 0; k < BK; k++) {
                const float xv = x_row[k];
                const float qv = float(q_row[k]);
                const float wv = qv * s + b;
                acc += xv * wv;
            }
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    if (m_in_bounds && n_in_bounds) {
        y[m * N + n] = acc;
    }
}