use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::encoder::{CommandEncoder, KernelArg, as_bytes};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use crate::ops::quantized_matmul_ggml::GgmlType;
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GgmlMatvecIdGpuParams {
ne00: i64, ne01: i64, ne02: i64, ne10: i64, ne12: i64, ne0: i64, ne1: i64, r2: u32, r3: u32, top_k: u32, n_tokens: u32, expert_stride: i64, }
#[derive(Debug, Clone, Copy)]
pub struct GgmlQuantizedMatmulIdParams {
pub n_tokens: u32,
pub top_k: u32,
pub n: u32,
pub k: u32,
pub n_experts: u32,
pub expert_stride: u64,
pub ggml_type: GgmlType,
}
impl GgmlType {
fn id_kernel_name(self) -> &'static str {
match self {
GgmlType::Q4_0 => "kernel_mul_mv_id_q4_0_f32",
GgmlType::Q8_0 => "kernel_mul_mv_id_q8_0_f32",
GgmlType::Q6_K => "kernel_mul_mv_id_q6_K_f32",
GgmlType::F32 | GgmlType::F16 | GgmlType::Q4_K => "unsupported",
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn quantized_matmul_id_ggml(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
ids: &MlxBuffer,
output: &mut MlxBuffer,
params: &GgmlQuantizedMatmulIdParams,
) -> Result<()> {
let qk = params.ggml_type.block_values();
let block_bytes = params.ggml_type.block_bytes();
if params.n_tokens == 0 || params.k == 0 || params.n == 0 {
return Err(MlxError::InvalidArgument(
"quantized_matmul_id_ggml: n_tokens, K, and N must all be > 0".into(),
));
}
if params.top_k == 0 {
return Err(MlxError::InvalidArgument(
"quantized_matmul_id_ggml: top_k must be > 0".into(),
));
}
if params.n_experts == 0 {
return Err(MlxError::InvalidArgument(
"quantized_matmul_id_ggml: n_experts must be > 0".into(),
));
}
if params.k % qk != 0 {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml: K ({}) must be divisible by block QK ({})",
params.k, qk
)));
}
let expected_input_bytes =
(params.n_tokens as usize) * (params.k as usize) * DType::F32.size_of();
if input.byte_len() < expected_input_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml: input buffer too small: expected {} bytes for [{} x {}] f32, got {}",
expected_input_bytes, params.n_tokens, params.k, input.byte_len()
)));
}
let blocks_per_row = params.k / qk;
let per_expert_bytes =
(params.n as usize) * (blocks_per_row as usize) * (block_bytes as usize);
if params.expert_stride < per_expert_bytes as u64 {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml: expert_stride ({}) < per_expert_bytes ({})",
params.expert_stride, per_expert_bytes
)));
}
let total_weight_bytes = per_expert_bytes * (params.n_experts as usize);
if weight.byte_len() < total_weight_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml: weight buffer too small: expected {} bytes for {} experts, got {}",
total_weight_bytes, params.n_experts, weight.byte_len()
)));
}
let total_rows = (params.n_tokens as usize) * (params.top_k as usize);
let expected_ids_bytes = total_rows * DType::U32.size_of();
if ids.byte_len() < expected_ids_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml: ids buffer too small: expected {} bytes for [{} * {}] u32, got {}",
expected_ids_bytes, params.n_tokens, params.top_k, ids.byte_len()
)));
}
let expected_output_bytes = total_rows * (params.n as usize) * DType::F32.size_of();
if output.byte_len() < expected_output_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml: output buffer too small: expected {} bytes for [{} x {}] f32, got {}",
expected_output_bytes, total_rows, params.n, output.byte_len()
)));
}
let kernel_name = params.ggml_type.id_kernel_name();
let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
let gpu_params = GgmlMatvecIdGpuParams {
ne00: params.k as i64,
ne01: params.n as i64,
ne02: 1,
ne10: params.k as i64,
ne12: 1,
ne0: params.n as i64,
ne1: total_rows as i64,
r2: 1,
r3: 1,
top_k: params.top_k,
n_tokens: params.n_tokens,
expert_stride: params.expert_stride as i64,
};
let (nth0, nth1, align) = match params.ggml_type {
GgmlType::Q4_0 | GgmlType::Q8_0 => (8u64, 8u64, 8usize),
GgmlType::Q6_K => (2u64, 32u64, 2usize),
GgmlType::F32 | GgmlType::F16 | GgmlType::Q4_K => {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml does not support {:?}",
params.ggml_type
)));
}
};
let n = params.n as usize;
let m = total_rows;
let threadgroups = metal::MTLSize::new(
div_ceil(n, align) as u64,
m as u64,
1,
);
let threads_per_tg = metal::MTLSize::new(nth0, nth1, 1);
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(weight)), (1, KernelArg::Buffer(input)), (2, KernelArg::Buffer(output)), (3, KernelArg::Buffer(ids)), (4, KernelArg::Bytes(as_bytes(&gpu_params))),
],
threadgroups,
threads_per_tg,
);
Ok(())
}
fn div_ceil(a: usize, b: usize) -> usize {
(a + b - 1) / b
}