use std::collections::{HashMap, VecDeque};
use super::{TransformerError, TransformerResult};
pub type SequenceId = u64;
pub type BlockId = u32;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheEvictionPolicy {
Lru,
Fifo,
Frequency,
}
#[derive(Debug, Clone)]
pub struct PagedKvCacheConfig {
pub block_size: usize,
pub num_blocks: usize,
pub num_layers: usize,
pub num_heads: usize,
pub head_dim: usize,
pub eviction_policy: CacheEvictionPolicy,
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub total_allocations: u64,
pub total_deallocations: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub evictions: u64,
pub cow_copies: u64,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.cache_hits + self.cache_misses;
if total == 0 {
return 0.0;
}
self.cache_hits as f64 / total as f64
}
pub fn utilization(&self, total_blocks: usize, free_blocks: usize) -> f64 {
if total_blocks == 0 {
return 0.0;
}
let used = total_blocks.saturating_sub(free_blocks);
used as f64 / total_blocks as f64
}
}
#[derive(Debug, Clone)]
struct BlockMeta {
ref_count: u32,
num_tokens: usize,
last_access: u64,
access_count: u64,
alloc_order: u64,
}
#[derive(Debug)]
pub struct PagedKvCache {
block_size: usize,
num_blocks: usize,
block_table: HashMap<SequenceId, Vec<BlockId>>,
free_blocks: VecDeque<BlockId>,
block_meta: HashMap<BlockId, BlockMeta>,
num_layers: usize,
num_heads: usize,
head_dim: usize,
eviction_policy: CacheEvictionPolicy,
clock: u64,
alloc_counter: u64,
prefix_table: HashMap<u64, Vec<BlockId>>,
stats: CacheStats,
}
impl PagedKvCache {
pub fn new(config: PagedKvCacheConfig) -> TransformerResult<Self> {
if config.block_size == 0 {
return Err(TransformerError::CacheError(
"block_size must be > 0".to_string(),
));
}
if config.num_blocks == 0 {
return Err(TransformerError::CacheError(
"num_blocks must be > 0".to_string(),
));
}
if config.num_layers == 0 || config.num_heads == 0 || config.head_dim == 0 {
return Err(TransformerError::CacheError(
"num_layers, num_heads, head_dim must all be > 0".to_string(),
));
}
let free_blocks: VecDeque<BlockId> = (0..config.num_blocks as BlockId).collect();
Ok(Self {
block_size: config.block_size,
num_blocks: config.num_blocks,
block_table: HashMap::new(),
free_blocks,
block_meta: HashMap::new(),
num_layers: config.num_layers,
num_heads: config.num_heads,
head_dim: config.head_dim,
eviction_policy: config.eviction_policy,
clock: 0,
alloc_counter: 0,
prefix_table: HashMap::new(),
stats: CacheStats::default(),
})
}
pub fn allocate_block(&mut self, seq_id: SequenceId) -> TransformerResult<BlockId> {
let block_id = if let Some(id) = self.free_blocks.pop_front() {
id
} else {
self.evict_block()?
};
self.clock += 1;
self.alloc_counter += 1;
self.block_meta.insert(
block_id,
BlockMeta {
ref_count: 1,
num_tokens: 0,
last_access: self.clock,
access_count: 1,
alloc_order: self.alloc_counter,
},
);
self.block_table.entry(seq_id).or_default().push(block_id);
self.stats.total_allocations += 1;
Ok(block_id)
}
pub fn allocate_blocks_for_tokens(
&mut self,
seq_id: SequenceId,
num_tokens: usize,
) -> TransformerResult<Vec<BlockId>> {
let num_blocks_needed = num_tokens.div_ceil(self.block_size);
let mut allocated = Vec::with_capacity(num_blocks_needed);
for _ in 0..num_blocks_needed {
match self.allocate_block(seq_id) {
Ok(block_id) => allocated.push(block_id),
Err(e) => {
for &bid in &allocated {
self.free_block_internal(bid);
}
if let Some(table) = self.block_table.get_mut(&seq_id) {
let len = table.len();
table.truncate(len.saturating_sub(allocated.len()));
if table.is_empty() {
self.block_table.remove(&seq_id);
}
}
return Err(e);
}
}
}
if let Some(&last_block) = allocated.last() {
let remainder = num_tokens % self.block_size;
let tokens_in_last = if remainder == 0 {
self.block_size
} else {
remainder
};
if let Some(meta) = self.block_meta.get_mut(&last_block) {
meta.num_tokens = tokens_in_last;
}
}
for &bid in allocated.iter().take(allocated.len().saturating_sub(1)) {
if let Some(meta) = self.block_meta.get_mut(&bid) {
meta.num_tokens = self.block_size;
}
}
Ok(allocated)
}
pub fn append_token(&mut self, seq_id: SequenceId) -> TransformerResult<BlockId> {
let needs_new_block = match self.block_table.get(&seq_id) {
None => true,
Some(blocks) => {
if blocks.is_empty() {
true
} else {
let last_block = blocks[blocks.len() - 1];
match self.block_meta.get(&last_block) {
Some(meta) => meta.num_tokens >= self.block_size,
None => true,
}
}
}
};
if needs_new_block {
if let Some(blocks) = self.block_table.get(&seq_id) {
if let Some(&last_block) = blocks.last() {
if let Some(meta) = self.block_meta.get(&last_block) {
if meta.ref_count > 1 {
self.copy_on_write(seq_id, blocks.len() - 1)?;
}
}
}
}
let block_id = self.allocate_block(seq_id)?;
if let Some(meta) = self.block_meta.get_mut(&block_id) {
meta.num_tokens = 1;
}
Ok(block_id)
} else {
let blocks = self
.block_table
.get(&seq_id)
.ok_or_else(|| TransformerError::CacheError("sequence not found".to_string()))?;
let last_block = blocks[blocks.len() - 1];
if let Some(meta) = self.block_meta.get(&last_block) {
if meta.ref_count > 1 {
let block_idx = blocks.len() - 1;
self.copy_on_write(seq_id, block_idx)?;
let blocks = self.block_table.get(&seq_id).ok_or_else(|| {
TransformerError::CacheError("sequence not found".to_string())
})?;
let new_last = blocks[blocks.len() - 1];
if let Some(meta) = self.block_meta.get_mut(&new_last) {
meta.num_tokens += 1;
self.clock += 1;
meta.last_access = self.clock;
meta.access_count += 1;
}
return Ok(new_last);
}
}
if let Some(meta) = self.block_meta.get_mut(&last_block) {
meta.num_tokens += 1;
self.clock += 1;
meta.last_access = self.clock;
meta.access_count += 1;
}
Ok(last_block)
}
}
pub fn free_sequence(&mut self, seq_id: SequenceId) -> TransformerResult<()> {
let blocks = self
.block_table
.remove(&seq_id)
.ok_or_else(|| TransformerError::CacheError(format!("sequence {seq_id} not found")))?;
for block_id in blocks {
self.free_block_internal(block_id);
self.stats.total_deallocations += 1;
}
Ok(())
}
pub fn share_prefix(
&mut self,
src_seq: SequenceId,
dst_seq: SequenceId,
num_prefix_blocks: usize,
) -> TransformerResult<()> {
let src_blocks = self
.block_table
.get(&src_seq)
.ok_or_else(|| {
TransformerError::CacheError(format!("source sequence {src_seq} not found"))
})?
.clone();
if num_prefix_blocks > src_blocks.len() {
return Err(TransformerError::CacheError(format!(
"requested {} prefix blocks but source has only {}",
num_prefix_blocks,
src_blocks.len()
)));
}
let shared: Vec<BlockId> = src_blocks[..num_prefix_blocks].to_vec();
for &block_id in &shared {
if let Some(meta) = self.block_meta.get_mut(&block_id) {
meta.ref_count += 1;
}
}
let dst_blocks = self.block_table.entry(dst_seq).or_default();
dst_blocks.extend_from_slice(&shared);
self.stats.cache_hits += 1;
Ok(())
}
pub fn register_prefix(&mut self, prefix_hash: u64, block_ids: Vec<BlockId>) {
for &bid in &block_ids {
if let Some(meta) = self.block_meta.get_mut(&bid) {
meta.ref_count += 1;
}
}
self.prefix_table.insert(prefix_hash, block_ids);
}
pub fn lookup_prefix(&mut self, prefix_hash: u64) -> Option<&[BlockId]> {
if self.prefix_table.contains_key(&prefix_hash) {
self.stats.cache_hits += 1;
self.prefix_table.get(&prefix_hash).map(|v| v.as_slice())
} else {
self.stats.cache_misses += 1;
None
}
}
pub fn get_block_table(&self, seq_id: SequenceId) -> Option<&[BlockId]> {
self.block_table.get(&seq_id).map(|v| v.as_slice())
}
pub fn num_cached_tokens(&self, seq_id: SequenceId) -> usize {
match self.block_table.get(&seq_id) {
None => 0,
Some(blocks) => {
if blocks.is_empty() {
return 0;
}
let full_blocks = blocks.len().saturating_sub(1);
let last_tokens = blocks
.last()
.and_then(|bid| self.block_meta.get(bid))
.map(|m| m.num_tokens)
.unwrap_or(0);
full_blocks * self.block_size + last_tokens
}
}
}
pub fn num_free_blocks(&self) -> usize {
self.free_blocks.len()
}
pub fn total_blocks(&self) -> usize {
self.num_blocks
}
pub fn block_size(&self) -> usize {
self.block_size
}
pub fn num_layers(&self) -> usize {
self.num_layers
}
pub fn num_heads(&self) -> usize {
self.num_heads
}
pub fn head_dim(&self) -> usize {
self.head_dim
}
pub fn stats(&self) -> &CacheStats {
&self.stats
}
pub fn fragmentation(&self) -> f64 {
let mut partial_blocks = 0usize;
let mut total_used = 0usize;
for blocks in self.block_table.values() {
for &bid in blocks {
if let Some(meta) = self.block_meta.get(&bid) {
total_used += 1;
if meta.num_tokens < self.block_size {
partial_blocks += 1;
}
}
}
}
if total_used == 0 {
return 0.0;
}
partial_blocks as f64 / total_used as f64
}
pub fn memory_usage_bytes(&self) -> usize {
let used_blocks = self.num_blocks - self.free_blocks.len();
let bytes_per_block =
2 * self.num_layers * self.num_heads * self.block_size * self.head_dim * 2;
used_blocks * bytes_per_block
}
fn free_block_internal(&mut self, block_id: BlockId) {
if let Some(meta) = self.block_meta.get_mut(&block_id) {
meta.ref_count = meta.ref_count.saturating_sub(1);
if meta.ref_count == 0 {
self.block_meta.remove(&block_id);
self.free_blocks.push_back(block_id);
}
}
}
fn copy_on_write(
&mut self,
seq_id: SequenceId,
block_idx: usize,
) -> TransformerResult<BlockId> {
let old_block_id = {
let blocks = self.block_table.get(&seq_id).ok_or_else(|| {
TransformerError::CacheError(format!("sequence {seq_id} not found"))
})?;
if block_idx >= blocks.len() {
return Err(TransformerError::CacheError(format!(
"block index {block_idx} out of range"
)));
}
blocks[block_idx]
};
let new_block_id = if let Some(id) = self.free_blocks.pop_front() {
id
} else {
self.evict_block()?
};
let old_meta =
self.block_meta.get(&old_block_id).cloned().ok_or_else(|| {
TransformerError::CacheError("block metadata missing".to_string())
})?;
self.clock += 1;
self.alloc_counter += 1;
self.block_meta.insert(
new_block_id,
BlockMeta {
ref_count: 1,
num_tokens: old_meta.num_tokens,
last_access: self.clock,
access_count: 1,
alloc_order: self.alloc_counter,
},
);
if let Some(meta) = self.block_meta.get_mut(&old_block_id) {
meta.ref_count = meta.ref_count.saturating_sub(1);
if meta.ref_count == 0 {
self.block_meta.remove(&old_block_id);
self.free_blocks.push_back(old_block_id);
}
}
if let Some(blocks) = self.block_table.get_mut(&seq_id) {
blocks[block_idx] = new_block_id;
}
self.stats.cow_copies += 1;
Ok(new_block_id)
}
fn evict_block(&mut self) -> TransformerResult<BlockId> {
let victim = match self.eviction_policy {
CacheEvictionPolicy::Lru => self.find_lru_victim(),
CacheEvictionPolicy::Fifo => self.find_fifo_victim(),
CacheEvictionPolicy::Frequency => self.find_frequency_victim(),
};
match victim {
Some(block_id) => {
for (_, blocks) in self.block_table.iter_mut() {
blocks.retain(|&b| b != block_id);
}
for (_, blocks) in self.prefix_table.iter_mut() {
blocks.retain(|&b| b != block_id);
}
self.block_meta.remove(&block_id);
self.stats.evictions += 1;
Ok(block_id)
}
None => Err(TransformerError::CacheError(
"no blocks available for eviction".to_string(),
)),
}
}
fn find_lru_victim(&self) -> Option<BlockId> {
self.block_meta
.iter()
.filter(|(_, meta)| meta.ref_count <= 1)
.min_by_key(|(_, meta)| meta.last_access)
.map(|(&id, _)| id)
}
fn find_fifo_victim(&self) -> Option<BlockId> {
self.block_meta
.iter()
.filter(|(_, meta)| meta.ref_count <= 1)
.min_by_key(|(_, meta)| meta.alloc_order)
.map(|(&id, _)| id)
}
fn find_frequency_victim(&self) -> Option<BlockId> {
self.block_meta
.iter()
.filter(|(_, meta)| meta.ref_count <= 1)
.min_by_key(|(_, meta)| meta.access_count)
.map(|(&id, _)| id)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config(num_blocks: usize) -> PagedKvCacheConfig {
PagedKvCacheConfig {
block_size: 16,
num_blocks,
num_layers: 32,
num_heads: 32,
head_dim: 128,
eviction_policy: CacheEvictionPolicy::Lru,
}
}
#[test]
fn test_basic_allocation() {
let mut cache = PagedKvCache::new(test_config(100)).unwrap();
assert_eq!(cache.num_free_blocks(), 100);
let block = cache.allocate_block(1).unwrap();
assert_eq!(cache.num_free_blocks(), 99);
assert!(cache.get_block_table(1).is_some());
assert_eq!(cache.get_block_table(1).unwrap().len(), 1);
assert_eq!(cache.get_block_table(1).unwrap()[0], block);
}
#[test]
fn test_allocate_blocks_for_tokens() {
let mut cache = PagedKvCache::new(test_config(100)).unwrap();
let blocks = cache.allocate_blocks_for_tokens(1, 50).unwrap();
assert_eq!(blocks.len(), 4);
assert_eq!(cache.num_free_blocks(), 96);
}
#[test]
fn test_free_sequence() {
let mut cache = PagedKvCache::new(test_config(100)).unwrap();
cache.allocate_blocks_for_tokens(1, 32).unwrap();
assert_eq!(cache.num_free_blocks(), 98);
cache.free_sequence(1).unwrap();
assert_eq!(cache.num_free_blocks(), 100);
assert!(cache.get_block_table(1).is_none());
}
#[test]
fn test_free_unknown_sequence() {
let mut cache = PagedKvCache::new(test_config(10)).unwrap();
assert!(cache.free_sequence(999).is_err());
}
#[test]
fn test_append_token() {
let mut cache = PagedKvCache::new(test_config(10)).unwrap();
let b1 = cache.append_token(1).unwrap();
assert_eq!(cache.num_cached_tokens(1), 1);
assert_eq!(cache.num_free_blocks(), 9);
for _ in 1..16 {
let b = cache.append_token(1).unwrap();
assert_eq!(b, b1); }
assert_eq!(cache.num_cached_tokens(1), 16);
let b2 = cache.append_token(1).unwrap();
assert_ne!(b1, b2);
assert_eq!(cache.num_cached_tokens(1), 17);
assert_eq!(cache.num_free_blocks(), 8);
}
#[test]
fn test_copy_on_write() {
let mut cache = PagedKvCache::new(test_config(20)).unwrap();
cache.allocate_blocks_for_tokens(1, 32).unwrap();
assert_eq!(cache.num_free_blocks(), 18);
cache.share_prefix(1, 2, 2).unwrap();
assert_eq!(cache.num_free_blocks(), 18);
let _ = cache.append_token(2).unwrap();
assert!(cache.stats().cow_copies > 0 || cache.num_free_blocks() < 18);
}
#[test]
fn test_prefix_sharing() {
let mut cache = PagedKvCache::new(test_config(20)).unwrap();
cache.allocate_blocks_for_tokens(1, 32).unwrap();
let blocks = cache.get_block_table(1).unwrap().to_vec();
let prefix_hash = 0x1234u64;
cache.register_prefix(prefix_hash, blocks[..2].to_vec());
assert!(cache.lookup_prefix(prefix_hash).is_some());
assert_eq!(cache.lookup_prefix(prefix_hash).unwrap().len(), 2);
assert!(cache.lookup_prefix(0xDEAD).is_none());
}
#[test]
fn test_eviction_lru() {
let mut cache = PagedKvCache::new(test_config(4)).unwrap();
cache.allocate_blocks_for_tokens(1, 16).unwrap(); cache.allocate_blocks_for_tokens(2, 16).unwrap(); cache.allocate_blocks_for_tokens(3, 16).unwrap(); cache.allocate_blocks_for_tokens(4, 16).unwrap(); assert_eq!(cache.num_free_blocks(), 0);
let _new_block = cache.allocate_block(5).unwrap();
assert_eq!(cache.stats().evictions, 1);
}
#[test]
fn test_eviction_fifo() {
let config = PagedKvCacheConfig {
eviction_policy: CacheEvictionPolicy::Fifo,
..test_config(3)
};
let mut cache = PagedKvCache::new(config).unwrap();
cache.allocate_block(1).unwrap();
cache.allocate_block(2).unwrap();
cache.allocate_block(3).unwrap();
assert_eq!(cache.num_free_blocks(), 0);
let _ = cache.allocate_block(4).unwrap();
assert_eq!(cache.stats().evictions, 1);
}
#[test]
fn test_eviction_frequency() {
let config = PagedKvCacheConfig {
eviction_policy: CacheEvictionPolicy::Frequency,
..test_config(3)
};
let mut cache = PagedKvCache::new(config).unwrap();
cache.allocate_block(1).unwrap();
cache.allocate_block(2).unwrap();
cache.allocate_block(3).unwrap();
let _ = cache.allocate_block(4).unwrap();
assert_eq!(cache.stats().evictions, 1);
}
#[test]
fn test_cache_stats() {
let mut cache = PagedKvCache::new(test_config(20)).unwrap();
assert_eq!(cache.stats().hit_rate(), 0.0);
cache.allocate_blocks_for_tokens(1, 32).unwrap();
let blocks = cache.get_block_table(1).unwrap().to_vec();
assert_eq!(cache.stats().total_allocations, 2);
cache.register_prefix(42, blocks[..1].to_vec());
cache.lookup_prefix(42);
cache.lookup_prefix(99);
assert!(cache.stats().hit_rate() > 0.0);
}
#[test]
fn test_memory_usage() {
let cache = PagedKvCache::new(test_config(100)).unwrap();
assert_eq!(cache.memory_usage_bytes(), 0);
}
#[test]
fn test_fragmentation() {
let mut cache = PagedKvCache::new(test_config(20)).unwrap();
assert_eq!(cache.fragmentation(), 0.0);
cache.allocate_blocks_for_tokens(1, 17).unwrap();
let frag = cache.fragmentation();
assert!(frag > 0.0);
}
#[test]
fn test_invalid_config() {
assert!(
PagedKvCache::new(PagedKvCacheConfig {
block_size: 0,
..test_config(10)
})
.is_err()
);
assert!(
PagedKvCache::new(PagedKvCacheConfig {
num_blocks: 0,
..test_config(10)
})
.is_err()
);
}
#[test]
fn test_num_cached_tokens_no_sequence() {
let cache = PagedKvCache::new(test_config(10)).unwrap();
assert_eq!(cache.num_cached_tokens(999), 0);
}
}