Skip to main content

dispatch_quantized_matmul_simd_bf16

Function dispatch_quantized_matmul_simd_bf16 

Source
pub fn dispatch_quantized_matmul_simd_bf16(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &MlxDevice,
    input: &MlxBuffer,
    packed_weights: &MlxBuffer,
    scales: &MlxBuffer,
    biases: &MlxBuffer,
    params: &QuantizedMatmulParams,
) -> Result<MlxBuffer>
Expand description

Dispatch the bf16 I/O variant of the SIMD quantized matmul kernel.

Input and output are both bf16. Accumulation happens in f32 inside the shader for numerical stability, matching the precision of the f32 variant.

Falls back to the scalar quantized_matmul kernel (with f32 output) if the dimensions don’t satisfy SIMD alignment requirements.

§Arguments

  • encoder — The command encoder to record the dispatch into.
  • registry — Kernel registry (compiles the shader on first call).
  • device — The Metal device for buffer allocation.
  • input — bf16 input matrix buffer, shape [M, K].
  • packed_weights — Packed quantized weight buffer, shape [N, packed_k].
  • scales — bf16 scale buffer, shape [N, num_groups].
  • biases — bf16 bias buffer, shape [N, num_groups].
  • params — Dimensions and quantization parameters.

§Returns

A freshly allocated MlxBuffer for the output of shape [M, N] with dtype BF16.