#![allow(clippy::must_use_candidate)]
#![allow(clippy::return_self_not_must_use)]
#![allow(clippy::missing_errors_doc)]
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicU64, Ordering};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum PagedCacheError {
#[error("Out of memory: need {needed} pages, have {available}")]
OutOfMemory {
needed: usize,
available: usize,
},
#[error("Sequence not found: {0}")]
SequenceNotFound(u64),
#[error("Invalid page access: page {page_id} at offset {offset}")]
InvalidPageAccess {
page_id: u64,
offset: usize,
},
#[error("Page table corruption for sequence {seq_id}")]
PageTableCorruption {
seq_id: u64,
},
#[error("Unknown sequence: {0:?}")]
UnknownSequence(SeqId),
#[error("Out of bounds: seq {seq_id:?} pos {position}, allocated {allocated}")]
OutOfBounds {
seq_id: SeqId,
position: usize,
allocated: usize,
},
#[error("Out of pages: need {requested}, have {available}")]
OutOfPages {
requested: usize,
available: usize,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SeqId(u64);
impl SeqId {
pub fn new() -> Self {
static COUNTER: AtomicU64 = AtomicU64::new(0);
Self(COUNTER.fetch_add(1, Ordering::Relaxed))
}
pub fn value(&self) -> u64 {
self.0
}
}
impl Default for SeqId {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct PageId(u64);
impl PageId {
pub fn new(id: u64) -> Self {
Self(id)
}
pub fn value(&self) -> u64 {
self.0
}
}
#[derive(Debug, Clone)]
pub struct KvPage {
pub id: PageId,
pub keys: Vec<f32>,
pub values: Vec<f32>,
pub num_tokens: usize,
pub ref_count: usize,
}
impl KvPage {
pub fn new(id: PageId, block_size: usize, num_heads: usize, head_dim: usize) -> Self {
let page_size = block_size * num_heads * head_dim;
Self {
id,
keys: vec![0.0; page_size],
values: vec![0.0; page_size],
num_tokens: 0,
ref_count: 1,
}
}
pub fn is_full(&self, block_size: usize) -> bool {
self.num_tokens >= block_size
}
pub fn is_shared(&self) -> bool {
self.ref_count > 1
}
pub fn remaining_capacity(&self, block_size: usize) -> usize {
block_size.saturating_sub(self.num_tokens)
}
}
pub struct PagedKvCache {
physical_pages: Vec<KvPage>,
page_tables: HashMap<SeqId, Vec<PageId>>,
free_pages: VecDeque<PageId>,
block_size: usize,
num_heads: usize,
head_dim: usize,
total_pages: usize,
stats: PagedCacheStats,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PagedCacheStats {
pub sequences_allocated: u64,
pub sequences_freed: u64,
pub pages_allocated: u64,
pub pages_freed: u64,
pub active_sequences: u64,
pub used_pages: u64,
pub cow_operations: u64,
pub defrag_operations: u64,
pub pages_moved: u64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct FragmentationStats {
pub holes: usize,
pub wasted_capacity: usize,
pub fragmentation_ratio: f32,
pub largest_free_region: usize,
pub avg_tokens_per_page: f32,
}
include!("contiguous.rs");
include!("mod_compute_prefix.rs");
include!("mod_quantized.rs");
include!("mod_quantized_paged.rs");