Skip to main content

quantized_matmul

Function quantized_matmul 

Source
pub fn quantized_matmul(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &MlxDevice,
    input: &MlxBuffer,
    weight: &MlxBuffer,
    scales: &MlxBuffer,
    biases: &MlxBuffer,
    params: &QuantizedMatmulParams,
) -> Result<MlxBuffer>
Expand description

Encode a quantized matrix multiplication onto the given command encoder.

This does not commit the command buffer — the caller is responsible for calling encoder.commit_and_wait() after encoding all desired operations.

§Arguments

  • encoder — The command encoder to record the dispatch into.
  • registry — Kernel registry (compiles the shader on first call).
  • device — The Metal device (needed for pipeline compilation and output allocation).
  • input — f32 input matrix buffer, shape [M, K].
  • weight — Packed quantized weight buffer, shape [N, packed_k].
  • scales — bf16 scale buffer, shape [N, num_groups].
  • biases — bf16 bias buffer, shape [N, num_groups].
  • params — Dimensions and quantization parameters.

§Returns

A freshly allocated MlxBuffer for the output of shape [M, N] with dtype F32.

§Errors

  • MlxError::InvalidArgument — unsupported bits value, or buffer sizes do not match the expected dimensions.