use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::encoder::{CommandEncoder, KernelArg, as_bytes};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
const QK4_0: u32 = 32;
const BLOCK_Q4_0_BYTES: u32 = 18;
const QK8_0: u32 = 32;
const BLOCK_Q8_0_BYTES: u32 = 34;
const QK4_K: u32 = 256;
const BLOCK_Q4_K_BYTES: u32 = 144;
const QK6_K: u32 = 256;
const BLOCK_Q6_K_BYTES: u32 = 210;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[allow(non_camel_case_types)]
pub enum GgmlType {
F32,
F16,
Q4_0,
Q8_0,
Q4_K,
Q6_K,
}
impl GgmlType {
pub fn block_values(self) -> u32 {
match self {
GgmlType::F32 => 1,
GgmlType::F16 => 1,
GgmlType::Q4_0 => QK4_0,
GgmlType::Q8_0 => QK8_0,
GgmlType::Q4_K => QK4_K,
GgmlType::Q6_K => QK6_K,
}
}
pub fn block_bytes(self) -> u32 {
match self {
GgmlType::F32 => 4,
GgmlType::F16 => 2,
GgmlType::Q4_0 => BLOCK_Q4_0_BYTES,
GgmlType::Q8_0 => BLOCK_Q8_0_BYTES,
GgmlType::Q4_K => BLOCK_Q4_K_BYTES,
GgmlType::Q6_K => BLOCK_Q6_K_BYTES,
}
}
fn kernel_name(self) -> &'static str {
match self {
GgmlType::F32 | GgmlType::F16 | GgmlType::Q4_K => {
"unsupported"
}
GgmlType::Q4_0 => "kernel_mul_mv_q4_0_f32",
GgmlType::Q8_0 => "kernel_mul_mv_q8_0_f32",
GgmlType::Q6_K => "kernel_mul_mv_q6_K_f32",
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct GgmlQuantizedMatmulParams {
pub m: u32,
pub n: u32,
pub k: u32,
pub ggml_type: GgmlType,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GgmlMatvecGpuParams {
ne00: i64, ne01: i64, ne02: i64, ne10: i64, ne12: i64, ne0: i64, ne1: i64, r2: u32, r3: u32, }
pub fn quantized_matmul_ggml(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &mut MlxBuffer,
params: &GgmlQuantizedMatmulParams,
) -> Result<()> {
let qk = params.ggml_type.block_values();
let block_bytes = params.ggml_type.block_bytes();
match params.ggml_type {
GgmlType::Q4_0 | GgmlType::Q8_0 | GgmlType::Q6_K => {}
other => {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_ggml does not support {:?} — use a different dispatch path",
other
)));
}
}
if params.m == 0 || params.k == 0 || params.n == 0 {
return Err(MlxError::InvalidArgument(
"M, K, and N must all be > 0".into(),
));
}
if params.k % qk != 0 {
return Err(MlxError::InvalidArgument(format!(
"K ({}) must be divisible by block QK ({})",
params.k, qk
)));
}
let blocks_per_row = params.k / qk;
let expected_weight_bytes =
(params.n as usize) * (blocks_per_row as usize) * (block_bytes as usize);
if weight.byte_len() < expected_weight_bytes {
return Err(MlxError::InvalidArgument(format!(
"Weight buffer too small: expected {} bytes for {:?} [{}x{}], got {}",
expected_weight_bytes,
params.ggml_type,
params.n,
params.k,
weight.byte_len()
)));
}
let expected_input_bytes =
(params.m as usize) * (params.k as usize) * DType::F32.size_of();
if input.byte_len() < expected_input_bytes {
return Err(MlxError::InvalidArgument(format!(
"Input buffer too small: expected {} bytes for [{}x{}] f32, got {}",
expected_input_bytes, params.m, params.k, input.byte_len()
)));
}
let expected_output_bytes =
(params.m as usize) * (params.n as usize) * DType::F32.size_of();
if output.byte_len() < expected_output_bytes {
return Err(MlxError::InvalidArgument(format!(
"Output buffer too small: expected {} bytes for [{}x{}] f32, got {}",
expected_output_bytes, params.m, params.n, output.byte_len()
)));
}
let kernel_name = params.ggml_type.kernel_name();
let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
let gpu_params = GgmlMatvecGpuParams {
ne00: params.k as i64,
ne01: params.n as i64,
ne02: 1,
ne10: params.k as i64,
ne12: 1,
ne0: params.n as i64,
ne1: params.m as i64,
r2: 1,
r3: 1,
};
let n = params.n as usize;
let m = params.m as usize;
match params.ggml_type {
GgmlType::Q8_0 => {
let nsg: usize = 4;
let nr0: usize = 2;
let align = nr0; let smem_bytes = nr0 * 32 * std::mem::size_of::<f32>();
let threadgroups = metal::MTLSize::new(
div_ceil(n, align) as u64,
m as u64,
1,
);
let threads_per_tg = metal::MTLSize::new(32, nsg as u64, 1);
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Buffer(weight)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
],
&[(0, smem_bytes as u64)],
threadgroups,
threads_per_tg,
);
}
_ => {
let (nth0, nth1, align) = match params.ggml_type {
GgmlType::Q4_0 => (8u64, 8u64, 8usize),
GgmlType::Q6_K => (2u64, 32u64, 2usize),
_ => unreachable!(),
};
let threadgroups = metal::MTLSize::new(
div_ceil(n, align) as u64,
m as u64,
1,
);
let threads_per_tg = metal::MTLSize::new(nth0, nth1, 1);
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(weight)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
],
threadgroups,
threads_per_tg,
);
}
}
Ok(())
}
fn div_ceil(a: usize, b: usize) -> usize {
(a + b - 1) / b
}