use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
#[derive(Debug, Clone, Copy)]
pub struct QuantizedMatmulParams {
pub m: u32,
pub k: u32,
pub n: u32,
pub group_size: u32,
pub bits: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct QuantizedMatmulGpuParams {
m: u32,
k: u32,
n: u32,
group_size: u32,
bits: u32,
}
fn expected_weight_bytes(k: u32, n: u32, bits: u32) -> usize {
match bits {
4 => {
let values_per_pack = 8u32;
let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
(n as usize) * (packs_per_row as usize) * 4
}
6 => {
let triplets_per_row = (k + 3) / 4;
(n as usize) * (triplets_per_row as usize) * 3
}
8 => {
let values_per_pack = 4u32;
let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
(n as usize) * (packs_per_row as usize) * 4
}
_ => 0,
}
}
fn expected_scales_bytes(k: u32, n: u32, group_size: u32) -> usize {
let num_groups = (k + group_size - 1) / group_size;
(n as usize) * (num_groups as usize) * 2 }
pub fn quantized_matmul(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
params: &QuantizedMatmulParams,
) -> Result<MlxBuffer> {
if params.bits != 4 && params.bits != 6 && params.bits != 8 {
return Err(MlxError::InvalidArgument(format!(
"Unsupported bits value {}; only 4, 6, and 8 are supported",
params.bits
)));
}
if params.m == 0 || params.k == 0 || params.n == 0 {
return Err(MlxError::InvalidArgument(
"M, K, and N must all be > 0".into(),
));
}
if params.group_size == 0 {
return Err(MlxError::InvalidArgument(
"group_size must be > 0".into(),
));
}
let expected_input = (params.m as usize) * (params.k as usize) * DType::F32.size_of();
if input.byte_len() < expected_input {
return Err(MlxError::InvalidArgument(format!(
"Input buffer too small: expected at least {} bytes for [{}x{}] f32, got {}",
expected_input, params.m, params.k, input.byte_len()
)));
}
let expected_w = expected_weight_bytes(params.k, params.n, params.bits);
if weight.byte_len() < expected_w {
return Err(MlxError::InvalidArgument(format!(
"Weight buffer too small: expected at least {} bytes for {}bit [{}x{}], got {}",
expected_w, params.bits, params.n, params.k, weight.byte_len()
)));
}
let expected_s = expected_scales_bytes(params.k, params.n, params.group_size);
if scales.byte_len() < expected_s {
return Err(MlxError::InvalidArgument(format!(
"Scales buffer too small: expected at least {} bytes, got {}",
expected_s, scales.byte_len()
)));
}
if biases.byte_len() < expected_s {
return Err(MlxError::InvalidArgument(format!(
"Biases buffer too small: expected at least {} bytes, got {}",
expected_s, biases.byte_len()
)));
}
let pipeline = registry.get_pipeline("quantized_matmul", device.metal_device())?;
let output_bytes = (params.m as usize) * (params.n as usize) * DType::F32.size_of();
let output = device.alloc_buffer(
output_bytes,
DType::F32,
vec![params.m as usize, params.n as usize],
)?;
let gpu_params = QuantizedMatmulGpuParams {
m: params.m,
k: params.k,
n: params.n,
group_size: params.group_size,
bits: params.bits,
};
let params_bytes = std::mem::size_of::<QuantizedMatmulGpuParams>();
let mut params_buf = device.alloc_buffer(params_bytes, DType::U32, vec![5])?;
{
let slice: &mut [QuantizedMatmulGpuParams] = bytemuck::cast_slice_mut(
params_buf
.as_mut_slice::<u8>()
.map_err(|e| MlxError::InvalidArgument(format!("params buf write: {e}")))?,
);
slice[0] = gpu_params;
}
let tg_x = 16u64.min(params.n as u64);
let tg_y = 16u64.min(params.m as u64);
let threadgroup_size = metal::MTLSize::new(tg_x, tg_y, 1);
let grid_groups = metal::MTLSize::new(
(params.n as u64 + tg_x - 1) / tg_x,
(params.m as u64 + tg_y - 1) / tg_y,
1,
);
encoder.encode_threadgroups(
pipeline,
&[
(0, input),
(1, weight),
(2, scales),
(3, biases),
(4, &output),
(5, ¶ms_buf),
],
grid_groups,
threadgroup_size,
);
Ok(output)
}
fn can_use_simd_kernel(params: &QuantizedMatmulParams) -> bool {
let bn = 8u32; if params.n % bn != 0 {
return false;
}
match params.bits {
4 => params.k % 256 == 0, 8 => params.k % 256 == 0,
_ => false,
}
}
fn can_use_simd_kernel_bf16(params: &QuantizedMatmulParams) -> bool {
let bn = 8u32;
if params.n % bn != 0 {
return false;
}
match params.bits {
4 => params.k % 512 == 0, 8 => params.k % 256 == 0,
_ => false,
}
}
pub fn quantized_matmul_simd(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
params: &QuantizedMatmulParams,
) -> Result<MlxBuffer> {
if !can_use_simd_kernel(params) {
return quantized_matmul(encoder, registry, device, input, weight, scales, biases, params);
}
if params.bits == 6 {
return quantized_matmul(encoder, registry, device, input, weight, scales, biases, params);
}
if params.bits != 4 && params.bits != 8 {
return Err(MlxError::InvalidArgument(format!(
"SIMD kernel: unsupported bits value {}; only 4, 6, and 8 are supported",
params.bits
)));
}
if params.m == 0 || params.k == 0 || params.n == 0 {
return Err(MlxError::InvalidArgument(
"M, K, and N must all be > 0".into(),
));
}
if params.group_size == 0 {
return Err(MlxError::InvalidArgument(
"group_size must be > 0".into(),
));
}
let expected_input = (params.m as usize) * (params.k as usize) * DType::F32.size_of();
if input.byte_len() < expected_input {
return Err(MlxError::InvalidArgument(format!(
"Input buffer too small: expected at least {} bytes for [{}x{}] f32, got {}",
expected_input, params.m, params.k, input.byte_len()
)));
}
let expected_w = expected_weight_bytes(params.k, params.n, params.bits);
if weight.byte_len() < expected_w {
return Err(MlxError::InvalidArgument(format!(
"Weight buffer too small: expected at least {} bytes for {}bit [{}x{}], got {}",
expected_w, params.bits, params.n, params.k, weight.byte_len()
)));
}
let expected_s = expected_scales_bytes(params.k, params.n, params.group_size);
if scales.byte_len() < expected_s {
return Err(MlxError::InvalidArgument(format!(
"Scales buffer too small: expected at least {} bytes, got {}",
expected_s, scales.byte_len()
)));
}
if biases.byte_len() < expected_s {
return Err(MlxError::InvalidArgument(format!(
"Biases buffer too small: expected at least {} bytes, got {}",
expected_s, biases.byte_len()
)));
}
let pipeline = registry.get_pipeline("quantized_matmul_simd", device.metal_device())?;
let output_bytes = (params.m as usize) * (params.n as usize) * DType::F32.size_of();
let output = device.alloc_buffer(
output_bytes,
DType::F32,
vec![params.m as usize, params.n as usize],
)?;
let gpu_params = QuantizedMatmulGpuParams {
m: params.m,
k: params.k,
n: params.n,
group_size: params.group_size,
bits: params.bits,
};
let params_bytes = std::mem::size_of::<QuantizedMatmulGpuParams>();
let mut params_buf = device.alloc_buffer(params_bytes, DType::U32, vec![5])?;
{
let slice: &mut [QuantizedMatmulGpuParams] = bytemuck::cast_slice_mut(
params_buf
.as_mut_slice::<u8>()
.map_err(|e| MlxError::InvalidArgument(format!("params buf write: {e}")))?,
);
slice[0] = gpu_params;
}
let num_simdgroups = 2u64;
let results_per_simdgroup = 4u64;
let bn = num_simdgroups * results_per_simdgroup;
let threadgroup_size = metal::MTLSize::new(32, num_simdgroups, 1);
let threadgroups = metal::MTLSize::new(
params.m as u64,
(params.n as u64 + bn - 1) / bn,
1,
);
encoder.encode_threadgroups(
pipeline,
&[
(0, input),
(1, weight),
(2, scales),
(3, biases),
(4, &output),
(5, ¶ms_buf),
],
threadgroups,
threadgroup_size,
);
Ok(output)
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct QMatmulBf16GpuParams {
m: u32,
k: u32,
n: u32,
group_size: u32,
bits: u32,
}
pub fn dispatch_quantized_matmul_simd_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
packed_weights: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
params: &QuantizedMatmulParams,
) -> Result<MlxBuffer> {
if !can_use_simd_kernel_bf16(params) {
let n_in = (params.m as usize) * (params.k as usize);
let f32_input = if input.dtype() == DType::BF16 {
let f32_buf = device.alloc_buffer(n_in * DType::F32.size_of(), DType::F32, vec![params.m as usize, params.k as usize])?;
crate::ops::elementwise::cast(
encoder, registry, device.metal_device(),
input, &f32_buf, n_in,
crate::ops::elementwise::CastDirection::BF16ToF32,
)?;
Some(f32_buf)
} else {
None
};
let actual_input = f32_input.as_ref().unwrap_or(input);
let f32_result = quantized_matmul(encoder, registry, device, actual_input, packed_weights, scales, biases, params)?;
let n_out = (params.m as usize) * (params.n as usize);
let bf16_out = device.alloc_buffer(n_out * DType::BF16.size_of(), DType::BF16, vec![params.m as usize, params.n as usize])?;
crate::ops::elementwise::cast(
encoder, registry, device.metal_device(),
&f32_result, &bf16_out, n_out,
crate::ops::elementwise::CastDirection::F32ToBF16,
)?;
return Ok(bf16_out);
}
if params.bits == 6 {
let n_in = (params.m as usize) * (params.k as usize);
let f32_input = if input.dtype() == DType::BF16 {
let f32_buf = device.alloc_buffer(n_in * DType::F32.size_of(), DType::F32, vec![params.m as usize, params.k as usize])?;
crate::ops::elementwise::cast(
encoder, registry, device.metal_device(),
input, &f32_buf, n_in,
crate::ops::elementwise::CastDirection::BF16ToF32,
)?;
Some(f32_buf)
} else {
None
};
let actual_input = f32_input.as_ref().unwrap_or(input);
let f32_result = quantized_matmul(encoder, registry, device, actual_input, packed_weights, scales, biases, params)?;
let n_out = (params.m as usize) * (params.n as usize);
let bf16_out = device.alloc_buffer(n_out * DType::BF16.size_of(), DType::BF16, vec![params.m as usize, params.n as usize])?;
crate::ops::elementwise::cast(
encoder, registry, device.metal_device(),
&f32_result, &bf16_out, n_out,
crate::ops::elementwise::CastDirection::F32ToBF16,
)?;
return Ok(bf16_out);
}
if params.bits != 4 && params.bits != 8 {
return Err(MlxError::InvalidArgument(format!(
"bf16 SIMD kernel: unsupported bits value {}; only 4, 6, and 8 are supported",
params.bits
)));
}
if params.m == 0 || params.k == 0 || params.n == 0 || params.group_size == 0 {
return Err(MlxError::InvalidArgument(
"M, K, N, and group_size must all be > 0".into(),
));
}
let expected_input = (params.m as usize) * (params.k as usize) * DType::BF16.size_of();
if input.byte_len() < expected_input {
return Err(MlxError::InvalidArgument(format!(
"bf16 input buffer too small: expected {} bytes for [{}x{}] bf16, got {}",
expected_input, params.m, params.k, input.byte_len()
)));
}
let expected_w = expected_weight_bytes(params.k, params.n, params.bits);
if packed_weights.byte_len() < expected_w {
return Err(MlxError::InvalidArgument(format!(
"Weight buffer too small: expected {} bytes, got {}",
expected_w, packed_weights.byte_len()
)));
}
let expected_s = expected_scales_bytes(params.k, params.n, params.group_size);
if scales.byte_len() < expected_s {
return Err(MlxError::InvalidArgument(format!(
"Scales buffer too small: expected {} bytes, got {}",
expected_s, scales.byte_len()
)));
}
if biases.byte_len() < expected_s {
return Err(MlxError::InvalidArgument(format!(
"Biases buffer too small: expected {} bytes, got {}",
expected_s, biases.byte_len()
)));
}
let pipeline = registry.get_pipeline("quantized_matmul_simd_bf16", device.metal_device())?;
let output_bytes = (params.m as usize) * (params.n as usize) * DType::BF16.size_of();
let output = device.alloc_buffer(
output_bytes,
DType::BF16,
vec![params.m as usize, params.n as usize],
)?;
let gpu_params = QMatmulBf16GpuParams {
m: params.m,
k: params.k,
n: params.n,
group_size: params.group_size,
bits: params.bits,
};
let params_bytes = std::mem::size_of::<QMatmulBf16GpuParams>();
let mut params_buf = device.alloc_buffer(params_bytes, DType::U32, vec![5])?;
{
let slice: &mut [QMatmulBf16GpuParams] = bytemuck::cast_slice_mut(
params_buf
.as_mut_slice::<u8>()
.map_err(|e| MlxError::InvalidArgument(format!("params buf write: {e}")))?,
);
slice[0] = gpu_params;
}
let num_simdgroups = 2u64;
let results_per_simdgroup = 4u64;
let bn = num_simdgroups * results_per_simdgroup;
let threadgroup_size = metal::MTLSize::new(32, num_simdgroups, 1);
let threadgroups = metal::MTLSize::new(
params.m as u64,
(params.n as u64 + bn - 1) / bn,
1,
);
encoder.encode_threadgroups(
pipeline,
&[
(0, input),
(1, packed_weights),
(2, scales),
(3, biases),
(4, &output),
(5, ¶ms_buf),
],
threadgroups,
threadgroup_size,
);
Ok(output)
}
pub fn dispatch_quantized_matmul_simd_bf16_expert(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
packed_weights: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
params: &QuantizedMatmulParams,
expert_offset_bytes: u32,
scales_offset_bytes: u32,
biases_offset_bytes: u32,
) -> Result<MlxBuffer> {
if !can_use_simd_kernel_bf16(params) {
return Err(MlxError::InvalidArgument(
"dispatch_quantized_matmul_simd_bf16_expert: dimensions do not satisfy bf16 SIMD \
alignment requirements (N%8==0 and K%512==0 for 4-bit, K%256==0 for 8-bit)".into(),
));
}
if params.bits != 4 && params.bits != 8 {
return Err(MlxError::InvalidArgument(format!(
"bf16 expert kernel: unsupported bits value {}; only 4 and 8 are supported",
params.bits
)));
}
if params.m == 0 || params.k == 0 || params.n == 0 || params.group_size == 0 {
return Err(MlxError::InvalidArgument(
"M, K, N, and group_size must all be > 0".into(),
));
}
let expert_weight_bytes = expected_weight_bytes(params.k, params.n, params.bits);
let expert_scales_bytes = expected_scales_bytes(params.k, params.n, params.group_size);
if packed_weights.byte_len() < (expert_offset_bytes as usize) + expert_weight_bytes {
return Err(MlxError::InvalidArgument(format!(
"packed_weights too small for expert slice: offset={} + size={} > buffer={}",
expert_offset_bytes, expert_weight_bytes, packed_weights.byte_len()
)));
}
if scales.byte_len() < (scales_offset_bytes as usize) + expert_scales_bytes {
return Err(MlxError::InvalidArgument(format!(
"scales buffer too small for expert slice: offset={} + size={} > buffer={}",
scales_offset_bytes, expert_scales_bytes, scales.byte_len()
)));
}
if biases.byte_len() < (biases_offset_bytes as usize) + expert_scales_bytes {
return Err(MlxError::InvalidArgument(format!(
"biases buffer too small for expert slice: offset={} + size={} > buffer={}",
biases_offset_bytes, expert_scales_bytes, biases.byte_len()
)));
}
let pipeline = registry.get_pipeline("quantized_matmul_simd_bf16_expert", device.metal_device())?;
let output_bytes = (params.m as usize) * (params.n as usize) * DType::BF16.size_of();
let output = device.alloc_buffer(
output_bytes,
DType::BF16,
vec![params.m as usize, params.n as usize],
)?;
let gpu_params = QMatmulBf16GpuParams {
m: params.m,
k: params.k,
n: params.n,
group_size: params.group_size,
bits: params.bits,
};
let params_bytes = std::mem::size_of::<QMatmulBf16GpuParams>();
let mut params_buf = device.alloc_buffer(params_bytes, DType::U32, vec![5])?;
{
let slice: &mut [QMatmulBf16GpuParams] = bytemuck::cast_slice_mut(
params_buf
.as_mut_slice::<u8>()
.map_err(|e| MlxError::InvalidArgument(format!("params buf write: {e}")))?,
);
slice[0] = gpu_params;
}
let mut expert_offset_buf = device.alloc_buffer(4, DType::U32, vec![1])?;
{
let s: &mut [u32] = expert_offset_buf
.as_mut_slice()
.map_err(|e| MlxError::InvalidArgument(format!("expert_offset buf: {e}")))?;
s[0] = expert_offset_bytes;
}
let mut scales_offset_buf = device.alloc_buffer(4, DType::U32, vec![1])?;
{
let s: &mut [u32] = scales_offset_buf
.as_mut_slice()
.map_err(|e| MlxError::InvalidArgument(format!("scales_offset buf: {e}")))?;
s[0] = scales_offset_bytes;
}
let mut biases_offset_buf = device.alloc_buffer(4, DType::U32, vec![1])?;
{
let s: &mut [u32] = biases_offset_buf
.as_mut_slice()
.map_err(|e| MlxError::InvalidArgument(format!("biases_offset buf: {e}")))?;
s[0] = biases_offset_bytes;
}
let num_simdgroups = 2u64;
let results_per_simdgroup = 4u64;
let bn = num_simdgroups * results_per_simdgroup;
let threadgroup_size = metal::MTLSize::new(32, num_simdgroups, 1);
let threadgroups = metal::MTLSize::new(
params.m as u64,
(params.n as u64 + bn - 1) / bn,
1,
);
encoder.encode_threadgroups(
pipeline,
&[
(0, input),
(1, packed_weights),
(2, scales),
(3, biases),
(4, &output),
(5, ¶ms_buf),
(6, &expert_offset_buf),
(7, &scales_offset_buf),
(8, &biases_offset_buf),
],
threadgroups,
threadgroup_size,
);
Ok(output)
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use crate::MlxDevice;
fn f32_to_bf16_bits(val: f32) -> u16 {
(val.to_bits() >> 16) as u16
}
fn f32_to_f16_bits(val: f32) -> u16 {
let bits = val.to_bits();
let sign = (bits >> 16) & 0x8000;
let exp = ((bits >> 23) & 0xFF) as i32;
let mantissa = bits & 0x007F_FFFF;
if exp == 255 {
let m = if mantissa != 0 { 0x0200 } else { 0 };
return (sign | 0x7C00 | m) as u16;
}
let new_exp = exp - 127 + 15;
if new_exp >= 31 {
return (sign | 0x7C00) as u16;
}
if new_exp <= 0 {
if new_exp < -10 {
return sign as u16; }
let m = (mantissa | 0x0080_0000) >> (1 - new_exp + 13);
return (sign | m) as u16;
}
let m = mantissa >> 13;
let round_bit = (mantissa >> 12) & 1;
let sticky = if (mantissa & 0xFFF) != 0 { 1u32 } else { 0 };
let round_up = round_bit & (sticky | m);
let result = sign | ((new_exp as u32) << 10) | m;
(result + round_up) as u16
}
fn f16_bits_to_f32(bits: u16) -> f32 {
let sign = ((bits as u32 & 0x8000) as u32) << 16;
let exp = (bits >> 10) & 0x1F;
let mantissa = (bits & 0x03FF) as u32;
if exp == 0 {
if mantissa == 0 {
return f32::from_bits(sign); }
let mut m = mantissa;
let mut e: i32 = -14;
while (m & 0x0400) == 0 {
m <<= 1;
e -= 1;
}
m &= 0x03FF;
let f32_exp = ((e + 127) as u32) << 23;
let f32_mantissa = m << 13;
return f32::from_bits(sign | f32_exp | f32_mantissa);
}
if exp == 31 {
let m = if mantissa != 0 { 0x007F_FFFF } else { 0 };
return f32::from_bits(sign | 0x7F80_0000 | m);
}
let f32_exp = ((exp as u32 - 15 + 127) as u32) << 23;
let f32_mantissa = mantissa << 13;
f32::from_bits(sign | f32_exp | f32_mantissa)
}
#[allow(dead_code)]
fn f16_buffer(device: &MlxDevice, shape: Vec<usize>, values: &[f32]) -> MlxBuffer {
let byte_len = values.len() * 2;
let mut buf = device.alloc_buffer(byte_len, DType::F16, shape).expect("alloc");
{
let slice: &mut [u16] = buf.as_mut_slice().expect("as_mut_slice");
for (i, &v) in values.iter().enumerate() {
slice[i] = f32_to_f16_bits(v);
}
}
buf
}
fn bf16_buffer(device: &MlxDevice, shape: Vec<usize>, values: &[f32]) -> MlxBuffer {
let byte_len = values.len() * 2;
let mut buf = device.alloc_buffer(byte_len, DType::BF16, shape).expect("alloc");
{
let slice: &mut [u16] = buf.as_mut_slice().expect("as_mut_slice");
for (i, &v) in values.iter().enumerate() {
slice[i] = f32_to_bf16_bits(v);
}
}
buf
}
fn f32_buffer(device: &MlxDevice, shape: Vec<usize>, values: &[f32]) -> MlxBuffer {
let byte_len = values.len() * 4;
let mut buf = device.alloc_buffer(byte_len, DType::F32, shape).expect("alloc");
{
let slice: &mut [f32] = buf.as_mut_slice().expect("as_mut_slice");
slice.copy_from_slice(values);
}
buf
}
fn pack_4bit_buffer(device: &MlxDevice, n: usize, k: usize, quant_values: &[u8]) -> MlxBuffer {
let values_per_pack = 8;
let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
let total_packs = n * packs_per_row;
let byte_len = total_packs * 4;
let mut buf = device.alloc_buffer(byte_len, DType::U32, vec![n, packs_per_row]).expect("alloc");
{
let slice: &mut [u32] = buf.as_mut_slice().expect("as_mut_slice");
for col in 0..n {
for pack in 0..packs_per_row {
let mut packed: u32 = 0;
for i in 0..values_per_pack {
let k_idx = pack * values_per_pack + i;
if k_idx < k {
let val = quant_values[col * k + k_idx] as u32 & 0xF;
packed |= val << (4 * i);
}
}
slice[col * packs_per_row + pack] = packed;
}
}
}
buf
}
fn pack_6bit_buffer(device: &MlxDevice, n: usize, k: usize, quant_values: &[u8]) -> MlxBuffer {
let triplets_per_row = (k + 3) / 4;
let row_bytes = triplets_per_row * 3;
let total_bytes = n * row_bytes;
let mut buf = device.alloc_buffer(total_bytes, DType::U8, vec![total_bytes]).expect("alloc");
{
let slice: &mut [u8] = buf.as_mut_slice().expect("as_mut_slice");
for col in 0..n {
for t in 0..triplets_per_row {
let mut packed: u32 = 0;
for i in 0..4 {
let k_idx = t * 4 + i;
if k_idx < k {
let val = quant_values[col * k + k_idx] as u32 & 0x3F;
packed |= val << (6 * i);
}
}
let base = col * row_bytes + t * 3;
slice[base] = (packed & 0xFF) as u8;
slice[base + 1] = ((packed >> 8) & 0xFF) as u8;
slice[base + 2] = ((packed >> 16) & 0xFF) as u8;
}
}
}
buf
}
fn pack_8bit_buffer(device: &MlxDevice, n: usize, k: usize, quant_values: &[u8]) -> MlxBuffer {
let values_per_pack = 4;
let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
let total_packs = n * packs_per_row;
let byte_len = total_packs * 4;
let mut buf = device.alloc_buffer(byte_len, DType::U32, vec![n, packs_per_row]).expect("alloc");
{
let slice: &mut [u32] = buf.as_mut_slice().expect("as_mut_slice");
for col in 0..n {
for pack in 0..packs_per_row {
let mut packed: u32 = 0;
for i in 0..values_per_pack {
let k_idx = pack * values_per_pack + i;
if k_idx < k {
let val = quant_values[col * k + k_idx] as u32 & 0xFF;
packed |= val << (8 * i);
}
}
slice[col * packs_per_row + pack] = packed;
}
}
}
buf
}
#[allow(dead_code)]
fn read_f16(buf: &MlxBuffer) -> Vec<f32> {
let slice: &[u16] = buf.as_slice().expect("as_slice");
slice.iter().map(|&bits| f16_bits_to_f32(bits)).collect()
}
fn read_f32(buf: &MlxBuffer) -> Vec<f32> {
let slice: &[f32] = buf.as_slice().expect("as_slice");
slice.to_vec()
}
#[test]
fn test_4bit_matmul_small_known() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let mut encoder = device.command_encoder().expect("encoder");
let m = 1u32;
let k = 4u32;
let n = 2u32;
let group_size = 64u32;
let bits = 4u32;
let input = f32_buffer(&device, vec![m as usize, k as usize], &[1.0, 2.0, 3.0, 4.0]);
let quant_w: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
let weight = pack_4bit_buffer(&device, n as usize, k as usize, &quant_w);
let scales = bf16_buffer(&device, vec![n as usize, 1], &[0.1, 0.2]);
let biases = bf16_buffer(&device, vec![n as usize, 1], &[0.0, 0.0]);
let params = QuantizedMatmulParams { m, k, n, group_size, bits };
let output = quantized_matmul(
&mut encoder, &mut registry, &device,
&input, &weight, &scales, &biases, ¶ms,
).expect("quantized_matmul");
encoder.commit_and_wait().expect("commit");
let result = read_f32(&output);
assert_eq!(result.len(), 2);
let tol = 1e-1; assert!(
(result[0] - 3.0).abs() < tol,
"output[0]={}, expected ~3.0", result[0]
);
assert!(
(result[1] - 14.0).abs() < tol,
"output[1]={}, expected ~14.0", result[1]
);
}
#[test]
fn test_6bit_matmul_small_known() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let mut encoder = device.command_encoder().expect("encoder");
let m = 1u32;
let k = 4u32;
let n = 2u32;
let group_size = 64u32;
let bits = 6u32;
let input = f32_buffer(&device, vec![m as usize, k as usize], &[1.0, 2.0, 3.0, 4.0]);
let quant_w: Vec<u8> = vec![1, 2, 3, 4, 10, 20, 30, 40];
let weight = pack_6bit_buffer(&device, n as usize, k as usize, &quant_w);
let scales = bf16_buffer(&device, vec![n as usize, 1], &[0.1, 0.05]);
let biases = bf16_buffer(&device, vec![n as usize, 1], &[0.0, 0.0]);
let params = QuantizedMatmulParams { m, k, n, group_size, bits };
let output = quantized_matmul(
&mut encoder, &mut registry, &device,
&input, &weight, &scales, &biases, ¶ms,
).expect("quantized_matmul");
encoder.commit_and_wait().expect("commit");
let result = read_f32(&output);
assert_eq!(result.len(), 2);
let tol = 1e-1;
assert!(
(result[0] - 3.0).abs() < tol,
"output[0]={}, expected ~3.0", result[0]
);
assert!(
(result[1] - 15.0).abs() < tol,
"output[1]={}, expected ~15.0", result[1]
);
}
#[test]
fn test_4bit_matmul_with_bias() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let mut encoder = device.command_encoder().expect("encoder");
let m = 1u32;
let k = 4u32;
let n = 1u32;
let group_size = 64u32;
let bits = 4u32;
let input = f32_buffer(&device, vec![1, 4], &[1.0, 1.0, 1.0, 1.0]);
let quant_w: Vec<u8> = vec![0, 0, 0, 0];
let weight = pack_4bit_buffer(&device, 1, 4, &quant_w);
let scales = bf16_buffer(&device, vec![1, 1], &[1.0]);
let biases = bf16_buffer(&device, vec![1, 1], &[0.5]);
let params = QuantizedMatmulParams { m, k, n, group_size, bits };
let output = quantized_matmul(
&mut encoder, &mut registry, &device,
&input, &weight, &scales, &biases, ¶ms,
).expect("quantized_matmul");
encoder.commit_and_wait().expect("commit");
let result = read_f32(&output);
let tol = 1e-2;
assert!(
(result[0] - 2.0).abs() < tol,
"output[0]={}, expected ~2.0", result[0]
);
}
#[test]
fn test_4bit_batch_matmul() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let mut encoder = device.command_encoder().expect("encoder");
let m = 2u32;
let k = 4u32;
let n = 1u32;
let group_size = 64u32;
let bits = 4u32;
let input = f32_buffer(&device, vec![2, 4], &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
let quant_w: Vec<u8> = vec![2, 4, 6, 8];
let weight = pack_4bit_buffer(&device, 1, 4, &quant_w);
let scales = bf16_buffer(&device, vec![1, 1], &[0.5]);
let biases = bf16_buffer(&device, vec![1, 1], &[0.0]);
let params = QuantizedMatmulParams { m, k, n, group_size, bits };
let output = quantized_matmul(
&mut encoder, &mut registry, &device,
&input, &weight, &scales, &biases, ¶ms,
).expect("quantized_matmul");
encoder.commit_and_wait().expect("commit");
let result = read_f32(&output);
assert_eq!(result.len(), 2);
let tol = 1e-2;
assert!((result[0] - 1.0).abs() < tol, "row0={}, expected 1.0", result[0]);
assert!((result[1] - 2.0).abs() < tol, "row1={}, expected 2.0", result[1]);
}
#[test]
fn test_invalid_bits_returns_error() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let mut encoder = device.command_encoder().expect("encoder");
let input = f32_buffer(&device, vec![1, 4], &[1.0; 4]);
let weight = device.alloc_buffer(4, DType::U32, vec![1]).expect("alloc");
let scales = bf16_buffer(&device, vec![1], &[1.0]);
let biases = bf16_buffer(&device, vec![1], &[0.0]);
let params = QuantizedMatmulParams {
m: 1, k: 4, n: 1, group_size: 64, bits: 5,
};
let result = quantized_matmul(
&mut encoder, &mut registry, &device,
&input, &weight, &scales, &biases, ¶ms,
);
assert!(result.is_err());
match result {
Err(MlxError::InvalidArgument(msg)) => {
assert!(msg.contains("bits"), "Error should mention bits: {msg}");
}
other => panic!("Expected InvalidArgument, got {:?}", other),
}
}
#[test]
fn test_mismatched_dimensions_returns_error() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let mut encoder = device.command_encoder().expect("encoder");
let input = f32_buffer(&device, vec![1, 4], &[1.0; 4]);
let weight = device.alloc_buffer(4, DType::U32, vec![1]).expect("alloc");
let scales = bf16_buffer(&device, vec![1], &[1.0]);
let biases = bf16_buffer(&device, vec![1], &[0.0]);
let params = QuantizedMatmulParams {
m: 1, k: 128, n: 1, group_size: 64, bits: 4,
};
let result = quantized_matmul(
&mut encoder, &mut registry, &device,
&input, &weight, &scales, &biases, ¶ms,
);
assert!(result.is_err());
match result {
Err(MlxError::InvalidArgument(msg)) => {
assert!(msg.contains("Input buffer too small"), "msg: {msg}");
}
other => panic!("Expected InvalidArgument for input size, got {:?}", other),
}
}
#[test]
fn test_8bit_matmul_small_known() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let mut encoder = device.command_encoder().expect("encoder");
let m = 1u32;
let k = 4u32;
let n = 2u32;
let group_size = 64u32;
let bits = 8u32;
let input = f32_buffer(&device, vec![m as usize, k as usize], &[1.0, 2.0, 3.0, 4.0]);
let quant_w: Vec<u8> = vec![10, 20, 30, 40, 50, 60, 70, 80];
let weight = pack_8bit_buffer(&device, n as usize, k as usize, &quant_w);
let scales = bf16_buffer(&device, vec![n as usize, 1], &[0.01, 0.02]);
let biases = bf16_buffer(&device, vec![n as usize, 1], &[0.0, 0.0]);
let params = QuantizedMatmulParams { m, k, n, group_size, bits };
let output = quantized_matmul(
&mut encoder, &mut registry, &device,
&input, &weight, &scales, &biases, ¶ms,
).expect("quantized_matmul");
encoder.commit_and_wait().expect("commit");
let result = read_f32(&output);
assert_eq!(result.len(), 2);
let tol = 1e-1;
assert!(
(result[0] - 3.0).abs() < tol,
"output[0]={}, expected ~3.0", result[0]
);
assert!(
(result[1] - 14.0).abs() < tol,
"output[1]={}, expected ~14.0", result[1]
);
}
#[test]
fn test_8bit_matmul_with_bias() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let mut encoder = device.command_encoder().expect("encoder");
let m = 1u32;
let k = 4u32;
let n = 1u32;
let group_size = 64u32;
let bits = 8u32;
let input = f32_buffer(&device, vec![1, 4], &[1.0, 1.0, 1.0, 1.0]);
let quant_w: Vec<u8> = vec![0, 0, 0, 0];
let weight = pack_8bit_buffer(&device, 1, 4, &quant_w);
let scales = bf16_buffer(&device, vec![1, 1], &[1.0]);
let biases = bf16_buffer(&device, vec![1, 1], &[0.5]);
let params = QuantizedMatmulParams { m, k, n, group_size, bits };
let output = quantized_matmul(
&mut encoder, &mut registry, &device,
&input, &weight, &scales, &biases, ¶ms,
).expect("quantized_matmul");
encoder.commit_and_wait().expect("commit");
let result = read_f32(&output);
let tol = 1e-2;
assert!(
(result[0] - 2.0).abs() < tol,
"output[0]={}, expected ~2.0", result[0]
);
}
#[test]
fn test_4bit_multiple_groups() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let mut encoder = device.command_encoder().expect("encoder");
let m = 1u32;
let k = 8u32;
let n = 1u32;
let group_size = 4u32;
let bits = 4u32;
let input = f32_buffer(&device, vec![1, 8], &[1.0; 8]);
let quant_w: Vec<u8> = vec![1, 1, 1, 1, 2, 2, 2, 2];
let weight = pack_4bit_buffer(&device, 1, 8, &quant_w);
let scales = bf16_buffer(&device, vec![1, 2], &[0.5, 1.0]);
let biases = bf16_buffer(&device, vec![1, 2], &[0.0, 0.0]);
let params = QuantizedMatmulParams { m, k, n, group_size, bits };
let output = quantized_matmul(
&mut encoder, &mut registry, &device,
&input, &weight, &scales, &biases, ¶ms,
).expect("quantized_matmul");
encoder.commit_and_wait().expect("commit");
let result = read_f32(&output);
let tol = 1e-1;
assert!(
(result[0] - 10.0).abs() < tol,
"output[0]={}, expected ~10.0", result[0]
);
}
}