use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::encoder::{CommandEncoder, KernelArg, as_bytes};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
const QK4_0: u32 = 32;
const BLOCK_Q4_0_BYTES: u32 = 18;
const QK8_0: u32 = 32;
const BLOCK_Q8_0_BYTES: u32 = 34;
const QK4_K: u32 = 256;
const BLOCK_Q4_K_BYTES: u32 = 144;
const QK5_K: u32 = 256;
const BLOCK_Q5_K_BYTES: u32 = 176;
const QK6_K: u32 = 256;
const BLOCK_Q6_K_BYTES: u32 = 210;
const QK5_1: u32 = 32;
const BLOCK_Q5_1_BYTES: u32 = 24;
const QK4_NL: u32 = 32;
const BLOCK_IQ4_NL_BYTES: u32 = 18;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[allow(non_camel_case_types)]
pub enum GgmlType {
F32,
F16,
Q4_0,
Q8_0,
Q4_K,
Q5_K,
Q6_K,
I16,
Q5_1,
IQ4_NL,
}
impl GgmlType {
pub fn block_values(self) -> u32 {
match self {
GgmlType::F32 => 1,
GgmlType::F16 => 1,
GgmlType::Q4_0 => QK4_0,
GgmlType::Q8_0 => QK8_0,
GgmlType::Q4_K => QK4_K,
GgmlType::Q5_K => QK5_K,
GgmlType::Q6_K => QK6_K,
GgmlType::I16 => 1,
GgmlType::Q5_1 => QK5_1,
GgmlType::IQ4_NL => QK4_NL,
}
}
pub fn block_bytes(self) -> u32 {
match self {
GgmlType::F32 => 4,
GgmlType::F16 => 2,
GgmlType::Q4_0 => BLOCK_Q4_0_BYTES,
GgmlType::Q8_0 => BLOCK_Q8_0_BYTES,
GgmlType::Q4_K => BLOCK_Q4_K_BYTES,
GgmlType::Q5_K => BLOCK_Q5_K_BYTES,
GgmlType::Q6_K => BLOCK_Q6_K_BYTES,
GgmlType::I16 => 2,
GgmlType::Q5_1 => BLOCK_Q5_1_BYTES,
GgmlType::IQ4_NL => BLOCK_IQ4_NL_BYTES,
}
}
fn kernel_name(self) -> &'static str {
match self {
GgmlType::F32 | GgmlType::F16 | GgmlType::I16 => "unsupported",
GgmlType::Q4_0 => "kernel_mul_mv_q4_0_f32",
GgmlType::Q8_0 => "kernel_mul_mv_q8_0_f32",
GgmlType::Q4_K => "kernel_mul_mv_q4_K_f32",
GgmlType::Q5_K => "kernel_mul_mv_q5_K_f32",
GgmlType::Q6_K => "kernel_mul_mv_q6_K_f32",
GgmlType::Q5_1 => "kernel_mul_mv_q5_1_f32",
GgmlType::IQ4_NL => "kernel_mul_mv_iq4_nl_f32",
}
}
fn mm_kernel_name(self) -> &'static str {
match self {
GgmlType::F32
| GgmlType::F16
| GgmlType::I16 => "unsupported",
GgmlType::Q4_0 => "kernel_mul_mm_q4_0_f32",
GgmlType::Q8_0 => "kernel_mul_mm_q8_0_f32",
GgmlType::Q4_K => "kernel_mul_mm_q4_K_f32",
GgmlType::Q5_K => "kernel_mul_mm_q5_K_f32",
GgmlType::Q6_K => "kernel_mul_mm_q6_K_f32",
GgmlType::Q5_1 => "kernel_mul_mm_q5_1_f32",
GgmlType::IQ4_NL => "kernel_mul_mm_iq4_nl_f32",
}
}
fn mm_tensor_kernel_name(self) -> &'static str {
match self {
GgmlType::F32
| GgmlType::F16
| GgmlType::I16 => "unsupported",
GgmlType::Q4_0 => "kernel_mul_mm_q4_0_tensor_f32",
GgmlType::Q8_0 => "kernel_mul_mm_q8_0_tensor_f32",
GgmlType::Q4_K => "kernel_mul_mm_q4_K_tensor_f32",
GgmlType::Q5_K => "kernel_mul_mm_q5_K_tensor_f32",
GgmlType::Q6_K => "kernel_mul_mm_q6_K_tensor_f32",
GgmlType::Q5_1 => "kernel_mul_mm_q5_1_tensor_f32",
GgmlType::IQ4_NL => "kernel_mul_mm_iq4_nl_tensor_f32",
}
}
}
static TENSOR_MM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
fn probe_tensor_mm(registry: &mut KernelRegistry, device: &MlxDevice) -> bool {
*TENSOR_MM_AVAILABLE.get_or_init(|| {
let ok = registry
.get_pipeline("kernel_mul_mm_q4_0_tensor_f32", device.metal_device())
.is_ok();
if std::env::var("MLX_LOG_TENSOR_PROBE").is_ok() {
eprintln!("[mlx-native] tensor_mm probe: {}", if ok { "OK (using tensor variant)" } else { "FAILED (falling back to simdgroup MMA)" });
}
ok
})
}
pub const MM_ROUTING_THRESHOLD: u32 = 8;
#[derive(Debug, Clone, Copy)]
pub struct GgmlQuantizedMatmulParams {
pub m: u32,
pub n: u32,
pub k: u32,
pub ggml_type: GgmlType,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GgmlMatvecGpuParams {
ne00: i64, ne01: i64, ne02: i64, ne10: i64, ne12: i64, ne0: i64, ne1: i64, r2: u32, r3: u32, }
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GgmlMatmulMmGpuParams {
ne00: i32, ne02: i32, nb01: u64, nb02: u64, nb03: u64, ne12: i32, _pad0: u32, nb10: u64, nb11: u64, nb12: u64, nb13: u64, ne0: i32, ne1: i32, r2: i16, r3: i16, _pad1: u32, }
pub fn quantized_matmul_ggml(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &mut MlxBuffer,
params: &GgmlQuantizedMatmulParams,
) -> Result<()> {
let qk = params.ggml_type.block_values();
let block_bytes = params.ggml_type.block_bytes();
match params.ggml_type {
GgmlType::Q4_0
| GgmlType::Q8_0
| GgmlType::Q4_K
| GgmlType::Q5_K
| GgmlType::Q6_K
| GgmlType::Q5_1
| GgmlType::IQ4_NL => {}
other => {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_ggml does not support {:?} — use a different dispatch path",
other
)));
}
}
if params.m == 0 || params.k == 0 || params.n == 0 {
return Err(MlxError::InvalidArgument(
"M, K, and N must all be > 0".into(),
));
}
if params.k % qk != 0 {
return Err(MlxError::InvalidArgument(format!(
"K ({}) must be divisible by block QK ({})",
params.k, qk
)));
}
let blocks_per_row = params.k / qk;
let expected_weight_bytes =
(params.n as usize) * (blocks_per_row as usize) * (block_bytes as usize);
if weight.byte_len() < expected_weight_bytes {
return Err(MlxError::InvalidArgument(format!(
"Weight buffer too small: expected {} bytes for {:?} [{}x{}], got {}",
expected_weight_bytes,
params.ggml_type,
params.n,
params.k,
weight.byte_len()
)));
}
let expected_input_bytes =
(params.m as usize) * (params.k as usize) * DType::F32.size_of();
if input.byte_len() < expected_input_bytes {
return Err(MlxError::InvalidArgument(format!(
"Input buffer too small: expected {} bytes for [{}x{}] f32, got {}",
expected_input_bytes, params.m, params.k, input.byte_len()
)));
}
let expected_output_bytes =
(params.m as usize) * (params.n as usize) * DType::F32.size_of();
if output.byte_len() < expected_output_bytes {
return Err(MlxError::InvalidArgument(format!(
"Output buffer too small: expected {} bytes for [{}x{}] f32, got {}",
expected_output_bytes, params.m, params.n, output.byte_len()
)));
}
let mm_supported = true;
if params.m > MM_ROUTING_THRESHOLD && params.k >= 32 && mm_supported {
dispatch_mm(encoder, registry, device, input, weight, output, params)
} else {
dispatch_mv(encoder, registry, device, input, weight, output, params)
}
}
#[doc(hidden)]
pub fn dispatch_mm_for_test(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &mut MlxBuffer,
params: &GgmlQuantizedMatmulParams,
) -> Result<()> {
let qk = params.ggml_type.block_values();
match params.ggml_type {
GgmlType::Q4_0
| GgmlType::Q8_0
| GgmlType::Q4_K
| GgmlType::Q5_K
| GgmlType::Q6_K
| GgmlType::Q5_1
| GgmlType::IQ4_NL => {}
other => {
return Err(MlxError::InvalidArgument(format!(
"dispatch_mm_for_test does not support {:?}", other
)));
}
}
if params.m == 0 || params.k == 0 || params.n == 0 {
return Err(MlxError::InvalidArgument(
"M, K, and N must all be > 0".into(),
));
}
if params.k % qk != 0 {
return Err(MlxError::InvalidArgument(format!(
"K ({}) must be divisible by block QK ({})", params.k, qk
)));
}
dispatch_mm(encoder, registry, device, input, weight, output, params)
}
fn dispatch_mv(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &mut MlxBuffer,
params: &GgmlQuantizedMatmulParams,
) -> Result<()> {
let kernel_name = params.ggml_type.kernel_name();
let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
let gpu_params = GgmlMatvecGpuParams {
ne00: params.k as i64,
ne01: params.n as i64,
ne02: 1,
ne10: params.k as i64,
ne12: 1,
ne0: params.n as i64,
ne1: params.m as i64,
r2: 1,
r3: 1,
};
let n = params.n as usize;
let m = params.m as usize;
let (nth0, nth1, align) = match params.ggml_type {
GgmlType::Q4_0
| GgmlType::Q8_0
| GgmlType::Q5_1
| GgmlType::IQ4_NL => (8u64, 8u64, 8usize),
GgmlType::Q4_K | GgmlType::Q5_K | GgmlType::Q6_K => (2u64, 32u64, 2usize),
_ => unreachable!(),
};
let threadgroups = metal::MTLSize::new(
div_ceil(n, align) as u64,
m as u64,
1,
);
let threads_per_tg = metal::MTLSize::new(nth0, nth1, 1);
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(weight)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
],
threadgroups,
threads_per_tg,
);
Ok(())
}
fn dispatch_mm(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &mut MlxBuffer,
params: &GgmlQuantizedMatmulParams,
) -> Result<()> {
let use_tensor = probe_tensor_mm(registry, device);
let kernel_name = if use_tensor {
params.ggml_type.mm_tensor_kernel_name()
} else {
params.ggml_type.mm_kernel_name()
};
let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
let qk = params.ggml_type.block_values();
let block_bytes = params.ggml_type.block_bytes();
let blocks_per_row = params.k / qk;
let nb01 = (blocks_per_row as u64) * (block_bytes as u64);
let nb11 = (params.k as u64) * DType::F32.size_of() as u64;
let gpu_params = GgmlMatmulMmGpuParams {
ne00: params.k as i32,
ne02: 1,
nb01,
nb02: nb01 * (params.n as u64),
nb03: 0,
ne12: 1,
_pad0: 0,
nb10: DType::F32.size_of() as u64,
nb11,
nb12: nb11 * (params.m as u64),
nb13: 0,
ne0: params.n as i32,
ne1: params.m as i32,
r2: 1,
r3: 1,
_pad1: 0,
};
const NR0: u64 = 64;
const NR1: u64 = 32;
const THREADS_PER_TG: u64 = 128;
let threadgroups = metal::MTLSize::new(
(params.m as u64 + NR1 - 1) / NR1,
(params.n as u64 + NR0 - 1) / NR0,
1,
);
let threads_per_tg = metal::MTLSize::new(THREADS_PER_TG, 1, 1);
const SHMEM_BYTES: u64 = 8192;
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(weight)),
(2, KernelArg::Buffer(input)),
(3, KernelArg::Buffer(output)),
],
&[(0, SHMEM_BYTES)],
threadgroups,
threads_per_tg,
);
Ok(())
}
fn div_ceil(a: usize, b: usize) -> usize {
(a + b - 1) / b
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GgmlMatmulMmTensorPerm021GpuParams {
ne00: i32, ne02: i32,
nb01: u64, nb02: u64,
nb03: u64,
ne12: i32,
_pad0: u32,
nb10: u64, nb11: u64, nb12: u64,
nb13: u64,
ne0: i32, ne1: i32, r2: i16,
r3: i16,
head_dim: i32,
seq_len: i32,
_pad_trailing: u32,
}
#[derive(Debug, Clone, Copy)]
pub struct GgmlQuantizedMatmulPerm021Params {
pub m: u32,
pub n: u32,
pub k: u32,
pub head_dim: u32,
pub ggml_type: GgmlType,
}
pub fn quantized_matmul_mm_tensor_perm021(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input_bf16: &MlxBuffer,
weight: &MlxBuffer,
output: &mut MlxBuffer,
params: &GgmlQuantizedMatmulPerm021Params,
) -> Result<()> {
let kernel_name = match params.ggml_type {
GgmlType::Q4_0 => "kernel_mul_mm_q4_0_tensor_bf16_perm021",
GgmlType::Q8_0 => "kernel_mul_mm_q8_0_tensor_bf16_perm021",
GgmlType::Q6_K => "kernel_mul_mm_q6_K_tensor_bf16_perm021",
other => {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_mm_tensor_perm021: unsupported ggml_type {:?} \
(only Q4_0 / Q8_0 / Q6_K are instantiated)",
other
)));
}
};
if params.head_dim == 0 || params.head_dim % 32 != 0 {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_mm_tensor_perm021: head_dim {} must be a positive \
multiple of 32 (NK tile width)",
params.head_dim
)));
}
if params.k % params.head_dim != 0 {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_mm_tensor_perm021: k ({}) must be divisible by \
head_dim ({})",
params.k, params.head_dim
)));
}
let n_heads = params.k / params.head_dim;
let expected_input_bytes = (n_heads as usize) * (params.m as usize)
* (params.head_dim as usize) * 2;
if input_bf16.byte_len() < expected_input_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_mm_tensor_perm021: input_bf16 buffer too small \
(have {}, need {})",
input_bf16.byte_len(), expected_input_bytes
)));
}
let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
let qk = params.ggml_type.block_values();
let block_bytes = params.ggml_type.block_bytes();
let blocks_per_row = params.k / qk;
let nb01 = (blocks_per_row as u64) * (block_bytes as u64);
let gpu_params = GgmlMatmulMmTensorPerm021GpuParams {
ne00: params.k as i32,
ne02: 1,
nb01,
nb02: nb01 * (params.n as u64),
nb03: 0,
ne12: 1,
_pad0: 0,
nb10: 2, nb11: 0, nb12: 0,
nb13: 0,
ne0: params.n as i32,
ne1: params.m as i32,
r2: 1,
r3: 1,
head_dim: params.head_dim as i32,
seq_len: params.m as i32,
_pad_trailing: 0,
};
const NR0: u64 = 64;
const NR1: u64 = 32;
const THREADS_PER_TG: u64 = 128;
const SHMEM_BYTES: u64 = 8192;
let threadgroups = metal::MTLSize::new(
(params.m as u64 + NR1 - 1) / NR1,
(params.n as u64 + NR0 - 1) / NR0,
1,
);
let threads_per_tg = metal::MTLSize::new(THREADS_PER_TG, 1, 1);
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(weight)),
(2, KernelArg::Buffer(input_bf16)),
(3, KernelArg::Buffer(output)),
],
&[(0, SHMEM_BYTES)],
threadgroups,
threads_per_tg,
);
Ok(())
}