Skip to main content

oxibonsai_model/
paged_kv_cache.rs

1//! PagedAttention / vLLM-style paged KV cache.
2//!
3//! This module implements a block-based key-value cache that mirrors the
4//! PagedAttention design from vLLM.  Physical memory is divided into fixed-size
5//! *pages* (blocks), each holding `block_size` token slots.  Logical sequences
6//! are given a [`BlockTable`] that maps their logical block indices to physical
7//! page indices obtained from a shared [`BlockPool`].  Allocation is *lazy*:
8//! pages are handed out on demand as sequences grow, and are returned to the
9//! pool when a sequence is dropped.
10//!
11//! # Architecture overview
12//!
13//! ```text
14//! ┌──────────────────────────────────────────────────┐
15//! │  PagedKvCache                                    │
16//! │                                                  │
17//! │  pool: BlockPool  ◄──── free_list (Vec<usize>)  │
18//! │  sequences: HashMap<seq_id, BlockTable>          │
19//! └──────────────────────────────────────────────────┘
20//! ```
21//!
22//! Each [`KvPage`] stores keys and values for *one* transformer layer and
23//! `block_size` token positions.  A [`BlockTable`] therefore holds
24//! `num_layers` independent block lists, one per layer.
25
26use std::collections::HashMap;
27
28// ---------------------------------------------------------------------------
29// Constants
30// ---------------------------------------------------------------------------
31
32/// Default number of token slots per KV page.
33pub const DEFAULT_BLOCK_SIZE: usize = 16;
34
35// ---------------------------------------------------------------------------
36// KvPage
37// ---------------------------------------------------------------------------
38
39/// A single KV page that holds `block_size` token slots for **one** layer.
40///
41/// Both `keys` and `values` are stored in row-major order with the logical
42/// layout `[block_size, num_kv_heads, head_dim]`.
43#[derive(Debug, Clone)]
44pub struct KvPage {
45    /// Flattened key tensor: `[block_size × num_kv_heads × head_dim]` f32 elements.
46    pub keys: Vec<f32>,
47    /// Flattened value tensor: same shape as `keys`.
48    pub values: Vec<f32>,
49}
50
51impl KvPage {
52    /// Allocate a zeroed page for the given dimensions.
53    fn new(block_size: usize, num_kv_heads: usize, head_dim: usize) -> Self {
54        let len = block_size * num_kv_heads * head_dim;
55        Self {
56            keys: vec![0.0_f32; len],
57            values: vec![0.0_f32; len],
58        }
59    }
60}
61
62// ---------------------------------------------------------------------------
63// BlockPool
64// ---------------------------------------------------------------------------
65
66/// Pre-allocated pool of [`KvPage`]s shared across all sequences.
67///
68/// Pages are handed out via [`BlockPool::allocate`] and returned via
69/// [`BlockPool::free`].  The pool never grows beyond its initial capacity.
70pub struct BlockPool {
71    /// All pages ever created (indexed by physical block index).
72    pages: Vec<KvPage>,
73    /// Indices of currently unused pages.
74    free_list: Vec<usize>,
75    /// Token slots per page.
76    block_size: usize,
77    /// Number of transformer layers each page covers.
78    num_layers: usize,
79    /// Number of KV-attention heads.
80    num_kv_heads: usize,
81    /// Dimensionality of each attention head.
82    head_dim: usize,
83}
84
85impl BlockPool {
86    /// Create a pool with `capacity` pages.
87    ///
88    /// Every page is pre-allocated and zeroed at construction time so that
89    /// subsequent allocations are O(1) pointer-hand-offs from the free list.
90    ///
91    /// # Arguments
92    ///
93    /// * `capacity`    – total number of physical pages.
94    /// * `block_size`  – token slots per page.
95    /// * `num_layers`  – number of transformer layers (informational; each
96    ///   logical layer has its own independent block list in [`BlockTable`]).
97    /// * `num_kv_heads` – number of KV attention heads.
98    /// * `head_dim`    – per-head dimension.
99    pub fn new(
100        capacity: usize,
101        block_size: usize,
102        num_layers: usize,
103        num_kv_heads: usize,
104        head_dim: usize,
105    ) -> Self {
106        let mut pages = Vec::with_capacity(capacity);
107        let mut free_list = Vec::with_capacity(capacity);
108        for idx in 0..capacity {
109            pages.push(KvPage::new(block_size, num_kv_heads, head_dim));
110            free_list.push(idx);
111        }
112        Self {
113            pages,
114            free_list,
115            block_size,
116            num_layers,
117            num_kv_heads,
118            head_dim,
119        }
120    }
121
122    /// Allocate one page and return its physical index.
123    ///
124    /// Returns `None` when the pool is exhausted (out-of-memory).
125    pub fn allocate(&mut self) -> Option<usize> {
126        self.free_list.pop()
127    }
128
129    /// Return page `idx` to the pool.
130    ///
131    /// The page contents are **not** zeroed on release; callers must overwrite
132    /// every slot they intend to read.
133    pub fn free(&mut self, idx: usize) {
134        self.free_list.push(idx);
135    }
136
137    /// Number of pages currently available for allocation.
138    pub fn free_count(&self) -> usize {
139        self.free_list.len()
140    }
141
142    /// Total number of pages in the pool (constant after construction).
143    pub fn total_count(&self) -> usize {
144        self.pages.len()
145    }
146
147    /// Fraction of pages currently in use, in `[0.0, 1.0]`.
148    ///
149    /// Returns `0.0` when the pool is empty (capacity == 0).
150    pub fn utilization(&self) -> f32 {
151        let total = self.total_count();
152        if total == 0 {
153            return 0.0;
154        }
155        let used = total - self.free_count();
156        used as f32 / total as f32
157    }
158
159    // ------------------------------------------------------------------
160    // Internal helpers
161    // ------------------------------------------------------------------
162
163    /// Return an immutable reference to the page at physical index `idx`.
164    fn page(&self, idx: usize) -> &KvPage {
165        &self.pages[idx]
166    }
167
168    /// Return a mutable reference to the page at physical index `idx`.
169    fn page_mut(&mut self, idx: usize) -> &mut KvPage {
170        &mut self.pages[idx]
171    }
172
173    /// Number of f32 elements per token slot within a page.
174    fn slot_len(&self) -> usize {
175        self.num_kv_heads * self.head_dim
176    }
177}
178
179// ---------------------------------------------------------------------------
180// BlockTable
181// ---------------------------------------------------------------------------
182
183/// Maps a sequence's logical block indices to physical page indices.
184///
185/// Each transformer layer has its own independent list of physical blocks so
186/// that cross-layer sharing is straightforward to reason about.
187pub struct BlockTable {
188    /// Token slots per block (must match the [`BlockPool`]'s `block_size`).
189    block_size: usize,
190    /// `blocks[layer][logical_block]` → physical page index.
191    blocks: Vec<Vec<usize>>,
192    /// Number of transformer layers.
193    num_layers: usize,
194}
195
196impl BlockTable {
197    /// Create an empty block table for `num_layers` layers.
198    pub fn new(num_layers: usize, block_size: usize) -> Self {
199        Self {
200            block_size,
201            blocks: vec![Vec::new(); num_layers],
202            num_layers,
203        }
204    }
205
206    /// Append a newly-allocated physical page to `layer`'s block list.
207    ///
208    /// # Panics
209    ///
210    /// Panics in debug builds if `layer >= num_layers`.
211    pub fn append_block(&mut self, layer: usize, physical_idx: usize) {
212        debug_assert!(layer < self.num_layers);
213        self.blocks[layer].push(physical_idx);
214    }
215
216    /// Look up the physical page index for logical block `logical_block` in
217    /// `layer`.
218    ///
219    /// Returns `None` if either the layer or the logical block index is out of
220    /// range.
221    pub fn get_block(&self, layer: usize, logical_block: usize) -> Option<usize> {
222        self.blocks.get(layer)?.get(logical_block).copied()
223    }
224
225    /// Number of physical blocks currently mapped for `layer`.
226    ///
227    /// Returns `0` for out-of-range `layer`.
228    pub fn num_blocks(&self, layer: usize) -> usize {
229        self.blocks.get(layer).map_or(0, |v| v.len())
230    }
231
232    /// Total token capacity (may include unused trailing slots) for `layer`.
233    pub fn token_capacity(&self, layer: usize) -> usize {
234        self.num_blocks(layer) * self.block_size
235    }
236}
237
238// ---------------------------------------------------------------------------
239// PagedKvError
240// ---------------------------------------------------------------------------
241
242/// Errors returned by [`PagedKvCache`] operations.
243#[derive(Debug, thiserror::Error)]
244pub enum PagedKvError {
245    /// The requested sequence ID does not exist in the cache.
246    #[error("sequence {0} not found")]
247    SequenceNotFound(u64),
248
249    /// The pool has no free pages left.
250    #[error("out of memory: no free KV blocks")]
251    OutOfMemory,
252
253    /// The token position exceeds the sequence's allocated capacity.
254    #[error("token position {pos} out of range for sequence {seq_id}")]
255    PositionOutOfRange { seq_id: u64, pos: usize },
256
257    /// A key or value slice has the wrong length.
258    #[error("dimension mismatch: expected {expected}, got {actual}")]
259    DimMismatch { expected: usize, actual: usize },
260}
261
262// ---------------------------------------------------------------------------
263// PagedKvCache
264// ---------------------------------------------------------------------------
265
266/// Orchestrates a [`BlockPool`] and per-sequence [`BlockTable`]s.
267///
268/// This is the primary entry-point for vLLM-style paged KV management.
269/// Sequences are identified by opaque `u64` IDs assigned by
270/// [`PagedKvCache::create_sequence`].
271///
272/// # Example
273///
274/// ```rust
275/// use oxibonsai_model::paged_kv_cache::{PagedKvCache};
276///
277/// let mut cache = PagedKvCache::new(
278///     /*capacity=*/ 128,
279///     /*num_layers=*/ 32,
280///     /*num_kv_heads=*/ 8,
281///     /*head_dim=*/ 128,
282/// );
283///
284/// let seq = cache.create_sequence();
285/// cache.ensure_capacity(seq, 1).expect("failed to ensure capacity");
286///
287/// let key = vec![1.0_f32; 8 * 128];   // num_kv_heads * head_dim
288/// let val = vec![2.0_f32; 8 * 128];
289/// cache.write_kv(seq, 0, 0, &key, &val).expect("failed to write kv");
290///
291/// let (k, v) = cache.read_kv(seq, 0, 0).expect("failed to read kv");
292/// assert_eq!(k, key.as_slice());
293/// ```
294pub struct PagedKvCache {
295    pool: BlockPool,
296    sequences: HashMap<u64, BlockTable>,
297    next_seq_id: u64,
298}
299
300impl PagedKvCache {
301    /// Create a cache with [`DEFAULT_BLOCK_SIZE`] token slots per page.
302    ///
303    /// # Arguments
304    ///
305    /// * `capacity`    – maximum number of physical KV pages.
306    /// * `num_layers`  – number of transformer layers.
307    /// * `num_kv_heads` – number of KV attention heads.
308    /// * `head_dim`    – per-head feature dimension.
309    pub fn new(capacity: usize, num_layers: usize, num_kv_heads: usize, head_dim: usize) -> Self {
310        Self::new_with_block_size(
311            capacity,
312            DEFAULT_BLOCK_SIZE,
313            num_layers,
314            num_kv_heads,
315            head_dim,
316        )
317    }
318
319    /// Create a cache with a custom `block_size`.
320    ///
321    /// All other parameters are the same as [`PagedKvCache::new`].
322    pub fn new_with_block_size(
323        capacity: usize,
324        block_size: usize,
325        num_layers: usize,
326        num_kv_heads: usize,
327        head_dim: usize,
328    ) -> Self {
329        Self {
330            pool: BlockPool::new(capacity, block_size, num_layers, num_kv_heads, head_dim),
331            sequences: HashMap::new(),
332            next_seq_id: 0,
333        }
334    }
335
336    // ------------------------------------------------------------------
337    // Sequence lifecycle
338    // ------------------------------------------------------------------
339
340    /// Register a new sequence and return its ID.
341    ///
342    /// Sequences start with no allocated pages.  Call [`ensure_capacity`] or
343    /// [`write_kv`] to trigger lazy allocation.
344    ///
345    /// [`ensure_capacity`]: PagedKvCache::ensure_capacity
346    /// [`write_kv`]: PagedKvCache::write_kv
347    pub fn create_sequence(&mut self) -> u64 {
348        let id = self.next_seq_id;
349        self.next_seq_id += 1;
350        let num_layers = self.pool.num_layers;
351        let block_size = self.pool.block_size;
352        self.sequences
353            .insert(id, BlockTable::new(num_layers, block_size));
354        id
355    }
356
357    /// Drop a sequence, returning all its physical pages to the pool.
358    ///
359    /// # Errors
360    ///
361    /// Returns [`PagedKvError::SequenceNotFound`] if `seq_id` is unknown.
362    pub fn drop_sequence(&mut self, seq_id: u64) -> Result<(), PagedKvError> {
363        let table = self
364            .sequences
365            .remove(&seq_id)
366            .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
367
368        for layer_blocks in &table.blocks {
369            for &phys_idx in layer_blocks {
370                self.pool.free(phys_idx);
371            }
372        }
373        Ok(())
374    }
375
376    // ------------------------------------------------------------------
377    // Capacity management
378    // ------------------------------------------------------------------
379
380    /// Ensure the sequence can hold at least `num_tokens` positions in every
381    /// layer, allocating new pages from the pool as needed.
382    ///
383    /// This is the primary lazy-allocation entry-point.  [`write_kv`] calls
384    /// this internally, so explicit calls are optional.
385    ///
386    /// # Errors
387    ///
388    /// * [`PagedKvError::SequenceNotFound`] – unknown `seq_id`.
389    /// * [`PagedKvError::OutOfMemory`] – pool exhausted before the request
390    ///   could be satisfied; partially-allocated pages are **not** rolled back
391    ///   (vLLM behaviour — callers should drop the sequence on OOM).
392    ///
393    /// [`write_kv`]: PagedKvCache::write_kv
394    pub fn ensure_capacity(&mut self, seq_id: u64, num_tokens: usize) -> Result<(), PagedKvError> {
395        // We need to look up immutable fields before taking the mutable borrow.
396        let num_layers = self.pool.num_layers;
397        let block_size = self.pool.block_size;
398
399        // Compute the number of blocks needed to cover `num_tokens`.
400        let blocks_needed = num_tokens.div_ceil(block_size);
401
402        // We must not hold a reference into `self.sequences` while mutating
403        // `self.pool`, so we collect the per-layer deficits first.
404        let deficits: Vec<usize> = {
405            let table = self
406                .sequences
407                .get(&seq_id)
408                .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
409
410            (0..num_layers)
411                .map(|layer| {
412                    let have = table.num_blocks(layer);
413                    blocks_needed.saturating_sub(have)
414                })
415                .collect()
416        };
417
418        // Allocate the required pages and record them in the block table.
419        for (layer, deficit) in deficits.into_iter().enumerate() {
420            for _ in 0..deficit {
421                let phys = self.pool.allocate().ok_or(PagedKvError::OutOfMemory)?;
422                let table = self
423                    .sequences
424                    .get_mut(&seq_id)
425                    .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
426                table.append_block(layer, phys);
427            }
428        }
429        Ok(())
430    }
431
432    // ------------------------------------------------------------------
433    // KV I/O
434    // ------------------------------------------------------------------
435
436    /// Write a key-value pair for token at position `token_pos` in `layer`.
437    ///
438    /// `key` and `value` must each have exactly `num_kv_heads * head_dim`
439    /// elements.  Capacity is extended automatically if the token position
440    /// falls outside the current allocation.
441    ///
442    /// # Errors
443    ///
444    /// * [`PagedKvError::SequenceNotFound`]
445    /// * [`PagedKvError::OutOfMemory`]
446    /// * [`PagedKvError::DimMismatch`] – wrong slice length.
447    pub fn write_kv(
448        &mut self,
449        seq_id: u64,
450        layer: usize,
451        token_pos: usize,
452        key: &[f32],
453        value: &[f32],
454    ) -> Result<(), PagedKvError> {
455        let slot_len = self.pool.slot_len();
456
457        if key.len() != slot_len {
458            return Err(PagedKvError::DimMismatch {
459                expected: slot_len,
460                actual: key.len(),
461            });
462        }
463        if value.len() != slot_len {
464            return Err(PagedKvError::DimMismatch {
465                expected: slot_len,
466                actual: value.len(),
467            });
468        }
469
470        // Ensure we have enough blocks for token_pos + 1.
471        self.ensure_capacity(seq_id, token_pos + 1)?;
472
473        let block_size = self.pool.block_size;
474        let logical_block = token_pos / block_size;
475        let slot_in_block = token_pos % block_size;
476
477        let phys = {
478            let table = self
479                .sequences
480                .get(&seq_id)
481                .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
482            table
483                .get_block(layer, logical_block)
484                .ok_or(PagedKvError::PositionOutOfRange {
485                    seq_id,
486                    pos: token_pos,
487                })?
488        };
489
490        let offset = slot_in_block * slot_len;
491        let page = self.pool.page_mut(phys);
492        page.keys[offset..offset + slot_len].copy_from_slice(key);
493        page.values[offset..offset + slot_len].copy_from_slice(value);
494        Ok(())
495    }
496
497    /// Read the key-value pair for token at position `token_pos` in `layer`.
498    ///
499    /// Returns `(&key_slice, &value_slice)` where each slice has
500    /// `num_kv_heads * head_dim` elements.
501    ///
502    /// # Errors
503    ///
504    /// * [`PagedKvError::SequenceNotFound`]
505    /// * [`PagedKvError::PositionOutOfRange`] – token has not been written yet.
506    pub fn read_kv(
507        &self,
508        seq_id: u64,
509        layer: usize,
510        token_pos: usize,
511    ) -> Result<(&[f32], &[f32]), PagedKvError> {
512        let block_size = self.pool.block_size;
513        let slot_len = self.pool.slot_len();
514
515        let table = self
516            .sequences
517            .get(&seq_id)
518            .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
519
520        let logical_block = token_pos / block_size;
521        let slot_in_block = token_pos % block_size;
522
523        let phys =
524            table
525                .get_block(layer, logical_block)
526                .ok_or(PagedKvError::PositionOutOfRange {
527                    seq_id,
528                    pos: token_pos,
529                })?;
530
531        let offset = slot_in_block * slot_len;
532        let page = self.pool.page(phys);
533        Ok((
534            &page.keys[offset..offset + slot_len],
535            &page.values[offset..offset + slot_len],
536        ))
537    }
538
539    // ------------------------------------------------------------------
540    // Metrics
541    // ------------------------------------------------------------------
542
543    /// Fraction of pool pages currently in use (`[0.0, 1.0]`).
544    pub fn pool_utilization(&self) -> f32 {
545        self.pool.utilization()
546    }
547
548    /// Number of token positions written to `seq_id` (across **all** layers,
549    /// using layer 0 as the canonical length).
550    ///
551    /// Returns `0` for unknown sequences.
552    pub fn sequence_length(&self, seq_id: u64) -> usize {
553        self.sequences
554            .get(&seq_id)
555            .map_or(0, |t| t.token_capacity(0))
556    }
557}