oxibonsai_model/paged_kv_cache.rs
1//! PagedAttention / vLLM-style paged KV cache.
2//!
3//! This module implements a block-based key-value cache that mirrors the
4//! PagedAttention design from vLLM. Physical memory is divided into fixed-size
5//! *pages* (blocks), each holding `block_size` token slots. Logical sequences
6//! are given a [`BlockTable`] that maps their logical block indices to physical
7//! page indices obtained from a shared [`BlockPool`]. Allocation is *lazy*:
8//! pages are handed out on demand as sequences grow, and are returned to the
9//! pool when a sequence is dropped.
10//!
11//! # Architecture overview
12//!
13//! ```text
14//! ┌──────────────────────────────────────────────────┐
15//! │ PagedKvCache │
16//! │ │
17//! │ pool: BlockPool ◄──── free_list (Vec<usize>) │
18//! │ sequences: HashMap<seq_id, BlockTable> │
19//! └──────────────────────────────────────────────────┘
20//! ```
21//!
22//! Each [`KvPage`] stores keys and values for *one* transformer layer and
23//! `block_size` token positions. A [`BlockTable`] therefore holds
24//! `num_layers` independent block lists, one per layer.
25
26use std::collections::HashMap;
27
28// ---------------------------------------------------------------------------
29// Constants
30// ---------------------------------------------------------------------------
31
32/// Default number of token slots per KV page.
33pub const DEFAULT_BLOCK_SIZE: usize = 16;
34
35// ---------------------------------------------------------------------------
36// KvPage
37// ---------------------------------------------------------------------------
38
39/// A single KV page that holds `block_size` token slots for **one** layer.
40///
41/// Both `keys` and `values` are stored in row-major order with the logical
42/// layout `[block_size, num_kv_heads, head_dim]`.
43#[derive(Debug, Clone)]
44pub struct KvPage {
45 /// Flattened key tensor: `[block_size × num_kv_heads × head_dim]` f32 elements.
46 pub keys: Vec<f32>,
47 /// Flattened value tensor: same shape as `keys`.
48 pub values: Vec<f32>,
49}
50
51impl KvPage {
52 /// Allocate a zeroed page for the given dimensions.
53 fn new(block_size: usize, num_kv_heads: usize, head_dim: usize) -> Self {
54 let len = block_size * num_kv_heads * head_dim;
55 Self {
56 keys: vec![0.0_f32; len],
57 values: vec![0.0_f32; len],
58 }
59 }
60}
61
62// ---------------------------------------------------------------------------
63// BlockPool
64// ---------------------------------------------------------------------------
65
66/// Pre-allocated pool of [`KvPage`]s shared across all sequences.
67///
68/// Pages are handed out via [`BlockPool::allocate`] and returned via
69/// [`BlockPool::free`]. The pool never grows beyond its initial capacity.
70pub struct BlockPool {
71 /// All pages ever created (indexed by physical block index).
72 pages: Vec<KvPage>,
73 /// Indices of currently unused pages.
74 free_list: Vec<usize>,
75 /// Token slots per page.
76 block_size: usize,
77 /// Number of transformer layers each page covers.
78 num_layers: usize,
79 /// Number of KV-attention heads.
80 num_kv_heads: usize,
81 /// Dimensionality of each attention head.
82 head_dim: usize,
83}
84
85impl BlockPool {
86 /// Create a pool with `capacity` pages.
87 ///
88 /// Every page is pre-allocated and zeroed at construction time so that
89 /// subsequent allocations are O(1) pointer-hand-offs from the free list.
90 ///
91 /// # Arguments
92 ///
93 /// * `capacity` – total number of physical pages.
94 /// * `block_size` – token slots per page.
95 /// * `num_layers` – number of transformer layers (informational; each
96 /// logical layer has its own independent block list in [`BlockTable`]).
97 /// * `num_kv_heads` – number of KV attention heads.
98 /// * `head_dim` – per-head dimension.
99 pub fn new(
100 capacity: usize,
101 block_size: usize,
102 num_layers: usize,
103 num_kv_heads: usize,
104 head_dim: usize,
105 ) -> Self {
106 let mut pages = Vec::with_capacity(capacity);
107 let mut free_list = Vec::with_capacity(capacity);
108 for idx in 0..capacity {
109 pages.push(KvPage::new(block_size, num_kv_heads, head_dim));
110 free_list.push(idx);
111 }
112 Self {
113 pages,
114 free_list,
115 block_size,
116 num_layers,
117 num_kv_heads,
118 head_dim,
119 }
120 }
121
122 /// Allocate one page and return its physical index.
123 ///
124 /// Returns `None` when the pool is exhausted (out-of-memory).
125 pub fn allocate(&mut self) -> Option<usize> {
126 self.free_list.pop()
127 }
128
129 /// Return page `idx` to the pool.
130 ///
131 /// The page contents are **not** zeroed on release; callers must overwrite
132 /// every slot they intend to read.
133 pub fn free(&mut self, idx: usize) {
134 self.free_list.push(idx);
135 }
136
137 /// Number of pages currently available for allocation.
138 pub fn free_count(&self) -> usize {
139 self.free_list.len()
140 }
141
142 /// Total number of pages in the pool (constant after construction).
143 pub fn total_count(&self) -> usize {
144 self.pages.len()
145 }
146
147 /// Fraction of pages currently in use, in `[0.0, 1.0]`.
148 ///
149 /// Returns `0.0` when the pool is empty (capacity == 0).
150 pub fn utilization(&self) -> f32 {
151 let total = self.total_count();
152 if total == 0 {
153 return 0.0;
154 }
155 let used = total - self.free_count();
156 used as f32 / total as f32
157 }
158
159 // ------------------------------------------------------------------
160 // Internal helpers
161 // ------------------------------------------------------------------
162
163 /// Return an immutable reference to the page at physical index `idx`.
164 fn page(&self, idx: usize) -> &KvPage {
165 &self.pages[idx]
166 }
167
168 /// Return a mutable reference to the page at physical index `idx`.
169 fn page_mut(&mut self, idx: usize) -> &mut KvPage {
170 &mut self.pages[idx]
171 }
172
173 /// Number of f32 elements per token slot within a page.
174 fn slot_len(&self) -> usize {
175 self.num_kv_heads * self.head_dim
176 }
177}
178
179// ---------------------------------------------------------------------------
180// BlockTable
181// ---------------------------------------------------------------------------
182
183/// Maps a sequence's logical block indices to physical page indices.
184///
185/// Each transformer layer has its own independent list of physical blocks so
186/// that cross-layer sharing is straightforward to reason about.
187pub struct BlockTable {
188 /// Token slots per block (must match the [`BlockPool`]'s `block_size`).
189 block_size: usize,
190 /// `blocks[layer][logical_block]` → physical page index.
191 blocks: Vec<Vec<usize>>,
192 /// Number of transformer layers.
193 num_layers: usize,
194}
195
196impl BlockTable {
197 /// Create an empty block table for `num_layers` layers.
198 pub fn new(num_layers: usize, block_size: usize) -> Self {
199 Self {
200 block_size,
201 blocks: vec![Vec::new(); num_layers],
202 num_layers,
203 }
204 }
205
206 /// Append a newly-allocated physical page to `layer`'s block list.
207 ///
208 /// # Panics
209 ///
210 /// Panics in debug builds if `layer >= num_layers`.
211 pub fn append_block(&mut self, layer: usize, physical_idx: usize) {
212 debug_assert!(layer < self.num_layers);
213 self.blocks[layer].push(physical_idx);
214 }
215
216 /// Look up the physical page index for logical block `logical_block` in
217 /// `layer`.
218 ///
219 /// Returns `None` if either the layer or the logical block index is out of
220 /// range.
221 pub fn get_block(&self, layer: usize, logical_block: usize) -> Option<usize> {
222 self.blocks.get(layer)?.get(logical_block).copied()
223 }
224
225 /// Number of physical blocks currently mapped for `layer`.
226 ///
227 /// Returns `0` for out-of-range `layer`.
228 pub fn num_blocks(&self, layer: usize) -> usize {
229 self.blocks.get(layer).map_or(0, |v| v.len())
230 }
231
232 /// Total token capacity (may include unused trailing slots) for `layer`.
233 pub fn token_capacity(&self, layer: usize) -> usize {
234 self.num_blocks(layer) * self.block_size
235 }
236}
237
238// ---------------------------------------------------------------------------
239// PagedKvError
240// ---------------------------------------------------------------------------
241
242/// Errors returned by [`PagedKvCache`] operations.
243#[derive(Debug, thiserror::Error)]
244pub enum PagedKvError {
245 /// The requested sequence ID does not exist in the cache.
246 #[error("sequence {0} not found")]
247 SequenceNotFound(u64),
248
249 /// The pool has no free pages left.
250 #[error("out of memory: no free KV blocks")]
251 OutOfMemory,
252
253 /// The token position exceeds the sequence's allocated capacity.
254 #[error("token position {pos} out of range for sequence {seq_id}")]
255 PositionOutOfRange { seq_id: u64, pos: usize },
256
257 /// A key or value slice has the wrong length.
258 #[error("dimension mismatch: expected {expected}, got {actual}")]
259 DimMismatch { expected: usize, actual: usize },
260}
261
262// ---------------------------------------------------------------------------
263// PagedKvCache
264// ---------------------------------------------------------------------------
265
266/// Orchestrates a [`BlockPool`] and per-sequence [`BlockTable`]s.
267///
268/// This is the primary entry-point for vLLM-style paged KV management.
269/// Sequences are identified by opaque `u64` IDs assigned by
270/// [`PagedKvCache::create_sequence`].
271///
272/// # Example
273///
274/// ```rust
275/// use oxibonsai_model::paged_kv_cache::{PagedKvCache};
276///
277/// let mut cache = PagedKvCache::new(
278/// /*capacity=*/ 128,
279/// /*num_layers=*/ 32,
280/// /*num_kv_heads=*/ 8,
281/// /*head_dim=*/ 128,
282/// );
283///
284/// let seq = cache.create_sequence();
285/// cache.ensure_capacity(seq, 1).expect("failed to ensure capacity");
286///
287/// let key = vec![1.0_f32; 8 * 128]; // num_kv_heads * head_dim
288/// let val = vec![2.0_f32; 8 * 128];
289/// cache.write_kv(seq, 0, 0, &key, &val).expect("failed to write kv");
290///
291/// let (k, v) = cache.read_kv(seq, 0, 0).expect("failed to read kv");
292/// assert_eq!(k, key.as_slice());
293/// ```
294pub struct PagedKvCache {
295 pool: BlockPool,
296 sequences: HashMap<u64, BlockTable>,
297 next_seq_id: u64,
298}
299
300impl PagedKvCache {
301 /// Create a cache with [`DEFAULT_BLOCK_SIZE`] token slots per page.
302 ///
303 /// # Arguments
304 ///
305 /// * `capacity` – maximum number of physical KV pages.
306 /// * `num_layers` – number of transformer layers.
307 /// * `num_kv_heads` – number of KV attention heads.
308 /// * `head_dim` – per-head feature dimension.
309 pub fn new(capacity: usize, num_layers: usize, num_kv_heads: usize, head_dim: usize) -> Self {
310 Self::new_with_block_size(
311 capacity,
312 DEFAULT_BLOCK_SIZE,
313 num_layers,
314 num_kv_heads,
315 head_dim,
316 )
317 }
318
319 /// Create a cache with a custom `block_size`.
320 ///
321 /// All other parameters are the same as [`PagedKvCache::new`].
322 pub fn new_with_block_size(
323 capacity: usize,
324 block_size: usize,
325 num_layers: usize,
326 num_kv_heads: usize,
327 head_dim: usize,
328 ) -> Self {
329 Self {
330 pool: BlockPool::new(capacity, block_size, num_layers, num_kv_heads, head_dim),
331 sequences: HashMap::new(),
332 next_seq_id: 0,
333 }
334 }
335
336 // ------------------------------------------------------------------
337 // Sequence lifecycle
338 // ------------------------------------------------------------------
339
340 /// Register a new sequence and return its ID.
341 ///
342 /// Sequences start with no allocated pages. Call [`ensure_capacity`] or
343 /// [`write_kv`] to trigger lazy allocation.
344 ///
345 /// [`ensure_capacity`]: PagedKvCache::ensure_capacity
346 /// [`write_kv`]: PagedKvCache::write_kv
347 pub fn create_sequence(&mut self) -> u64 {
348 let id = self.next_seq_id;
349 self.next_seq_id += 1;
350 let num_layers = self.pool.num_layers;
351 let block_size = self.pool.block_size;
352 self.sequences
353 .insert(id, BlockTable::new(num_layers, block_size));
354 id
355 }
356
357 /// Drop a sequence, returning all its physical pages to the pool.
358 ///
359 /// # Errors
360 ///
361 /// Returns [`PagedKvError::SequenceNotFound`] if `seq_id` is unknown.
362 pub fn drop_sequence(&mut self, seq_id: u64) -> Result<(), PagedKvError> {
363 let table = self
364 .sequences
365 .remove(&seq_id)
366 .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
367
368 for layer_blocks in &table.blocks {
369 for &phys_idx in layer_blocks {
370 self.pool.free(phys_idx);
371 }
372 }
373 Ok(())
374 }
375
376 // ------------------------------------------------------------------
377 // Capacity management
378 // ------------------------------------------------------------------
379
380 /// Ensure the sequence can hold at least `num_tokens` positions in every
381 /// layer, allocating new pages from the pool as needed.
382 ///
383 /// This is the primary lazy-allocation entry-point. [`write_kv`] calls
384 /// this internally, so explicit calls are optional.
385 ///
386 /// # Errors
387 ///
388 /// * [`PagedKvError::SequenceNotFound`] – unknown `seq_id`.
389 /// * [`PagedKvError::OutOfMemory`] – pool exhausted before the request
390 /// could be satisfied; partially-allocated pages are **not** rolled back
391 /// (vLLM behaviour — callers should drop the sequence on OOM).
392 ///
393 /// [`write_kv`]: PagedKvCache::write_kv
394 pub fn ensure_capacity(&mut self, seq_id: u64, num_tokens: usize) -> Result<(), PagedKvError> {
395 // We need to look up immutable fields before taking the mutable borrow.
396 let num_layers = self.pool.num_layers;
397 let block_size = self.pool.block_size;
398
399 // Compute the number of blocks needed to cover `num_tokens`.
400 let blocks_needed = num_tokens.div_ceil(block_size);
401
402 // We must not hold a reference into `self.sequences` while mutating
403 // `self.pool`, so we collect the per-layer deficits first.
404 let deficits: Vec<usize> = {
405 let table = self
406 .sequences
407 .get(&seq_id)
408 .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
409
410 (0..num_layers)
411 .map(|layer| {
412 let have = table.num_blocks(layer);
413 blocks_needed.saturating_sub(have)
414 })
415 .collect()
416 };
417
418 // Allocate the required pages and record them in the block table.
419 for (layer, deficit) in deficits.into_iter().enumerate() {
420 for _ in 0..deficit {
421 let phys = self.pool.allocate().ok_or(PagedKvError::OutOfMemory)?;
422 let table = self
423 .sequences
424 .get_mut(&seq_id)
425 .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
426 table.append_block(layer, phys);
427 }
428 }
429 Ok(())
430 }
431
432 // ------------------------------------------------------------------
433 // KV I/O
434 // ------------------------------------------------------------------
435
436 /// Write a key-value pair for token at position `token_pos` in `layer`.
437 ///
438 /// `key` and `value` must each have exactly `num_kv_heads * head_dim`
439 /// elements. Capacity is extended automatically if the token position
440 /// falls outside the current allocation.
441 ///
442 /// # Errors
443 ///
444 /// * [`PagedKvError::SequenceNotFound`]
445 /// * [`PagedKvError::OutOfMemory`]
446 /// * [`PagedKvError::DimMismatch`] – wrong slice length.
447 pub fn write_kv(
448 &mut self,
449 seq_id: u64,
450 layer: usize,
451 token_pos: usize,
452 key: &[f32],
453 value: &[f32],
454 ) -> Result<(), PagedKvError> {
455 let slot_len = self.pool.slot_len();
456
457 if key.len() != slot_len {
458 return Err(PagedKvError::DimMismatch {
459 expected: slot_len,
460 actual: key.len(),
461 });
462 }
463 if value.len() != slot_len {
464 return Err(PagedKvError::DimMismatch {
465 expected: slot_len,
466 actual: value.len(),
467 });
468 }
469
470 // Ensure we have enough blocks for token_pos + 1.
471 self.ensure_capacity(seq_id, token_pos + 1)?;
472
473 let block_size = self.pool.block_size;
474 let logical_block = token_pos / block_size;
475 let slot_in_block = token_pos % block_size;
476
477 let phys = {
478 let table = self
479 .sequences
480 .get(&seq_id)
481 .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
482 table
483 .get_block(layer, logical_block)
484 .ok_or(PagedKvError::PositionOutOfRange {
485 seq_id,
486 pos: token_pos,
487 })?
488 };
489
490 let offset = slot_in_block * slot_len;
491 let page = self.pool.page_mut(phys);
492 page.keys[offset..offset + slot_len].copy_from_slice(key);
493 page.values[offset..offset + slot_len].copy_from_slice(value);
494 Ok(())
495 }
496
497 /// Read the key-value pair for token at position `token_pos` in `layer`.
498 ///
499 /// Returns `(&key_slice, &value_slice)` where each slice has
500 /// `num_kv_heads * head_dim` elements.
501 ///
502 /// # Errors
503 ///
504 /// * [`PagedKvError::SequenceNotFound`]
505 /// * [`PagedKvError::PositionOutOfRange`] – token has not been written yet.
506 pub fn read_kv(
507 &self,
508 seq_id: u64,
509 layer: usize,
510 token_pos: usize,
511 ) -> Result<(&[f32], &[f32]), PagedKvError> {
512 let block_size = self.pool.block_size;
513 let slot_len = self.pool.slot_len();
514
515 let table = self
516 .sequences
517 .get(&seq_id)
518 .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
519
520 let logical_block = token_pos / block_size;
521 let slot_in_block = token_pos % block_size;
522
523 let phys =
524 table
525 .get_block(layer, logical_block)
526 .ok_or(PagedKvError::PositionOutOfRange {
527 seq_id,
528 pos: token_pos,
529 })?;
530
531 let offset = slot_in_block * slot_len;
532 let page = self.pool.page(phys);
533 Ok((
534 &page.keys[offset..offset + slot_len],
535 &page.values[offset..offset + slot_len],
536 ))
537 }
538
539 // ------------------------------------------------------------------
540 // Metrics
541 // ------------------------------------------------------------------
542
543 /// Fraction of pool pages currently in use (`[0.0, 1.0]`).
544 pub fn pool_utilization(&self) -> f32 {
545 self.pool.utilization()
546 }
547
548 /// Number of token positions written to `seq_id` (across **all** layers,
549 /// using layer 0 as the canonical length).
550 ///
551 /// Returns `0` for unknown sequences.
552 pub fn sequence_length(&self, seq_id: u64) -> usize {
553 self.sequences
554 .get(&seq_id)
555 .map_or(0, |t| t.token_capacity(0))
556 }
557}