Skip to main content

mlx_native/ops/
quantized_matmul.rs

1//! Quantized matrix multiplication host-side dispatch.
2//!
3//! Encodes a GPU compute command that performs:
4//!   output[row][col] = sum_k(dequant(weight[col][k]) * input[row][k])
5//!
6//! Weights are stored in packed quantized format (4-bit or 6-bit) with per-group
7//! bf16 scales and biases for affine dequantization.
8
9use crate::buffer::MlxBuffer;
10use crate::device::MlxDevice;
11use crate::dtypes::DType;
12use crate::encoder::CommandEncoder;
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16/// Parameters describing the quantized matmul dimensions and format.
17#[derive(Debug, Clone, Copy)]
18pub struct QuantizedMatmulParams {
19    /// Number of input rows (tokens).
20    pub m: u32,
21    /// Inner dimension (shared between input and weight).
22    pub k: u32,
23    /// Number of output columns.
24    pub n: u32,
25    /// Number of consecutive values sharing one scale/bias pair.
26    pub group_size: u32,
27    /// Quantization bit width (4, 6, or 8).
28    pub bits: u32,
29}
30
31/// GPU-side params struct — must match the Metal shader's `QuantizedMatmulParams`.
32///
33/// This is `#[repr(C)]` to guarantee C-compatible layout for Metal buffer binding.
34#[repr(C)]
35#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
36struct QuantizedMatmulGpuParams {
37    m: u32,
38    k: u32,
39    n: u32,
40    group_size: u32,
41    bits: u32,
42}
43
44/// Compute the expected weight buffer size in bytes for the given parameters.
45///
46/// - 4-bit: 8 values per uint32, so each row of K values needs ceil(K/8) uint32s.
47///   Total = N * ceil(K/8) * 4 bytes.
48/// - 6-bit: 4 values per 3 bytes (MLX triplet packing), so each row needs
49///   ceil(K/4) * 3 bytes. Total = N * ceil(K/4) * 3 bytes.
50/// - 8-bit: 4 values per uint32, so each row needs ceil(K/4) uint32s.
51///   Total = N * ceil(K/4) * 4 bytes.
52fn expected_weight_bytes(k: u32, n: u32, bits: u32) -> usize {
53    match bits {
54        4 => {
55            let values_per_pack = 8u32;
56            let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
57            (n as usize) * (packs_per_row as usize) * 4
58        }
59        6 => {
60            // 4 values per 3-byte triplet
61            let triplets_per_row = (k + 3) / 4;
62            (n as usize) * (triplets_per_row as usize) * 3
63        }
64        8 => {
65            let values_per_pack = 4u32;
66            let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
67            (n as usize) * (packs_per_row as usize) * 4
68        }
69        _ => 0,
70    }
71}
72
73/// Compute the expected scales (or biases) buffer size in bytes.
74///
75/// Each output column has ceil(K / group_size) groups, each with one bf16 value.
76/// Total = N * ceil(K / group_size) * 2 bytes.
77fn expected_scales_bytes(k: u32, n: u32, group_size: u32) -> usize {
78    let num_groups = (k + group_size - 1) / group_size;
79    (n as usize) * (num_groups as usize) * 2 // 2 bytes per bf16
80}
81
82/// Encode a quantized matrix multiplication onto the given command encoder.
83///
84/// This does **not** commit the command buffer — the caller is responsible for
85/// calling `encoder.commit_and_wait()` after encoding all desired operations.
86///
87/// # Arguments
88///
89/// * `encoder`  — The command encoder to record the dispatch into.
90/// * `registry` — Kernel registry (compiles the shader on first call).
91/// * `device`   — The Metal device (needed for pipeline compilation and output allocation).
92/// * `input`    — f32 input matrix buffer, shape `[M, K]`.
93/// * `weight`   — Packed quantized weight buffer, shape `[N, packed_k]`.
94/// * `scales`   — bf16 scale buffer, shape `[N, num_groups]`.
95/// * `biases`   — bf16 bias buffer, shape `[N, num_groups]`.
96/// * `params`   — Dimensions and quantization parameters.
97///
98/// # Returns
99///
100/// A freshly allocated `MlxBuffer` for the output of shape `[M, N]` with dtype `F32`.
101///
102/// # Errors
103///
104/// * `MlxError::InvalidArgument` — unsupported `bits` value, or buffer sizes
105///   do not match the expected dimensions.
106pub fn quantized_matmul(
107    encoder: &mut CommandEncoder,
108    registry: &mut KernelRegistry,
109    device: &MlxDevice,
110    input: &MlxBuffer,
111    weight: &MlxBuffer,
112    scales: &MlxBuffer,
113    biases: &MlxBuffer,
114    params: &QuantizedMatmulParams,
115) -> Result<MlxBuffer> {
116    // --- Validate bits ---
117    if params.bits != 4 && params.bits != 6 && params.bits != 8 {
118        return Err(MlxError::InvalidArgument(format!(
119            "Unsupported bits value {}; only 4, 6, and 8 are supported",
120            params.bits
121        )));
122    }
123
124    // --- Validate dimensions are non-zero ---
125    if params.m == 0 || params.k == 0 || params.n == 0 {
126        return Err(MlxError::InvalidArgument(
127            "M, K, and N must all be > 0".into(),
128        ));
129    }
130    if params.group_size == 0 {
131        return Err(MlxError::InvalidArgument(
132            "group_size must be > 0".into(),
133        ));
134    }
135
136    // --- Validate buffer sizes ---
137    let expected_input = (params.m as usize) * (params.k as usize) * DType::F32.size_of();
138    if input.byte_len() < expected_input {
139        return Err(MlxError::InvalidArgument(format!(
140            "Input buffer too small: expected at least {} bytes for [{}x{}] f32, got {}",
141            expected_input, params.m, params.k, input.byte_len()
142        )));
143    }
144
145    let expected_w = expected_weight_bytes(params.k, params.n, params.bits);
146    if weight.byte_len() < expected_w {
147        return Err(MlxError::InvalidArgument(format!(
148            "Weight buffer too small: expected at least {} bytes for {}bit [{}x{}], got {}",
149            expected_w, params.bits, params.n, params.k, weight.byte_len()
150        )));
151    }
152
153    let expected_s = expected_scales_bytes(params.k, params.n, params.group_size);
154    if scales.byte_len() < expected_s {
155        return Err(MlxError::InvalidArgument(format!(
156            "Scales buffer too small: expected at least {} bytes, got {}",
157            expected_s, scales.byte_len()
158        )));
159    }
160    if biases.byte_len() < expected_s {
161        return Err(MlxError::InvalidArgument(format!(
162            "Biases buffer too small: expected at least {} bytes, got {}",
163            expected_s, biases.byte_len()
164        )));
165    }
166
167    // --- Get (or compile) the pipeline ---
168    let pipeline = registry.get_pipeline("quantized_matmul", device.metal_device())?;
169
170    // --- Allocate output buffer ---
171    // Output is f32 to avoid f16 overflow (max ~65504) on projections with large
172    // accumulated values (e.g. attention output projections where K=4096).
173    let output_bytes = (params.m as usize) * (params.n as usize) * DType::F32.size_of();
174    let output = device.alloc_buffer(
175        output_bytes,
176        DType::F32,
177        vec![params.m as usize, params.n as usize],
178    )?;
179
180    // --- Create GPU params buffer ---
181    let gpu_params = QuantizedMatmulGpuParams {
182        m: params.m,
183        k: params.k,
184        n: params.n,
185        group_size: params.group_size,
186        bits: params.bits,
187    };
188    let params_bytes = std::mem::size_of::<QuantizedMatmulGpuParams>();
189    let mut params_buf = device.alloc_buffer(params_bytes, DType::U32, vec![5])?;
190    {
191        let slice: &mut [QuantizedMatmulGpuParams] = bytemuck::cast_slice_mut(
192            params_buf
193                .as_mut_slice::<u8>()
194                .map_err(|e| MlxError::InvalidArgument(format!("params buf write: {e}")))?,
195        );
196        slice[0] = gpu_params;
197    }
198
199    // --- Dispatch ---
200    // Grid: (N, M, 1) — one thread per output element.
201    // Threadgroup: up to 256 threads, arranged as (tx, ty, 1).
202    // We use encode_threadgroups with explicit threadgroup counts for non-even dims.
203
204    // Choose threadgroup size: we want a 2D block that fits within 256 threads.
205    // Use 16x16 = 256 as a good default for 2D dispatch.
206    let tg_x = 16u64.min(params.n as u64);
207    let tg_y = 16u64.min(params.m as u64);
208    let threadgroup_size = metal::MTLSize::new(tg_x, tg_y, 1);
209
210    // Use encode_threadgroups with ceil-division for non-even dimensions.
211    let grid_groups = metal::MTLSize::new(
212        (params.n as u64 + tg_x - 1) / tg_x,
213        (params.m as u64 + tg_y - 1) / tg_y,
214        1,
215    );
216
217    encoder.encode_threadgroups(
218        pipeline,
219        &[
220            (0, input),
221            (1, weight),
222            (2, scales),
223            (3, biases),
224            (4, &output),
225            (5, &params_buf),
226        ],
227        grid_groups,
228        threadgroup_size,
229    );
230
231    Ok(output)
232}
233
234/// Check whether the SIMD-cooperative kernel can be used for the given params.
235///
236/// The SIMD path requires:
237///   - bits is 4 or 8 (not 6)
238///   - N is divisible by 8 (results_per_simdgroup * num_simdgroups)
239///   - K is divisible by block_size:
240///     - 4-bit: K % 256 == 0 (block_size = 8 values/thread * 32 SIMD = 256, qmv)
241///     - 8-bit: K % 256 == 0 (block_size = 8 * 32 = 256)
242///
243/// NOTE: The f32 SIMD kernel uses qmv params (values_per_thread=8) not qmv_fast,
244/// because K=2816 (Gemma 4 hidden_size) is 256-aligned but not 512-aligned.
245/// The bf16 SIMD kernels use qmv_fast (values_per_thread=16, block_size=512).
246/// Check alignment for f32 SIMD kernel (qmv: values_per_thread=8, block_size=256).
247fn can_use_simd_kernel(params: &QuantizedMatmulParams) -> bool {
248    let bn = 8u32; // num_simdgroups * results_per_simdgroup
249    if params.n % bn != 0 {
250        return false;
251    }
252    match params.bits {
253        4 => params.k % 256 == 0,  // qmv: block_size = 8 * 32 = 256
254        8 => params.k % 256 == 0,
255        _ => false,
256    }
257}
258
259/// Check alignment for bf16 SIMD kernel (qmv_fast: values_per_thread=16, block_size=512).
260fn can_use_simd_kernel_bf16(params: &QuantizedMatmulParams) -> bool {
261    let bn = 8u32;
262    if params.n % bn != 0 {
263        return false;
264    }
265    match params.bits {
266        4 => params.k % 512 == 0,  // qmv_fast: block_size = 16 * 32 = 512
267        8 => params.k % 256 == 0,
268        _ => false,
269    }
270}
271
272/// Encode a quantized matrix-vector multiply using the SIMD-cooperative kernel
273/// that matches MLX's `qmv_fast` accumulation pattern exactly.
274///
275/// This kernel uses 2 simdgroups of 32 threads, each producing 4 output rows,
276/// with `simd_sum()` reduction. The accumulation order matches MLX bit-for-bit.
277///
278/// Falls back to the scalar `quantized_matmul` kernel if the dimensions don't
279/// meet the alignment requirements.
280///
281/// # Arguments
282///
283/// Same as [`quantized_matmul`].
284///
285/// # Returns
286///
287/// A freshly allocated `MlxBuffer` for the output of shape `[M, N]` with dtype `F32`.
288pub fn quantized_matmul_simd(
289    encoder: &mut CommandEncoder,
290    registry: &mut KernelRegistry,
291    device: &MlxDevice,
292    input: &MlxBuffer,
293    weight: &MlxBuffer,
294    scales: &MlxBuffer,
295    biases: &MlxBuffer,
296    params: &QuantizedMatmulParams,
297) -> Result<MlxBuffer> {
298    // Fall back to scalar kernel if dimensions don't support SIMD path.
299    if !can_use_simd_kernel(params) {
300        return quantized_matmul(encoder, registry, device, input, weight, scales, biases, params);
301    }
302
303    // --- Validate bits ---
304    // 6-bit: fall back to scalar GPU kernel (SIMD path only handles 4/8-bit)
305    if params.bits == 6 {
306        return quantized_matmul(encoder, registry, device, input, weight, scales, biases, params);
307    }
308    if params.bits != 4 && params.bits != 8 {
309        return Err(MlxError::InvalidArgument(format!(
310            "SIMD kernel: unsupported bits value {}; only 4, 6, and 8 are supported",
311            params.bits
312        )));
313    }
314
315    // --- Validate dimensions are non-zero ---
316    if params.m == 0 || params.k == 0 || params.n == 0 {
317        return Err(MlxError::InvalidArgument(
318            "M, K, and N must all be > 0".into(),
319        ));
320    }
321    if params.group_size == 0 {
322        return Err(MlxError::InvalidArgument(
323            "group_size must be > 0".into(),
324        ));
325    }
326
327    // --- Validate buffer sizes ---
328    let expected_input = (params.m as usize) * (params.k as usize) * DType::F32.size_of();
329    if input.byte_len() < expected_input {
330        return Err(MlxError::InvalidArgument(format!(
331            "Input buffer too small: expected at least {} bytes for [{}x{}] f32, got {}",
332            expected_input, params.m, params.k, input.byte_len()
333        )));
334    }
335
336    let expected_w = expected_weight_bytes(params.k, params.n, params.bits);
337    if weight.byte_len() < expected_w {
338        return Err(MlxError::InvalidArgument(format!(
339            "Weight buffer too small: expected at least {} bytes for {}bit [{}x{}], got {}",
340            expected_w, params.bits, params.n, params.k, weight.byte_len()
341        )));
342    }
343
344    let expected_s = expected_scales_bytes(params.k, params.n, params.group_size);
345    if scales.byte_len() < expected_s {
346        return Err(MlxError::InvalidArgument(format!(
347            "Scales buffer too small: expected at least {} bytes, got {}",
348            expected_s, scales.byte_len()
349        )));
350    }
351    if biases.byte_len() < expected_s {
352        return Err(MlxError::InvalidArgument(format!(
353            "Biases buffer too small: expected at least {} bytes, got {}",
354            expected_s, biases.byte_len()
355        )));
356    }
357
358    // --- Get (or compile) the SIMD pipeline ---
359    let pipeline = registry.get_pipeline("quantized_matmul_simd", device.metal_device())?;
360
361    // --- Allocate output buffer ---
362    let output_bytes = (params.m as usize) * (params.n as usize) * DType::F32.size_of();
363    let output = device.alloc_buffer(
364        output_bytes,
365        DType::F32,
366        vec![params.m as usize, params.n as usize],
367    )?;
368
369    // --- Create GPU params buffer ---
370    let gpu_params = QuantizedMatmulGpuParams {
371        m: params.m,
372        k: params.k,
373        n: params.n,
374        group_size: params.group_size,
375        bits: params.bits,
376    };
377    let params_bytes = std::mem::size_of::<QuantizedMatmulGpuParams>();
378    let mut params_buf = device.alloc_buffer(params_bytes, DType::U32, vec![5])?;
379    {
380        let slice: &mut [QuantizedMatmulGpuParams] = bytemuck::cast_slice_mut(
381            params_buf
382                .as_mut_slice::<u8>()
383                .map_err(|e| MlxError::InvalidArgument(format!("params buf write: {e}")))?,
384        );
385        slice[0] = gpu_params;
386    }
387
388    // --- Dispatch with MLX's qmv_fast pattern ---
389    // threadgroup_size = (SIMD_SIZE, num_simdgroups, 1) = (32, 2, 1) = 64 threads
390    // threadgroups     = (M, ceil(N / 8), 1)
391    //
392    // In the kernel:
393    //   tid.x = input row (M dimension)
394    //   tid.y = output column block (each block produces 8 output columns)
395    //   simd_gid = which simdgroup (0 or 1), each handles 4 of the 8 columns
396    //   simd_lid = thread index within simdgroup (0..31)
397    let num_simdgroups = 2u64;
398    let results_per_simdgroup = 4u64;
399    let bn = num_simdgroups * results_per_simdgroup; // 8
400
401    let threadgroup_size = metal::MTLSize::new(32, num_simdgroups, 1);
402    let threadgroups = metal::MTLSize::new(
403        params.m as u64,
404        (params.n as u64 + bn - 1) / bn,
405        1,
406    );
407
408    encoder.encode_threadgroups(
409        pipeline,
410        &[
411            (0, input),
412            (1, weight),
413            (2, scales),
414            (3, biases),
415            (4, &output),
416            (5, &params_buf),
417        ],
418        threadgroups,
419        threadgroup_size,
420    );
421
422    Ok(output)
423}
424
425// ---------------------------------------------------------------------------
426// bf16 I/O variant — eliminates 2 cast dispatches per projection by accepting
427// bfloat input and producing bfloat output directly.
428// ---------------------------------------------------------------------------
429
430/// GPU-side params for the bf16 kernels.  Identical layout to
431/// `QuantizedMatmulGpuParams`; kept as a separate type for clarity.
432#[repr(C)]
433#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
434struct QMatmulBf16GpuParams {
435    m: u32,
436    k: u32,
437    n: u32,
438    group_size: u32,
439    bits: u32,
440}
441
442/// Dispatch the bf16 I/O variant of the SIMD quantized matmul kernel.
443///
444/// Input and output are both bf16.  Accumulation happens in f32 inside the
445/// shader for numerical stability, matching the precision of the f32 variant.
446///
447/// Falls back to the scalar `quantized_matmul` kernel (with f32 output) if the
448/// dimensions don't satisfy SIMD alignment requirements.
449///
450/// # Arguments
451///
452/// * `encoder`      — The command encoder to record the dispatch into.
453/// * `registry`     — Kernel registry (compiles the shader on first call).
454/// * `device`       — The Metal device for buffer allocation.
455/// * `input`        — bf16 input matrix buffer, shape `[M, K]`.
456/// * `packed_weights` — Packed quantized weight buffer, shape `[N, packed_k]`.
457/// * `scales`       — bf16 scale buffer, shape `[N, num_groups]`.
458/// * `biases`       — bf16 bias buffer, shape `[N, num_groups]`.
459/// * `params`       — Dimensions and quantization parameters.
460///
461/// # Returns
462///
463/// A freshly allocated `MlxBuffer` for the output of shape `[M, N]` with dtype `BF16`.
464pub fn dispatch_quantized_matmul_simd_bf16(
465    encoder: &mut CommandEncoder,
466    registry: &mut KernelRegistry,
467    device: &MlxDevice,
468    input: &MlxBuffer,
469    packed_weights: &MlxBuffer,
470    scales: &MlxBuffer,
471    biases: &MlxBuffer,
472    params: &QuantizedMatmulParams,
473) -> Result<MlxBuffer> {
474    // Fall back to f32 cast+scalar path for shapes where bf16 SIMD (qmv_fast)
475    // alignment is not met (e.g. K=2816 is 256-aligned but not 512-aligned).
476    if !can_use_simd_kernel_bf16(params) {
477        let n_in = (params.m as usize) * (params.k as usize);
478        let f32_input = if input.dtype() == DType::BF16 {
479            let f32_buf = device.alloc_buffer(n_in * DType::F32.size_of(), DType::F32, vec![params.m as usize, params.k as usize])?;
480            crate::ops::elementwise::cast(
481                encoder, registry, device.metal_device(),
482                input, &f32_buf, n_in,
483                crate::ops::elementwise::CastDirection::BF16ToF32,
484            )?;
485            Some(f32_buf)
486        } else {
487            None
488        };
489        let actual_input = f32_input.as_ref().unwrap_or(input);
490        let f32_result = quantized_matmul(encoder, registry, device, actual_input, packed_weights, scales, biases, params)?;
491        // Cast f32 output back to bf16
492        let n_out = (params.m as usize) * (params.n as usize);
493        let bf16_out = device.alloc_buffer(n_out * DType::BF16.size_of(), DType::BF16, vec![params.m as usize, params.n as usize])?;
494        crate::ops::elementwise::cast(
495            encoder, registry, device.metal_device(),
496            &f32_result, &bf16_out, n_out,
497            crate::ops::elementwise::CastDirection::F32ToBF16,
498        )?;
499        return Ok(bf16_out);
500    }
501
502    // 6-bit: already handled by the !can_use_simd_kernel_bf16 fallback above
503    // which routes through quantized_matmul (scalar GPU kernel that supports 6-bit).
504    // But add an explicit check to be safe.
505    if params.bits == 6 {
506        // Route through the same fallback path as non-SIMD-aligned dimensions
507        let n_in = (params.m as usize) * (params.k as usize);
508        let f32_input = if input.dtype() == DType::BF16 {
509            let f32_buf = device.alloc_buffer(n_in * DType::F32.size_of(), DType::F32, vec![params.m as usize, params.k as usize])?;
510            crate::ops::elementwise::cast(
511                encoder, registry, device.metal_device(),
512                input, &f32_buf, n_in,
513                crate::ops::elementwise::CastDirection::BF16ToF32,
514            )?;
515            Some(f32_buf)
516        } else {
517            None
518        };
519        let actual_input = f32_input.as_ref().unwrap_or(input);
520        let f32_result = quantized_matmul(encoder, registry, device, actual_input, packed_weights, scales, biases, params)?;
521        let n_out = (params.m as usize) * (params.n as usize);
522        let bf16_out = device.alloc_buffer(n_out * DType::BF16.size_of(), DType::BF16, vec![params.m as usize, params.n as usize])?;
523        crate::ops::elementwise::cast(
524            encoder, registry, device.metal_device(),
525            &f32_result, &bf16_out, n_out,
526            crate::ops::elementwise::CastDirection::F32ToBF16,
527        )?;
528        return Ok(bf16_out);
529    }
530    if params.bits != 4 && params.bits != 8 {
531        return Err(MlxError::InvalidArgument(format!(
532            "bf16 SIMD kernel: unsupported bits value {}; only 4, 6, and 8 are supported",
533            params.bits
534        )));
535    }
536    if params.m == 0 || params.k == 0 || params.n == 0 || params.group_size == 0 {
537        return Err(MlxError::InvalidArgument(
538            "M, K, N, and group_size must all be > 0".into(),
539        ));
540    }
541
542    // Buffer size validation (input is bf16 = 2 bytes per element).
543    let expected_input = (params.m as usize) * (params.k as usize) * DType::BF16.size_of();
544    if input.byte_len() < expected_input {
545        return Err(MlxError::InvalidArgument(format!(
546            "bf16 input buffer too small: expected {} bytes for [{}x{}] bf16, got {}",
547            expected_input, params.m, params.k, input.byte_len()
548        )));
549    }
550
551    let expected_w = expected_weight_bytes(params.k, params.n, params.bits);
552    if packed_weights.byte_len() < expected_w {
553        return Err(MlxError::InvalidArgument(format!(
554            "Weight buffer too small: expected {} bytes, got {}",
555            expected_w, packed_weights.byte_len()
556        )));
557    }
558
559    let expected_s = expected_scales_bytes(params.k, params.n, params.group_size);
560    if scales.byte_len() < expected_s {
561        return Err(MlxError::InvalidArgument(format!(
562            "Scales buffer too small: expected {} bytes, got {}",
563            expected_s, scales.byte_len()
564        )));
565    }
566    if biases.byte_len() < expected_s {
567        return Err(MlxError::InvalidArgument(format!(
568            "Biases buffer too small: expected {} bytes, got {}",
569            expected_s, biases.byte_len()
570        )));
571    }
572
573    let pipeline = registry.get_pipeline("quantized_matmul_simd_bf16", device.metal_device())?;
574
575    // Output is bf16.
576    let output_bytes = (params.m as usize) * (params.n as usize) * DType::BF16.size_of();
577    let output = device.alloc_buffer(
578        output_bytes,
579        DType::BF16,
580        vec![params.m as usize, params.n as usize],
581    )?;
582
583    let gpu_params = QMatmulBf16GpuParams {
584        m: params.m,
585        k: params.k,
586        n: params.n,
587        group_size: params.group_size,
588        bits: params.bits,
589    };
590    let params_bytes = std::mem::size_of::<QMatmulBf16GpuParams>();
591    let mut params_buf = device.alloc_buffer(params_bytes, DType::U32, vec![5])?;
592    {
593        let slice: &mut [QMatmulBf16GpuParams] = bytemuck::cast_slice_mut(
594            params_buf
595                .as_mut_slice::<u8>()
596                .map_err(|e| MlxError::InvalidArgument(format!("params buf write: {e}")))?,
597        );
598        slice[0] = gpu_params;
599    }
600
601    let num_simdgroups = 2u64;
602    let results_per_simdgroup = 4u64;
603    let bn = num_simdgroups * results_per_simdgroup; // 8
604
605    let threadgroup_size = metal::MTLSize::new(32, num_simdgroups, 1);
606    let threadgroups = metal::MTLSize::new(
607        params.m as u64,
608        (params.n as u64 + bn - 1) / bn,
609        1,
610    );
611
612    encoder.encode_threadgroups(
613        pipeline,
614        &[
615            (0, input),
616            (1, packed_weights),
617            (2, scales),
618            (3, biases),
619            (4, &output),
620            (5, &params_buf),
621        ],
622        threadgroups,
623        threadgroup_size,
624    );
625
626    Ok(output)
627}
628
629/// Dispatch bf16 quantized matmul with expert offset for MoE inference.
630///
631/// Indexes into a 3D packed weight tensor `[n_experts, rows, packed_cols]` using
632/// byte offsets, eliminating CPU memcpy for expert weight selection.
633///
634/// # Arguments
635///
636/// * `encoder`              — The command encoder to record the dispatch into.
637/// * `registry`             — Kernel registry (compiles the shader on first call).
638/// * `device`               — The Metal device for buffer allocation.
639/// * `input`                — bf16 input matrix buffer, shape `[M, K]`.
640/// * `packed_weights`       — Full 3D packed weight tensor for all experts.
641/// * `scales`               — Full scales buffer for all experts.
642/// * `biases`               — Full biases buffer for all experts.
643/// * `params`               — Dimensions for this expert's projection (M, K, N).
644/// * `expert_offset_bytes`  — Byte offset into `packed_weights` for this expert.
645/// * `scales_offset_bytes`  — Byte offset into `scales` for this expert.
646/// * `biases_offset_bytes`  — Byte offset into `biases` for this expert.
647///
648/// # Returns
649///
650/// A freshly allocated `MlxBuffer` for the output of shape `[M, N]` with dtype `BF16`.
651pub fn dispatch_quantized_matmul_simd_bf16_expert(
652    encoder: &mut CommandEncoder,
653    registry: &mut KernelRegistry,
654    device: &MlxDevice,
655    input: &MlxBuffer,
656    packed_weights: &MlxBuffer,
657    scales: &MlxBuffer,
658    biases: &MlxBuffer,
659    params: &QuantizedMatmulParams,
660    expert_offset_bytes: u32,
661    scales_offset_bytes: u32,
662    biases_offset_bytes: u32,
663) -> Result<MlxBuffer> {
664    // Expert-offset path requires bf16 SIMD (qmv_fast) alignment; no fallback
665    // because the scalar kernel doesn't understand 3D expert packing.
666    if !can_use_simd_kernel_bf16(params) {
667        return Err(MlxError::InvalidArgument(
668            "dispatch_quantized_matmul_simd_bf16_expert: dimensions do not satisfy bf16 SIMD \
669             alignment requirements (N%8==0 and K%512==0 for 4-bit, K%256==0 for 8-bit)".into(),
670        ));
671    }
672
673    if params.bits != 4 && params.bits != 8 {
674        return Err(MlxError::InvalidArgument(format!(
675            "bf16 expert kernel: unsupported bits value {}; only 4 and 8 are supported",
676            params.bits
677        )));
678    }
679    if params.m == 0 || params.k == 0 || params.n == 0 || params.group_size == 0 {
680        return Err(MlxError::InvalidArgument(
681            "M, K, N, and group_size must all be > 0".into(),
682        ));
683    }
684
685    // We trust the caller to have sized the 3D buffers correctly.  Validate
686    // that the requested slice (offset + one-expert size) fits.
687    let expert_weight_bytes = expected_weight_bytes(params.k, params.n, params.bits);
688    let expert_scales_bytes = expected_scales_bytes(params.k, params.n, params.group_size);
689
690    if packed_weights.byte_len() < (expert_offset_bytes as usize) + expert_weight_bytes {
691        return Err(MlxError::InvalidArgument(format!(
692            "packed_weights too small for expert slice: offset={} + size={} > buffer={}",
693            expert_offset_bytes, expert_weight_bytes, packed_weights.byte_len()
694        )));
695    }
696    if scales.byte_len() < (scales_offset_bytes as usize) + expert_scales_bytes {
697        return Err(MlxError::InvalidArgument(format!(
698            "scales buffer too small for expert slice: offset={} + size={} > buffer={}",
699            scales_offset_bytes, expert_scales_bytes, scales.byte_len()
700        )));
701    }
702    if biases.byte_len() < (biases_offset_bytes as usize) + expert_scales_bytes {
703        return Err(MlxError::InvalidArgument(format!(
704            "biases buffer too small for expert slice: offset={} + size={} > buffer={}",
705            biases_offset_bytes, expert_scales_bytes, biases.byte_len()
706        )));
707    }
708
709    let pipeline = registry.get_pipeline("quantized_matmul_simd_bf16_expert", device.metal_device())?;
710
711    let output_bytes = (params.m as usize) * (params.n as usize) * DType::BF16.size_of();
712    let output = device.alloc_buffer(
713        output_bytes,
714        DType::BF16,
715        vec![params.m as usize, params.n as usize],
716    )?;
717
718    let gpu_params = QMatmulBf16GpuParams {
719        m: params.m,
720        k: params.k,
721        n: params.n,
722        group_size: params.group_size,
723        bits: params.bits,
724    };
725    let params_bytes = std::mem::size_of::<QMatmulBf16GpuParams>();
726    let mut params_buf = device.alloc_buffer(params_bytes, DType::U32, vec![5])?;
727    {
728        let slice: &mut [QMatmulBf16GpuParams] = bytemuck::cast_slice_mut(
729            params_buf
730                .as_mut_slice::<u8>()
731                .map_err(|e| MlxError::InvalidArgument(format!("params buf write: {e}")))?,
732        );
733        slice[0] = gpu_params;
734    }
735
736    // Pack the three byte-offset values into individual u32 buffers.
737    let mut expert_offset_buf = device.alloc_buffer(4, DType::U32, vec![1])?;
738    {
739        let s: &mut [u32] = expert_offset_buf
740            .as_mut_slice()
741            .map_err(|e| MlxError::InvalidArgument(format!("expert_offset buf: {e}")))?;
742        s[0] = expert_offset_bytes;
743    }
744    let mut scales_offset_buf = device.alloc_buffer(4, DType::U32, vec![1])?;
745    {
746        let s: &mut [u32] = scales_offset_buf
747            .as_mut_slice()
748            .map_err(|e| MlxError::InvalidArgument(format!("scales_offset buf: {e}")))?;
749        s[0] = scales_offset_bytes;
750    }
751    let mut biases_offset_buf = device.alloc_buffer(4, DType::U32, vec![1])?;
752    {
753        let s: &mut [u32] = biases_offset_buf
754            .as_mut_slice()
755            .map_err(|e| MlxError::InvalidArgument(format!("biases_offset buf: {e}")))?;
756        s[0] = biases_offset_bytes;
757    }
758
759    let num_simdgroups = 2u64;
760    let results_per_simdgroup = 4u64;
761    let bn = num_simdgroups * results_per_simdgroup;
762
763    let threadgroup_size = metal::MTLSize::new(32, num_simdgroups, 1);
764    let threadgroups = metal::MTLSize::new(
765        params.m as u64,
766        (params.n as u64 + bn - 1) / bn,
767        1,
768    );
769
770    encoder.encode_threadgroups(
771        pipeline,
772        &[
773            (0, input),
774            (1, packed_weights),
775            (2, scales),
776            (3, biases),
777            (4, &output),
778            (5, &params_buf),
779            (6, &expert_offset_buf),
780            (7, &scales_offset_buf),
781            (8, &biases_offset_buf),
782        ],
783        threadgroups,
784        threadgroup_size,
785    );
786
787    Ok(output)
788}
789
790#[cfg(test)]
791#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
792mod tests {
793    use super::*;
794    use crate::MlxDevice;
795
796    // ---- f16 / bf16 conversion helpers (no external dependency) ----
797
798    /// Convert an f32 to bfloat16 (bf16) bits.
799    /// bf16 is simply the top 16 bits of the IEEE 754 f32 representation.
800    fn f32_to_bf16_bits(val: f32) -> u16 {
801        (val.to_bits() >> 16) as u16
802    }
803
804    /// Convert an f32 to IEEE 754 half-precision (f16) bits.
805    /// Uses round-to-nearest-even.
806    fn f32_to_f16_bits(val: f32) -> u16 {
807        let bits = val.to_bits();
808        let sign = (bits >> 16) & 0x8000;
809        let exp = ((bits >> 23) & 0xFF) as i32;
810        let mantissa = bits & 0x007F_FFFF;
811
812        if exp == 255 {
813            // Inf or NaN
814            let m = if mantissa != 0 { 0x0200 } else { 0 };
815            return (sign | 0x7C00 | m) as u16;
816        }
817
818        // Rebias exponent from f32 (bias=127) to f16 (bias=15).
819        let new_exp = exp - 127 + 15;
820
821        if new_exp >= 31 {
822            // Overflow → Inf
823            return (sign | 0x7C00) as u16;
824        }
825
826        if new_exp <= 0 {
827            // Denormalized or zero
828            if new_exp < -10 {
829                return sign as u16; // Too small → zero
830            }
831            let m = (mantissa | 0x0080_0000) >> (1 - new_exp + 13);
832            return (sign | m) as u16;
833        }
834
835        // Normalized: round-to-nearest-even
836        let m = mantissa >> 13;
837        let round_bit = (mantissa >> 12) & 1;
838        let sticky = if (mantissa & 0xFFF) != 0 { 1u32 } else { 0 };
839        let round_up = round_bit & (sticky | m);
840        let result = sign | ((new_exp as u32) << 10) | m;
841        (result + round_up) as u16
842    }
843
844    /// Convert IEEE 754 half-precision (f16) bits to f32.
845    fn f16_bits_to_f32(bits: u16) -> f32 {
846        let sign = ((bits as u32 & 0x8000) as u32) << 16;
847        let exp = (bits >> 10) & 0x1F;
848        let mantissa = (bits & 0x03FF) as u32;
849
850        if exp == 0 {
851            if mantissa == 0 {
852                return f32::from_bits(sign); // +/- zero
853            }
854            // Denormalized: normalize it.
855            let mut m = mantissa;
856            let mut e: i32 = -14;
857            while (m & 0x0400) == 0 {
858                m <<= 1;
859                e -= 1;
860            }
861            m &= 0x03FF;
862            let f32_exp = ((e + 127) as u32) << 23;
863            let f32_mantissa = m << 13;
864            return f32::from_bits(sign | f32_exp | f32_mantissa);
865        }
866
867        if exp == 31 {
868            let m = if mantissa != 0 { 0x007F_FFFF } else { 0 };
869            return f32::from_bits(sign | 0x7F80_0000 | m);
870        }
871
872        let f32_exp = ((exp as u32 - 15 + 127) as u32) << 23;
873        let f32_mantissa = mantissa << 13;
874        f32::from_bits(sign | f32_exp | f32_mantissa)
875    }
876
877    // Helper: create an f16 buffer from f32 values.
878    #[allow(dead_code)]
879    fn f16_buffer(device: &MlxDevice, shape: Vec<usize>, values: &[f32]) -> MlxBuffer {
880        let byte_len = values.len() * 2;
881        let mut buf = device.alloc_buffer(byte_len, DType::F16, shape).expect("alloc");
882        {
883            let slice: &mut [u16] = buf.as_mut_slice().expect("as_mut_slice");
884            for (i, &v) in values.iter().enumerate() {
885                slice[i] = f32_to_f16_bits(v);
886            }
887        }
888        buf
889    }
890
891    // Helper: create a bf16 buffer from f32 values (used for scales/biases).
892    fn bf16_buffer(device: &MlxDevice, shape: Vec<usize>, values: &[f32]) -> MlxBuffer {
893        let byte_len = values.len() * 2;
894        let mut buf = device.alloc_buffer(byte_len, DType::BF16, shape).expect("alloc");
895        {
896            let slice: &mut [u16] = buf.as_mut_slice().expect("as_mut_slice");
897            for (i, &v) in values.iter().enumerate() {
898                slice[i] = f32_to_bf16_bits(v);
899            }
900        }
901        buf
902    }
903
904    // Helper: create an f32 buffer from f32 values (used for input).
905    fn f32_buffer(device: &MlxDevice, shape: Vec<usize>, values: &[f32]) -> MlxBuffer {
906        let byte_len = values.len() * 4;
907        let mut buf = device.alloc_buffer(byte_len, DType::F32, shape).expect("alloc");
908        {
909            let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
910            slice.copy_from_slice(values);
911        }
912        buf
913    }
914
915    // Helper: pack 4-bit values into uint32 buffer.
916    // `quant_values` is a flat array of quantized unsigned values (0..15),
917    // laid out as weight[col][k] (N rows of K values each).
918    fn pack_4bit_buffer(device: &MlxDevice, n: usize, k: usize, quant_values: &[u8]) -> MlxBuffer {
919        let values_per_pack = 8;
920        let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
921        let total_packs = n * packs_per_row;
922        let byte_len = total_packs * 4;
923
924        let mut buf = device.alloc_buffer(byte_len, DType::U32, vec![n, packs_per_row]).expect("alloc");
925        {
926            let slice: &mut [u32] = buf.as_mut_slice().expect("as_mut_slice");
927            for col in 0..n {
928                for pack in 0..packs_per_row {
929                    let mut packed: u32 = 0;
930                    for i in 0..values_per_pack {
931                        let k_idx = pack * values_per_pack + i;
932                        if k_idx < k {
933                            let val = quant_values[col * k + k_idx] as u32 & 0xF;
934                            packed |= val << (4 * i);
935                        }
936                    }
937                    slice[col * packs_per_row + pack] = packed;
938                }
939            }
940        }
941        buf
942    }
943
944    // Helper: pack 6-bit values into uint32 buffer.
945    fn pack_6bit_buffer(device: &MlxDevice, n: usize, k: usize, quant_values: &[u8]) -> MlxBuffer {
946        // 6-bit: 4 values per 3-byte triplet (24 bits). The Metal shader reads
947        // raw bytes as 3-byte triplets, NOT uint32, so we must match that layout.
948        let triplets_per_row = (k + 3) / 4;
949        let row_bytes = triplets_per_row * 3;
950        let total_bytes = n * row_bytes;
951
952        let mut buf = device.alloc_buffer(total_bytes, DType::U8, vec![total_bytes]).expect("alloc");
953        {
954            let slice: &mut [u8] = buf.as_mut_slice().expect("as_mut_slice");
955            for col in 0..n {
956                for t in 0..triplets_per_row {
957                    let mut packed: u32 = 0;
958                    for i in 0..4 {
959                        let k_idx = t * 4 + i;
960                        if k_idx < k {
961                            let val = quant_values[col * k + k_idx] as u32 & 0x3F;
962                            packed |= val << (6 * i);
963                        }
964                    }
965                    let base = col * row_bytes + t * 3;
966                    slice[base] = (packed & 0xFF) as u8;
967                    slice[base + 1] = ((packed >> 8) & 0xFF) as u8;
968                    slice[base + 2] = ((packed >> 16) & 0xFF) as u8;
969                }
970            }
971        }
972        buf
973    }
974
975    // Helper: pack 8-bit values into uint32 buffer.
976    // 4 values per uint32 (8 bits each).
977    fn pack_8bit_buffer(device: &MlxDevice, n: usize, k: usize, quant_values: &[u8]) -> MlxBuffer {
978        let values_per_pack = 4;
979        let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
980        let total_packs = n * packs_per_row;
981        let byte_len = total_packs * 4;
982
983        let mut buf = device.alloc_buffer(byte_len, DType::U32, vec![n, packs_per_row]).expect("alloc");
984        {
985            let slice: &mut [u32] = buf.as_mut_slice().expect("as_mut_slice");
986            for col in 0..n {
987                for pack in 0..packs_per_row {
988                    let mut packed: u32 = 0;
989                    for i in 0..values_per_pack {
990                        let k_idx = pack * values_per_pack + i;
991                        if k_idx < k {
992                            let val = quant_values[col * k + k_idx] as u32 & 0xFF;
993                            packed |= val << (8 * i);
994                        }
995                    }
996                    slice[col * packs_per_row + pack] = packed;
997                }
998            }
999        }
1000        buf
1001    }
1002
1003    // Helper: read f16 buffer back as f32.
1004    #[allow(dead_code)]
1005    fn read_f16(buf: &MlxBuffer) -> Vec<f32> {
1006        let slice: &[u16] = buf.as_slice().expect("as_slice");
1007        slice.iter().map(|&bits| f16_bits_to_f32(bits)).collect()
1008    }
1009
1010    // Helper: read f32 output buffer.
1011    fn read_f32(buf: &MlxBuffer) -> Vec<f32> {
1012        let slice: &[f32] = buf.as_slice().expect("as_slice");
1013        slice.to_vec()
1014    }
1015
1016    /// Test 4-bit quantized matmul with a small known example.
1017    ///
1018    /// input = [[1.0, 2.0, 3.0, 4.0]]  (M=1, K=4)
1019    /// weight quantized values (N=2, K=4): [[1, 2, 3, 4], [5, 6, 7, 8]]
1020    /// scales = [[0.1], [0.2]]  (1 group per row since group_size=64 > K=4)
1021    /// biases = [[0.0], [0.0]]
1022    ///
1023    /// dequant weight row 0: [0.1, 0.2, 0.3, 0.4]
1024    /// dequant weight row 1: [1.0, 1.2, 1.4, 1.6]
1025    ///
1026    /// output[0][0] = 1.0*0.1 + 2.0*0.2 + 3.0*0.3 + 4.0*0.4 = 0.1+0.4+0.9+1.6 = 3.0
1027    /// output[0][1] = 1.0*1.0 + 2.0*1.2 + 3.0*1.4 + 4.0*1.6 = 1.0+2.4+4.2+6.4 = 14.0
1028    #[test]
1029    fn test_4bit_matmul_small_known() {
1030        let device = MlxDevice::new().expect("device");
1031        let mut registry = KernelRegistry::new();
1032        let mut encoder = device.command_encoder().expect("encoder");
1033
1034        let m = 1u32;
1035        let k = 4u32;
1036        let n = 2u32;
1037        let group_size = 64u32;
1038        let bits = 4u32;
1039
1040        let input = f32_buffer(&device, vec![m as usize, k as usize], &[1.0, 2.0, 3.0, 4.0]);
1041
1042        // Quantized weight values (unsigned): row0=[1,2,3,4], row1=[5,6,7,8]
1043        let quant_w: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
1044        let weight = pack_4bit_buffer(&device, n as usize, k as usize, &quant_w);
1045
1046        let scales = bf16_buffer(&device, vec![n as usize, 1], &[0.1, 0.2]);
1047        let biases = bf16_buffer(&device, vec![n as usize, 1], &[0.0, 0.0]);
1048
1049        let params = QuantizedMatmulParams { m, k, n, group_size, bits };
1050
1051        let output = quantized_matmul(
1052            &mut encoder, &mut registry, &device,
1053            &input, &weight, &scales, &biases, &params,
1054        ).expect("quantized_matmul");
1055
1056        encoder.commit_and_wait().expect("commit");
1057
1058        let result = read_f32(&output);
1059        assert_eq!(result.len(), 2);
1060
1061        // Tolerance: bf16 precision is ~1e-2 for these magnitudes.
1062        let tol = 1e-1; // bf16 has limited precision, be generous for this test
1063        assert!(
1064            (result[0] - 3.0).abs() < tol,
1065            "output[0]={}, expected ~3.0", result[0]
1066        );
1067        assert!(
1068            (result[1] - 14.0).abs() < tol,
1069            "output[1]={}, expected ~14.0", result[1]
1070        );
1071    }
1072
1073    /// Test 6-bit quantized matmul with a small known example.
1074    #[test]
1075    fn test_6bit_matmul_small_known() {
1076        let device = MlxDevice::new().expect("device");
1077        let mut registry = KernelRegistry::new();
1078        let mut encoder = device.command_encoder().expect("encoder");
1079
1080        let m = 1u32;
1081        let k = 4u32;
1082        let n = 2u32;
1083        let group_size = 64u32;
1084        let bits = 6u32;
1085
1086        let input = f32_buffer(&device, vec![m as usize, k as usize], &[1.0, 2.0, 3.0, 4.0]);
1087
1088        // 6-bit quantized weight values (0..63): row0=[1,2,3,4], row1=[10,20,30,40]
1089        let quant_w: Vec<u8> = vec![1, 2, 3, 4, 10, 20, 30, 40];
1090        let weight = pack_6bit_buffer(&device, n as usize, k as usize, &quant_w);
1091
1092        let scales = bf16_buffer(&device, vec![n as usize, 1], &[0.1, 0.05]);
1093        let biases = bf16_buffer(&device, vec![n as usize, 1], &[0.0, 0.0]);
1094
1095        let params = QuantizedMatmulParams { m, k, n, group_size, bits };
1096
1097        let output = quantized_matmul(
1098            &mut encoder, &mut registry, &device,
1099            &input, &weight, &scales, &biases, &params,
1100        ).expect("quantized_matmul");
1101
1102        encoder.commit_and_wait().expect("commit");
1103
1104        let result = read_f32(&output);
1105        assert_eq!(result.len(), 2);
1106
1107        // dequant row 0: [0.1, 0.2, 0.3, 0.4]
1108        // output[0] = 1*0.1 + 2*0.2 + 3*0.3 + 4*0.4 = 3.0
1109        // dequant row 1: [0.5, 1.0, 1.5, 2.0]
1110        // output[1] = 1*0.5 + 2*1.0 + 3*1.5 + 4*2.0 = 0.5+2.0+4.5+8.0 = 15.0
1111        let tol = 1e-1;
1112        assert!(
1113            (result[0] - 3.0).abs() < tol,
1114            "output[0]={}, expected ~3.0", result[0]
1115        );
1116        assert!(
1117            (result[1] - 15.0).abs() < tol,
1118            "output[1]={}, expected ~15.0", result[1]
1119        );
1120    }
1121
1122    /// Test with non-zero biases.
1123    #[test]
1124    fn test_4bit_matmul_with_bias() {
1125        let device = MlxDevice::new().expect("device");
1126        let mut registry = KernelRegistry::new();
1127        let mut encoder = device.command_encoder().expect("encoder");
1128
1129        let m = 1u32;
1130        let k = 4u32;
1131        let n = 1u32;
1132        let group_size = 64u32;
1133        let bits = 4u32;
1134
1135        let input = f32_buffer(&device, vec![1, 4], &[1.0, 1.0, 1.0, 1.0]);
1136
1137        // quant values all 0 → dequant = scale*0 + bias = bias
1138        let quant_w: Vec<u8> = vec![0, 0, 0, 0];
1139        let weight = pack_4bit_buffer(&device, 1, 4, &quant_w);
1140
1141        let scales = bf16_buffer(&device, vec![1, 1], &[1.0]);
1142        let biases = bf16_buffer(&device, vec![1, 1], &[0.5]);
1143
1144        let params = QuantizedMatmulParams { m, k, n, group_size, bits };
1145
1146        let output = quantized_matmul(
1147            &mut encoder, &mut registry, &device,
1148            &input, &weight, &scales, &biases, &params,
1149        ).expect("quantized_matmul");
1150
1151        encoder.commit_and_wait().expect("commit");
1152
1153        let result = read_f32(&output);
1154        // Each weight dequantized to 0.5, dot with [1,1,1,1] = 2.0
1155        let tol = 1e-2;
1156        assert!(
1157            (result[0] - 2.0).abs() < tol,
1158            "output[0]={}, expected ~2.0", result[0]
1159        );
1160    }
1161
1162    /// Test batch (M > 1).
1163    #[test]
1164    fn test_4bit_batch_matmul() {
1165        let device = MlxDevice::new().expect("device");
1166        let mut registry = KernelRegistry::new();
1167        let mut encoder = device.command_encoder().expect("encoder");
1168
1169        let m = 2u32;
1170        let k = 4u32;
1171        let n = 1u32;
1172        let group_size = 64u32;
1173        let bits = 4u32;
1174
1175        // Two input rows: [1,0,0,0] and [0,1,0,0]
1176        let input = f32_buffer(&device, vec![2, 4], &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
1177
1178        // quant weight row: [2, 4, 6, 8]
1179        let quant_w: Vec<u8> = vec![2, 4, 6, 8];
1180        let weight = pack_4bit_buffer(&device, 1, 4, &quant_w);
1181
1182        let scales = bf16_buffer(&device, vec![1, 1], &[0.5]);
1183        let biases = bf16_buffer(&device, vec![1, 1], &[0.0]);
1184
1185        let params = QuantizedMatmulParams { m, k, n, group_size, bits };
1186
1187        let output = quantized_matmul(
1188            &mut encoder, &mut registry, &device,
1189            &input, &weight, &scales, &biases, &params,
1190        ).expect("quantized_matmul");
1191
1192        encoder.commit_and_wait().expect("commit");
1193
1194        let result = read_f32(&output);
1195        assert_eq!(result.len(), 2);
1196
1197        // dequant: [1.0, 2.0, 3.0, 4.0]
1198        // row0: 1.0*1.0 = 1.0
1199        // row1: 1.0*2.0 = 2.0
1200        let tol = 1e-2;
1201        assert!((result[0] - 1.0).abs() < tol, "row0={}, expected 1.0", result[0]);
1202        assert!((result[1] - 2.0).abs() < tol, "row1={}, expected 2.0", result[1]);
1203    }
1204
1205    /// Test invalid bits returns error.
1206    #[test]
1207    fn test_invalid_bits_returns_error() {
1208        let device = MlxDevice::new().expect("device");
1209        let mut registry = KernelRegistry::new();
1210        let mut encoder = device.command_encoder().expect("encoder");
1211
1212        let input = f32_buffer(&device, vec![1, 4], &[1.0; 4]);
1213        // Minimal buffers — validation should fail before size checks matter.
1214        let weight = device.alloc_buffer(4, DType::U32, vec![1]).expect("alloc");
1215        let scales = bf16_buffer(&device, vec![1], &[1.0]);
1216        let biases = bf16_buffer(&device, vec![1], &[0.0]);
1217
1218        let params = QuantizedMatmulParams {
1219            m: 1, k: 4, n: 1, group_size: 64, bits: 5,
1220        };
1221
1222        let result = quantized_matmul(
1223            &mut encoder, &mut registry, &device,
1224            &input, &weight, &scales, &biases, &params,
1225        );
1226
1227        assert!(result.is_err());
1228        match result {
1229            Err(MlxError::InvalidArgument(msg)) => {
1230                assert!(msg.contains("bits"), "Error should mention bits: {msg}");
1231            }
1232            other => panic!("Expected InvalidArgument, got {:?}", other),
1233        }
1234    }
1235
1236    /// Test mismatched dimensions returns error.
1237    #[test]
1238    fn test_mismatched_dimensions_returns_error() {
1239        let device = MlxDevice::new().expect("device");
1240        let mut registry = KernelRegistry::new();
1241        let mut encoder = device.command_encoder().expect("encoder");
1242
1243        // Input is 1x4 but we'll claim K=128 in params.
1244        let input = f32_buffer(&device, vec![1, 4], &[1.0; 4]);
1245        let weight = device.alloc_buffer(4, DType::U32, vec![1]).expect("alloc");
1246        let scales = bf16_buffer(&device, vec![1], &[1.0]);
1247        let biases = bf16_buffer(&device, vec![1], &[0.0]);
1248
1249        let params = QuantizedMatmulParams {
1250            m: 1, k: 128, n: 1, group_size: 64, bits: 4,
1251        };
1252
1253        let result = quantized_matmul(
1254            &mut encoder, &mut registry, &device,
1255            &input, &weight, &scales, &biases, &params,
1256        );
1257
1258        assert!(result.is_err());
1259        match result {
1260            Err(MlxError::InvalidArgument(msg)) => {
1261                assert!(msg.contains("Input buffer too small"), "msg: {msg}");
1262            }
1263            other => panic!("Expected InvalidArgument for input size, got {:?}", other),
1264        }
1265    }
1266
1267    /// Test 8-bit quantized matmul with a small known example.
1268    ///
1269    /// input = [[1.0, 2.0, 3.0, 4.0]]  (M=1, K=4)
1270    /// weight quantized values (N=2, K=4): [[10, 20, 30, 40], [50, 60, 70, 80]]
1271    /// scales = [[0.01], [0.02]]  (1 group per row since group_size=64 > K=4)
1272    /// biases = [[0.0], [0.0]]
1273    ///
1274    /// dequant weight row 0: [0.1, 0.2, 0.3, 0.4]
1275    /// dequant weight row 1: [1.0, 1.2, 1.4, 1.6]
1276    ///
1277    /// output[0][0] = 1.0*0.1 + 2.0*0.2 + 3.0*0.3 + 4.0*0.4 = 3.0
1278    /// output[0][1] = 1.0*1.0 + 2.0*1.2 + 3.0*1.4 + 4.0*1.6 = 14.0
1279    #[test]
1280    fn test_8bit_matmul_small_known() {
1281        let device = MlxDevice::new().expect("device");
1282        let mut registry = KernelRegistry::new();
1283        let mut encoder = device.command_encoder().expect("encoder");
1284
1285        let m = 1u32;
1286        let k = 4u32;
1287        let n = 2u32;
1288        let group_size = 64u32;
1289        let bits = 8u32;
1290
1291        let input = f32_buffer(&device, vec![m as usize, k as usize], &[1.0, 2.0, 3.0, 4.0]);
1292
1293        // 8-bit quantized weight values (0..255): row0=[10,20,30,40], row1=[50,60,70,80]
1294        let quant_w: Vec<u8> = vec![10, 20, 30, 40, 50, 60, 70, 80];
1295        let weight = pack_8bit_buffer(&device, n as usize, k as usize, &quant_w);
1296
1297        let scales = bf16_buffer(&device, vec![n as usize, 1], &[0.01, 0.02]);
1298        let biases = bf16_buffer(&device, vec![n as usize, 1], &[0.0, 0.0]);
1299
1300        let params = QuantizedMatmulParams { m, k, n, group_size, bits };
1301
1302        let output = quantized_matmul(
1303            &mut encoder, &mut registry, &device,
1304            &input, &weight, &scales, &biases, &params,
1305        ).expect("quantized_matmul");
1306
1307        encoder.commit_and_wait().expect("commit");
1308
1309        let result = read_f32(&output);
1310        assert_eq!(result.len(), 2);
1311
1312        // Tolerance: bf16 precision is ~1e-2 for these magnitudes.
1313        let tol = 1e-1;
1314        assert!(
1315            (result[0] - 3.0).abs() < tol,
1316            "output[0]={}, expected ~3.0", result[0]
1317        );
1318        assert!(
1319            (result[1] - 14.0).abs() < tol,
1320            "output[1]={}, expected ~14.0", result[1]
1321        );
1322    }
1323
1324    /// Test 8-bit with non-zero biases.
1325    #[test]
1326    fn test_8bit_matmul_with_bias() {
1327        let device = MlxDevice::new().expect("device");
1328        let mut registry = KernelRegistry::new();
1329        let mut encoder = device.command_encoder().expect("encoder");
1330
1331        let m = 1u32;
1332        let k = 4u32;
1333        let n = 1u32;
1334        let group_size = 64u32;
1335        let bits = 8u32;
1336
1337        let input = f32_buffer(&device, vec![1, 4], &[1.0, 1.0, 1.0, 1.0]);
1338
1339        // quant values all 0 -> dequant = scale*0 + bias = bias
1340        let quant_w: Vec<u8> = vec![0, 0, 0, 0];
1341        let weight = pack_8bit_buffer(&device, 1, 4, &quant_w);
1342
1343        let scales = bf16_buffer(&device, vec![1, 1], &[1.0]);
1344        let biases = bf16_buffer(&device, vec![1, 1], &[0.5]);
1345
1346        let params = QuantizedMatmulParams { m, k, n, group_size, bits };
1347
1348        let output = quantized_matmul(
1349            &mut encoder, &mut registry, &device,
1350            &input, &weight, &scales, &biases, &params,
1351        ).expect("quantized_matmul");
1352
1353        encoder.commit_and_wait().expect("commit");
1354
1355        let result = read_f32(&output);
1356        // Each weight dequantized to 0.5, dot with [1,1,1,1] = 2.0
1357        let tol = 1e-2;
1358        assert!(
1359            (result[0] - 2.0).abs() < tol,
1360            "output[0]={}, expected ~2.0", result[0]
1361        );
1362    }
1363
1364    /// Test with multiple groups along K (K > group_size).
1365    #[test]
1366    fn test_4bit_multiple_groups() {
1367        let device = MlxDevice::new().expect("device");
1368        let mut registry = KernelRegistry::new();
1369        let mut encoder = device.command_encoder().expect("encoder");
1370
1371        // Use K=8 with group_size=4 → 2 groups per column.
1372        let m = 1u32;
1373        let k = 8u32;
1374        let n = 1u32;
1375        let group_size = 4u32;
1376        let bits = 4u32;
1377
1378        let input = f32_buffer(&device, vec![1, 8], &[1.0; 8]);
1379
1380        // quant values: [1,1,1,1, 2,2,2,2]
1381        let quant_w: Vec<u8> = vec![1, 1, 1, 1, 2, 2, 2, 2];
1382        let weight = pack_4bit_buffer(&device, 1, 8, &quant_w);
1383
1384        // 2 groups: scale=[0.5, 1.0], bias=[0.0, 0.0]
1385        let scales = bf16_buffer(&device, vec![1, 2], &[0.5, 1.0]);
1386        let biases = bf16_buffer(&device, vec![1, 2], &[0.0, 0.0]);
1387
1388        let params = QuantizedMatmulParams { m, k, n, group_size, bits };
1389
1390        let output = quantized_matmul(
1391            &mut encoder, &mut registry, &device,
1392            &input, &weight, &scales, &biases, &params,
1393        ).expect("quantized_matmul");
1394
1395        encoder.commit_and_wait().expect("commit");
1396
1397        let result = read_f32(&output);
1398        // Group 0: dequant=[0.5,0.5,0.5,0.5], sum = 4*0.5 = 2.0
1399        // Group 1: dequant=[2.0,2.0,2.0,2.0], sum = 4*2.0 = 8.0
1400        // Total = 10.0
1401        let tol = 1e-1;
1402        assert!(
1403            (result[0] - 10.0).abs() < tol,
1404            "output[0]={}, expected ~10.0", result[0]
1405        );
1406    }
1407}