#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum KvCachePolicy {
#[default]
Standard,
Fp16,
SlidingWindow(usize),
}
#[derive(Debug)]
pub struct KvCache {
num_layers: usize,
num_kv_heads: usize,
head_dim: usize,
max_seq_len: usize,
seq_len: usize,
keys: Vec<f32>,
values: Vec<f32>,
}
impl KvCache {
pub fn new(
num_layers: usize,
num_kv_heads: usize,
head_dim: usize,
max_seq_len: usize,
) -> Self {
let total = num_layers * num_kv_heads * max_seq_len * head_dim;
Self {
num_layers,
num_kv_heads,
head_dim,
max_seq_len,
seq_len: 0,
keys: vec![0.0; total],
values: vec![0.0; total],
}
}
pub fn seq_len(&self) -> usize {
self.seq_len
}
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
pub fn store_key(&mut self, layer: usize, head: usize, pos: usize, key: &[f32]) {
debug_assert!(layer < self.num_layers);
debug_assert!(head < self.num_kv_heads);
debug_assert!(pos < self.max_seq_len);
debug_assert_eq!(key.len(), self.head_dim);
let offset = self.cache_offset(layer, head, pos);
self.keys[offset..offset + self.head_dim].copy_from_slice(key);
}
pub fn store_value(&mut self, layer: usize, head: usize, pos: usize, value: &[f32]) {
debug_assert!(layer < self.num_layers);
debug_assert!(head < self.num_kv_heads);
debug_assert!(pos < self.max_seq_len);
debug_assert_eq!(value.len(), self.head_dim);
let offset = self.cache_offset(layer, head, pos);
self.values[offset..offset + self.head_dim].copy_from_slice(value);
}
pub fn keys_for(&self, layer: usize, head: usize, seq_len: usize) -> &[f32] {
let start = self.cache_offset(layer, head, 0);
let end = start + seq_len * self.head_dim;
&self.keys[start..end]
}
pub fn values_for(&self, layer: usize, head: usize, seq_len: usize) -> &[f32] {
let start = self.cache_offset(layer, head, 0);
let end = start + seq_len * self.head_dim;
&self.values[start..end]
}
pub fn advance(&mut self) {
self.seq_len += 1;
}
pub fn clear(&mut self) {
self.seq_len = 0;
}
fn cache_offset(&self, layer: usize, head: usize, pos: usize) -> usize {
((layer * self.num_kv_heads + head) * self.max_seq_len + pos) * self.head_dim
}
pub fn memory_bytes(&self) -> usize {
(self.keys.len() + self.values.len()) * std::mem::size_of::<f32>()
}
pub fn utilization_ratio(&self) -> f64 {
if self.max_seq_len == 0 {
return 0.0;
}
self.seq_len as f64 / self.max_seq_len as f64
}
pub fn num_layers(&self) -> usize {
self.num_layers
}
pub fn num_kv_heads(&self) -> usize {
self.num_kv_heads
}
pub fn head_dim(&self) -> usize {
self.head_dim
}
pub fn set_seq_len(&mut self, n: usize) {
self.seq_len = n.min(self.max_seq_len);
}
pub fn extract_block(
&self,
layer: usize,
start_pos: usize,
block_size: usize,
) -> (Vec<f32>, Vec<f32>) {
debug_assert!(layer < self.num_layers);
let per_layer = self.num_kv_heads * block_size * self.head_dim;
let mut keys = vec![0.0f32; per_layer];
let mut values = vec![0.0f32; per_layer];
for head in 0..self.num_kv_heads {
for off in 0..block_size {
let pos = start_pos + off;
if pos >= self.max_seq_len {
continue;
}
let src = self.cache_offset(layer, head, pos);
let dst = (head * block_size + off) * self.head_dim;
keys[dst..dst + self.head_dim]
.copy_from_slice(&self.keys[src..src + self.head_dim]);
values[dst..dst + self.head_dim]
.copy_from_slice(&self.values[src..src + self.head_dim]);
}
}
(keys, values)
}
pub fn inject_block(
&mut self,
layer: usize,
start_pos: usize,
block_size: usize,
keys: &[f32],
values: &[f32],
) {
debug_assert!(layer < self.num_layers);
let per_layer = self.num_kv_heads * block_size * self.head_dim;
debug_assert_eq!(keys.len(), per_layer);
debug_assert_eq!(values.len(), per_layer);
for head in 0..self.num_kv_heads {
for off in 0..block_size {
let pos = start_pos + off;
if pos >= self.max_seq_len {
continue;
}
let src = (head * block_size + off) * self.head_dim;
let dst = self.cache_offset(layer, head, pos);
self.keys[dst..dst + self.head_dim]
.copy_from_slice(&keys[src..src + self.head_dim]);
self.values[dst..dst + self.head_dim]
.copy_from_slice(&values[src..src + self.head_dim]);
}
}
}
}
const DEFAULT_PAGE_SIZE: usize = 256;
#[derive(Debug, Clone)]
struct KvPage {
keys: Vec<f32>,
values: Vec<f32>,
used: usize,
}
impl KvPage {
fn new(page_size: usize, head_dim: usize) -> Self {
Self {
keys: vec![0.0; page_size * head_dim],
values: vec![0.0; page_size * head_dim],
used: 0,
}
}
}
#[derive(Debug)]
pub struct PagedKvCache {
pages: Vec<Vec<Vec<KvPage>>>,
num_layers: usize,
num_kv_heads: usize,
head_dim: usize,
page_size: usize,
max_seq_len: usize,
seq_len: usize,
}
impl PagedKvCache {
pub fn new(
num_layers: usize,
num_kv_heads: usize,
head_dim: usize,
max_seq_len: usize,
) -> Self {
Self::with_page_size(
num_layers,
num_kv_heads,
head_dim,
max_seq_len,
DEFAULT_PAGE_SIZE,
)
}
pub fn with_page_size(
num_layers: usize,
num_kv_heads: usize,
head_dim: usize,
max_seq_len: usize,
page_size: usize,
) -> Self {
let pages = (0..num_layers)
.map(|_| (0..num_kv_heads).map(|_| Vec::new()).collect())
.collect();
Self {
pages,
num_layers,
num_kv_heads,
head_dim,
page_size,
max_seq_len,
seq_len: 0,
}
}
pub fn store_key(&mut self, layer: usize, head: usize, pos: usize, key: &[f32]) {
debug_assert!(layer < self.num_layers);
debug_assert!(head < self.num_kv_heads);
debug_assert!(pos < self.max_seq_len);
debug_assert_eq!(key.len(), self.head_dim);
let page_idx = pos / self.page_size;
let offset_in_page = pos % self.page_size;
self.ensure_page(layer, head, page_idx);
let page = &mut self.pages[layer][head][page_idx];
let start = offset_in_page * self.head_dim;
page.keys[start..start + self.head_dim].copy_from_slice(key);
if offset_in_page >= page.used {
page.used = offset_in_page + 1;
}
}
pub fn store_value(&mut self, layer: usize, head: usize, pos: usize, value: &[f32]) {
debug_assert!(layer < self.num_layers);
debug_assert!(head < self.num_kv_heads);
debug_assert!(pos < self.max_seq_len);
debug_assert_eq!(value.len(), self.head_dim);
let page_idx = pos / self.page_size;
let offset_in_page = pos % self.page_size;
self.ensure_page(layer, head, page_idx);
let page = &mut self.pages[layer][head][page_idx];
let start = offset_in_page * self.head_dim;
page.values[start..start + self.head_dim].copy_from_slice(value);
if offset_in_page >= page.used {
page.used = offset_in_page + 1;
}
}
pub fn keys_for(&self, layer: usize, head: usize, seq_len: usize) -> Vec<f32> {
let mut result = Vec::with_capacity(seq_len * self.head_dim);
let head_pages = &self.pages[layer][head];
for pos in 0..seq_len {
let page_idx = pos / self.page_size;
let offset_in_page = pos % self.page_size;
if page_idx < head_pages.len() {
let page = &head_pages[page_idx];
let start = offset_in_page * self.head_dim;
result.extend_from_slice(&page.keys[start..start + self.head_dim]);
} else {
result.extend(std::iter::repeat_n(0.0f32, self.head_dim));
}
}
result
}
pub fn values_for(&self, layer: usize, head: usize, seq_len: usize) -> Vec<f32> {
let mut result = Vec::with_capacity(seq_len * self.head_dim);
let head_pages = &self.pages[layer][head];
for pos in 0..seq_len {
let page_idx = pos / self.page_size;
let offset_in_page = pos % self.page_size;
if page_idx < head_pages.len() {
let page = &head_pages[page_idx];
let start = offset_in_page * self.head_dim;
result.extend_from_slice(&page.values[start..start + self.head_dim]);
} else {
result.extend(std::iter::repeat_n(0.0f32, self.head_dim));
}
}
result
}
pub fn seq_len(&self) -> usize {
self.seq_len
}
pub fn advance(&mut self) {
self.seq_len += 1;
}
pub fn clear(&mut self) {
self.seq_len = 0;
for layer_pages in &mut self.pages {
for head_pages in layer_pages.iter_mut() {
head_pages.clear();
}
}
}
pub fn memory_usage_bytes(&self) -> usize {
let mut total_pages = 0usize;
for layer_pages in &self.pages {
for head_pages in layer_pages {
total_pages += head_pages.len();
}
}
total_pages * self.page_size * self.head_dim * std::mem::size_of::<f32>() * 2
}
pub fn utilization_ratio(&self) -> f64 {
let mut total_slots = 0usize;
let mut used_slots = 0usize;
for layer_pages in &self.pages {
for head_pages in layer_pages {
for page in head_pages {
total_slots += self.page_size;
used_slots += page.used;
}
}
}
if total_slots == 0 {
return 0.0;
}
used_slots as f64 / total_slots as f64
}
pub fn total_pages(&self) -> usize {
let mut count = 0usize;
for layer_pages in &self.pages {
for head_pages in layer_pages {
count += head_pages.len();
}
}
count
}
pub fn page_size(&self) -> usize {
self.page_size
}
fn ensure_page(&mut self, layer: usize, head: usize, page_idx: usize) {
let head_pages = &mut self.pages[layer][head];
while head_pages.len() <= page_idx {
head_pages.push(KvPage::new(self.page_size, self.head_dim));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn kv_cache_store_and_retrieve() {
let mut cache = KvCache::new(2, 8, 128, 16);
let key = vec![1.0f32; 128];
let value = vec![2.0f32; 128];
cache.store_key(0, 0, 0, &key);
cache.store_value(0, 0, 0, &value);
cache.advance();
let keys = cache.keys_for(0, 0, 1);
let values = cache.values_for(0, 0, 1);
assert_eq!(keys.len(), 128);
assert_eq!(values.len(), 128);
assert!((keys[0] - 1.0).abs() < 1e-5);
assert!((values[0] - 2.0).abs() < 1e-5);
}
#[test]
fn kv_cache_multiple_positions() {
let mut cache = KvCache::new(1, 1, 4, 8);
cache.store_key(0, 0, 0, &[1.0, 2.0, 3.0, 4.0]);
cache.advance();
cache.store_key(0, 0, 1, &[5.0, 6.0, 7.0, 8.0]);
cache.advance();
let keys = cache.keys_for(0, 0, 2);
assert_eq!(keys.len(), 8);
assert!((keys[0] - 1.0).abs() < 1e-5);
assert!((keys[4] - 5.0).abs() < 1e-5);
}
#[test]
fn kv_cache_memory_size() {
let cache = KvCache::new(36, 8, 128, 4096);
let expected = 36 * 8 * 4096 * 128 * 4 * 2;
assert_eq!(cache.memory_bytes(), expected);
}
#[test]
fn kv_cache_utilization() {
let mut cache = KvCache::new(1, 1, 4, 10);
assert!((cache.utilization_ratio() - 0.0).abs() < 1e-10);
cache.advance();
cache.advance();
cache.advance();
assert!((cache.utilization_ratio() - 0.3).abs() < 1e-10);
}
#[test]
fn kv_cache_policy_default() {
let policy = KvCachePolicy::default();
assert_eq!(policy, KvCachePolicy::Standard);
}
#[test]
fn kv_cache_set_seq_len_clamps_to_max() {
let mut cache = KvCache::new(1, 1, 4, 8);
cache.set_seq_len(4);
assert_eq!(cache.seq_len(), 4);
cache.set_seq_len(100);
assert_eq!(cache.seq_len(), 8); }
#[test]
fn kv_cache_extract_inject_roundtrip() {
let num_layers = 2;
let num_kv_heads = 2;
let head_dim = 4;
let block_size = 4;
let max_seq = 16;
let mut cache = KvCache::new(num_layers, num_kv_heads, head_dim, max_seq);
for head in 0..num_kv_heads {
for pos in 0..block_size {
let key: Vec<f32> = (0..head_dim)
.map(|d| (head as f32 + 1.0) * 100.0 + pos as f32 * 10.0 + d as f32)
.collect();
let value: Vec<f32> = (0..head_dim)
.map(|d| (head as f32 + 1.0) * 1000.0 + pos as f32 * 10.0 + d as f32)
.collect();
cache.store_key(1, head, pos, &key);
cache.store_value(1, head, pos, &value);
}
}
let (k_block, v_block) = cache.extract_block(1, 0, block_size);
let per_layer = num_kv_heads * block_size * head_dim;
assert_eq!(k_block.len(), per_layer);
assert_eq!(v_block.len(), per_layer);
let mut fresh = KvCache::new(num_layers, num_kv_heads, head_dim, max_seq);
fresh.inject_block(1, 0, block_size, &k_block, &v_block);
fresh.set_seq_len(block_size);
let (k_block_2, v_block_2) = fresh.extract_block(1, 0, block_size);
assert_eq!(k_block_2, k_block);
assert_eq!(v_block_2, v_block);
for head in 0..num_kv_heads {
let original_keys = cache.keys_for(1, head, block_size);
let restored_keys = fresh.keys_for(1, head, block_size);
assert_eq!(
original_keys, restored_keys,
"head {head} keys must round-trip"
);
let original_values = cache.values_for(1, head, block_size);
let restored_values = fresh.values_for(1, head, block_size);
assert_eq!(
original_values, restored_values,
"head {head} values must round-trip"
);
}
}
#[test]
fn kv_cache_extract_inject_at_offset() {
let mut cache = KvCache::new(1, 1, 2, 16);
for pos in 0..4 {
let key = vec![pos as f32, pos as f32 + 0.5];
let value = vec![-(pos as f32), -(pos as f32) - 0.5];
cache.store_key(0, 0, 4 + pos, &key);
cache.store_value(0, 0, 4 + pos, &value);
}
let (k, v) = cache.extract_block(0, 4, 4);
let mut other = KvCache::new(1, 1, 2, 16);
other.inject_block(0, 4, 4, &k, &v);
for pos in 0..4 {
let original_k = cache.keys_for(0, 0, 8);
let restored_k = other.keys_for(0, 0, 8);
let off = (4 + pos) * 2;
assert!((restored_k[off] - original_k[off]).abs() < 1e-6);
assert!((restored_k[off + 1] - original_k[off + 1]).abs() < 1e-6);
}
}
#[test]
fn paged_kv_cache_store_and_retrieve() {
let mut cache = PagedKvCache::with_page_size(2, 1, 4, 16, 4);
let key = vec![1.0, 2.0, 3.0, 4.0];
let value = vec![5.0, 6.0, 7.0, 8.0];
cache.store_key(0, 0, 0, &key);
cache.store_value(0, 0, 0, &value);
cache.advance();
let keys = cache.keys_for(0, 0, 1);
let values = cache.values_for(0, 0, 1);
assert_eq!(keys.len(), 4);
assert_eq!(values.len(), 4);
assert!((keys[0] - 1.0).abs() < 1e-5);
assert!((values[0] - 5.0).abs() < 1e-5);
}
#[test]
fn paged_kv_cache_cross_page_boundary() {
let mut cache = PagedKvCache::with_page_size(1, 1, 4, 16, 2);
cache.store_key(0, 0, 0, &[1.0, 2.0, 3.0, 4.0]);
cache.store_key(0, 0, 1, &[5.0, 6.0, 7.0, 8.0]);
cache.store_key(0, 0, 2, &[9.0, 10.0, 11.0, 12.0]);
let keys = cache.keys_for(0, 0, 3);
assert_eq!(keys.len(), 12);
assert!((keys[0] - 1.0).abs() < 1e-5);
assert!((keys[4] - 5.0).abs() < 1e-5);
assert!((keys[8] - 9.0).abs() < 1e-5);
}
#[test]
fn paged_kv_cache_lazy_allocation() {
let cache = PagedKvCache::with_page_size(1, 1, 4, 1024, 256);
assert_eq!(cache.total_pages(), 0);
assert_eq!(cache.memory_usage_bytes(), 0);
}
#[test]
fn paged_kv_cache_memory_grows() {
let mut cache = PagedKvCache::with_page_size(1, 1, 4, 1024, 4);
assert_eq!(cache.memory_usage_bytes(), 0);
cache.store_key(0, 0, 0, &[1.0; 4]);
let one_page_bytes = 4 * 4 * 4 * 2;
assert_eq!(cache.memory_usage_bytes(), one_page_bytes);
cache.store_key(0, 0, 4, &[1.0; 4]);
assert_eq!(cache.memory_usage_bytes(), one_page_bytes * 2);
}
#[test]
fn paged_kv_cache_clear() {
let mut cache = PagedKvCache::with_page_size(1, 1, 4, 16, 4);
cache.store_key(0, 0, 0, &[1.0; 4]);
cache.advance();
assert!(cache.total_pages() > 0);
cache.clear();
assert_eq!(cache.total_pages(), 0);
assert_eq!(cache.seq_len(), 0);
}
#[test]
fn paged_kv_cache_utilization() {
let mut cache = PagedKvCache::with_page_size(1, 1, 4, 16, 4);
assert!((cache.utilization_ratio() - 0.0).abs() < 1e-10);
cache.store_key(0, 0, 0, &[1.0; 4]);
assert!((cache.utilization_ratio() - 0.25).abs() < 1e-10);
}
}