use std::collections::HashMap;
pub const DEFAULT_BLOCK_SIZE: usize = 16;
#[derive(Debug, Clone)]
pub struct KvPage {
pub keys: Vec<f32>,
pub values: Vec<f32>,
}
impl KvPage {
fn new(block_size: usize, num_kv_heads: usize, head_dim: usize) -> Self {
let len = block_size * num_kv_heads * head_dim;
Self {
keys: vec![0.0_f32; len],
values: vec![0.0_f32; len],
}
}
}
pub struct BlockPool {
pages: Vec<KvPage>,
free_list: Vec<usize>,
block_size: usize,
num_layers: usize,
num_kv_heads: usize,
head_dim: usize,
}
impl BlockPool {
pub fn new(
capacity: usize,
block_size: usize,
num_layers: usize,
num_kv_heads: usize,
head_dim: usize,
) -> Self {
let mut pages = Vec::with_capacity(capacity);
let mut free_list = Vec::with_capacity(capacity);
for idx in 0..capacity {
pages.push(KvPage::new(block_size, num_kv_heads, head_dim));
free_list.push(idx);
}
Self {
pages,
free_list,
block_size,
num_layers,
num_kv_heads,
head_dim,
}
}
pub fn allocate(&mut self) -> Option<usize> {
self.free_list.pop()
}
pub fn free(&mut self, idx: usize) {
self.free_list.push(idx);
}
pub fn free_count(&self) -> usize {
self.free_list.len()
}
pub fn total_count(&self) -> usize {
self.pages.len()
}
pub fn utilization(&self) -> f32 {
let total = self.total_count();
if total == 0 {
return 0.0;
}
let used = total - self.free_count();
used as f32 / total as f32
}
fn page(&self, idx: usize) -> &KvPage {
&self.pages[idx]
}
fn page_mut(&mut self, idx: usize) -> &mut KvPage {
&mut self.pages[idx]
}
fn slot_len(&self) -> usize {
self.num_kv_heads * self.head_dim
}
}
pub struct BlockTable {
block_size: usize,
blocks: Vec<Vec<usize>>,
num_layers: usize,
}
impl BlockTable {
pub fn new(num_layers: usize, block_size: usize) -> Self {
Self {
block_size,
blocks: vec![Vec::new(); num_layers],
num_layers,
}
}
pub fn append_block(&mut self, layer: usize, physical_idx: usize) {
debug_assert!(layer < self.num_layers);
self.blocks[layer].push(physical_idx);
}
pub fn get_block(&self, layer: usize, logical_block: usize) -> Option<usize> {
self.blocks.get(layer)?.get(logical_block).copied()
}
pub fn num_blocks(&self, layer: usize) -> usize {
self.blocks.get(layer).map_or(0, |v| v.len())
}
pub fn token_capacity(&self, layer: usize) -> usize {
self.num_blocks(layer) * self.block_size
}
}
#[derive(Debug, thiserror::Error)]
pub enum PagedKvError {
#[error("sequence {0} not found")]
SequenceNotFound(u64),
#[error("out of memory: no free KV blocks")]
OutOfMemory,
#[error("token position {pos} out of range for sequence {seq_id}")]
PositionOutOfRange { seq_id: u64, pos: usize },
#[error("dimension mismatch: expected {expected}, got {actual}")]
DimMismatch { expected: usize, actual: usize },
}
pub struct PagedKvCache {
pool: BlockPool,
sequences: HashMap<u64, BlockTable>,
next_seq_id: u64,
}
impl PagedKvCache {
pub fn new(capacity: usize, num_layers: usize, num_kv_heads: usize, head_dim: usize) -> Self {
Self::new_with_block_size(
capacity,
DEFAULT_BLOCK_SIZE,
num_layers,
num_kv_heads,
head_dim,
)
}
pub fn new_with_block_size(
capacity: usize,
block_size: usize,
num_layers: usize,
num_kv_heads: usize,
head_dim: usize,
) -> Self {
Self {
pool: BlockPool::new(capacity, block_size, num_layers, num_kv_heads, head_dim),
sequences: HashMap::new(),
next_seq_id: 0,
}
}
pub fn create_sequence(&mut self) -> u64 {
let id = self.next_seq_id;
self.next_seq_id += 1;
let num_layers = self.pool.num_layers;
let block_size = self.pool.block_size;
self.sequences
.insert(id, BlockTable::new(num_layers, block_size));
id
}
pub fn drop_sequence(&mut self, seq_id: u64) -> Result<(), PagedKvError> {
let table = self
.sequences
.remove(&seq_id)
.ok_or(PagedKvError::SequenceNotFound(seq_id))?;
for layer_blocks in &table.blocks {
for &phys_idx in layer_blocks {
self.pool.free(phys_idx);
}
}
Ok(())
}
pub fn ensure_capacity(&mut self, seq_id: u64, num_tokens: usize) -> Result<(), PagedKvError> {
let num_layers = self.pool.num_layers;
let block_size = self.pool.block_size;
let blocks_needed = num_tokens.div_ceil(block_size);
let deficits: Vec<usize> = {
let table = self
.sequences
.get(&seq_id)
.ok_or(PagedKvError::SequenceNotFound(seq_id))?;
(0..num_layers)
.map(|layer| {
let have = table.num_blocks(layer);
blocks_needed.saturating_sub(have)
})
.collect()
};
for (layer, deficit) in deficits.into_iter().enumerate() {
for _ in 0..deficit {
let phys = self.pool.allocate().ok_or(PagedKvError::OutOfMemory)?;
let table = self
.sequences
.get_mut(&seq_id)
.ok_or(PagedKvError::SequenceNotFound(seq_id))?;
table.append_block(layer, phys);
}
}
Ok(())
}
pub fn write_kv(
&mut self,
seq_id: u64,
layer: usize,
token_pos: usize,
key: &[f32],
value: &[f32],
) -> Result<(), PagedKvError> {
let slot_len = self.pool.slot_len();
if key.len() != slot_len {
return Err(PagedKvError::DimMismatch {
expected: slot_len,
actual: key.len(),
});
}
if value.len() != slot_len {
return Err(PagedKvError::DimMismatch {
expected: slot_len,
actual: value.len(),
});
}
self.ensure_capacity(seq_id, token_pos + 1)?;
let block_size = self.pool.block_size;
let logical_block = token_pos / block_size;
let slot_in_block = token_pos % block_size;
let phys = {
let table = self
.sequences
.get(&seq_id)
.ok_or(PagedKvError::SequenceNotFound(seq_id))?;
table
.get_block(layer, logical_block)
.ok_or(PagedKvError::PositionOutOfRange {
seq_id,
pos: token_pos,
})?
};
let offset = slot_in_block * slot_len;
let page = self.pool.page_mut(phys);
page.keys[offset..offset + slot_len].copy_from_slice(key);
page.values[offset..offset + slot_len].copy_from_slice(value);
Ok(())
}
pub fn read_kv(
&self,
seq_id: u64,
layer: usize,
token_pos: usize,
) -> Result<(&[f32], &[f32]), PagedKvError> {
let block_size = self.pool.block_size;
let slot_len = self.pool.slot_len();
let table = self
.sequences
.get(&seq_id)
.ok_or(PagedKvError::SequenceNotFound(seq_id))?;
let logical_block = token_pos / block_size;
let slot_in_block = token_pos % block_size;
let phys =
table
.get_block(layer, logical_block)
.ok_or(PagedKvError::PositionOutOfRange {
seq_id,
pos: token_pos,
})?;
let offset = slot_in_block * slot_len;
let page = self.pool.page(phys);
Ok((
&page.keys[offset..offset + slot_len],
&page.values[offset..offset + slot_len],
))
}
pub fn pool_utilization(&self) -> f32 {
self.pool.utilization()
}
pub fn sequence_length(&self, seq_id: u64) -> usize {
self.sequences
.get(&seq_id)
.map_or(0, |t| t.token_capacity(0))
}
}