1use bytes::Bytes;
47use async_trait::async_trait;
49use serde::{Deserialize, Serialize};
50use std::{
51 cell::RefCell,
52 collections::{HashMap, HashSet, VecDeque},
53 iter,
54 rc::Rc,
55 sync::OnceLock,
56 thread::JoinHandle,
57 time::{Duration, Instant},
58};
59use tokio::sync::{broadcast, mpsc, oneshot};
60use tokio_util::sync::CancellationToken;
61use tracing as log;
62use xxhash_rust::xxh3;
63
64pub const XXH3_SEED: u64 = 1337;
65
66use crate::kv_router::protocols::*;
67
68#[derive(Debug, thiserror::Error)]
70pub enum KvRouterError {
71 #[error("Block not found")]
72 BlockNotFound,
73
74 #[error("Indexer is offline")]
75 IndexerOffline,
76
77 #[error("Indexer is dropped request")]
78 IndexerDroppedRequest,
79}
80
81pub type WorkerId = i64;
83
84type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
86
87pub fn compute_hash(data: &[u8]) -> u64 {
88 xxh3::xxh3_64_with_seed(data, XXH3_SEED)
89}
90
91pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
101 LocalBlockHash(compute_hash(data))
102}
103
104pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: usize) -> Vec<LocalBlockHash> {
124 tokens
125 .chunks_exact(kv_block_size) .map(|chunk| {
127 let bytes: Vec<u8> = chunk
128 .iter()
129 .flat_map(|&num| num.to_le_bytes()) .collect();
131
132 compute_block_hash(&Bytes::from(bytes)) })
134 .collect()
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct RouterEvent {
140 worker_id: WorkerId,
142 event: KvCacheEvent,
144}
145
146impl RouterEvent {
147 pub fn new(worker_id: WorkerId, event: KvCacheEvent) -> Self {
158 Self { worker_id, event }
159 }
160}
161
162struct RadixBlock {
164 children: HashMap<LocalBlockHash, SharedRadixBlock>,
166 workers: HashSet<WorkerId>,
168 recent_uses: VecDeque<Instant>,
170}
171
172impl RadixBlock {
173 pub fn new() -> Self {
179 Self {
180 children: HashMap::new(),
181 workers: HashSet::new(),
182 recent_uses: VecDeque::new(),
183 }
184 }
185}
186
187pub struct RadixTree {
188 root: SharedRadixBlock,
191
192 lookup: HashMap<WorkerId, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
201 expiration_duration: Option<Duration>,
203}
204
205impl Default for RadixTree {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211impl RadixTree {
212 pub fn new_with_frequency(expiration_duration: Option<Duration>) -> Self {
218 Self {
219 root: Rc::new(RefCell::new(RadixBlock::new())),
220 lookup: HashMap::new(),
221 expiration_duration,
222 }
223 }
224
225 pub fn new() -> Self {
226 Self::new_with_frequency(None)
227 }
228
229 pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
240 let mut scores = OverlapScores::new();
241 let mut current = self.root.clone();
242 let now = Instant::now();
243 for block_hash in sequence {
244 let next_block = {
245 let current_borrow = current.borrow();
246 current_borrow.children.get(&block_hash).cloned()
247 };
248
249 if let Some(block) = next_block {
250 scores.update_scores(&block.borrow().workers);
251
252 if let Some(expiration_duration) = self.expiration_duration {
253 let mut block_mut = block.borrow_mut();
254
255 while let Some(access_time) = block_mut.recent_uses.front() {
256 if now.duration_since(*access_time) > expiration_duration {
257 block_mut.recent_uses.pop_front();
258 } else {
259 break;
260 }
261 }
262 scores.add_frequency(block_mut.recent_uses.len());
263 block_mut.recent_uses.push_back(now);
264 }
265
266 if early_exit && block.borrow().workers.len() == 1 {
267 break;
268 }
269
270 current = block;
271 } else {
272 break;
273 }
274 }
275
276 scores
277 }
278
279 pub fn apply_event(&mut self, event: RouterEvent) {
285 let (worker_id, event) = (event.worker_id, event.event);
286 let (id, op) = (event.event_id, event.data);
287 log::debug!(id, "Store operation: {:?}", op);
288
289 let worker_lookup = self.lookup.entry(worker_id).or_default();
290
291 match op {
292 KvCacheEventData::Stored(op) => {
293 let current = match op.parent_hash {
297 Some(parent) => worker_lookup.get(&parent),
298 None => Some(&self.root),
299 };
300
301 let mut current = match current {
302 Some(current) => current.clone(),
303 None => {
304 log::warn!(
305 worker_id = worker_id.to_string(),
306 id,
307 parent_hash = ?op.parent_hash,
308 "Failed to find parent block; skipping store operation"
309 );
310 return;
311 }
312 };
313
314 for block_id in op.blocks {
315 let mut inner = current.borrow_mut();
316 let block = match inner.children.get(&block_id.tokens_hash) {
317 Some(block) => block.clone(),
318 None => {
319 let new_block = worker_lookup
321 .get(&block_id.block_hash)
322 .cloned()
323 .unwrap_or_else(|| Rc::new(RefCell::new(RadixBlock::new())));
324
325 inner
327 .children
328 .insert(block_id.tokens_hash, new_block.clone());
329
330 new_block
331 }
332 };
333
334 block.borrow_mut().workers.insert(worker_id);
336
337 worker_lookup.insert(block_id.block_hash, block.clone());
339
340 drop(inner);
342
343 current = block;
344 }
345 }
346 KvCacheEventData::Removed(remove) => {
347 for block in remove.block_hashes {
351 let entry = match worker_lookup.get(&block) {
356 Some(entry) => entry.clone(),
357 None => {
358 log::warn!(
359 worker_id = worker_id.to_string(),
360 id,
361 "Failed to find block to remove; skipping remove operation"
362 );
363 continue;
364 }
365 };
366
367 let mut guard = entry.borrow_mut();
368 guard.workers.remove(&worker_id);
369 if guard.workers.is_empty() {
370 guard.children.clear();
372 }
373 worker_lookup.remove(&block);
375 }
376 }
377 }
378 }
379
380 pub fn remove_worker(&mut self, worker: WorkerId) {
381 if let Some((_, blocks)) = self.lookup.remove_entry(&worker) {
382 blocks.iter().for_each(|(_, block)| {
383 block.borrow_mut().workers.remove(&worker);
384 });
385 }
386 }
387}
388
389#[derive(Debug, Clone, Serialize, Deserialize)]
391pub struct OverlapScores {
392 pub scores: HashMap<WorkerId, u32>,
394 pub frequencies: Vec<usize>,
396}
397
398impl Default for OverlapScores {
399 fn default() -> Self {
400 Self::new()
401 }
402}
403
404impl OverlapScores {
405 pub fn new() -> Self {
411 Self {
412 scores: HashMap::new(),
413 frequencies: Vec::with_capacity(32),
414 }
415 }
416
417 pub fn update_scores(&mut self, workers: &HashSet<WorkerId>) {
423 for worker in workers {
424 let score = self.scores.entry(*worker).or_insert(0);
425 *score += 1;
426 }
427 }
428
429 pub fn add_frequency(&mut self, frequency: usize) {
431 if frequency != 0 {
432 self.frequencies
433 .last()
434 .inspect(|elem| debug_assert!(**elem >= frequency));
435 self.frequencies.push(frequency);
436 }
437 }
438}
439
440pub struct MatchRequest {
442 sequence: Vec<LocalBlockHash>,
444 early_exit: bool,
446 resp: oneshot::Sender<OverlapScores>,
448}
449
450#[async_trait]
451pub trait KvIndexerInterface {
452 async fn find_matches(
462 &self,
463 sequence: Vec<LocalBlockHash>,
464 ) -> Result<OverlapScores, KvRouterError>;
465
466 async fn find_matches_for_request(
476 &self,
477 tokens: &[u32],
478 ) -> Result<OverlapScores, KvRouterError>;
479
480 async fn apply_event(&mut self, event: RouterEvent);
486
487 async fn remove_worker(&mut self, worker: WorkerId);
493
494 fn shutdown(&mut self);
496}
497
498pub struct KvIndexer {
500 cancel: CancellationToken,
502 event_tx: mpsc::Sender<RouterEvent>,
504 match_tx: mpsc::Sender<MatchRequest>,
506 remove_worker_tx: mpsc::Sender<WorkerId>,
508 task: OnceLock<std::thread::JoinHandle<()>>,
510 kv_block_size: usize,
512}
513
514impl KvIndexer {
515 pub fn new_with_frequency(
526 token: CancellationToken,
527 expiration_duration: Option<Duration>,
528 kv_block_size: usize,
529 ) -> Self {
530 let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
531 let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
532 let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
533 let cancel_clone = token.clone();
534 let task = std::thread::spawn(move || {
535 let runtime = tokio::runtime::Builder::new_multi_thread()
537 .worker_threads(1) .enable_all()
539 .build()
540 .unwrap();
541
542 let local_set = tokio::task::LocalSet::new();
543
544 runtime.block_on(local_set.run_until(async move {
545 tokio::task::spawn_local(async move {
546 let cancel = cancel_clone;
547 let mut match_rx = match_rx;
548 let mut event_rx = event_rx;
549 let mut remove_worker_rx = remove_worker_rx;
550 let mut trie = RadixTree::new_with_frequency(expiration_duration);
551 loop {
552 tokio::select! {
553 biased;
554
555 Some(worker) = remove_worker_rx.recv() => {
556 trie.remove_worker(worker);
557 }
558
559 Some(req) = match_rx.recv() => {
560 let matches = trie.find_matches(req.sequence, req.early_exit);
561 let _ = req.resp.send(matches);
562 }
563
564 _ = cancel.cancelled() => {
565 log::debug!("KvCacheIndexer progress loop shutting down");
566 return;
567 }
568
569 Some(event) = event_rx.recv() => {
570 trie.apply_event(event);
571 }
572 }
573 }
574 })
575 .await
576 .unwrap()
577 }));
578
579 log::debug!("KvCacheIndexer task completed");
580 });
581
582 let once = OnceLock::new();
583 once.set(task).unwrap();
584
585 Self {
586 cancel: token,
587 event_tx,
588 match_tx,
589 remove_worker_tx,
590 task: once,
591 kv_block_size,
592 }
593 }
594
595 pub fn block_size(&self) -> usize {
596 self.kv_block_size
597 }
598
599 pub fn new(token: CancellationToken, kv_block_size: usize) -> Self {
600 Self::new_with_frequency(token, None, kv_block_size)
601 }
602
603 pub fn event_sender(&self) -> mpsc::Sender<RouterEvent> {
609 self.event_tx.clone()
610 }
611}
612
613#[async_trait]
614impl KvIndexerInterface for KvIndexer {
615 async fn find_matches(
616 &self,
617 sequence: Vec<LocalBlockHash>,
618 ) -> Result<OverlapScores, KvRouterError> {
619 let (resp_tx, resp_rx) = oneshot::channel();
620 let req = MatchRequest {
621 sequence,
622 early_exit: false,
623 resp: resp_tx,
624 };
625
626 if let Err(e) = self.match_tx.send(req).await {
627 log::error!(
628 "Failed to send match request: {:?}; the indexer maybe offline",
629 e
630 );
631 return Err(KvRouterError::IndexerOffline);
632 }
633
634 resp_rx
635 .await
636 .map_err(|_| KvRouterError::IndexerDroppedRequest)
637 }
638
639 async fn find_matches_for_request(
640 &self,
641 tokens: &[u32],
642 ) -> Result<OverlapScores, KvRouterError> {
643 log::debug!(
644 "Finding matches for request tokens: {:?} / len: {}",
645 tokens,
646 tokens.len()
647 );
648 let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
649 log::debug!("Computed sequence: {:?}", sequence);
650 self.find_matches(sequence).await
651 }
652
653 async fn apply_event(&mut self, event: RouterEvent) {
654 self.event_tx.send(event).await.unwrap();
655 }
656
657 async fn remove_worker(&mut self, worker: WorkerId) {
658 self.remove_worker_tx.send(worker).await.unwrap();
659 }
660
661 fn shutdown(&mut self) {
662 self.cancel.cancel();
663 if let Some(task) = self.task.take() {
664 task.join().expect("Failed to join kv indexer task");
665 }
666 }
667}
668
669#[derive(Debug, Clone)]
670pub struct ShardedMatchRequest {
671 sequence: Vec<LocalBlockHash>,
672 early_exit: bool,
673 resp: mpsc::Sender<OverlapScores>,
674}
675
676pub struct KvIndexerSharded {
678 cancel: CancellationToken,
680 kv_block_size: usize,
682 worker_assignments: HashMap<WorkerId, usize>,
683 worker_counts: Vec<usize>,
684
685 event_tx: Vec<mpsc::Sender<RouterEvent>>,
686 request_broadcast_tx: broadcast::Sender<ShardedMatchRequest>,
687 remove_worker_tx: Vec<mpsc::Sender<WorkerId>>,
688 tasks: Vec<JoinHandle<()>>,
689}
690
691impl KvIndexerSharded {
692 pub fn new_with_frequency(
704 token: CancellationToken,
705 num_shards: usize,
706 expiration_duration: Option<Duration>,
707 kv_block_size: usize,
708 ) -> Self {
709 let worker_assignments: HashMap<WorkerId, usize> = HashMap::new();
710 let worker_counts: Vec<usize> = vec![0; num_shards];
711
712 let mut event_tx = Vec::new();
713 let mut remove_worker_tx = Vec::new();
714 let mut tasks = Vec::new();
715
716 let (request_broadcast_tx, _) = broadcast::channel::<ShardedMatchRequest>(1048576);
717
718 for _ in 0..num_shards {
719 let (shard_event_tx, mut shard_event_rx) = mpsc::channel::<RouterEvent>(2048);
720 let (shard_remove_worker_tx, mut shard_remove_worker_rx) =
721 mpsc::channel::<WorkerId>(16);
722 let mut shard_broadcast_rx = request_broadcast_tx.subscribe();
723 let cancel = token.clone();
724
725 event_tx.push(shard_event_tx);
726 remove_worker_tx.push(shard_remove_worker_tx);
727
728 let runtime = tokio::runtime::Builder::new_multi_thread()
729 .worker_threads(1)
730 .enable_all()
731 .build()
732 .unwrap();
733
734 tasks.push(std::thread::spawn(move || {
735 let local_set = tokio::task::LocalSet::new();
736
737 runtime.block_on(local_set.run_until(async move {
738 tokio::task::spawn_local(async move {
739 let mut trie = RadixTree::new_with_frequency(expiration_duration);
740 loop {
741 tokio::select! {
742 biased;
743
744 Some(worker) = shard_remove_worker_rx.recv() => {
745 trie.remove_worker(worker);
746 }
747
748 Ok(req) = shard_broadcast_rx.recv() => {
749 let matches = trie.find_matches(req.sequence, req.early_exit);
750 if let Err(e) = req.resp.send(matches).await {
751 log::trace!("Failed to send match response: {:?}", e);
752 }
753 }
754
755 _ = cancel.cancelled() => {
756 log::debug!("KvCacheIndexer progress loop shutting down");
757 return;
758 }
759
760 Some(event) = shard_event_rx.recv() => {
761 trie.apply_event(event);
762 }
763 }
764 }
765 })
766 .await
767 .unwrap()
768 }));
769
770 log::debug!("KvCacheIndexer task completed");
771 }));
772 }
773
774 Self {
775 cancel: token,
776 kv_block_size,
777 worker_assignments,
778 worker_counts,
779 event_tx,
780 request_broadcast_tx,
781 remove_worker_tx,
782 tasks,
783 }
784 }
785
786 pub fn block_size(&self) -> usize {
787 self.kv_block_size
788 }
789
790 pub fn new(token: CancellationToken, num_shards: usize, kv_block_size: usize) -> Self {
791 Self::new_with_frequency(token, num_shards, None, kv_block_size)
792 }
793}
794
795#[async_trait]
796impl KvIndexerInterface for KvIndexerSharded {
797 async fn find_matches(
798 &self,
799 sequence: Vec<LocalBlockHash>,
800 ) -> Result<OverlapScores, KvRouterError> {
801 'match_loop: loop {
802 let (match_tx, mut match_rx) = mpsc::channel(self.event_tx.len());
803 self.request_broadcast_tx
804 .send(ShardedMatchRequest {
805 sequence: sequence.clone(),
806 early_exit: false,
807 resp: match_tx,
808 })
809 .map_err(|_| KvRouterError::IndexerOffline)?;
810
811 let mut scores = OverlapScores::new();
812
813 for response_num in 0..self.event_tx.len() {
814 match match_rx.recv().await {
815 Some(response) => {
816 scores.scores.extend(response.scores);
817
818 if response_num == 0 {
819 scores.frequencies = response.frequencies;
820 } else {
821 let diff = (response.frequencies.len() as i64)
822 - (scores.frequencies.len() as i64);
823
824 if diff > 0 {
825 scores.frequencies.extend(iter::repeat_n(0, diff as usize));
826 }
827
828 for i in 0..response.frequencies.len() {
829 scores.frequencies[i] += response.frequencies[i];
830 }
831 }
832 }
833 None => {
834 continue 'match_loop;
837 }
838 }
839 }
840 return Ok(scores);
841 }
842 }
843
844 async fn find_matches_for_request(
845 &self,
846 tokens: &[u32],
847 ) -> Result<OverlapScores, KvRouterError> {
848 let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
849 self.find_matches(sequence).await
850 }
851
852 async fn apply_event(&mut self, event: RouterEvent) {
853 #[allow(clippy::map_entry)]
854 if !self.worker_assignments.contains_key(&event.worker_id) {
855 let selected_shard = self
857 .worker_counts
858 .iter()
859 .enumerate()
860 .min_by_key(|&(_, value)| value)
861 .unwrap()
862 .0;
863
864 self.worker_assignments
865 .insert(event.worker_id, selected_shard);
866 self.worker_counts[selected_shard] += 1;
867 }
868
869 self.event_tx[self.worker_assignments[&event.worker_id]]
870 .send(event)
871 .await
872 .unwrap();
873 }
874
875 async fn remove_worker(&mut self, worker: WorkerId) {
876 if let Some((_, shard)) = self.worker_assignments.remove_entry(&worker) {
877 self.worker_counts[shard] -= 1;
878 self.remove_worker_tx[shard].send(worker).await.unwrap();
879 }
880 }
881
882 fn shutdown(&mut self) {
884 self.cancel.cancel();
885 while !self.tasks.is_empty() {
886 self.tasks.pop().unwrap().join().unwrap();
887 }
888 }
889}
890
891#[cfg(test)]
892mod tests {
893
894 use super::*;
895 use rstest::rstest;
896 use rstest_reuse::{self, *};
897 use tokio::time;
898 use tokio_util::sync::CancellationToken;
899
900 fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
901 hashes
902 .iter()
903 .map(|i| KvCacheStoredBlockData {
904 tokens_hash: LocalBlockHash(*i),
905 block_hash: ExternalSequenceBlockHash(*i * 100),
906 })
907 .collect()
908 }
909
910 fn add_blocks(
911 hashes: Vec<u64>,
912 parent_hash: Option<ExternalSequenceBlockHash>,
913 ) -> KvCacheEventData {
914 KvCacheEventData::Stored(KvCacheStoreData {
915 parent_hash,
916 blocks: make_blocks(hashes),
917 })
918 }
919
920 fn create_store_event(
921 worker_id: WorkerId,
922 event_id: u64,
923 hashes: Vec<u64>,
924 parent: Option<ExternalSequenceBlockHash>,
925 ) -> RouterEvent {
926 RouterEvent {
927 worker_id,
928 event: KvCacheEvent {
929 event_id,
930 data: add_blocks(hashes, parent),
931 },
932 }
933 }
934
935 fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
936 RouterEvent {
937 worker_id,
938 event: KvCacheEvent {
939 event_id,
940 data: KvCacheEventData::Removed(KvCacheRemoveData {
941 block_hashes: hashes
942 .iter()
943 .map(|i| ExternalSequenceBlockHash(*i * 100))
944 .collect(),
945 }),
946 },
947 }
948 }
949
950 #[test]
951 fn test_radix_tree() {
952 let mut trie = RadixTree::new();
953
954 let worker_1 = 0;
955 let worker_2 = 1;
956
957 trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None));
958
959 let scores = trie.find_matches(
960 vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
961 false,
962 );
963 assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
964
965 assert_eq!(trie.lookup.len(), 1);
966 assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
967 assert_eq!(trie.root.borrow().workers.len(), 0);
968 assert_eq!(trie.root.borrow().children.len(), 1);
969 assert_eq!(
970 trie.root
971 .borrow()
972 .children
973 .get(&LocalBlockHash(1))
974 .unwrap()
975 .borrow()
976 .workers
977 .len(),
978 1
979 );
980 assert_eq!(
981 trie.root
982 .borrow()
983 .children
984 .get(&LocalBlockHash(1))
985 .unwrap()
986 .borrow()
987 .children
988 .len(),
989 1
990 );
991
992 trie.apply_event(create_store_event(worker_2, 1, vec![1, 4, 5], None));
993
994 let scores = trie.find_matches(
995 vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
996 false,
997 );
998 assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
999 assert_eq!(scores.scores.get(&worker_2).unwrap(), &1);
1000
1001 assert_eq!(trie.lookup.len(), 2);
1002 assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1003 assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 3);
1004 assert_eq!(trie.root.borrow().workers.len(), 0);
1005 assert_eq!(trie.root.borrow().children.len(), 1);
1006 assert_eq!(
1007 trie.root
1008 .borrow()
1009 .children
1010 .get(&LocalBlockHash(1))
1011 .unwrap()
1012 .borrow()
1013 .workers
1014 .len(),
1015 2
1016 );
1017 assert_eq!(
1018 trie.root
1019 .borrow()
1020 .children
1021 .get(&LocalBlockHash(1))
1022 .unwrap()
1023 .borrow()
1024 .children
1025 .len(),
1026 2
1027 );
1028
1029 trie.apply_event(create_remove_event(worker_2, 2, vec![5]));
1030 assert_eq!(trie.lookup.len(), 2);
1031 assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1032 assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 2);
1033 assert_eq!(trie.root.borrow().workers.len(), 0);
1034 assert_eq!(trie.root.borrow().children.len(), 1);
1035 assert_eq!(
1036 trie.root
1037 .borrow()
1038 .children
1039 .get(&LocalBlockHash(1))
1040 .unwrap()
1041 .borrow()
1042 .workers
1043 .len(),
1044 2
1045 );
1046 assert_eq!(
1047 trie.root
1048 .borrow()
1049 .children
1050 .get(&LocalBlockHash(1))
1051 .unwrap()
1052 .borrow()
1053 .children
1054 .len(),
1055 2
1056 );
1057
1058 trie.apply_event(create_remove_event(worker_2, 3, vec![4]));
1059
1060 assert_eq!(trie.lookup.len(), 2);
1061 assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1062 assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 1);
1063 assert_eq!(trie.root.borrow().workers.len(), 0);
1064 assert_eq!(trie.root.borrow().children.len(), 1);
1065 assert_eq!(
1066 trie.root
1067 .borrow()
1068 .children
1069 .get(&LocalBlockHash(1))
1070 .unwrap()
1071 .borrow()
1072 .workers
1073 .len(),
1074 2
1075 );
1076 assert_eq!(
1077 trie.root
1078 .borrow()
1079 .children
1080 .get(&LocalBlockHash(1))
1081 .unwrap()
1082 .borrow()
1083 .children
1084 .len(),
1085 2
1086 );
1087
1088 trie.apply_event(create_store_event(
1089 worker_2,
1090 4,
1091 vec![2, 6, 7],
1092 Some(ExternalSequenceBlockHash(100)),
1093 ));
1094
1095 let scores = trie.find_matches(
1096 vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
1097 false,
1098 );
1099 assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
1100 assert_eq!(scores.scores.get(&worker_2).unwrap(), &2);
1101
1102 assert_eq!(trie.lookup.len(), 2);
1103 assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1104 assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 4);
1105 assert_eq!(trie.root.borrow().workers.len(), 0);
1106 assert_eq!(trie.root.borrow().children.len(), 1);
1107 assert_eq!(
1108 trie.root
1109 .borrow()
1110 .children
1111 .get(&LocalBlockHash(1))
1112 .unwrap()
1113 .borrow()
1114 .workers
1115 .len(),
1116 2
1117 );
1118 assert_eq!(
1119 trie.root
1120 .borrow()
1121 .children
1122 .get(&LocalBlockHash(1))
1123 .unwrap()
1124 .borrow()
1125 .children
1126 .len(),
1127 2
1128 );
1129 assert_eq!(
1130 trie.lookup
1131 .get(&worker_1)
1132 .unwrap()
1133 .get(&ExternalSequenceBlockHash(200))
1134 .unwrap()
1135 .borrow()
1136 .workers
1137 .len(),
1138 2
1139 );
1140 assert_eq!(
1141 trie.lookup
1142 .get(&worker_2)
1143 .unwrap()
1144 .get(&ExternalSequenceBlockHash(200))
1145 .unwrap()
1146 .borrow()
1147 .workers
1148 .len(),
1149 2
1150 );
1151 }
1152
1153 #[test]
1154 fn test_remove_worker() {
1155 let mut trie = RadixTree::new();
1156
1157 let worker_0 = 0;
1158 let worker_1 = 1;
1159
1160 assert!(trie
1161 .find_matches(vec![LocalBlockHash(0)], false)
1162 .scores
1163 .is_empty());
1164
1165 trie.apply_event(create_store_event(worker_0, 0, vec![0], None));
1166 trie.apply_event(create_store_event(worker_1, 0, vec![0], None));
1167
1168 let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
1169 assert!(result.len() == 2 && result[&worker_0] == 1 && result[&worker_1] == 1);
1170
1171 trie.remove_worker(worker_0);
1172
1173 let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
1174 assert!(result.len() == 1 && result[&worker_1] == 1);
1175 }
1176
1177 #[test]
1178 fn test_early_stopping() {
1179 let mut trie = RadixTree::new();
1180
1181 let worker_0 = 0;
1182 let worker_1 = 1;
1183
1184 trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 2], None));
1185 trie.apply_event(create_store_event(worker_1, 0, vec![0], None));
1186
1187 let result = trie
1188 .find_matches(
1189 vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(2)],
1190 true,
1191 )
1192 .scores;
1193
1194 assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1);
1195
1196 let result = trie
1197 .find_matches(vec![LocalBlockHash(0), LocalBlockHash(1)], true)
1198 .scores;
1199 assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1);
1200 }
1201
1202 #[rstest]
1203 #[case(11)]
1204 #[case(32)]
1205 #[case(64)]
1206 fn test_compute_block_hash_for_seq(#[case] kv_block_size: usize) {
1207 let sequence = (0..kv_block_size).map(|i| i as u32).collect::<Vec<u32>>();
1209 let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
1210 assert_eq!(hashes.len(), 1);
1211
1212 let sequence = (0..(kv_block_size + 1))
1214 .map(|i| i as u32)
1215 .collect::<Vec<u32>>();
1216 let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
1217 assert_eq!(hashes.len(), 1);
1218
1219 let sequence = (0..(2 * kv_block_size + 1))
1221 .map(|i| i as u32)
1222 .collect::<Vec<u32>>();
1223 let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
1224 assert_eq!(hashes.len(), 2);
1225 }
1226
1227 fn make_indexer(
1228 token: &CancellationToken,
1229 num_shards: usize,
1230 kv_block_size: usize,
1231 ) -> Box<dyn KvIndexerInterface> {
1232 if num_shards == 1 {
1233 Box::new(KvIndexer::new(token.clone(), kv_block_size))
1234 } else {
1235 Box::new(KvIndexerSharded::new(
1236 token.clone(),
1237 num_shards,
1238 kv_block_size,
1239 ))
1240 }
1241 }
1242
1243 #[template]
1244 #[rstest]
1245 fn indexer_template(
1246 #[values(1, 3, 8)] num_shards: usize,
1247 #[values(11, 32, 64)] kv_block_size: usize,
1248 ) {
1249 }
1250
1251 #[tokio::test]
1252 #[apply(indexer_template)]
1253 async fn test_kv_indexer_new(num_shards: usize, kv_block_size: usize) {
1254 let token: CancellationToken = CancellationToken::new();
1255 let _ = make_indexer(&token, num_shards, kv_block_size);
1256 }
1257
1258 #[tokio::test]
1259 #[apply(indexer_template)]
1260 async fn test_find_matches(num_shards: usize, kv_block_size: usize) {
1261 let token = CancellationToken::new();
1262 let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1263
1264 let sequence = vec![compute_block_hash(b"test data")];
1265 let scores = kv_indexer.find_matches(sequence).await;
1266
1267 assert!(scores.unwrap().scores.is_empty());
1268 }
1269
1270 #[tokio::test]
1271 #[apply(indexer_template)]
1272 async fn test_find_matches_for_request(num_shards: usize, kv_block_size: usize) {
1273 let token = CancellationToken::new();
1274 let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1275
1276 let tokens = vec![1, 2, 3, 4];
1277 let scores = kv_indexer.find_matches_for_request(&tokens).await;
1278
1279 assert!(scores.unwrap().scores.is_empty());
1280 }
1281
1282 #[tokio::test]
1283 #[apply(indexer_template)]
1284 async fn test_apply_event(num_shards: usize, kv_block_size: usize) {
1285 let worker_id = 0;
1286
1287 let token = CancellationToken::new();
1288 let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1289
1290 let event = create_store_event(worker_id, 1, vec![1, 2, 3], None);
1291 kv_indexer.apply_event(event).await;
1292
1293 }
1295
1296 #[tokio::test]
1297 #[apply(indexer_template)]
1298 async fn test_shutdown(num_shards: usize, kv_block_size: usize) {
1299 let token = CancellationToken::new();
1300 let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1301
1302 kv_indexer.shutdown();
1303 }
1304
1305 #[tokio::test]
1306 #[apply(indexer_template)]
1307 async fn test_frequency(num_shards: usize, kv_block_size: usize) {
1308 let mut kv_indexer: Box<dyn KvIndexerInterface>;
1309 let token = CancellationToken::new();
1310 let duration = Some(Duration::from_millis(50));
1311
1312 if num_shards == 1 {
1313 kv_indexer = Box::new(KvIndexer::new_with_frequency(
1314 token,
1315 duration,
1316 kv_block_size,
1317 ));
1318 } else {
1319 kv_indexer = Box::new(KvIndexerSharded::new_with_frequency(
1320 token,
1321 num_shards,
1322 duration,
1323 kv_block_size,
1324 ));
1325 }
1326
1327 let worker_id = 0;
1328
1329 let event = create_store_event(worker_id, 0, vec![1, 2, 3, 4], None);
1330 kv_indexer.apply_event(event).await;
1331
1332 time::sleep(Duration::from_millis(5)).await;
1333
1334 let block_hashes = vec![
1335 LocalBlockHash(1),
1336 LocalBlockHash(2),
1337 LocalBlockHash(3),
1338 LocalBlockHash(4),
1339 ];
1340 let scores = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1341
1342 assert_eq!(scores.frequencies.len(), 0);
1343
1344 let scores = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1345 assert_eq!(scores.frequencies, vec![1, 1, 1, 1]);
1346
1347 time::sleep(Duration::from_millis(100)).await;
1348
1349 let scores = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1350 assert_eq!(scores.frequencies.len(), 0);
1351
1352 let scores = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1353 assert_eq!(scores.frequencies, vec![1, 1, 1, 1]);
1354
1355 let scores = kv_indexer
1356 .find_matches(block_hashes[0..3].to_vec())
1357 .await
1358 .unwrap();
1359 assert_eq!(scores.frequencies, vec![2, 2, 2]);
1360
1361 let scores = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1362 assert_eq!(scores.frequencies, vec![3, 3, 3, 2]);
1363 }
1364
1365 #[test]
1366 fn test_router_event_new() {
1367 let worker_id = 0;
1368 let kv_cache_event = KvCacheEvent {
1369 event_id: 1,
1370 data: KvCacheEventData::Stored(KvCacheStoreData {
1371 parent_hash: None,
1372 blocks: vec![KvCacheStoredBlockData {
1373 block_hash: ExternalSequenceBlockHash(0),
1374 tokens_hash: LocalBlockHash(13226331709069118873),
1375 }],
1376 }),
1377 };
1378 let router_event = RouterEvent::new(worker_id, kv_cache_event);
1379
1380 assert_eq!(router_event.worker_id, worker_id);
1381 assert_eq!(router_event.event.event_id, 1);
1382 if let KvCacheEventData::Stored(store_op) = &router_event.event.data {
1383 assert_eq!(store_op.blocks.len(), 1);
1384 assert_eq!(
1385 store_op.blocks[0].tokens_hash,
1386 compute_block_hash(b"test data")
1387 );
1388 assert_eq!(store_op.blocks[0].block_hash, ExternalSequenceBlockHash(0));
1389 } else {
1390 panic!("Expected KvCacheEventData::Stored");
1391 }
1392 }
1393
1394 #[test]
1395 fn test_radix_tree_default() {
1396 let radix_tree: RadixTree = Default::default();
1397 assert!(radix_tree.root.borrow().children.is_empty());
1398 assert!(radix_tree.root.borrow().workers.is_empty());
1399 assert!(radix_tree.lookup.is_empty());
1400 }
1401
1402 #[test]
1403 fn test_overlap_scores_default() {
1404 let overlap_scores: OverlapScores = Default::default();
1405 assert!(overlap_scores.scores.is_empty());
1406 }
1407}