mlx-native 0.8.1

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
// dequant_to_f16.metal — whole-tensor dequantization from block-quantized
// formats (Q4_0, Q5_K, Q6_K, Q8_0, etc.) to F16 storage.
//
// ADR-029 iter-28 H29 — peer (llama.cpp on Apple Silicon) NEVER dispatches
// quantized mat-mat kernels for gemma4 attn weights (Q6_K).  Instead it
// pre-dequantizes Q6_K → F16 once at model load, then runs the F16-input
// `kernel_mul_mm_f16_f32_*` for every dense attn dispatch.  This trades
// ~1 GB of extra resident memory for 2-3× faster per-call dense MM at
// prefill — bandwidth-friendly F16 reads instead of per-call dequant work
// inside the matmul kernel.
//
// hf2q has 128 GB unified memory on the target M5 Max device; the memory
// trade is favorable.  This shader produces the F16 shadow buffer that
// the load-path materializes once per quantized weight tensor.
//
// Design (mirrors peer's `kernel_get_rows_q` at ggml-metal.metal:9164-9191):
//   * One thread per 16-element group.
//   * Each thread reads its corresponding `(block_idx, il)` slot:
//       block_idx = ind / nl
//       il        = ind % nl
//   * Calls `dequantize_func(block_ptr + block_idx, il, temp)` which
//     produces a `half4x4` (16 halfs).
//   * Writes the half4x4 to `dst[ind]` (treated as `device half4x4 *`).
//
// Dispatch: total threads = n_elements / 16.  Pick threadgroup size to
// saturate (e.g. 256 threads/tg).
//
// `nl` is the type-specific QK_NL constant (2 for legacy block-quant types
// at QK=32; 16 for K-quants at QK_K=256).  See callers for the per-type
// instantiation list.
//
// Coherence: F16 storage of dequantized Q6_K introduces F16-rounding drift
// vs the per-call float-precision dequant the V1/V2 matmul kernel does.
// Empirically peer ships this in production on gemma4-26B with no observable
// quality regression; the F16 mantissa (10 bits, ~1e-3 ulp) is well above
// the Q6_K quantization noise floor.  Sourdough byte-identity is NOT
// expected; coherence gate is "fluent output at temp=0 across regimes".

#include <metal_stdlib>
using namespace metal;

// ---- GGML block sizes (match quantized_matmul_mm_tensor.metal) ----
#define QK4_0 32
#define QK8_0 32
#define QK_K  256
#define K_SCALE_SIZE 12

// ---- Block struct definitions (byte-identical to quantized_matmul_mm_tensor.metal) ----

typedef struct {
    half    d;
    uint8_t qs[QK4_0 / 2];
} block_q4_0;

typedef struct {
    half   d;
    int8_t qs[QK8_0];
} block_q8_0;

typedef struct {
    uint8_t ql[QK_K/2];
    uint8_t qh[QK_K/4];
    int8_t  scales[QK_K/16];
    half    d;
} block_q6_K;

typedef struct {
    half    d;
    half    dmin;
    uint8_t scales[K_SCALE_SIZE];
    uint8_t qh[QK_K/8];
    uint8_t qs[QK_K/2];
} block_q5_K;

typedef struct {
    half    d;
    half    dmin;
    uint8_t scales[K_SCALE_SIZE];
    uint8_t qs[QK_K/2];
} block_q4_K;

typedef struct {
    half    d;
    half    m;
    uint    qh;
    uint8_t qs[QK4_0 / 2];
} block_q5_1;

typedef struct {
    half    d;
    uint8_t qs[QK4_0 / 2];
} block_iq4_nl;

constant int8_t kvalues_iq4nl[16] = {
    -127, -104, -83, -65, -49, -35, -22, -10,
       1,   13,  25,  38,  53,  69,  89, 113
};

static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
    return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
                 : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)),
                          uchar((q[j+4+k] >> 4)  | ((q[j-0+k] & 0xc0) >> 2))};
}

// ---- Dequantize helpers (identical bodies to quantized_matmul_mm_tensor.metal) ----
// Each function dequantizes ONE 16-element sub-group of one block_q.

template <typename type4x4>
void dq_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
    device const uint16_t * qs = ((device const uint16_t *)xb + 1);
    const float d1 = il ? (xb->d / 16.h) : xb->d;
    const float d2 = d1 / 256.f;
    const float md = -8.h * xb->d;
    const ushort mask0 = il ? 0x00F0 : 0x000F;
    const ushort mask1 = mask0 << 8;
    float4x4 reg_f;
    for (int i = 0; i < 8; i++) {
        reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
        reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
    }
    reg = (type4x4) reg_f;
}

template <typename type4x4>
void dq_q8_0(device const block_q8_0 * xb, short il, thread type4x4 & reg) {
    device const int8_t * qs = ((device const int8_t *)xb->qs);
    const half d = xb->d;
    float4x4 reg_f;
    for (int i = 0; i < 16; i++) {
        reg_f[i/4][i%4] = (float)(d * qs[16*il + i]);
    }
    reg = (type4x4) reg_f;
}

template <typename type4x4>
void dq_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
    device const uint16_t * qs = ((device const uint16_t *)xb + 4);
    const float d = xb->d;
    const float m = xb->m;
    const ushort mask = il ? 0x00F0 : 0x000F;
    const uint32_t qh = xb->qh;
    const int x_mv = il ? 4 : 0;
    const int gh_mv = il ? 12 : 0;
    const int gh_bk = il ?  0 : 4;
    float4x4 reg_f;
    for (int i = 0; i < 8; i++) {
        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
        reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
        reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
    }
    reg = (type4x4) reg_f;
}

template <typename type4x4>
void dq_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
    device const uint16_t * q4 = (device const uint16_t *)xb->qs;
    const float d = xb->d;
    uint32_t aux32;
    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
    for (int i = 0; i < 4; ++i) {
        aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
        reg[i][0] = d * (float)kvalues_iq4nl[q8[0]];
        reg[i][1] = d * (float)kvalues_iq4nl[q8[1]];
        reg[i][2] = d * (float)kvalues_iq4nl[q8[2]];
        reg[i][3] = d * (float)kvalues_iq4nl[q8[3]];
    }
}

template <typename type4x4>
void dq_q5_K(device const block_q5_K * xb, short il, thread type4x4 & reg) {
    device const uchar * q = xb->qs + 32*(il/4) + 16*(il&1);
    device const uchar * qh = xb->qh + 16*(il&1);
    const uchar2 sc = get_scale_min_k4_just2(il/2, 8, xb->scales);
    const float d_all = (float)xb->d;
    const float dl = d_all * sc[0];
    const float ml = (float)xb->dmin * sc[1];
    const ushort mask = 1 << (il/2);
    float4x4 reg_f;
    for (int i = 0; i < 16; ++i) {
        const float val = dl * (((q[i] >> (4*(il&2))) & 0xF) + (qh[i] & mask ? 16 : 0)) - ml;
        reg_f[i/4][i%4] = val;
    }
    reg = (type4x4) reg_f;
}

template <typename type4x4>
void dq_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
    device const uchar * q = xb->qs + 32*(il/4) + 16*(il&1);
    const uchar2 sc = get_scale_min_k4_just2(il/2, 8, xb->scales);
    const float d_all = (float)xb->d;
    const float dl = d_all * sc[0];
    const float ml = (float)xb->dmin * sc[1];
    float4x4 reg_f;
    for (int i = 0; i < 16; ++i) {
        const float val = dl * ((q[i] >> (4*(il&2))) & 0xF) - ml;
        reg_f[i/4][i%4] = val;
    }
    reg = (type4x4) reg_f;
}

// ADR-029 iter-29 — Q6_K dequant in LINEAR K-order, byte-identical to the
// CPU reference at /opt/mlx-native/src/gguf/mod.rs:667-720.
//
// Layout per Q6_K block (210 bytes, 256 elements):
//   ql[0..128]     — low nibbles
//   qh[0..64]      — high 2-bit pairs
//   sc[0..16] i8   — 16 sub-scales
//   d:half          — block scale
//
// Each `il ∈ [0, 16)` here covers a CONTIGUOUS 16-element K-strip,
// in linear K-position order.  Writing `*(half4x4 *)dst + tid = reg` then
// produces a row-major linear-K F16 tensor that matches what
// `dense_matmul_f16_f32_tensor` expects (peer's `kernel_get_rows_q6_K`
// pattern would re-pack at runtime; we do it once at load).
//
// Math is the CPU dequantize_q6_k code restructured to emit exactly the
// 16 elements at K-positions [il*16 .. il*16 + 16).  Index map (matches
// CPU loops at gguf/mod.rs:702-715):
//   block of 256 elements = 2 halves × 128 elements
//   half h ∈ {0, 1}: ql_base = ql + 64*h, qh_base = qh + 32*h,
//                    sc_base = sc + 8*h, out_base offset = 128*h
//   within a half, 32 'l' values (l ∈ [0, 32)) produce 4 elements each
//   at K-positions {l, l+32, l+64, l+96}.
//
// For il=0..15, the linear K-range is [il*16, il*16+16).
//   half h          = il / 8
//   within-half pos = il % 8 * 16 → 16-element strip starts at this
//
// Each strip's 16 elements are at K-positions s..s+15 where s = (il%8)*16.
// Within the strip's [0, 16) sub-index i, the CPU code maps to specific
// (l, group, q-index).  Reconstruct here:
//   If s + i < 32 (group 0 = q1):  l = s + i,        out_pos = l
//   If s + i < 64 (group 1 = q2):  l = (s + i) - 32, out_pos = l + 32
//   If s + i < 96 (group 2 = q3):  l = (s + i) - 64, out_pos = l + 64
//   If s + i < 128 (group 3 = q4): l = (s + i) - 96, out_pos = l + 96
template <typename type4x4>
void dq_q6_K(device const block_q6_K * xb, short il, thread type4x4 & reg) {
    const float d_all = (float)xb->d;
    device const uint8_t * ql = (device const uint8_t *)xb->ql;
    device const uint8_t * qh = (device const uint8_t *)xb->qh;
    device const int8_t  * sc = (device const int8_t  *)xb->scales;

    const short h     = il / 8;          // 0 or 1
    const short s     = (il % 8) * 16;   // 16-element strip start within half

    device const uint8_t * ql_base = ql + 64 * h;
    device const uint8_t * qh_base = qh + 32 * h;
    device const int8_t  * sc_base = sc + 8  * h;

    float4x4 reg_f;
    for (short i = 0; i < 16; ++i) {
        const short k_in_half = s + i;            // 0..127
        const short group = k_in_half / 32;       // 0..3
        const short l = k_in_half - group * 32;   // 0..31
        const short is = l / 16;                  // 0 for l<16, 1 for l>=16

        // Extract the 6-bit value for this position.  Matches CPU code:
        //   group 0: q1 = (ql_base[l]   & 0xF) | ((qh_base[l] & 3) << 4) - 32
        //   group 1: q2 = (ql_base[l+32]& 0xF) | ((qh_base[l]>>2 & 3) << 4) - 32
        //   group 2: q3 = (ql_base[l]   >> 4)  | ((qh_base[l]>>4 & 3) << 4) - 32
        //   group 3: q4 = (ql_base[l+32]>> 4)  | ((qh_base[l]>>6 & 3) << 4) - 32
        // CPU reference (gguf/mod.rs:705-710):
        //   group 0 (q1): ql_base[l]    & 0xF    | (qh_base[l] & 3)        << 4
        //   group 1 (q2): ql_base[l+32] & 0xF    | ((qh_base[l] >> 2) & 3) << 4
        //   group 2 (q3): ql_base[l]    >> 4     | ((qh_base[l] >> 4) & 3) << 4
        //   group 3 (q4): ql_base[l+32] >> 4     | ((qh_base[l] >> 6) & 3) << 4
        //   Then cast to i8 and subtract 32 to get signed [-32, 31].
        //
        // Mapping: groups 0,1 use LOW nibble of ql; groups 2,3 use HIGH nibble.
        //          groups 0,2 use ql_base[l]; groups 1,3 use ql_base[l+32].
        uint8_t ql_byte;
        uint8_t shift_h;
        switch (group) {
            case 0: ql_byte = ql_base[l];      shift_h = 0; break;
            case 1: ql_byte = ql_base[l + 32]; shift_h = 2; break;
            case 2: ql_byte = ql_base[l];      shift_h = 4; break;
            default: /* 3 */ ql_byte = ql_base[l + 32]; shift_h = 6; break;
        }
        const uint8_t high_bits = (qh_base[l] >> shift_h) & 0x3;
        // group < 2 → low nibble; group >= 2 → high nibble (>> 4).
        const uint8_t low_bits = (group < 2) ? (ql_byte & 0xF) : (ql_byte >> 4);
        // Final 6-bit value: subtract 32 in i8 space → [-32, 31].
        const int q = (int)((int8_t)((low_bits | (high_bits << 4)) - 32));

        // Sub-scale: each group has 4 scales, indexed by `is` (l/16).
        const float scale = d_all * (float)(sc_base[group * 2 + is]);

        const float val = scale * (float)q;
        reg_f[i / 4][i % 4] = val;
    }
    reg = (type4x4) reg_f;
}

// ---- The whole-tensor dequant kernel ----
//
// Total dispatched threads = n_elements / 16 = n_blocks * nl
// Each thread:
//   block_idx = tid / nl
//   il        = tid % nl
//   dequantize_func(src + block_idx, il, temp_half4x4)
//   *(dst + tid) = temp_half4x4    (16 halfs at offset tid*16)

template<typename block_q, short nl,
         void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void hf2q_dequant_to_f16_impl(
        constant uint32_t & n_groups [[buffer(0)]],
        device const char * src      [[buffer(1)]],
        device       char * dst      [[buffer(2)]],
        uint tid [[thread_position_in_grid]]) {
    if (tid >= n_groups) return;

    const uint block_idx = tid / nl;
    const short il = (short)(tid % nl);

    device const block_q * blk = (device const block_q *)src + block_idx;
    device half4x4 * out = (device half4x4 *)dst + tid;

    half4x4 temp;
    dequantize_func(blk, il, temp);
    *out = temp;
}

// ---- Kernel instantiations (one per supported quant type) ----

template [[host_name("hf2q_dequant_q4_0_to_f16")]]
kernel void hf2q_dequant_to_f16_impl<block_q4_0, 2, dq_q4_0>(
    constant uint32_t &, device const char *, device char *, uint);

template [[host_name("hf2q_dequant_q8_0_to_f16")]]
kernel void hf2q_dequant_to_f16_impl<block_q8_0, 2, dq_q8_0>(
    constant uint32_t &, device const char *, device char *, uint);

template [[host_name("hf2q_dequant_q5_1_to_f16")]]
kernel void hf2q_dequant_to_f16_impl<block_q5_1, 2, dq_q5_1>(
    constant uint32_t &, device const char *, device char *, uint);

template [[host_name("hf2q_dequant_iq4_nl_to_f16")]]
kernel void hf2q_dequant_to_f16_impl<block_iq4_nl, 2, dq_iq4_nl>(
    constant uint32_t &, device const char *, device char *, uint);

template [[host_name("hf2q_dequant_q4_K_to_f16")]]
kernel void hf2q_dequant_to_f16_impl<block_q4_K, 16, dq_q4_K>(
    constant uint32_t &, device const char *, device char *, uint);

template [[host_name("hf2q_dequant_q5_K_to_f16")]]
kernel void hf2q_dequant_to_f16_impl<block_q5_K, 16, dq_q5_K>(
    constant uint32_t &, device const char *, device char *, uint);

template [[host_name("hf2q_dequant_q6_K_to_f16")]]
kernel void hf2q_dequant_to_f16_impl<block_q6_K, 16, dq_q6_K>(
    constant uint32_t &, device const char *, device char *, uint);