1use crate::buffer::MlxBuffer;
14use crate::device::MlxDevice;
15use crate::dtypes::DType;
16use crate::encoder::CommandEncoder;
17use crate::error::{MlxError, Result};
18use crate::kernel_registry::KernelRegistry;
19
20#[derive(Debug, Clone, Copy)]
22pub struct QuantizedMatmulIdParams {
23 pub m: u32,
25 pub k: u32,
27 pub n: u32,
29 pub group_size: u32,
31 pub bits: u32,
33 pub n_expert_used: u32,
35 pub num_experts: u32,
37}
38
39#[repr(C)]
41#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
42struct QuantizedMatmulIdGpuParams {
43 m: u32,
44 k: u32,
45 n: u32,
46 group_size: u32,
47 bits: u32,
48 n_expert_used: u32,
49 num_experts: u32,
50 expert_weight_stride: u32,
51 expert_scales_stride: u32,
52 expert_biases_stride: u32,
53}
54
55fn expert_weight_bytes(k: u32, n: u32, bits: u32) -> usize {
57 match bits {
58 4 => {
59 let values_per_pack = 8u32;
60 let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
61 (n as usize) * (packs_per_row as usize) * 4
62 }
63 6 => {
64 let triplets_per_row = (k + 3) / 4;
65 (n as usize) * (triplets_per_row as usize) * 3
66 }
67 8 => {
68 let values_per_pack = 4u32;
69 let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
70 (n as usize) * (packs_per_row as usize) * 4
71 }
72 _ => 0,
73 }
74}
75
76fn expert_scales_elements(k: u32, n: u32, group_size: u32) -> usize {
79 let num_groups = (k + group_size - 1) / group_size;
80 (n as usize) * (num_groups as usize)
81}
82
83#[allow(clippy::too_many_arguments)]
110pub fn quantized_matmul_id(
111 encoder: &mut CommandEncoder,
112 registry: &mut KernelRegistry,
113 device: &MlxDevice,
114 input: &MlxBuffer,
115 weight: &MlxBuffer,
116 scales: &MlxBuffer,
117 biases: &MlxBuffer,
118 ids: &MlxBuffer,
119 params: &QuantizedMatmulIdParams,
120) -> Result<MlxBuffer> {
121 if params.bits != 4 && params.bits != 6 && params.bits != 8 {
123 return Err(MlxError::InvalidArgument(format!(
124 "quantized_matmul_id: unsupported bits value {}; only 4, 6, and 8 are supported",
125 params.bits
126 )));
127 }
128
129 if params.m == 0 || params.k == 0 || params.n == 0 {
131 return Err(MlxError::InvalidArgument(
132 "quantized_matmul_id: M, K, and N must all be > 0".into(),
133 ));
134 }
135 if params.group_size == 0 {
136 return Err(MlxError::InvalidArgument(
137 "quantized_matmul_id: group_size must be > 0".into(),
138 ));
139 }
140 if params.n_expert_used == 0 {
141 return Err(MlxError::InvalidArgument(
142 "quantized_matmul_id: n_expert_used must be > 0".into(),
143 ));
144 }
145 if params.num_experts == 0 {
146 return Err(MlxError::InvalidArgument(
147 "quantized_matmul_id: num_experts must be > 0".into(),
148 ));
149 }
150
151 let expected_input = (params.m as usize) * (params.k as usize) * DType::F32.size_of();
153 if input.byte_len() < expected_input {
154 return Err(MlxError::InvalidArgument(format!(
155 "quantized_matmul_id: input buffer too small: expected at least {} bytes for [{}x{}] f32, got {}",
156 expected_input, params.m, params.k, input.byte_len()
157 )));
158 }
159
160 let per_expert_w = expert_weight_bytes(params.k, params.n, params.bits);
161 let total_w = per_expert_w * (params.num_experts as usize);
162 if weight.byte_len() < total_w {
163 return Err(MlxError::InvalidArgument(format!(
164 "quantized_matmul_id: weight buffer too small: expected at least {} bytes for {} experts, got {}",
165 total_w, params.num_experts, weight.byte_len()
166 )));
167 }
168
169 let per_expert_s = expert_scales_elements(params.k, params.n, params.group_size);
170 let total_s_bytes = per_expert_s * (params.num_experts as usize) * 2; if scales.byte_len() < total_s_bytes {
172 return Err(MlxError::InvalidArgument(format!(
173 "quantized_matmul_id: scales buffer too small: expected at least {} bytes, got {}",
174 total_s_bytes, scales.byte_len()
175 )));
176 }
177 if biases.byte_len() < total_s_bytes {
178 return Err(MlxError::InvalidArgument(format!(
179 "quantized_matmul_id: biases buffer too small: expected at least {} bytes, got {}",
180 total_s_bytes, biases.byte_len()
181 )));
182 }
183
184 let expected_ids = (params.m as usize) * (params.n_expert_used as usize) * DType::U32.size_of();
185 if ids.byte_len() < expected_ids {
186 return Err(MlxError::InvalidArgument(format!(
187 "quantized_matmul_id: ids buffer too small: expected at least {} bytes for [{}x{}] u32, got {}",
188 expected_ids, params.m, params.n_expert_used, ids.byte_len()
189 )));
190 }
191
192 let pipeline = registry.get_pipeline("quantized_matmul_id", device.metal_device())?;
194
195 let output_elems = (params.m as usize) * (params.n_expert_used as usize) * (params.n as usize);
197 let output_bytes = output_elems * DType::F32.size_of();
198 let output = device.alloc_buffer(
199 output_bytes,
200 DType::F32,
201 vec![
202 params.m as usize,
203 params.n_expert_used as usize,
204 params.n as usize,
205 ],
206 )?;
207
208 let gpu_params = QuantizedMatmulIdGpuParams {
210 m: params.m,
211 k: params.k,
212 n: params.n,
213 group_size: params.group_size,
214 bits: params.bits,
215 n_expert_used: params.n_expert_used,
216 num_experts: params.num_experts,
217 expert_weight_stride: per_expert_w as u32,
218 expert_scales_stride: per_expert_s as u32,
219 expert_biases_stride: per_expert_s as u32,
220 };
221 let params_bytes = std::mem::size_of::<QuantizedMatmulIdGpuParams>();
222 let mut params_buf = device.alloc_buffer(params_bytes, DType::U32, vec![10])?;
223 {
224 let slice: &mut [QuantizedMatmulIdGpuParams] = bytemuck::cast_slice_mut(
225 params_buf
226 .as_mut_slice::<u8>()
227 .map_err(|e| MlxError::InvalidArgument(format!("params buf write: {e}")))?,
228 );
229 slice[0] = gpu_params;
230 }
231
232 let total_rows = (params.m as u64) * (params.n_expert_used as u64);
235 let tg_x = 16u64.min(params.n as u64);
236 let tg_y = 16u64.min(total_rows);
237 let threadgroup_size = metal::MTLSize::new(tg_x, tg_y, 1);
238
239 let grid_groups = metal::MTLSize::new(
240 (params.n as u64 + tg_x - 1) / tg_x,
241 (total_rows + tg_y - 1) / tg_y,
242 1,
243 );
244
245 encoder.encode_threadgroups(
246 pipeline,
247 &[
248 (0, input),
249 (1, weight),
250 (2, scales),
251 (3, biases),
252 (4, ids),
253 (5, &output),
254 (6, ¶ms_buf),
255 ],
256 grid_groups,
257 threadgroup_size,
258 );
259
260 Ok(output)
261}