1use crate::buffer::MlxBuffer;
23use crate::device::MlxDevice;
24use crate::dtypes::DType;
25use crate::encoder::CommandEncoder;
26use crate::error::{MlxError, Result};
27use crate::kernel_registry::KernelRegistry;
28use crate::ops::quantized_matmul_ggml::GgmlType;
29
30#[derive(Debug, Clone, Copy)]
39pub struct MulMvExtParams {
40 pub m: u32,
42 pub n: u32,
44 pub k: u32,
46 pub batch: u32,
48 pub ggml_type: GgmlType,
51}
52
53#[repr(C)]
62#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
63struct MulMvExtGpuArgs {
64 ne00: i32,
65 ne01: i32,
66 ne02: i32,
67 _pad0: u32,
68 nb00: u64,
69 nb01: u64,
70 nb02: u64,
71 nb03: u64,
72 ne10: i32,
73 ne11: i32,
74 ne12: i32,
75 _pad1: u32,
76 nb10: u64,
77 nb11: u64,
78 nb12: u64,
79 nb13: u64,
80 ne0: i32,
81 ne1: i32,
82 r2: i16,
83 r3: i16,
84 _pad2: u32,
89}
90
91fn pick_nxpsg(k: u32, m: u32) -> i32 {
93 if k % 256 == 0 && m < 3 {
94 16
95 } else if k % 128 == 0 {
96 8
97 } else {
98 4
99 }
100}
101
102fn pick_r1ptg(m: u32) -> Result<i32> {
105 match m {
106 2 => Ok(2),
107 3 | 6 => Ok(3),
108 4 | 7 | 8 => Ok(4),
109 5 => Ok(5),
110 other => Err(MlxError::InvalidArgument(format!(
111 "mul_mv_ext: unsupported m {} (peer mapping covers 2..=8 only)",
112 other
113 ))),
114 }
115}
116
117fn kernel_name(ggml_type: GgmlType, r1ptg: i32) -> Result<&'static str> {
123 Ok(match (ggml_type, r1ptg) {
124 (GgmlType::Q5_1, 2) => "kernel_mul_mv_ext_q5_1_f32_r1_2",
125 (GgmlType::Q5_1, 3) => "kernel_mul_mv_ext_q5_1_f32_r1_3",
126 (GgmlType::Q5_1, 4) => "kernel_mul_mv_ext_q5_1_f32_r1_4",
127 (GgmlType::Q5_1, 5) => "kernel_mul_mv_ext_q5_1_f32_r1_5",
128 (GgmlType::IQ4_NL, 2) => "kernel_mul_mv_ext_iq4_nl_f32_r1_2",
129 (GgmlType::IQ4_NL, 3) => "kernel_mul_mv_ext_iq4_nl_f32_r1_3",
130 (GgmlType::IQ4_NL, 4) => "kernel_mul_mv_ext_iq4_nl_f32_r1_4",
131 (GgmlType::IQ4_NL, 5) => "kernel_mul_mv_ext_iq4_nl_f32_r1_5",
132 (GgmlType::Q4_0, 2) => "kernel_mul_mv_ext_q4_0_f32_r1_2",
133 (GgmlType::Q4_0, 3) => "kernel_mul_mv_ext_q4_0_f32_r1_3",
134 (GgmlType::Q4_0, 4) => "kernel_mul_mv_ext_q4_0_f32_r1_4",
135 (GgmlType::Q4_0, 5) => "kernel_mul_mv_ext_q4_0_f32_r1_5",
136 (GgmlType::Q8_0, 2) => "kernel_mul_mv_ext_q8_0_f32_r1_2",
137 (GgmlType::Q8_0, 3) => "kernel_mul_mv_ext_q8_0_f32_r1_3",
138 (GgmlType::Q8_0, 4) => "kernel_mul_mv_ext_q8_0_f32_r1_4",
139 (GgmlType::Q8_0, 5) => "kernel_mul_mv_ext_q8_0_f32_r1_5",
140 (GgmlType::Q4_K, 2) => "kernel_mul_mv_ext_q4_K_f32_r1_2",
141 (GgmlType::Q4_K, 3) => "kernel_mul_mv_ext_q4_K_f32_r1_3",
142 (GgmlType::Q4_K, 4) => "kernel_mul_mv_ext_q4_K_f32_r1_4",
143 (GgmlType::Q4_K, 5) => "kernel_mul_mv_ext_q4_K_f32_r1_5",
144 (GgmlType::Q5_K, 2) => "kernel_mul_mv_ext_q5_K_f32_r1_2",
145 (GgmlType::Q5_K, 3) => "kernel_mul_mv_ext_q5_K_f32_r1_3",
146 (GgmlType::Q5_K, 4) => "kernel_mul_mv_ext_q5_K_f32_r1_4",
147 (GgmlType::Q5_K, 5) => "kernel_mul_mv_ext_q5_K_f32_r1_5",
148 (GgmlType::Q6_K, 2) => "kernel_mul_mv_ext_q6_K_f32_r1_2",
149 (GgmlType::Q6_K, 3) => "kernel_mul_mv_ext_q6_K_f32_r1_3",
150 (GgmlType::Q6_K, 4) => "kernel_mul_mv_ext_q6_K_f32_r1_4",
151 (GgmlType::Q6_K, 5) => "kernel_mul_mv_ext_q6_K_f32_r1_5",
152 (other_type, other_r1) => {
153 return Err(MlxError::InvalidArgument(format!(
154 "mul_mv_ext: no kernel for type {:?} × r1ptg {} (Phase 1+4 ports Q4_0/Q8_0/Q4_K/Q5_K/Q6_K/Q5_1/IQ4_NL × r1∈{{2,3,4,5}})",
155 other_type, other_r1
156 )));
157 }
158 })
159}
160
161pub fn mul_mv_ext_dispatch(
172 encoder: &mut CommandEncoder,
173 registry: &mut KernelRegistry,
174 device: &MlxDevice,
175 weight: &MlxBuffer,
176 input: &MlxBuffer,
177 output: &MlxBuffer,
178 params: &MulMvExtParams,
179) -> Result<()> {
180 if params.m == 0 || params.n == 0 || params.k == 0 || params.batch == 0 {
181 return Err(MlxError::InvalidArgument(
182 "mul_mv_ext: m, n, k, batch must all be > 0".into(),
183 ));
184 }
185 let block_qk = params.ggml_type.block_values();
189 if params.k % block_qk != 0 {
190 return Err(MlxError::InvalidArgument(format!(
191 "mul_mv_ext: k ({}) must be divisible by block QK ({}) for {:?}",
192 params.k, block_qk, params.ggml_type
193 )));
194 }
195
196 let r1ptg = pick_r1ptg(params.m)?;
197 let nxpsg = pick_nxpsg(params.k, params.m);
198 let nsg: i32 = 2;
199 let nypsg = 32 / nxpsg;
200 let r0ptg = nypsg * nsg;
201
202 let kname = kernel_name(params.ggml_type, r1ptg)?;
203
204 let pipeline = registry
206 .get_pipeline_with_constants(
207 kname,
208 device.metal_device(),
209 &[],
210 &[(600, nsg), (601, nxpsg)],
211 )?
212 .clone();
213
214 let block_bytes_per_row =
217 (params.k as usize / block_qk as usize) * (params.ggml_type.block_bytes() as usize);
218 let weight_required = (params.n as usize) * block_bytes_per_row;
219 if weight.byte_len() < weight_required {
220 return Err(MlxError::InvalidArgument(format!(
221 "mul_mv_ext: weight buffer too small: {} < {} bytes",
222 weight.byte_len(),
223 weight_required
224 )));
225 }
226 let input_required = (params.batch as usize)
227 * (params.m as usize)
228 * (params.k as usize)
229 * DType::F32.size_of();
230 if input.byte_len() < input_required {
231 return Err(MlxError::InvalidArgument(format!(
232 "mul_mv_ext: input buffer too small: {} < {} bytes",
233 input.byte_len(),
234 input_required
235 )));
236 }
237 let output_required = (params.batch as usize)
238 * (params.m as usize)
239 * (params.n as usize)
240 * DType::F32.size_of();
241 if output.byte_len() < output_required {
242 return Err(MlxError::InvalidArgument(format!(
243 "mul_mv_ext: output buffer too small: {} < {} bytes",
244 output.byte_len(),
245 output_required
246 )));
247 }
248
249 let nb00 = params.ggml_type.block_bytes() as u64;
258 let nb01 = block_bytes_per_row as u64;
259 let nb02 = nb01 * params.n as u64;
260 let nb10: u64 = 4;
261 let nb11 = (params.k as u64) * 4;
262 let nb12 = nb11 * params.m as u64;
263 let args = MulMvExtGpuArgs {
264 ne00: params.k as i32,
265 ne01: params.n as i32,
266 ne02: 1,
267 _pad0: 0,
268 nb00,
269 nb01,
270 nb02,
271 nb03: nb02, ne10: params.k as i32,
273 ne11: params.m as i32,
274 ne12: params.batch as i32,
275 _pad1: 0,
276 nb10,
277 nb11,
278 nb12,
279 nb13: nb12, ne0: params.n as i32,
281 ne1: params.m as i32,
282 r2: 1,
283 r3: 1,
284 _pad2: 0,
285 };
286
287 use crate::encoder::{as_bytes, KernelArg};
288
289 let args_bytes = as_bytes(&args);
290 let r0_groups = ((params.n as i32) + r0ptg - 1) / r0ptg;
291 let r1_groups = ((params.m as i32) + r1ptg - 1) / r1ptg;
292
293 encoder.encode_threadgroups_with_args(
294 &pipeline,
295 &[
296 (0, KernelArg::Bytes(args_bytes)),
297 (1, KernelArg::Buffer(weight)),
298 (2, KernelArg::Buffer(input)),
299 (3, KernelArg::Buffer(output)),
300 ],
301 crate::MTLSize::new(r0_groups as u64, r1_groups as u64, params.batch as u64),
302 crate::MTLSize::new(32, nsg as u64, 1),
303 );
304
305 Ok(())
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn pick_nxpsg_matches_peer_logic() {
314 assert_eq!(pick_nxpsg(128, 1), 8);
316 assert_eq!(pick_nxpsg(256, 2), 16);
318 assert_eq!(pick_nxpsg(256, 3), 8);
320 assert_eq!(pick_nxpsg(64, 4), 4);
322 assert_eq!(pick_nxpsg(2816, 2), 16);
324 assert_eq!(pick_nxpsg(2816, 3), 8);
326 assert_eq!(pick_nxpsg(512, 4), 8);
328 }
329
330 #[test]
331 fn pick_r1ptg_matches_peer_switch() {
332 assert_eq!(pick_r1ptg(2).unwrap(), 2);
333 assert_eq!(pick_r1ptg(3).unwrap(), 3);
334 assert_eq!(pick_r1ptg(4).unwrap(), 4);
335 assert_eq!(pick_r1ptg(5).unwrap(), 5);
336 assert_eq!(pick_r1ptg(6).unwrap(), 3);
337 assert_eq!(pick_r1ptg(7).unwrap(), 4);
338 assert_eq!(pick_r1ptg(8).unwrap(), 4);
339 assert!(pick_r1ptg(1).is_err());
340 assert!(pick_r1ptg(9).is_err());
341 }
342
343 #[test]
344 fn kernel_name_covers_all_phase1_combinations() {
345 for r1 in 2..=5 {
346 assert!(kernel_name(GgmlType::Q5_1, r1).is_ok());
347 assert!(kernel_name(GgmlType::IQ4_NL, r1).is_ok());
348 }
349 }
356
357 #[test]
358 fn kernel_name_covers_all_phase4_combinations() {
359 for r1 in 2..=5 {
363 assert!(kernel_name(GgmlType::Q4_0, r1).is_ok(),
364 "Phase 4 Q4_0 r1={r1} must have a kernel");
365 assert!(kernel_name(GgmlType::Q8_0, r1).is_ok(),
366 "Phase 4 Q8_0 r1={r1} must have a kernel");
367 assert!(kernel_name(GgmlType::Q4_K, r1).is_ok(),
368 "Phase 4 Q4_K r1={r1} must have a kernel");
369 assert!(kernel_name(GgmlType::Q5_K, r1).is_ok(),
370 "Phase 4 Q5_K r1={r1} must have a kernel");
371 assert!(kernel_name(GgmlType::Q6_K, r1).is_ok(),
372 "Phase 4 Q6_K r1={r1} must have a kernel");
373 }
374 }
375
376 #[test]
377 fn kernel_name_rejects_unsupported_combinations() {
378 assert!(kernel_name(GgmlType::Q5_1, 1).is_err(),
381 "r1=1 not supported by any phase");
382 assert!(kernel_name(GgmlType::Q5_1, 6).is_err(),
383 "r1=6 not supported by any phase");
384 assert!(kernel_name(GgmlType::Q4_0, 0).is_err(),
385 "r1=0 not supported");
386 assert!(kernel_name(GgmlType::Q4_0, -1).is_err(),
387 "r1=-1 not supported");
388 }
389}