use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::encoder::{CommandEncoder, KernelArg, as_bytes};
use crate::env_flags::env_default_true;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
const QK4_0: u32 = 32;
const BLOCK_Q4_0_BYTES: u32 = 18;
const QK8_0: u32 = 32;
const BLOCK_Q8_0_BYTES: u32 = 34;
const QK4_K: u32 = 256;
const BLOCK_Q4_K_BYTES: u32 = 144;
const QK5_K: u32 = 256;
const BLOCK_Q5_K_BYTES: u32 = 176;
const QK6_K: u32 = 256;
const BLOCK_Q6_K_BYTES: u32 = 210;
const QK5_1: u32 = 32;
const BLOCK_Q5_1_BYTES: u32 = 24;
const QK4_NL: u32 = 32;
const BLOCK_IQ4_NL_BYTES: u32 = 18;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[allow(non_camel_case_types)]
pub enum GgmlType {
F32,
F16,
Q4_0,
Q8_0,
Q4_K,
Q5_K,
Q6_K,
I16,
Q5_1,
IQ4_NL,
}
impl GgmlType {
pub fn block_values(self) -> u32 {
match self {
GgmlType::F32 => 1,
GgmlType::F16 => 1,
GgmlType::Q4_0 => QK4_0,
GgmlType::Q8_0 => QK8_0,
GgmlType::Q4_K => QK4_K,
GgmlType::Q5_K => QK5_K,
GgmlType::Q6_K => QK6_K,
GgmlType::I16 => 1,
GgmlType::Q5_1 => QK5_1,
GgmlType::IQ4_NL => QK4_NL,
}
}
pub fn block_bytes(self) -> u32 {
match self {
GgmlType::F32 => 4,
GgmlType::F16 => 2,
GgmlType::Q4_0 => BLOCK_Q4_0_BYTES,
GgmlType::Q8_0 => BLOCK_Q8_0_BYTES,
GgmlType::Q4_K => BLOCK_Q4_K_BYTES,
GgmlType::Q5_K => BLOCK_Q5_K_BYTES,
GgmlType::Q6_K => BLOCK_Q6_K_BYTES,
GgmlType::I16 => 2,
GgmlType::Q5_1 => BLOCK_Q5_1_BYTES,
GgmlType::IQ4_NL => BLOCK_IQ4_NL_BYTES,
}
}
fn kernel_name(self) -> &'static str {
match self {
GgmlType::F32 | GgmlType::F16 | GgmlType::I16 => "unsupported",
GgmlType::Q4_0 => "kernel_mul_mv_q4_0_f32",
GgmlType::Q8_0 => "kernel_mul_mv_q8_0_f32",
GgmlType::Q4_K => "kernel_mul_mv_q4_K_f32",
GgmlType::Q5_K => "kernel_mul_mv_q5_K_f32",
GgmlType::Q6_K => "kernel_mul_mv_q6_K_f32",
GgmlType::Q5_1 => "kernel_mul_mv_q5_1_f32",
GgmlType::IQ4_NL => "kernel_mul_mv_iq4_nl_f32",
}
}
fn mm_kernel_name(self) -> &'static str {
match self {
GgmlType::F32
| GgmlType::F16
| GgmlType::I16 => "unsupported",
GgmlType::Q4_0 => "kernel_mul_mm_q4_0_f32",
GgmlType::Q8_0 => "kernel_mul_mm_q8_0_f32",
GgmlType::Q4_K => "kernel_mul_mm_q4_K_f32",
GgmlType::Q5_K => "kernel_mul_mm_q5_K_f32",
GgmlType::Q6_K => "kernel_mul_mm_q6_K_f32",
GgmlType::Q5_1 => "kernel_mul_mm_q5_1_f32",
GgmlType::IQ4_NL => "kernel_mul_mm_iq4_nl_f32",
}
}
fn mm_tensor_kernel_name(self) -> &'static str {
match self {
GgmlType::F32
| GgmlType::F16
| GgmlType::I16 => "unsupported",
GgmlType::Q4_0 => "kernel_mul_mm_q4_0_tensor_f32",
GgmlType::Q8_0 => "kernel_mul_mm_q8_0_tensor_f32",
GgmlType::Q4_K => "kernel_mul_mm_q4_K_tensor_f32",
GgmlType::Q5_K => "kernel_mul_mm_q5_K_tensor_f32",
GgmlType::Q6_K => "kernel_mul_mm_q6_K_tensor_f32",
GgmlType::Q5_1 => "kernel_mul_mm_q5_1_tensor_f32",
GgmlType::IQ4_NL => "kernel_mul_mm_iq4_nl_tensor_f32",
}
}
fn mm_tensor_v2_kernel_name(self) -> &'static str {
match self {
GgmlType::F32
| GgmlType::F16
| GgmlType::I16 => "unsupported",
GgmlType::Q4_0 => "kernel_mul_mm_q4_0_tensor_v2_f32",
GgmlType::Q8_0 => "kernel_mul_mm_q8_0_tensor_v2_f32",
GgmlType::Q4_K => "kernel_mul_mm_q4_K_tensor_v2_f32",
GgmlType::Q5_K => "kernel_mul_mm_q5_K_tensor_v2_f32",
GgmlType::Q6_K => "kernel_mul_mm_q6_K_tensor_v2_f32",
GgmlType::Q5_1 => "kernel_mul_mm_q5_1_tensor_v2_f32",
GgmlType::IQ4_NL => "kernel_mul_mm_iq4_nl_tensor_v2_f32",
}
}
}
static TENSOR_MM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
fn probe_tensor_mm(registry: &mut KernelRegistry, device: &MlxDevice) -> bool {
*TENSOR_MM_AVAILABLE.get_or_init(|| {
if std::env::var("HF2Q_DISABLE_TENSOR_MM").as_deref() == Ok("1") {
if std::env::var("MLX_LOG_TENSOR_PROBE").is_ok() {
eprintln!("[mlx-native] tensor_mm probe: DISABLED via HF2Q_DISABLE_TENSOR_MM=1");
}
return false;
}
let ok = registry
.get_pipeline_with_constants(
"kernel_mul_mm_q4_0_tensor_f32",
device.metal_device(),
&[],
&[(700, 1), (701, 1), (702, 1)],
)
.is_ok();
if std::env::var("MLX_LOG_TENSOR_PROBE").is_ok() {
eprintln!("[mlx-native] tensor_mm probe: {}", if ok { "OK (using tensor variant)" } else { "FAILED (falling back to simdgroup MMA)" });
}
ok
})
}
pub const MM_ROUTING_THRESHOLD: u32 = 8;
#[derive(Debug, Clone, Copy)]
pub struct GgmlQuantizedMatmulParams {
pub m: u32,
pub n: u32,
pub k: u32,
pub ggml_type: GgmlType,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GgmlMatvecGpuParams {
ne00: i64, ne01: i64, ne02: i64, ne10: i64, ne12: i64, ne0: i64, ne1: i64, r2: u32, r3: u32, }
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GgmlMatmulMmGpuParams {
ne00: i32, ne02: i32, nb01: u64, nb02: u64, nb03: u64, ne12: i32, _pad0: u32, nb10: u64, nb11: u64, nb12: u64, nb13: u64, ne0: i32, ne1: i32, r2: i16, r3: i16, _pad1: u32, }
pub fn quantized_matmul_ggml(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &MlxBuffer,
params: &GgmlQuantizedMatmulParams,
) -> Result<()> {
let qk = params.ggml_type.block_values();
let block_bytes = params.ggml_type.block_bytes();
match params.ggml_type {
GgmlType::Q4_0
| GgmlType::Q8_0
| GgmlType::Q4_K
| GgmlType::Q5_K
| GgmlType::Q6_K
| GgmlType::Q5_1
| GgmlType::IQ4_NL => {}
other => {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_ggml does not support {:?} — use a different dispatch path",
other
)));
}
}
if params.m == 0 || params.k == 0 || params.n == 0 {
return Err(MlxError::InvalidArgument(
"M, K, and N must all be > 0".into(),
));
}
if params.k % qk != 0 {
return Err(MlxError::InvalidArgument(format!(
"K ({}) must be divisible by block QK ({})",
params.k, qk
)));
}
let blocks_per_row = params.k / qk;
let expected_weight_bytes =
(params.n as usize) * (blocks_per_row as usize) * (block_bytes as usize);
if weight.byte_len() < expected_weight_bytes {
return Err(MlxError::InvalidArgument(format!(
"Weight buffer too small: expected {} bytes for {:?} [{}x{}], got {}",
expected_weight_bytes,
params.ggml_type,
params.n,
params.k,
weight.byte_len()
)));
}
let expected_input_bytes =
(params.m as usize) * (params.k as usize) * DType::F32.size_of();
if input.byte_len() < expected_input_bytes {
return Err(MlxError::InvalidArgument(format!(
"Input buffer too small: expected {} bytes for [{}x{}] f32, got {}",
expected_input_bytes, params.m, params.k, input.byte_len()
)));
}
let expected_output_bytes =
(params.m as usize) * (params.n as usize) * DType::F32.size_of();
if output.byte_len() < expected_output_bytes {
return Err(MlxError::InvalidArgument(format!(
"Output buffer too small: expected {} bytes for [{}x{}] f32, got {}",
expected_output_bytes, params.m, params.n, output.byte_len()
)));
}
let mm_supported = true;
if params.m > MM_ROUTING_THRESHOLD && params.k >= 32 && mm_supported {
dispatch_mm(encoder, registry, device, input, weight, output, params)
} else {
dispatch_mv(encoder, registry, device, input, weight, output, params)
}
}
pub fn dispatch_mm_v2_f16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
f16_weight: &MlxBuffer,
input: &MlxBuffer,
output: &MlxBuffer,
m: u32,
n: u32,
k: u32,
) -> Result<()> {
if f16_weight.dtype() != DType::F16 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_mm_v2_f16: f16_weight must be F16, got {:?}",
f16_weight.dtype()
)));
}
if m == 0 || k == 0 || n == 0 {
return Err(MlxError::InvalidArgument(
"dispatch_mm_v2_f16: M, K, N must all be > 0".into(),
));
}
let nb01 = (k as u64) * (DType::F16.size_of() as u64);
let nb11 = (k as u64) * (DType::F32.size_of() as u64);
let gpu_params = GgmlMatmulMmGpuParams {
ne00: k as i32,
ne02: 1,
nb01,
nb02: nb01 * (n as u64),
nb03: 0,
ne12: 1,
_pad0: 0,
nb10: DType::F32.size_of() as u64,
nb11,
nb12: nb11 * (m as u64),
nb13: 0,
ne0: n as i32,
ne1: m as i32,
r2: 1,
r3: 1,
_pad1: 0,
};
let pipeline = registry
.get_pipeline_with_constants(
"hf2q_mul_mm_tensor_v2_f16",
device.metal_device(),
&[],
&[(700, 1), (701, 1), (702, 1)],
)?
.clone();
const THREADS_PER_TG: u64 = 128;
let nra: u64 = 64; let nrb: u64 = 128; let tg_x = (m as u64 + nrb - 1) / nrb;
let tg_y = (n as u64 + nra - 1) / nra;
let threadgroups = metal::MTLSize::new(tg_x, tg_y, 1);
let threads_per_tg = metal::MTLSize::new(THREADS_PER_TG, 1, 1);
const SHMEM_BYTES: u64 = 4096;
encoder.encode_threadgroups_with_args_and_shared(
&pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(f16_weight)),
(2, KernelArg::Buffer(input)),
(3, KernelArg::Buffer(output)),
],
&[(0, SHMEM_BYTES)],
threadgroups,
threads_per_tg,
);
Ok(())
}
#[doc(hidden)]
pub fn dispatch_mm_for_test(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &MlxBuffer,
params: &GgmlQuantizedMatmulParams,
) -> Result<()> {
let qk = params.ggml_type.block_values();
match params.ggml_type {
GgmlType::Q4_0
| GgmlType::Q8_0
| GgmlType::Q4_K
| GgmlType::Q5_K
| GgmlType::Q6_K
| GgmlType::Q5_1
| GgmlType::IQ4_NL => {}
other => {
return Err(MlxError::InvalidArgument(format!(
"dispatch_mm_for_test does not support {:?}", other
)));
}
}
if params.m == 0 || params.k == 0 || params.n == 0 {
return Err(MlxError::InvalidArgument(
"M, K, and N must all be > 0".into(),
));
}
if params.k % qk != 0 {
return Err(MlxError::InvalidArgument(format!(
"K ({}) must be divisible by block QK ({})", params.k, qk
)));
}
dispatch_mm(encoder, registry, device, input, weight, output, params)
}
fn dispatch_mv(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &MlxBuffer, params: &GgmlQuantizedMatmulParams,
) -> Result<()> {
let use_q6k_nr2 = matches!(params.ggml_type, GgmlType::Q6_K)
&& env_default_true("HF2Q_Q6K_MV_NR2");
let use_q8_0_nr2 = matches!(params.ggml_type, GgmlType::Q8_0)
&& std::env::var("HF2Q_Q8_0_MV_NR2").ok().as_deref() == Some("1");
let kernel_name = if use_q6k_nr2 {
"kernel_mul_mv_q6_K_f32_nr2"
} else if use_q8_0_nr2 {
"kernel_mul_mv_q8_0_f32_nr2"
} else {
params.ggml_type.kernel_name()
};
let pipeline = registry
.get_pipeline_with_constants(
kernel_name,
device.metal_device(),
&[],
&[(700, 1), (701, 1), (702, 1)],
)?
.clone();
let gpu_params = GgmlMatvecGpuParams {
ne00: params.k as i64,
ne01: params.n as i64,
ne02: 1,
ne10: params.k as i64,
ne12: 1,
ne0: params.n as i64,
ne1: params.m as i64,
r2: 1,
r3: 1,
};
let n = params.n as usize;
let m = params.m as usize;
let (nth0, nth1, align) = match params.ggml_type {
GgmlType::Q4_0
| GgmlType::Q8_0
| GgmlType::Q5_1
| GgmlType::IQ4_NL => (8u64, 8u64, 8usize),
GgmlType::Q4_K | GgmlType::Q5_K | GgmlType::Q6_K => (2u64, 32u64, 2usize),
_ => unreachable!(),
};
let align = if use_q6k_nr2 { 4usize } else { align };
let (nth0, nth1, align) = if use_q8_0_nr2 {
(32u64, 4u64, 2usize)
} else {
(nth0, nth1, align)
};
let threadgroups = metal::MTLSize::new(
div_ceil(n, align) as u64,
m as u64,
1,
);
let threads_per_tg = metal::MTLSize::new(nth0, nth1, 1);
if use_q8_0_nr2 {
let smem_bytes: u64 = 2 * 32 * std::mem::size_of::<f32>() as u64;
encoder.encode_threadgroups_with_args_and_shared(
&pipeline,
&[
(0, KernelArg::Buffer(weight)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
],
&[(0, smem_bytes)],
threadgroups,
threads_per_tg,
);
} else {
encoder.encode_threadgroups_with_args(
&pipeline,
&[
(0, KernelArg::Buffer(weight)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
],
threadgroups,
threads_per_tg,
);
}
Ok(())
}
fn dispatch_mm(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &MlxBuffer, params: &GgmlQuantizedMatmulParams,
) -> Result<()> {
let use_tensor = probe_tensor_mm(registry, device);
let use_v2_large_tile = use_tensor
&& match std::env::var("HF2Q_LARGE_TILE_MM").as_deref() {
Ok("0") | Ok("false") | Ok("off") => false,
_ => true,
};
let kernel_name = if use_v2_large_tile {
params.ggml_type.mm_tensor_v2_kernel_name()
} else if use_tensor {
params.ggml_type.mm_tensor_kernel_name()
} else {
params.ggml_type.mm_kernel_name()
};
let pipeline = registry
.get_pipeline_with_constants(
kernel_name,
device.metal_device(),
&[],
&[(700, 1), (701, 1), (702, 1)],
)?
.clone();
let qk = params.ggml_type.block_values();
let block_bytes = params.ggml_type.block_bytes();
let blocks_per_row = params.k / qk;
let nb01 = (blocks_per_row as u64) * (block_bytes as u64);
let nb11 = (params.k as u64) * DType::F32.size_of() as u64;
let gpu_params = GgmlMatmulMmGpuParams {
ne00: params.k as i32,
ne02: 1,
nb01,
nb02: nb01 * (params.n as u64),
nb03: 0,
ne12: 1,
_pad0: 0,
nb10: DType::F32.size_of() as u64,
nb11,
nb12: nb11 * (params.m as u64),
nb13: 0,
ne0: params.n as i32,
ne1: params.m as i32,
r2: 1,
r3: 1,
_pad1: 0,
};
const THREADS_PER_TG: u64 = 128;
let (tg_x, tg_y, shmem_bytes) = if use_v2_large_tile {
let nra: u64 = 64; let nrb: u64 = 128; (
(params.m as u64 + nrb - 1) / nrb, (params.n as u64 + nra - 1) / nra, 4096u64,
)
} else {
let nr0: u64 = 64;
let nr1: u64 = 32;
(
(params.m as u64 + nr1 - 1) / nr1,
(params.n as u64 + nr0 - 1) / nr0,
8192u64,
)
};
let threadgroups = metal::MTLSize::new(tg_x, tg_y, 1);
let threads_per_tg = metal::MTLSize::new(THREADS_PER_TG, 1, 1);
encoder.encode_threadgroups_with_args_and_shared(
&pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(weight)),
(2, KernelArg::Buffer(input)),
(3, KernelArg::Buffer(output)),
],
&[(0, shmem_bytes)],
threadgroups,
threads_per_tg,
);
Ok(())
}
fn div_ceil(a: usize, b: usize) -> usize {
(a + b - 1) / b
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GgmlMatmulMmTensorPerm021GpuParams {
ne00: i32, ne02: i32,
nb01: u64, nb02: u64,
nb03: u64,
ne12: i32,
_pad0: u32,
nb10: u64, nb11: u64, nb12: u64,
nb13: u64,
ne0: i32, ne1: i32, r2: i16,
r3: i16,
head_dim: i32,
seq_len: i32,
_pad_trailing: u32,
}
#[derive(Debug, Clone, Copy)]
pub struct GgmlQuantizedMatmulPerm021Params {
pub m: u32,
pub n: u32,
pub k: u32,
pub head_dim: u32,
pub ggml_type: GgmlType,
}
pub fn quantized_matmul_mm_tensor_perm021(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input_bf16: &MlxBuffer,
weight: &MlxBuffer,
output: &MlxBuffer,
params: &GgmlQuantizedMatmulPerm021Params,
) -> Result<()> {
let kernel_name = match params.ggml_type {
GgmlType::Q4_0 => "kernel_mul_mm_q4_0_tensor_bf16_perm021",
GgmlType::Q8_0 => "kernel_mul_mm_q8_0_tensor_bf16_perm021",
GgmlType::Q6_K => "kernel_mul_mm_q6_K_tensor_bf16_perm021",
other => {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_mm_tensor_perm021: unsupported ggml_type {:?} \
(only Q4_0 / Q8_0 / Q6_K are instantiated)",
other
)));
}
};
if params.head_dim == 0 || params.head_dim % 32 != 0 {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_mm_tensor_perm021: head_dim {} must be a positive \
multiple of 32 (NK tile width)",
params.head_dim
)));
}
if params.k % params.head_dim != 0 {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_mm_tensor_perm021: k ({}) must be divisible by \
head_dim ({})",
params.k, params.head_dim
)));
}
let n_heads = params.k / params.head_dim;
let expected_input_bytes = (n_heads as usize) * (params.m as usize)
* (params.head_dim as usize) * 2;
if input_bf16.byte_len() < expected_input_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_mm_tensor_perm021: input_bf16 buffer too small \
(have {}, need {})",
input_bf16.byte_len(), expected_input_bytes
)));
}
let pipeline = registry
.get_pipeline_with_constants(
kernel_name,
device.metal_device(),
&[],
&[(700, 1), (701, 1), (702, 1)],
)?
.clone();
let qk = params.ggml_type.block_values();
let block_bytes = params.ggml_type.block_bytes();
let blocks_per_row = params.k / qk;
let nb01 = (blocks_per_row as u64) * (block_bytes as u64);
let gpu_params = GgmlMatmulMmTensorPerm021GpuParams {
ne00: params.k as i32,
ne02: 1,
nb01,
nb02: nb01 * (params.n as u64),
nb03: 0,
ne12: 1,
_pad0: 0,
nb10: 2, nb11: 0, nb12: 0,
nb13: 0,
ne0: params.n as i32,
ne1: params.m as i32,
r2: 1,
r3: 1,
head_dim: params.head_dim as i32,
seq_len: params.m as i32,
_pad_trailing: 0,
};
const NR0: u64 = 64;
const NR1: u64 = 32;
const THREADS_PER_TG: u64 = 128;
const SHMEM_BYTES: u64 = 8192;
let threadgroups = metal::MTLSize::new(
(params.m as u64 + NR1 - 1) / NR1,
(params.n as u64 + NR0 - 1) / NR0,
1,
);
let threads_per_tg = metal::MTLSize::new(THREADS_PER_TG, 1, 1);
encoder.encode_threadgroups_with_args_and_shared(
&pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(weight)),
(2, KernelArg::Buffer(input_bf16)),
(3, KernelArg::Buffer(output)),
],
&[(0, SHMEM_BYTES)],
threadgroups,
threads_per_tg,
);
Ok(())
}
pub fn quantized_matmul_mm_tensor_perm021_f16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input_bf16: &MlxBuffer,
weight_f16: &MlxBuffer,
output: &MlxBuffer,
params: &GgmlQuantizedMatmulPerm021Params,
) -> Result<()> {
if params.head_dim == 0 || params.head_dim % 32 != 0 {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_mm_tensor_perm021_f16: head_dim {} must be a positive \
multiple of 32 (NK tile width)",
params.head_dim
)));
}
if params.k % params.head_dim != 0 {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_mm_tensor_perm021_f16: k ({}) must be divisible by \
head_dim ({})",
params.k, params.head_dim
)));
}
let n_heads = params.k / params.head_dim;
let expected_input_bytes = (n_heads as usize) * (params.m as usize)
* (params.head_dim as usize) * 2;
if input_bf16.byte_len() < expected_input_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_mm_tensor_perm021_f16: input_bf16 buffer too small \
(have {}, need {})",
input_bf16.byte_len(), expected_input_bytes
)));
}
let expected_weight_bytes = (params.n as usize) * (params.k as usize) * 2;
if weight_f16.byte_len() < expected_weight_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_mm_tensor_perm021_f16: weight_f16 buffer too small \
(have {}, need {} bytes for [n={}, k={}] half)",
weight_f16.byte_len(), expected_weight_bytes, params.n, params.k
)));
}
let pipeline = registry
.get_pipeline_with_constants(
"kernel_mul_mm_f16_tensor_bf16_perm021",
device.metal_device(),
&[],
&[(700, 1), (701, 1), (702, 1)],
)?
.clone();
let nb01: u64 = (params.k as u64) * 2;
let gpu_params = GgmlMatmulMmTensorPerm021GpuParams {
ne00: params.k as i32,
ne02: 1,
nb01,
nb02: nb01 * (params.n as u64),
nb03: 0,
ne12: 1,
_pad0: 0,
nb10: 2, nb11: 0,
nb12: 0,
nb13: 0,
ne0: params.n as i32,
ne1: params.m as i32,
r2: 1,
r3: 1,
head_dim: params.head_dim as i32,
seq_len: params.m as i32,
_pad_trailing: 0,
};
const NR0: u64 = 64;
const NR1: u64 = 32;
const THREADS_PER_TG: u64 = 128;
const SHMEM_BYTES: u64 = 8192;
let threadgroups = metal::MTLSize::new(
(params.m as u64 + NR1 - 1) / NR1,
(params.n as u64 + NR0 - 1) / NR0,
1,
);
let threads_per_tg = metal::MTLSize::new(THREADS_PER_TG, 1, 1);
encoder.encode_threadgroups_with_args_and_shared(
&pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(weight_f16)),
(2, KernelArg::Buffer(input_bf16)),
(3, KernelArg::Buffer(output)),
],
&[(0, SHMEM_BYTES)],
threadgroups,
threads_per_tg,
);
Ok(())
}