Skip to main content

cjc_runtime/
paged_kv.rs

1use std::fmt;
2
3use crate::error::RuntimeError;
4use crate::tensor::Tensor;
5
6// ---------------------------------------------------------------------------
7// 2e. BlockPaged KV-Cache — vLLM-style block paging
8// ---------------------------------------------------------------------------
9
10/// Fixed-size block for the paged KV-cache. Each block holds up to
11/// `BLOCK_TOKEN_COUNT` tokens of a fixed hidden dimension.
12const BLOCK_TOKEN_COUNT: usize = 16;
13
14/// A single page/block in the KV-cache. Pre-allocated, fixed-size.
15#[derive(Debug, Clone)]
16pub struct KvBlock {
17    /// Data storage: [BLOCK_TOKEN_COUNT, dim]. Pre-allocated and zeroed.
18    data: Vec<f64>,
19    /// Hidden dimension per token.
20    dim: usize,
21    /// Number of tokens currently written in this block (0..=BLOCK_TOKEN_COUNT).
22    used: usize,
23}
24
25impl KvBlock {
26    fn new(dim: usize) -> Self {
27        KvBlock {
28            data: vec![0.0; BLOCK_TOKEN_COUNT * dim],
29            dim,
30            used: 0,
31        }
32    }
33
34    fn is_full(&self) -> bool {
35        self.used >= BLOCK_TOKEN_COUNT
36    }
37
38    #[allow(dead_code)]
39    fn remaining(&self) -> usize {
40        BLOCK_TOKEN_COUNT - self.used
41    }
42
43    /// Write a single token vector into the block. Returns error if full.
44    fn write_token(&mut self, token: &[f64]) -> Result<(), RuntimeError> {
45        if token.len() != self.dim {
46            return Err(RuntimeError::ShapeMismatch {
47                expected: self.dim,
48                got: token.len(),
49            });
50        }
51        if self.is_full() {
52            return Err(RuntimeError::InvalidOperation(
53                "KvBlock is full".to_string(),
54            ));
55        }
56        let base = self.used * self.dim;
57        self.data[base..base + self.dim].copy_from_slice(token);
58        self.used += 1;
59        Ok(())
60    }
61
62    /// Read token at position `idx` within this block.
63    fn read_token(&self, idx: usize) -> &[f64] {
64        let base = idx * self.dim;
65        &self.data[base..base + self.dim]
66    }
67}
68
69/// A vLLM-style block-paged KV-cache. Instead of one contiguous pre-allocated
70/// tensor (which may fragment or require realloc), this manages memory in
71/// fixed-size 16-token blocks via a `BlockTable`.
72///
73/// Benefits:
74/// - No single large allocation — blocks are page-sized
75/// - Zero reallocation on append (new blocks allocated on demand from pool)
76/// - Logical-to-physical mapping via block table
77/// - Each block is independently cache-line friendly
78#[derive(Debug, Clone)]
79pub struct PagedKvCache {
80    /// All allocated blocks.
81    blocks: Vec<KvBlock>,
82    /// Block table: maps logical block indices to physical block indices.
83    /// `block_table[i]` = index into `blocks` for the i-th logical block.
84    block_table: Vec<usize>,
85    /// Hidden dimension per token.
86    dim: usize,
87    /// Maximum total tokens allowed.
88    max_tokens: usize,
89    /// Total tokens currently stored.
90    current_len: usize,
91}
92
93impl PagedKvCache {
94    /// Create a paged KV-cache for `max_tokens` tokens of dimension `dim`.
95    ///
96    /// Pre-allocates all blocks upfront to avoid any heap allocation during
97    /// the inference loop. The number of blocks = ceil(max_tokens / 16).
98    pub fn new(max_tokens: usize, dim: usize) -> Self {
99        let num_blocks = (max_tokens + BLOCK_TOKEN_COUNT - 1) / BLOCK_TOKEN_COUNT;
100        let mut blocks = Vec::with_capacity(num_blocks);
101        let mut block_table = Vec::with_capacity(num_blocks);
102        for i in 0..num_blocks {
103            blocks.push(KvBlock::new(dim));
104            block_table.push(i); // identity mapping initially
105        }
106        PagedKvCache {
107            blocks,
108            block_table,
109            dim,
110            max_tokens,
111            current_len: 0,
112        }
113    }
114
115    /// Number of tokens currently stored.
116    pub fn len(&self) -> usize {
117        self.current_len
118    }
119
120    /// Whether no tokens are stored.
121    pub fn is_empty(&self) -> bool {
122        self.current_len == 0
123    }
124
125    /// Maximum tokens this cache can hold.
126    pub fn max_tokens(&self) -> usize {
127        self.max_tokens
128    }
129
130    /// Hidden dimension per token.
131    pub fn dim(&self) -> usize {
132        self.dim
133    }
134
135    /// Number of blocks allocated.
136    pub fn num_blocks(&self) -> usize {
137        self.blocks.len()
138    }
139
140    /// Number of blocks currently in use (partially or fully).
141    pub fn blocks_in_use(&self) -> usize {
142        if self.current_len == 0 { return 0; }
143        (self.current_len + BLOCK_TOKEN_COUNT - 1) / BLOCK_TOKEN_COUNT
144    }
145
146    /// Append a single token vector. **Zero allocation** — writes into
147    /// the next available slot in the current block.
148    pub fn append(&mut self, token: &[f64]) -> Result<(), RuntimeError> {
149        if token.len() != self.dim {
150            return Err(RuntimeError::ShapeMismatch {
151                expected: self.dim,
152                got: token.len(),
153            });
154        }
155        if self.current_len >= self.max_tokens {
156            return Err(RuntimeError::InvalidOperation(
157                format!(
158                    "PagedKvCache full: {} / {} tokens",
159                    self.current_len, self.max_tokens
160                ),
161            ));
162        }
163        let logical_block = self.current_len / BLOCK_TOKEN_COUNT;
164        let physical_block = self.block_table[logical_block];
165        self.blocks[physical_block].write_token(token)?;
166        self.current_len += 1;
167        Ok(())
168    }
169
170    /// Append a batch of tokens from a 2D tensor `[n, dim]`.
171    pub fn append_tensor(&mut self, t: &Tensor) -> Result<(), RuntimeError> {
172        if t.ndim() != 2 || t.shape()[1] != self.dim {
173            return Err(RuntimeError::InvalidOperation(
174                format!(
175                    "PagedKvCache.append_tensor: expected [n, {}], got {:?}",
176                    self.dim, t.shape()
177                ),
178            ));
179        }
180        let n = t.shape()[0];
181        if self.current_len + n > self.max_tokens {
182            return Err(RuntimeError::InvalidOperation(
183                format!(
184                    "PagedKvCache overflow: {} + {} > {}",
185                    self.current_len, n, self.max_tokens
186                ),
187            ));
188        }
189        let data = t.to_vec();
190        for i in 0..n {
191            let start = i * self.dim;
192            self.append(&data[start..start + self.dim])?;
193        }
194        Ok(())
195    }
196
197    /// Materialize all stored tokens into a contiguous Tensor `[current_len, dim]`.
198    ///
199    /// This is a read operation that copies data from blocks into a flat
200    /// buffer. The copy is required since blocks are non-contiguous.
201    pub fn as_tensor(&self) -> Tensor {
202        if self.current_len == 0 {
203            return Tensor::from_vec(vec![], &[0, self.dim])
204                .unwrap_or_else(|_| Tensor::zeros(&[0]));
205        }
206        let mut data = Vec::with_capacity(self.current_len * self.dim);
207        let mut remaining = self.current_len;
208        for &phys_idx in &self.block_table {
209            if remaining == 0 { break; }
210            let block = &self.blocks[phys_idx];
211            let tokens_in_block = remaining.min(block.used);
212            for t in 0..tokens_in_block {
213                data.extend_from_slice(block.read_token(t));
214            }
215            remaining -= tokens_in_block;
216        }
217        Tensor::from_vec(data, &[self.current_len, self.dim])
218            .expect("PagedKvCache::as_tensor shape mismatch")
219    }
220
221    /// Reset the cache to empty without deallocating blocks.
222    /// Block data is retained; only cursors are reset.
223    pub fn clear(&mut self) {
224        for block in &mut self.blocks {
225            block.used = 0;
226        }
227        self.current_len = 0;
228    }
229
230    /// Read a single token at logical position `idx`.
231    pub fn get_token(&self, idx: usize) -> Result<Vec<f64>, RuntimeError> {
232        if idx >= self.current_len {
233            return Err(RuntimeError::IndexOutOfBounds {
234                index: idx,
235                length: self.current_len,
236            });
237        }
238        let logical_block = idx / BLOCK_TOKEN_COUNT;
239        let offset_in_block = idx % BLOCK_TOKEN_COUNT;
240        let physical_block = self.block_table[logical_block];
241        Ok(self.blocks[physical_block].read_token(offset_in_block).to_vec())
242    }
243}
244
245impl fmt::Display for PagedKvCache {
246    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247        write!(
248            f,
249            "PagedKvCache(len={}, max={}, dim={}, blocks={}/{})",
250            self.current_len,
251            self.max_tokens,
252            self.dim,
253            self.blocks_in_use(),
254            self.blocks.len()
255        )
256    }
257}
258
259impl fmt::Display for Tensor {
260    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261        write!(f, "Tensor(shape={:?}, data={:?})", self.shape, self.to_vec())
262    }
263}
264