1use std::fmt;
2
3use crate::error::RuntimeError;
4use crate::tensor::Tensor;
5
6const BLOCK_TOKEN_COUNT: usize = 16;
13
14#[derive(Debug, Clone)]
16pub struct KvBlock {
17 data: Vec<f64>,
19 dim: usize,
21 used: usize,
23}
24
25impl KvBlock {
26 fn new(dim: usize) -> Self {
27 KvBlock {
28 data: vec![0.0; BLOCK_TOKEN_COUNT * dim],
29 dim,
30 used: 0,
31 }
32 }
33
34 fn is_full(&self) -> bool {
35 self.used >= BLOCK_TOKEN_COUNT
36 }
37
38 #[allow(dead_code)]
39 fn remaining(&self) -> usize {
40 BLOCK_TOKEN_COUNT - self.used
41 }
42
43 fn write_token(&mut self, token: &[f64]) -> Result<(), RuntimeError> {
45 if token.len() != self.dim {
46 return Err(RuntimeError::ShapeMismatch {
47 expected: self.dim,
48 got: token.len(),
49 });
50 }
51 if self.is_full() {
52 return Err(RuntimeError::InvalidOperation(
53 "KvBlock is full".to_string(),
54 ));
55 }
56 let base = self.used * self.dim;
57 self.data[base..base + self.dim].copy_from_slice(token);
58 self.used += 1;
59 Ok(())
60 }
61
62 fn read_token(&self, idx: usize) -> &[f64] {
64 let base = idx * self.dim;
65 &self.data[base..base + self.dim]
66 }
67}
68
69#[derive(Debug, Clone)]
79pub struct PagedKvCache {
80 blocks: Vec<KvBlock>,
82 block_table: Vec<usize>,
85 dim: usize,
87 max_tokens: usize,
89 current_len: usize,
91}
92
93impl PagedKvCache {
94 pub fn new(max_tokens: usize, dim: usize) -> Self {
99 let num_blocks = (max_tokens + BLOCK_TOKEN_COUNT - 1) / BLOCK_TOKEN_COUNT;
100 let mut blocks = Vec::with_capacity(num_blocks);
101 let mut block_table = Vec::with_capacity(num_blocks);
102 for i in 0..num_blocks {
103 blocks.push(KvBlock::new(dim));
104 block_table.push(i); }
106 PagedKvCache {
107 blocks,
108 block_table,
109 dim,
110 max_tokens,
111 current_len: 0,
112 }
113 }
114
115 pub fn len(&self) -> usize {
117 self.current_len
118 }
119
120 pub fn is_empty(&self) -> bool {
122 self.current_len == 0
123 }
124
125 pub fn max_tokens(&self) -> usize {
127 self.max_tokens
128 }
129
130 pub fn dim(&self) -> usize {
132 self.dim
133 }
134
135 pub fn num_blocks(&self) -> usize {
137 self.blocks.len()
138 }
139
140 pub fn blocks_in_use(&self) -> usize {
142 if self.current_len == 0 { return 0; }
143 (self.current_len + BLOCK_TOKEN_COUNT - 1) / BLOCK_TOKEN_COUNT
144 }
145
146 pub fn append(&mut self, token: &[f64]) -> Result<(), RuntimeError> {
149 if token.len() != self.dim {
150 return Err(RuntimeError::ShapeMismatch {
151 expected: self.dim,
152 got: token.len(),
153 });
154 }
155 if self.current_len >= self.max_tokens {
156 return Err(RuntimeError::InvalidOperation(
157 format!(
158 "PagedKvCache full: {} / {} tokens",
159 self.current_len, self.max_tokens
160 ),
161 ));
162 }
163 let logical_block = self.current_len / BLOCK_TOKEN_COUNT;
164 let physical_block = self.block_table[logical_block];
165 self.blocks[physical_block].write_token(token)?;
166 self.current_len += 1;
167 Ok(())
168 }
169
170 pub fn append_tensor(&mut self, t: &Tensor) -> Result<(), RuntimeError> {
172 if t.ndim() != 2 || t.shape()[1] != self.dim {
173 return Err(RuntimeError::InvalidOperation(
174 format!(
175 "PagedKvCache.append_tensor: expected [n, {}], got {:?}",
176 self.dim, t.shape()
177 ),
178 ));
179 }
180 let n = t.shape()[0];
181 if self.current_len + n > self.max_tokens {
182 return Err(RuntimeError::InvalidOperation(
183 format!(
184 "PagedKvCache overflow: {} + {} > {}",
185 self.current_len, n, self.max_tokens
186 ),
187 ));
188 }
189 let data = t.to_vec();
190 for i in 0..n {
191 let start = i * self.dim;
192 self.append(&data[start..start + self.dim])?;
193 }
194 Ok(())
195 }
196
197 pub fn as_tensor(&self) -> Tensor {
202 if self.current_len == 0 {
203 return Tensor::from_vec(vec![], &[0, self.dim])
204 .unwrap_or_else(|_| Tensor::zeros(&[0]));
205 }
206 let mut data = Vec::with_capacity(self.current_len * self.dim);
207 let mut remaining = self.current_len;
208 for &phys_idx in &self.block_table {
209 if remaining == 0 { break; }
210 let block = &self.blocks[phys_idx];
211 let tokens_in_block = remaining.min(block.used);
212 for t in 0..tokens_in_block {
213 data.extend_from_slice(block.read_token(t));
214 }
215 remaining -= tokens_in_block;
216 }
217 Tensor::from_vec(data, &[self.current_len, self.dim])
218 .expect("PagedKvCache::as_tensor shape mismatch")
219 }
220
221 pub fn clear(&mut self) {
224 for block in &mut self.blocks {
225 block.used = 0;
226 }
227 self.current_len = 0;
228 }
229
230 pub fn get_token(&self, idx: usize) -> Result<Vec<f64>, RuntimeError> {
232 if idx >= self.current_len {
233 return Err(RuntimeError::IndexOutOfBounds {
234 index: idx,
235 length: self.current_len,
236 });
237 }
238 let logical_block = idx / BLOCK_TOKEN_COUNT;
239 let offset_in_block = idx % BLOCK_TOKEN_COUNT;
240 let physical_block = self.block_table[logical_block];
241 Ok(self.blocks[physical_block].read_token(offset_in_block).to_vec())
242 }
243}
244
245impl fmt::Display for PagedKvCache {
246 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247 write!(
248 f,
249 "PagedKvCache(len={}, max={}, dim={}, blocks={}/{})",
250 self.current_len,
251 self.max_tokens,
252 self.dim,
253 self.blocks_in_use(),
254 self.blocks.len()
255 )
256 }
257}
258
259impl fmt::Display for Tensor {
260 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261 write!(f, "Tensor(shape={:?}, data={:?})", self.shape, self.to_vec())
262 }
263}
264