use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{encode_with_args, KernelArg};
pub static KV_CACHE_COPY_SHADER_SOURCE: &str = include_str!("../shaders/kv_cache_copy.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("kv_cache_copy", KV_CACHE_COPY_SHADER_SOURCE);
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_kv_cache_copy(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
cache: &MlxBuffer,
write_pos: u32,
row_size: u32,
n_new: u32,
cache_cap: u32,
is_sliding: bool,
) -> Result<()> {
if n_new == 0 || row_size == 0 {
return Ok(()); }
let total_elements = (n_new as u64) * (row_size as u64);
let src_elements = src.element_count() as u64;
if src_elements < total_elements {
return Err(MlxError::InvalidArgument(format!(
"kv_cache_copy: src has {} elements but need {} (n_new={} * row_size={})",
src_elements, total_elements, n_new, row_size
)));
}
if !is_sliding && (write_pos as u64 + n_new as u64) > cache_cap as u64 {
return Err(MlxError::InvalidArgument(format!(
"kv_cache_copy: global cache overflow: write_pos({}) + n_new({}) > cache_cap({})",
write_pos, n_new, cache_cap
)));
}
let pipeline = registry.get_pipeline("kv_cache_copy", device)?;
let is_sliding_val: u32 = if is_sliding { 1 } else { 0 };
let write_pos_bytes = write_pos.to_ne_bytes();
let row_size_bytes = row_size.to_ne_bytes();
let n_new_bytes = n_new.to_ne_bytes();
let cache_cap_bytes = cache_cap.to_ne_bytes();
let is_sliding_bytes = is_sliding_val.to_ne_bytes();
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(src)),
(1, KernelArg::Buffer(cache)),
(2, KernelArg::Bytes(&write_pos_bytes)),
(3, KernelArg::Bytes(&row_size_bytes)),
(4, KernelArg::Bytes(&n_new_bytes)),
(5, KernelArg::Bytes(&cache_cap_bytes)),
(6, KernelArg::Bytes(&is_sliding_bytes)),
],
MTLSize::new(total_elements, 1, 1),
MTLSize::new(std::cmp::min(256, total_elements), 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_kv_cache_copy_batch_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
cache: &MlxBuffer,
n_heads: u32,
head_dim: u32,
capacity: u32,
seq_pos: u32,
) -> Result<()> {
if n_heads == 0 || head_dim == 0 {
return Ok(());
}
let total_src = (n_heads as u64) * (head_dim as u64);
if (src.element_count() as u64) < total_src {
return Err(MlxError::InvalidArgument(format!(
"kv_cache_copy_batch_f32: src has {} elements but need {} (n_heads={} * head_dim={})",
src.element_count(), total_src, n_heads, head_dim
)));
}
let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32", device)?;
let n_heads_bytes = n_heads.to_ne_bytes();
let head_dim_bytes = head_dim.to_ne_bytes();
let capacity_bytes = capacity.to_ne_bytes();
let seq_pos_bytes = seq_pos.to_ne_bytes();
use super::encode_helpers::{encode_with_args, KernelArg};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(src)),
(1, KernelArg::Buffer(cache)),
(2, KernelArg::Bytes(&n_heads_bytes)),
(3, KernelArg::Bytes(&head_dim_bytes)),
(4, KernelArg::Bytes(&capacity_bytes)),
(5, KernelArg::Bytes(&seq_pos_bytes)),
],
MTLSize::new(head_dim as u64, n_heads as u64, 1),
MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_kv_cache_copy_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
cache: &MlxBuffer,
write_pos: u32,
row_size: u32,
n_new: u32,
cache_cap: u32,
is_sliding: bool,
) -> Result<()> {
if n_new == 0 || row_size == 0 {
return Ok(()); }
let total_elements = (n_new as u64) * (row_size as u64);
let src_elements = src.element_count() as u64;
if src_elements < total_elements {
return Err(MlxError::InvalidArgument(format!(
"kv_cache_copy_f32: src has {} elements but need {} (n_new={} * row_size={})",
src_elements, total_elements, n_new, row_size
)));
}
if !is_sliding && (write_pos as u64 + n_new as u64) > cache_cap as u64 {
return Err(MlxError::InvalidArgument(format!(
"kv_cache_copy_f32: global cache overflow: write_pos({}) + n_new({}) > cache_cap({})",
write_pos, n_new, cache_cap
)));
}
let pipeline = registry.get_pipeline("kv_cache_copy_f32", device)?;
let is_sliding_val: u32 = if is_sliding { 1 } else { 0 };
let write_pos_bytes = write_pos.to_ne_bytes();
let row_size_bytes = row_size.to_ne_bytes();
let n_new_bytes = n_new.to_ne_bytes();
let cache_cap_bytes = cache_cap.to_ne_bytes();
let is_sliding_bytes = is_sliding_val.to_ne_bytes();
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(src)),
(1, KernelArg::Buffer(cache)),
(2, KernelArg::Bytes(&write_pos_bytes)),
(3, KernelArg::Bytes(&row_size_bytes)),
(4, KernelArg::Bytes(&n_new_bytes)),
(5, KernelArg::Bytes(&cache_cap_bytes)),
(6, KernelArg::Bytes(&is_sliding_bytes)),
],
MTLSize::new(total_elements, 1, 1),
MTLSize::new(std::cmp::min(256, total_elements), 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_kv_cache_copy_batch_f32_to_f16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
cache: &MlxBuffer,
n_heads: u32,
head_dim: u32,
capacity: u32,
seq_pos: u32,
) -> Result<()> {
if n_heads == 0 || head_dim == 0 {
return Ok(());
}
let total_src = (n_heads as u64) * (head_dim as u64);
if (src.element_count() as u64) < total_src {
return Err(MlxError::InvalidArgument(format!(
"kv_cache_copy_batch_f32_to_f16: src has {} elements but need {} (n_heads={} * head_dim={})",
src.element_count(), total_src, n_heads, head_dim
)));
}
let pipeline = registry.get_pipeline("kv_cache_copy_batch_f32_to_f16", device)?;
let n_heads_bytes = n_heads.to_ne_bytes();
let head_dim_bytes = head_dim.to_ne_bytes();
let capacity_bytes = capacity.to_ne_bytes();
let seq_pos_bytes = seq_pos.to_ne_bytes();
use super::encode_helpers::{encode_with_args, KernelArg};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(src)),
(1, KernelArg::Buffer(cache)),
(2, KernelArg::Bytes(&n_heads_bytes)),
(3, KernelArg::Bytes(&head_dim_bytes)),
(4, KernelArg::Bytes(&capacity_bytes)),
(5, KernelArg::Bytes(&seq_pos_bytes)),
],
MTLSize::new(head_dim as u64, n_heads as u64, 1),
MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_kv_cache_copy_seq_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
cache: &MlxBuffer,
n_heads: u32,
head_dim: u32,
capacity: u32,
seq_pos_start: u32,
n_tokens: u32,
src_tok_offset: u32,
) -> Result<()> {
if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
return Ok(());
}
let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
* (n_heads as u64) * (head_dim as u64);
if (src.element_count() as u64) < total_src {
return Err(MlxError::InvalidArgument(format!(
"kv_cache_copy_seq_f32: src has {} elements, need {} ((src_tok_offset={} + n_tokens={}) * n_heads={} * head_dim={})",
src.element_count(), total_src, src_tok_offset, n_tokens, n_heads, head_dim
)));
}
let pipeline = registry.get_pipeline("kv_cache_copy_seq_f32", device)?;
let n_heads_bytes = n_heads.to_ne_bytes();
let head_dim_bytes = head_dim.to_ne_bytes();
let capacity_bytes = capacity.to_ne_bytes();
let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
let n_tokens_bytes = n_tokens.to_ne_bytes();
let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
use super::encode_helpers::{encode_with_args, KernelArg};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(src)),
(1, KernelArg::Buffer(cache)),
(2, KernelArg::Bytes(&n_heads_bytes)),
(3, KernelArg::Bytes(&head_dim_bytes)),
(4, KernelArg::Bytes(&capacity_bytes)),
(5, KernelArg::Bytes(&seq_pos_start_bytes)),
(6, KernelArg::Bytes(&n_tokens_bytes)),
(7, KernelArg::Bytes(&src_tok_offset_bytes)),
],
MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_kv_cache_copy_seq_f32_to_f16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
cache: &MlxBuffer,
n_heads: u32,
head_dim: u32,
capacity: u32,
seq_pos_start: u32,
n_tokens: u32,
src_tok_offset: u32,
) -> Result<()> {
if n_heads == 0 || head_dim == 0 || n_tokens == 0 {
return Ok(());
}
let total_src = ((src_tok_offset as u64) + (n_tokens as u64))
* (n_heads as u64) * (head_dim as u64);
if (src.element_count() as u64) < total_src {
return Err(MlxError::InvalidArgument(format!(
"kv_cache_copy_seq_f32_to_f16: src has {} elements, need {}",
src.element_count(), total_src
)));
}
let pipeline = registry.get_pipeline("kv_cache_copy_seq_f32_to_f16", device)?;
let n_heads_bytes = n_heads.to_ne_bytes();
let head_dim_bytes = head_dim.to_ne_bytes();
let capacity_bytes = capacity.to_ne_bytes();
let seq_pos_start_bytes = seq_pos_start.to_ne_bytes();
let n_tokens_bytes = n_tokens.to_ne_bytes();
let src_tok_offset_bytes = src_tok_offset.to_ne_bytes();
use super::encode_helpers::{encode_with_args, KernelArg};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(src)),
(1, KernelArg::Buffer(cache)),
(2, KernelArg::Bytes(&n_heads_bytes)),
(3, KernelArg::Bytes(&head_dim_bytes)),
(4, KernelArg::Bytes(&capacity_bytes)),
(5, KernelArg::Bytes(&seq_pos_start_bytes)),
(6, KernelArg::Bytes(&n_tokens_bytes)),
(7, KernelArg::Bytes(&src_tok_offset_bytes)),
],
MTLSize::new(head_dim as u64, n_heads as u64, n_tokens as u64),
MTLSize::new(std::cmp::min(256, head_dim as u64), 1, 1),
);
Ok(())
}