1use super::block::Block;
2use crate::request::{BlockId, Request};
3use std::collections::HashMap;
4
5pub struct KVCacheManager {
7 block_size: u32,
9
10 total_blocks: u32,
12
13 blocks: Vec<Block>,
15
16 free_blocks: Vec<BlockId>,
18
19 enable_prefix_caching: bool,
21
22 prefix_cache: HashMap<u64, BlockId>,
27
28 pub num_prefix_cache_hits: u64,
30 pub num_prefix_cache_misses: u64,
31 pub hit_size_count: u64,
32 pub hit_size_sum: u64,
33}
34
35impl KVCacheManager {
36 pub fn new(
38 kv_cache_capacity: u64,
39 block_size: u32,
40 kv_cache_bytes_per_token: u64,
41 enable_prefix_caching: bool,
42 ) -> Self {
43 let bytes_per_block = block_size as u64 * kv_cache_bytes_per_token;
44 let total_blocks = (kv_cache_capacity / bytes_per_block) as u32;
45
46 let blocks = (0..total_blocks).map(Block::new).collect();
47
48 let free_blocks = (0..total_blocks).collect();
49
50 Self {
51 block_size,
52 total_blocks,
53 blocks,
54 free_blocks,
55 enable_prefix_caching,
56 prefix_cache: HashMap::new(),
57 num_prefix_cache_hits: 0,
58 num_prefix_cache_misses: 0,
59 hit_size_count: 0,
60 hit_size_sum: 0,
61 }
62 }
63
64 pub fn allocate_blocks(&mut self, request: &Request, num_tokens: u32) -> Option<Vec<BlockId>> {
67 let blocks_needed = self.calculate_blocks_needed(request, num_tokens);
68 let hashes = request.get_prompt_block_hashes();
69
70 if self.free_blocks.len() < blocks_needed {
71 return None; }
73
74 let mut allocated = Vec::new();
75 let mut evicted_hashes = Vec::new();
76 for i in 0..blocks_needed {
77 let block_id = self.free_blocks.pop().unwrap();
78 let evicted_hash = self.blocks[block_id as usize].allocate(hashes.get(i).cloned());
79 evicted_hashes.extend(evicted_hash);
80 allocated.push(block_id);
81 }
82
83 if self.enable_prefix_caching {
85 for hash in evicted_hashes {
87 self.prefix_cache.remove(&hash);
88 }
89 for (i, &hash) in hashes.iter().enumerate() {
91 if let Some(&block_id) = allocated.get(i) {
92 self.prefix_cache.insert(hash, block_id);
93 }
94 }
95 };
96
97 Some(allocated)
98 }
99
100 fn calculate_blocks_needed(&self, request: &Request, num_new_tokens: u32) -> usize {
102 let total_tokens = request.num_computed_tokens + num_new_tokens;
103 let total_blocks_needed = total_tokens.div_ceil(self.block_size) as usize;
104 total_blocks_needed.saturating_sub(request.kv_blocks.len())
105 }
106
107 pub fn free_blocks(&mut self, block_ids: &[BlockId]) {
109 for &block_id in block_ids {
110 let block = &mut self.blocks[block_id as usize];
111 block.release();
112
113 if block.is_free {
114 self.free_blocks.push(block_id);
115 }
116 }
117 }
118
119 pub fn num_free_blocks(&self) -> usize {
121 self.free_blocks.len()
122 }
123
124 pub fn total_blocks(&self) -> usize {
126 self.total_blocks as usize
127 }
128
129 pub fn utilization(&self) -> f64 {
131 1.0 - (self.free_blocks.len() as f64 / self.total_blocks as f64)
132 }
133
134 pub fn peek_prefix_cache(&mut self, request: &Request) -> u32 {
137 if !self.enable_prefix_caching {
138 return 0;
139 }
140
141 let block_hashes = request.get_prompt_block_hashes();
143
144 if block_hashes.is_empty() {
145 return 0;
147 }
148
149 let mut cached_blocks = 0;
151 for &hash in block_hashes {
152 if self.prefix_cache.contains_key(&hash) {
153 cached_blocks += 1;
154 } else {
155 break;
157 }
158 }
159
160 cached_blocks * self.block_size
161 }
162
163 pub fn query_prefix_cache(&mut self, request: &Request) -> u32 {
164 let tokens = self.peek_prefix_cache(request);
165 self.hit_size_count += 1;
166 self.hit_size_sum += tokens as u64;
167
168 if tokens == 0 {
169 self.num_prefix_cache_misses += 1;
170 } else {
171 self.num_prefix_cache_hits += 1;
172 }
173 tokens
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 fn create_test_request(id: &str, prompt_tokens: u32) -> Request {
182 Request::new(id.to_string(), 0, 0.0, prompt_tokens, 50)
183 }
184
185 #[test]
186 fn test_kv_cache_manager_creation() {
187 let manager = KVCacheManager::new(16000, 16, 100, false);
189
190 assert_eq!(manager.block_size, 16);
191 assert_eq!(manager.total_blocks, 10); assert_eq!(manager.num_free_blocks(), 10);
193 assert_eq!(manager.utilization(), 0.0);
194 }
195
196 #[test]
197 fn test_block_allocation() {
198 let mut manager = KVCacheManager::new(16000, 16, 100, false);
199 let mut request = create_test_request("req-1", 32);
200
201 let allocated = manager.allocate_blocks(&request, 32);
203 assert!(allocated.is_some());
204
205 let blocks = allocated.unwrap();
206 assert_eq!(blocks.len(), 2);
207 assert_eq!(manager.num_free_blocks(), 8);
208
209 request.kv_blocks.extend(blocks);
210 request.num_computed_tokens = 32; let more_blocks = manager.allocate_blocks(&request, 16);
214 assert!(more_blocks.is_some());
215 assert_eq!(more_blocks.unwrap().len(), 1); assert_eq!(manager.num_free_blocks(), 7);
217 }
218
219 #[test]
220 fn test_block_allocation_failure() {
221 let mut manager = KVCacheManager::new(1600, 16, 100, false);
222 assert_eq!(manager.total_blocks, 1);
224
225 let request = create_test_request("req-1", 32);
226
227 let allocated = manager.allocate_blocks(&request, 32);
229 assert!(allocated.is_none());
230 }
231
232 #[test]
233 fn test_block_free() {
234 let mut manager = KVCacheManager::new(16000, 16, 100, false);
235 let request = create_test_request("req-1", 32);
236
237 let blocks = manager.allocate_blocks(&request, 32).unwrap();
238 assert_eq!(manager.num_free_blocks(), 8);
239
240 manager.free_blocks(&blocks);
241 assert_eq!(manager.num_free_blocks(), 10);
242 assert_eq!(manager.utilization(), 0.0);
243 }
244
245 #[test]
246 fn test_utilization() {
247 let mut manager = KVCacheManager::new(16000, 16, 100, false);
248 assert_eq!(manager.utilization(), 0.0);
249
250 let request = create_test_request("req-1", 32);
251 let blocks = manager.allocate_blocks(&request, 32).unwrap();
252
253 let util = manager.utilization();
255 assert!((util - 0.2).abs() < 1e-10);
256
257 manager.free_blocks(&blocks);
258 assert_eq!(manager.utilization(), 0.0);
259 }
260
261 #[test]
262 fn test_prefix_caching() {
263 let mut manager = KVCacheManager::new(16000, 16, 100, true);
264
265 let mut request1 = create_test_request("req-1", 16);
267 request1.prompt_block_hashes = vec![12345]; let cached = manager.query_prefix_cache(&request1);
271 assert_eq!(cached, 0);
272 assert_eq!(manager.num_prefix_cache_misses, 1);
273
274 let blocks1 = manager.allocate_blocks(&request1, 16).unwrap();
276 request1.kv_blocks.extend(blocks1);
277
278 let mut request2 = create_test_request("req-2", 16);
280 request2.prompt_block_hashes = vec![12345]; let cached = manager.query_prefix_cache(&request2);
282 assert_eq!(cached, 16); assert_eq!(manager.num_prefix_cache_hits, 1);
284
285 let mut request3 = create_test_request("req-3", 16);
287 request3.prompt_block_hashes = vec![67890]; let cached = manager.query_prefix_cache(&request3);
289 assert_eq!(cached, 0);
290 assert_eq!(manager.num_prefix_cache_misses, 2);
291 }
292}