use std::collections::{BTreeSet, VecDeque};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct KvPageId(pub u32);
#[derive(Debug, Clone, Copy)]
pub struct KvPageDesc {
pub offset: usize,
pub bytes: usize,
pub filled: u16,
}
pub struct KvPagePool {
free: BTreeSet<u32>,
descs: Vec<KvPageDesc>,
pub bytes_per_page: usize,
pub tokens_per_page: u16,
}
impl KvPagePool {
pub fn new(num_pages: u32, bytes_per_page: usize, tokens_per_page: u16) -> Self {
let descs: Vec<KvPageDesc> = (0..num_pages)
.map(|i| KvPageDesc {
offset: (i as usize) * bytes_per_page,
bytes: bytes_per_page,
filled: 0,
})
.collect();
let free: BTreeSet<u32> = (0..num_pages).collect();
Self {
free,
descs,
bytes_per_page,
tokens_per_page,
}
}
pub fn capacity(&self) -> u32 {
self.descs.len() as u32
}
pub fn free_count(&self) -> u32 {
self.free.len() as u32
}
pub fn used_count(&self) -> u32 {
self.capacity() - self.free_count()
}
pub fn alloc(&mut self) -> Option<KvPageId> {
let id = *self.free.iter().next()?;
self.free.remove(&id);
self.descs[id as usize].filled = 0;
Some(KvPageId(id))
}
pub fn free(&mut self, id: KvPageId) {
self.free.insert(id.0);
}
pub fn descriptor(&self, id: KvPageId) -> &KvPageDesc {
&self.descs[id.0 as usize]
}
pub fn descriptor_mut(&mut self, id: KvPageId) -> &mut KvPageDesc {
&mut self.descs[id.0 as usize]
}
}
#[derive(Debug, Clone, Default)]
pub struct KvBlockTable {
pages: Vec<KvPageId>,
pub seq_len: u32,
}
impl KvBlockTable {
pub fn new() -> Self {
Self::default()
}
pub fn push_page(&mut self, page: KvPageId) {
self.pages.push(page);
}
pub fn page_for_token(&self, t: u32, tokens_per_page: u16) -> Option<KvPageId> {
let idx = (t / tokens_per_page as u32) as usize;
self.pages.get(idx).copied()
}
pub fn page_count(&self) -> usize {
self.pages.len()
}
pub fn release(&mut self, pool: &mut KvPagePool) {
for p in self.pages.drain(..) {
pool.free(p);
}
self.seq_len = 0;
}
pub fn pages(&self) -> &[KvPageId] {
&self.pages
}
}
#[derive(Debug, Clone)]
pub struct BatchEntry {
pub seq_id: u64,
pub kind: BatchKind,
pub input_tokens: Vec<u32>,
pub cached_len: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BatchKind {
Prefill,
Decode,
}
pub struct BatchConstructor {
pub max_tokens_per_batch: usize,
pub max_entries: usize,
}
impl BatchConstructor {
pub fn new(max_tokens_per_batch: usize, max_entries: usize) -> Self {
Self {
max_tokens_per_batch,
max_entries,
}
}
pub fn build(
&self,
decode_queue: &mut VecDeque<BatchEntry>,
prefill_queue: &mut VecDeque<BatchEntry>,
) -> Vec<BatchEntry> {
let mut batch: Vec<BatchEntry> = Vec::new();
let mut tokens_used = 0usize;
while batch.len() < self.max_entries {
if let Some(d) = decode_queue.front() {
let need = d.input_tokens.len();
if tokens_used + need > self.max_tokens_per_batch {
break;
}
batch.push(decode_queue.pop_front().unwrap());
tokens_used += need;
} else {
break;
}
}
while batch.len() < self.max_entries {
let want = match prefill_queue.front() {
Some(p) => p.input_tokens.len(),
None => break,
};
let remaining = self.max_tokens_per_batch.saturating_sub(tokens_used);
if remaining == 0 {
break;
}
if want <= remaining {
batch.push(prefill_queue.pop_front().unwrap());
tokens_used += want;
} else {
let mut p = prefill_queue.pop_front().unwrap();
let chunk: Vec<u32> = p.input_tokens.drain(..remaining).collect();
let chunk_entry = BatchEntry {
seq_id: p.seq_id,
kind: BatchKind::Prefill,
input_tokens: chunk,
cached_len: p.cached_len,
};
p.cached_len += remaining as u32;
batch.push(chunk_entry);
prefill_queue.push_front(p);
break;
}
}
batch
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pool_alloc_free_round_trip() {
let mut pool = KvPagePool::new(4, 1024, 16);
assert_eq!(pool.free_count(), 4);
let p1 = pool.alloc().unwrap();
let p2 = pool.alloc().unwrap();
assert_eq!(pool.free_count(), 2);
pool.free(p1);
pool.free(p2);
assert_eq!(pool.free_count(), 4);
}
#[test]
fn pool_returns_none_when_exhausted() {
let mut pool = KvPagePool::new(2, 64, 4);
let _a = pool.alloc().unwrap();
let _b = pool.alloc().unwrap();
assert!(pool.alloc().is_none());
}
#[test]
fn pool_descriptor_offsets_are_unique_and_aligned() {
let pool = KvPagePool::new(4, 256, 16);
for i in 0..4u32 {
let d = pool.descriptor(KvPageId(i));
assert_eq!(d.offset, i as usize * 256);
assert_eq!(d.bytes, 256);
}
}
#[test]
fn block_table_page_for_token() {
let mut pool = KvPagePool::new(8, 64, 4);
let mut bt = KvBlockTable::new();
for _ in 0..3 {
bt.push_page(pool.alloc().unwrap());
}
assert_eq!(bt.page_for_token(0, 4), Some(bt.pages()[0]));
assert_eq!(bt.page_for_token(7, 4), Some(bt.pages()[1]));
assert_eq!(bt.page_for_token(11, 4), Some(bt.pages()[2]));
assert_eq!(bt.page_for_token(12, 4), None);
}
#[test]
fn block_table_release_returns_pages() {
let mut pool = KvPagePool::new(8, 64, 4);
let mut bt = KvBlockTable::new();
for _ in 0..3 {
bt.push_page(pool.alloc().unwrap());
}
assert_eq!(pool.free_count(), 5);
bt.release(&mut pool);
assert_eq!(pool.free_count(), 8);
assert_eq!(bt.page_count(), 0);
}
#[test]
fn batch_constructor_decodes_first_then_prefill() {
let bc = BatchConstructor::new(8, 16);
let mut decodes: VecDeque<BatchEntry> = (0..3)
.map(|i| BatchEntry {
seq_id: i,
kind: BatchKind::Decode,
input_tokens: vec![100 + i as u32],
cached_len: 50,
})
.collect();
let mut prefills: VecDeque<BatchEntry> = (0..2)
.map(|i| BatchEntry {
seq_id: 100 + i,
kind: BatchKind::Prefill,
input_tokens: vec![1; 3],
cached_len: 0,
})
.collect();
let batch = bc.build(&mut decodes, &mut prefills);
assert_eq!(batch.len(), 5);
for entry in batch.iter().take(3) {
assert_eq!(entry.kind, BatchKind::Decode);
}
for entry in batch.iter().skip(3).take(2) {
assert_eq!(entry.kind, BatchKind::Prefill);
}
let total_tokens: usize = batch.iter().map(|e| e.input_tokens.len()).sum();
assert_eq!(total_tokens, 8);
assert_eq!(prefills.len(), 1);
assert_eq!(prefills[0].input_tokens.len(), 1);
assert_eq!(prefills[0].cached_len, 2);
}
#[test]
fn batch_constructor_respects_max_entries() {
let bc = BatchConstructor::new(1024, 2); let mut decodes: VecDeque<BatchEntry> = (0..5)
.map(|i| BatchEntry {
seq_id: i,
kind: BatchKind::Decode,
input_tokens: vec![1],
cached_len: 0,
})
.collect();
let mut prefills: VecDeque<BatchEntry> = VecDeque::new();
let batch = bc.build(&mut decodes, &mut prefills);
assert_eq!(batch.len(), 2);
assert_eq!(decodes.len(), 3);
}
}