use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{encode_threadgroups_with_args_and_shared, KernelArg};
pub static HADAMARD_QUANTIZE_KV_SHADER_SOURCE: &str =
include_str!("../shaders/hadamard_quantize_kv.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("hadamard_quantize_kv", HADAMARD_QUANTIZE_KV_SHADER_SOURCE);
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct HadamardQuantizeParams {
head_dim: u32,
num_kv_heads: u32,
write_pos: u32,
cache_capacity: u32,
is_sliding: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_hadamard_quantize_kv(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
packed: &MlxBuffer,
norms: &MlxBuffer,
num_kv_heads: u32,
head_dim: u32,
cache_capacity: u32,
write_pos: u32,
is_sliding: bool,
) -> Result<()> {
if num_kv_heads == 0 || head_dim == 0 {
return Ok(());
}
if !head_dim.is_power_of_two() {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv: head_dim must be a power of two, got {}",
head_dim
)));
}
if head_dim > 4096 {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv: head_dim {} exceeds Metal 32 KB threadgroup limit \
(max 4096 for 2x f32 shared memory)",
head_dim
)));
}
if head_dim % 2 != 0 {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv: head_dim must be even for nibble packing, got {}",
head_dim
)));
}
if !is_sliding && write_pos >= cache_capacity {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv: global cache write_pos({}) >= cache_capacity({})",
write_pos, cache_capacity
)));
}
let required_src = (num_kv_heads as u64) * (head_dim as u64);
if (src.element_count() as u64) < required_src {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv: src has {} elements but need {} \
(num_kv_heads={} * head_dim={})",
src.element_count(),
required_src,
num_kv_heads,
head_dim,
)));
}
let required_packed_bytes =
(num_kv_heads as u64) * (cache_capacity as u64) * (head_dim as u64 / 2);
if (packed.byte_len() as u64) < required_packed_bytes {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv: packed buffer has {} bytes but need {} \
(num_kv_heads={} * cache_capacity={} * head_dim/2={})",
packed.byte_len(),
required_packed_bytes,
num_kv_heads,
cache_capacity,
head_dim / 2,
)));
}
let required_norms = (num_kv_heads as u64) * (cache_capacity as u64);
if (norms.element_count() as u64) < required_norms {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv: norms buffer has {} elements but need {} \
(num_kv_heads={} * cache_capacity={})",
norms.element_count(),
required_norms,
num_kv_heads,
cache_capacity,
)));
}
let kernel_name = match head_dim {
256 => "hadamard_quantize_kv_fast_d256",
512 => "hadamard_quantize_kv_fast_d512",
_ => "hadamard_quantize_kv", };
let pipeline = registry.get_pipeline(kernel_name, device)?;
let params = HadamardQuantizeParams {
head_dim,
num_kv_heads,
write_pos,
cache_capacity,
is_sliding: if is_sliding { 1 } else { 0 },
};
let params_bytes = bytemuck::bytes_of(¶ms);
if kernel_name.starts_with("hadamard_quantize_kv_fast") {
use super::encode_helpers::{encode_threadgroups_with_args, KernelArg as KA};
encode_threadgroups_with_args(
encoder,
pipeline,
&[
(0, KA::Buffer(src)),
(1, KA::Buffer(packed)),
(2, KA::Buffer(norms)),
(3, KA::Bytes(params_bytes)),
],
MTLSize::new(num_kv_heads as u64, 1, 1),
MTLSize::new(32, 1, 1), );
} else {
let shared_mem_bytes = 2u64 * (head_dim as u64) * 4;
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(src)),
(1, KernelArg::Buffer(packed)),
(2, KernelArg::Buffer(norms)),
(3, KernelArg::Bytes(params_bytes)),
],
&[(0, shared_mem_bytes)],
MTLSize::new(num_kv_heads as u64, 1, 1),
MTLSize::new(head_dim as u64, 1, 1),
);
}
Ok(())
}