use metal::{Buffer, Device, MTLResourceOptions, NSUInteger};
use crate::riir::backend::buftype::{KvCacheKBuf, KvCacheVBuf};
use crate::riir::backend::gpu::MetalBufferPool;
use crate::riir::backend::{BufId, BufferPool};
use crate::riir::variants::{Variant, MAX_SEQ_LEN, VARIANT};
#[derive(Debug)]
pub struct KvCache {
pub k_id: Option<BufId<KvCacheKBuf>>,
pub v_id: Option<BufId<KvCacheVBuf>>,
pub len: i32,
}
impl KvCache {
pub fn new() -> Self {
Self {
k_id: None,
v_id: None,
len: 0,
}
}
pub fn ensure_buffers(&mut self, pool: &mut MetalBufferPool) {
let entries =
MAX_SEQ_LEN * VARIANT.num_kv_heads * VARIANT.head_dim;
let bytes = entries * std::mem::size_of::<f32>();
let device = pool.device().clone();
if self.k_id.is_none() {
let buf = device.new_buffer(
bytes as NSUInteger,
MTLResourceOptions::StorageModeShared,
);
self.k_id = Some(pool.register_borrowed(
buf,
bytes,
"kv.k_cache",
true,
));
}
if self.v_id.is_none() {
let buf = device.new_buffer(
bytes as NSUInteger,
MTLResourceOptions::StorageModeShared,
);
self.v_id = Some(pool.register_borrowed(
buf,
bytes,
"kv.v_cache",
true,
));
}
}
pub fn truncate(&mut self, new_len: i32) {
if new_len < 0 || new_len > self.len {
return;
}
self.len = new_len;
}
pub unsafe fn k_slice<'p>(
&self,
pool: &'p MetalBufferPool,
len: usize,
) -> &'p [f32] {
let buf = pool.handle(
self.k_id.expect("k_slice called before ensure_buffers"),
);
let n = len * VARIANT.num_kv_heads * VARIANT.head_dim;
unsafe {
std::slice::from_raw_parts(buf.contents() as *const f32, n)
}
}
pub unsafe fn k_slice_mut<'p>(
&self,
pool: &'p MetalBufferPool,
start_row: usize,
end_row: usize,
) -> &'p mut [f32] {
let buf = pool.handle(
self.k_id
.expect("k_slice_mut called before ensure_buffers"),
);
let stride = VARIANT.num_kv_heads * VARIANT.head_dim;
unsafe {
let p =
(buf.contents() as *mut f32).add(start_row * stride);
std::slice::from_raw_parts_mut(
p,
(end_row - start_row) * stride,
)
}
}
pub unsafe fn v_slice<'p>(
&self,
pool: &'p MetalBufferPool,
len: usize,
) -> &'p [f32] {
let buf = pool.handle(
self.v_id.expect("v_slice called before ensure_buffers"),
);
let n = len * VARIANT.num_kv_heads * VARIANT.head_dim;
unsafe {
std::slice::from_raw_parts(buf.contents() as *const f32, n)
}
}
pub unsafe fn v_slice_mut<'p>(
&self,
pool: &'p MetalBufferPool,
start_row: usize,
end_row: usize,
) -> &'p mut [f32] {
let buf = pool.handle(
self.v_id
.expect("v_slice_mut called before ensure_buffers"),
);
let stride = VARIANT.num_kv_heads * VARIANT.head_dim;
unsafe {
let p =
(buf.contents() as *mut f32).add(start_row * stride);
std::slice::from_raw_parts_mut(
p,
(end_row - start_row) * stride,
)
}
}
}
#[derive(Debug)]
pub struct MlaKvCacheGpu {
pub latent_cache: Option<Buffer>,
pub rope_k_cache: Option<Buffer>,
pub len: i32,
}
impl MlaKvCacheGpu {
pub fn new() -> Self {
Self {
latent_cache: None,
rope_k_cache: None,
len: 0,
}
}
pub fn ensure_buffers(&mut self, device: &Device) {
if self.latent_cache.is_none() {
let bytes = (MAX_SEQ_LEN * VARIANT.kv_lora_rank
* std::mem::size_of::<f32>())
as NSUInteger;
let buf = device.new_buffer(
bytes,
MTLResourceOptions::StorageModeShared,
);
self.latent_cache = Some(buf);
}
if self.rope_k_cache.is_none() {
let bytes = (MAX_SEQ_LEN * VARIANT.qk_rope_head_dim
* std::mem::size_of::<f32>())
as NSUInteger;
let buf = device.new_buffer(
bytes,
MTLResourceOptions::StorageModeShared,
);
self.rope_k_cache = Some(buf);
}
}
pub fn truncate(&mut self, new_len: i32) {
if new_len < 0 || new_len > self.len {
return;
}
let old_len = self.len;
if new_len < old_len {
if let Some(buf) = &self.latent_cache {
let stride_bytes =
VARIANT.kv_lora_rank * std::mem::size_of::<f32>();
let start = (new_len as usize) * stride_bytes;
let end = (old_len as usize) * stride_bytes;
unsafe {
let p = buf.contents() as *mut u8;
std::ptr::write_bytes(
p.add(start),
0,
end - start,
);
}
}
if let Some(buf) = &self.rope_k_cache {
let stride_bytes = VARIANT.qk_rope_head_dim
* std::mem::size_of::<f32>();
let start = (new_len as usize) * stride_bytes;
let end = (old_len as usize) * stride_bytes;
unsafe {
let p = buf.contents() as *mut u8;
std::ptr::write_bytes(
p.add(start),
0,
end - start,
);
}
}
}
self.len = new_len;
}
pub unsafe fn latent_slice(&self, len: usize) -> &[f32] {
let buf = self
.latent_cache
.as_ref()
.expect("latent_slice called before ensure_buffers");
let n = len * VARIANT.kv_lora_rank;
unsafe {
std::slice::from_raw_parts(buf.contents() as *const f32, n)
}
}
pub unsafe fn latent_slice_mut(
&mut self,
start_row: usize,
end_row: usize,
) -> &mut [f32] {
let buf = self
.latent_cache
.as_ref()
.expect("latent_slice_mut called before ensure_buffers");
let stride = VARIANT.kv_lora_rank;
unsafe {
let p =
(buf.contents() as *mut f32).add(start_row * stride);
std::slice::from_raw_parts_mut(
p,
(end_row - start_row) * stride,
)
}
}
pub unsafe fn rope_k_slice(&self, len: usize) -> &[f32] {
let buf = self
.rope_k_cache
.as_ref()
.expect("rope_k_slice called before ensure_buffers");
let n = len * VARIANT.qk_rope_head_dim;
unsafe {
std::slice::from_raw_parts(buf.contents() as *const f32, n)
}
}
pub unsafe fn rope_k_slice_mut(
&mut self,
start_row: usize,
end_row: usize,
) -> &mut [f32] {
let buf = self
.rope_k_cache
.as_ref()
.expect("rope_k_slice_mut called before ensure_buffers");
let stride = VARIANT.qk_rope_head_dim;
unsafe {
let p =
(buf.contents() as *mut f32).add(start_row * stride);
std::slice::from_raw_parts_mut(
p,
(end_row - start_row) * stride,
)
}
}
}
impl Default for MlaKvCacheGpu {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct LinearAttnState {
pub conv_state: Box<[f32]>,
pub ssm_state: Box<[f32]>,
}
impl LinearAttnState {
pub fn new() -> Self {
let conv_entries =
(Variant::CONV_KERNEL_SIZE - 1) * VARIANT.linear_conv_dim();
let ssm_entries = VARIANT.linear_num_v_heads
* Variant::LINEAR_VALUE_DIM
* Variant::LINEAR_KEY_DIM;
Self {
conv_state: vec![0.0f32; conv_entries].into_boxed_slice(),
ssm_state: vec![0.0f32; ssm_entries].into_boxed_slice(),
}
}
pub fn reset(&mut self) {
self.conv_state.fill(0.0);
self.ssm_state.fill(0.0);
}
}
#[derive(Debug)]
pub enum LayerState {
FullAttn(KvCache),
Mla(MlaKvCacheGpu),
LinearAttn(LinearAttnState),
}
impl LayerState {
pub fn is_full(&self) -> bool {
matches!(self, Self::FullAttn(_) | Self::Mla(_))
}
}
pub fn alloc_layer_states() -> Vec<LayerState> {
use crate::riir::variants::{AttnKind, LayerKind};
(0..VARIANT.num_layers)
.map(|i| match VARIANT.layer_kind(i) {
LayerKind::FullAttn => match VARIANT.attn_kind {
AttnKind::Gqa => LayerState::FullAttn(KvCache::new()),
AttnKind::Mla => LayerState::Mla(MlaKvCacheGpu::new()),
},
LayerKind::LinearAttn => {
LayerState::LinearAttn(LinearAttnState::new())
}
})
.collect()
}
pub fn clear_all(layers: &mut [LayerState]) {
for layer in layers {
match layer {
LayerState::FullAttn(kv) => kv.truncate(0),
LayerState::Mla(mla) => mla.truncate(0),
LayerState::LinearAttn(la) => la.reset(),
}
}
}
pub fn truncate(layers: &mut [LayerState], p0: i32, p1: i32) {
let new_len = p0.max(0);
for layer in layers {
match layer {
LayerState::FullAttn(kv) => {
let effective_end =
if p1 < 0 || p1 > kv.len { kv.len } else { p1 };
let truncate_to = new_len.min(effective_end);
kv.truncate(truncate_to);
}
LayerState::Mla(mla) => {
let effective_end =
if p1 < 0 || p1 > mla.len { mla.len } else { p1 };
let truncate_to = new_len.min(effective_end);
mla.truncate(truncate_to);
}
LayerState::LinearAttn(la) => la.reset(),
}
}
}
pub fn pos_max(layers: &[LayerState]) -> i32 {
let mut max_len = -1;
for layer in layers {
let len = match layer {
LayerState::FullAttn(kv) => kv.len,
LayerState::Mla(mla) => mla.len,
LayerState::LinearAttn(_) => continue,
};
if len > max_len {
max_len = len;
}
}
max_len
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_state_pos_max_is_zero() {
let mut layers = alloc_layer_states();
assert_eq!(pos_max(&layers), 0);
clear_all(&mut layers);
assert_eq!(pos_max(&layers), 0);
}
#[test]
fn truncate_empty_is_noop() {
let mut layers = alloc_layer_states();
truncate(&mut layers, 0, -1);
assert_eq!(pos_max(&layers), 0);
truncate(&mut layers, 5, 10);
assert_eq!(pos_max(&layers), 0);
truncate(&mut layers, -1, -1);
assert_eq!(pos_max(&layers), 0);
}
#[test]
fn truncate_drops_full_attn_len() {
let mut layers = alloc_layer_states();
let injected = layers
.iter_mut()
.find_map(|l| match l {
LayerState::FullAttn(kv) => {
kv.len = 7;
Some(())
}
LayerState::Mla(mla) => {
mla.len = 7;
Some(())
}
LayerState::LinearAttn(_) => None,
});
assert!(
injected.is_some(),
"variant must have at least one full-attn (GQA or MLA) layer",
);
assert_eq!(pos_max(&layers), 7);
truncate(&mut layers, 3, -1);
assert_eq!(pos_max(&layers), 3);
truncate(&mut layers, 10, -1);
assert_eq!(pos_max(&layers), 3);
clear_all(&mut layers);
assert_eq!(pos_max(&layers), 0);
}
}