use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{encode_threadgroups_with_args_and_shared, KernelArg};
pub static HADAMARD_QUANTIZE_KV_SHADER_SOURCE: &str =
include_str!("../shaders/hadamard_quantize_kv.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("hadamard_quantize_kv", HADAMARD_QUANTIZE_KV_SHADER_SOURCE);
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct HadamardQuantizeParams {
head_dim: u32,
num_kv_heads: u32,
write_pos: u32,
cache_capacity: u32,
is_sliding: u32,
scale_factor_d512: f32,
rms_probe_enabled: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_hadamard_quantize_kv(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
packed: &MlxBuffer,
norms: &MlxBuffer,
num_kv_heads: u32,
head_dim: u32,
cache_capacity: u32,
write_pos: u32,
is_sliding: bool,
scale_factor_d512: Option<f32>,
rms_scratch: Option<&MlxBuffer>,
) -> Result<()> {
if num_kv_heads == 0 || head_dim == 0 {
return Ok(());
}
if !head_dim.is_power_of_two() {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv: head_dim must be a power of two, got {}",
head_dim
)));
}
if head_dim > 4096 {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv: head_dim {} exceeds Metal 32 KB threadgroup limit \
(max 4096 for 2x f32 shared memory)",
head_dim
)));
}
if head_dim % 2 != 0 {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv: head_dim must be even for nibble packing, got {}",
head_dim
)));
}
if !is_sliding && write_pos >= cache_capacity {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv: global cache write_pos({}) >= cache_capacity({})",
write_pos, cache_capacity
)));
}
let required_src = (num_kv_heads as u64) * (head_dim as u64);
if (src.element_count() as u64) < required_src {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv: src has {} elements but need {} \
(num_kv_heads={} * head_dim={})",
src.element_count(),
required_src,
num_kv_heads,
head_dim,
)));
}
let required_packed_bytes =
(num_kv_heads as u64) * (cache_capacity as u64) * (head_dim as u64 / 2);
if (packed.byte_len() as u64) < required_packed_bytes {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv: packed buffer has {} bytes but need {} \
(num_kv_heads={} * cache_capacity={} * head_dim/2={})",
packed.byte_len(),
required_packed_bytes,
num_kv_heads,
cache_capacity,
head_dim / 2,
)));
}
let norms_per_pos = (head_dim / 256).max(1) as u64;
let required_norms = (num_kv_heads as u64) * (cache_capacity as u64) * norms_per_pos;
if (norms.element_count() as u64) < required_norms {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv: norms buffer has {} elements but need {} \
(num_kv_heads={} * cache_capacity={} * norms_per_pos={})",
norms.element_count(),
required_norms,
num_kv_heads,
cache_capacity,
norms_per_pos,
)));
}
let kernel_name = match head_dim {
256 => "hadamard_quantize_kv_fast_d256",
512 => "hadamard_quantize_kv_fast_d512",
_ => "hadamard_quantize_kv", };
let pipeline = registry.get_pipeline(kernel_name, device)?;
let effective_scale = scale_factor_d512.unwrap_or(1.0_f32);
let probe_enabled = rms_scratch.is_some() as u32;
let params = HadamardQuantizeParams {
head_dim,
num_kv_heads,
write_pos,
cache_capacity,
is_sliding: if is_sliding { 1 } else { 0 },
scale_factor_d512: effective_scale,
rms_probe_enabled: probe_enabled,
};
let params_bytes = bytemuck::bytes_of(¶ms);
if kernel_name.starts_with("hadamard_quantize_kv_fast") {
use super::encode_helpers::{encode_threadgroups_with_args, KernelArg as KA};
let scratch_binding = rms_scratch.unwrap_or(norms);
encode_threadgroups_with_args(
encoder,
pipeline,
&[
(0, KA::Buffer(src)),
(1, KA::Buffer(packed)),
(2, KA::Buffer(norms)),
(3, KA::Bytes(params_bytes)),
(4, KA::Buffer(scratch_binding)),
],
MTLSize::new(num_kv_heads as u64, 1, 1),
MTLSize::new(32, 1, 1), );
} else {
let shared_mem_bytes = 2u64 * (head_dim as u64) * 4;
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(src)),
(1, KernelArg::Buffer(packed)),
(2, KernelArg::Buffer(norms)),
(3, KernelArg::Bytes(params_bytes)),
],
&[(0, shared_mem_bytes)],
MTLSize::new(num_kv_heads as u64, 1, 1),
MTLSize::new(head_dim as u64, 1, 1),
);
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_hadamard_quantize_kv_seq(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
packed: &MlxBuffer,
norms: &MlxBuffer,
num_kv_heads: u32,
head_dim: u32,
cache_capacity: u32,
write_pos_start: u32,
n_tokens: u32,
src_tok_offset: u32,
is_sliding: bool,
scale_factor_d512: Option<f32>,
) -> Result<()> {
if n_tokens == 0 || num_kv_heads == 0 || head_dim == 0 {
return Ok(());
}
let required_src =
(src_tok_offset as u64 + n_tokens as u64) * (num_kv_heads as u64) * (head_dim as u64);
if (src.element_count() as u64) < required_src {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_seq: src has {} elements but need {} \
(src_tok_offset={} + n_tokens={} * num_kv_heads={} * head_dim={})",
src.element_count(),
required_src,
src_tok_offset,
n_tokens,
num_kv_heads,
head_dim,
)));
}
if !head_dim.is_power_of_two() {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_seq: head_dim must be a power of two, got {}",
head_dim
)));
}
if head_dim > 4096 {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_seq: head_dim {} exceeds Metal 32 KB threadgroup limit",
head_dim
)));
}
if head_dim % 2 != 0 {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_seq: head_dim must be even for nibble packing, got {}",
head_dim
)));
}
let kernel_name = match head_dim {
256 => "hadamard_quantize_kv_fast_d256",
512 => "hadamard_quantize_kv_fast_d512",
_ => "hadamard_quantize_kv",
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let bytes_per_token = (num_kv_heads as u64) * (head_dim as u64) * 4;
for i in 0..n_tokens {
let write_pos = write_pos_start + i;
if !is_sliding && write_pos >= cache_capacity {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_seq: global cache write_pos({}) >= cache_capacity({}) at seq idx {}",
write_pos, cache_capacity, i
)));
}
let effective_scale = scale_factor_d512.unwrap_or(1.0_f32);
let params = HadamardQuantizeParams {
head_dim,
num_kv_heads,
write_pos,
cache_capacity,
is_sliding: if is_sliding { 1 } else { 0 },
scale_factor_d512: effective_scale,
rms_probe_enabled: 0, };
let params_bytes = bytemuck::bytes_of(¶ms);
let src_offset = ((src_tok_offset + i) as u64) * bytes_per_token;
if kernel_name.starts_with("hadamard_quantize_kv_fast") {
use super::encode_helpers::encode_threadgroups_with_args;
encode_threadgroups_with_args(
encoder,
pipeline,
&[
(0, KernelArg::BufferWithOffset(src, src_offset)),
(1, KernelArg::Buffer(packed)),
(2, KernelArg::Buffer(norms)),
(3, KernelArg::Bytes(params_bytes)),
(4, KernelArg::Buffer(norms)), ],
MTLSize::new(num_kv_heads as u64, 1, 1),
MTLSize::new(32, 1, 1),
);
} else {
let shared_mem_bytes = 2u64 * (head_dim as u64) * 4;
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::BufferWithOffset(src, src_offset)),
(1, KernelArg::Buffer(packed)),
(2, KernelArg::Buffer(norms)),
(3, KernelArg::Bytes(params_bytes)),
],
&[(0, shared_mem_bytes)],
MTLSize::new(num_kv_heads as u64, 1, 1),
MTLSize::new(head_dim as u64, 1, 1),
);
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_hadamard_quantize_kv_fast_dual(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src_k: &MlxBuffer,
src_v: &MlxBuffer,
packed_k: &MlxBuffer,
packed_v: &MlxBuffer,
norms_k: &MlxBuffer,
norms_v: &MlxBuffer,
num_kv_heads: u32,
head_dim: u32,
cache_capacity: u32,
write_pos: u32,
is_sliding: bool,
scale_factor_d512: Option<f32>,
) -> Result<()> {
if num_kv_heads == 0 || head_dim == 0 {
return Ok(());
}
let kernel_name = match head_dim {
256 => "hadamard_quantize_kv_fast_dual_d256",
512 => "hadamard_quantize_kv_fast_dual_d512",
_ => {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_fast_dual: head_dim {} not supported (need 256 or 512)",
head_dim
)));
}
};
if !is_sliding && write_pos >= cache_capacity {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_fast_dual: global cache write_pos({}) >= cache_capacity({})",
write_pos, cache_capacity
)));
}
let required_src = (num_kv_heads as u64) * (head_dim as u64);
if (src_k.element_count() as u64) < required_src {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_fast_dual: src_k has {} elements but need {}",
src_k.element_count(), required_src
)));
}
if (src_v.element_count() as u64) < required_src {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_fast_dual: src_v has {} elements but need {}",
src_v.element_count(), required_src
)));
}
let required_packed_bytes =
(num_kv_heads as u64) * (cache_capacity as u64) * (head_dim as u64 / 2);
if (packed_k.byte_len() as u64) < required_packed_bytes {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_fast_dual: packed_k has {} bytes but need {}",
packed_k.byte_len(), required_packed_bytes
)));
}
if (packed_v.byte_len() as u64) < required_packed_bytes {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_fast_dual: packed_v has {} bytes but need {}",
packed_v.byte_len(), required_packed_bytes
)));
}
let norms_per_pos = (head_dim / 256).max(1) as u64;
let required_norms = (num_kv_heads as u64) * (cache_capacity as u64) * norms_per_pos;
if (norms_k.element_count() as u64) < required_norms {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_fast_dual: norms_k has {} elements but need {}",
norms_k.element_count(), required_norms
)));
}
if (norms_v.element_count() as u64) < required_norms {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_fast_dual: norms_v has {} elements but need {}",
norms_v.element_count(), required_norms
)));
}
let pipeline = registry.get_pipeline(kernel_name, device)?;
let params = HadamardQuantizeParams {
head_dim,
num_kv_heads,
write_pos,
cache_capacity,
is_sliding: if is_sliding { 1 } else { 0 },
scale_factor_d512: scale_factor_d512.unwrap_or(1.0_f32),
rms_probe_enabled: 0, };
let params_bytes = bytemuck::bytes_of(¶ms);
use super::encode_helpers::{encode_threadgroups_with_args, KernelArg as KA};
encode_threadgroups_with_args(
encoder,
pipeline,
&[
(0, KA::Buffer(src_k)),
(1, KA::Buffer(src_v)),
(2, KA::Buffer(packed_k)),
(3, KA::Buffer(packed_v)),
(4, KA::Buffer(norms_k)),
(5, KA::Buffer(norms_v)),
(6, KA::Bytes(params_bytes)),
],
MTLSize::new(num_kv_heads as u64, 1, 2), MTLSize::new(32, 1, 1), );
Ok(())
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct HadamardQuantizeHbParams {
head_dim: u32,
num_kv_heads: u32,
write_pos: u32,
cache_capacity: u32,
is_sliding: u32,
scale_factor_d512: f32,
codebook_bits: u32, }
#[allow(clippy::too_many_arguments)]
pub fn dispatch_hadamard_quantize_kv_hb(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
packed: &MlxBuffer, norms: &MlxBuffer,
num_kv_heads: u32,
head_dim: u32,
cache_capacity: u32,
write_pos: u32,
is_sliding: bool,
scale_factor_d512: f32,
codebook_bits: u32, ) -> Result<()> {
if num_kv_heads == 0 || head_dim == 0 { return Ok(()); }
if !matches!(codebook_bits, 5 | 6 | 8) {
return Err(MlxError::InvalidArgument(format!(
"dispatch_hadamard_quantize_kv_hb: codebook_bits must be 5, 6, or 8, got {}", codebook_bits)));
}
let kernel_name = match head_dim {
256 => "hadamard_quantize_kv_hb_d256",
512 => "hadamard_quantize_kv_hb_d512",
_ => return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_hb: head_dim {} not supported (need 256 or 512)", head_dim))),
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let params = HadamardQuantizeHbParams {
head_dim,
num_kv_heads,
write_pos,
cache_capacity,
is_sliding: if is_sliding { 1 } else { 0 },
scale_factor_d512,
codebook_bits,
};
let params_bytes = bytemuck::bytes_of(¶ms);
use super::encode_helpers::{encode_threadgroups_with_args, KernelArg as KA};
encode_threadgroups_with_args(
encoder,
pipeline,
&[
(0, KA::Buffer(src)),
(1, KA::Buffer(packed)),
(2, KA::Buffer(norms)),
(3, KA::Bytes(params_bytes)),
],
MTLSize::new(num_kv_heads as u64, 1, 1),
MTLSize::new(32, 1, 1), );
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_kv_quantize_v_no_fwht(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
packed: &MlxBuffer, norms: &MlxBuffer,
num_kv_heads: u32,
head_dim: u32,
cache_capacity: u32,
write_pos: u32,
is_sliding: bool,
scale_factor_d512: f32,
codebook_bits: u32, ) -> Result<()> {
if num_kv_heads == 0 || head_dim == 0 { return Ok(()); }
if !matches!(codebook_bits, 5 | 6 | 8) {
return Err(MlxError::InvalidArgument(format!(
"dispatch_kv_quantize_v_no_fwht: codebook_bits must be 5, 6, or 8, got {}",
codebook_bits)));
}
let kernel_name = match head_dim {
256 => "kv_quantize_v_no_fwht_d256",
512 => "kv_quantize_v_no_fwht_d512",
_ => return Err(MlxError::InvalidArgument(format!(
"kv_quantize_v_no_fwht: head_dim {} not supported (need 256 or 512)", head_dim))),
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let params = HadamardQuantizeHbParams {
head_dim,
num_kv_heads,
write_pos,
cache_capacity,
is_sliding: if is_sliding { 1 } else { 0 },
scale_factor_d512,
codebook_bits,
};
let params_bytes = bytemuck::bytes_of(¶ms);
use super::encode_helpers::{encode_threadgroups_with_args, KernelArg as KA};
encode_threadgroups_with_args(
encoder,
pipeline,
&[
(0, KA::Buffer(src)),
(1, KA::Buffer(packed)),
(2, KA::Buffer(norms)),
(3, KA::Bytes(params_bytes)),
],
MTLSize::new(num_kv_heads as u64, 1, 1),
MTLSize::new(32, 1, 1), );
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_kv_copy_kf16_quantize_v_no_fwht(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src_k: &MlxBuffer,
src_v: &MlxBuffer,
cache_k: &MlxBuffer, packed_v: &MlxBuffer, norms_v: &MlxBuffer, num_kv_heads: u32,
head_dim: u32,
cache_capacity: u32,
write_pos: u32,
is_sliding: bool,
scale_factor_d512: f32,
codebook_bits: u32,
) -> Result<()> {
if num_kv_heads == 0 || head_dim == 0 { return Ok(()); }
if !matches!(codebook_bits, 5 | 6 | 8) {
return Err(MlxError::InvalidArgument(format!(
"dispatch_kv_copy_kf16_quantize_v_no_fwht: codebook_bits must be 5, 6, or 8, got {}",
codebook_bits)));
}
if cache_k.dtype() != crate::DType::F16 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_kv_copy_kf16_quantize_v_no_fwht: cache_k must be DType::F16, got {:?}",
cache_k.dtype())));
}
let kernel_name = match head_dim {
256 => "kv_copy_kf16_quantize_v_no_fwht_d256",
512 => "kv_copy_kf16_quantize_v_no_fwht_d512",
_ => return Err(MlxError::InvalidArgument(format!(
"kv_copy_kf16_quantize_v_no_fwht: head_dim {} not supported (need 256 or 512)",
head_dim))),
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let params = HadamardQuantizeHbParams {
head_dim,
num_kv_heads,
write_pos,
cache_capacity,
is_sliding: if is_sliding { 1 } else { 0 },
scale_factor_d512,
codebook_bits,
};
let params_bytes = bytemuck::bytes_of(¶ms);
use super::encode_helpers::{encode_threadgroups_with_args, KernelArg as KA};
encode_threadgroups_with_args(
encoder,
pipeline,
&[
(0, KA::Buffer(src_k)),
(1, KA::Buffer(src_v)),
(2, KA::Buffer(cache_k)),
(3, KA::Buffer(packed_v)),
(4, KA::Buffer(norms_v)),
(5, KA::Bytes(params_bytes)),
],
MTLSize::new(num_kv_heads as u64, 1, 2),
MTLSize::new(32, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_hadamard_quantize_kv_hb_dual(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src_k: &MlxBuffer,
src_v: &MlxBuffer,
packed_k: &MlxBuffer,
packed_v: &MlxBuffer,
norms_k: &MlxBuffer,
norms_v: &MlxBuffer,
num_kv_heads: u32,
head_dim: u32,
cache_capacity: u32,
write_pos: u32,
is_sliding: bool,
scale_factor_d512: f32,
codebook_bits: u32,
) -> Result<()> {
if num_kv_heads == 0 || head_dim == 0 { return Ok(()); }
if !matches!(codebook_bits, 5 | 6 | 8) {
return Err(MlxError::InvalidArgument(format!(
"dispatch_hadamard_quantize_kv_hb_dual: codebook_bits must be 5, 6, or 8, got {}", codebook_bits)));
}
let kernel_name = match head_dim {
256 => "hadamard_quantize_kv_hb_dual_d256",
512 => "hadamard_quantize_kv_hb_dual_d512",
_ => return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_hb_dual: head_dim {} not supported (need 256 or 512)", head_dim))),
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let params = HadamardQuantizeHbParams {
head_dim,
num_kv_heads,
write_pos,
cache_capacity,
is_sliding: if is_sliding { 1 } else { 0 },
scale_factor_d512,
codebook_bits,
};
let params_bytes = bytemuck::bytes_of(¶ms);
use super::encode_helpers::{encode_threadgroups_with_args, KernelArg as KA};
encode_threadgroups_with_args(
encoder,
pipeline,
&[
(0, KA::Buffer(src_k)),
(1, KA::Buffer(src_v)),
(2, KA::Buffer(packed_k)),
(3, KA::Buffer(packed_v)),
(4, KA::Buffer(norms_k)),
(5, KA::Buffer(norms_v)),
(6, KA::Bytes(params_bytes)),
],
MTLSize::new(num_kv_heads as u64, 1, 2), MTLSize::new(32, 1, 1), );
Ok(())
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments)]
pub fn dispatch_kv_quantize_v_no_fwht_seq(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
packed: &MlxBuffer,
norms: &MlxBuffer,
num_kv_heads: u32,
head_dim: u32,
cache_capacity: u32,
write_pos_start: u32,
n_tokens: u32,
src_tok_offset: u32,
is_sliding: bool,
scale_factor_d512: f32,
codebook_bits: u32,
) -> Result<()> {
if n_tokens == 0 || num_kv_heads == 0 || head_dim == 0 {
return Ok(());
}
if !matches!(codebook_bits, 5 | 6 | 8) {
return Err(MlxError::InvalidArgument(format!(
"dispatch_kv_quantize_v_no_fwht_seq: codebook_bits must be \
5, 6, or 8, got {}",
codebook_bits
)));
}
let kernel_name = match head_dim {
256 => "kv_quantize_v_no_fwht_d256",
512 => "kv_quantize_v_no_fwht_d512",
_ => {
return Err(MlxError::InvalidArgument(format!(
"kv_quantize_v_no_fwht_seq: head_dim {} not supported \
(need 256 or 512)",
head_dim
)))
}
};
let required_src = (src_tok_offset as u64 + n_tokens as u64)
* (num_kv_heads as u64)
* (head_dim as u64);
if (src.element_count() as u64) < required_src {
return Err(MlxError::InvalidArgument(format!(
"kv_quantize_v_no_fwht_seq: src has {} elements but need {} \
(src_tok_offset={} + n_tokens={} * num_kv_heads={} * head_dim={})",
src.element_count(), required_src,
src_tok_offset, n_tokens, num_kv_heads, head_dim,
)));
}
let pipeline = registry.get_pipeline(kernel_name, device)?;
let bytes_per_token = (num_kv_heads as u64) * (head_dim as u64) * 4;
use super::encode_helpers::{encode_threadgroups_with_args, KernelArg as KA};
for i in 0..n_tokens {
let write_pos = write_pos_start + i;
if !is_sliding && write_pos >= cache_capacity {
return Err(MlxError::InvalidArgument(format!(
"kv_quantize_v_no_fwht_seq: global cache write_pos({}) >= \
cache_capacity({}) at seq idx {}",
write_pos, cache_capacity, i
)));
}
let params = HadamardQuantizeHbParams {
head_dim, num_kv_heads, write_pos, cache_capacity,
is_sliding: if is_sliding { 1 } else { 0 },
scale_factor_d512, codebook_bits,
};
let params_bytes = bytemuck::bytes_of(¶ms);
let src_offset = ((src_tok_offset + i) as u64) * bytes_per_token;
encode_threadgroups_with_args(
encoder, pipeline,
&[
(0, KA::BufferWithOffset(src, src_offset)),
(1, KA::Buffer(packed)),
(2, KA::Buffer(norms)),
(3, KA::Bytes(params_bytes)),
],
MTLSize::new(num_kv_heads as u64, 1, 1),
MTLSize::new(32, 1, 1),
);
}
Ok(())
}
pub fn dispatch_hadamard_quantize_kv_hb_seq(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
packed: &MlxBuffer,
norms: &MlxBuffer,
num_kv_heads: u32,
head_dim: u32,
cache_capacity: u32,
write_pos_start: u32,
n_tokens: u32,
src_tok_offset: u32,
is_sliding: bool,
scale_factor_d512: f32,
codebook_bits: u32,
) -> Result<()> {
if n_tokens == 0 || num_kv_heads == 0 || head_dim == 0 {
return Ok(());
}
if !matches!(codebook_bits, 5 | 6 | 8) {
return Err(MlxError::InvalidArgument(format!(
"dispatch_hadamard_quantize_kv_hb_seq: codebook_bits must be \
5, 6, or 8, got {}",
codebook_bits
)));
}
let kernel_name = match head_dim {
256 => "hadamard_quantize_kv_hb_d256",
512 => "hadamard_quantize_kv_hb_d512",
_ => {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_hb_seq: head_dim {} not supported \
(need 256 or 512)",
head_dim
)))
}
};
let required_src = (src_tok_offset as u64 + n_tokens as u64)
* (num_kv_heads as u64)
* (head_dim as u64);
if (src.element_count() as u64) < required_src {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_hb_seq: src has {} elements but need {} \
(src_tok_offset={} + n_tokens={} * num_kv_heads={} * head_dim={})",
src.element_count(),
required_src,
src_tok_offset,
n_tokens,
num_kv_heads,
head_dim,
)));
}
let pipeline = registry.get_pipeline(kernel_name, device)?;
let bytes_per_token = (num_kv_heads as u64) * (head_dim as u64) * 4;
use super::encode_helpers::{encode_threadgroups_with_args, KernelArg as KA};
for i in 0..n_tokens {
let write_pos = write_pos_start + i;
if !is_sliding && write_pos >= cache_capacity {
return Err(MlxError::InvalidArgument(format!(
"hadamard_quantize_kv_hb_seq: global cache write_pos({}) >= \
cache_capacity({}) at seq idx {}",
write_pos, cache_capacity, i
)));
}
let params = HadamardQuantizeHbParams {
head_dim,
num_kv_heads,
write_pos,
cache_capacity,
is_sliding: if is_sliding { 1 } else { 0 },
scale_factor_d512,
codebook_bits,
};
let params_bytes = bytemuck::bytes_of(¶ms);
let src_offset = ((src_tok_offset + i) as u64) * bytes_per_token;
encode_threadgroups_with_args(
encoder,
pipeline,
&[
(0, KA::BufferWithOffset(src, src_offset)),
(1, KA::Buffer(packed)),
(2, KA::Buffer(norms)),
(3, KA::Bytes(params_bytes)),
],
MTLSize::new(num_kv_heads as u64, 1, 1),
MTLSize::new(32, 1, 1),
);
}
Ok(())
}