mlx-native 0.6.7

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
#include <metal_stdlib>
using namespace metal;

/// Q4_0 quantize-dequantize round-trip in fp32.
///
/// For each block of QK4_0=32 fp32 input values, computes the GGUF
/// Q4_0 quant→dequant round-trip and writes the rounded fp32 values
/// back to the output buffer (same shape as input).  Byte-identical
/// to `quantize_row_q4_0 → dequantize_row_q4_0` from
/// `hf2q::quantize::q_legacy`.
///
/// Per-block formula (matches q_legacy.rs:329 + 380):
///   max  = signed value at the position with largest |.|  (ties go to LOWER index)
///   d    = max / -8.0                  (fp32)
///   id   = (d == 0) ? 0 : 1/d          (fp32)
///   d_h  = float(half(d))              (f16 round-trip — dequantize uses this)
///   For each element v:
///     q  = clamp((v * id + 8.5) as int, 0, 15)
///     dq = (q - 8) * d_h
///
/// Buffer layout:
///   buffer(0): input  — float[num_blocks * 32]
///   buffer(1): output — float[num_blocks * 32]
///
/// Threadgroup: (32, 1, 1) — one block per threadgroup.
/// Grid threadgroups: (num_blocks, 1, 1)
///
/// Threadgroup shared memory: 64 * sizeof(float) = 256 bytes
///   layout: [amax_arr (32 floats), max_arr (32 floats)]
kernel void qdq_q4_0_f32(
    device const float *input  [[buffer(0)]],
    device float       *output [[buffer(1)]],
    uint  block_idx [[threadgroup_position_in_grid]],
    uint  tid       [[thread_index_in_threadgroup]],
    threadgroup float *shared  [[threadgroup(0)]]
) {
    threadgroup float *amax_arr = shared;          // [32]
    threadgroup float *max_arr  = shared + 32;     // [32]

    const uint base = block_idx * 32u + tid;
    const float v = input[base];
    amax_arr[tid] = fabs(v);
    max_arr[tid]  = v;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Tree reduction over 32 lanes — keep entry with LARGER amax;
    // on tie, keep LEFT (lower tid) to match CPU `>` semantics
    // (q_legacy.rs:351: `if av > amax { amax = av; max = v; }`).
    for (uint stride = 16u; stride > 0u; stride >>= 1u) {
        if (tid < stride) {
            const float r_amax = amax_arr[tid + stride];
            const float l_amax = amax_arr[tid];
            if (r_amax > l_amax) {
                amax_arr[tid] = r_amax;
                max_arr[tid]  = max_arr[tid + stride];
            }
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    const float max_val = max_arr[0];
    const float d  = max_val / -8.0f;
    const float id = (d == 0.0f) ? 0.0f : (1.0f / d);
    // f16 round-trip — dequant reads d as f16 (q_legacy.rs:358 + d() at decode).
    const float d_h = float(half(d));

    int q = int(floor(v * id + 8.5f));
    q = clamp(q, 0, 15);
    output[base] = float(q - 8) * d_h;
}

/// Q8_0 quantize-dequantize round-trip in fp32.
///
/// Per-block formula (matches q_legacy.rs:149 + 187):
///   amax = max(|v|)                     (fp32)
///   d    = amax / 127.0                 (fp32)
///   id   = (d == 0) ? 0 : 1/d           (fp32)
///   d_h  = float(half(d))               (f16 round-trip — dequant uses this)
///   For each element v:
///     q  = clamp(round(v * id) as int, -128, 127)
///     dq = q * d_h
///
/// Buffer layout:
///   buffer(0): input  — float[num_blocks * 32]
///   buffer(1): output — float[num_blocks * 32]
///
/// Threadgroup: (32, 1, 1) — one block per threadgroup.
/// Grid threadgroups: (num_blocks, 1, 1)
///
/// Threadgroup shared memory: 32 * sizeof(float) = 128 bytes
kernel void qdq_q8_0_f32(
    device const float *input  [[buffer(0)]],
    device float       *output [[buffer(1)]],
    uint  block_idx [[threadgroup_position_in_grid]],
    uint  tid       [[thread_index_in_threadgroup]],
    threadgroup float *shared  [[threadgroup(0)]]
) {
    const uint base = block_idx * 32u + tid;
    const float v = input[base];
    shared[tid] = fabs(v);
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Tree reduction for amax (max of |v|).
    for (uint stride = 16u; stride > 0u; stride >>= 1u) {
        if (tid < stride) {
            const float r = shared[tid + stride];
            const float l = shared[tid];
            if (r > l) {
                shared[tid] = r;
            }
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    const float amax = shared[0];
    const float d   = amax / 127.0f;
    const float id  = (d == 0.0f) ? 0.0f : (1.0f / d);
    const float d_h = float(half(d));

    // CPU uses `(v * id).round() as i32` then `clamp(-128, 127) as i8`.
    // Metal `rint` is round-to-nearest-even (banker's), but Rust's
    // `f32::round` is round-half-away-from-zero.  Use `floor(x + 0.5)`
    // for positive, `ceil(x - 0.5)` for negative — equivalent to
    // round-half-away-from-zero, matches Rust.
    const float scaled = v * id;
    int q = (scaled >= 0.0f) ? int(floor(scaled + 0.5f))
                              : int(ceil(scaled - 0.5f));
    q = clamp(q, -128, 127);
    output[base] = float(q) * d_h;
}