inference_lab/kv_cache/
manager.rs

1use super::block::Block;
2use crate::request::{BlockId, Request};
3use std::collections::HashMap;
4
5/// Manages KV cache blocks for all requests
6pub struct KVCacheManager {
7    /// Block size in tokens
8    block_size: u32,
9
10    /// Total number of blocks available
11    total_blocks: u32,
12
13    /// All blocks
14    blocks: Vec<Block>,
15
16    /// Free blocks (indices into blocks vec)
17    free_blocks: Vec<BlockId>,
18
19    /// Enable prefix caching
20    enable_prefix_caching: bool,
21
22    /// Prefix cache: hash -> block_id
23    prefix_cache: HashMap<u64, BlockId>,
24
25    /// Metrics
26    pub num_prefix_cache_hits: u64,
27    pub num_prefix_cache_misses: u64,
28}
29
30impl KVCacheManager {
31    /// Create a new KV cache manager
32    pub fn new(
33        kv_cache_capacity: u64,
34        block_size: u32,
35        kv_cache_bytes_per_token: u64,
36        enable_prefix_caching: bool,
37    ) -> Self {
38        let bytes_per_block = block_size as u64 * kv_cache_bytes_per_token;
39        let total_blocks = (kv_cache_capacity / bytes_per_block) as u32;
40
41        let blocks = (0..total_blocks).map(Block::new).collect();
42
43        let free_blocks = (0..total_blocks).collect();
44
45        Self {
46            block_size,
47            total_blocks,
48            blocks,
49            free_blocks,
50            enable_prefix_caching,
51            prefix_cache: HashMap::new(),
52            num_prefix_cache_hits: 0,
53            num_prefix_cache_misses: 0,
54        }
55    }
56
57    /// Try to allocate blocks for a request
58    /// Returns Some(Vec<BlockId>) if successful, None if insufficient blocks
59    pub fn allocate_blocks(&mut self, request: &Request, num_tokens: u32) -> Option<Vec<BlockId>> {
60        let blocks_needed = self.calculate_blocks_needed(request, num_tokens);
61
62        if self.free_blocks.len() < blocks_needed {
63            return None; // Not enough blocks
64        }
65
66        let mut allocated = Vec::new();
67        for _ in 0..blocks_needed {
68            let block_id = self.free_blocks.pop().unwrap();
69            self.blocks[block_id as usize].allocate();
70            allocated.push(block_id);
71        }
72
73        Some(allocated)
74    }
75
76    /// Calculate how many new blocks are needed for a request
77    fn calculate_blocks_needed(&self, request: &Request, num_new_tokens: u32) -> usize {
78        let total_tokens = request.num_computed_tokens + num_new_tokens;
79        let total_blocks_needed = ((total_tokens + self.block_size - 1) / self.block_size) as usize;
80        total_blocks_needed.saturating_sub(request.kv_blocks.len())
81    }
82
83    /// Free blocks from a request (due to preemption or completion)
84    pub fn free_blocks(&mut self, block_ids: &[BlockId]) {
85        for &block_id in block_ids {
86            let block = &mut self.blocks[block_id as usize];
87            block.release();
88
89            if block.is_free {
90                self.free_blocks.push(block_id);
91            }
92        }
93    }
94
95    /// Get number of free blocks
96    pub fn num_free_blocks(&self) -> usize {
97        self.free_blocks.len()
98    }
99
100    /// Get cache utilization (0.0 to 1.0)
101    pub fn utilization(&self) -> f64 {
102        1.0 - (self.free_blocks.len() as f64 / self.total_blocks as f64)
103    }
104
105    /// Check for prefix cache hits (simplified hash-based implementation)
106    pub fn check_prefix_cache(&mut self, request: &Request) -> u32 {
107        if !self.enable_prefix_caching {
108            return 0;
109        }
110
111        // Simplified: hash the prompt tokens
112        let prompt_hash = self.hash_prompt(&request.request_id, request.num_prompt_tokens);
113
114        if let Some(&_block_id) = self.prefix_cache.get(&prompt_hash) {
115            self.num_prefix_cache_hits += 1;
116            // Return number of cached tokens
117            // For simplicity, assume full blocks are cached up to block_size
118            self.block_size.min(request.num_prompt_tokens)
119        } else {
120            self.num_prefix_cache_misses += 1;
121            // Cache the prompt for future requests
122            if let Some(&first_block) = request.kv_blocks.first() {
123                self.prefix_cache.insert(prompt_hash, first_block);
124            }
125            0
126        }
127    }
128
129    /// Hash a prompt for prefix caching
130    fn hash_prompt(&self, request_id: &str, num_tokens: u32) -> u64 {
131        // Simplified hash - in reality would hash actual token IDs
132        use std::collections::hash_map::DefaultHasher;
133        use std::hash::{Hash, Hasher};
134
135        let mut hasher = DefaultHasher::new();
136        request_id.hash(&mut hasher);
137        num_tokens.hash(&mut hasher);
138        hasher.finish()
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    fn create_test_request(id: &str, prompt_tokens: u32) -> Request {
147        Request::new(id.to_string(), 0, 0.0, prompt_tokens, 50)
148    }
149
150    #[test]
151    fn test_kv_cache_manager_creation() {
152        // Create a manager with capacity for 10 blocks, block size = 16 tokens, 100 bytes per token
153        let manager = KVCacheManager::new(16000, 16, 100, false);
154
155        assert_eq!(manager.block_size, 16);
156        assert_eq!(manager.total_blocks, 10); // 16000 / (16 * 100) = 10
157        assert_eq!(manager.num_free_blocks(), 10);
158        assert_eq!(manager.utilization(), 0.0);
159    }
160
161    #[test]
162    fn test_block_allocation() {
163        let mut manager = KVCacheManager::new(16000, 16, 100, false);
164        let mut request = create_test_request("req-1", 32);
165
166        // Allocate blocks for 32 tokens (should need 2 blocks of size 16)
167        let allocated = manager.allocate_blocks(&request, 32);
168        assert!(allocated.is_some());
169
170        let blocks = allocated.unwrap();
171        assert_eq!(blocks.len(), 2);
172        assert_eq!(manager.num_free_blocks(), 8);
173
174        request.kv_blocks.extend(blocks);
175        request.num_computed_tokens = 32; // Update state
176
177        // Try to allocate more tokens for the same request
178        let more_blocks = manager.allocate_blocks(&request, 16);
179        assert!(more_blocks.is_some());
180        assert_eq!(more_blocks.unwrap().len(), 1); // Need 1 more block
181        assert_eq!(manager.num_free_blocks(), 7);
182    }
183
184    #[test]
185    fn test_block_allocation_failure() {
186        let mut manager = KVCacheManager::new(1600, 16, 100, false);
187        // Only 1 block available
188        assert_eq!(manager.total_blocks, 1);
189
190        let request = create_test_request("req-1", 32);
191
192        // Try to allocate 32 tokens (need 2 blocks, but only 1 available)
193        let allocated = manager.allocate_blocks(&request, 32);
194        assert!(allocated.is_none());
195    }
196
197    #[test]
198    fn test_block_free() {
199        let mut manager = KVCacheManager::new(16000, 16, 100, false);
200        let request = create_test_request("req-1", 32);
201
202        let blocks = manager.allocate_blocks(&request, 32).unwrap();
203        assert_eq!(manager.num_free_blocks(), 8);
204
205        manager.free_blocks(&blocks);
206        assert_eq!(manager.num_free_blocks(), 10);
207        assert_eq!(manager.utilization(), 0.0);
208    }
209
210    #[test]
211    fn test_utilization() {
212        let mut manager = KVCacheManager::new(16000, 16, 100, false);
213        assert_eq!(manager.utilization(), 0.0);
214
215        let request = create_test_request("req-1", 32);
216        let blocks = manager.allocate_blocks(&request, 32).unwrap();
217
218        // 2 out of 10 blocks used
219        let util = manager.utilization();
220        assert!((util - 0.2).abs() < 1e-10);
221
222        manager.free_blocks(&blocks);
223        assert_eq!(manager.utilization(), 0.0);
224    }
225
226    #[test]
227    fn test_prefix_caching() {
228        let mut manager = KVCacheManager::new(16000, 16, 100, true);
229
230        let mut request1 = create_test_request("req-1", 16);
231        let blocks1 = manager.allocate_blocks(&request1, 16).unwrap();
232        request1.kv_blocks.extend(blocks1);
233
234        // First check - should miss
235        let cached = manager.check_prefix_cache(&request1);
236        assert_eq!(cached, 0);
237        assert_eq!(manager.num_prefix_cache_misses, 1);
238
239        // Second check with same prompt - should hit
240        let request2 = create_test_request("req-1", 16);
241        let cached = manager.check_prefix_cache(&request2);
242        assert_eq!(cached, 16);
243        assert_eq!(manager.num_prefix_cache_hits, 1);
244    }
245}