1use crate::buffer::MlxBuffer;
10use crate::device::MlxDevice;
11use crate::dtypes::DType;
12use crate::encoder::CommandEncoder;
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16#[derive(Debug, Clone, Copy)]
18pub struct QuantizedMatmulParams {
19 pub m: u32,
21 pub k: u32,
23 pub n: u32,
25 pub group_size: u32,
27 pub bits: u32,
29}
30
31#[repr(C)]
35#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
36struct QuantizedMatmulGpuParams {
37 m: u32,
38 k: u32,
39 n: u32,
40 group_size: u32,
41 bits: u32,
42}
43
44fn expected_weight_bytes(k: u32, n: u32, bits: u32) -> usize {
53 match bits {
54 4 => {
55 let values_per_pack = 8u32;
56 let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
57 (n as usize) * (packs_per_row as usize) * 4
58 }
59 6 => {
60 let triplets_per_row = (k + 3) / 4;
62 (n as usize) * (triplets_per_row as usize) * 3
63 }
64 8 => {
65 let values_per_pack = 4u32;
66 let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
67 (n as usize) * (packs_per_row as usize) * 4
68 }
69 _ => 0,
70 }
71}
72
73fn expected_scales_bytes(k: u32, n: u32, group_size: u32) -> usize {
78 let num_groups = (k + group_size - 1) / group_size;
79 (n as usize) * (num_groups as usize) * 2 }
81
82pub fn quantized_matmul(
107 encoder: &mut CommandEncoder,
108 registry: &mut KernelRegistry,
109 device: &MlxDevice,
110 input: &MlxBuffer,
111 weight: &MlxBuffer,
112 scales: &MlxBuffer,
113 biases: &MlxBuffer,
114 params: &QuantizedMatmulParams,
115) -> Result<MlxBuffer> {
116 if params.bits != 4 && params.bits != 6 && params.bits != 8 {
118 return Err(MlxError::InvalidArgument(format!(
119 "Unsupported bits value {}; only 4, 6, and 8 are supported",
120 params.bits
121 )));
122 }
123
124 if params.m == 0 || params.k == 0 || params.n == 0 {
126 return Err(MlxError::InvalidArgument(
127 "M, K, and N must all be > 0".into(),
128 ));
129 }
130 if params.group_size == 0 {
131 return Err(MlxError::InvalidArgument(
132 "group_size must be > 0".into(),
133 ));
134 }
135
136 let expected_input = (params.m as usize) * (params.k as usize) * DType::F32.size_of();
138 if input.byte_len() < expected_input {
139 return Err(MlxError::InvalidArgument(format!(
140 "Input buffer too small: expected at least {} bytes for [{}x{}] f32, got {}",
141 expected_input, params.m, params.k, input.byte_len()
142 )));
143 }
144
145 let expected_w = expected_weight_bytes(params.k, params.n, params.bits);
146 if weight.byte_len() < expected_w {
147 return Err(MlxError::InvalidArgument(format!(
148 "Weight buffer too small: expected at least {} bytes for {}bit [{}x{}], got {}",
149 expected_w, params.bits, params.n, params.k, weight.byte_len()
150 )));
151 }
152
153 let expected_s = expected_scales_bytes(params.k, params.n, params.group_size);
154 if scales.byte_len() < expected_s {
155 return Err(MlxError::InvalidArgument(format!(
156 "Scales buffer too small: expected at least {} bytes, got {}",
157 expected_s, scales.byte_len()
158 )));
159 }
160 if biases.byte_len() < expected_s {
161 return Err(MlxError::InvalidArgument(format!(
162 "Biases buffer too small: expected at least {} bytes, got {}",
163 expected_s, biases.byte_len()
164 )));
165 }
166
167 let pipeline = registry.get_pipeline("quantized_matmul", device.metal_device())?;
169
170 let output_bytes = (params.m as usize) * (params.n as usize) * DType::F32.size_of();
174 let output = device.alloc_buffer(
175 output_bytes,
176 DType::F32,
177 vec![params.m as usize, params.n as usize],
178 )?;
179
180 let gpu_params = QuantizedMatmulGpuParams {
182 m: params.m,
183 k: params.k,
184 n: params.n,
185 group_size: params.group_size,
186 bits: params.bits,
187 };
188 let params_bytes = std::mem::size_of::<QuantizedMatmulGpuParams>();
189 let mut params_buf = device.alloc_buffer(params_bytes, DType::U32, vec![5])?;
190 {
191 let slice: &mut [QuantizedMatmulGpuParams] = bytemuck::cast_slice_mut(
192 params_buf
193 .as_mut_slice::<u8>()
194 .map_err(|e| MlxError::InvalidArgument(format!("params buf write: {e}")))?,
195 );
196 slice[0] = gpu_params;
197 }
198
199 let tg_x = 16u64.min(params.n as u64);
207 let tg_y = 16u64.min(params.m as u64);
208 let threadgroup_size = metal::MTLSize::new(tg_x, tg_y, 1);
209
210 let grid_groups = metal::MTLSize::new(
212 (params.n as u64 + tg_x - 1) / tg_x,
213 (params.m as u64 + tg_y - 1) / tg_y,
214 1,
215 );
216
217 encoder.encode_threadgroups(
218 pipeline,
219 &[
220 (0, input),
221 (1, weight),
222 (2, scales),
223 (3, biases),
224 (4, &output),
225 (5, ¶ms_buf),
226 ],
227 grid_groups,
228 threadgroup_size,
229 );
230
231 Ok(output)
232}
233
234fn can_use_simd_kernel(params: &QuantizedMatmulParams) -> bool {
248 let bn = 8u32; if params.n % bn != 0 {
250 return false;
251 }
252 match params.bits {
253 4 => params.k % 256 == 0, 8 => params.k % 256 == 0,
255 _ => false,
256 }
257}
258
259fn can_use_simd_kernel_bf16(params: &QuantizedMatmulParams) -> bool {
261 let bn = 8u32;
262 if params.n % bn != 0 {
263 return false;
264 }
265 match params.bits {
266 4 => params.k % 512 == 0, 8 => params.k % 256 == 0,
268 _ => false,
269 }
270}
271
272pub fn quantized_matmul_simd(
289 encoder: &mut CommandEncoder,
290 registry: &mut KernelRegistry,
291 device: &MlxDevice,
292 input: &MlxBuffer,
293 weight: &MlxBuffer,
294 scales: &MlxBuffer,
295 biases: &MlxBuffer,
296 params: &QuantizedMatmulParams,
297) -> Result<MlxBuffer> {
298 if !can_use_simd_kernel(params) {
300 return quantized_matmul(encoder, registry, device, input, weight, scales, biases, params);
301 }
302
303 if params.bits == 6 {
306 return quantized_matmul(encoder, registry, device, input, weight, scales, biases, params);
307 }
308 if params.bits != 4 && params.bits != 8 {
309 return Err(MlxError::InvalidArgument(format!(
310 "SIMD kernel: unsupported bits value {}; only 4, 6, and 8 are supported",
311 params.bits
312 )));
313 }
314
315 if params.m == 0 || params.k == 0 || params.n == 0 {
317 return Err(MlxError::InvalidArgument(
318 "M, K, and N must all be > 0".into(),
319 ));
320 }
321 if params.group_size == 0 {
322 return Err(MlxError::InvalidArgument(
323 "group_size must be > 0".into(),
324 ));
325 }
326
327 let expected_input = (params.m as usize) * (params.k as usize) * DType::F32.size_of();
329 if input.byte_len() < expected_input {
330 return Err(MlxError::InvalidArgument(format!(
331 "Input buffer too small: expected at least {} bytes for [{}x{}] f32, got {}",
332 expected_input, params.m, params.k, input.byte_len()
333 )));
334 }
335
336 let expected_w = expected_weight_bytes(params.k, params.n, params.bits);
337 if weight.byte_len() < expected_w {
338 return Err(MlxError::InvalidArgument(format!(
339 "Weight buffer too small: expected at least {} bytes for {}bit [{}x{}], got {}",
340 expected_w, params.bits, params.n, params.k, weight.byte_len()
341 )));
342 }
343
344 let expected_s = expected_scales_bytes(params.k, params.n, params.group_size);
345 if scales.byte_len() < expected_s {
346 return Err(MlxError::InvalidArgument(format!(
347 "Scales buffer too small: expected at least {} bytes, got {}",
348 expected_s, scales.byte_len()
349 )));
350 }
351 if biases.byte_len() < expected_s {
352 return Err(MlxError::InvalidArgument(format!(
353 "Biases buffer too small: expected at least {} bytes, got {}",
354 expected_s, biases.byte_len()
355 )));
356 }
357
358 let pipeline = registry.get_pipeline("quantized_matmul_simd", device.metal_device())?;
360
361 let output_bytes = (params.m as usize) * (params.n as usize) * DType::F32.size_of();
363 let output = device.alloc_buffer(
364 output_bytes,
365 DType::F32,
366 vec![params.m as usize, params.n as usize],
367 )?;
368
369 let gpu_params = QuantizedMatmulGpuParams {
371 m: params.m,
372 k: params.k,
373 n: params.n,
374 group_size: params.group_size,
375 bits: params.bits,
376 };
377 let params_bytes = std::mem::size_of::<QuantizedMatmulGpuParams>();
378 let mut params_buf = device.alloc_buffer(params_bytes, DType::U32, vec![5])?;
379 {
380 let slice: &mut [QuantizedMatmulGpuParams] = bytemuck::cast_slice_mut(
381 params_buf
382 .as_mut_slice::<u8>()
383 .map_err(|e| MlxError::InvalidArgument(format!("params buf write: {e}")))?,
384 );
385 slice[0] = gpu_params;
386 }
387
388 let num_simdgroups = 2u64;
398 let results_per_simdgroup = 4u64;
399 let bn = num_simdgroups * results_per_simdgroup; let threadgroup_size = metal::MTLSize::new(32, num_simdgroups, 1);
402 let threadgroups = metal::MTLSize::new(
403 params.m as u64,
404 (params.n as u64 + bn - 1) / bn,
405 1,
406 );
407
408 encoder.encode_threadgroups(
409 pipeline,
410 &[
411 (0, input),
412 (1, weight),
413 (2, scales),
414 (3, biases),
415 (4, &output),
416 (5, ¶ms_buf),
417 ],
418 threadgroups,
419 threadgroup_size,
420 );
421
422 Ok(output)
423}
424
425#[repr(C)]
433#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
434struct QMatmulBf16GpuParams {
435 m: u32,
436 k: u32,
437 n: u32,
438 group_size: u32,
439 bits: u32,
440}
441
442pub fn dispatch_quantized_matmul_simd_bf16(
465 encoder: &mut CommandEncoder,
466 registry: &mut KernelRegistry,
467 device: &MlxDevice,
468 input: &MlxBuffer,
469 packed_weights: &MlxBuffer,
470 scales: &MlxBuffer,
471 biases: &MlxBuffer,
472 params: &QuantizedMatmulParams,
473) -> Result<MlxBuffer> {
474 if !can_use_simd_kernel_bf16(params) {
477 let n_in = (params.m as usize) * (params.k as usize);
478 let f32_input = if input.dtype() == DType::BF16 {
479 let f32_buf = device.alloc_buffer(n_in * DType::F32.size_of(), DType::F32, vec![params.m as usize, params.k as usize])?;
480 crate::ops::elementwise::cast(
481 encoder, registry, device.metal_device(),
482 input, &f32_buf, n_in,
483 crate::ops::elementwise::CastDirection::BF16ToF32,
484 )?;
485 Some(f32_buf)
486 } else {
487 None
488 };
489 let actual_input = f32_input.as_ref().unwrap_or(input);
490 let f32_result = quantized_matmul(encoder, registry, device, actual_input, packed_weights, scales, biases, params)?;
491 let n_out = (params.m as usize) * (params.n as usize);
493 let bf16_out = device.alloc_buffer(n_out * DType::BF16.size_of(), DType::BF16, vec![params.m as usize, params.n as usize])?;
494 crate::ops::elementwise::cast(
495 encoder, registry, device.metal_device(),
496 &f32_result, &bf16_out, n_out,
497 crate::ops::elementwise::CastDirection::F32ToBF16,
498 )?;
499 return Ok(bf16_out);
500 }
501
502 if params.bits == 6 {
506 let n_in = (params.m as usize) * (params.k as usize);
508 let f32_input = if input.dtype() == DType::BF16 {
509 let f32_buf = device.alloc_buffer(n_in * DType::F32.size_of(), DType::F32, vec![params.m as usize, params.k as usize])?;
510 crate::ops::elementwise::cast(
511 encoder, registry, device.metal_device(),
512 input, &f32_buf, n_in,
513 crate::ops::elementwise::CastDirection::BF16ToF32,
514 )?;
515 Some(f32_buf)
516 } else {
517 None
518 };
519 let actual_input = f32_input.as_ref().unwrap_or(input);
520 let f32_result = quantized_matmul(encoder, registry, device, actual_input, packed_weights, scales, biases, params)?;
521 let n_out = (params.m as usize) * (params.n as usize);
522 let bf16_out = device.alloc_buffer(n_out * DType::BF16.size_of(), DType::BF16, vec![params.m as usize, params.n as usize])?;
523 crate::ops::elementwise::cast(
524 encoder, registry, device.metal_device(),
525 &f32_result, &bf16_out, n_out,
526 crate::ops::elementwise::CastDirection::F32ToBF16,
527 )?;
528 return Ok(bf16_out);
529 }
530 if params.bits != 4 && params.bits != 8 {
531 return Err(MlxError::InvalidArgument(format!(
532 "bf16 SIMD kernel: unsupported bits value {}; only 4, 6, and 8 are supported",
533 params.bits
534 )));
535 }
536 if params.m == 0 || params.k == 0 || params.n == 0 || params.group_size == 0 {
537 return Err(MlxError::InvalidArgument(
538 "M, K, N, and group_size must all be > 0".into(),
539 ));
540 }
541
542 let expected_input = (params.m as usize) * (params.k as usize) * DType::BF16.size_of();
544 if input.byte_len() < expected_input {
545 return Err(MlxError::InvalidArgument(format!(
546 "bf16 input buffer too small: expected {} bytes for [{}x{}] bf16, got {}",
547 expected_input, params.m, params.k, input.byte_len()
548 )));
549 }
550
551 let expected_w = expected_weight_bytes(params.k, params.n, params.bits);
552 if packed_weights.byte_len() < expected_w {
553 return Err(MlxError::InvalidArgument(format!(
554 "Weight buffer too small: expected {} bytes, got {}",
555 expected_w, packed_weights.byte_len()
556 )));
557 }
558
559 let expected_s = expected_scales_bytes(params.k, params.n, params.group_size);
560 if scales.byte_len() < expected_s {
561 return Err(MlxError::InvalidArgument(format!(
562 "Scales buffer too small: expected {} bytes, got {}",
563 expected_s, scales.byte_len()
564 )));
565 }
566 if biases.byte_len() < expected_s {
567 return Err(MlxError::InvalidArgument(format!(
568 "Biases buffer too small: expected {} bytes, got {}",
569 expected_s, biases.byte_len()
570 )));
571 }
572
573 let pipeline = registry.get_pipeline("quantized_matmul_simd_bf16", device.metal_device())?;
574
575 let output_bytes = (params.m as usize) * (params.n as usize) * DType::BF16.size_of();
577 let output = device.alloc_buffer(
578 output_bytes,
579 DType::BF16,
580 vec![params.m as usize, params.n as usize],
581 )?;
582
583 let gpu_params = QMatmulBf16GpuParams {
584 m: params.m,
585 k: params.k,
586 n: params.n,
587 group_size: params.group_size,
588 bits: params.bits,
589 };
590 let params_bytes = std::mem::size_of::<QMatmulBf16GpuParams>();
591 let mut params_buf = device.alloc_buffer(params_bytes, DType::U32, vec![5])?;
592 {
593 let slice: &mut [QMatmulBf16GpuParams] = bytemuck::cast_slice_mut(
594 params_buf
595 .as_mut_slice::<u8>()
596 .map_err(|e| MlxError::InvalidArgument(format!("params buf write: {e}")))?,
597 );
598 slice[0] = gpu_params;
599 }
600
601 let num_simdgroups = 2u64;
602 let results_per_simdgroup = 4u64;
603 let bn = num_simdgroups * results_per_simdgroup; let threadgroup_size = metal::MTLSize::new(32, num_simdgroups, 1);
606 let threadgroups = metal::MTLSize::new(
607 params.m as u64,
608 (params.n as u64 + bn - 1) / bn,
609 1,
610 );
611
612 encoder.encode_threadgroups(
613 pipeline,
614 &[
615 (0, input),
616 (1, packed_weights),
617 (2, scales),
618 (3, biases),
619 (4, &output),
620 (5, ¶ms_buf),
621 ],
622 threadgroups,
623 threadgroup_size,
624 );
625
626 Ok(output)
627}
628
629pub fn dispatch_quantized_matmul_simd_bf16_expert(
652 encoder: &mut CommandEncoder,
653 registry: &mut KernelRegistry,
654 device: &MlxDevice,
655 input: &MlxBuffer,
656 packed_weights: &MlxBuffer,
657 scales: &MlxBuffer,
658 biases: &MlxBuffer,
659 params: &QuantizedMatmulParams,
660 expert_offset_bytes: u32,
661 scales_offset_bytes: u32,
662 biases_offset_bytes: u32,
663) -> Result<MlxBuffer> {
664 if !can_use_simd_kernel_bf16(params) {
667 return Err(MlxError::InvalidArgument(
668 "dispatch_quantized_matmul_simd_bf16_expert: dimensions do not satisfy bf16 SIMD \
669 alignment requirements (N%8==0 and K%512==0 for 4-bit, K%256==0 for 8-bit)".into(),
670 ));
671 }
672
673 if params.bits != 4 && params.bits != 8 {
674 return Err(MlxError::InvalidArgument(format!(
675 "bf16 expert kernel: unsupported bits value {}; only 4 and 8 are supported",
676 params.bits
677 )));
678 }
679 if params.m == 0 || params.k == 0 || params.n == 0 || params.group_size == 0 {
680 return Err(MlxError::InvalidArgument(
681 "M, K, N, and group_size must all be > 0".into(),
682 ));
683 }
684
685 let expert_weight_bytes = expected_weight_bytes(params.k, params.n, params.bits);
688 let expert_scales_bytes = expected_scales_bytes(params.k, params.n, params.group_size);
689
690 if packed_weights.byte_len() < (expert_offset_bytes as usize) + expert_weight_bytes {
691 return Err(MlxError::InvalidArgument(format!(
692 "packed_weights too small for expert slice: offset={} + size={} > buffer={}",
693 expert_offset_bytes, expert_weight_bytes, packed_weights.byte_len()
694 )));
695 }
696 if scales.byte_len() < (scales_offset_bytes as usize) + expert_scales_bytes {
697 return Err(MlxError::InvalidArgument(format!(
698 "scales buffer too small for expert slice: offset={} + size={} > buffer={}",
699 scales_offset_bytes, expert_scales_bytes, scales.byte_len()
700 )));
701 }
702 if biases.byte_len() < (biases_offset_bytes as usize) + expert_scales_bytes {
703 return Err(MlxError::InvalidArgument(format!(
704 "biases buffer too small for expert slice: offset={} + size={} > buffer={}",
705 biases_offset_bytes, expert_scales_bytes, biases.byte_len()
706 )));
707 }
708
709 let pipeline = registry.get_pipeline("quantized_matmul_simd_bf16_expert", device.metal_device())?;
710
711 let output_bytes = (params.m as usize) * (params.n as usize) * DType::BF16.size_of();
712 let output = device.alloc_buffer(
713 output_bytes,
714 DType::BF16,
715 vec![params.m as usize, params.n as usize],
716 )?;
717
718 let gpu_params = QMatmulBf16GpuParams {
719 m: params.m,
720 k: params.k,
721 n: params.n,
722 group_size: params.group_size,
723 bits: params.bits,
724 };
725 let params_bytes = std::mem::size_of::<QMatmulBf16GpuParams>();
726 let mut params_buf = device.alloc_buffer(params_bytes, DType::U32, vec![5])?;
727 {
728 let slice: &mut [QMatmulBf16GpuParams] = bytemuck::cast_slice_mut(
729 params_buf
730 .as_mut_slice::<u8>()
731 .map_err(|e| MlxError::InvalidArgument(format!("params buf write: {e}")))?,
732 );
733 slice[0] = gpu_params;
734 }
735
736 let mut expert_offset_buf = device.alloc_buffer(4, DType::U32, vec![1])?;
738 {
739 let s: &mut [u32] = expert_offset_buf
740 .as_mut_slice()
741 .map_err(|e| MlxError::InvalidArgument(format!("expert_offset buf: {e}")))?;
742 s[0] = expert_offset_bytes;
743 }
744 let mut scales_offset_buf = device.alloc_buffer(4, DType::U32, vec![1])?;
745 {
746 let s: &mut [u32] = scales_offset_buf
747 .as_mut_slice()
748 .map_err(|e| MlxError::InvalidArgument(format!("scales_offset buf: {e}")))?;
749 s[0] = scales_offset_bytes;
750 }
751 let mut biases_offset_buf = device.alloc_buffer(4, DType::U32, vec![1])?;
752 {
753 let s: &mut [u32] = biases_offset_buf
754 .as_mut_slice()
755 .map_err(|e| MlxError::InvalidArgument(format!("biases_offset buf: {e}")))?;
756 s[0] = biases_offset_bytes;
757 }
758
759 let num_simdgroups = 2u64;
760 let results_per_simdgroup = 4u64;
761 let bn = num_simdgroups * results_per_simdgroup;
762
763 let threadgroup_size = metal::MTLSize::new(32, num_simdgroups, 1);
764 let threadgroups = metal::MTLSize::new(
765 params.m as u64,
766 (params.n as u64 + bn - 1) / bn,
767 1,
768 );
769
770 encoder.encode_threadgroups(
771 pipeline,
772 &[
773 (0, input),
774 (1, packed_weights),
775 (2, scales),
776 (3, biases),
777 (4, &output),
778 (5, ¶ms_buf),
779 (6, &expert_offset_buf),
780 (7, &scales_offset_buf),
781 (8, &biases_offset_buf),
782 ],
783 threadgroups,
784 threadgroup_size,
785 );
786
787 Ok(output)
788}
789
790#[cfg(test)]
791#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
792mod tests {
793 use super::*;
794 use crate::MlxDevice;
795
796 fn f32_to_bf16_bits(val: f32) -> u16 {
801 (val.to_bits() >> 16) as u16
802 }
803
804 fn f32_to_f16_bits(val: f32) -> u16 {
807 let bits = val.to_bits();
808 let sign = (bits >> 16) & 0x8000;
809 let exp = ((bits >> 23) & 0xFF) as i32;
810 let mantissa = bits & 0x007F_FFFF;
811
812 if exp == 255 {
813 let m = if mantissa != 0 { 0x0200 } else { 0 };
815 return (sign | 0x7C00 | m) as u16;
816 }
817
818 let new_exp = exp - 127 + 15;
820
821 if new_exp >= 31 {
822 return (sign | 0x7C00) as u16;
824 }
825
826 if new_exp <= 0 {
827 if new_exp < -10 {
829 return sign as u16; }
831 let m = (mantissa | 0x0080_0000) >> (1 - new_exp + 13);
832 return (sign | m) as u16;
833 }
834
835 let m = mantissa >> 13;
837 let round_bit = (mantissa >> 12) & 1;
838 let sticky = if (mantissa & 0xFFF) != 0 { 1u32 } else { 0 };
839 let round_up = round_bit & (sticky | m);
840 let result = sign | ((new_exp as u32) << 10) | m;
841 (result + round_up) as u16
842 }
843
844 fn f16_bits_to_f32(bits: u16) -> f32 {
846 let sign = ((bits as u32 & 0x8000) as u32) << 16;
847 let exp = (bits >> 10) & 0x1F;
848 let mantissa = (bits & 0x03FF) as u32;
849
850 if exp == 0 {
851 if mantissa == 0 {
852 return f32::from_bits(sign); }
854 let mut m = mantissa;
856 let mut e: i32 = -14;
857 while (m & 0x0400) == 0 {
858 m <<= 1;
859 e -= 1;
860 }
861 m &= 0x03FF;
862 let f32_exp = ((e + 127) as u32) << 23;
863 let f32_mantissa = m << 13;
864 return f32::from_bits(sign | f32_exp | f32_mantissa);
865 }
866
867 if exp == 31 {
868 let m = if mantissa != 0 { 0x007F_FFFF } else { 0 };
869 return f32::from_bits(sign | 0x7F80_0000 | m);
870 }
871
872 let f32_exp = ((exp as u32 - 15 + 127) as u32) << 23;
873 let f32_mantissa = mantissa << 13;
874 f32::from_bits(sign | f32_exp | f32_mantissa)
875 }
876
877 #[allow(dead_code)]
879 fn f16_buffer(device: &MlxDevice, shape: Vec<usize>, values: &[f32]) -> MlxBuffer {
880 let byte_len = values.len() * 2;
881 let mut buf = device.alloc_buffer(byte_len, DType::F16, shape).expect("alloc");
882 {
883 let slice: &mut [u16] = buf.as_mut_slice().expect("as_mut_slice");
884 for (i, &v) in values.iter().enumerate() {
885 slice[i] = f32_to_f16_bits(v);
886 }
887 }
888 buf
889 }
890
891 fn bf16_buffer(device: &MlxDevice, shape: Vec<usize>, values: &[f32]) -> MlxBuffer {
893 let byte_len = values.len() * 2;
894 let mut buf = device.alloc_buffer(byte_len, DType::BF16, shape).expect("alloc");
895 {
896 let slice: &mut [u16] = buf.as_mut_slice().expect("as_mut_slice");
897 for (i, &v) in values.iter().enumerate() {
898 slice[i] = f32_to_bf16_bits(v);
899 }
900 }
901 buf
902 }
903
904 fn f32_buffer(device: &MlxDevice, shape: Vec<usize>, values: &[f32]) -> MlxBuffer {
906 let byte_len = values.len() * 4;
907 let mut buf = device.alloc_buffer(byte_len, DType::F32, shape).expect("alloc");
908 {
909 let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
910 slice.copy_from_slice(values);
911 }
912 buf
913 }
914
915 fn pack_4bit_buffer(device: &MlxDevice, n: usize, k: usize, quant_values: &[u8]) -> MlxBuffer {
919 let values_per_pack = 8;
920 let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
921 let total_packs = n * packs_per_row;
922 let byte_len = total_packs * 4;
923
924 let mut buf = device.alloc_buffer(byte_len, DType::U32, vec![n, packs_per_row]).expect("alloc");
925 {
926 let slice: &mut [u32] = buf.as_mut_slice().expect("as_mut_slice");
927 for col in 0..n {
928 for pack in 0..packs_per_row {
929 let mut packed: u32 = 0;
930 for i in 0..values_per_pack {
931 let k_idx = pack * values_per_pack + i;
932 if k_idx < k {
933 let val = quant_values[col * k + k_idx] as u32 & 0xF;
934 packed |= val << (4 * i);
935 }
936 }
937 slice[col * packs_per_row + pack] = packed;
938 }
939 }
940 }
941 buf
942 }
943
944 fn pack_6bit_buffer(device: &MlxDevice, n: usize, k: usize, quant_values: &[u8]) -> MlxBuffer {
946 let triplets_per_row = (k + 3) / 4;
949 let row_bytes = triplets_per_row * 3;
950 let total_bytes = n * row_bytes;
951
952 let mut buf = device.alloc_buffer(total_bytes, DType::U8, vec![total_bytes]).expect("alloc");
953 {
954 let slice: &mut [u8] = buf.as_mut_slice().expect("as_mut_slice");
955 for col in 0..n {
956 for t in 0..triplets_per_row {
957 let mut packed: u32 = 0;
958 for i in 0..4 {
959 let k_idx = t * 4 + i;
960 if k_idx < k {
961 let val = quant_values[col * k + k_idx] as u32 & 0x3F;
962 packed |= val << (6 * i);
963 }
964 }
965 let base = col * row_bytes + t * 3;
966 slice[base] = (packed & 0xFF) as u8;
967 slice[base + 1] = ((packed >> 8) & 0xFF) as u8;
968 slice[base + 2] = ((packed >> 16) & 0xFF) as u8;
969 }
970 }
971 }
972 buf
973 }
974
975 fn pack_8bit_buffer(device: &MlxDevice, n: usize, k: usize, quant_values: &[u8]) -> MlxBuffer {
978 let values_per_pack = 4;
979 let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
980 let total_packs = n * packs_per_row;
981 let byte_len = total_packs * 4;
982
983 let mut buf = device.alloc_buffer(byte_len, DType::U32, vec![n, packs_per_row]).expect("alloc");
984 {
985 let slice: &mut [u32] = buf.as_mut_slice().expect("as_mut_slice");
986 for col in 0..n {
987 for pack in 0..packs_per_row {
988 let mut packed: u32 = 0;
989 for i in 0..values_per_pack {
990 let k_idx = pack * values_per_pack + i;
991 if k_idx < k {
992 let val = quant_values[col * k + k_idx] as u32 & 0xFF;
993 packed |= val << (8 * i);
994 }
995 }
996 slice[col * packs_per_row + pack] = packed;
997 }
998 }
999 }
1000 buf
1001 }
1002
1003 #[allow(dead_code)]
1005 fn read_f16(buf: &MlxBuffer) -> Vec<f32> {
1006 let slice: &[u16] = buf.as_slice().expect("as_slice");
1007 slice.iter().map(|&bits| f16_bits_to_f32(bits)).collect()
1008 }
1009
1010 fn read_f32(buf: &MlxBuffer) -> Vec<f32> {
1012 let slice: &[f32] = buf.as_slice().expect("as_slice");
1013 slice.to_vec()
1014 }
1015
1016 #[test]
1029 fn test_4bit_matmul_small_known() {
1030 let device = MlxDevice::new().expect("device");
1031 let mut registry = KernelRegistry::new();
1032 let mut encoder = device.command_encoder().expect("encoder");
1033
1034 let m = 1u32;
1035 let k = 4u32;
1036 let n = 2u32;
1037 let group_size = 64u32;
1038 let bits = 4u32;
1039
1040 let input = f32_buffer(&device, vec![m as usize, k as usize], &[1.0, 2.0, 3.0, 4.0]);
1041
1042 let quant_w: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
1044 let weight = pack_4bit_buffer(&device, n as usize, k as usize, &quant_w);
1045
1046 let scales = bf16_buffer(&device, vec![n as usize, 1], &[0.1, 0.2]);
1047 let biases = bf16_buffer(&device, vec![n as usize, 1], &[0.0, 0.0]);
1048
1049 let params = QuantizedMatmulParams { m, k, n, group_size, bits };
1050
1051 let output = quantized_matmul(
1052 &mut encoder, &mut registry, &device,
1053 &input, &weight, &scales, &biases, ¶ms,
1054 ).expect("quantized_matmul");
1055
1056 encoder.commit_and_wait().expect("commit");
1057
1058 let result = read_f32(&output);
1059 assert_eq!(result.len(), 2);
1060
1061 let tol = 1e-1; assert!(
1064 (result[0] - 3.0).abs() < tol,
1065 "output[0]={}, expected ~3.0", result[0]
1066 );
1067 assert!(
1068 (result[1] - 14.0).abs() < tol,
1069 "output[1]={}, expected ~14.0", result[1]
1070 );
1071 }
1072
1073 #[test]
1075 fn test_6bit_matmul_small_known() {
1076 let device = MlxDevice::new().expect("device");
1077 let mut registry = KernelRegistry::new();
1078 let mut encoder = device.command_encoder().expect("encoder");
1079
1080 let m = 1u32;
1081 let k = 4u32;
1082 let n = 2u32;
1083 let group_size = 64u32;
1084 let bits = 6u32;
1085
1086 let input = f32_buffer(&device, vec![m as usize, k as usize], &[1.0, 2.0, 3.0, 4.0]);
1087
1088 let quant_w: Vec<u8> = vec![1, 2, 3, 4, 10, 20, 30, 40];
1090 let weight = pack_6bit_buffer(&device, n as usize, k as usize, &quant_w);
1091
1092 let scales = bf16_buffer(&device, vec![n as usize, 1], &[0.1, 0.05]);
1093 let biases = bf16_buffer(&device, vec![n as usize, 1], &[0.0, 0.0]);
1094
1095 let params = QuantizedMatmulParams { m, k, n, group_size, bits };
1096
1097 let output = quantized_matmul(
1098 &mut encoder, &mut registry, &device,
1099 &input, &weight, &scales, &biases, ¶ms,
1100 ).expect("quantized_matmul");
1101
1102 encoder.commit_and_wait().expect("commit");
1103
1104 let result = read_f32(&output);
1105 assert_eq!(result.len(), 2);
1106
1107 let tol = 1e-1;
1112 assert!(
1113 (result[0] - 3.0).abs() < tol,
1114 "output[0]={}, expected ~3.0", result[0]
1115 );
1116 assert!(
1117 (result[1] - 15.0).abs() < tol,
1118 "output[1]={}, expected ~15.0", result[1]
1119 );
1120 }
1121
1122 #[test]
1124 fn test_4bit_matmul_with_bias() {
1125 let device = MlxDevice::new().expect("device");
1126 let mut registry = KernelRegistry::new();
1127 let mut encoder = device.command_encoder().expect("encoder");
1128
1129 let m = 1u32;
1130 let k = 4u32;
1131 let n = 1u32;
1132 let group_size = 64u32;
1133 let bits = 4u32;
1134
1135 let input = f32_buffer(&device, vec![1, 4], &[1.0, 1.0, 1.0, 1.0]);
1136
1137 let quant_w: Vec<u8> = vec![0, 0, 0, 0];
1139 let weight = pack_4bit_buffer(&device, 1, 4, &quant_w);
1140
1141 let scales = bf16_buffer(&device, vec![1, 1], &[1.0]);
1142 let biases = bf16_buffer(&device, vec![1, 1], &[0.5]);
1143
1144 let params = QuantizedMatmulParams { m, k, n, group_size, bits };
1145
1146 let output = quantized_matmul(
1147 &mut encoder, &mut registry, &device,
1148 &input, &weight, &scales, &biases, ¶ms,
1149 ).expect("quantized_matmul");
1150
1151 encoder.commit_and_wait().expect("commit");
1152
1153 let result = read_f32(&output);
1154 let tol = 1e-2;
1156 assert!(
1157 (result[0] - 2.0).abs() < tol,
1158 "output[0]={}, expected ~2.0", result[0]
1159 );
1160 }
1161
1162 #[test]
1164 fn test_4bit_batch_matmul() {
1165 let device = MlxDevice::new().expect("device");
1166 let mut registry = KernelRegistry::new();
1167 let mut encoder = device.command_encoder().expect("encoder");
1168
1169 let m = 2u32;
1170 let k = 4u32;
1171 let n = 1u32;
1172 let group_size = 64u32;
1173 let bits = 4u32;
1174
1175 let input = f32_buffer(&device, vec![2, 4], &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
1177
1178 let quant_w: Vec<u8> = vec![2, 4, 6, 8];
1180 let weight = pack_4bit_buffer(&device, 1, 4, &quant_w);
1181
1182 let scales = bf16_buffer(&device, vec![1, 1], &[0.5]);
1183 let biases = bf16_buffer(&device, vec![1, 1], &[0.0]);
1184
1185 let params = QuantizedMatmulParams { m, k, n, group_size, bits };
1186
1187 let output = quantized_matmul(
1188 &mut encoder, &mut registry, &device,
1189 &input, &weight, &scales, &biases, ¶ms,
1190 ).expect("quantized_matmul");
1191
1192 encoder.commit_and_wait().expect("commit");
1193
1194 let result = read_f32(&output);
1195 assert_eq!(result.len(), 2);
1196
1197 let tol = 1e-2;
1201 assert!((result[0] - 1.0).abs() < tol, "row0={}, expected 1.0", result[0]);
1202 assert!((result[1] - 2.0).abs() < tol, "row1={}, expected 2.0", result[1]);
1203 }
1204
1205 #[test]
1207 fn test_invalid_bits_returns_error() {
1208 let device = MlxDevice::new().expect("device");
1209 let mut registry = KernelRegistry::new();
1210 let mut encoder = device.command_encoder().expect("encoder");
1211
1212 let input = f32_buffer(&device, vec![1, 4], &[1.0; 4]);
1213 let weight = device.alloc_buffer(4, DType::U32, vec![1]).expect("alloc");
1215 let scales = bf16_buffer(&device, vec![1], &[1.0]);
1216 let biases = bf16_buffer(&device, vec![1], &[0.0]);
1217
1218 let params = QuantizedMatmulParams {
1219 m: 1, k: 4, n: 1, group_size: 64, bits: 5,
1220 };
1221
1222 let result = quantized_matmul(
1223 &mut encoder, &mut registry, &device,
1224 &input, &weight, &scales, &biases, ¶ms,
1225 );
1226
1227 assert!(result.is_err());
1228 match result {
1229 Err(MlxError::InvalidArgument(msg)) => {
1230 assert!(msg.contains("bits"), "Error should mention bits: {msg}");
1231 }
1232 other => panic!("Expected InvalidArgument, got {:?}", other),
1233 }
1234 }
1235
1236 #[test]
1238 fn test_mismatched_dimensions_returns_error() {
1239 let device = MlxDevice::new().expect("device");
1240 let mut registry = KernelRegistry::new();
1241 let mut encoder = device.command_encoder().expect("encoder");
1242
1243 let input = f32_buffer(&device, vec![1, 4], &[1.0; 4]);
1245 let weight = device.alloc_buffer(4, DType::U32, vec![1]).expect("alloc");
1246 let scales = bf16_buffer(&device, vec![1], &[1.0]);
1247 let biases = bf16_buffer(&device, vec![1], &[0.0]);
1248
1249 let params = QuantizedMatmulParams {
1250 m: 1, k: 128, n: 1, group_size: 64, bits: 4,
1251 };
1252
1253 let result = quantized_matmul(
1254 &mut encoder, &mut registry, &device,
1255 &input, &weight, &scales, &biases, ¶ms,
1256 );
1257
1258 assert!(result.is_err());
1259 match result {
1260 Err(MlxError::InvalidArgument(msg)) => {
1261 assert!(msg.contains("Input buffer too small"), "msg: {msg}");
1262 }
1263 other => panic!("Expected InvalidArgument for input size, got {:?}", other),
1264 }
1265 }
1266
1267 #[test]
1280 fn test_8bit_matmul_small_known() {
1281 let device = MlxDevice::new().expect("device");
1282 let mut registry = KernelRegistry::new();
1283 let mut encoder = device.command_encoder().expect("encoder");
1284
1285 let m = 1u32;
1286 let k = 4u32;
1287 let n = 2u32;
1288 let group_size = 64u32;
1289 let bits = 8u32;
1290
1291 let input = f32_buffer(&device, vec![m as usize, k as usize], &[1.0, 2.0, 3.0, 4.0]);
1292
1293 let quant_w: Vec<u8> = vec![10, 20, 30, 40, 50, 60, 70, 80];
1295 let weight = pack_8bit_buffer(&device, n as usize, k as usize, &quant_w);
1296
1297 let scales = bf16_buffer(&device, vec![n as usize, 1], &[0.01, 0.02]);
1298 let biases = bf16_buffer(&device, vec![n as usize, 1], &[0.0, 0.0]);
1299
1300 let params = QuantizedMatmulParams { m, k, n, group_size, bits };
1301
1302 let output = quantized_matmul(
1303 &mut encoder, &mut registry, &device,
1304 &input, &weight, &scales, &biases, ¶ms,
1305 ).expect("quantized_matmul");
1306
1307 encoder.commit_and_wait().expect("commit");
1308
1309 let result = read_f32(&output);
1310 assert_eq!(result.len(), 2);
1311
1312 let tol = 1e-1;
1314 assert!(
1315 (result[0] - 3.0).abs() < tol,
1316 "output[0]={}, expected ~3.0", result[0]
1317 );
1318 assert!(
1319 (result[1] - 14.0).abs() < tol,
1320 "output[1]={}, expected ~14.0", result[1]
1321 );
1322 }
1323
1324 #[test]
1326 fn test_8bit_matmul_with_bias() {
1327 let device = MlxDevice::new().expect("device");
1328 let mut registry = KernelRegistry::new();
1329 let mut encoder = device.command_encoder().expect("encoder");
1330
1331 let m = 1u32;
1332 let k = 4u32;
1333 let n = 1u32;
1334 let group_size = 64u32;
1335 let bits = 8u32;
1336
1337 let input = f32_buffer(&device, vec![1, 4], &[1.0, 1.0, 1.0, 1.0]);
1338
1339 let quant_w: Vec<u8> = vec![0, 0, 0, 0];
1341 let weight = pack_8bit_buffer(&device, 1, 4, &quant_w);
1342
1343 let scales = bf16_buffer(&device, vec![1, 1], &[1.0]);
1344 let biases = bf16_buffer(&device, vec![1, 1], &[0.5]);
1345
1346 let params = QuantizedMatmulParams { m, k, n, group_size, bits };
1347
1348 let output = quantized_matmul(
1349 &mut encoder, &mut registry, &device,
1350 &input, &weight, &scales, &biases, ¶ms,
1351 ).expect("quantized_matmul");
1352
1353 encoder.commit_and_wait().expect("commit");
1354
1355 let result = read_f32(&output);
1356 let tol = 1e-2;
1358 assert!(
1359 (result[0] - 2.0).abs() < tol,
1360 "output[0]={}, expected ~2.0", result[0]
1361 );
1362 }
1363
1364 #[test]
1366 fn test_4bit_multiple_groups() {
1367 let device = MlxDevice::new().expect("device");
1368 let mut registry = KernelRegistry::new();
1369 let mut encoder = device.command_encoder().expect("encoder");
1370
1371 let m = 1u32;
1373 let k = 8u32;
1374 let n = 1u32;
1375 let group_size = 4u32;
1376 let bits = 4u32;
1377
1378 let input = f32_buffer(&device, vec![1, 8], &[1.0; 8]);
1379
1380 let quant_w: Vec<u8> = vec![1, 1, 1, 1, 2, 2, 2, 2];
1382 let weight = pack_4bit_buffer(&device, 1, 8, &quant_w);
1383
1384 let scales = bf16_buffer(&device, vec![1, 2], &[0.5, 1.0]);
1386 let biases = bf16_buffer(&device, vec![1, 2], &[0.0, 0.0]);
1387
1388 let params = QuantizedMatmulParams { m, k, n, group_size, bits };
1389
1390 let output = quantized_matmul(
1391 &mut encoder, &mut registry, &device,
1392 &input, &weight, &scales, &biases, ¶ms,
1393 ).expect("quantized_matmul");
1394
1395 encoder.commit_and_wait().expect("commit");
1396
1397 let result = read_f32(&output);
1398 let tol = 1e-1;
1402 assert!(
1403 (result[0] - 10.0).abs() < tol,
1404 "output[0]={}, expected ~10.0", result[0]
1405 );
1406 }
1407}