Skip to main content

cbtop/paged_kv/
cache.rs

1//! PagedKvCache implementation with allocation, eviction, and copy-on-write.
2
3use std::collections::{HashMap, VecDeque};
4use std::fmt;
5
6use super::types::{
7    BlockId, CacheStats, EvictionStrategy, PagedKvError, PagedKvResult, SeqId, SequenceInfo,
8};
9
10/// Paged KV cache for efficient memory management.
11///
12/// Based on vLLM's PagedAttention algorithm. Manages KV cache memory
13/// using fixed-size blocks to prevent fragmentation and enable
14/// efficient memory sharing.
15#[derive(Debug)]
16pub struct PagedKvCache {
17    /// Block size (tokens per block)
18    block_size: usize,
19    /// Number of attention heads
20    num_heads: usize,
21    /// Head dimension
22    head_dim: usize,
23    /// Total number of physical blocks
24    num_blocks: usize,
25    /// Free block indices
26    free_blocks: VecDeque<BlockId>,
27    /// Sequence -> info mapping
28    sequences: HashMap<SeqId, SequenceInfo>,
29    /// Block reference counts (for COW)
30    block_refs: HashMap<BlockId, u32>,
31    /// Eviction strategy
32    eviction_strategy: EvictionStrategy,
33    /// Memory threshold for eviction (0.0-1.0)
34    eviction_threshold: f64,
35    /// Cache statistics
36    stats: CacheStats,
37}
38
39impl PagedKvCache {
40    /// Create a new PagedKvCache.
41    ///
42    /// # Arguments
43    /// - `num_blocks`: Total number of physical blocks
44    /// - `block_size`: Tokens per block
45    /// - `num_heads`: Number of attention heads
46    /// - `head_dim`: Dimension of each head
47    pub fn new(num_blocks: usize, block_size: usize, num_heads: usize, head_dim: usize) -> Self {
48        // Initialize free blocks
49        let free_blocks: VecDeque<BlockId> = (0..num_blocks as u32).map(BlockId).collect();
50
51        Self {
52            block_size,
53            num_heads,
54            head_dim,
55            num_blocks,
56            free_blocks,
57            sequences: HashMap::new(),
58            block_refs: HashMap::new(),
59            eviction_strategy: EvictionStrategy::default(),
60            eviction_threshold: 0.9,
61            stats: CacheStats::default(),
62        }
63    }
64
65    /// Set eviction strategy.
66    pub fn with_eviction_strategy(mut self, strategy: EvictionStrategy) -> Self {
67        self.eviction_strategy = strategy;
68        self
69    }
70
71    /// Set eviction threshold (0.0-1.0).
72    pub fn with_eviction_threshold(mut self, threshold: f64) -> Self {
73        self.eviction_threshold = threshold.clamp(0.0, 1.0);
74        self
75    }
76
77    /// Get block size.
78    pub fn block_size(&self) -> usize {
79        self.block_size
80    }
81
82    /// Get total number of blocks.
83    pub fn total_blocks(&self) -> usize {
84        self.num_blocks
85    }
86
87    /// Get number of free blocks.
88    pub fn free_block_count(&self) -> usize {
89        self.free_blocks.len()
90    }
91
92    /// Get number of used blocks.
93    pub fn used_block_count(&self) -> usize {
94        self.num_blocks - self.free_blocks.len()
95    }
96
97    /// Memory utilization percentage (0.0-1.0).
98    pub fn utilization(&self) -> f64 {
99        if self.num_blocks == 0 {
100            return 0.0;
101        }
102        self.used_block_count() as f64 / self.num_blocks as f64
103    }
104
105    /// Calculate memory for a block in bytes.
106    pub fn block_memory_bytes(&self) -> usize {
107        // KV cache: 2 (K+V) * block_size * num_heads * head_dim * 2 (f16)
108        2 * self.block_size * self.num_heads * self.head_dim * 2
109    }
110
111    /// Total memory capacity in bytes.
112    pub fn total_memory_bytes(&self) -> usize {
113        self.num_blocks * self.block_memory_bytes()
114    }
115
116    /// Used memory in bytes.
117    pub fn used_memory_bytes(&self) -> usize {
118        self.used_block_count() * self.block_memory_bytes()
119    }
120
121    /// Check if eviction is needed.
122    pub fn needs_eviction(&self) -> bool {
123        self.utilization() >= self.eviction_threshold
124    }
125
126    /// Get number of active sequences.
127    pub fn num_sequences(&self) -> usize {
128        self.sequences.len()
129    }
130
131    /// Get sequence info.
132    pub fn get_sequence(&self, seq_id: SeqId) -> Option<&SequenceInfo> {
133        self.sequences.get(&seq_id)
134    }
135
136    /// Get cache statistics.
137    pub fn stats(&self) -> &CacheStats {
138        &self.stats
139    }
140
141    /// Get eviction strategy.
142    pub fn eviction_strategy(&self) -> &EvictionStrategy {
143        &self.eviction_strategy
144    }
145
146    /// Calculate blocks needed for tokens.
147    fn blocks_needed(&self, num_tokens: usize) -> usize {
148        num_tokens.div_ceil(self.block_size)
149    }
150
151    /// Allocate a single block.
152    fn allocate_block(&mut self) -> PagedKvResult<BlockId> {
153        if let Some(block_id) = self.free_blocks.pop_front() {
154            self.block_refs.insert(block_id, 1);
155            self.stats.total_allocations += 1;
156
157            // Track peak usage
158            let used = self.used_block_count();
159            if used > self.stats.peak_blocks_used {
160                self.stats.peak_blocks_used = used;
161            }
162
163            Ok(block_id)
164        } else {
165            Err(PagedKvError::OutOfMemory {
166                requested: 1,
167                available: 0,
168            })
169        }
170    }
171
172    /// Free a single block.
173    fn free_block(&mut self, block_id: BlockId) -> PagedKvResult<()> {
174        if let Some(refs) = self.block_refs.get_mut(&block_id) {
175            *refs -= 1;
176            if *refs == 0 {
177                self.block_refs.remove(&block_id);
178                self.free_blocks.push_back(block_id);
179                self.stats.total_frees += 1;
180            }
181            Ok(())
182        } else {
183            Err(PagedKvError::BlockNotFound(block_id))
184        }
185    }
186
187    /// Allocate blocks for a new sequence.
188    pub fn allocate(&mut self, seq_id: SeqId, num_tokens: usize) -> PagedKvResult<()> {
189        if self.sequences.contains_key(&seq_id) {
190            return Err(PagedKvError::InvalidOperation(format!(
191                "Sequence {} already exists",
192                seq_id
193            )));
194        }
195
196        let blocks_needed = self.blocks_needed(num_tokens);
197
198        // Check if we have enough blocks
199        if blocks_needed > self.free_blocks.len() {
200            return Err(PagedKvError::OutOfMemory {
201                requested: blocks_needed,
202                available: self.free_blocks.len(),
203            });
204        }
205
206        // Allocate blocks
207        let mut block_ids = Vec::with_capacity(blocks_needed);
208        for _ in 0..blocks_needed {
209            block_ids.push(self.allocate_block()?);
210        }
211
212        // Create sequence info
213        let mut seq_info = SequenceInfo::new(seq_id);
214        seq_info.num_tokens = num_tokens;
215        seq_info.block_ids = block_ids;
216        seq_info.touch();
217
218        self.sequences.insert(seq_id, seq_info);
219        Ok(())
220    }
221
222    /// Append tokens to an existing sequence.
223    pub fn append(&mut self, seq_id: SeqId, num_new_tokens: usize) -> PagedKvResult<()> {
224        // First, calculate how many blocks we need (immutably)
225        let (old_tokens, additional_blocks) = {
226            let seq_info = self
227                .sequences
228                .get(&seq_id)
229                .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
230
231            let old_tokens = seq_info.num_tokens;
232            let new_tokens = old_tokens + num_new_tokens;
233            let old_blocks = self.blocks_needed(old_tokens);
234            let new_blocks = self.blocks_needed(new_tokens);
235            let additional = new_blocks.saturating_sub(old_blocks);
236
237            (old_tokens, additional)
238        };
239
240        // Check if we have enough blocks
241        if additional_blocks > self.free_blocks.len() {
242            return Err(PagedKvError::OutOfMemory {
243                requested: additional_blocks,
244                available: self.free_blocks.len(),
245            });
246        }
247
248        // Allocate the blocks
249        let mut new_block_ids = Vec::with_capacity(additional_blocks);
250        for _ in 0..additional_blocks {
251            new_block_ids.push(self.allocate_block()?);
252        }
253
254        // Update sequence info
255        let seq_info = self
256            .sequences
257            .get_mut(&seq_id)
258            .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
259
260        seq_info.block_ids.extend(new_block_ids);
261        seq_info.num_tokens = old_tokens + num_new_tokens;
262        seq_info.touch();
263        Ok(())
264    }
265
266    /// Free all blocks for a sequence.
267    pub fn free(&mut self, seq_id: SeqId) -> PagedKvResult<()> {
268        let seq_info = self
269            .sequences
270            .remove(&seq_id)
271            .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
272
273        for block_id in seq_info.block_ids {
274            self.free_block(block_id)?;
275        }
276
277        Ok(())
278    }
279
280    /// Copy-on-write fork for beam search.
281    ///
282    /// Creates a new sequence that shares blocks with the source sequence.
283    /// Blocks are only copied when modified (copy-on-write).
284    pub fn fork(&mut self, src_seq: SeqId, dst_seq: SeqId) -> PagedKvResult<()> {
285        if self.sequences.contains_key(&dst_seq) {
286            return Err(PagedKvError::InvalidOperation(format!(
287                "Destination sequence {} already exists",
288                dst_seq
289            )));
290        }
291
292        let src_info = self
293            .sequences
294            .get(&src_seq)
295            .ok_or(PagedKvError::SequenceNotFound(src_seq))?
296            .clone();
297
298        // Increment reference counts for shared blocks
299        for block_id in &src_info.block_ids {
300            if let Some(refs) = self.block_refs.get_mut(block_id) {
301                *refs += 1;
302            }
303        }
304
305        // Create new sequence with shared blocks
306        let mut dst_info = SequenceInfo::new(dst_seq);
307        dst_info.num_tokens = src_info.num_tokens;
308        dst_info.block_ids = src_info.block_ids.clone();
309        dst_info.touch();
310
311        self.sequences.insert(dst_seq, dst_info);
312        self.stats.total_forks += 1;
313        Ok(())
314    }
315
316    /// Select sequence to evict based on strategy.
317    pub fn select_eviction_target(&self) -> Option<SeqId> {
318        if self.sequences.is_empty() {
319            return None;
320        }
321
322        match &self.eviction_strategy {
323            EvictionStrategy::LRU => {
324                // Evict least recently used
325                self.sequences
326                    .values()
327                    .min_by_key(|s| s.last_access)
328                    .map(|s| s.seq_id)
329            }
330            EvictionStrategy::LFU => {
331                // Evict least frequently used
332                self.sequences
333                    .values()
334                    .min_by_key(|s| s.access_count)
335                    .map(|s| s.seq_id)
336            }
337            EvictionStrategy::LongestFirst => {
338                // Evict longest sequence (most blocks)
339                self.sequences
340                    .values()
341                    .max_by_key(|s| s.num_tokens)
342                    .map(|s| s.seq_id)
343            }
344            EvictionStrategy::Priority { .. } => {
345                // Evict lowest priority
346                self.sequences
347                    .values()
348                    .min_by_key(|s| s.priority)
349                    .map(|s| s.seq_id)
350            }
351            EvictionStrategy::StreamingLLM { .. } => {
352                // StreamingLLM doesn't evict sequences, it evicts tokens
353                // For simplicity, fall back to LRU for sequence eviction
354                self.sequences
355                    .values()
356                    .min_by_key(|s| s.last_access)
357                    .map(|s| s.seq_id)
358            }
359        }
360    }
361
362    /// Evict a sequence to free memory.
363    pub fn evict(&mut self) -> PagedKvResult<SeqId> {
364        let target = self
365            .select_eviction_target()
366            .ok_or(PagedKvError::InvalidOperation(
367                "No sequences to evict".to_string(),
368            ))?;
369
370        self.free(target)?;
371        self.stats.total_evictions += 1;
372        Ok(target)
373    }
374
375    /// Evict until memory utilization is below threshold.
376    pub fn evict_to_threshold(&mut self, target_util: f64) -> PagedKvResult<Vec<SeqId>> {
377        let mut evicted = Vec::new();
378        while self.utilization() > target_util && !self.sequences.is_empty() {
379            evicted.push(self.evict()?);
380        }
381        Ok(evicted)
382    }
383
384    /// Apply StreamingLLM eviction to a sequence.
385    ///
386    /// Keeps sink tokens at the beginning and a recent window at the end,
387    /// evicting middle tokens.
388    pub fn apply_streaming_llm(
389        &mut self,
390        seq_id: SeqId,
391        sink_tokens: usize,
392        window_tokens: usize,
393    ) -> PagedKvResult<usize> {
394        // Get sequence info immutably first to compute values
395        let (num_tokens, blocks_to_remove) = {
396            let seq_info = self
397                .sequences
398                .get(&seq_id)
399                .ok_or(PagedKvError::SequenceNotFound(seq_id))?;
400
401            let keep_tokens = sink_tokens + window_tokens;
402            if seq_info.num_tokens <= keep_tokens {
403                return Ok(0); // Nothing to evict
404            }
405
406            let old_blocks = self.blocks_needed(seq_info.num_tokens);
407            let new_blocks = self.blocks_needed(keep_tokens);
408            let blocks_to_free = old_blocks.saturating_sub(new_blocks);
409
410            // Collect blocks to remove
411            let blocks: Vec<BlockId> = seq_info
412                .block_ids
413                .iter()
414                .skip(sink_tokens / self.block_size + 1)
415                .take(blocks_to_free)
416                .cloned()
417                .collect();
418
419            (seq_info.num_tokens, blocks)
420        };
421
422        let keep_tokens = sink_tokens + window_tokens;
423        let evict_tokens = num_tokens - keep_tokens;
424
425        // Free the blocks
426        for block_id in &blocks_to_remove {
427            self.free_block(*block_id)?;
428        }
429
430        // Update sequence info
431        if let Some(seq_info) = self.sequences.get_mut(&seq_id) {
432            for block_id in blocks_to_remove {
433                seq_info.block_ids.retain(|&id| id != block_id);
434            }
435            seq_info.num_tokens = keep_tokens;
436        }
437
438        Ok(evict_tokens)
439    }
440
441    /// Get all sequence IDs.
442    pub fn sequence_ids(&self) -> Vec<SeqId> {
443        self.sequences.keys().cloned().collect()
444    }
445}
446
447impl fmt::Display for PagedKvCache {
448    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449        writeln!(f, "PagedKvCache")?;
450        writeln!(
451            f,
452            "  Strategy: {} (block_size={})",
453            self.eviction_strategy, self.block_size
454        )?;
455        writeln!(
456            f,
457            "  Blocks: {}/{} ({:.1}% used)",
458            self.used_block_count(),
459            self.num_blocks,
460            self.utilization() * 100.0
461        )?;
462        writeln!(
463            f,
464            "  Memory: {:.2} MB / {:.2} MB",
465            self.used_memory_bytes() as f64 / 1_000_000.0,
466            self.total_memory_bytes() as f64 / 1_000_000.0
467        )?;
468        writeln!(f, "  Sequences: {} active", self.num_sequences())?;
469        writeln!(
470            f,
471            "  Stats: allocs={}, frees={}, evictions={}, forks={}",
472            self.stats.total_allocations,
473            self.stats.total_frees,
474            self.stats.total_evictions,
475            self.stats.total_forks
476        )?;
477        Ok(())
478    }
479}