1use std::collections::{HashMap, VecDeque};
4use std::fmt;
5
6use super::types::{
7 BlockId, CacheStats, EvictionStrategy, PagedKvError, PagedKvResult, SeqId, SequenceInfo,
8};
9
10#[derive(Debug)]
16pub struct PagedKvCache {
17 block_size: usize,
19 num_heads: usize,
21 head_dim: usize,
23 num_blocks: usize,
25 free_blocks: VecDeque<BlockId>,
27 sequences: HashMap<SeqId, SequenceInfo>,
29 block_refs: HashMap<BlockId, u32>,
31 eviction_strategy: EvictionStrategy,
33 eviction_threshold: f64,
35 stats: CacheStats,
37}
38
39impl PagedKvCache {
40 pub fn new(num_blocks: usize, block_size: usize, num_heads: usize, head_dim: usize) -> Self {
48 let free_blocks: VecDeque<BlockId> = (0..num_blocks as u32).map(BlockId).collect();
50
51 Self {
52 block_size,
53 num_heads,
54 head_dim,
55 num_blocks,
56 free_blocks,
57 sequences: HashMap::new(),
58 block_refs: HashMap::new(),
59 eviction_strategy: EvictionStrategy::default(),
60 eviction_threshold: 0.9,
61 stats: CacheStats::default(),
62 }
63 }
64
65 pub fn with_eviction_strategy(mut self, strategy: EvictionStrategy) -> Self {
67 self.eviction_strategy = strategy;
68 self
69 }
70
71 pub fn with_eviction_threshold(mut self, threshold: f64) -> Self {
73 self.eviction_threshold = threshold.clamp(0.0, 1.0);
74 self
75 }
76
77 pub fn block_size(&self) -> usize {
79 self.block_size
80 }
81
82 pub fn total_blocks(&self) -> usize {
84 self.num_blocks
85 }
86
87 pub fn free_block_count(&self) -> usize {
89 self.free_blocks.len()
90 }
91
92 pub fn used_block_count(&self) -> usize {
94 self.num_blocks - self.free_blocks.len()
95 }
96
97 pub fn utilization(&self) -> f64 {
99 if self.num_blocks == 0 {
100 return 0.0;
101 }
102 self.used_block_count() as f64 / self.num_blocks as f64
103 }
104
105 pub fn block_memory_bytes(&self) -> usize {
107 2 * self.block_size * self.num_heads * self.head_dim * 2
109 }
110
111 pub fn total_memory_bytes(&self) -> usize {
113 self.num_blocks * self.block_memory_bytes()
114 }
115
116 pub fn used_memory_bytes(&self) -> usize {
118 self.used_block_count() * self.block_memory_bytes()
119 }
120
121 pub fn needs_eviction(&self) -> bool {
123 self.utilization() >= self.eviction_threshold
124 }
125
126 pub fn num_sequences(&self) -> usize {
128 self.sequences.len()
129 }
130
131 pub fn get_sequence(&self, seq_id: SeqId) -> Option<&SequenceInfo> {
133 self.sequences.get(&seq_id)
134 }
135
136 pub fn stats(&self) -> &CacheStats {
138 &self.stats
139 }
140
141 pub fn eviction_strategy(&self) -> &EvictionStrategy {
143 &self.eviction_strategy
144 }
145
146 fn blocks_needed(&self, num_tokens: usize) -> usize {
148 num_tokens.div_ceil(self.block_size)
149 }
150
151 fn allocate_block(&mut self) -> PagedKvResult<BlockId> {
153 if let Some(block_id) = self.free_blocks.pop_front() {
154 self.block_refs.insert(block_id, 1);
155 self.stats.total_allocations += 1;
156
157 let used = self.used_block_count();
159 if used > self.stats.peak_blocks_used {
160 self.stats.peak_blocks_used = used;
161 }
162
163 Ok(block_id)
164 } else {
165 Err(PagedKvError::OutOfMemory {
166 requested: 1,
167 available: 0,
168 })
169 }
170 }
171
172 fn free_block(&mut self, block_id: BlockId) -> PagedKvResult<()> {
174 if let Some(refs) = self.block_refs.get_mut(&block_id) {
175 *refs -= 1;
176 if *refs == 0 {
177 self.block_refs.remove(&block_id);
178 self.free_blocks.push_back(block_id);
179 self.stats.total_frees += 1;
180 }
181 Ok(())
182 } else {
183 Err(PagedKvError::BlockNotFound(block_id))
184 }
185 }
186
187 pub fn allocate(&mut self, seq_id: SeqId, num_tokens: usize) -> PagedKvResult<()> {
189 if self.sequences.contains_key(&seq_id) {
190 return Err(PagedKvError::InvalidOperation(format!(
191 "Sequence {} already exists",
192 seq_id
193 )));
194 }
195
196 let blocks_needed = self.blocks_needed(num_tokens);
197
198 if blocks_needed > self.free_blocks.len() {
200 return Err(PagedKvError::OutOfMemory {
201 requested: blocks_needed,
202 available: self.free_blocks.len(),
203 });
204 }
205
206 let mut block_ids = Vec::with_capacity(blocks_needed);
208 for _ in 0..blocks_needed {
209 block_ids.push(self.allocate_block()?);
210 }
211
212 let mut seq_info = SequenceInfo::new(seq_id);
214 seq_info.num_tokens = num_tokens;
215 seq_info.block_ids = block_ids;
216 seq_info.touch();
217
218 self.sequences.insert(seq_id, seq_info);
219 Ok(())
220 }
221
222 pub fn append(&mut self, seq_id: SeqId, num_new_tokens: usize) -> PagedKvResult<()> {
224 let (old_tokens, additional_blocks) = {
226 let seq_info = self
227 .sequences
228 .get(&seq_id)
229 .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
230
231 let old_tokens = seq_info.num_tokens;
232 let new_tokens = old_tokens + num_new_tokens;
233 let old_blocks = self.blocks_needed(old_tokens);
234 let new_blocks = self.blocks_needed(new_tokens);
235 let additional = new_blocks.saturating_sub(old_blocks);
236
237 (old_tokens, additional)
238 };
239
240 if additional_blocks > self.free_blocks.len() {
242 return Err(PagedKvError::OutOfMemory {
243 requested: additional_blocks,
244 available: self.free_blocks.len(),
245 });
246 }
247
248 let mut new_block_ids = Vec::with_capacity(additional_blocks);
250 for _ in 0..additional_blocks {
251 new_block_ids.push(self.allocate_block()?);
252 }
253
254 let seq_info = self
256 .sequences
257 .get_mut(&seq_id)
258 .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
259
260 seq_info.block_ids.extend(new_block_ids);
261 seq_info.num_tokens = old_tokens + num_new_tokens;
262 seq_info.touch();
263 Ok(())
264 }
265
266 pub fn free(&mut self, seq_id: SeqId) -> PagedKvResult<()> {
268 let seq_info = self
269 .sequences
270 .remove(&seq_id)
271 .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
272
273 for block_id in seq_info.block_ids {
274 self.free_block(block_id)?;
275 }
276
277 Ok(())
278 }
279
280 pub fn fork(&mut self, src_seq: SeqId, dst_seq: SeqId) -> PagedKvResult<()> {
285 if self.sequences.contains_key(&dst_seq) {
286 return Err(PagedKvError::InvalidOperation(format!(
287 "Destination sequence {} already exists",
288 dst_seq
289 )));
290 }
291
292 let src_info = self
293 .sequences
294 .get(&src_seq)
295 .ok_or(PagedKvError::SequenceNotFound(src_seq))?
296 .clone();
297
298 for block_id in &src_info.block_ids {
300 if let Some(refs) = self.block_refs.get_mut(block_id) {
301 *refs += 1;
302 }
303 }
304
305 let mut dst_info = SequenceInfo::new(dst_seq);
307 dst_info.num_tokens = src_info.num_tokens;
308 dst_info.block_ids = src_info.block_ids.clone();
309 dst_info.touch();
310
311 self.sequences.insert(dst_seq, dst_info);
312 self.stats.total_forks += 1;
313 Ok(())
314 }
315
316 pub fn select_eviction_target(&self) -> Option<SeqId> {
318 if self.sequences.is_empty() {
319 return None;
320 }
321
322 match &self.eviction_strategy {
323 EvictionStrategy::LRU => {
324 self.sequences
326 .values()
327 .min_by_key(|s| s.last_access)
328 .map(|s| s.seq_id)
329 }
330 EvictionStrategy::LFU => {
331 self.sequences
333 .values()
334 .min_by_key(|s| s.access_count)
335 .map(|s| s.seq_id)
336 }
337 EvictionStrategy::LongestFirst => {
338 self.sequences
340 .values()
341 .max_by_key(|s| s.num_tokens)
342 .map(|s| s.seq_id)
343 }
344 EvictionStrategy::Priority { .. } => {
345 self.sequences
347 .values()
348 .min_by_key(|s| s.priority)
349 .map(|s| s.seq_id)
350 }
351 EvictionStrategy::StreamingLLM { .. } => {
352 self.sequences
355 .values()
356 .min_by_key(|s| s.last_access)
357 .map(|s| s.seq_id)
358 }
359 }
360 }
361
362 pub fn evict(&mut self) -> PagedKvResult<SeqId> {
364 let target = self
365 .select_eviction_target()
366 .ok_or(PagedKvError::InvalidOperation(
367 "No sequences to evict".to_string(),
368 ))?;
369
370 self.free(target)?;
371 self.stats.total_evictions += 1;
372 Ok(target)
373 }
374
375 pub fn evict_to_threshold(&mut self, target_util: f64) -> PagedKvResult<Vec<SeqId>> {
377 let mut evicted = Vec::new();
378 while self.utilization() > target_util && !self.sequences.is_empty() {
379 evicted.push(self.evict()?);
380 }
381 Ok(evicted)
382 }
383
384 pub fn apply_streaming_llm(
389 &mut self,
390 seq_id: SeqId,
391 sink_tokens: usize,
392 window_tokens: usize,
393 ) -> PagedKvResult<usize> {
394 let (num_tokens, blocks_to_remove) = {
396 let seq_info = self
397 .sequences
398 .get(&seq_id)
399 .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
400
401 let keep_tokens = sink_tokens + window_tokens;
402 if seq_info.num_tokens <= keep_tokens {
403 return Ok(0); }
405
406 let old_blocks = self.blocks_needed(seq_info.num_tokens);
407 let new_blocks = self.blocks_needed(keep_tokens);
408 let blocks_to_free = old_blocks.saturating_sub(new_blocks);
409
410 let blocks: Vec<BlockId> = seq_info
412 .block_ids
413 .iter()
414 .skip(sink_tokens / self.block_size + 1)
415 .take(blocks_to_free)
416 .cloned()
417 .collect();
418
419 (seq_info.num_tokens, blocks)
420 };
421
422 let keep_tokens = sink_tokens + window_tokens;
423 let evict_tokens = num_tokens - keep_tokens;
424
425 for block_id in &blocks_to_remove {
427 self.free_block(*block_id)?;
428 }
429
430 if let Some(seq_info) = self.sequences.get_mut(&seq_id) {
432 for block_id in blocks_to_remove {
433 seq_info.block_ids.retain(|&id| id != block_id);
434 }
435 seq_info.num_tokens = keep_tokens;
436 }
437
438 Ok(evict_tokens)
439 }
440
441 pub fn sequence_ids(&self) -> Vec<SeqId> {
443 self.sequences.keys().cloned().collect()
444 }
445}
446
447impl fmt::Display for PagedKvCache {
448 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449 writeln!(f, "PagedKvCache")?;
450 writeln!(
451 f,
452 " Strategy: {} (block_size={})",
453 self.eviction_strategy, self.block_size
454 )?;
455 writeln!(
456 f,
457 " Blocks: {}/{} ({:.1}% used)",
458 self.used_block_count(),
459 self.num_blocks,
460 self.utilization() * 100.0
461 )?;
462 writeln!(
463 f,
464 " Memory: {:.2} MB / {:.2} MB",
465 self.used_memory_bytes() as f64 / 1_000_000.0,
466 self.total_memory_bytes() as f64 / 1_000_000.0
467 )?;
468 writeln!(f, " Sequences: {} active", self.num_sequences())?;
469 writeln!(
470 f,
471 " Stats: allocs={}, frees={}, evictions={}, forks={}",
472 self.stats.total_allocations,
473 self.stats.total_frees,
474 self.stats.total_evictions,
475 self.stats.total_forks
476 )?;
477 Ok(())
478 }
479}