use crate::riir::variants::{Variant, MAX_SEQ_LEN, VARIANT};
#[derive(Debug)]
pub struct KvCache {
pub k_cache: Box<[f32]>,
pub v_cache: Box<[f32]>,
pub len: i32,
}
impl KvCache {
pub fn new() -> Self {
let entries = MAX_SEQ_LEN * VARIANT.num_kv_heads * VARIANT.head_dim;
Self {
k_cache: vec![0.0f32; entries].into_boxed_slice(),
v_cache: vec![0.0f32; entries].into_boxed_slice(),
len: 0,
}
}
pub fn truncate(&mut self, new_len: i32) {
if new_len < 0 || new_len > self.len {
return;
}
let old_len = self.len;
let stride = VARIANT.num_kv_heads * VARIANT.head_dim;
if new_len < old_len {
let start = (new_len as usize) * stride;
let end = (old_len as usize) * stride;
self.k_cache[start..end].fill(0.0);
self.v_cache[start..end].fill(0.0);
}
self.len = new_len;
}
}
#[derive(Debug)]
pub struct LinearAttnState {
pub conv_state: Box<[f32]>,
pub ssm_state: Box<[f32]>,
}
impl LinearAttnState {
pub fn new() -> Self {
let conv_entries =
(Variant::CONV_KERNEL_SIZE - 1) * VARIANT.linear_conv_dim();
let ssm_entries = VARIANT.linear_num_v_heads
* Variant::LINEAR_VALUE_DIM
* Variant::LINEAR_KEY_DIM;
Self {
conv_state: vec![0.0f32; conv_entries].into_boxed_slice(),
ssm_state: vec![0.0f32; ssm_entries].into_boxed_slice(),
}
}
pub fn reset(&mut self) {
self.conv_state.fill(0.0);
self.ssm_state.fill(0.0);
}
}
#[derive(Debug)]
pub enum LayerState {
FullAttn(KvCache),
LinearAttn(LinearAttnState),
}
impl LayerState {
pub fn is_full(&self) -> bool {
matches!(self, Self::FullAttn(_))
}
}
pub fn alloc_layer_states() -> Vec<LayerState> {
use super::variants::LayerKind;
(0..VARIANT.num_layers)
.map(|i| match VARIANT.layer_kind(i) {
LayerKind::FullAttn => LayerState::FullAttn(KvCache::new()),
LayerKind::LinearAttn => {
LayerState::LinearAttn(LinearAttnState::new())
}
})
.collect()
}
pub fn clear_all(layers: &mut [LayerState]) {
for layer in layers {
match layer {
LayerState::FullAttn(kv) => kv.truncate(0),
LayerState::LinearAttn(la) => la.reset(),
}
}
}
pub fn truncate(layers: &mut [LayerState], p0: i32, p1: i32) {
let new_len = p0.max(0);
for layer in layers {
match layer {
LayerState::FullAttn(kv) => {
let effective_end =
if p1 < 0 || p1 > kv.len { kv.len } else { p1 };
let truncate_to = new_len.min(effective_end);
kv.truncate(truncate_to);
}
LayerState::LinearAttn(la) => la.reset(),
}
}
}
pub fn pos_max(layers: &[LayerState]) -> i32 {
let mut max_len = -1;
for layer in layers {
if let LayerState::FullAttn(kv) = layer {
if kv.len > max_len {
max_len = kv.len;
}
}
}
max_len
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_state_pos_max_is_zero() {
let mut layers = alloc_layer_states();
assert_eq!(pos_max(&layers), 0);
clear_all(&mut layers);
assert_eq!(pos_max(&layers), 0);
}
#[test]
fn truncate_empty_is_noop() {
let mut layers = alloc_layer_states();
truncate(&mut layers, 0, -1);
assert_eq!(pos_max(&layers), 0);
truncate(&mut layers, 5, 10);
assert_eq!(pos_max(&layers), 0);
truncate(&mut layers, -1, -1);
assert_eq!(pos_max(&layers), 0);
}
#[test]
fn truncate_drops_full_attn_len() {
let mut layers = alloc_layer_states();
let target = layers
.iter_mut()
.find_map(|l| match l {
LayerState::FullAttn(kv) => Some(kv),
_ => None,
})
.expect("variant must have at least one full-attn layer");
target.len = 7;
assert_eq!(pos_max(&layers), 7);
truncate(&mut layers, 3, -1);
assert_eq!(pos_max(&layers), 3);
truncate(&mut layers, 10, -1);
assert_eq!(pos_max(&layers), 3);
clear_all(&mut layers);
assert_eq!(pos_max(&layers), 0);
}
}