inference_lab/kv_cache/
manager.rs

1use super::block::Block;
2use crate::request::{BlockId, Request};
3use std::collections::HashMap;
4
5/// Manages KV cache blocks for all requests
6pub struct KVCacheManager {
7    /// Block size in tokens
8    block_size: u32,
9
10    /// Total number of blocks available
11    total_blocks: u32,
12
13    /// All blocks
14    blocks: Vec<Block>,
15
16    /// Free blocks (indices into blocks vec)
17    free_blocks: Vec<BlockId>,
18
19    /// Enable prefix caching
20    enable_prefix_caching: bool,
21
22    /// Prefix cache: hash -> block_id
23    prefix_cache: HashMap<u64, BlockId>,
24
25    /// Metrics
26    pub num_prefix_cache_hits: u64,
27    pub num_prefix_cache_misses: u64,
28}
29
30impl KVCacheManager {
31    /// Create a new KV cache manager
32    pub fn new(
33        kv_cache_capacity: u64,
34        block_size: u32,
35        kv_cache_bytes_per_token: u64,
36        enable_prefix_caching: bool,
37    ) -> Self {
38        let bytes_per_block = block_size as u64 * kv_cache_bytes_per_token;
39        let total_blocks = (kv_cache_capacity / bytes_per_block) as u32;
40
41        let blocks = (0..total_blocks).map(Block::new).collect();
42
43        let free_blocks = (0..total_blocks).collect();
44
45        Self {
46            block_size,
47            total_blocks,
48            blocks,
49            free_blocks,
50            enable_prefix_caching,
51            prefix_cache: HashMap::new(),
52            num_prefix_cache_hits: 0,
53            num_prefix_cache_misses: 0,
54        }
55    }
56
57    /// Try to allocate blocks for a request
58    /// Returns Some(Vec<BlockId>) if successful, None if insufficient blocks
59    pub fn allocate_blocks(&mut self, request: &Request, num_tokens: u32) -> Option<Vec<BlockId>> {
60        let blocks_needed = self.calculate_blocks_needed(request, num_tokens);
61
62        if self.free_blocks.len() < blocks_needed {
63            return None; // Not enough blocks
64        }
65
66        let mut allocated = Vec::new();
67        for _ in 0..blocks_needed {
68            let block_id = self.free_blocks.pop().unwrap();
69            self.blocks[block_id as usize].allocate();
70            allocated.push(block_id);
71        }
72
73        Some(allocated)
74    }
75
76    /// Calculate how many new blocks are needed for a request
77    fn calculate_blocks_needed(&self, request: &Request, num_new_tokens: u32) -> usize {
78        let total_tokens = request.num_computed_tokens + num_new_tokens;
79        let total_blocks_needed =
80            ((total_tokens + self.block_size - 1) / self.block_size) as usize;
81        total_blocks_needed.saturating_sub(request.kv_blocks.len())
82    }
83
84    /// Free blocks from a request (due to preemption or completion)
85    pub fn free_blocks(&mut self, block_ids: &[BlockId]) {
86        for &block_id in block_ids {
87            let block = &mut self.blocks[block_id as usize];
88            block.release();
89
90            if block.is_free {
91                self.free_blocks.push(block_id);
92            }
93        }
94    }
95
96    /// Get number of free blocks
97    pub fn num_free_blocks(&self) -> usize {
98        self.free_blocks.len()
99    }
100
101    /// Get cache utilization (0.0 to 1.0)
102    pub fn utilization(&self) -> f64 {
103        1.0 - (self.free_blocks.len() as f64 / self.total_blocks as f64)
104    }
105
106    /// Check for prefix cache hits (simplified hash-based implementation)
107    pub fn check_prefix_cache(&mut self, request: &Request) -> u32 {
108        if !self.enable_prefix_caching {
109            return 0;
110        }
111
112        // Simplified: hash the prompt tokens
113        let prompt_hash = self.hash_prompt(&request.request_id, request.num_prompt_tokens);
114
115        if let Some(&_block_id) = self.prefix_cache.get(&prompt_hash) {
116            self.num_prefix_cache_hits += 1;
117            // Return number of cached tokens
118            // For simplicity, assume full blocks are cached up to block_size
119            self.block_size.min(request.num_prompt_tokens)
120        } else {
121            self.num_prefix_cache_misses += 1;
122            // Cache the prompt for future requests
123            if let Some(&first_block) = request.kv_blocks.first() {
124                self.prefix_cache.insert(prompt_hash, first_block);
125            }
126            0
127        }
128    }
129
130    /// Hash a prompt for prefix caching
131    fn hash_prompt(&self, request_id: &str, num_tokens: u32) -> u64 {
132        // Simplified hash - in reality would hash actual token IDs
133        use std::collections::hash_map::DefaultHasher;
134        use std::hash::{Hash, Hasher};
135
136        let mut hasher = DefaultHasher::new();
137        request_id.hash(&mut hasher);
138        num_tokens.hash(&mut hasher);
139        hasher.finish()
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146
147    fn create_test_request(id: &str, prompt_tokens: u32) -> Request {
148        Request::new(id.to_string(), 0, 0.0, prompt_tokens, 50)
149    }
150
151    #[test]
152    fn test_kv_cache_manager_creation() {
153        // Create a manager with capacity for 10 blocks, block size = 16 tokens, 100 bytes per token
154        let manager = KVCacheManager::new(16000, 16, 100, false);
155
156        assert_eq!(manager.block_size, 16);
157        assert_eq!(manager.total_blocks, 10); // 16000 / (16 * 100) = 10
158        assert_eq!(manager.num_free_blocks(), 10);
159        assert_eq!(manager.utilization(), 0.0);
160    }
161
162    #[test]
163    fn test_block_allocation() {
164        let mut manager = KVCacheManager::new(16000, 16, 100, false);
165        let mut request = create_test_request("req-1", 32);
166
167        // Allocate blocks for 32 tokens (should need 2 blocks of size 16)
168        let allocated = manager.allocate_blocks(&request, 32);
169        assert!(allocated.is_some());
170
171        let blocks = allocated.unwrap();
172        assert_eq!(blocks.len(), 2);
173        assert_eq!(manager.num_free_blocks(), 8);
174
175        request.kv_blocks.extend(blocks);
176        request.num_computed_tokens = 32; // Update state
177
178        // Try to allocate more tokens for the same request
179        let more_blocks = manager.allocate_blocks(&request, 16);
180        assert!(more_blocks.is_some());
181        assert_eq!(more_blocks.unwrap().len(), 1); // Need 1 more block
182        assert_eq!(manager.num_free_blocks(), 7);
183    }
184
185    #[test]
186    fn test_block_allocation_failure() {
187        let mut manager = KVCacheManager::new(1600, 16, 100, false);
188        // Only 1 block available
189        assert_eq!(manager.total_blocks, 1);
190
191        let request = create_test_request("req-1", 32);
192
193        // Try to allocate 32 tokens (need 2 blocks, but only 1 available)
194        let allocated = manager.allocate_blocks(&request, 32);
195        assert!(allocated.is_none());
196    }
197
198    #[test]
199    fn test_block_free() {
200        let mut manager = KVCacheManager::new(16000, 16, 100, false);
201        let request = create_test_request("req-1", 32);
202
203        let blocks = manager.allocate_blocks(&request, 32).unwrap();
204        assert_eq!(manager.num_free_blocks(), 8);
205
206        manager.free_blocks(&blocks);
207        assert_eq!(manager.num_free_blocks(), 10);
208        assert_eq!(manager.utilization(), 0.0);
209    }
210
211    #[test]
212    fn test_utilization() {
213        let mut manager = KVCacheManager::new(16000, 16, 100, false);
214        assert_eq!(manager.utilization(), 0.0);
215
216        let request = create_test_request("req-1", 32);
217        let blocks = manager.allocate_blocks(&request, 32).unwrap();
218
219        // 2 out of 10 blocks used
220        let util = manager.utilization();
221        assert!((util - 0.2).abs() < 1e-10);
222
223        manager.free_blocks(&blocks);
224        assert_eq!(manager.utilization(), 0.0);
225    }
226
227    #[test]
228    fn test_prefix_caching() {
229        let mut manager = KVCacheManager::new(16000, 16, 100, true);
230
231        let mut request1 = create_test_request("req-1", 16);
232        let blocks1 = manager.allocate_blocks(&request1, 16).unwrap();
233        request1.kv_blocks.extend(blocks1);
234
235        // First check - should miss
236        let cached = manager.check_prefix_cache(&request1);
237        assert_eq!(cached, 0);
238        assert_eq!(manager.num_prefix_cache_misses, 1);
239
240        // Second check with same prompt - should hit
241        let request2 = create_test_request("req-1", 16);
242        let cached = manager.check_prefix_cache(&request2);
243        assert_eq!(cached, 16);
244        assert_eq!(manager.num_prefix_cache_hits, 1);
245    }
246}