use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::encoder::{CapturedOpKind, CommandEncoder, DispatchRecord, KernelArg, as_bytes};
use crate::env_flags::{cached_env_default_true, cached_env_eq_one};
use std::sync::atomic::AtomicI8;
static CACHED_Q6K_ID_MV_NR2: AtomicI8 = AtomicI8::new(-1);
static CACHED_Q8_0_ID_MV_NR2: AtomicI8 = AtomicI8::new(-1);
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use crate::ops::quantized_matmul_ggml::GgmlType;
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GgmlMatvecIdGpuParams {
ne00: i64, ne01: i64, ne02: i64, ne10: i64, ne12: i64, ne0: i64, ne1: i64, r2: u32, r3: u32, top_k: u32, n_tokens: u32, expert_stride: i64, }
#[derive(Debug, Clone, Copy)]
pub struct GgmlQuantizedMatmulIdParams {
pub n_tokens: u32,
pub top_k: u32,
pub n: u32,
pub k: u32,
pub n_experts: u32,
pub expert_stride: u64,
pub ggml_type: GgmlType,
}
impl GgmlType {
fn id_kernel_name(self) -> &'static str {
match self {
GgmlType::Q4_0 => "kernel_mul_mv_id_q4_0_f32",
GgmlType::Q8_0 => "kernel_mul_mv_id_q8_0_f32",
GgmlType::Q4_K => "kernel_mul_mv_id_q4_K_f32",
GgmlType::Q5_K => "kernel_mul_mv_id_q5_K_f32",
GgmlType::Q6_K => "kernel_mul_mv_id_q6_K_f32",
GgmlType::Q5_1 => "kernel_mul_mv_id_q5_1_f32",
GgmlType::IQ4_NL => "kernel_mul_mv_id_iq4_nl_f32",
GgmlType::F32 | GgmlType::F16 | GgmlType::I16 => "unsupported",
}
}
fn id_mm_kernel_name(self) -> &'static str {
match self {
GgmlType::Q4_0 => "kernel_mul_mm_id_q4_0_f32",
GgmlType::Q8_0 => "kernel_mul_mm_id_q8_0_f32",
GgmlType::Q5_K => "kernel_mul_mm_id_q5_K_f32",
GgmlType::Q6_K => "kernel_mul_mm_id_q6_K_f32",
GgmlType::Q4_K => "kernel_mul_mm_id_q4_K_f32",
GgmlType::Q5_1 => "kernel_mul_mm_id_q5_1_f32",
GgmlType::IQ4_NL => "kernel_mul_mm_id_iq4_nl_f32",
GgmlType::F32 | GgmlType::F16 | GgmlType::I16 => "unsupported",
}
}
fn id_mm_tensor_kernel_name(self) -> &'static str {
match self {
GgmlType::Q4_0 => "kernel_mul_mm_id_q4_0_tensor_f32",
GgmlType::Q8_0 => "kernel_mul_mm_id_q8_0_tensor_f32",
GgmlType::Q5_K => "kernel_mul_mm_id_q5_K_tensor_f32",
GgmlType::Q6_K => "kernel_mul_mm_id_q6_K_tensor_f32",
GgmlType::Q4_K => "kernel_mul_mm_id_q4_K_tensor_f32",
GgmlType::Q5_1 => "kernel_mul_mm_id_q5_1_tensor_f32",
GgmlType::IQ4_NL => "kernel_mul_mm_id_iq4_nl_tensor_f32",
GgmlType::F32 | GgmlType::F16 | GgmlType::I16 => "unsupported",
}
}
}
static TENSOR_MM_ID_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
fn probe_tensor_mm_id(registry: &mut KernelRegistry, device: &MlxDevice) -> bool {
*TENSOR_MM_ID_AVAILABLE.get_or_init(|| {
if std::env::var("HF2Q_DISABLE_TENSOR_MM_ID").is_ok() {
if std::env::var("MLX_LOG_TENSOR_PROBE").is_ok() {
eprintln!("[mlx-native] tensor_mm_id: DISABLED via HF2Q_DISABLE_TENSOR_MM_ID");
}
return false;
}
let ok = registry
.get_pipeline("kernel_mul_mm_id_q4_0_tensor_f32", device.metal_device())
.is_ok();
if std::env::var("MLX_LOG_TENSOR_PROBE").is_ok() {
eprintln!("[mlx-native] tensor_mm_id probe: {}", if ok { "OK (using tensor variant for MoE)" } else { "FAILED (falling back to simdgroup MMA)" });
}
ok
})
}
#[allow(clippy::too_many_arguments)]
pub fn quantized_matmul_id_ggml(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
ids: &MlxBuffer,
output: &MlxBuffer,
params: &GgmlQuantizedMatmulIdParams,
) -> Result<()> {
let qk = params.ggml_type.block_values();
let block_bytes = params.ggml_type.block_bytes();
if params.n_tokens == 0 || params.k == 0 || params.n == 0 {
return Err(MlxError::InvalidArgument(
"quantized_matmul_id_ggml: n_tokens, K, and N must all be > 0".into(),
));
}
if params.top_k == 0 {
return Err(MlxError::InvalidArgument(
"quantized_matmul_id_ggml: top_k must be > 0".into(),
));
}
if params.n_experts == 0 {
return Err(MlxError::InvalidArgument(
"quantized_matmul_id_ggml: n_experts must be > 0".into(),
));
}
if params.k % qk != 0 {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml: K ({}) must be divisible by block QK ({})",
params.k, qk
)));
}
let expected_input_bytes =
(params.n_tokens as usize) * (params.k as usize) * DType::F32.size_of();
if input.byte_len() < expected_input_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml: input buffer too small: expected {} bytes for [{} x {}] f32, got {}",
expected_input_bytes, params.n_tokens, params.k, input.byte_len()
)));
}
let blocks_per_row = params.k / qk;
let per_expert_bytes =
(params.n as usize) * (blocks_per_row as usize) * (block_bytes as usize);
if params.expert_stride < per_expert_bytes as u64 {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml: expert_stride ({}) < per_expert_bytes ({})",
params.expert_stride, per_expert_bytes
)));
}
let total_weight_bytes = per_expert_bytes * (params.n_experts as usize);
if weight.byte_len() < total_weight_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml: weight buffer too small: expected {} bytes for {} experts, got {}",
total_weight_bytes, params.n_experts, weight.byte_len()
)));
}
let total_rows = (params.n_tokens as usize) * (params.top_k as usize);
let expected_ids_bytes = total_rows * DType::U32.size_of();
if ids.byte_len() < expected_ids_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml: ids buffer too small: expected {} bytes for [{} * {}] u32, got {}",
expected_ids_bytes, params.n_tokens, params.top_k, ids.byte_len()
)));
}
let expected_output_bytes = total_rows * (params.n as usize) * DType::F32.size_of();
if output.byte_len() < expected_output_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml: output buffer too small: expected {} bytes for [{} x {}] f32, got {}",
expected_output_bytes, total_rows, params.n, output.byte_len()
)));
}
if params.n_tokens > mm_id_routing_threshold()
&& (params.top_k == 1 || params.top_k == 8)
&& params.k >= 32
{
if std::env::var("HF2Q_LOG_MM_ID_ROUTE").is_ok() {
eprintln!(
"[mlx-native adr-022 AC-4] dispatch_id_mm engaged: type={:?} \
n_tokens={} top_k={} k={} n={} n_experts={}",
params.ggml_type,
params.n_tokens,
params.top_k,
params.k,
params.n,
params.n_experts,
);
}
return dispatch_id_mm(
encoder, registry, device, input, weight, ids, output, params,
);
}
dispatch_id_mv(encoder, registry, device, input, weight, ids, output, params)
}
#[allow(clippy::too_many_arguments)]
pub fn quantized_matmul_id_ggml_pooled(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
ids: &MlxBuffer,
output: &MlxBuffer,
scratch: &mut IdMmScratch,
params: &GgmlQuantizedMatmulIdParams,
) -> Result<()> {
let qk = params.ggml_type.block_values();
let block_bytes = params.ggml_type.block_bytes();
if params.n_tokens == 0 || params.k == 0 || params.n == 0 {
return Err(MlxError::InvalidArgument(
"quantized_matmul_id_ggml_pooled: n_tokens, K, and N must all be > 0".into(),
));
}
if params.top_k == 0 {
return Err(MlxError::InvalidArgument(
"quantized_matmul_id_ggml_pooled: top_k must be > 0".into(),
));
}
if params.n_experts == 0 {
return Err(MlxError::InvalidArgument(
"quantized_matmul_id_ggml_pooled: n_experts must be > 0".into(),
));
}
if params.k % qk != 0 {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml_pooled: K ({}) must be divisible by block QK ({})",
params.k, qk
)));
}
let expected_input_bytes =
(params.n_tokens as usize) * (params.k as usize) * DType::F32.size_of();
if input.byte_len() < expected_input_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml_pooled: input buffer too small: expected {} bytes for [{} x {}] f32, got {}",
expected_input_bytes, params.n_tokens, params.k, input.byte_len()
)));
}
let blocks_per_row = params.k / qk;
let per_expert_bytes =
(params.n as usize) * (blocks_per_row as usize) * (block_bytes as usize);
if params.expert_stride < per_expert_bytes as u64 {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml_pooled: expert_stride ({}) < per_expert_bytes ({})",
params.expert_stride, per_expert_bytes
)));
}
let total_weight_bytes = per_expert_bytes * (params.n_experts as usize);
if weight.byte_len() < total_weight_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml_pooled: weight buffer too small: expected {} bytes for {} experts, got {}",
total_weight_bytes, params.n_experts, weight.byte_len()
)));
}
let total_rows = (params.n_tokens as usize) * (params.top_k as usize);
let expected_ids_bytes = total_rows * DType::U32.size_of();
if ids.byte_len() < expected_ids_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml_pooled: ids buffer too small: expected {} bytes for [{} * {}] u32, got {}",
expected_ids_bytes, params.n_tokens, params.top_k, ids.byte_len()
)));
}
let expected_output_bytes = total_rows * (params.n as usize) * DType::F32.size_of();
if output.byte_len() < expected_output_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml_pooled: output buffer too small: expected {} bytes for [{} x {}] f32, got {}",
expected_output_bytes, total_rows, params.n, output.byte_len()
)));
}
if params.n_tokens > mm_id_routing_threshold()
&& (params.top_k == 1 || params.top_k == 8)
&& params.k >= 32
{
if std::env::var("HF2Q_LOG_MM_ID_ROUTE").is_ok() {
eprintln!(
"[mlx-native adr-022 AC-4 pooled] dispatch_id_mm_pooled engaged: \
type={:?} n_tokens={} top_k={} k={} n={} n_experts={}",
params.ggml_type,
params.n_tokens,
params.top_k,
params.k,
params.n,
params.n_experts,
);
}
return dispatch_id_mm_pooled(
encoder, registry, device, input, weight, ids, output,
scratch, params,
);
}
dispatch_id_mv(encoder, registry, device, input, weight, ids, output, params)
}
pub const MM_ID_ROUTING_THRESHOLD: u32 = 32;
fn mm_id_routing_threshold() -> u32 {
static CACHED: std::sync::OnceLock<u32> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
std::env::var("HF2Q_MM_ID_ROUTING_THRESHOLD")
.ok()
.and_then(|s| s.parse::<u32>().ok())
.map(|v| {
if std::env::var("MLX_LOG_TENSOR_PROBE").is_ok() {
eprintln!(
"[mlx-native] mm_id_routing_threshold: OVERRIDE via HF2Q_MM_ID_ROUTING_THRESHOLD={v} (default {})",
MM_ID_ROUTING_THRESHOLD
);
}
v
})
.unwrap_or(MM_ID_ROUTING_THRESHOLD)
})
}
#[allow(clippy::too_many_arguments)]
fn dispatch_id_mv(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
ids: &MlxBuffer,
output: &MlxBuffer,
params: &GgmlQuantizedMatmulIdParams,
) -> Result<()> {
let total_rows = (params.n_tokens as usize) * (params.top_k as usize);
let use_q6k_id_nr2 = matches!(params.ggml_type, GgmlType::Q6_K)
&& cached_env_default_true(&CACHED_Q6K_ID_MV_NR2, "HF2Q_Q6K_ID_MV_NR2");
let use_q8_0_id_nr2 = matches!(params.ggml_type, GgmlType::Q8_0)
&& cached_env_eq_one(&CACHED_Q8_0_ID_MV_NR2, "HF2Q_Q8_0_ID_MV_NR2");
let kernel_name = if use_q6k_id_nr2 {
"kernel_mul_mv_id_q6_K_f32_nr2"
} else if use_q8_0_id_nr2 {
"kernel_mul_mv_id_q8_0_f32_nr2"
} else {
params.ggml_type.id_kernel_name()
};
let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
let gpu_params = GgmlMatvecIdGpuParams {
ne00: params.k as i64,
ne01: params.n as i64,
ne02: 1,
ne10: params.k as i64,
ne12: 1,
ne0: params.n as i64,
ne1: total_rows as i64,
r2: 1,
r3: 1,
top_k: params.top_k,
n_tokens: params.n_tokens,
expert_stride: params.expert_stride as i64,
};
let (nth0, nth1, align) = match params.ggml_type {
GgmlType::Q4_0
| GgmlType::Q8_0
| GgmlType::Q5_1
| GgmlType::IQ4_NL => (8u64, 8u64, 8usize),
GgmlType::Q4_K | GgmlType::Q5_K | GgmlType::Q6_K => (2u64, 32u64, 2usize),
GgmlType::F32
| GgmlType::F16
| GgmlType::I16 => {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_ggml does not support {:?}",
params.ggml_type
)));
}
};
let align = if use_q6k_id_nr2 { 4usize } else { align };
let (nth0, nth1, align) = if use_q8_0_id_nr2 {
(32u64, 4u64, 2usize)
} else {
(nth0, nth1, align)
};
let n = params.n as usize;
let m = total_rows;
let threadgroups = metal::MTLSize::new(
div_ceil(n, align) as u64,
m as u64,
1,
);
let threads_per_tg = metal::MTLSize::new(nth0, nth1, 1);
if use_q8_0_id_nr2 {
let smem_bytes: u64 = 2 * 32 * std::mem::size_of::<f32>() as u64;
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Buffer(weight)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Buffer(ids)),
(4, KernelArg::Bytes(as_bytes(&gpu_params))),
],
&[(0, smem_bytes)],
threadgroups,
threads_per_tg,
);
} else {
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(weight)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Buffer(ids)),
(4, KernelArg::Bytes(as_bytes(&gpu_params))),
],
threadgroups,
threads_per_tg,
);
}
Ok(())
}
pub fn build_q6k_id_nr2_m1_record(
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
n: u32,
k: u32,
top_k: u32,
expert_stride: u64,
) -> Result<Option<DispatchRecord>> {
if !cached_env_default_true(&CACHED_Q6K_ID_MV_NR2, "HF2Q_Q6K_ID_MV_NR2") {
return Ok(None);
}
let pipeline = registry
.get_pipeline("kernel_mul_mv_id_q6_K_f32_nr2", device)?
.clone();
let gpu_params = GgmlMatvecIdGpuParams {
ne00: k as i64,
ne01: n as i64,
ne02: 1,
ne10: k as i64,
ne12: 1,
ne0: n as i64,
ne1: top_k as i64,
r2: 1,
r3: 1,
top_k,
n_tokens: 1,
expert_stride: expert_stride as i64,
};
let params_bytes = as_bytes(&gpu_params).to_vec();
const ALIGN: u32 = 4;
let threadgroups = metal::MTLSize::new(
div_ceil(n as usize, ALIGN as usize) as u64,
top_k as u64,
1,
);
let threads_per_tg = metal::MTLSize::new(2, 32, 1);
Ok(Some(DispatchRecord {
pipeline,
threadgroups,
threads_per_tg,
threadgroup_mem: Vec::new(), params_bytes,
params_slot: 4,
buffer_slots: vec![0, 1, 2, 3], op_kind: CapturedOpKind::Other,
kernel_name: "kernel_mul_mv_id_q6_K_f32_nr2".to_string(),
}))
}
pub fn build_q8_0_id_decode_record(
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
n: u32,
k: u32,
real_top_k: u32,
expert_stride: u64,
) -> Result<Option<DispatchRecord>> {
if cached_env_eq_one(&CACHED_Q8_0_ID_MV_NR2, "HF2Q_Q8_0_ID_MV_NR2") {
return Ok(None);
}
let pipeline = registry
.get_pipeline("kernel_mul_mv_id_q8_0_f32", device)?
.clone();
let gpu_params = GgmlMatvecIdGpuParams {
ne00: k as i64,
ne01: n as i64,
ne02: 1,
ne10: k as i64,
ne12: 1,
ne0: n as i64,
ne1: real_top_k as i64,
r2: 1,
r3: 1,
top_k: 1,
n_tokens: real_top_k,
expert_stride: expert_stride as i64,
};
let params_bytes = as_bytes(&gpu_params).to_vec();
const ALIGN: u32 = 8;
let threadgroups = metal::MTLSize::new(
div_ceil(n as usize, ALIGN as usize) as u64,
real_top_k as u64,
1,
);
let threads_per_tg = metal::MTLSize::new(8, 8, 1);
Ok(Some(DispatchRecord {
pipeline,
threadgroups,
threads_per_tg,
threadgroup_mem: Vec::new(), params_bytes,
params_slot: 4,
buffer_slots: vec![0, 1, 2, 3], op_kind: CapturedOpKind::Other,
kernel_name: "kernel_mul_mv_id_q8_0_f32".to_string(),
}))
}
#[allow(clippy::too_many_arguments)]
pub fn quantized_matmul_id_swiglu_q4_0(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
gate: &MlxBuffer,
up: &MlxBuffer,
weight: &MlxBuffer,
ids: &MlxBuffer,
output: &MlxBuffer,
params: &GgmlQuantizedMatmulIdParams,
) -> Result<()> {
if params.ggml_type != GgmlType::Q4_0 {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_swiglu_q4_0: expected Q4_0, got {:?}",
params.ggml_type
)));
}
let qk = GgmlType::Q4_0.block_values();
if params.n_tokens == 0 || params.k == 0 || params.n == 0 {
return Err(MlxError::InvalidArgument(
"quantized_matmul_id_swiglu_q4_0: n_tokens, K, and N must all be > 0".into(),
));
}
if params.top_k == 0 {
return Err(MlxError::InvalidArgument(
"quantized_matmul_id_swiglu_q4_0: top_k must be > 0".into(),
));
}
if params.k % qk != 0 {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_swiglu_q4_0: K ({}) must be divisible by block QK ({})",
params.k, qk
)));
}
let total_rows = (params.n_tokens as usize) * (params.top_k as usize);
let expected_in_bytes = total_rows * (params.k as usize) * DType::F32.size_of();
if gate.byte_len() < expected_in_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_swiglu_q4_0: gate buffer too small: expected {} bytes, got {}",
expected_in_bytes, gate.byte_len()
)));
}
if up.byte_len() < expected_in_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_swiglu_q4_0: up buffer too small: expected {} bytes, got {}",
expected_in_bytes, up.byte_len()
)));
}
let expected_out_bytes = total_rows * (params.n as usize) * DType::F32.size_of();
if output.byte_len() < expected_out_bytes {
return Err(MlxError::InvalidArgument(format!(
"quantized_matmul_id_swiglu_q4_0: output buffer too small: expected {} bytes, got {}",
expected_out_bytes, output.byte_len()
)));
}
let pipeline = registry.get_pipeline(
"kernel_mul_mv_id_q4_0_f32_swiglu",
device.metal_device(),
)?;
let gpu_params = GgmlMatvecIdGpuParams {
ne00: params.k as i64,
ne01: params.n as i64,
ne02: 1,
ne10: params.k as i64,
ne12: 1,
ne0: params.n as i64,
ne1: total_rows as i64,
r2: 1,
r3: 1,
top_k: params.top_k,
n_tokens: params.n_tokens,
expert_stride: params.expert_stride as i64,
};
let (nth0, nth1, align) = (8u64, 8u64, 8usize);
let n = params.n as usize;
let m = total_rows;
let threadgroups = metal::MTLSize::new(
div_ceil(n, align) as u64,
m as u64,
1,
);
let threads_per_tg = metal::MTLSize::new(nth0, nth1, 1);
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(weight)),
(1, KernelArg::Buffer(gate)),
(2, KernelArg::Buffer(up)),
(3, KernelArg::Buffer(output)),
(4, KernelArg::Buffer(ids)),
(5, KernelArg::Bytes(as_bytes(&gpu_params))),
],
threadgroups,
threads_per_tg,
);
Ok(())
}
pub struct IdMmScratch {
pub htpe: MlxBuffer,
pub hids: MlxBuffer,
n_experts_cap: u32,
n_tokens_cap: u32,
}
impl IdMmScratch {
pub fn alloc(
device: &MlxDevice,
n_experts: u32,
max_n_tokens: u32,
) -> Result<Self> {
let htpe = device.alloc_buffer(
(n_experts as usize) * DType::U32.size_of(),
DType::U32,
vec![n_experts as usize],
)?;
let hids = device.alloc_buffer(
(n_experts as usize) * (max_n_tokens as usize) * DType::U32.size_of(),
DType::U32,
vec![n_experts as usize, max_n_tokens as usize],
)?;
Ok(Self {
htpe,
hids,
n_experts_cap: n_experts,
n_tokens_cap: max_n_tokens,
})
}
fn check_capacity(&self, n_experts: u32, n_tokens: u32) -> Result<()> {
if n_experts > self.n_experts_cap {
return Err(MlxError::InvalidArgument(format!(
"IdMmScratch: n_experts ({}) > cap ({})",
n_experts, self.n_experts_cap,
)));
}
if n_tokens > self.n_tokens_cap {
return Err(MlxError::InvalidArgument(format!(
"IdMmScratch: n_tokens ({}) > cap ({})",
n_tokens, self.n_tokens_cap,
)));
}
Ok(())
}
}
#[allow(clippy::too_many_arguments)]
fn dispatch_id_mm_pooled(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
ids: &MlxBuffer,
output: &MlxBuffer,
scratch: &mut IdMmScratch,
params: &GgmlQuantizedMatmulIdParams,
) -> Result<()> {
scratch.check_capacity(params.n_experts, params.n_tokens)?;
let dispatch = GgmlIdMmDispatchParams {
n_tokens: params.n_tokens,
top_k: params.top_k,
n: params.n,
k: params.k,
n_experts: params.n_experts,
expert_stride: params.expert_stride,
ggml_type: params.ggml_type,
};
dispatch_id_mm_for_test(
encoder, registry, device,
input, weight, ids,
&mut scratch.htpe, &mut scratch.hids, output, &dispatch,
)
}
#[allow(clippy::too_many_arguments)]
fn dispatch_id_mm(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
ids: &MlxBuffer,
output: &MlxBuffer,
params: &GgmlQuantizedMatmulIdParams,
) -> Result<()> {
let mut scratch = IdMmScratch::alloc(device, params.n_experts, params.n_tokens)?;
dispatch_id_mm_pooled(
encoder, registry, device,
input, weight, ids, output,
&mut scratch, params,
)
}
fn div_ceil(a: usize, b: usize) -> usize {
(a + b - 1) / b
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GgmlIdMmMap0GpuParams {
ne10: i32, ne11: i32, nb11: u64, nb12: u64, ne21: i32, ne20: i32, nb21: u64, }
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GgmlIdMmMmGpuParams {
ne00: i32, ne02: i32, nb01: u64, nb02: u64, nb03: u64,
ne11: i32, _pad0: u32,
nb10: u64, nb11: u64, nb12: u64, nb13: u64,
ne20: i32, ne21: i32, ne0: i32, ne1: i32, r2: i16,
r3: i16,
_pad1: u32,
}
#[derive(Debug, Clone, Copy)]
pub struct GgmlIdMmDispatchParams {
pub n_tokens: u32,
pub top_k: u32,
pub n: u32,
pub k: u32,
pub n_experts: u32,
pub expert_stride: u64,
pub ggml_type: GgmlType,
}
impl GgmlIdMmDispatchParams {
pub fn htpe_bytes(&self) -> usize {
(self.n_experts as usize) * DType::U32.size_of()
}
pub fn hids_bytes(&self) -> usize {
(self.n_experts as usize) * (self.n_tokens as usize) * DType::U32.size_of()
}
}
#[doc(hidden)]
#[allow(clippy::too_many_arguments)]
pub fn dispatch_id_mm_for_test(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
ids: &MlxBuffer,
htpe: &MlxBuffer,
hids: &MlxBuffer,
output: &MlxBuffer,
params: &GgmlIdMmDispatchParams,
) -> Result<()> {
let qk = params.ggml_type.block_values();
match params.ggml_type {
GgmlType::Q4_0
| GgmlType::Q8_0
| GgmlType::Q4_K
| GgmlType::Q5_K
| GgmlType::Q6_K
| GgmlType::Q5_1
| GgmlType::IQ4_NL => {}
other => {
return Err(MlxError::InvalidArgument(format!(
"dispatch_id_mm_for_test does not support {:?}", other
)));
}
}
if params.n_tokens == 0 || params.k == 0 || params.n == 0
|| params.top_k == 0 || params.n_experts == 0
{
return Err(MlxError::InvalidArgument(
"n_tokens, K, N, top_k, n_experts must all be > 0".into(),
));
}
if params.k % qk != 0 {
return Err(MlxError::InvalidArgument(format!(
"K ({}) must be divisible by block QK ({})", params.k, qk
)));
}
if params.top_k != 1 && params.top_k != 8 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_id_mm_for_test: top_k {} has no map0 instantiation (need 1 or 8)",
params.top_k
)));
}
let blocks_per_row = params.k / qk;
let block_bytes = params.ggml_type.block_bytes();
let per_expert_bytes =
(params.n as usize) * (blocks_per_row as usize) * (block_bytes as usize);
if (params.expert_stride as usize) < per_expert_bytes {
return Err(MlxError::InvalidArgument(format!(
"expert_stride ({}) < per_expert_bytes ({})",
params.expert_stride, per_expert_bytes
)));
}
if weight.byte_len() < per_expert_bytes * params.n_experts as usize {
return Err(MlxError::InvalidArgument(
"dispatch_id_mm_for_test: weight buffer too small".into(),
));
}
if input.byte_len()
< (params.n_tokens as usize) * (params.k as usize) * DType::F32.size_of()
{
return Err(MlxError::InvalidArgument(
"dispatch_id_mm_for_test: input buffer too small".into(),
));
}
let total_rows = (params.n_tokens as usize) * (params.top_k as usize);
if ids.byte_len() < total_rows * DType::U32.size_of() {
return Err(MlxError::InvalidArgument(
"dispatch_id_mm_for_test: ids buffer too small".into(),
));
}
if output.byte_len() < total_rows * (params.n as usize) * DType::F32.size_of() {
return Err(MlxError::InvalidArgument(
"dispatch_id_mm_for_test: output buffer too small".into(),
));
}
if htpe.byte_len() < params.htpe_bytes() {
return Err(MlxError::InvalidArgument(
"dispatch_id_mm_for_test: htpe buffer too small".into(),
));
}
if hids.byte_len() < params.hids_bytes() {
return Err(MlxError::InvalidArgument(
"dispatch_id_mm_for_test: hids buffer too small".into(),
));
}
let map0_kernel_name = match params.top_k {
1 => "kernel_mul_mm_id_map0_ne20_1",
8 => "kernel_mul_mm_id_map0_ne20_8",
other => return Err(MlxError::InvalidArgument(format!(
"dispatch_id_mm_for_test: no map0 instantiation for top_k={}",
other
))),
};
let map0_pipeline = registry.get_pipeline(map0_kernel_name, device.metal_device())?;
let map0_params = GgmlIdMmMap0GpuParams {
ne10: params.n.try_into().map_err(|_| {
MlxError::InvalidArgument("N out of i32 range".into())
})?,
ne11: params.top_k as i32,
nb11: 0,
nb12: 0,
ne21: params.n_tokens as i32,
ne20: params.top_k as i32,
nb21: (params.top_k as u64) * (DType::U32.size_of() as u64),
};
let map0_shmem =
(params.n_experts as u64) * (params.top_k as u64) * std::mem::size_of::<u16>() as u64;
let map0_threadgroups = metal::MTLSize::new(1, 1, 1);
let map0_threads = metal::MTLSize::new(params.n_experts as u64, 1, 1);
encoder.encode_threadgroups_with_args_and_shared(
map0_pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&map0_params))),
(1, KernelArg::Buffer(ids)),
(2, KernelArg::Buffer(htpe)),
(3, KernelArg::Buffer(hids)),
],
&[(0, map0_shmem)],
map0_threadgroups,
map0_threads,
);
encoder.memory_barrier();
let use_tensor = probe_tensor_mm_id(registry, device);
let mm_kernel_name = if use_tensor {
params.ggml_type.id_mm_tensor_kernel_name()
} else {
params.ggml_type.id_mm_kernel_name()
};
let mm_pipeline = registry.get_pipeline(mm_kernel_name, device.metal_device())?;
let nb01 = (blocks_per_row as u64) * (block_bytes as u64);
let row_bytes = (params.k as u64) * (DType::F32.size_of() as u64);
let mm_params = GgmlIdMmMmGpuParams {
ne00: params.k as i32,
ne02: params.n_experts as i32,
nb01,
nb02: params.expert_stride,
nb03: 0,
ne11: params.top_k as i32,
_pad0: 0,
nb10: DType::F32.size_of() as u64,
nb11: 0, nb12: row_bytes, nb13: 0,
ne20: params.top_k as i32,
ne21: params.n_tokens as i32,
ne0: params.n as i32,
ne1: params.top_k as i32,
r2: 1,
r3: 1,
_pad1: 0,
};
const NR0: u64 = 64;
const NR1: u64 = 32;
const THREADS_PER_TG: u64 = 128;
let mm_threadgroups = metal::MTLSize::new(
(params.n_tokens as u64 + NR1 - 1) / NR1,
(params.n as u64 + NR0 - 1) / NR0,
params.n_experts as u64,
);
let mm_threads = metal::MTLSize::new(THREADS_PER_TG, 1, 1);
const MM_SHMEM_BYTES: u64 = 8192;
encoder.encode_threadgroups_with_args_and_shared(
mm_pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&mm_params))),
(1, KernelArg::Buffer(weight)),
(2, KernelArg::Buffer(input)),
(3, KernelArg::Buffer(htpe)),
(4, KernelArg::Buffer(hids)),
(5, KernelArg::Buffer(output)),
],
&[(0, MM_SHMEM_BYTES)],
mm_threadgroups,
mm_threads,
);
Ok(())
}