mlx-native 0.3.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
//! Hadamard-quantize KV cache kernel dispatch (ADR-007 Phase 1.1).
//!
//! Replaces `kv_cache_copy_batch_f32_to_f16` with a fused kernel that
//! applies a Fast Walsh-Hadamard Transform, extracts the L2 norm, and
//! quantizes each coordinate using the 4-bit Lloyd-Max codebook before
//! packing the indices as nibbles into the output buffer.
//!
//! Output format per head per token:
//! - `packed`: `[num_kv_heads, cache_capacity, head_dim/2]` u8 — nibble-packed 4-bit indices
//! - `norms`:  `[num_kv_heads, cache_capacity]` f32 — per-position L2 norm scalar

use metal::MTLSize;

use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;

use super::encode_helpers::{encode_threadgroups_with_args_and_shared, KernelArg};

/// MSL source for the `hadamard_quantize_kv` kernel (embedded at compile time).
pub static HADAMARD_QUANTIZE_KV_SHADER_SOURCE: &str =
    include_str!("../shaders/hadamard_quantize_kv.metal");

/// Register the `hadamard_quantize_kv` shader source with the given kernel registry.
pub fn register(registry: &mut KernelRegistry) {
    registry.register_source("hadamard_quantize_kv", HADAMARD_QUANTIZE_KV_SHADER_SOURCE);
}

/// Parameters struct matching the `HadamardQuantizeParams` in the Metal shader.
///
/// `repr(C)` + `bytemuck::Pod` ensures the struct can be passed directly via
/// `set_bytes` without any marshalling.
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct HadamardQuantizeParams {
    head_dim: u32,
    num_kv_heads: u32,
    write_pos: u32,
    cache_capacity: u32,
    is_sliding: u32,
}

/// Dispatch the fused Hadamard-quantize KV kernel on the GPU.
///
/// For each KV head vector (length `head_dim`) in the source:
/// 1. Applies in-place normalized FWHT (butterfly, in shared memory).
/// 2. Extracts the L2 norm of the rotated vector.
/// 3. Normalizes to unit sphere, then scales to N(0,1) domain.
/// 4. Finds the nearest 4-bit Lloyd-Max centroid for every coordinate.
/// 5. Packs pairs of 4-bit indices as nibbles into `packed`.
/// 6. Writes the L2 norm scalar to `norms`.
///
/// # Arguments
///
/// * `encoder`       — Command encoder to record the dispatch into.
/// * `registry`      — Kernel registry (must have `hadamard_quantize_kv` registered).
/// * `device`        — Metal device for pipeline compilation.
/// * `src`           — F32 buffer of shape `[num_kv_heads, head_dim]` (one token, all heads).
/// * `packed`        — u8 buffer of shape `[num_kv_heads, cache_capacity, head_dim/2]`.
/// * `norms`         — F32 buffer of shape `[num_kv_heads, cache_capacity]`.
/// * `num_kv_heads`  — Number of KV heads (threadgroups dispatched).
/// * `head_dim`      — Elements per head.  Must be a power of two in `[4, 4096]`.
/// * `cache_capacity`— Cache capacity (ring buffer size for sliding, max_seq_len for global).
/// * `write_pos`     — Write position in cache (the kernel applies modulo for sliding window).
/// * `is_sliding`    — If `true`, `write_pos` is wrapped modulo `cache_capacity`.
///
/// # Errors
///
/// Returns `MlxError::InvalidArgument` if:
/// - `head_dim` is not a power of two.
/// - `head_dim` is larger than 4096 (would exceed Metal 32 KB threadgroup limit at 2× float).
/// - `head_dim` is odd (nibble packing requires even count).
/// - Source buffer is smaller than `num_kv_heads * head_dim` f32 elements.
/// - `packed` buffer is smaller than `num_kv_heads * cache_capacity * head_dim/2` bytes.
/// - `norms` buffer is smaller than `num_kv_heads * cache_capacity` f32 elements.
/// - For global (non-sliding) caches: `write_pos >= cache_capacity`.
#[allow(clippy::too_many_arguments)]
pub fn dispatch_hadamard_quantize_kv(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &metal::DeviceRef,
    src: &MlxBuffer,
    packed: &MlxBuffer,
    norms: &MlxBuffer,
    num_kv_heads: u32,
    head_dim: u32,
    cache_capacity: u32,
    write_pos: u32,
    is_sliding: bool,
) -> Result<()> {
    if num_kv_heads == 0 || head_dim == 0 {
        return Ok(());
    }

    // head_dim must be a power of two for the butterfly pattern.
    if !head_dim.is_power_of_two() {
        return Err(MlxError::InvalidArgument(format!(
            "hadamard_quantize_kv: head_dim must be a power of two, got {}",
            head_dim
        )));
    }

    // Shared memory: 2 * head_dim floats (data region + norm reduction scratch).
    // 2 * head_dim * 4 bytes <= 32768  =>  head_dim <= 4096.
    if head_dim > 4096 {
        return Err(MlxError::InvalidArgument(format!(
            "hadamard_quantize_kv: head_dim {} exceeds Metal 32 KB threadgroup limit \
             (max 4096 for 2x f32 shared memory)",
            head_dim
        )));
    }

    // Nibble packing requires an even head_dim (always true for powers of two >= 2).
    if head_dim % 2 != 0 {
        return Err(MlxError::InvalidArgument(format!(
            "hadamard_quantize_kv: head_dim must be even for nibble packing, got {}",
            head_dim
        )));
    }

    // For global (non-sliding) cache, write_pos must be within bounds.
    if !is_sliding && write_pos >= cache_capacity {
        return Err(MlxError::InvalidArgument(format!(
            "hadamard_quantize_kv: global cache write_pos({}) >= cache_capacity({})",
            write_pos, cache_capacity
        )));
    }

    // Validate source buffer size.
    let required_src = (num_kv_heads as u64) * (head_dim as u64);
    if (src.element_count() as u64) < required_src {
        return Err(MlxError::InvalidArgument(format!(
            "hadamard_quantize_kv: src has {} elements but need {} \
             (num_kv_heads={} * head_dim={})",
            src.element_count(),
            required_src,
            num_kv_heads,
            head_dim,
        )));
    }

    // Validate packed buffer size (in bytes).
    let required_packed_bytes =
        (num_kv_heads as u64) * (cache_capacity as u64) * (head_dim as u64 / 2);
    if (packed.byte_len() as u64) < required_packed_bytes {
        return Err(MlxError::InvalidArgument(format!(
            "hadamard_quantize_kv: packed buffer has {} bytes but need {} \
             (num_kv_heads={} * cache_capacity={} * head_dim/2={})",
            packed.byte_len(),
            required_packed_bytes,
            num_kv_heads,
            cache_capacity,
            head_dim / 2,
        )));
    }

    // Validate norms buffer size.
    let required_norms = (num_kv_heads as u64) * (cache_capacity as u64);
    if (norms.element_count() as u64) < required_norms {
        return Err(MlxError::InvalidArgument(format!(
            "hadamard_quantize_kv: norms buffer has {} elements but need {} \
             (num_kv_heads={} * cache_capacity={})",
            norms.element_count(),
            required_norms,
            num_kv_heads,
            cache_capacity,
        )));
    }

    // Use the fast SIMD-shuffle kernel (zero threadgroup barriers).
    let kernel_name = match head_dim {
        256 => "hadamard_quantize_kv_fast_d256",
        512 => "hadamard_quantize_kv_fast_d512",
        _ => "hadamard_quantize_kv", // fallback to shared-memory version
    };

    let pipeline = registry.get_pipeline(kernel_name, device)?;

    let params = HadamardQuantizeParams {
        head_dim,
        num_kv_heads,
        write_pos,
        cache_capacity,
        is_sliding: if is_sliding { 1 } else { 0 },
    };
    let params_bytes = bytemuck::bytes_of(&params);

    if kernel_name.starts_with("hadamard_quantize_kv_fast") {
        // Fast kernel: 1 simdgroup (32 threads) per head, no shared memory.
        use super::encode_helpers::{encode_threadgroups_with_args, KernelArg as KA};
        encode_threadgroups_with_args(
            encoder,
            pipeline,
            &[
                (0, KA::Buffer(src)),
                (1, KA::Buffer(packed)),
                (2, KA::Buffer(norms)),
                (3, KA::Bytes(params_bytes)),
            ],
            MTLSize::new(num_kv_heads as u64, 1, 1),
            MTLSize::new(32, 1, 1), // 1 simdgroup
        );
    } else {
        // Fallback: shared-memory version for non-256/512 head_dim.
        let shared_mem_bytes = 2u64 * (head_dim as u64) * 4;
        encode_threadgroups_with_args_and_shared(
            encoder,
            pipeline,
            &[
                (0, KernelArg::Buffer(src)),
                (1, KernelArg::Buffer(packed)),
                (2, KernelArg::Buffer(norms)),
                (3, KernelArg::Bytes(params_bytes)),
            ],
            &[(0, shared_mem_bytes)],
            MTLSize::new(num_kv_heads as u64, 1, 1),
            MTLSize::new(head_dim as u64, 1, 1),
        );
    }

    Ok(())
}