Skip to main content

quantized_matmul_simd

Function quantized_matmul_simd 

Source
pub fn quantized_matmul_simd(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &MlxDevice,
    input: &MlxBuffer,
    weight: &MlxBuffer,
    scales: &MlxBuffer,
    biases: &MlxBuffer,
    params: &QuantizedMatmulParams,
) -> Result<MlxBuffer>
Expand description

Encode a quantized matrix-vector multiply using the SIMD-cooperative kernel that matches MLX’s qmv_fast accumulation pattern exactly.

This kernel uses 2 simdgroups of 32 threads, each producing 4 output rows, with simd_sum() reduction. The accumulation order matches MLX bit-for-bit.

Falls back to the scalar quantized_matmul kernel if the dimensions don’t meet the alignment requirements.

§Arguments

Same as quantized_matmul.

§Returns

A freshly allocated MlxBuffer for the output of shape [M, N] with dtype F32.