Skip to main content

mlx_native/ops/
mul_mv_ext.rs

1//! ADR-022 Phase 1 P1.7 — `mul_mv_ext` r1 family for Q5_1 + IQ4_NL.
2//!
3//! Wraps the eight Metal kernels in `shaders/mul_mv_ext.metal`:
4//!
5//!   `kernel_mul_mv_ext_<q>_f32_r1_<r1>` for q ∈ {q5_1, iq4_nl},
6//!   r1 ∈ {2, 3, 4, 5}.
7//!
8//! The host dispatcher mirrors llama.cpp's
9//! `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp:2080-2152`:
10//!   - `nsg = 2` (constant)
11//!   - `nxpsg = 16` if `K % 256 == 0 && M < 3`,
12//!     else `8` if `K % 128 == 0`,
13//!     else `4`
14//!   - `r1ptg` selected by m: m=2→2, m∈{3,6}→3, m∈{4,7,8}→4, m=5→5
15//!   - threadgroups = (N/r0ptg, M/r1ptg, batch)
16//!   - threads-per-tg = (32, nsg, 1)
17//!
18//! Falls within the public dispatcher's m=2..8 batch range. ADR-022 P1.7
19//! ports it for Q5_1 + IQ4_NL only; Phase 4 will extend the family across
20//! Q4_0 / Q8_0 / Q4_K / Q5_K / Q6_K.
21
22use crate::buffer::MlxBuffer;
23use crate::device::MlxDevice;
24use crate::dtypes::DType;
25use crate::encoder::CommandEncoder;
26use crate::error::{MlxError, Result};
27use crate::kernel_registry::KernelRegistry;
28use crate::ops::quantized_matmul_ggml::GgmlType;
29
30/// Host-side parameters for [`mul_mv_ext_dispatch`].
31///
32/// Buffer layout contract:
33///   - `weight` (= src0): row-major `[N, blocks_per_row]` GGUF blocks.
34///   - `input`  (= src1): row-major `[batch, M, K]` f32. K = ne00.
35///   - `output` (= dst):  row-major `[batch, M, N]` f32. N = ne01 / ne0.
36///
37/// `r2`, `r3` model llama.cpp's batch-broadcast (default 1, 1).
38#[derive(Debug, Clone, Copy)]
39pub struct MulMvExtParams {
40    /// M — number of src1 rows (small batch, must be ∈ [2, 8]).
41    pub m: u32,
42    /// N — number of weight rows (output dim).
43    pub n: u32,
44    /// K — contract dim (input dim, must be divisible by 32).
45    pub k: u32,
46    /// Batch-broadcast factor for src0 vs src1 (typical 1).
47    pub batch: u32,
48    /// GGUF weight type. Phase 1 supports Q5_1 + IQ4_NL only; other types
49    /// return `MlxError::InvalidArgument`.
50    pub ggml_type: GgmlType,
51}
52
53/// GPU args struct — must match `hf2q_mul_mv_ext_args` in
54/// `shaders/mul_mv_ext.metal` byte-for-byte.
55///
56/// llama.cpp's C layout puts an int32 triple before u64 fields, then more
57/// int32 + u64, ending with two i16. The Metal-side struct's natural
58/// alignment matches this with padding inserted after `ne02` (4-byte pad
59/// before nb00 to reach 8-byte alignment). We model that explicitly so
60/// `bytemuck::Pod` is happy and the byte layout matches the GPU struct.
61#[repr(C)]
62#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
63struct MulMvExtGpuArgs {
64    ne00: i32,
65    ne01: i32,
66    ne02: i32,
67    _pad0: u32,
68    nb00: u64,
69    nb01: u64,
70    nb02: u64,
71    nb03: u64,
72    ne10: i32,
73    ne11: i32,
74    ne12: i32,
75    _pad1: u32,
76    nb10: u64,
77    nb11: u64,
78    nb12: u64,
79    nb13: u64,
80    ne0: i32,
81    ne1: i32,
82    r2: i16,
83    r3: i16,
84    // Trailing pad: largest member u64 (align 8) → struct size must be a
85    // multiple of 8. Without this, bytemuck::Pod's no-padding contract
86    // refuses the type. Match the implicit C tail padding the Metal
87    // compiler emits for the MSL struct.
88    _pad2: u32,
89}
90
91/// Pick `nxpsg` per llama.cpp's `ggml-metal-ops.cpp:2094-2100`.
92fn pick_nxpsg(k: u32, m: u32) -> i32 {
93    if k % 256 == 0 && m < 3 {
94        16
95    } else if k % 128 == 0 {
96        8
97    } else {
98        4
99    }
100}
101
102/// Pick `r1ptg` per llama.cpp's `ggml-metal-ops.cpp:2107-2120` switch.
103/// Returns `Err(InvalidArgument)` for unsupported m values.
104fn pick_r1ptg(m: u32) -> Result<i32> {
105    match m {
106        2 => Ok(2),
107        3 | 6 => Ok(3),
108        4 | 7 | 8 => Ok(4),
109        5 => Ok(5),
110        other => Err(MlxError::InvalidArgument(format!(
111            "mul_mv_ext: unsupported m {} (peer mapping covers 2..=8 only)",
112            other
113        ))),
114    }
115}
116
117/// Compose the kernel name from ggml_type + r1ptg, matching the metal
118/// shader's `[[host_name(...)]]` attributes.
119///
120/// Phase 1 (P1.7): Q5_1 + IQ4_NL × r1∈{2,3,4,5}.
121/// Phase 4: Q4_0, Q8_0, Q4_K, Q5_K, Q6_K × r1∈{2,3,4,5}.
122fn kernel_name(ggml_type: GgmlType, r1ptg: i32) -> Result<&'static str> {
123    Ok(match (ggml_type, r1ptg) {
124        (GgmlType::Q5_1, 2) => "kernel_mul_mv_ext_q5_1_f32_r1_2",
125        (GgmlType::Q5_1, 3) => "kernel_mul_mv_ext_q5_1_f32_r1_3",
126        (GgmlType::Q5_1, 4) => "kernel_mul_mv_ext_q5_1_f32_r1_4",
127        (GgmlType::Q5_1, 5) => "kernel_mul_mv_ext_q5_1_f32_r1_5",
128        (GgmlType::IQ4_NL, 2) => "kernel_mul_mv_ext_iq4_nl_f32_r1_2",
129        (GgmlType::IQ4_NL, 3) => "kernel_mul_mv_ext_iq4_nl_f32_r1_3",
130        (GgmlType::IQ4_NL, 4) => "kernel_mul_mv_ext_iq4_nl_f32_r1_4",
131        (GgmlType::IQ4_NL, 5) => "kernel_mul_mv_ext_iq4_nl_f32_r1_5",
132        (GgmlType::Q4_0, 2) => "kernel_mul_mv_ext_q4_0_f32_r1_2",
133        (GgmlType::Q4_0, 3) => "kernel_mul_mv_ext_q4_0_f32_r1_3",
134        (GgmlType::Q4_0, 4) => "kernel_mul_mv_ext_q4_0_f32_r1_4",
135        (GgmlType::Q4_0, 5) => "kernel_mul_mv_ext_q4_0_f32_r1_5",
136        (GgmlType::Q8_0, 2) => "kernel_mul_mv_ext_q8_0_f32_r1_2",
137        (GgmlType::Q8_0, 3) => "kernel_mul_mv_ext_q8_0_f32_r1_3",
138        (GgmlType::Q8_0, 4) => "kernel_mul_mv_ext_q8_0_f32_r1_4",
139        (GgmlType::Q8_0, 5) => "kernel_mul_mv_ext_q8_0_f32_r1_5",
140        (GgmlType::Q4_K, 2) => "kernel_mul_mv_ext_q4_K_f32_r1_2",
141        (GgmlType::Q4_K, 3) => "kernel_mul_mv_ext_q4_K_f32_r1_3",
142        (GgmlType::Q4_K, 4) => "kernel_mul_mv_ext_q4_K_f32_r1_4",
143        (GgmlType::Q4_K, 5) => "kernel_mul_mv_ext_q4_K_f32_r1_5",
144        (GgmlType::Q5_K, 2) => "kernel_mul_mv_ext_q5_K_f32_r1_2",
145        (GgmlType::Q5_K, 3) => "kernel_mul_mv_ext_q5_K_f32_r1_3",
146        (GgmlType::Q5_K, 4) => "kernel_mul_mv_ext_q5_K_f32_r1_4",
147        (GgmlType::Q5_K, 5) => "kernel_mul_mv_ext_q5_K_f32_r1_5",
148        (GgmlType::Q6_K, 2) => "kernel_mul_mv_ext_q6_K_f32_r1_2",
149        (GgmlType::Q6_K, 3) => "kernel_mul_mv_ext_q6_K_f32_r1_3",
150        (GgmlType::Q6_K, 4) => "kernel_mul_mv_ext_q6_K_f32_r1_4",
151        (GgmlType::Q6_K, 5) => "kernel_mul_mv_ext_q6_K_f32_r1_5",
152        (other_type, other_r1) => {
153            return Err(MlxError::InvalidArgument(format!(
154                "mul_mv_ext: no kernel for type {:?} × r1ptg {} (Phase 1+4 ports Q4_0/Q8_0/Q4_K/Q5_K/Q6_K/Q5_1/IQ4_NL × r1∈{{2,3,4,5}})",
155                other_type, other_r1
156            )));
157        }
158    })
159}
160
161/// Encode a `mul_mv_ext` dispatch.
162///
163/// # Errors
164///
165/// Returns `MlxError::InvalidArgument` if:
166///   - `ggml_type` is not Q5_1 or IQ4_NL,
167///   - `m` is outside [2, 8],
168///   - `k` is not divisible by 32,
169///   - any of `m`, `n`, `k`, `batch` is zero,
170///   - any buffer is too small.
171pub fn mul_mv_ext_dispatch(
172    encoder: &mut CommandEncoder,
173    registry: &mut KernelRegistry,
174    device: &MlxDevice,
175    weight: &MlxBuffer,
176    input: &MlxBuffer,
177    output: &MlxBuffer,
178    params: &MulMvExtParams,
179) -> Result<()> {
180    if params.m == 0 || params.n == 0 || params.k == 0 || params.batch == 0 {
181        return Err(MlxError::InvalidArgument(
182            "mul_mv_ext: m, n, k, batch must all be > 0".into(),
183        ));
184    }
185    // K must be divisible by the block size of the weight type.
186    // Legacy 32-element types (Q4_0/Q8_0/Q5_1/IQ4_NL): k % 32 == 0.
187    // K-quants (Q4_K/Q5_K/Q6_K): k % 256 == 0.
188    let block_qk = params.ggml_type.block_values();
189    if params.k % block_qk != 0 {
190        return Err(MlxError::InvalidArgument(format!(
191            "mul_mv_ext: k ({}) must be divisible by block QK ({}) for {:?}",
192            params.k, block_qk, params.ggml_type
193        )));
194    }
195
196    let r1ptg = pick_r1ptg(params.m)?;
197    let nxpsg = pick_nxpsg(params.k, params.m);
198    let nsg: i32 = 2;
199    let nypsg = 32 / nxpsg;
200    let r0ptg = nypsg * nsg;
201
202    let kname = kernel_name(params.ggml_type, r1ptg)?;
203
204    // PSO compile keyed on (kname, FC_mul_mv_nsg, FC_mul_mv_nxpsg).
205    let pipeline = registry
206        .get_pipeline_with_constants(
207            kname,
208            device.metal_device(),
209            &[],
210            &[(600, nsg), (601, nxpsg)],
211        )?
212        .clone();
213
214    // Buffer-size validation. Use the block-aware formula so K-quants
215    // (256-element blocks) work alongside legacy 32-element types.
216    let block_bytes_per_row =
217        (params.k as usize / block_qk as usize) * (params.ggml_type.block_bytes() as usize);
218    let weight_required = (params.n as usize) * block_bytes_per_row;
219    if weight.byte_len() < weight_required {
220        return Err(MlxError::InvalidArgument(format!(
221            "mul_mv_ext: weight buffer too small: {} < {} bytes",
222            weight.byte_len(),
223            weight_required
224        )));
225    }
226    let input_required = (params.batch as usize)
227        * (params.m as usize)
228        * (params.k as usize)
229        * DType::F32.size_of();
230    if input.byte_len() < input_required {
231        return Err(MlxError::InvalidArgument(format!(
232            "mul_mv_ext: input buffer too small: {} < {} bytes",
233            input.byte_len(),
234            input_required
235        )));
236    }
237    let output_required = (params.batch as usize)
238        * (params.m as usize)
239        * (params.n as usize)
240        * DType::F32.size_of();
241    if output.byte_len() < output_required {
242        return Err(MlxError::InvalidArgument(format!(
243            "mul_mv_ext: output buffer too small: {} < {} bytes",
244            output.byte_len(),
245            output_required
246        )));
247    }
248
249    // GPU args. nb01 = bytes per weight row; nb00 = block_bytes (1 block per
250    // QK4_0=32 elements). The args mirror llama.cpp's:
251    //   nb00 = ggml_type_size(weight)            (single block)
252    //   nb01 = ggml_row_size(weight, K)           (full row)
253    //   nb02 = nb01 * N                           (single batch)
254    //   nb10 = sizeof(float) = 4
255    //   nb11 = K * sizeof(float)                  (single src1 row)
256    //   nb12 = nb11 * M                           (single src1 batch)
257    let nb00 = params.ggml_type.block_bytes() as u64;
258    let nb01 = block_bytes_per_row as u64;
259    let nb02 = nb01 * params.n as u64;
260    let nb10: u64 = 4;
261    let nb11 = (params.k as u64) * 4;
262    let nb12 = nb11 * params.m as u64;
263    let args = MulMvExtGpuArgs {
264        ne00: params.k as i32,
265        ne01: params.n as i32,
266        ne02: 1,
267        _pad0: 0,
268        nb00,
269        nb01,
270        nb02,
271        nb03: nb02, // unused
272        ne10: params.k as i32,
273        ne11: params.m as i32,
274        ne12: params.batch as i32,
275        _pad1: 0,
276        nb10,
277        nb11,
278        nb12,
279        nb13: nb12, // unused
280        ne0: params.n as i32,
281        ne1: params.m as i32,
282        r2: 1,
283        r3: 1,
284        _pad2: 0,
285    };
286
287    use crate::encoder::{as_bytes, KernelArg};
288
289    let args_bytes = as_bytes(&args);
290    let r0_groups = ((params.n as i32) + r0ptg - 1) / r0ptg;
291    let r1_groups = ((params.m as i32) + r1ptg - 1) / r1ptg;
292
293    encoder.encode_threadgroups_with_args(
294        &pipeline,
295        &[
296            (0, KernelArg::Bytes(args_bytes)),
297            (1, KernelArg::Buffer(weight)),
298            (2, KernelArg::Buffer(input)),
299            (3, KernelArg::Buffer(output)),
300        ],
301        crate::MTLSize::new(r0_groups as u64, r1_groups as u64, params.batch as u64),
302        crate::MTLSize::new(32, nsg as u64, 1),
303    );
304
305    Ok(())
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn pick_nxpsg_matches_peer_logic() {
314        // K=128, M=1 → not %256 / yes %128 → 8
315        assert_eq!(pick_nxpsg(128, 1), 8);
316        // K=256, M=2 → yes %256 + M<3 → 16
317        assert_eq!(pick_nxpsg(256, 2), 16);
318        // K=256, M=3 → yes %256 + M>=3 → 8
319        assert_eq!(pick_nxpsg(256, 3), 8);
320        // K=64, M=4 → not %256 / not %128 → 4
321        assert_eq!(pick_nxpsg(64, 4), 4);
322        // K=2816, M=2 → 2816%256==0 (11×256) AND M<3 → 16
323        assert_eq!(pick_nxpsg(2816, 2), 16);
324        // K=2816, M=3 → 2816%256==0 but M>=3 → fallthrough %128==0 → 8
325        assert_eq!(pick_nxpsg(2816, 3), 8);
326        // K=512, M=4 → 512%256==0 but M>=3 → fallthrough %128==0 → 8
327        assert_eq!(pick_nxpsg(512, 4), 8);
328    }
329
330    #[test]
331    fn pick_r1ptg_matches_peer_switch() {
332        assert_eq!(pick_r1ptg(2).unwrap(), 2);
333        assert_eq!(pick_r1ptg(3).unwrap(), 3);
334        assert_eq!(pick_r1ptg(4).unwrap(), 4);
335        assert_eq!(pick_r1ptg(5).unwrap(), 5);
336        assert_eq!(pick_r1ptg(6).unwrap(), 3);
337        assert_eq!(pick_r1ptg(7).unwrap(), 4);
338        assert_eq!(pick_r1ptg(8).unwrap(), 4);
339        assert!(pick_r1ptg(1).is_err());
340        assert!(pick_r1ptg(9).is_err());
341    }
342
343    #[test]
344    fn kernel_name_covers_all_phase1_combinations() {
345        for r1 in 2..=5 {
346            assert!(kernel_name(GgmlType::Q5_1, r1).is_ok());
347            assert!(kernel_name(GgmlType::IQ4_NL, r1).is_ok());
348        }
349        // Note: Q4_0 was Phase 1's "no kernel" canary — but Phase 4
350        // (commit `9ee8a28`) added Q4_0/Q8_0/Q4_K/Q5_K/Q6_K coverage,
351        // so Q4_0 is now a hit, not a miss. See sibling test
352        // `kernel_name_covers_all_phase4_combinations` for the Phase 4
353        // coverage assertion + `kernel_name_rejects_unsupported_types`
354        // for the new "no kernel" canary.
355    }
356
357    #[test]
358    fn kernel_name_covers_all_phase4_combinations() {
359        // Phase 4 (commit `9ee8a28`): Q4_0, Q8_0, Q4_K, Q5_K, Q6_K
360        // × r1 ∈ {2, 3, 4, 5}. Pin coverage so future GgmlType
361        // additions don't silently drop Phase 4 wires.
362        for r1 in 2..=5 {
363            assert!(kernel_name(GgmlType::Q4_0, r1).is_ok(),
364                "Phase 4 Q4_0 r1={r1} must have a kernel");
365            assert!(kernel_name(GgmlType::Q8_0, r1).is_ok(),
366                "Phase 4 Q8_0 r1={r1} must have a kernel");
367            assert!(kernel_name(GgmlType::Q4_K, r1).is_ok(),
368                "Phase 4 Q4_K r1={r1} must have a kernel");
369            assert!(kernel_name(GgmlType::Q5_K, r1).is_ok(),
370                "Phase 4 Q5_K r1={r1} must have a kernel");
371            assert!(kernel_name(GgmlType::Q6_K, r1).is_ok(),
372                "Phase 4 Q6_K r1={r1} must have a kernel");
373        }
374    }
375
376    #[test]
377    fn kernel_name_rejects_unsupported_combinations() {
378        // r1 outside [2, 5] is rejected for ALL types (including
379        // Phase 1 + Phase 4 covered ones).
380        assert!(kernel_name(GgmlType::Q5_1, 1).is_err(),
381            "r1=1 not supported by any phase");
382        assert!(kernel_name(GgmlType::Q5_1, 6).is_err(),
383            "r1=6 not supported by any phase");
384        assert!(kernel_name(GgmlType::Q4_0, 0).is_err(),
385            "r1=0 not supported");
386        assert!(kernel_name(GgmlType::Q4_0, -1).is_err(),
387            "r1=-1 not supported");
388    }
389}