cjc_runtime/paged_kv.rs
1//! Block-paged KV-cache -- vLLM-style block paging for transformer inference.
2//!
3//! Instead of a single contiguous pre-allocated tensor (which may fragment
4//! or require reallocation), this module manages KV-cache memory in
5//! fixed-size 16-token blocks via a logical-to-physical block table.
6//!
7//! # Benefits
8//!
9//! - No single large allocation -- blocks are page-sized (16 tokens each).
10//! - Zero reallocation on append -- new blocks are pre-allocated at construction.
11//! - Logical-to-physical mapping via block table enables flexible memory reuse.
12//! - Each block is independently cache-line friendly.
13//!
14//! # NoGC guarantee
15//!
16//! All blocks are pre-allocated at construction time. [`PagedKvCache::append`]
17//! performs zero heap allocations -- it only copies token data into existing
18//! block storage.
19
20use std::fmt;
21
22use crate::error::RuntimeError;
23use crate::tensor::Tensor;
24
25// ---------------------------------------------------------------------------
26// 2e. BlockPaged KV-Cache — vLLM-style block paging
27// ---------------------------------------------------------------------------
28
29/// Number of tokens stored per block in the paged KV-cache.
30const BLOCK_TOKEN_COUNT: usize = 16;
31
32/// A single page/block in the KV-cache.
33///
34/// Pre-allocated at construction with capacity for [`BLOCK_TOKEN_COUNT`]
35/// tokens. Data is stored as a flat `Vec<f64>` of shape
36/// `[BLOCK_TOKEN_COUNT, dim]`, with a `used` cursor tracking how many
37/// token slots have been written.
38#[derive(Debug, Clone)]
39pub struct KvBlock {
40 /// Data storage: [BLOCK_TOKEN_COUNT, dim]. Pre-allocated and zeroed.
41 data: Vec<f64>,
42 /// Hidden dimension per token.
43 dim: usize,
44 /// Number of tokens currently written in this block (0..=BLOCK_TOKEN_COUNT).
45 used: usize,
46}
47
48impl KvBlock {
49 /// Create a new zeroed block for tokens of the given hidden dimension.
50 fn new(dim: usize) -> Self {
51 KvBlock {
52 data: vec![0.0; BLOCK_TOKEN_COUNT * dim],
53 dim,
54 used: 0,
55 }
56 }
57
58 /// Return `true` if all token slots in this block have been written.
59 fn is_full(&self) -> bool {
60 self.used >= BLOCK_TOKEN_COUNT
61 }
62
63 /// Return the number of unused token slots remaining in this block.
64 #[allow(dead_code)]
65 fn remaining(&self) -> usize {
66 BLOCK_TOKEN_COUNT - self.used
67 }
68
69 /// Write a single token vector into the block. Returns error if full.
70 fn write_token(&mut self, token: &[f64]) -> Result<(), RuntimeError> {
71 if token.len() != self.dim {
72 return Err(RuntimeError::ShapeMismatch {
73 expected: self.dim,
74 got: token.len(),
75 });
76 }
77 if self.is_full() {
78 return Err(RuntimeError::InvalidOperation(
79 "KvBlock is full".to_string(),
80 ));
81 }
82 let base = self.used * self.dim;
83 self.data[base..base + self.dim].copy_from_slice(token);
84 self.used += 1;
85 Ok(())
86 }
87
88 /// Read token at position `idx` within this block.
89 fn read_token(&self, idx: usize) -> &[f64] {
90 let base = idx * self.dim;
91 &self.data[base..base + self.dim]
92 }
93}
94
95/// A vLLM-style block-paged KV-cache. Instead of one contiguous pre-allocated
96/// tensor (which may fragment or require realloc), this manages memory in
97/// fixed-size 16-token blocks via a `BlockTable`.
98///
99/// Benefits:
100/// - No single large allocation — blocks are page-sized
101/// - Zero reallocation on append (new blocks allocated on demand from pool)
102/// - Logical-to-physical mapping via block table
103/// - Each block is independently cache-line friendly
104#[derive(Debug, Clone)]
105pub struct PagedKvCache {
106 /// All allocated blocks.
107 blocks: Vec<KvBlock>,
108 /// Block table: maps logical block indices to physical block indices.
109 /// `block_table[i]` = index into `blocks` for the i-th logical block.
110 block_table: Vec<usize>,
111 /// Hidden dimension per token.
112 dim: usize,
113 /// Maximum total tokens allowed.
114 max_tokens: usize,
115 /// Total tokens currently stored.
116 current_len: usize,
117}
118
119impl PagedKvCache {
120 /// Create a paged KV-cache for `max_tokens` tokens of dimension `dim`.
121 ///
122 /// Pre-allocates all blocks upfront to avoid any heap allocation during
123 /// the inference loop. The number of blocks = ceil(max_tokens / 16).
124 pub fn new(max_tokens: usize, dim: usize) -> Self {
125 let num_blocks = (max_tokens + BLOCK_TOKEN_COUNT - 1) / BLOCK_TOKEN_COUNT;
126 let mut blocks = Vec::with_capacity(num_blocks);
127 let mut block_table = Vec::with_capacity(num_blocks);
128 for i in 0..num_blocks {
129 blocks.push(KvBlock::new(dim));
130 block_table.push(i); // identity mapping initially
131 }
132 PagedKvCache {
133 blocks,
134 block_table,
135 dim,
136 max_tokens,
137 current_len: 0,
138 }
139 }
140
141 /// Number of tokens currently stored.
142 pub fn len(&self) -> usize {
143 self.current_len
144 }
145
146 /// Whether no tokens are stored.
147 pub fn is_empty(&self) -> bool {
148 self.current_len == 0
149 }
150
151 /// Maximum tokens this cache can hold.
152 pub fn max_tokens(&self) -> usize {
153 self.max_tokens
154 }
155
156 /// Hidden dimension per token.
157 pub fn dim(&self) -> usize {
158 self.dim
159 }
160
161 /// Number of blocks allocated.
162 pub fn num_blocks(&self) -> usize {
163 self.blocks.len()
164 }
165
166 /// Number of blocks currently in use (partially or fully).
167 pub fn blocks_in_use(&self) -> usize {
168 if self.current_len == 0 { return 0; }
169 (self.current_len + BLOCK_TOKEN_COUNT - 1) / BLOCK_TOKEN_COUNT
170 }
171
172 /// Append a single token vector. **Zero allocation** — writes into
173 /// the next available slot in the current block.
174 pub fn append(&mut self, token: &[f64]) -> Result<(), RuntimeError> {
175 if token.len() != self.dim {
176 return Err(RuntimeError::ShapeMismatch {
177 expected: self.dim,
178 got: token.len(),
179 });
180 }
181 if self.current_len >= self.max_tokens {
182 return Err(RuntimeError::InvalidOperation(
183 format!(
184 "PagedKvCache full: {} / {} tokens",
185 self.current_len, self.max_tokens
186 ),
187 ));
188 }
189 let logical_block = self.current_len / BLOCK_TOKEN_COUNT;
190 let physical_block = self.block_table[logical_block];
191 self.blocks[physical_block].write_token(token)?;
192 self.current_len += 1;
193 Ok(())
194 }
195
196 /// Append a batch of tokens from a 2D tensor `[n, dim]`.
197 pub fn append_tensor(&mut self, t: &Tensor) -> Result<(), RuntimeError> {
198 if t.ndim() != 2 || t.shape()[1] != self.dim {
199 return Err(RuntimeError::InvalidOperation(
200 format!(
201 "PagedKvCache.append_tensor: expected [n, {}], got {:?}",
202 self.dim, t.shape()
203 ),
204 ));
205 }
206 let n = t.shape()[0];
207 if self.current_len + n > self.max_tokens {
208 return Err(RuntimeError::InvalidOperation(
209 format!(
210 "PagedKvCache overflow: {} + {} > {}",
211 self.current_len, n, self.max_tokens
212 ),
213 ));
214 }
215 let data = t.to_vec();
216 for i in 0..n {
217 let start = i * self.dim;
218 self.append(&data[start..start + self.dim])?;
219 }
220 Ok(())
221 }
222
223 /// Materialize all stored tokens into a contiguous Tensor `[current_len, dim]`.
224 ///
225 /// This is a read operation that copies data from blocks into a flat
226 /// buffer. The copy is required since blocks are non-contiguous.
227 pub fn as_tensor(&self) -> Tensor {
228 if self.current_len == 0 {
229 return Tensor::from_vec(vec![], &[0, self.dim])
230 .unwrap_or_else(|_| Tensor::zeros(&[0]));
231 }
232 let mut data = Vec::with_capacity(self.current_len * self.dim);
233 let mut remaining = self.current_len;
234 for &phys_idx in &self.block_table {
235 if remaining == 0 { break; }
236 let block = &self.blocks[phys_idx];
237 let tokens_in_block = remaining.min(block.used);
238 for t in 0..tokens_in_block {
239 data.extend_from_slice(block.read_token(t));
240 }
241 remaining -= tokens_in_block;
242 }
243 Tensor::from_vec(data, &[self.current_len, self.dim])
244 .expect("PagedKvCache::as_tensor shape mismatch")
245 }
246
247 /// Reset the cache to empty without deallocating blocks.
248 /// Block data is retained; only cursors are reset.
249 pub fn clear(&mut self) {
250 for block in &mut self.blocks {
251 block.used = 0;
252 }
253 self.current_len = 0;
254 }
255
256 /// Read a single token at logical position `idx`.
257 pub fn get_token(&self, idx: usize) -> Result<Vec<f64>, RuntimeError> {
258 if idx >= self.current_len {
259 return Err(RuntimeError::IndexOutOfBounds {
260 index: idx,
261 length: self.current_len,
262 });
263 }
264 let logical_block = idx / BLOCK_TOKEN_COUNT;
265 let offset_in_block = idx % BLOCK_TOKEN_COUNT;
266 let physical_block = self.block_table[logical_block];
267 Ok(self.blocks[physical_block].read_token(offset_in_block).to_vec())
268 }
269}
270
271impl fmt::Display for PagedKvCache {
272 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273 write!(
274 f,
275 "PagedKvCache(len={}, max={}, dim={}, blocks={}/{})",
276 self.current_len,
277 self.max_tokens,
278 self.dim,
279 self.blocks_in_use(),
280 self.blocks.len()
281 )
282 }
283}
284
285impl fmt::Display for Tensor {
286 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287 write!(f, "Tensor(shape={:?}, data={:?})", self.shape, self.to_vec())
288 }
289}
290