1use super::block::Block;
2use crate::request::{BlockId, Request};
3use std::collections::HashMap;
4
5pub struct KVCacheManager {
7 block_size: u32,
9
10 total_blocks: u32,
12
13 blocks: Vec<Block>,
15
16 free_blocks: Vec<BlockId>,
18
19 enable_prefix_caching: bool,
21
22 prefix_cache: HashMap<u64, BlockId>,
24
25 pub num_prefix_cache_hits: u64,
27 pub num_prefix_cache_misses: u64,
28}
29
30impl KVCacheManager {
31 pub fn new(
33 kv_cache_capacity: u64,
34 block_size: u32,
35 kv_cache_bytes_per_token: u64,
36 enable_prefix_caching: bool,
37 ) -> Self {
38 let bytes_per_block = block_size as u64 * kv_cache_bytes_per_token;
39 let total_blocks = (kv_cache_capacity / bytes_per_block) as u32;
40
41 let blocks = (0..total_blocks).map(Block::new).collect();
42
43 let free_blocks = (0..total_blocks).collect();
44
45 Self {
46 block_size,
47 total_blocks,
48 blocks,
49 free_blocks,
50 enable_prefix_caching,
51 prefix_cache: HashMap::new(),
52 num_prefix_cache_hits: 0,
53 num_prefix_cache_misses: 0,
54 }
55 }
56
57 pub fn allocate_blocks(&mut self, request: &Request, num_tokens: u32) -> Option<Vec<BlockId>> {
60 let blocks_needed = self.calculate_blocks_needed(request, num_tokens);
61
62 if self.free_blocks.len() < blocks_needed {
63 return None; }
65
66 let mut allocated = Vec::new();
67 for _ in 0..blocks_needed {
68 let block_id = self.free_blocks.pop().unwrap();
69 self.blocks[block_id as usize].allocate();
70 allocated.push(block_id);
71 }
72
73 Some(allocated)
74 }
75
76 fn calculate_blocks_needed(&self, request: &Request, num_new_tokens: u32) -> usize {
78 let total_tokens = request.num_computed_tokens + num_new_tokens;
79 let total_blocks_needed =
80 ((total_tokens + self.block_size - 1) / self.block_size) as usize;
81 total_blocks_needed.saturating_sub(request.kv_blocks.len())
82 }
83
84 pub fn free_blocks(&mut self, block_ids: &[BlockId]) {
86 for &block_id in block_ids {
87 let block = &mut self.blocks[block_id as usize];
88 block.release();
89
90 if block.is_free {
91 self.free_blocks.push(block_id);
92 }
93 }
94 }
95
96 pub fn num_free_blocks(&self) -> usize {
98 self.free_blocks.len()
99 }
100
101 pub fn utilization(&self) -> f64 {
103 1.0 - (self.free_blocks.len() as f64 / self.total_blocks as f64)
104 }
105
106 pub fn check_prefix_cache(&mut self, request: &Request) -> u32 {
108 if !self.enable_prefix_caching {
109 return 0;
110 }
111
112 let prompt_hash = self.hash_prompt(&request.request_id, request.num_prompt_tokens);
114
115 if let Some(&_block_id) = self.prefix_cache.get(&prompt_hash) {
116 self.num_prefix_cache_hits += 1;
117 self.block_size.min(request.num_prompt_tokens)
120 } else {
121 self.num_prefix_cache_misses += 1;
122 if let Some(&first_block) = request.kv_blocks.first() {
124 self.prefix_cache.insert(prompt_hash, first_block);
125 }
126 0
127 }
128 }
129
130 fn hash_prompt(&self, request_id: &str, num_tokens: u32) -> u64 {
132 use std::collections::hash_map::DefaultHasher;
134 use std::hash::{Hash, Hasher};
135
136 let mut hasher = DefaultHasher::new();
137 request_id.hash(&mut hasher);
138 num_tokens.hash(&mut hasher);
139 hasher.finish()
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 fn create_test_request(id: &str, prompt_tokens: u32) -> Request {
148 Request::new(id.to_string(), 0, 0.0, prompt_tokens, 50)
149 }
150
151 #[test]
152 fn test_kv_cache_manager_creation() {
153 let manager = KVCacheManager::new(16000, 16, 100, false);
155
156 assert_eq!(manager.block_size, 16);
157 assert_eq!(manager.total_blocks, 10); assert_eq!(manager.num_free_blocks(), 10);
159 assert_eq!(manager.utilization(), 0.0);
160 }
161
162 #[test]
163 fn test_block_allocation() {
164 let mut manager = KVCacheManager::new(16000, 16, 100, false);
165 let mut request = create_test_request("req-1", 32);
166
167 let allocated = manager.allocate_blocks(&request, 32);
169 assert!(allocated.is_some());
170
171 let blocks = allocated.unwrap();
172 assert_eq!(blocks.len(), 2);
173 assert_eq!(manager.num_free_blocks(), 8);
174
175 request.kv_blocks.extend(blocks);
176 request.num_computed_tokens = 32; let more_blocks = manager.allocate_blocks(&request, 16);
180 assert!(more_blocks.is_some());
181 assert_eq!(more_blocks.unwrap().len(), 1); assert_eq!(manager.num_free_blocks(), 7);
183 }
184
185 #[test]
186 fn test_block_allocation_failure() {
187 let mut manager = KVCacheManager::new(1600, 16, 100, false);
188 assert_eq!(manager.total_blocks, 1);
190
191 let request = create_test_request("req-1", 32);
192
193 let allocated = manager.allocate_blocks(&request, 32);
195 assert!(allocated.is_none());
196 }
197
198 #[test]
199 fn test_block_free() {
200 let mut manager = KVCacheManager::new(16000, 16, 100, false);
201 let request = create_test_request("req-1", 32);
202
203 let blocks = manager.allocate_blocks(&request, 32).unwrap();
204 assert_eq!(manager.num_free_blocks(), 8);
205
206 manager.free_blocks(&blocks);
207 assert_eq!(manager.num_free_blocks(), 10);
208 assert_eq!(manager.utilization(), 0.0);
209 }
210
211 #[test]
212 fn test_utilization() {
213 let mut manager = KVCacheManager::new(16000, 16, 100, false);
214 assert_eq!(manager.utilization(), 0.0);
215
216 let request = create_test_request("req-1", 32);
217 let blocks = manager.allocate_blocks(&request, 32).unwrap();
218
219 let util = manager.utilization();
221 assert!((util - 0.2).abs() < 1e-10);
222
223 manager.free_blocks(&blocks);
224 assert_eq!(manager.utilization(), 0.0);
225 }
226
227 #[test]
228 fn test_prefix_caching() {
229 let mut manager = KVCacheManager::new(16000, 16, 100, true);
230
231 let mut request1 = create_test_request("req-1", 16);
232 let blocks1 = manager.allocate_blocks(&request1, 16).unwrap();
233 request1.kv_blocks.extend(blocks1);
234
235 let cached = manager.check_prefix_cache(&request1);
237 assert_eq!(cached, 0);
238 assert_eq!(manager.num_prefix_cache_misses, 1);
239
240 let request2 = create_test_request("req-1", 16);
242 let cached = manager.check_prefix_cache(&request2);
243 assert_eq!(cached, 16);
244 assert_eq!(manager.num_prefix_cache_hits, 1);
245 }
246}