use std::collections::HashMap;
use crate::error::DnnError;
#[derive(Debug, Clone)]
pub struct KvManager {
total_pages: usize,
page_size: usize,
head_dim: usize,
free_pages: Vec<usize>,
page_tables: HashMap<u64, Vec<usize>>,
}
impl KvManager {
#[must_use]
pub fn new(total_pages: usize, page_size: usize, head_dim: usize) -> Self {
let free_pages: Vec<usize> = (0..total_pages).rev().collect();
Self {
total_pages,
page_size,
head_dim,
free_pages,
page_tables: HashMap::new(),
}
}
pub fn allocate_page(&mut self, seq_id: u64) -> Result<usize, DnnError> {
let page = self.free_pages.pop().ok_or_else(|| {
DnnError::WorkspaceRequired(self.page_size * self.head_dim * std::mem::size_of::<f32>())
})?;
self.page_tables.entry(seq_id).or_default().push(page);
Ok(page)
}
pub fn free_sequence(&mut self, seq_id: u64) {
if let Some(pages) = self.page_tables.remove(&seq_id) {
self.free_pages.extend(pages);
}
}
#[must_use]
pub fn page_count(&self, seq_id: u64) -> usize {
self.page_tables.get(&seq_id).map_or(0, |v| v.len())
}
#[must_use]
pub fn token_capacity(&self, seq_id: u64) -> usize {
self.page_count(seq_id) * self.page_size
}
#[must_use]
pub fn free_page_count(&self) -> usize {
self.free_pages.len()
}
#[must_use]
pub fn total_pages(&self) -> usize {
self.total_pages
}
#[must_use]
pub fn page_size(&self) -> usize {
self.page_size
}
#[must_use]
pub fn head_dim(&self) -> usize {
self.head_dim
}
#[must_use]
pub fn checkpoint(&self, seq_id: u64) -> Option<KvCheckpoint> {
let pages = self.page_tables.get(&seq_id)?;
Some(KvCheckpoint {
seq_id,
page_snapshot: pages.clone(),
token_count: pages.len() * self.page_size,
})
}
pub fn restore(&mut self, checkpoint: &KvCheckpoint) -> Result<(), DnnError> {
for &p in &checkpoint.page_snapshot {
if p >= self.total_pages {
return Err(DnnError::InvalidArgument(format!(
"checkpoint contains page {} which is out of pool bounds ({})",
p, self.total_pages
)));
}
}
let seq_id = checkpoint.seq_id;
let current: Vec<usize> = self.page_tables.get(&seq_id).cloned().unwrap_or_default();
let snap_len = checkpoint.page_snapshot.len();
if current.len() > snap_len {
let excess = ¤t[snap_len..];
self.free_pages.extend_from_slice(excess);
}
if checkpoint.page_snapshot.is_empty() {
self.page_tables.remove(&seq_id);
} else {
self.page_tables
.insert(seq_id, checkpoint.page_snapshot.clone());
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct KvCheckpoint {
pub seq_id: u64,
pub page_snapshot: Vec<usize>,
pub token_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
fn make_mgr(pages: usize) -> KvManager {
KvManager::new(pages, 16, 64)
}
#[test]
fn test_kvmanager_allocates_pages() {
let mut mgr = make_mgr(4);
let p0 = mgr.allocate_page(1).expect("allocate page 0");
let p1 = mgr.allocate_page(1).expect("allocate page 1");
assert_ne!(p0, p1, "each allocated page must be distinct");
assert_eq!(mgr.page_count(1), 2);
assert_eq!(mgr.free_page_count(), 2);
}
#[test]
fn test_kvmanager_free_sequence_returns_pages() {
let mut mgr = make_mgr(4);
mgr.allocate_page(7).expect("alloc");
mgr.allocate_page(7).expect("alloc");
assert_eq!(mgr.free_page_count(), 2);
mgr.free_sequence(7);
assert_eq!(mgr.free_page_count(), 4);
assert_eq!(mgr.page_count(7), 0);
}
#[test]
fn test_kvmanager_out_of_pages_error() {
let mut mgr = make_mgr(2);
mgr.allocate_page(1).expect("page 0");
mgr.allocate_page(1).expect("page 1");
let err = mgr.allocate_page(1);
assert!(err.is_err());
assert!(matches!(err.unwrap_err(), DnnError::WorkspaceRequired(_)));
}
#[test]
fn test_kvmanager_token_capacity() {
let mut mgr = KvManager::new(8, 16, 64); mgr.allocate_page(3).expect("p0");
mgr.allocate_page(3).expect("p1");
assert_eq!(mgr.token_capacity(3), 32);
}
#[test]
fn test_kvmanager_free_unknown_seq_is_noop() {
let mut mgr = make_mgr(4);
mgr.free_sequence(99); assert_eq!(mgr.free_page_count(), 4);
}
#[test]
fn test_checkpoint_restore_roundtrip() {
let mut mgr = make_mgr(8);
mgr.allocate_page(5).expect("p0");
let cp = mgr.checkpoint(5).expect("checkpoint should exist");
assert_eq!(cp.seq_id, 5);
assert_eq!(cp.page_snapshot.len(), 1);
mgr.allocate_page(5).expect("p1");
mgr.allocate_page(5).expect("p2");
assert_eq!(mgr.page_count(5), 3);
mgr.restore(&cp).expect("restore");
assert_eq!(mgr.page_count(5), 1);
assert_eq!(mgr.free_page_count(), 7); }
#[test]
fn test_restore_frees_pages_allocated_after_checkpoint() {
let mut mgr = make_mgr(4);
mgr.allocate_page(10).expect("pre-cp page 0");
mgr.allocate_page(10).expect("pre-cp page 1");
let cp = mgr.checkpoint(10).expect("cp");
mgr.allocate_page(10).expect("post-cp page 0");
mgr.allocate_page(10).expect("post-cp page 1");
assert_eq!(mgr.free_page_count(), 0);
mgr.restore(&cp).expect("restore");
assert_eq!(mgr.free_page_count(), 2);
assert_eq!(mgr.page_count(10), 2);
}
#[test]
fn test_checkpoint_none_for_unknown_seq() {
let mgr = make_mgr(4);
assert!(mgr.checkpoint(999).is_none());
}
}