use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
#[derive(Debug, Clone, Copy)]
pub struct QuantizedMatmulIdParams {
pub m: u32,
pub k: u32,
pub n: u32,
pub group_size: u32,
pub bits: u32,
pub n_expert_used: u32,
pub num_experts: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct QuantizedMatmulIdGpuParams {
m: u32,
k: u32,
n: u32,
group_size: u32,
bits: u32,
n_expert_used: u32,
num_experts: u32,
expert_weight_stride: u32,
expert_scales_stride: u32,
expert_biases_stride: u32,
}
fn expert_weight_bytes(k: u32, n: u32, bits: u32) -> usize {
match bits {
4 => {
let values_per_pack = 8u32;
let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
(n as usize) * (packs_per_row as usize) * 4
}
6 => {
let triplets_per_row = (k + 3) / 4;
(n as usize) * (triplets_per_row as usize) * 3
}
8 => {
let values_per_pack = 4u32;
let packs_per_row = (k + values_per_pack - 1) / values_per_pack;
(n as usize) * (packs_per_row as usize) * 4
}
_ => 0,
}
}
fn expert_scales_elements(k: u32, n: u32, group_size: u32) -> usize {
let num_groups = (k + group_size - 1) / group_size;
(n as usize) * (num_groups as usize)
}
#[allow(clippy::too_many_arguments)]
pub fn quantized_matmul_id(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
ids: &MlxBuffer,
params: &QuantizedMatmulIdParams,
) -> Result<MlxBuffer> {
if params.bits != 4 && params.bits != 6 && params.bits != 8 {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id: unsupported bits value {}; only 4, 6, and 8 are supported",
params.bits
)));
}
if params.m == 0 || params.k == 0 || params.n == 0 {
return Err(MlxError::InvalidArgument(
"quantized_matmul_id: M, K, and N must all be > 0".into(),
));
}
if params.group_size == 0 {
return Err(MlxError::InvalidArgument(
"quantized_matmul_id: group_size must be > 0".into(),
));
}
if params.n_expert_used == 0 {
return Err(MlxError::InvalidArgument(
"quantized_matmul_id: n_expert_used must be > 0".into(),
));
}
if params.num_experts == 0 {
return Err(MlxError::InvalidArgument(
"quantized_matmul_id: num_experts must be > 0".into(),
));
}
let expected_input = (params.m as usize) * (params.k as usize) * DType::F32.size_of();
if input.byte_len() < expected_input {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id: input buffer too small: expected at least {} bytes for [{}x{}] f32, got {}",
expected_input, params.m, params.k, input.byte_len()
)));
}
let per_expert_w = expert_weight_bytes(params.k, params.n, params.bits);
let total_w = per_expert_w * (params.num_experts as usize);
if weight.byte_len() < total_w {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id: weight buffer too small: expected at least {} bytes for {} experts, got {}",
total_w, params.num_experts, weight.byte_len()
)));
}
let per_expert_s = expert_scales_elements(params.k, params.n, params.group_size);
let total_s_bytes = per_expert_s * (params.num_experts as usize) * 2; if scales.byte_len() < total_s_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id: scales buffer too small: expected at least {} bytes, got {}",
total_s_bytes, scales.byte_len()
)));
}
if biases.byte_len() < total_s_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id: biases buffer too small: expected at least {} bytes, got {}",
total_s_bytes, biases.byte_len()
)));
}
let expected_ids = (params.m as usize) * (params.n_expert_used as usize) * DType::U32.size_of();
if ids.byte_len() < expected_ids {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id: ids buffer too small: expected at least {} bytes for [{}x{}] u32, got {}",
expected_ids, params.m, params.n_expert_used, ids.byte_len()
)));
}
let pipeline = registry.get_pipeline("quantized_matmul_id", device.metal_device())?;
let output_elems = (params.m as usize) * (params.n_expert_used as usize) * (params.n as usize);
let output_bytes = output_elems * DType::F32.size_of();
let output = device.alloc_buffer(
output_bytes,
DType::F32,
vec![
params.m as usize,
params.n_expert_used as usize,
params.n as usize,
],
)?;
let gpu_params = QuantizedMatmulIdGpuParams {
m: params.m,
k: params.k,
n: params.n,
group_size: params.group_size,
bits: params.bits,
n_expert_used: params.n_expert_used,
num_experts: params.num_experts,
expert_weight_stride: per_expert_w as u32,
expert_scales_stride: per_expert_s as u32,
expert_biases_stride: per_expert_s as u32,
};
let params_bytes = std::mem::size_of::<QuantizedMatmulIdGpuParams>();
let mut params_buf = device.alloc_buffer(params_bytes, DType::U32, vec![10])?;
{
let slice: &mut [QuantizedMatmulIdGpuParams] = bytemuck::cast_slice_mut(
params_buf
.as_mut_slice::<u8>()
.map_err(|e| MlxError::InvalidArgument(format!("params buf write: {e}")))?,
);
slice[0] = gpu_params;
}
let total_rows = (params.m as u64) * (params.n_expert_used as u64);
let tg_x = 16u64.min(params.n as u64);
let tg_y = 16u64.min(total_rows);
let threadgroup_size = metal::MTLSize::new(tg_x, tg_y, 1);
let grid_groups = metal::MTLSize::new(
(params.n as u64 + tg_x - 1) / tg_x,
(total_rows + tg_y - 1) / tg_y,
1,
);
encoder.encode_threadgroups(
pipeline,
&[
(0, input),
(1, weight),
(2, scales),
(3, biases),
(4, ids),
(5, &output),
(6, ¶ms_buf),
],
grid_groups,
threadgroup_size,
);
Ok(output)
}