use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{encode_threadgroups_with_args, KernelArg};
pub static TQ_DEQUANTIZE_KV_SHADER_SOURCE: &str =
include_str!("../shaders/tq_dequantize_kv.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("tq_dequantize_kv", TQ_DEQUANTIZE_KV_SHADER_SOURCE);
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct TqDequantizeKvParamsGpu {
head_dim: u32,
num_kv_heads: u32,
read_pos: u32,
cache_capacity: u32,
norms_per_pos: u32,
scale_factor_d512: f32,
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_tq_dequantize_kv(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
packed: &MlxBuffer,
norms: &MlxBuffer,
dst: &MlxBuffer,
num_kv_heads: u32,
head_dim: u32,
cache_capacity: u32,
read_pos: u32,
scale_factor_d512: f32,
) -> Result<()> {
if num_kv_heads == 0 || head_dim == 0 {
return Ok(());
}
if !head_dim.is_power_of_two() {
return Err(MlxError::InvalidArgument(format!(
"tq_dequantize_kv: head_dim must be a power of two, got {}",
head_dim
)));
}
let required_dst = (num_kv_heads as u64) * (head_dim as u64);
if (dst.element_count() as u64) < required_dst {
return Err(MlxError::InvalidArgument(format!(
"tq_dequantize_kv: dst has {} elements, need {}",
dst.element_count(),
required_dst
)));
}
let norms_per_pos = (head_dim / 256).max(1);
let params = TqDequantizeKvParamsGpu {
head_dim,
num_kv_heads,
read_pos,
cache_capacity,
norms_per_pos,
scale_factor_d512,
};
let params_bytes = bytemuck::bytes_of(¶ms);
let pipeline = registry.get_pipeline("tq_dequantize_kv", device)?;
let threadgroups = MTLSize { width: num_kv_heads as u64, height: 1, depth: 1 };
let threadgroup_size = MTLSize {
width: head_dim.min(1024) as u64,
height: 1,
depth: 1,
};
encode_threadgroups_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(packed)),
(1, KernelArg::Buffer(norms)),
(2, KernelArg::Buffer(dst)),
(3, KernelArg::Bytes(params_bytes)),
],
threadgroups,
threadgroup_size,
);
Ok(())
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct TqDequantizeHbKvParamsGpu {
head_dim: u32,
num_kv_heads: u32,
read_pos: u32,
cache_capacity: u32,
norms_per_pos: u32,
scale_factor_d512: f32,
codebook_bits: u32, }
#[allow(clippy::too_many_arguments)]
pub fn dispatch_tq_dequantize_hb_kv(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
packed: &MlxBuffer, norms: &MlxBuffer,
dst: &MlxBuffer, num_kv_heads: u32,
head_dim: u32,
cache_capacity: u32,
read_pos: u32,
scale_factor_d512: f32,
codebook_bits: u32,
) -> Result<()> {
if num_kv_heads == 0 || head_dim == 0 { return Ok(()); }
if !matches!(codebook_bits, 5 | 6 | 8) {
return Err(MlxError::InvalidArgument(format!(
"dispatch_tq_dequantize_hb_kv: codebook_bits must be 5, 6, or 8, got {}", codebook_bits)));
}
let norms_per_pos = (head_dim / 256).max(1);
let params = TqDequantizeHbKvParamsGpu {
head_dim,
num_kv_heads,
read_pos,
cache_capacity,
norms_per_pos,
scale_factor_d512,
codebook_bits,
};
let params_bytes = bytemuck::bytes_of(¶ms);
let pipeline = registry.get_pipeline("tq_dequantize_hb_kv", device)?;
let threadgroups = MTLSize { width: num_kv_heads as u64, height: 1, depth: 1 };
let threadgroup_size = MTLSize {
width: head_dim.min(1024) as u64,
height: 1,
depth: 1,
};
encode_threadgroups_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(packed)),
(1, KernelArg::Buffer(norms)),
(2, KernelArg::Buffer(dst)),
(3, KernelArg::Bytes(params_bytes)),
],
threadgroups,
threadgroup_size,
);
Ok(())
}