1use async_trait::async_trait;
35use bytes::Bytes;
36use dynamo_runtime::{
37 component::Component,
38 metrics::{MetricsRegistry, prometheus_names::kvrouter},
39};
40use prometheus::{IntCounterVec, Opts};
41use serde::{Deserialize, Serialize};
42use std::{
43 cell::RefCell,
44 collections::{HashMap, HashSet, VecDeque},
45 iter,
46 rc::Rc,
47 sync::{Arc, OnceLock},
48 thread::JoinHandle,
49 time::{Duration, Instant},
50};
51use tokio::sync::{broadcast, mpsc, oneshot};
52use tokio_util::sync::CancellationToken;
53use xxhash_rust::xxh3;
54
55pub const XXH3_SEED: u64 = 1337;
56
57use crate::kv_router::protocols::*;
58use crate::tokens::SequenceHash;
59
60#[derive(Debug, thiserror::Error)]
62pub enum KvRouterError {
63 #[error("Block not found")]
64 BlockNotFound,
65
66 #[error("Indexer is offline")]
67 IndexerOffline,
68
69 #[error("Indexer is dropped request")]
70 IndexerDroppedRequest,
71}
72
73#[derive(Debug, thiserror::Error)]
75pub enum KvCacheEventError {
76 #[error("Failed to find parent block")]
77 ParentBlockNotFound,
78
79 #[error("Failed to find block")]
80 BlockNotFound,
81}
82
83pub type WorkerId = i64;
85
86type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
88
89pub fn compute_hash(data: &[u8]) -> u64 {
90 xxh3::xxh3_64_with_seed(data, XXH3_SEED)
91}
92
93pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
103 LocalBlockHash(compute_hash(data))
104}
105
106pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: u32) -> Vec<LocalBlockHash> {
126 tokens
127 .chunks_exact(kv_block_size as usize) .map(|chunk| {
129 let bytes: Vec<u8> = chunk
130 .iter()
131 .flat_map(|&num| num.to_le_bytes()) .collect();
133
134 compute_block_hash(&Bytes::from(bytes)) })
136 .collect()
137}
138
139pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<SequenceHash> {
153 if block_hashes.is_empty() {
154 return Vec::new();
155 }
156
157 let mut sequence_hashes = Vec::with_capacity(block_hashes.len());
158 sequence_hashes.push(block_hashes[0].0);
159
160 for i in 1..block_hashes.len() {
161 let parent_seq_hash = sequence_hashes[i - 1];
162 let current_block_hash = block_hashes[i].0;
163
164 let combined = [parent_seq_hash, current_block_hash];
165 let bytes: Vec<u8> = combined.iter().flat_map(|&num| num.to_le_bytes()).collect();
166 let seq_hash = compute_hash(&bytes);
167 sequence_hashes.push(seq_hash);
168 }
169
170 sequence_hashes
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct RouterEvent {
176 worker_id: WorkerId,
178 event: KvCacheEvent,
180}
181
182impl RouterEvent {
183 pub fn new(worker_id: WorkerId, event: KvCacheEvent) -> Self {
194 Self { worker_id, event }
195 }
196}
197
198#[derive(Debug)]
200struct RadixBlock {
201 children: HashMap<LocalBlockHash, SharedRadixBlock>,
203 workers: HashSet<WorkerId>,
205 recent_uses: VecDeque<Instant>,
207}
208
209impl RadixBlock {
210 pub fn new() -> Self {
216 Self {
217 children: HashMap::new(),
218 workers: HashSet::new(),
219 recent_uses: VecDeque::new(),
220 }
221 }
222}
223
224pub struct RadixTree {
225 root: SharedRadixBlock,
228
229 lookup: HashMap<WorkerId, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
238 expiration_duration: Option<Duration>,
240}
241
242impl Default for RadixTree {
243 fn default() -> Self {
244 Self::new()
245 }
246}
247
248impl RadixTree {
249 pub fn new_with_frequency(expiration_duration: Option<Duration>) -> Self {
255 Self {
256 root: Rc::new(RefCell::new(RadixBlock::new())),
257 lookup: HashMap::new(),
258 expiration_duration,
259 }
260 }
261
262 pub fn new() -> Self {
263 Self::new_with_frequency(None)
264 }
265
266 pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
277 let mut scores = OverlapScores::new();
278 let mut current = self.root.clone();
279 let now = Instant::now();
280 for block_hash in sequence {
281 let next_block = {
282 let current_borrow = current.borrow();
283 current_borrow.children.get(&block_hash).cloned()
284 };
285 if let Some(block) = next_block {
286 scores.update_scores(&block.borrow().workers);
287
288 if let Some(expiration_duration) = self.expiration_duration {
289 let mut block_mut = block.borrow_mut();
290
291 while let Some(access_time) = block_mut.recent_uses.front() {
292 if now.duration_since(*access_time) > expiration_duration {
293 block_mut.recent_uses.pop_front();
294 } else {
295 break;
296 }
297 }
298 scores.add_frequency(block_mut.recent_uses.len());
299 block_mut.recent_uses.push_back(now);
300 }
301
302 if early_exit && block.borrow().workers.len() == 1 {
303 break;
304 }
305
306 current = block;
307 } else {
308 break;
309 }
310 }
311
312 scores
313 }
314
315 pub fn apply_event(&mut self, event: RouterEvent) -> Result<(), KvCacheEventError> {
321 let (worker_id, event) = (event.worker_id, event.event);
322 let (id, op) = (event.event_id, event.data);
323 tracing::trace!(id, "Store operation: {:?}", op);
324
325 let worker_lookup = self.lookup.entry(worker_id).or_default();
326
327 match op {
328 KvCacheEventData::Stored(op) => {
329 let current = match op.parent_hash {
333 Some(parent) => worker_lookup.get(&parent),
334 None => Some(&self.root),
335 };
336
337 let mut current = match current {
338 Some(current) => current.clone(),
339 None => {
340 tracing::warn!(
341 worker_id = worker_id.to_string(),
342 id,
343 parent_hash = ?op.parent_hash,
344 "Failed to find parent block; skipping store operation"
345 );
346 return Err(KvCacheEventError::ParentBlockNotFound);
347 }
348 };
349
350 for block_id in op.blocks {
351 let mut inner = current.borrow_mut();
352 let block = match inner.children.get(&block_id.tokens_hash) {
353 Some(block) => block.clone(),
354 None => {
355 let new_block = worker_lookup
357 .get(&block_id.block_hash)
358 .cloned()
359 .unwrap_or_else(|| Rc::new(RefCell::new(RadixBlock::new())));
360
361 inner
363 .children
364 .insert(block_id.tokens_hash, new_block.clone());
365
366 new_block
367 }
368 };
369
370 block.borrow_mut().workers.insert(worker_id);
372
373 worker_lookup.insert(block_id.block_hash, block.clone());
375
376 drop(inner);
378
379 current = block;
380 }
381 Ok(())
382 }
383 KvCacheEventData::Removed(remove) => {
384 for block in remove.block_hashes {
388 let entry = match worker_lookup.get(&block) {
393 Some(entry) => entry.clone(),
394 None => {
395 tracing::warn!(
396 worker_id = worker_id.to_string(),
397 id,
398 "Failed to find block to remove; skipping remove operation"
399 );
400 return Err(KvCacheEventError::BlockNotFound);
401 }
402 };
403
404 let mut guard = entry.borrow_mut();
405 guard.workers.remove(&worker_id);
406 if guard.workers.is_empty() {
407 guard.children.clear();
409 }
410 worker_lookup.remove(&block);
412 }
413 Ok(())
414 }
415 KvCacheEventData::Cleared => {
416 self.clear_all_blocks(worker_id);
417 Ok(())
418 }
419 }
420 }
421
422 pub fn remove_worker(&mut self, worker: WorkerId) {
423 if let Some((_, blocks)) = self.lookup.remove_entry(&worker) {
424 blocks.iter().for_each(|(_, block)| {
425 block.borrow_mut().workers.remove(&worker);
426 });
427 }
428 }
429
430 pub fn clear_all_blocks(&mut self, worker: WorkerId) {
431 if let Some(blocks) = self.lookup.get(&worker) {
433 let blocks_to_clear: Vec<_> = blocks.values().collect();
434
435 blocks_to_clear.iter().for_each(|block| {
437 block.borrow_mut().workers.remove(&worker);
438 });
439
440 if let Some(worker_blocks) = self.lookup.get_mut(&worker) {
442 worker_blocks.clear();
443 }
444 }
445 }
446
447 pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
451 let mut events = Vec::new();
452 let mut event_id = 0u64;
453
454 let mut queue = VecDeque::new();
456
457 let root_borrow = self.root.borrow();
459 for (tokens_hash, child_block) in &root_borrow.children {
460 queue.push_back((child_block.clone(), None, *tokens_hash));
461 }
462 drop(root_borrow);
463
464 while let Some((current_block, parent_external_hash, tokens_hash)) = queue.pop_front() {
465 let current_borrow = current_block.borrow();
466
467 let find_external_hash = |worker_id: &WorkerId| {
469 self.lookup.get(worker_id).and_then(|worker_blocks| {
470 worker_blocks
471 .iter()
472 .find(|(_, block)| Rc::ptr_eq(block, ¤t_block))
473 .map(|(hash, _)| *hash)
474 })
475 };
476
477 for worker_id in ¤t_borrow.workers {
479 let external_hash = find_external_hash(worker_id);
481
482 if let Some(block_hash) = external_hash {
483 let event = RouterEvent {
485 worker_id: *worker_id,
486 event: KvCacheEvent {
487 event_id,
488 data: KvCacheEventData::Stored(KvCacheStoreData {
489 parent_hash: parent_external_hash,
490 blocks: vec![KvCacheStoredBlockData {
491 block_hash,
492 tokens_hash,
493 }],
494 }),
495 },
496 };
497 events.push(event);
498 event_id += 1;
499 }
500 }
501
502 let any_external_hash = if !current_borrow.workers.is_empty() {
505 current_borrow
506 .workers
507 .iter()
508 .next()
509 .and_then(find_external_hash)
510 } else {
511 None
512 };
513
514 for (child_tokens_hash, child_block) in ¤t_borrow.children {
515 queue.push_back((child_block.clone(), any_external_hash, *child_tokens_hash));
516 }
517 }
518
519 events
520 }
521}
522
523#[derive(Clone)]
525pub struct KvIndexerMetrics {
526 pub kv_cache_events_applied: IntCounterVec,
528}
529
530pub const METRIC_STATUS_OK: &str = "ok";
532pub const METRIC_STATUS_PARENT_NOT_FOUND: &str = "parent_block_not_found";
533pub const METRIC_STATUS_BLOCK_NOT_FOUND: &str = "block_not_found";
534
535pub const METRIC_EVENT_STORED: &str = "stored";
537pub const METRIC_EVENT_REMOVED: &str = "removed";
538pub const METRIC_EVENT_CLEARED: &str = "cleared";
539
540static KV_INDEXER_METRICS: OnceLock<Arc<KvIndexerMetrics>> = OnceLock::new();
541
542impl KvIndexerMetrics {
543 fn new(kv_cache_events_applied: IntCounterVec) -> Self {
544 Self {
545 kv_cache_events_applied,
546 }
547 }
548
549 pub fn from_component(component: &Component) -> Arc<Self> {
552 KV_INDEXER_METRICS.get_or_init(|| {
553 match component.create_intcountervec(
554 kvrouter::KV_CACHE_EVENTS_APPLIED,
555 "Total number of KV cache events applied to index",
556 &["event_type", "status"],
557 &[],
558 ) {
559 Ok(kv_cache_events_applied) => Arc::new(Self::new(kv_cache_events_applied)),
560 Err(e) => {
561 tracing::warn!("Failed to create kv indexer metrics from component: {}. Using unregistered metrics as fallback.", e);
562 Arc::new(Self::new_unregistered())
563 }
564 }
565 }).clone()
566 }
567
568 pub fn new_unregistered() -> Self {
571 Self {
572 kv_cache_events_applied: IntCounterVec::new(
573 Opts::new(
574 kvrouter::KV_CACHE_EVENTS_APPLIED,
575 "Total number of KV cache events applied to index",
576 ),
577 &["event_type", "status"],
578 )
579 .unwrap(),
580 }
581 }
582
583 pub fn get_event_type(event_data: &KvCacheEventData) -> &'static str {
584 match event_data {
585 KvCacheEventData::Stored(_) => METRIC_EVENT_STORED,
586 KvCacheEventData::Removed(_) => METRIC_EVENT_REMOVED,
587 KvCacheEventData::Cleared => METRIC_EVENT_CLEARED,
588 }
589 }
590
591 pub fn increment_event_applied(
592 &self,
593 event_type: &'static str,
594 result: Result<(), KvCacheEventError>,
595 ) {
596 match result {
597 Ok(_) => {
598 self.kv_cache_events_applied
599 .with_label_values(&[event_type, METRIC_STATUS_OK])
600 .inc_by(1);
601 }
602 Err(e) => {
603 let error_label = match e {
604 KvCacheEventError::ParentBlockNotFound => METRIC_STATUS_PARENT_NOT_FOUND,
605 KvCacheEventError::BlockNotFound => METRIC_STATUS_BLOCK_NOT_FOUND,
606 };
607 self.kv_cache_events_applied
608 .with_label_values(&[event_type, error_label])
609 .inc_by(1);
610 }
611 }
612 }
613}
614
615#[derive(Debug, Clone, Serialize, Deserialize)]
617pub struct OverlapScores {
618 pub scores: HashMap<WorkerId, u32>,
620 pub frequencies: Vec<usize>,
622}
623
624impl Default for OverlapScores {
625 fn default() -> Self {
626 Self::new()
627 }
628}
629
630impl OverlapScores {
631 pub fn new() -> Self {
637 Self {
638 scores: HashMap::new(),
639 frequencies: Vec::with_capacity(32),
640 }
641 }
642
643 pub fn update_scores(&mut self, workers: &HashSet<WorkerId>) {
649 for worker in workers {
650 let score = self.scores.entry(*worker).or_insert(0);
651 *score += 1;
652 }
653 }
654
655 pub fn add_frequency(&mut self, frequency: usize) {
657 if frequency != 0 {
658 self.frequencies
659 .last()
660 .inspect(|elem| debug_assert!(**elem >= frequency));
661 self.frequencies.push(frequency);
662 }
663 }
664}
665
666pub struct MatchRequest {
668 sequence: Vec<LocalBlockHash>,
670 early_exit: bool,
672 resp: oneshot::Sender<OverlapScores>,
674}
675
676pub struct DumpRequest {
678 pub resp: oneshot::Sender<Vec<RouterEvent>>,
680}
681
682#[async_trait]
683pub trait KvIndexerInterface {
684 async fn find_matches(
694 &self,
695 sequence: Vec<LocalBlockHash>,
696 ) -> Result<OverlapScores, KvRouterError>;
697
698 async fn find_matches_for_request(
708 &self,
709 tokens: &[u32],
710 ) -> Result<OverlapScores, KvRouterError>;
711
712 async fn apply_event(&mut self, event: RouterEvent);
718
719 async fn remove_worker(&mut self, worker: WorkerId);
725
726 fn shutdown(&mut self);
728
729 async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError>;
735}
736
737pub struct KvIndexer {
739 cancel: CancellationToken,
741 event_tx: mpsc::Sender<RouterEvent>,
743 match_tx: mpsc::Sender<MatchRequest>,
745 remove_worker_tx: mpsc::Sender<WorkerId>,
747 dump_tx: mpsc::Sender<DumpRequest>,
749 task: OnceLock<std::thread::JoinHandle<()>>,
751 kv_block_size: u32,
753}
754
755impl KvIndexer {
756 pub fn new_with_frequency(
767 token: CancellationToken,
768 expiration_duration: Option<Duration>,
769 kv_block_size: u32,
770 metrics: Arc<KvIndexerMetrics>,
771 ) -> Self {
772 let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
773 let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
774 let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
775 let (dump_tx, dump_rx) = mpsc::channel::<DumpRequest>(16);
776 let cancel_clone = token.clone();
777
778 let task = std::thread::spawn(move || {
779 let runtime = tokio::runtime::Builder::new_multi_thread()
781 .worker_threads(1) .enable_all()
783 .build()
784 .unwrap();
785
786 let local_set = tokio::task::LocalSet::new();
787
788 runtime.block_on(local_set.run_until(async move {
789 tokio::task::spawn_local(async move {
790 let cancel = cancel_clone;
791 let mut match_rx = match_rx;
792 let mut event_rx = event_rx;
793 let mut remove_worker_rx = remove_worker_rx;
794 let mut dump_rx = dump_rx;
795 let mut trie = RadixTree::new_with_frequency(expiration_duration);
796 loop {
797 tokio::select! {
798 biased;
799
800 _ = cancel.cancelled() => {
801 tracing::debug!("KvCacheIndexer progress loop shutting down");
802 return;
803 }
804
805 Some(worker) = remove_worker_rx.recv() => {
806 trie.remove_worker(worker);
807 }
808
809 Some(event) = event_rx.recv() => {
810 let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
811 let result = trie.apply_event(event);
812 metrics.increment_event_applied(event_type, result);
813 }
814
815 Some(dump_req) = dump_rx.recv() => {
816 let events = trie.dump_tree_as_events();
817 let _ = dump_req.resp.send(events);
818 }
819
820 Some(req) = match_rx.recv() => {
821 let matches = trie.find_matches(req.sequence, req.early_exit);
822 let _ = req.resp.send(matches);
823 }
824 }
825 }
826 })
827 .await
828 .unwrap()
829 }));
830
831 tracing::debug!("KvCacheIndexer task completed");
832 });
833
834 let once = OnceLock::new();
835 once.set(task).unwrap();
836
837 Self {
838 cancel: token,
839 event_tx,
840 match_tx,
841 remove_worker_tx,
842 dump_tx,
843 task: once,
844 kv_block_size,
845 }
846 }
847
848 pub fn block_size(&self) -> u32 {
849 self.kv_block_size
850 }
851
852 pub fn new(
853 token: CancellationToken,
854 kv_block_size: u32,
855 metrics: Arc<KvIndexerMetrics>,
856 ) -> Self {
857 Self::new_with_frequency(token, None, kv_block_size, metrics)
858 }
859
860 pub fn event_sender(&self) -> mpsc::Sender<RouterEvent> {
866 self.event_tx.clone()
867 }
868
869 pub fn snapshot_event_sender(&self) -> mpsc::Sender<DumpRequest> {
875 self.dump_tx.clone()
876 }
877
878 pub fn remove_worker_sender(&self) -> mpsc::Sender<WorkerId> {
884 self.remove_worker_tx.clone()
885 }
886}
887
888#[async_trait]
889impl KvIndexerInterface for KvIndexer {
890 async fn find_matches(
891 &self,
892 sequence: Vec<LocalBlockHash>,
893 ) -> Result<OverlapScores, KvRouterError> {
894 let (resp_tx, resp_rx) = oneshot::channel();
895 let req = MatchRequest {
896 sequence,
897 early_exit: false,
898 resp: resp_tx,
899 };
900
901 if let Err(e) = self.match_tx.send(req).await {
902 tracing::error!(
903 "Failed to send match request: {:?}; the indexer maybe offline",
904 e
905 );
906 return Err(KvRouterError::IndexerOffline);
907 }
908
909 resp_rx
910 .await
911 .map_err(|_| KvRouterError::IndexerDroppedRequest)
912 }
913
914 async fn find_matches_for_request(
915 &self,
916 tokens: &[u32],
917 ) -> Result<OverlapScores, KvRouterError> {
918 tracing::debug!(
919 "Finding matches for request tokens: {:?} / len: {}",
920 tokens,
921 tokens.len()
922 );
923 let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
924 tracing::debug!("Computed sequence: {:?}", sequence);
925 self.find_matches(sequence).await
926 }
927
928 async fn apply_event(&mut self, event: RouterEvent) {
929 self.event_tx.send(event).await.unwrap();
930 }
931
932 async fn remove_worker(&mut self, worker: WorkerId) {
933 self.remove_worker_tx.send(worker).await.unwrap();
934 }
935
936 fn shutdown(&mut self) {
937 self.cancel.cancel();
938 if let Some(task) = self.task.take() {
939 task.join().expect("Failed to join kv indexer task");
940 }
941 }
942
943 async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
944 let (resp_tx, resp_rx) = oneshot::channel();
945 let dump_req = DumpRequest { resp: resp_tx };
946
947 if let Err(e) = self.dump_tx.send(dump_req).await {
948 tracing::error!("Failed to send dump request: {:?}", e);
949 return Err(KvRouterError::IndexerOffline);
950 }
951
952 resp_rx
953 .await
954 .map_err(|_| KvRouterError::IndexerDroppedRequest)
955 }
956}
957
958#[derive(Debug, Clone)]
959pub struct ShardedMatchRequest {
960 sequence: Vec<LocalBlockHash>,
961 early_exit: bool,
962 resp: mpsc::Sender<OverlapScores>,
963}
964
965pub struct KvIndexerSharded {
980 cancel: CancellationToken,
982 kv_block_size: u32,
984 worker_assignments: HashMap<WorkerId, usize>,
985 worker_counts: Vec<usize>,
986
987 event_tx: Vec<mpsc::Sender<RouterEvent>>,
988 request_broadcast_tx: broadcast::Sender<ShardedMatchRequest>,
989 remove_worker_tx: Vec<mpsc::Sender<WorkerId>>,
990 dump_tx: Vec<mpsc::Sender<DumpRequest>>,
991 tasks: Vec<JoinHandle<()>>,
992}
993
994impl KvIndexerSharded {
995 pub fn new_with_frequency(
1007 token: CancellationToken,
1008 num_shards: usize,
1009 expiration_duration: Option<Duration>,
1010 kv_block_size: u32,
1011 metrics: Arc<KvIndexerMetrics>,
1012 ) -> Self {
1013 let worker_assignments: HashMap<WorkerId, usize> = HashMap::new();
1014 let worker_counts: Vec<usize> = vec![0; num_shards];
1015
1016 let mut event_tx = Vec::new();
1017 let mut remove_worker_tx = Vec::new();
1018 let mut dump_tx = Vec::new(); let mut tasks = Vec::new();
1020
1021 let (request_broadcast_tx, _) = broadcast::channel::<ShardedMatchRequest>(1048576);
1022
1023 for _ in 0..num_shards {
1024 let (shard_event_tx, mut shard_event_rx) = mpsc::channel::<RouterEvent>(2048);
1025 let (shard_remove_worker_tx, mut shard_remove_worker_rx) =
1026 mpsc::channel::<WorkerId>(16);
1027 let (shard_dump_tx, mut shard_dump_rx) = mpsc::channel::<DumpRequest>(16); let mut shard_broadcast_rx = request_broadcast_tx.subscribe();
1029 let cancel = token.clone();
1030 let metrics = metrics.clone();
1031
1032 event_tx.push(shard_event_tx);
1033 remove_worker_tx.push(shard_remove_worker_tx);
1034 dump_tx.push(shard_dump_tx); let runtime = tokio::runtime::Builder::new_multi_thread()
1037 .worker_threads(1)
1038 .enable_all()
1039 .build()
1040 .unwrap();
1041
1042 tasks.push(std::thread::spawn(move || {
1043 let local_set = tokio::task::LocalSet::new();
1044
1045 runtime.block_on(local_set.run_until(async move {
1046 tokio::task::spawn_local(async move {
1047 let mut trie = RadixTree::new_with_frequency(expiration_duration);
1048 loop {
1049 tokio::select! {
1050 biased;
1051
1052 _ = cancel.cancelled() => {
1053 tracing::trace!("KvCacheIndexer progress loop shutting down");
1054 return;
1055 }
1056
1057 Some(worker) = shard_remove_worker_rx.recv() => {
1058 trie.remove_worker(worker);
1059 }
1060
1061 Some(event) = shard_event_rx.recv() => {
1062 let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
1063 let result = trie.apply_event(event);
1064 metrics.increment_event_applied(event_type, result);
1065 }
1066
1067 Some(dump_req) = shard_dump_rx.recv() => {
1068 let events = trie.dump_tree_as_events();
1069 let _ = dump_req.resp.send(events);
1070 }
1071
1072 Ok(req) = shard_broadcast_rx.recv() => {
1073 let matches = trie.find_matches(req.sequence, req.early_exit);
1074 if let Err(e) = req.resp.send(matches).await {
1075 tracing::trace!("Failed to send match response: {:?}", e);
1076 }
1077 }
1078 }
1079 }
1080 })
1081 .await
1082 .unwrap()
1083 }));
1084
1085 tracing::debug!("KvCacheIndexer task completed");
1086 }));
1087 }
1088
1089 Self {
1090 cancel: token,
1091 kv_block_size,
1092 worker_assignments,
1093 worker_counts,
1094 event_tx,
1095 request_broadcast_tx,
1096 remove_worker_tx,
1097 dump_tx, tasks,
1099 }
1100 }
1101
1102 pub fn block_size(&self) -> u32 {
1103 self.kv_block_size
1104 }
1105
1106 pub fn new(
1107 token: CancellationToken,
1108 num_shards: usize,
1109 kv_block_size: u32,
1110 metrics: Arc<KvIndexerMetrics>,
1111 ) -> Self {
1112 Self::new_with_frequency(token, num_shards, None, kv_block_size, metrics)
1113 }
1114}
1115
1116#[async_trait]
1117impl KvIndexerInterface for KvIndexerSharded {
1118 async fn find_matches(
1119 &self,
1120 sequence: Vec<LocalBlockHash>,
1121 ) -> Result<OverlapScores, KvRouterError> {
1122 'match_loop: loop {
1123 let (match_tx, mut match_rx) = mpsc::channel(self.event_tx.len());
1124 self.request_broadcast_tx
1125 .send(ShardedMatchRequest {
1126 sequence: sequence.clone(),
1127 early_exit: false,
1128 resp: match_tx,
1129 })
1130 .map_err(|_| KvRouterError::IndexerOffline)?;
1131
1132 let mut scores = OverlapScores::new();
1133
1134 for response_num in 0..self.event_tx.len() {
1135 match match_rx.recv().await {
1136 Some(response) => {
1137 scores.scores.extend(response.scores);
1138
1139 if response_num == 0 {
1140 scores.frequencies = response.frequencies;
1141 } else {
1142 let diff = (response.frequencies.len() as i64)
1143 - (scores.frequencies.len() as i64);
1144
1145 if diff > 0 {
1146 scores.frequencies.extend(iter::repeat_n(0, diff as usize));
1147 }
1148
1149 for i in 0..response.frequencies.len() {
1150 scores.frequencies[i] += response.frequencies[i];
1151 }
1152 }
1153 }
1154 None => {
1155 continue 'match_loop;
1158 }
1159 }
1160 }
1161 return Ok(scores);
1162 }
1163 }
1164
1165 async fn find_matches_for_request(
1166 &self,
1167 tokens: &[u32],
1168 ) -> Result<OverlapScores, KvRouterError> {
1169 let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
1170 self.find_matches(sequence).await
1171 }
1172
1173 async fn apply_event(&mut self, event: RouterEvent) {
1174 #[allow(clippy::map_entry)]
1175 if !self.worker_assignments.contains_key(&event.worker_id) {
1176 let selected_shard = self
1178 .worker_counts
1179 .iter()
1180 .enumerate()
1181 .min_by_key(|&(_, value)| value)
1182 .unwrap()
1183 .0;
1184
1185 self.worker_assignments
1186 .insert(event.worker_id, selected_shard);
1187 self.worker_counts[selected_shard] += 1;
1188 }
1189
1190 self.event_tx[self.worker_assignments[&event.worker_id]]
1191 .send(event)
1192 .await
1193 .unwrap();
1194 }
1195
1196 async fn remove_worker(&mut self, worker: WorkerId) {
1197 if let Some((_, shard)) = self.worker_assignments.remove_entry(&worker) {
1198 self.worker_counts[shard] -= 1;
1199 self.remove_worker_tx[shard].send(worker).await.unwrap();
1200 }
1201 }
1202
1203 fn shutdown(&mut self) {
1205 self.cancel.cancel();
1206 while !self.tasks.is_empty() {
1207 self.tasks.pop().unwrap().join().unwrap();
1208 }
1209 }
1210
1211 async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
1212 let mut all_events = Vec::new();
1213
1214 let mut receivers = Vec::new();
1216
1217 for shard_dump_tx in &self.dump_tx {
1218 let (resp_tx, resp_rx) = oneshot::channel();
1219 let dump_req = DumpRequest { resp: resp_tx };
1220
1221 if let Err(e) = shard_dump_tx.send(dump_req).await {
1222 tracing::error!("Failed to send dump request to shard: {:?}", e);
1223 return Err(KvRouterError::IndexerOffline);
1224 }
1225
1226 receivers.push(resp_rx);
1227 }
1228
1229 for resp_rx in receivers {
1231 match resp_rx.await {
1232 Ok(events) => all_events.extend(events),
1233 Err(_) => return Err(KvRouterError::IndexerDroppedRequest),
1234 }
1235 }
1236
1237 Ok(all_events)
1238 }
1239}
1240
1241#[cfg(test)]
1242mod tests {
1243
1244 use super::*;
1245 use rstest::rstest;
1246 use rstest_reuse::{self, *};
1247 use tokio::time;
1248 use tokio_util::sync::CancellationToken;
1249
1250 fn setup() {
1251 dynamo_runtime::logging::init();
1252 }
1253
1254 fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
1255 hashes
1256 .iter()
1257 .map(|i| KvCacheStoredBlockData {
1258 tokens_hash: LocalBlockHash(*i),
1259 block_hash: ExternalSequenceBlockHash(*i * 100),
1260 })
1261 .collect()
1262 }
1263
1264 fn add_blocks(
1265 hashes: Vec<u64>,
1266 parent_hash: Option<ExternalSequenceBlockHash>,
1267 ) -> KvCacheEventData {
1268 KvCacheEventData::Stored(KvCacheStoreData {
1269 parent_hash,
1270 blocks: make_blocks(hashes),
1271 })
1272 }
1273
1274 fn create_store_event(
1275 worker_id: WorkerId,
1276 event_id: u64,
1277 hashes: Vec<u64>,
1278 parent: Option<ExternalSequenceBlockHash>,
1279 ) -> RouterEvent {
1280 RouterEvent {
1281 worker_id,
1282 event: KvCacheEvent {
1283 event_id,
1284 data: add_blocks(hashes, parent),
1285 },
1286 }
1287 }
1288
1289 fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
1290 RouterEvent {
1291 worker_id,
1292 event: KvCacheEvent {
1293 event_id,
1294 data: KvCacheEventData::Removed(KvCacheRemoveData {
1295 block_hashes: hashes
1296 .iter()
1297 .map(|i| ExternalSequenceBlockHash(*i * 100))
1298 .collect(),
1299 }),
1300 },
1301 }
1302 }
1303
1304 #[test]
1305 fn test_radix_tree() {
1306 setup();
1307
1308 let mut trie = RadixTree::new();
1309
1310 let worker_1 = 0;
1311 let worker_2 = 1;
1312
1313 trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
1314 .unwrap();
1315
1316 let scores = trie.find_matches(
1317 vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
1318 false,
1319 );
1320 assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
1321
1322 assert_eq!(trie.lookup.len(), 1);
1323 assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1324 assert_eq!(trie.root.borrow().workers.len(), 0);
1325 assert_eq!(trie.root.borrow().children.len(), 1);
1326 assert_eq!(
1327 trie.root
1328 .borrow()
1329 .children
1330 .get(&LocalBlockHash(1))
1331 .unwrap()
1332 .borrow()
1333 .workers
1334 .len(),
1335 1
1336 );
1337 assert_eq!(
1338 trie.root
1339 .borrow()
1340 .children
1341 .get(&LocalBlockHash(1))
1342 .unwrap()
1343 .borrow()
1344 .children
1345 .len(),
1346 1
1347 );
1348
1349 trie.apply_event(create_store_event(worker_2, 1, vec![1, 4, 5], None))
1350 .unwrap();
1351
1352 let scores = trie.find_matches(
1353 vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
1354 false,
1355 );
1356 assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
1357 assert_eq!(scores.scores.get(&worker_2).unwrap(), &1);
1358
1359 assert_eq!(trie.lookup.len(), 2);
1360 assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1361 assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 3);
1362 assert_eq!(trie.root.borrow().workers.len(), 0);
1363 assert_eq!(trie.root.borrow().children.len(), 1);
1364 assert_eq!(
1365 trie.root
1366 .borrow()
1367 .children
1368 .get(&LocalBlockHash(1))
1369 .unwrap()
1370 .borrow()
1371 .workers
1372 .len(),
1373 2
1374 );
1375 assert_eq!(
1376 trie.root
1377 .borrow()
1378 .children
1379 .get(&LocalBlockHash(1))
1380 .unwrap()
1381 .borrow()
1382 .children
1383 .len(),
1384 2
1385 );
1386
1387 trie.apply_event(create_remove_event(worker_2, 2, vec![5]))
1388 .unwrap();
1389 assert_eq!(trie.lookup.len(), 2);
1390 assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1391 assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 2);
1392 assert_eq!(trie.root.borrow().workers.len(), 0);
1393 assert_eq!(trie.root.borrow().children.len(), 1);
1394 assert_eq!(
1395 trie.root
1396 .borrow()
1397 .children
1398 .get(&LocalBlockHash(1))
1399 .unwrap()
1400 .borrow()
1401 .workers
1402 .len(),
1403 2
1404 );
1405 assert_eq!(
1406 trie.root
1407 .borrow()
1408 .children
1409 .get(&LocalBlockHash(1))
1410 .unwrap()
1411 .borrow()
1412 .children
1413 .len(),
1414 2
1415 );
1416
1417 trie.apply_event(create_remove_event(worker_2, 3, vec![4]))
1418 .unwrap();
1419
1420 assert_eq!(trie.lookup.len(), 2);
1421 assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1422 assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 1);
1423 assert_eq!(trie.root.borrow().workers.len(), 0);
1424 assert_eq!(trie.root.borrow().children.len(), 1);
1425 assert_eq!(
1426 trie.root
1427 .borrow()
1428 .children
1429 .get(&LocalBlockHash(1))
1430 .unwrap()
1431 .borrow()
1432 .workers
1433 .len(),
1434 2
1435 );
1436 assert_eq!(
1437 trie.root
1438 .borrow()
1439 .children
1440 .get(&LocalBlockHash(1))
1441 .unwrap()
1442 .borrow()
1443 .children
1444 .len(),
1445 2
1446 );
1447
1448 trie.apply_event(create_store_event(
1449 worker_2,
1450 4,
1451 vec![2, 6, 7],
1452 Some(ExternalSequenceBlockHash(100)),
1453 ))
1454 .unwrap();
1455
1456 let scores = trie.find_matches(
1457 vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
1458 false,
1459 );
1460 assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
1461 assert_eq!(scores.scores.get(&worker_2).unwrap(), &2);
1462
1463 assert_eq!(trie.lookup.len(), 2);
1464 assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1465 assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 4);
1466 assert_eq!(trie.root.borrow().workers.len(), 0);
1467 assert_eq!(trie.root.borrow().children.len(), 1);
1468 assert_eq!(
1469 trie.root
1470 .borrow()
1471 .children
1472 .get(&LocalBlockHash(1))
1473 .unwrap()
1474 .borrow()
1475 .workers
1476 .len(),
1477 2
1478 );
1479 assert_eq!(
1480 trie.root
1481 .borrow()
1482 .children
1483 .get(&LocalBlockHash(1))
1484 .unwrap()
1485 .borrow()
1486 .children
1487 .len(),
1488 2
1489 );
1490 assert_eq!(
1491 trie.lookup
1492 .get(&worker_1)
1493 .unwrap()
1494 .get(&ExternalSequenceBlockHash(200))
1495 .unwrap()
1496 .borrow()
1497 .workers
1498 .len(),
1499 2
1500 );
1501 assert_eq!(
1502 trie.lookup
1503 .get(&worker_2)
1504 .unwrap()
1505 .get(&ExternalSequenceBlockHash(200))
1506 .unwrap()
1507 .borrow()
1508 .workers
1509 .len(),
1510 2
1511 );
1512 }
1513
1514 #[test]
1515 fn test_radix_tree_apply_event_errors() {
1516 let mut trie = RadixTree::new();
1517 let worker_0 = 0;
1518
1519 let result = trie.apply_event(create_store_event(
1521 worker_0,
1522 0,
1523 vec![1, 2, 3],
1524 Some(ExternalSequenceBlockHash(12345)),
1525 ));
1526 assert!(result.is_err());
1527 assert!(matches!(
1528 result.unwrap_err(),
1529 KvCacheEventError::ParentBlockNotFound
1530 ));
1531
1532 let result = trie.apply_event(create_remove_event(worker_0, 0, vec![1, 2, 3]));
1534 assert!(result.is_err());
1535 assert!(matches!(
1536 result.unwrap_err(),
1537 KvCacheEventError::BlockNotFound
1538 ));
1539 }
1540
1541 #[test]
1542 fn test_remove_worker() {
1543 setup();
1544 let mut trie = RadixTree::new();
1545
1546 let worker_0 = 0;
1547 let worker_1 = 1;
1548
1549 assert!(
1550 trie.find_matches(vec![LocalBlockHash(0)], false)
1551 .scores
1552 .is_empty()
1553 );
1554
1555 trie.apply_event(create_store_event(worker_0, 0, vec![0], None))
1556 .unwrap();
1557 trie.apply_event(create_store_event(worker_1, 0, vec![0], None))
1558 .unwrap();
1559
1560 let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
1561 assert!(result.len() == 2 && result[&worker_0] == 1 && result[&worker_1] == 1);
1562
1563 trie.remove_worker(worker_0);
1564
1565 let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
1566 assert!(result.len() == 1 && result[&worker_1] == 1);
1567 }
1568
1569 #[test]
1570 fn test_clear_all_blocks() {
1571 let mut trie = RadixTree::new();
1572
1573 let worker_0 = 0;
1574 let worker_1 = 1;
1575
1576 assert!(
1577 trie.find_matches(vec![LocalBlockHash(0)], false)
1578 .scores
1579 .is_empty()
1580 );
1581
1582 trie.clear_all_blocks(worker_0);
1584 assert!(!trie.lookup.contains_key(&worker_0));
1585
1586 trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 3], None))
1588 .unwrap();
1589 trie.apply_event(create_store_event(worker_1, 0, vec![0, 2, 3], None))
1590 .unwrap();
1591
1592 let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
1593 assert!(result.len() == 2 && result[&worker_0] == 1 && result[&worker_1] == 1);
1594
1595 trie.clear_all_blocks(worker_0);
1596
1597 assert!(trie.lookup.contains_key(&worker_0));
1598 assert!(trie.lookup.get(&worker_0).unwrap().is_empty());
1599 let result = trie
1600 .find_matches(vec![LocalBlockHash(0), LocalBlockHash(2)], false)
1601 .scores;
1602 assert_eq!(result.len(), 1);
1603 assert_eq!(result[&worker_1], 2);
1604 let result = trie
1605 .find_matches(
1606 vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(3)],
1607 false,
1608 )
1609 .scores;
1610 assert_eq!(result.len(), 1);
1611 assert_eq!(result[&worker_1], 1);
1612
1613 trie.apply_event(create_store_event(worker_0, 0, vec![4, 5], None))
1615 .unwrap();
1616 let result = trie
1617 .find_matches(vec![LocalBlockHash(4), LocalBlockHash(5)], false)
1618 .scores;
1619 assert_eq!(result.len(), 1);
1620 assert_eq!(result[&worker_0], 2);
1621
1622 trie.clear_all_blocks(worker_0);
1624 trie.clear_all_blocks(worker_0);
1625 assert!(trie.lookup.contains_key(&worker_0));
1626
1627 trie.clear_all_blocks(worker_0);
1629 trie.clear_all_blocks(worker_1);
1630 assert!(!trie.lookup.is_empty());
1631 assert!(trie.lookup.get(&worker_0).unwrap().is_empty());
1632 assert!(trie.lookup.get(&worker_1).unwrap().is_empty());
1633
1634 trie.apply_event(create_store_event(worker_0, 0, vec![6], None))
1636 .unwrap();
1637 trie.apply_event(create_store_event(worker_1, 0, vec![6], None))
1638 .unwrap();
1639 trie.remove_worker(worker_0);
1640 trie.clear_all_blocks(worker_0);
1641 assert!(!trie.lookup.contains_key(&worker_0));
1642 let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores;
1643 assert_eq!(result.len(), 1);
1644 assert_eq!(result[&worker_1], 1);
1645
1646 let worker_fake = 2;
1648 assert!(!trie.lookup.contains_key(&worker_fake));
1649 trie.clear_all_blocks(worker_fake);
1650 assert!(!trie.lookup.contains_key(&worker_fake));
1651 assert!(trie.lookup.contains_key(&worker_1));
1652 let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores;
1653 assert_eq!(result.len(), 1);
1654 assert_eq!(result[&worker_1], 1);
1655 }
1656
1657 #[test]
1658 fn test_early_stopping() {
1659 setup();
1660 let mut trie = RadixTree::new();
1661
1662 let worker_0 = 0;
1663 let worker_1 = 1;
1664
1665 trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 2], None))
1666 .unwrap();
1667 trie.apply_event(create_store_event(worker_1, 0, vec![0], None))
1668 .unwrap();
1669
1670 let result = trie
1671 .find_matches(
1672 vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(2)],
1673 true,
1674 )
1675 .scores;
1676
1677 assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1);
1678
1679 let result = trie
1680 .find_matches(vec![LocalBlockHash(0), LocalBlockHash(1)], true)
1681 .scores;
1682 assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1);
1683 }
1684
1685 #[rstest]
1686 #[case(11)]
1687 #[case(32)]
1688 #[case(64)]
1689 fn test_compute_block_hash_for_seq(#[case] kv_block_size: u32) {
1690 setup();
1691 let sequence = (0..kv_block_size).collect::<Vec<u32>>();
1693 let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
1694 assert_eq!(hashes.len(), 1);
1695
1696 let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>();
1698 let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
1699 assert_eq!(hashes.len(), 1);
1700
1701 let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>();
1703 let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
1704 assert_eq!(hashes.len(), 2);
1705 }
1706
1707 fn make_indexer(
1708 token: &CancellationToken,
1709 num_shards: usize,
1710 kv_block_size: u32,
1711 ) -> Box<dyn KvIndexerInterface> {
1712 let metrics = KvIndexerMetrics::new_unregistered();
1713 if num_shards == 1 {
1714 Box::new(KvIndexer::new(token.clone(), kv_block_size, metrics.into()))
1715 } else {
1716 Box::new(KvIndexerSharded::new(
1717 token.clone(),
1718 num_shards,
1719 kv_block_size,
1720 metrics.into(),
1721 ))
1722 }
1723 }
1724
1725 #[template]
1726 #[rstest]
1727 fn indexer_template(
1728 #[values(1, 3, 8)] num_shards: usize,
1729 #[values(11, 32, 64)] kv_block_size: usize,
1730 ) {
1731 }
1732
1733 #[tokio::test]
1734 #[apply(indexer_template)]
1735 async fn test_kv_indexer_new(num_shards: usize, kv_block_size: u32) {
1736 setup();
1737 let token: CancellationToken = CancellationToken::new();
1738 let _ = make_indexer(&token, num_shards, kv_block_size);
1739 }
1740
1741 #[tokio::test]
1742 #[apply(indexer_template)]
1743 async fn test_find_matches(num_shards: usize, kv_block_size: u32) {
1744 setup();
1745 let token = CancellationToken::new();
1746 let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1747
1748 let sequence = vec![compute_block_hash(b"test data")];
1749 let scores = kv_indexer.find_matches(sequence).await;
1750
1751 assert!(scores.unwrap().scores.is_empty());
1752 }
1753
1754 #[tokio::test]
1755 #[apply(indexer_template)]
1756 async fn test_find_matches_for_request(num_shards: usize, kv_block_size: u32) {
1757 setup();
1758 let token = CancellationToken::new();
1759 let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1760
1761 let tokens = vec![1, 2, 3, 4];
1762 let scores = kv_indexer.find_matches_for_request(&tokens).await;
1763
1764 assert!(scores.unwrap().scores.is_empty());
1765 }
1766
1767 #[tokio::test]
1768 #[apply(indexer_template)]
1769 async fn test_apply_event(num_shards: usize, kv_block_size: u32) {
1770 setup();
1771 let worker_id = 0;
1772
1773 let token = CancellationToken::new();
1774 let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1775
1776 let event = create_store_event(worker_id, 1, vec![1, 2, 3], None);
1777 kv_indexer.apply_event(event).await;
1778
1779 }
1781
1782 #[tokio::test]
1783 #[apply(indexer_template)]
1784 async fn test_shutdown(num_shards: usize, kv_block_size: u32) {
1785 setup();
1786 let token = CancellationToken::new();
1787 let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1788
1789 kv_indexer.shutdown();
1790 }
1791
1792 #[tokio::test]
1793 #[apply(indexer_template)]
1794 async fn test_frequency(num_shards: usize, kv_block_size: u32) {
1795 const ONE_MILLIS: Duration = Duration::from_millis(1);
1796
1797 setup();
1798 let mut kv_indexer: Box<dyn KvIndexerInterface>;
1799 let token = CancellationToken::new();
1800 let expiration = Duration::from_millis(50);
1801 let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
1802
1803 if num_shards == 1 {
1804 kv_indexer = Box::new(KvIndexer::new_with_frequency(
1805 token,
1806 Some(expiration),
1807 kv_block_size,
1808 metrics,
1809 ));
1810 } else {
1811 kv_indexer = Box::new(KvIndexerSharded::new_with_frequency(
1812 token,
1813 num_shards,
1814 Some(expiration),
1815 kv_block_size,
1816 metrics,
1817 ));
1818 }
1819
1820 let block_hashes = vec![
1822 LocalBlockHash(1),
1823 LocalBlockHash(2),
1824 LocalBlockHash(3),
1825 LocalBlockHash(4),
1826 ];
1827
1828 let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1829 assert_eq!(
1830 overlap.frequencies.len(),
1831 0,
1832 "Should be no cached blocks yet"
1833 );
1834
1835 let worker_id = 0;
1837 let event = create_store_event(worker_id, 0, vec![1, 2, 3, 4], None);
1838 kv_indexer.apply_event(event).await;
1839
1840 let mut overlap = OverlapScores::default();
1843 let timeout = Duration::from_millis(10);
1844 let start = Instant::now();
1845 while overlap.scores.is_empty() && Instant::now().duration_since(start) < timeout {
1846 time::sleep(ONE_MILLIS).await;
1847 overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1848 }
1849 assert_eq!(
1850 overlap.scores.len(),
1851 1,
1852 "One worker has these blocks cached"
1853 );
1854 assert_eq!(
1855 overlap.frequencies.len(),
1856 0,
1857 "Blocks have not previously been accessed"
1858 );
1859
1860 let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1862 assert_eq!(overlap.scores.len(), 1, "Still one worker matches");
1863 assert_eq!(
1864 overlap.frequencies,
1865 vec![1, 1, 1, 1],
1866 "We should see the first access now"
1867 );
1868
1869 time::sleep(expiration + Duration::from_millis(10)).await;
1871
1872 let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1874 assert_eq!(
1875 overlap.frequencies.len(),
1876 0,
1877 "Blocks were accessed too long ago"
1878 );
1879
1880 let _ = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1882
1883 let overlap = kv_indexer
1885 .find_matches(block_hashes[0..3].to_vec())
1886 .await
1887 .unwrap();
1888 assert_eq!(overlap.frequencies, vec![2, 2, 2]);
1890
1891 let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1893 assert_eq!(overlap.frequencies, vec![3, 3, 3, 2]);
1894 }
1895
1896 #[test]
1897 fn test_router_event_new() {
1898 setup();
1899 let worker_id = 0;
1900 let kv_cache_event = KvCacheEvent {
1901 event_id: 1,
1902 data: KvCacheEventData::Stored(KvCacheStoreData {
1903 parent_hash: None,
1904 blocks: vec![KvCacheStoredBlockData {
1905 block_hash: ExternalSequenceBlockHash(0),
1906 tokens_hash: LocalBlockHash(13226331709069118873),
1907 }],
1908 }),
1909 };
1910 let router_event = RouterEvent::new(worker_id, kv_cache_event);
1911
1912 assert_eq!(router_event.worker_id, worker_id);
1913 assert_eq!(router_event.event.event_id, 1);
1914 if let KvCacheEventData::Stored(store_op) = &router_event.event.data {
1915 assert_eq!(store_op.blocks.len(), 1);
1916 assert_eq!(
1917 store_op.blocks[0].tokens_hash,
1918 compute_block_hash(b"test data")
1919 );
1920 assert_eq!(store_op.blocks[0].block_hash, ExternalSequenceBlockHash(0));
1921 } else {
1922 panic!("Expected KvCacheEventData::Stored");
1923 }
1924 }
1925
1926 #[test]
1927 fn test_radix_tree_default() {
1928 setup();
1929 let radix_tree: RadixTree = Default::default();
1930 assert!(radix_tree.root.borrow().children.is_empty());
1931 assert!(radix_tree.root.borrow().workers.is_empty());
1932 assert!(radix_tree.lookup.is_empty());
1933 }
1934
1935 #[test]
1936 fn test_overlap_scores_default() {
1937 setup();
1938 let overlap_scores: OverlapScores = Default::default();
1939 assert!(overlap_scores.scores.is_empty());
1940 }
1941
1942 #[tokio::test]
1943 async fn test_dump_tree_as_events_round_trip() {
1944 setup();
1945
1946 let kv_block_size = 32;
1948 let num_shards = 2;
1949 let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
1950
1951 let token1 = CancellationToken::new();
1953 let mut original_indexer =
1954 KvIndexerSharded::new(token1.clone(), num_shards, kv_block_size, metrics.clone());
1955
1956 let worker_0 = 0;
1957 let worker_1 = 1;
1958 let worker_2 = 2;
1959
1960 original_indexer
1962 .apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
1963 .await;
1964
1965 original_indexer
1966 .apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
1967 .await;
1968 original_indexer
1969 .apply_event(create_store_event(
1970 worker_1,
1971 2,
1972 vec![4, 5],
1973 Some(ExternalSequenceBlockHash(100)),
1974 ))
1975 .await;
1976
1977 original_indexer
1978 .apply_event(create_store_event(worker_2, 3, vec![6, 7], None))
1979 .await;
1980
1981 original_indexer
1982 .apply_event(create_store_event(
1983 worker_0,
1984 4,
1985 vec![4],
1986 Some(ExternalSequenceBlockHash(100)),
1987 ))
1988 .await;
1989
1990 tokio::time::sleep(Duration::from_millis(50)).await;
1992
1993 let dump1 = original_indexer.dump_events().await.unwrap();
1995 println!("Dumped {} events", dump1.len());
1996
1997 let token2 = CancellationToken::new();
1999 let mut reconstructed_indexer =
2000 KvIndexerSharded::new(token2.clone(), num_shards, kv_block_size, metrics);
2001
2002 for event in &dump1 {
2003 reconstructed_indexer.apply_event(event.clone()).await;
2004 }
2005
2006 tokio::time::sleep(Duration::from_millis(50)).await;
2008
2009 let dump2 = reconstructed_indexer.dump_events().await.unwrap();
2011
2012 let mut sorted_dump1 = dump1.clone();
2014 let mut sorted_dump2 = dump2.clone();
2015
2016 let sort_key = |event: &RouterEvent| {
2018 if let KvCacheEventData::Stored(ref data) = event.event.data {
2019 (
2020 event.worker_id,
2021 data.blocks.first().map(|b| b.tokens_hash.0).unwrap_or(0),
2022 data.parent_hash.map(|h| h.0).unwrap_or(0),
2023 )
2024 } else {
2025 (event.worker_id, 0, 0)
2026 }
2027 };
2028
2029 sorted_dump1.sort_by_key(sort_key);
2030 sorted_dump2.sort_by_key(sort_key);
2031
2032 assert_eq!(
2034 sorted_dump1.len(),
2035 sorted_dump2.len(),
2036 "Dumps have different lengths: {} vs {}",
2037 sorted_dump1.len(),
2038 sorted_dump2.len()
2039 );
2040
2041 for (i, (event1, event2)) in sorted_dump1.iter().zip(sorted_dump2.iter()).enumerate() {
2043 assert_eq!(
2044 event1.worker_id, event2.worker_id,
2045 "Event {} worker_id mismatch",
2046 i
2047 );
2048
2049 if let (KvCacheEventData::Stored(data1), KvCacheEventData::Stored(data2)) =
2050 (&event1.event.data, &event2.event.data)
2051 {
2052 assert_eq!(
2053 data1.parent_hash, data2.parent_hash,
2054 "Event {} parent_hash mismatch",
2055 i
2056 );
2057 assert_eq!(
2058 data1.blocks.len(),
2059 data2.blocks.len(),
2060 "Event {} blocks length mismatch",
2061 i
2062 );
2063
2064 for (j, (block1, block2)) in
2065 data1.blocks.iter().zip(data2.blocks.iter()).enumerate()
2066 {
2067 assert_eq!(
2068 block1.tokens_hash, block2.tokens_hash,
2069 "Event {} block {} tokens_hash mismatch",
2070 i, j
2071 );
2072 assert_eq!(
2073 block1.block_hash, block2.block_hash,
2074 "Event {} block {} block_hash mismatch",
2075 i, j
2076 );
2077 }
2078 } else {
2079 panic!("Expected Stored events in both dumps");
2080 }
2081 }
2082
2083 for test_seq in [
2085 vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
2086 vec![LocalBlockHash(1), LocalBlockHash(4), LocalBlockHash(5)],
2087 vec![LocalBlockHash(6), LocalBlockHash(7)],
2088 vec![LocalBlockHash(1)],
2089 ] {
2090 let scores1 = original_indexer
2091 .find_matches(test_seq.clone())
2092 .await
2093 .unwrap();
2094 let scores2 = reconstructed_indexer
2095 .find_matches(test_seq.clone())
2096 .await
2097 .unwrap();
2098
2099 let mut scores1_sorted: Vec<_> = scores1.scores.iter().collect();
2101 let mut scores2_sorted: Vec<_> = scores2.scores.iter().collect();
2102 scores1_sorted.sort_by_key(|(k, _)| *k);
2103 scores2_sorted.sort_by_key(|(k, _)| *k);
2104
2105 assert_eq!(
2106 scores1_sorted, scores2_sorted,
2107 "Match scores differ for sequence {:?}",
2108 test_seq
2109 );
2110 }
2111
2112 original_indexer.shutdown();
2114 reconstructed_indexer.shutdown();
2115 }
2116
2117 #[test]
2118 fn test_increment_event_applied() {
2119 let metrics = KvIndexerMetrics::new_unregistered();
2120
2121 metrics.increment_event_applied(METRIC_EVENT_STORED, Ok(()));
2122 assert_eq!(
2123 metrics
2124 .kv_cache_events_applied
2125 .get_metric_with_label_values(&[METRIC_EVENT_STORED, METRIC_STATUS_OK])
2126 .unwrap()
2127 .get(),
2128 1
2129 );
2130
2131 metrics.increment_event_applied(
2132 METRIC_EVENT_STORED,
2133 Err(KvCacheEventError::ParentBlockNotFound),
2134 );
2135 assert_eq!(
2136 metrics
2137 .kv_cache_events_applied
2138 .get_metric_with_label_values(&[
2139 METRIC_EVENT_STORED,
2140 METRIC_STATUS_PARENT_NOT_FOUND
2141 ])
2142 .unwrap()
2143 .get(),
2144 1
2145 );
2146
2147 metrics
2148 .increment_event_applied(METRIC_EVENT_REMOVED, Err(KvCacheEventError::BlockNotFound));
2149 assert_eq!(
2150 metrics
2151 .kv_cache_events_applied
2152 .get_metric_with_label_values(&[
2153 METRIC_EVENT_REMOVED,
2154 METRIC_STATUS_BLOCK_NOT_FOUND
2155 ])
2156 .unwrap()
2157 .get(),
2158 1
2159 );
2160 }
2161}