Skip to main content

mlx_native/ops/
quantized_matmul_id.rs

1//! Expert-routed (MoE) quantized matrix-vector multiply dispatch.
2//!
3//! Encodes a GPU compute command that performs, for each (token, expert-slot):
4//!   expert_id = ids[token * n_expert_used + slot]
5//!   output[token][slot][col] = sum_k(dequant(expert_weight[expert_id][col][k]) * input[token][k])
6//!
7//! This is the _id variant of quantized_matmul: same dequantization logic but
8//! with per-token expert selection via an ids buffer, enabling fused MoE dispatch.
9//!
10//! Portions derived from candle-metal-kernels v0.10.2 (Apache-2.0).
11//! See src/shaders/quantized_matmul_id.metal for full attribution.
12
13use crate::buffer::MlxBuffer;
14use crate::device::MlxDevice;
15use crate::dtypes::DType;
16use crate::encoder::CommandEncoder;
17use crate::error::{MlxError, Result};
18use crate::kernel_registry::KernelRegistry;
19
20/// Parameters describing the expert-routed quantized matmul dimensions.
21#[derive(Debug, Clone, Copy)]
22pub struct QuantizedMatmulIdParams {
23    /// Number of input rows (tokens).
24    pub m: u32,
25    /// Inner dimension (shared between input and weight).
26    pub k: u32,
27    /// Number of output columns per expert.
28    pub n: u32,
29    /// Number of consecutive values sharing one scale/bias pair.
30    pub group_size: u32,
31    /// Quantization bit width (4, 6, or 8).
32    pub bits: u32,
33    /// Number of experts each token is routed to (top-k).
34    pub n_expert_used: u32,
35    /// Total number of experts in the weight tensor.
36    pub num_experts: u32,
37}
38
39/// GPU-side params struct -- must match the Metal shader's QuantizedMatmulIdParams.
40#[repr(C)]
41#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
42struct QuantizedMatmulIdGpuParams {
43    m: u32,
44    k: u32,
45    n: u32,
46    group_size: u32,
47    bits: u32,
48    n_expert_used: u32,
49    num_experts: u32,
50    expert_weight_stride: u32,
51    expert_scales_stride: u32,
52    expert_biases_stride: u32,
53}
54
55/// Compute the expected weight buffer size in bytes for one expert.
56fn expert_weight_bytes(k: u32, n: u32, bits: u32) -> usize {
57    match bits {
58        4 => {
59            let values_per_pack = 8u32;
60            let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
61            (n as usize) * (packs_per_row as usize) * 4
62        }
63        6 => {
64            let triplets_per_row = (k + 3) / 4;
65            (n as usize) * (triplets_per_row as usize) * 3
66        }
67        8 => {
68            let values_per_pack = 4u32;
69            let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
70            (n as usize) * (packs_per_row as usize) * 4
71        }
72        _ => 0,
73    }
74}
75
76/// Compute the expected scales (or biases) element count for one expert.
77/// Each output column has ceil(K / group_size) groups, each with one bf16 value.
78fn expert_scales_elements(k: u32, n: u32, group_size: u32) -> usize {
79    let num_groups = (k + group_size - 1) / group_size;
80    (n as usize) * (num_groups as usize)
81}
82
83/// Encode an expert-routed quantized matrix multiplication onto the command encoder.
84///
85/// This does **not** commit the command buffer -- the caller is responsible for
86/// calling `encoder.commit_and_wait()` after encoding all desired operations.
87///
88/// # Arguments
89///
90/// * `encoder`  -- The command encoder to record the dispatch into.
91/// * `registry` -- Kernel registry (compiles the shader on first call).
92/// * `device`   -- The Metal device (needed for pipeline compilation and output allocation).
93/// * `input`    -- f32 input matrix buffer, shape `[M, K]`.
94/// * `weight`   -- Packed quantized weight buffer, shape `[num_experts, N, packed_k]` contiguous.
95/// * `scales`   -- bf16 scale buffer, shape `[num_experts, N, num_groups]` contiguous.
96/// * `biases`   -- bf16 bias buffer, shape `[num_experts, N, num_groups]` contiguous.
97/// * `ids`      -- u32 expert index buffer, shape `[M, n_expert_used]`.
98/// * `params`   -- Dimensions and quantization parameters.
99///
100/// # Returns
101///
102/// A freshly allocated `MlxBuffer` for the output of shape `[M, n_expert_used, N]`
103/// with dtype `F32`.
104///
105/// # Errors
106///
107/// * `MlxError::InvalidArgument` -- unsupported `bits` value, or buffer sizes
108///   do not match the expected dimensions.
109#[allow(clippy::too_many_arguments)]
110pub fn quantized_matmul_id(
111    encoder: &mut CommandEncoder,
112    registry: &mut KernelRegistry,
113    device: &MlxDevice,
114    input: &MlxBuffer,
115    weight: &MlxBuffer,
116    scales: &MlxBuffer,
117    biases: &MlxBuffer,
118    ids: &MlxBuffer,
119    params: &QuantizedMatmulIdParams,
120) -> Result<MlxBuffer> {
121    // --- Validate bits ---
122    if params.bits != 4 && params.bits != 6 && params.bits != 8 {
123        return Err(MlxError::InvalidArgument(format!(
124            "quantized_matmul_id: unsupported bits value {}; only 4, 6, and 8 are supported",
125            params.bits
126        )));
127    }
128
129    // --- Validate dimensions are non-zero ---
130    if params.m == 0 || params.k == 0 || params.n == 0 {
131        return Err(MlxError::InvalidArgument(
132            "quantized_matmul_id: M, K, and N must all be > 0".into(),
133        ));
134    }
135    if params.group_size == 0 {
136        return Err(MlxError::InvalidArgument(
137            "quantized_matmul_id: group_size must be > 0".into(),
138        ));
139    }
140    if params.n_expert_used == 0 {
141        return Err(MlxError::InvalidArgument(
142            "quantized_matmul_id: n_expert_used must be > 0".into(),
143        ));
144    }
145    if params.num_experts == 0 {
146        return Err(MlxError::InvalidArgument(
147            "quantized_matmul_id: num_experts must be > 0".into(),
148        ));
149    }
150
151    // --- Validate buffer sizes ---
152    let expected_input = (params.m as usize) * (params.k as usize) * DType::F32.size_of();
153    if input.byte_len() < expected_input {
154        return Err(MlxError::InvalidArgument(format!(
155            "quantized_matmul_id: input buffer too small: expected at least {} bytes for [{}x{}] f32, got {}",
156            expected_input, params.m, params.k, input.byte_len()
157        )));
158    }
159
160    let per_expert_w = expert_weight_bytes(params.k, params.n, params.bits);
161    let total_w = per_expert_w * (params.num_experts as usize);
162    if weight.byte_len() < total_w {
163        return Err(MlxError::InvalidArgument(format!(
164            "quantized_matmul_id: weight buffer too small: expected at least {} bytes for {} experts, got {}",
165            total_w, params.num_experts, weight.byte_len()
166        )));
167    }
168
169    let per_expert_s = expert_scales_elements(params.k, params.n, params.group_size);
170    let total_s_bytes = per_expert_s * (params.num_experts as usize) * 2; // 2 bytes per bf16
171    if scales.byte_len() < total_s_bytes {
172        return Err(MlxError::InvalidArgument(format!(
173            "quantized_matmul_id: scales buffer too small: expected at least {} bytes, got {}",
174            total_s_bytes, scales.byte_len()
175        )));
176    }
177    if biases.byte_len() < total_s_bytes {
178        return Err(MlxError::InvalidArgument(format!(
179            "quantized_matmul_id: biases buffer too small: expected at least {} bytes, got {}",
180            total_s_bytes, biases.byte_len()
181        )));
182    }
183
184    let expected_ids = (params.m as usize) * (params.n_expert_used as usize) * DType::U32.size_of();
185    if ids.byte_len() < expected_ids {
186        return Err(MlxError::InvalidArgument(format!(
187            "quantized_matmul_id: ids buffer too small: expected at least {} bytes for [{}x{}] u32, got {}",
188            expected_ids, params.m, params.n_expert_used, ids.byte_len()
189        )));
190    }
191
192    // --- Get (or compile) the pipeline ---
193    let pipeline = registry.get_pipeline("quantized_matmul_id", device.metal_device())?;
194
195    // --- Allocate output buffer ---
196    let output_elems = (params.m as usize) * (params.n_expert_used as usize) * (params.n as usize);
197    let output_bytes = output_elems * DType::F32.size_of();
198    let output = device.alloc_buffer(
199        output_bytes,
200        DType::F32,
201        vec![
202            params.m as usize,
203            params.n_expert_used as usize,
204            params.n as usize,
205        ],
206    )?;
207
208    // --- Create GPU params ---
209    let gpu_params = QuantizedMatmulIdGpuParams {
210        m: params.m,
211        k: params.k,
212        n: params.n,
213        group_size: params.group_size,
214        bits: params.bits,
215        n_expert_used: params.n_expert_used,
216        num_experts: params.num_experts,
217        expert_weight_stride: per_expert_w as u32,
218        expert_scales_stride: per_expert_s as u32,
219        expert_biases_stride: per_expert_s as u32,
220    };
221    let params_bytes = std::mem::size_of::<QuantizedMatmulIdGpuParams>();
222    let mut params_buf = device.alloc_buffer(params_bytes, DType::U32, vec![10])?;
223    {
224        let slice: &mut [QuantizedMatmulIdGpuParams] = bytemuck::cast_slice_mut(
225            params_buf
226                .as_mut_slice::<u8>()
227                .map_err(|e| MlxError::InvalidArgument(format!("params buf write: {e}")))?,
228        );
229        slice[0] = gpu_params;
230    }
231
232    // --- Dispatch ---
233    // Grid: (N, M * n_expert_used, 1)
234    let total_rows = (params.m as u64) * (params.n_expert_used as u64);
235    let tg_x = 16u64.min(params.n as u64);
236    let tg_y = 16u64.min(total_rows);
237    let threadgroup_size = metal::MTLSize::new(tg_x, tg_y, 1);
238
239    let grid_groups = metal::MTLSize::new(
240        (params.n as u64 + tg_x - 1) / tg_x,
241        (total_rows + tg_y - 1) / tg_y,
242        1,
243    );
244
245    encoder.encode_threadgroups(
246        pipeline,
247        &[
248            (0, input),
249            (1, weight),
250            (2, scales),
251            (3, biases),
252            (4, ids),
253            (5, &output),
254            (6, &params_buf),
255        ],
256        grid_groups,
257        threadgroup_size,
258    );
259
260    Ok(output)
261}