Skip to main content

cjc_runtime/
paged_kv.rs

1//! Block-paged KV-cache -- vLLM-style block paging for transformer inference.
2//!
3//! Instead of a single contiguous pre-allocated tensor (which may fragment
4//! or require reallocation), this module manages KV-cache memory in
5//! fixed-size 16-token blocks via a logical-to-physical block table.
6//!
7//! # Benefits
8//!
9//! - No single large allocation -- blocks are page-sized (16 tokens each).
10//! - Zero reallocation on append -- new blocks are pre-allocated at construction.
11//! - Logical-to-physical mapping via block table enables flexible memory reuse.
12//! - Each block is independently cache-line friendly.
13//!
14//! # NoGC guarantee
15//!
16//! All blocks are pre-allocated at construction time. [`PagedKvCache::append`]
17//! performs zero heap allocations -- it only copies token data into existing
18//! block storage.
19
20use std::fmt;
21
22use crate::error::RuntimeError;
23use crate::tensor::Tensor;
24
25// ---------------------------------------------------------------------------
26// 2e. BlockPaged KV-Cache — vLLM-style block paging
27// ---------------------------------------------------------------------------
28
29/// Number of tokens stored per block in the paged KV-cache.
30const BLOCK_TOKEN_COUNT: usize = 16;
31
32/// A single page/block in the KV-cache.
33///
34/// Pre-allocated at construction with capacity for [`BLOCK_TOKEN_COUNT`]
35/// tokens. Data is stored as a flat `Vec<f64>` of shape
36/// `[BLOCK_TOKEN_COUNT, dim]`, with a `used` cursor tracking how many
37/// token slots have been written.
38#[derive(Debug, Clone)]
39pub struct KvBlock {
40    /// Data storage: [BLOCK_TOKEN_COUNT, dim]. Pre-allocated and zeroed.
41    data: Vec<f64>,
42    /// Hidden dimension per token.
43    dim: usize,
44    /// Number of tokens currently written in this block (0..=BLOCK_TOKEN_COUNT).
45    used: usize,
46}
47
48impl KvBlock {
49    /// Create a new zeroed block for tokens of the given hidden dimension.
50    fn new(dim: usize) -> Self {
51        KvBlock {
52            data: vec![0.0; BLOCK_TOKEN_COUNT * dim],
53            dim,
54            used: 0,
55        }
56    }
57
58    /// Return `true` if all token slots in this block have been written.
59    fn is_full(&self) -> bool {
60        self.used >= BLOCK_TOKEN_COUNT
61    }
62
63    /// Return the number of unused token slots remaining in this block.
64    #[allow(dead_code)]
65    fn remaining(&self) -> usize {
66        BLOCK_TOKEN_COUNT - self.used
67    }
68
69    /// Write a single token vector into the block. Returns error if full.
70    fn write_token(&mut self, token: &[f64]) -> Result<(), RuntimeError> {
71        if token.len() != self.dim {
72            return Err(RuntimeError::ShapeMismatch {
73                expected: self.dim,
74                got: token.len(),
75            });
76        }
77        if self.is_full() {
78            return Err(RuntimeError::InvalidOperation(
79                "KvBlock is full".to_string(),
80            ));
81        }
82        let base = self.used * self.dim;
83        self.data[base..base + self.dim].copy_from_slice(token);
84        self.used += 1;
85        Ok(())
86    }
87
88    /// Read token at position `idx` within this block.
89    fn read_token(&self, idx: usize) -> &[f64] {
90        let base = idx * self.dim;
91        &self.data[base..base + self.dim]
92    }
93}
94
95/// A vLLM-style block-paged KV-cache. Instead of one contiguous pre-allocated
96/// tensor (which may fragment or require realloc), this manages memory in
97/// fixed-size 16-token blocks via a `BlockTable`.
98///
99/// Benefits:
100/// - No single large allocation — blocks are page-sized
101/// - Zero reallocation on append (new blocks allocated on demand from pool)
102/// - Logical-to-physical mapping via block table
103/// - Each block is independently cache-line friendly
104#[derive(Debug, Clone)]
105pub struct PagedKvCache {
106    /// All allocated blocks.
107    blocks: Vec<KvBlock>,
108    /// Block table: maps logical block indices to physical block indices.
109    /// `block_table[i]` = index into `blocks` for the i-th logical block.
110    block_table: Vec<usize>,
111    /// Hidden dimension per token.
112    dim: usize,
113    /// Maximum total tokens allowed.
114    max_tokens: usize,
115    /// Total tokens currently stored.
116    current_len: usize,
117}
118
119impl PagedKvCache {
120    /// Create a paged KV-cache for `max_tokens` tokens of dimension `dim`.
121    ///
122    /// Pre-allocates all blocks upfront to avoid any heap allocation during
123    /// the inference loop. The number of blocks = ceil(max_tokens / 16).
124    pub fn new(max_tokens: usize, dim: usize) -> Self {
125        let num_blocks = (max_tokens + BLOCK_TOKEN_COUNT - 1) / BLOCK_TOKEN_COUNT;
126        let mut blocks = Vec::with_capacity(num_blocks);
127        let mut block_table = Vec::with_capacity(num_blocks);
128        for i in 0..num_blocks {
129            blocks.push(KvBlock::new(dim));
130            block_table.push(i); // identity mapping initially
131        }
132        PagedKvCache {
133            blocks,
134            block_table,
135            dim,
136            max_tokens,
137            current_len: 0,
138        }
139    }
140
141    /// Number of tokens currently stored.
142    pub fn len(&self) -> usize {
143        self.current_len
144    }
145
146    /// Whether no tokens are stored.
147    pub fn is_empty(&self) -> bool {
148        self.current_len == 0
149    }
150
151    /// Maximum tokens this cache can hold.
152    pub fn max_tokens(&self) -> usize {
153        self.max_tokens
154    }
155
156    /// Hidden dimension per token.
157    pub fn dim(&self) -> usize {
158        self.dim
159    }
160
161    /// Number of blocks allocated.
162    pub fn num_blocks(&self) -> usize {
163        self.blocks.len()
164    }
165
166    /// Number of blocks currently in use (partially or fully).
167    pub fn blocks_in_use(&self) -> usize {
168        if self.current_len == 0 { return 0; }
169        (self.current_len + BLOCK_TOKEN_COUNT - 1) / BLOCK_TOKEN_COUNT
170    }
171
172    /// Append a single token vector. **Zero allocation** — writes into
173    /// the next available slot in the current block.
174    pub fn append(&mut self, token: &[f64]) -> Result<(), RuntimeError> {
175        if token.len() != self.dim {
176            return Err(RuntimeError::ShapeMismatch {
177                expected: self.dim,
178                got: token.len(),
179            });
180        }
181        if self.current_len >= self.max_tokens {
182            return Err(RuntimeError::InvalidOperation(
183                format!(
184                    "PagedKvCache full: {} / {} tokens",
185                    self.current_len, self.max_tokens
186                ),
187            ));
188        }
189        let logical_block = self.current_len / BLOCK_TOKEN_COUNT;
190        let physical_block = self.block_table[logical_block];
191        self.blocks[physical_block].write_token(token)?;
192        self.current_len += 1;
193        Ok(())
194    }
195
196    /// Append a batch of tokens from a 2D tensor `[n, dim]`.
197    pub fn append_tensor(&mut self, t: &Tensor) -> Result<(), RuntimeError> {
198        if t.ndim() != 2 || t.shape()[1] != self.dim {
199            return Err(RuntimeError::InvalidOperation(
200                format!(
201                    "PagedKvCache.append_tensor: expected [n, {}], got {:?}",
202                    self.dim, t.shape()
203                ),
204            ));
205        }
206        let n = t.shape()[0];
207        if self.current_len + n > self.max_tokens {
208            return Err(RuntimeError::InvalidOperation(
209                format!(
210                    "PagedKvCache overflow: {} + {} > {}",
211                    self.current_len, n, self.max_tokens
212                ),
213            ));
214        }
215        let data = t.to_vec();
216        for i in 0..n {
217            let start = i * self.dim;
218            self.append(&data[start..start + self.dim])?;
219        }
220        Ok(())
221    }
222
223    /// Materialize all stored tokens into a contiguous Tensor `[current_len, dim]`.
224    ///
225    /// This is a read operation that copies data from blocks into a flat
226    /// buffer. The copy is required since blocks are non-contiguous.
227    pub fn as_tensor(&self) -> Tensor {
228        if self.current_len == 0 {
229            return Tensor::from_vec(vec![], &[0, self.dim])
230                .unwrap_or_else(|_| Tensor::zeros(&[0]));
231        }
232        let mut data = Vec::with_capacity(self.current_len * self.dim);
233        let mut remaining = self.current_len;
234        for &phys_idx in &self.block_table {
235            if remaining == 0 { break; }
236            let block = &self.blocks[phys_idx];
237            let tokens_in_block = remaining.min(block.used);
238            for t in 0..tokens_in_block {
239                data.extend_from_slice(block.read_token(t));
240            }
241            remaining -= tokens_in_block;
242        }
243        Tensor::from_vec(data, &[self.current_len, self.dim])
244            .expect("PagedKvCache::as_tensor shape mismatch")
245    }
246
247    /// Reset the cache to empty without deallocating blocks.
248    /// Block data is retained; only cursors are reset.
249    pub fn clear(&mut self) {
250        for block in &mut self.blocks {
251            block.used = 0;
252        }
253        self.current_len = 0;
254    }
255
256    /// Read a single token at logical position `idx`.
257    pub fn get_token(&self, idx: usize) -> Result<Vec<f64>, RuntimeError> {
258        if idx >= self.current_len {
259            return Err(RuntimeError::IndexOutOfBounds {
260                index: idx,
261                length: self.current_len,
262            });
263        }
264        let logical_block = idx / BLOCK_TOKEN_COUNT;
265        let offset_in_block = idx % BLOCK_TOKEN_COUNT;
266        let physical_block = self.block_table[logical_block];
267        Ok(self.blocks[physical_block].read_token(offset_in_block).to_vec())
268    }
269}
270
271impl fmt::Display for PagedKvCache {
272    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273        write!(
274            f,
275            "PagedKvCache(len={}, max={}, dim={}, blocks={}/{})",
276            self.current_len,
277            self.max_tokens,
278            self.dim,
279            self.blocks_in_use(),
280            self.blocks.len()
281        )
282    }
283}
284
285impl fmt::Display for Tensor {
286    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287        write!(f, "Tensor(shape={:?}, data={:?})", self.shape, self.to_vec())
288    }
289}
290