Skip to main content

inference_lab/kv_cache/
manager.rs

1use super::block::Block;
2use crate::request::{BlockId, Request};
3use std::collections::HashMap;
4
5/// Manages KV cache blocks for all requests
6pub struct KVCacheManager {
7    /// Block size in tokens
8    block_size: u32,
9
10    /// Total number of blocks available
11    total_blocks: u32,
12
13    /// All blocks
14    blocks: Vec<Block>,
15
16    /// Free blocks (indices into blocks vec)
17    free_blocks: Vec<BlockId>,
18
19    /// Enable prefix caching
20    enable_prefix_caching: bool,
21
22    /// Prefix cache: maps block hash -> block_id
23    /// Its a big flat hash map - the hashes are supposed to be incremental hashes of all of the
24    /// tokens up to a certain point. If a sequence with block (1, 2, 3, 4, 5) is in the cache,
25    /// then we get the prefix cache entry by doing cat(cache[1], cache[2], ...).
26    prefix_cache: HashMap<u64, BlockId>,
27
28    /// Metrics
29    pub num_prefix_cache_hits: u64,
30    pub num_prefix_cache_misses: u64,
31    pub hit_size_count: u64,
32    pub hit_size_sum: u64,
33}
34
35impl KVCacheManager {
36    /// Create a new KV cache manager
37    pub fn new(
38        kv_cache_capacity: u64,
39        block_size: u32,
40        kv_cache_bytes_per_token: u64,
41        enable_prefix_caching: bool,
42    ) -> Self {
43        let bytes_per_block = block_size as u64 * kv_cache_bytes_per_token;
44        let total_blocks = (kv_cache_capacity / bytes_per_block) as u32;
45
46        let blocks = (0..total_blocks).map(Block::new).collect();
47
48        let free_blocks = (0..total_blocks).collect();
49
50        Self {
51            block_size,
52            total_blocks,
53            blocks,
54            free_blocks,
55            enable_prefix_caching,
56            prefix_cache: HashMap::new(),
57            num_prefix_cache_hits: 0,
58            num_prefix_cache_misses: 0,
59            hit_size_count: 0,
60            hit_size_sum: 0,
61        }
62    }
63
64    /// Try to allocate blocks for a request
65    /// Returns Some(Vec<BlockId>) if successful, None if insufficient blocks
66    pub fn allocate_blocks(&mut self, request: &Request, num_tokens: u32) -> Option<Vec<BlockId>> {
67        let blocks_needed = self.calculate_blocks_needed(request, num_tokens);
68        let hashes = request.get_prompt_block_hashes();
69
70        if self.free_blocks.len() < blocks_needed {
71            return None; // Not enough blocks
72        }
73
74        let mut allocated = Vec::new();
75        let mut evicted_hashes = Vec::new();
76        for i in 0..blocks_needed {
77            let block_id = self.free_blocks.pop().unwrap();
78            let evicted_hash = self.blocks[block_id as usize].allocate(hashes.get(i).cloned());
79            evicted_hashes.extend(evicted_hash);
80            allocated.push(block_id);
81        }
82
83        // Update prefix cache with newly allocated/deallocated blocks
84        if self.enable_prefix_caching {
85            // remove all the content hashes that were overwritten
86            for hash in evicted_hashes {
87                self.prefix_cache.remove(&hash);
88            }
89            // Store the new block hashes
90            for (i, &hash) in hashes.iter().enumerate() {
91                if let Some(&block_id) = allocated.get(i) {
92                    self.prefix_cache.insert(hash, block_id);
93                }
94            }
95        };
96
97        Some(allocated)
98    }
99
100    /// Calculate how many new blocks are needed for a request
101    fn calculate_blocks_needed(&self, request: &Request, num_new_tokens: u32) -> usize {
102        let total_tokens = request.num_computed_tokens + num_new_tokens;
103        let total_blocks_needed = total_tokens.div_ceil(self.block_size) as usize;
104        total_blocks_needed.saturating_sub(request.kv_blocks.len())
105    }
106
107    /// Free blocks from a request (due to preemption or completion)
108    pub fn free_blocks(&mut self, block_ids: &[BlockId]) {
109        for &block_id in block_ids {
110            let block = &mut self.blocks[block_id as usize];
111            block.release();
112
113            if block.is_free {
114                self.free_blocks.push(block_id);
115            }
116        }
117    }
118
119    /// Get number of free blocks
120    pub fn num_free_blocks(&self) -> usize {
121        self.free_blocks.len()
122    }
123
124    /// Get total number of blocks
125    pub fn total_blocks(&self) -> usize {
126        self.total_blocks as usize
127    }
128
129    /// Get cache utilization (0.0 to 1.0)
130    pub fn utilization(&self) -> f64 {
131        1.0 - (self.free_blocks.len() as f64 / self.total_blocks as f64)
132    }
133
134    /// Check for prefix cache hits
135    /// Returns the number of tokens that can be served from the cache
136    pub fn peek_prefix_cache(&mut self, request: &Request) -> u32 {
137        if !self.enable_prefix_caching {
138            return 0;
139        }
140
141        // Get block hashes from the request
142        let block_hashes = request.get_prompt_block_hashes();
143
144        if block_hashes.is_empty() {
145            // If there are no block hashes, then theres no caching, so don't increment anything
146            return 0;
147        }
148
149        // Check consecutive blocks from the start until we find a miss
150        let mut cached_blocks = 0;
151        for &hash in block_hashes {
152            if self.prefix_cache.contains_key(&hash) {
153                cached_blocks += 1;
154            } else {
155                // First cache miss = end of cached prefix
156                break;
157            }
158        }
159
160        cached_blocks * self.block_size
161    }
162
163    pub fn query_prefix_cache(&mut self, request: &Request) -> u32 {
164        let tokens = self.peek_prefix_cache(request);
165        self.hit_size_count += 1;
166        self.hit_size_sum += tokens as u64;
167
168        if tokens == 0 {
169            self.num_prefix_cache_misses += 1;
170        } else {
171            self.num_prefix_cache_hits += 1;
172        }
173        tokens
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    fn create_test_request(id: &str, prompt_tokens: u32) -> Request {
182        Request::new(id.to_string(), 0, 0.0, prompt_tokens, 50)
183    }
184
185    #[test]
186    fn test_kv_cache_manager_creation() {
187        // Create a manager with capacity for 10 blocks, block size = 16 tokens, 100 bytes per token
188        let manager = KVCacheManager::new(16000, 16, 100, false);
189
190        assert_eq!(manager.block_size, 16);
191        assert_eq!(manager.total_blocks, 10); // 16000 / (16 * 100) = 10
192        assert_eq!(manager.num_free_blocks(), 10);
193        assert_eq!(manager.utilization(), 0.0);
194    }
195
196    #[test]
197    fn test_block_allocation() {
198        let mut manager = KVCacheManager::new(16000, 16, 100, false);
199        let mut request = create_test_request("req-1", 32);
200
201        // Allocate blocks for 32 tokens (should need 2 blocks of size 16)
202        let allocated = manager.allocate_blocks(&request, 32);
203        assert!(allocated.is_some());
204
205        let blocks = allocated.unwrap();
206        assert_eq!(blocks.len(), 2);
207        assert_eq!(manager.num_free_blocks(), 8);
208
209        request.kv_blocks.extend(blocks);
210        request.num_computed_tokens = 32; // Update state
211
212        // Try to allocate more tokens for the same request
213        let more_blocks = manager.allocate_blocks(&request, 16);
214        assert!(more_blocks.is_some());
215        assert_eq!(more_blocks.unwrap().len(), 1); // Need 1 more block
216        assert_eq!(manager.num_free_blocks(), 7);
217    }
218
219    #[test]
220    fn test_block_allocation_failure() {
221        let mut manager = KVCacheManager::new(1600, 16, 100, false);
222        // Only 1 block available
223        assert_eq!(manager.total_blocks, 1);
224
225        let request = create_test_request("req-1", 32);
226
227        // Try to allocate 32 tokens (need 2 blocks, but only 1 available)
228        let allocated = manager.allocate_blocks(&request, 32);
229        assert!(allocated.is_none());
230    }
231
232    #[test]
233    fn test_block_free() {
234        let mut manager = KVCacheManager::new(16000, 16, 100, false);
235        let request = create_test_request("req-1", 32);
236
237        let blocks = manager.allocate_blocks(&request, 32).unwrap();
238        assert_eq!(manager.num_free_blocks(), 8);
239
240        manager.free_blocks(&blocks);
241        assert_eq!(manager.num_free_blocks(), 10);
242        assert_eq!(manager.utilization(), 0.0);
243    }
244
245    #[test]
246    fn test_utilization() {
247        let mut manager = KVCacheManager::new(16000, 16, 100, false);
248        assert_eq!(manager.utilization(), 0.0);
249
250        let request = create_test_request("req-1", 32);
251        let blocks = manager.allocate_blocks(&request, 32).unwrap();
252
253        // 2 out of 10 blocks used
254        let util = manager.utilization();
255        assert!((util - 0.2).abs() < 1e-10);
256
257        manager.free_blocks(&blocks);
258        assert_eq!(manager.utilization(), 0.0);
259    }
260
261    #[test]
262    fn test_prefix_caching() {
263        let mut manager = KVCacheManager::new(16000, 16, 100, true);
264
265        // Create first request with a block hash
266        let mut request1 = create_test_request("req-1", 16);
267        request1.prompt_block_hashes = vec![12345]; // Synthetic hash for 1 block
268
269        // First check - should miss (hash not in cache yet)
270        let cached = manager.query_prefix_cache(&request1);
271        assert_eq!(cached, 0);
272        assert_eq!(manager.num_prefix_cache_misses, 1);
273
274        // Now allocate blocks (this adds to cache)
275        let blocks1 = manager.allocate_blocks(&request1, 16).unwrap();
276        request1.kv_blocks.extend(blocks1);
277
278        // Second request with same block hash - should hit
279        let mut request2 = create_test_request("req-2", 16);
280        request2.prompt_block_hashes = vec![12345]; // Same hash = shared prefix
281        let cached = manager.query_prefix_cache(&request2);
282        assert_eq!(cached, 16); // 1 block * 16 tokens per block
283        assert_eq!(manager.num_prefix_cache_hits, 1);
284
285        // Third request with different hash - should miss
286        let mut request3 = create_test_request("req-3", 16);
287        request3.prompt_block_hashes = vec![67890]; // Different hash
288        let cached = manager.query_prefix_cache(&request3);
289        assert_eq!(cached, 0);
290        assert_eq!(manager.num_prefix_cache_misses, 2);
291    }
292}