1use super::block::Block;
2use crate::request::{BlockId, Request};
3use std::collections::HashMap;
4
5pub struct KVCacheManager {
7 block_size: u32,
9
10 total_blocks: u32,
12
13 blocks: Vec<Block>,
15
16 free_blocks: Vec<BlockId>,
18
19 enable_prefix_caching: bool,
21
22 prefix_cache: HashMap<u64, BlockId>,
24
25 pub num_prefix_cache_hits: u64,
27 pub num_prefix_cache_misses: u64,
28}
29
30impl KVCacheManager {
31 pub fn new(
33 kv_cache_capacity: u64,
34 block_size: u32,
35 kv_cache_bytes_per_token: u64,
36 enable_prefix_caching: bool,
37 ) -> Self {
38 let bytes_per_block = block_size as u64 * kv_cache_bytes_per_token;
39 let total_blocks = (kv_cache_capacity / bytes_per_block) as u32;
40
41 let blocks = (0..total_blocks).map(Block::new).collect();
42
43 let free_blocks = (0..total_blocks).collect();
44
45 Self {
46 block_size,
47 total_blocks,
48 blocks,
49 free_blocks,
50 enable_prefix_caching,
51 prefix_cache: HashMap::new(),
52 num_prefix_cache_hits: 0,
53 num_prefix_cache_misses: 0,
54 }
55 }
56
57 pub fn allocate_blocks(&mut self, request: &Request, num_tokens: u32) -> Option<Vec<BlockId>> {
60 let blocks_needed = self.calculate_blocks_needed(request, num_tokens);
61
62 if self.free_blocks.len() < blocks_needed {
63 return None; }
65
66 let mut allocated = Vec::new();
67 for _ in 0..blocks_needed {
68 let block_id = self.free_blocks.pop().unwrap();
69 self.blocks[block_id as usize].allocate();
70 allocated.push(block_id);
71 }
72
73 Some(allocated)
74 }
75
76 fn calculate_blocks_needed(&self, request: &Request, num_new_tokens: u32) -> usize {
78 let total_tokens = request.num_computed_tokens + num_new_tokens;
79 let total_blocks_needed = ((total_tokens + self.block_size - 1) / self.block_size) as usize;
80 total_blocks_needed.saturating_sub(request.kv_blocks.len())
81 }
82
83 pub fn free_blocks(&mut self, block_ids: &[BlockId]) {
85 for &block_id in block_ids {
86 let block = &mut self.blocks[block_id as usize];
87 block.release();
88
89 if block.is_free {
90 self.free_blocks.push(block_id);
91 }
92 }
93 }
94
95 pub fn num_free_blocks(&self) -> usize {
97 self.free_blocks.len()
98 }
99
100 pub fn utilization(&self) -> f64 {
102 1.0 - (self.free_blocks.len() as f64 / self.total_blocks as f64)
103 }
104
105 pub fn check_prefix_cache(&mut self, request: &Request) -> u32 {
107 if !self.enable_prefix_caching {
108 return 0;
109 }
110
111 let prompt_hash = self.hash_prompt(&request.request_id, request.num_prompt_tokens);
113
114 if let Some(&_block_id) = self.prefix_cache.get(&prompt_hash) {
115 self.num_prefix_cache_hits += 1;
116 self.block_size.min(request.num_prompt_tokens)
119 } else {
120 self.num_prefix_cache_misses += 1;
121 if let Some(&first_block) = request.kv_blocks.first() {
123 self.prefix_cache.insert(prompt_hash, first_block);
124 }
125 0
126 }
127 }
128
129 fn hash_prompt(&self, request_id: &str, num_tokens: u32) -> u64 {
131 use std::collections::hash_map::DefaultHasher;
133 use std::hash::{Hash, Hasher};
134
135 let mut hasher = DefaultHasher::new();
136 request_id.hash(&mut hasher);
137 num_tokens.hash(&mut hasher);
138 hasher.finish()
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 fn create_test_request(id: &str, prompt_tokens: u32) -> Request {
147 Request::new(id.to_string(), 0, 0.0, prompt_tokens, 50)
148 }
149
150 #[test]
151 fn test_kv_cache_manager_creation() {
152 let manager = KVCacheManager::new(16000, 16, 100, false);
154
155 assert_eq!(manager.block_size, 16);
156 assert_eq!(manager.total_blocks, 10); assert_eq!(manager.num_free_blocks(), 10);
158 assert_eq!(manager.utilization(), 0.0);
159 }
160
161 #[test]
162 fn test_block_allocation() {
163 let mut manager = KVCacheManager::new(16000, 16, 100, false);
164 let mut request = create_test_request("req-1", 32);
165
166 let allocated = manager.allocate_blocks(&request, 32);
168 assert!(allocated.is_some());
169
170 let blocks = allocated.unwrap();
171 assert_eq!(blocks.len(), 2);
172 assert_eq!(manager.num_free_blocks(), 8);
173
174 request.kv_blocks.extend(blocks);
175 request.num_computed_tokens = 32; let more_blocks = manager.allocate_blocks(&request, 16);
179 assert!(more_blocks.is_some());
180 assert_eq!(more_blocks.unwrap().len(), 1); assert_eq!(manager.num_free_blocks(), 7);
182 }
183
184 #[test]
185 fn test_block_allocation_failure() {
186 let mut manager = KVCacheManager::new(1600, 16, 100, false);
187 assert_eq!(manager.total_blocks, 1);
189
190 let request = create_test_request("req-1", 32);
191
192 let allocated = manager.allocate_blocks(&request, 32);
194 assert!(allocated.is_none());
195 }
196
197 #[test]
198 fn test_block_free() {
199 let mut manager = KVCacheManager::new(16000, 16, 100, false);
200 let request = create_test_request("req-1", 32);
201
202 let blocks = manager.allocate_blocks(&request, 32).unwrap();
203 assert_eq!(manager.num_free_blocks(), 8);
204
205 manager.free_blocks(&blocks);
206 assert_eq!(manager.num_free_blocks(), 10);
207 assert_eq!(manager.utilization(), 0.0);
208 }
209
210 #[test]
211 fn test_utilization() {
212 let mut manager = KVCacheManager::new(16000, 16, 100, false);
213 assert_eq!(manager.utilization(), 0.0);
214
215 let request = create_test_request("req-1", 32);
216 let blocks = manager.allocate_blocks(&request, 32).unwrap();
217
218 let util = manager.utilization();
220 assert!((util - 0.2).abs() < 1e-10);
221
222 manager.free_blocks(&blocks);
223 assert_eq!(manager.utilization(), 0.0);
224 }
225
226 #[test]
227 fn test_prefix_caching() {
228 let mut manager = KVCacheManager::new(16000, 16, 100, true);
229
230 let mut request1 = create_test_request("req-1", 16);
231 let blocks1 = manager.allocate_blocks(&request1, 16).unwrap();
232 request1.kv_blocks.extend(blocks1);
233
234 let cached = manager.check_prefix_cache(&request1);
236 assert_eq!(cached, 0);
237 assert_eq!(manager.num_prefix_cache_misses, 1);
238
239 let request2 = create_test_request("req-1", 16);
241 let cached = manager.check_prefix_cache(&request2);
242 assert_eq!(cached, 16);
243 assert_eq!(manager.num_prefix_cache_hits, 1);
244 }
245}