1use crate::mocker::evictor::LRUEvictor;
37use crate::mocker::protocols::{MoveBlock, MoveBlockResponse, PrefillCost};
38use crate::mocker::sequence::ActiveSequence;
39use crate::tokens::blocks::UniqueBlock;
40use derive_getters::Getters;
41use std::collections::{HashMap, HashSet};
42use tokio::sync::mpsc;
43
44#[derive(Getters)]
45pub struct KvManager {
46 #[getter(copy)]
47 max_capacity: usize,
48
49 #[getter(copy)]
50 block_size: usize,
51
52 active_blocks: HashMap<UniqueBlock, usize>,
53
54 inactive_blocks: LRUEvictor<UniqueBlock>,
55
56 all_blocks: HashSet<UniqueBlock>,
57
58 move_block_response_tx: Option<mpsc::UnboundedSender<MoveBlockResponse>>,
59}
60
61impl KvManager {
62 pub fn new(max_capacity: usize, block_size: usize) -> Self {
63 Self::new_with_sender(max_capacity, block_size, None)
64 }
65
66 pub fn new_with_sender(
67 max_capacity: usize,
68 block_size: usize,
69 move_block_response_tx: Option<mpsc::UnboundedSender<MoveBlockResponse>>,
70 ) -> Self {
71 let active_blocks = HashMap::new();
72 let inactive_blocks = LRUEvictor::default();
73 let all_blocks = HashSet::new();
74
75 KvManager {
76 max_capacity,
77 block_size,
78 active_blocks,
79 inactive_blocks,
80 all_blocks,
81 move_block_response_tx,
82 }
83 }
84
85 fn send_block_response(
87 &self,
88 mut blocks: Vec<u64>,
89 reverse: bool,
90 store: bool,
91 parent_hash: Option<u64>,
92 ) {
93 if let Some(ref tx) = self.move_block_response_tx
94 && !blocks.is_empty()
95 {
96 if reverse {
97 blocks.reverse();
98 }
99 let response = if store {
100 MoveBlockResponse::Store(blocks, parent_hash)
101 } else {
102 MoveBlockResponse::Remove(blocks)
103 };
104 tx.send(response).unwrap();
105 }
106 }
107
108 pub fn process(&mut self, event: &MoveBlock) -> bool {
110 match event {
111 MoveBlock::Use(hashes) => {
112 let mut blocks_stored = Vec::<u64>::new();
113
114 let mut parent_block: Option<&UniqueBlock> = None;
115 for hash in hashes {
116 if let Some(ref_count) = self.active_blocks.get_mut(hash) {
118 *ref_count += 1;
120 parent_block = Some(hash);
121 continue;
122 }
123
124 if self.inactive_blocks.remove(hash) {
126 self.active_blocks.insert(hash.clone(), 1);
128 parent_block = Some(hash);
129 continue;
130 }
131
132 let active_count = self.active_blocks.len();
134 let inactive_count = self.inactive_blocks.len();
135
136 if active_count + inactive_count >= self.max_capacity {
138 let Some(evicted) = self.inactive_blocks.evict() else {
139 return false;
140 };
141 self.all_blocks.remove(&evicted);
142 if let UniqueBlock::FullBlock(evicted_full_block) = evicted {
143 self.send_block_response(vec![evicted_full_block], false, false, None);
144 }
145 }
146
147 self.active_blocks.insert(hash.clone(), 1);
149 self.all_blocks.insert(hash.clone());
150 if self.move_block_response_tx.is_some()
151 && let UniqueBlock::FullBlock(stored_full_block) = hash
152 {
153 blocks_stored.push(*stored_full_block);
154 }
155 }
156
157 let parent_hash = match parent_block {
158 None => None,
159 Some(UniqueBlock::FullBlock(block)) => Some(*block),
160 Some(UniqueBlock::PartialBlock(_)) => panic!("parent block cannot be partial"),
161 };
162 self.send_block_response(blocks_stored, false, true, parent_hash);
163 }
164
165 MoveBlock::Destroy(hashes) => {
166 let mut blocks_destroyed = Vec::<u64>::new();
167
168 for hash in hashes.iter().rev() {
170 self.active_blocks.remove(hash).unwrap();
171 assert!(self.all_blocks.remove(hash));
173
174 if self.move_block_response_tx.is_some()
176 && let UniqueBlock::FullBlock(destroyed_full_block) = hash
177 {
178 blocks_destroyed.push(*destroyed_full_block);
179 }
180 }
181
182 self.send_block_response(blocks_destroyed, true, false, None);
183 }
184
185 MoveBlock::Deref(hashes) => {
186 for hash in hashes.iter().rev() {
188 if let Some(ref_count) = self.active_blocks.get_mut(hash) {
190 if *ref_count == 0 {
191 panic!("Negative reference count would be encountered after Deref.");
192 }
193 *ref_count -= 1;
194
195 if *ref_count == 0 {
197 self.active_blocks.remove(hash);
198 self.inactive_blocks.insert(hash.clone());
200 }
201 }
202 }
203 }
204
205 MoveBlock::Promote(uuid, hash, parent_hash) => {
206 let uuid_block = UniqueBlock::PartialBlock(*uuid);
207 let hash_block = UniqueBlock::FullBlock(*hash);
208
209 let Some(ref_count) = self.active_blocks.remove(&uuid_block) else {
210 let in_all_blocks = self.all_blocks.contains(&uuid_block);
211 panic!(
212 "Missing active block for promotion: {uuid_block:?}. Block still exists: {in_all_blocks}"
213 );
214 };
215
216 self.active_blocks.insert(hash_block.clone(), ref_count);
218
219 assert!(self.all_blocks.remove(&uuid_block));
221 self.all_blocks.insert(hash_block);
222 self.send_block_response(vec![*hash], false, true, *parent_hash);
223 }
224 }
225
226 true
228 }
229
230 pub fn probe_new_blocks(&self, blocks: &[UniqueBlock]) -> usize {
232 blocks
233 .iter()
234 .filter(|&block| !self.all_blocks.contains(block))
236 .count()
237 }
238
239 pub fn current_capacity(&self) -> usize {
241 let active = self.active_blocks.len();
242 let inactive = self.inactive_blocks.len();
243 active + inactive
244 }
245
246 pub fn current_capacity_perc(&self) -> f64 {
248 let current = self.current_capacity() as f64;
249 current / self.max_capacity as f64
250 }
251
252 pub fn num_active_blocks(&self) -> usize {
254 self.active_blocks.len()
255 }
256
257 pub fn get_active_perc(&self) -> f64 {
259 self.active_blocks.len() as f64 / self.max_capacity as f64
260 }
261
262 pub fn num_inactive_blocks(&self) -> usize {
264 self.inactive_blocks.len()
265 }
266
267 pub fn get_inactive_blocks(&self) -> Vec<&UniqueBlock> {
269 self.inactive_blocks.keys().collect()
270 }
271
272 pub fn get_active_blocks(&self) -> Vec<&UniqueBlock> {
274 self.active_blocks.keys().collect()
275 }
276
277 pub fn get_prefill_cost(&self, sequence: &ActiveSequence) -> PrefillCost {
279 let seq_blocks = sequence.unique_blocks();
280 let new_blocks = self.probe_new_blocks(seq_blocks);
281 let overlap_blocks = seq_blocks.len() - new_blocks;
282 let new_tokens = sequence.num_input_tokens() - overlap_blocks * self.block_size;
283
284 PrefillCost {
285 new_blocks,
286 new_tokens,
287 }
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use tokio::sync::mpsc;
295
296 #[test]
297 fn test_failure_on_max_capacity() {
298 let mut manager = KvManager::new(10, 16);
300
301 fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) -> bool {
303 let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
304 manager.process(&MoveBlock::Use(blocks))
305 }
306
307 let response = use_blocks(&mut manager, (0..10).collect());
309 assert!(response, "Expected success response");
310
311 assert_eq!(manager.current_capacity(), 10);
313
314 let response = use_blocks(&mut manager, vec![10]);
316 assert!(
317 !response,
318 "Expected failure response when exceeding max capacity"
319 );
320 }
321
322 #[test]
323 fn test_block_lifecycle_stringent() {
324 let (tx, mut rx) = mpsc::unbounded_channel::<MoveBlockResponse>();
326
327 let mut manager = KvManager::new_with_sender(10, 16, Some(tx));
329
330 fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) {
332 let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
333 manager.process(&MoveBlock::Use(blocks));
334 }
335
336 fn destroy_blocks(manager: &mut KvManager, ids: Vec<u64>) {
338 let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
339 manager.process(&MoveBlock::Destroy(blocks));
340 }
341
342 fn deref_blocks(manager: &mut KvManager, ids: Vec<u64>) {
344 let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
345 manager.process(&MoveBlock::Deref(blocks));
346 }
347
348 fn assert_block_response(
350 rx: &mut mpsc::UnboundedReceiver<MoveBlockResponse>,
351 expected_type: &str,
352 expected_blocks: Vec<u64>,
353 description: &str,
354 ) {
355 let response = rx
356 .try_recv()
357 .unwrap_or_else(|_| panic!("Expected {expected_type} response {description}"));
358
359 match (&response, expected_type) {
360 (MoveBlockResponse::Store(blocks, _parent_hash), "Store") => {
361 assert_eq!(
362 blocks.len(),
363 expected_blocks.len(),
364 "Expected {} blocks in Store response {}",
365 expected_blocks.len(),
366 description
367 );
368 assert_eq!(
369 *blocks, expected_blocks,
370 "Store blocks don't match expected {description}"
371 );
372 }
373 (MoveBlockResponse::Remove(blocks), "Remove") => {
374 assert_eq!(
375 blocks.len(),
376 expected_blocks.len(),
377 "Expected {} blocks in Remove response {}",
378 expected_blocks.len(),
379 description
380 );
381 assert_eq!(
382 *blocks, expected_blocks,
383 "Remove blocks don't match expected {description}"
384 );
385 }
386 _ => panic!("Expected {expected_type} response, got {response:?} {description}"),
387 }
388 }
389
390 fn assert_no_response(
392 rx: &mut mpsc::UnboundedReceiver<MoveBlockResponse>,
393 description: &str,
394 ) {
395 assert!(rx.try_recv().is_err(), "Expected no response {description}",);
396 }
397
398 fn assert_active_blocks(manager: &KvManager, expected_blocks: &[(u64, usize)]) {
400 assert_eq!(
401 manager.active_blocks().len(),
402 expected_blocks.len(),
403 "Active blocks count doesn't match expected"
404 );
405
406 for &(id, ref_count) in expected_blocks {
407 let block = UniqueBlock::FullBlock(id);
408 assert!(
409 manager.active_blocks().contains_key(&block),
410 "Block {id} not found in active blocks",
411 );
412 assert_eq!(
413 manager.active_blocks().get(&block),
414 Some(&ref_count),
415 "Block {id} has wrong reference count",
416 );
417 }
418 }
419
420 fn assert_inactive_blocks(
422 manager: &KvManager,
423 expected_size: usize,
424 expected_blocks: &[u64],
425 ) {
426 let inactive_blocks = manager.get_inactive_blocks();
427 let inactive_blocks_count = manager.inactive_blocks().len();
428
429 assert_eq!(
430 inactive_blocks_count, expected_size,
431 "Inactive blocks count doesn't match expected"
432 );
433
434 for &id in expected_blocks {
435 let block = UniqueBlock::FullBlock(id);
436 assert!(
437 inactive_blocks.iter().any(|&b| *b == block),
438 "Block {id} not found in inactive blocks",
439 );
440 }
441 }
442
443 use_blocks(&mut manager, (0..5).collect());
445 assert_block_response(&mut rx, "Store", vec![0, 1, 2, 3, 4], "after first use");
446
447 use_blocks(&mut manager, vec![0, 1, 5, 6]);
449 assert_block_response(&mut rx, "Store", vec![5, 6], "after second use");
450
451 assert_active_blocks(
453 &manager,
454 &[(0, 2), (1, 2), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1)],
455 );
456
457 destroy_blocks(&mut manager, vec![4]);
459 assert_block_response(&mut rx, "Remove", vec![4], "after destroy block 4");
460
461 deref_blocks(&mut manager, vec![0, 1, 2, 3]);
463 assert_no_response(&mut rx, "after deref operation");
464
465 assert_inactive_blocks(&manager, 2, &[3, 2]);
467 assert_active_blocks(&manager, &[(0, 1), (1, 1), (5, 1), (6, 1)]);
468
469 destroy_blocks(&mut manager, vec![6]);
471 assert_block_response(&mut rx, "Remove", vec![6], "after block 6 eviction");
472
473 deref_blocks(&mut manager, vec![0, 1, 5]);
475
476 assert_inactive_blocks(&manager, 5, &[0, 1, 2, 3, 5]);
478 assert_active_blocks(&manager, &[]);
479
480 use_blocks(&mut manager, vec![0, 1, 2, 7, 8, 9]);
482 assert_block_response(&mut rx, "Store", vec![7, 8, 9], "after [7, 8, 9] use");
483
484 assert_inactive_blocks(&manager, 2, &[3, 5]);
486 assert_active_blocks(&manager, &[(0, 1), (1, 1), (2, 1), (7, 1), (8, 1), (9, 1)]);
487
488 let blocks_to_check: Vec<UniqueBlock> = vec![0, 1, 2, 3, 4]
490 .into_iter()
491 .map(UniqueBlock::FullBlock)
492 .collect();
493 assert_eq!(manager.probe_new_blocks(&blocks_to_check), 1);
494
495 use_blocks(&mut manager, vec![10, 11, 12]);
497 assert_block_response(&mut rx, "Remove", vec![3], "after block 5 eviction");
498 assert_block_response(&mut rx, "Store", vec![10, 11, 12], "after [10, 11, 12] use");
499
500 assert_inactive_blocks(&manager, 1, &[5]);
502
503 use_blocks(&mut manager, vec![13]);
504 assert_block_response(&mut rx, "Remove", vec![5], "after block 5 eviction");
505 assert_block_response(&mut rx, "Store", vec![13], "after block 13 use");
506 }
507}