dynamo_llm/kv_router/
indexer.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! KV RadixTree
5//!
6//! This module implements a key-value (KV) store using a Radix Tree structure to efficiently manage and retrieve data blocks.
7//! It is designed to support LLM (Large Language Model) inference by re-using a global KV cache.
8//!
9//! # Overview
10//!
11//! The main components of this module include:
12//!
13//! - **Radix Tree Structure**:
14//!   - The `RadixTree` struct represents the main data structure, with nodes (`RadixBlock`) containing children and associated worker IDs.
15//!   - It allows efficient storage and retrieval of data blocks based on their hashes.
16//!
17//! - **Event Handling**:
18//!   - The `RouterEvent` struct represents events emitted by LLM workers, which can be applied to the Radix Tree to update its state.
19//!   - The `KvIndexer` struct manages these events and match requests asynchronously using Tokio channels.
20//!
21//! - **Hash Computation**:
22//!   - Functions like `compute_block_hash` and `compute_block_hash_for_seq` compute hashes for data blocks and sequences of tokens, facilitating quick lookups.
23//!
24//! - **Concurrency and Asynchronous Operations**:
25//!   - The `KvIndexer` uses a single-threaded Tokio runtime to handle events and match requests concurrently, ensuring efficient processing without blocking.
26//!
27//! - **Match Requests**:
28//!   - The `MatchRequest` struct represents requests to find matches in the Radix Tree, returning overlap scores indicating the best matches.
29//!
30//! # Purpose
31//!
32//! This module provides a scalable and efficient way to manage and retrieve data blocks for LLM inference, leveraging a global KV cache to optimize performance.
33
34use async_trait::async_trait;
35use bytes::Bytes;
36use dynamo_runtime::{
37    component::Component,
38    metrics::{MetricsRegistry, prometheus_names::kvrouter},
39};
40use prometheus::{IntCounterVec, Opts};
41use serde::{Deserialize, Serialize};
42use std::{
43    cell::RefCell,
44    collections::{HashMap, HashSet, VecDeque},
45    iter,
46    rc::Rc,
47    sync::{Arc, OnceLock},
48    thread::JoinHandle,
49    time::{Duration, Instant},
50};
51use tokio::sync::{broadcast, mpsc, oneshot};
52use tokio_util::sync::CancellationToken;
53use xxhash_rust::xxh3;
54
55pub const XXH3_SEED: u64 = 1337;
56
57use crate::kv_router::protocols::*;
58use crate::tokens::SequenceHash;
59
60/// Errors that can occur in the KV Router.
61#[derive(Debug, thiserror::Error)]
62pub enum KvRouterError {
63    #[error("Block not found")]
64    BlockNotFound,
65
66    #[error("Indexer is offline")]
67    IndexerOffline,
68
69    #[error("Indexer is dropped request")]
70    IndexerDroppedRequest,
71}
72
73/// Errors that can occur during KV Cache Event processing.
74#[derive(Debug, thiserror::Error)]
75pub enum KvCacheEventError {
76    #[error("Failed to find parent block")]
77    ParentBlockNotFound,
78
79    #[error("Failed to find block")]
80    BlockNotFound,
81}
82
83/// Identifier of a LLM worker which emits events to the router.
84pub type WorkerId = i64;
85
86/// A shared reference to a [`RadixBlock`].
87type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
88
89pub fn compute_hash(data: &[u8]) -> u64 {
90    xxh3::xxh3_64_with_seed(data, XXH3_SEED)
91}
92
93/// Compute the hash of a local block.
94///
95/// ### Arguments
96///
97/// * `data` - A byte slice representing the data to hash.
98///
99/// ### Returns
100///
101/// A `LocalBlockHash` representing the computed hash.
102pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
103    LocalBlockHash(compute_hash(data))
104}
105
106// /// Updated version of the `compute_block_hash` function that included the lora_id
107// pub fn compute_block_hash_v2(token_id: &[u32], lora_id: u64) {
108//     let mut bytes = Vec::new();
109//     for token in token_id {
110//         bytes.extend_from_slice(&token.to_le_bytes());
111//     }
112//     bytes.extend_from_slice(&lora_id.to_le_bytes());
113//     let hash = xxh3::xxh3_64_with_seed(&bytes, XXH3_SEED);
114// }
115
116/// Compute the hash for a sequence of tokens.
117///
118/// ### Arguments
119///
120/// * `tokens` - A vector of `u32` tokens.
121///
122/// ### Returns
123///
124/// A vector of `LocalBlockHash` representing the computed hashes for each chunk of tokens.
125pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: u32) -> Vec<LocalBlockHash> {
126    tokens
127        .chunks_exact(kv_block_size as usize) // Split into chunks of kv_block_size elements
128        .map(|chunk| {
129            let bytes: Vec<u8> = chunk
130                .iter()
131                .flat_map(|&num| num.to_le_bytes()) // Convert each i32 to its little-endian bytes
132                .collect();
133
134            compute_block_hash(&Bytes::from(bytes)) // Convert the byte Vec to Bytes
135        })
136        .collect()
137}
138
139/// Compute rolling sequence hashes for a vector of block hashes.
140///
141/// This mirrors the behavior in tokens.rs where:
142/// - The first block's sequence hash equals its block hash
143/// - Subsequent blocks' sequence hash = hash([parent_sequence_hash, current_block_hash], seed)
144///
145/// ### Arguments
146///
147/// * `block_hashes` - A vector of `LocalBlockHash` values representing the block hashes.
148///
149/// ### Returns
150///
151/// A vector of u64 values representing the sequence hashes for each block.
152pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<SequenceHash> {
153    if block_hashes.is_empty() {
154        return Vec::new();
155    }
156
157    let mut sequence_hashes = Vec::with_capacity(block_hashes.len());
158    sequence_hashes.push(block_hashes[0].0);
159
160    for i in 1..block_hashes.len() {
161        let parent_seq_hash = sequence_hashes[i - 1];
162        let current_block_hash = block_hashes[i].0;
163
164        let combined = [parent_seq_hash, current_block_hash];
165        let bytes: Vec<u8> = combined.iter().flat_map(|&num| num.to_le_bytes()).collect();
166        let seq_hash = compute_hash(&bytes);
167        sequence_hashes.push(seq_hash);
168    }
169
170    sequence_hashes
171}
172
173/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`].
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct RouterEvent {
176    /// The ID of the worker emitting the event.
177    worker_id: WorkerId,
178    /// The cache event associated with the worker.
179    event: KvCacheEvent,
180}
181
182impl RouterEvent {
183    /// Create a new `RouterEvent`.
184    ///
185    /// ### Arguments
186    ///
187    /// * `worker_id` - The ID of the worker emitting the event.
188    /// * `event` - The cache event.
189    ///
190    /// ### Returns
191    ///
192    /// A new `RouterEvent`.
193    pub fn new(worker_id: WorkerId, event: KvCacheEvent) -> Self {
194        Self { worker_id, event }
195    }
196}
197
198/// A block in the Radix Tree.
199#[derive(Debug)]
200struct RadixBlock {
201    /// A map of child blocks, keyed by their local block hash.
202    children: HashMap<LocalBlockHash, SharedRadixBlock>,
203    /// A set of worker IDs associated with this block.
204    workers: HashSet<WorkerId>,
205    /// A buffer of times that this block was last traversed
206    recent_uses: VecDeque<Instant>,
207}
208
209impl RadixBlock {
210    /// Create a new `RadixBlock`.
211    ///
212    /// ### Returns
213    ///
214    /// A new `RadixBlock`.
215    pub fn new() -> Self {
216        Self {
217            children: HashMap::new(),
218            workers: HashSet::new(),
219            recent_uses: VecDeque::new(),
220        }
221    }
222}
223
224pub struct RadixTree {
225    /// This is the root of the radix/prefix tree
226    /// This will only contain root blocks
227    root: SharedRadixBlock,
228
229    /// This is a global lookup table for all blocks which will let you jump into
230    /// the radix tree at any point
231    /// Lookup is best case O(1) and worst case O(N); however, even constant in-time
232    /// could be expensive if N is large
233    /// We should monitor the size of this table and consider using a proper radix tree.
234    /// Transitioning to a radix tree only would require a change in the messaging structure
235    /// as the entire prefix would need to be sent. Alternatively, we could use block_depth
236    /// integers to indicate how many blocks to skip and use a radix/prefix tree at each level.
237    lookup: HashMap<WorkerId, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
238    /// The time buffer the radix tree should check when considering frequence of block accesses
239    expiration_duration: Option<Duration>,
240}
241
242impl Default for RadixTree {
243    fn default() -> Self {
244        Self::new()
245    }
246}
247
248impl RadixTree {
249    /// Create a new `RadixTree`.
250    ///
251    /// ### Returns
252    ///
253    /// A new `RadixTree`.
254    pub fn new_with_frequency(expiration_duration: Option<Duration>) -> Self {
255        Self {
256            root: Rc::new(RefCell::new(RadixBlock::new())),
257            lookup: HashMap::new(),
258            expiration_duration,
259        }
260    }
261
262    pub fn new() -> Self {
263        Self::new_with_frequency(None)
264    }
265
266    /// Traverse the radix tree to find the best match for a given sequence of [`LocalBlockHash`]es.
267    ///
268    /// ### Arguments
269    ///
270    /// * `sequence` - A vector of `LocalBlockHash` representing the sequence to match.
271    /// * `early_exit` - A boolean indicating whether to exit early if a single match is found.
272    ///
273    /// ### Returns
274    ///
275    /// An `OverlapScores` representing the match scores.
276    pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
277        let mut scores = OverlapScores::new();
278        let mut current = self.root.clone();
279        let now = Instant::now();
280        for block_hash in sequence {
281            let next_block = {
282                let current_borrow = current.borrow();
283                current_borrow.children.get(&block_hash).cloned()
284            };
285            if let Some(block) = next_block {
286                scores.update_scores(&block.borrow().workers);
287
288                if let Some(expiration_duration) = self.expiration_duration {
289                    let mut block_mut = block.borrow_mut();
290
291                    while let Some(access_time) = block_mut.recent_uses.front() {
292                        if now.duration_since(*access_time) > expiration_duration {
293                            block_mut.recent_uses.pop_front();
294                        } else {
295                            break;
296                        }
297                    }
298                    scores.add_frequency(block_mut.recent_uses.len());
299                    block_mut.recent_uses.push_back(now);
300                }
301
302                if early_exit && block.borrow().workers.len() == 1 {
303                    break;
304                }
305
306                current = block;
307            } else {
308                break;
309            }
310        }
311
312        scores
313    }
314
315    /// Apply a [`RouterEvent`] to the radix tree.
316    ///
317    /// ### Arguments
318    ///
319    /// * `event` - The `RouterEvent` to apply.
320    pub fn apply_event(&mut self, event: RouterEvent) -> Result<(), KvCacheEventError> {
321        let (worker_id, event) = (event.worker_id, event.event);
322        let (id, op) = (event.event_id, event.data);
323        tracing::trace!(id, "Store operation: {:?}", op);
324
325        let worker_lookup = self.lookup.entry(worker_id).or_default();
326
327        match op {
328            KvCacheEventData::Stored(op) => {
329                // find the parent block - if the parent exists it must be on our worker, if not,
330                // we check the radix tree's root to find it.
331                // this is the single most expensive lookup
332                let current = match op.parent_hash {
333                    Some(parent) => worker_lookup.get(&parent),
334                    None => Some(&self.root),
335                };
336
337                let mut current = match current {
338                    Some(current) => current.clone(),
339                    None => {
340                        tracing::warn!(
341                            worker_id = worker_id.to_string(),
342                            id,
343                            parent_hash = ?op.parent_hash,
344                            "Failed to find parent block; skipping store operation"
345                        );
346                        return Err(KvCacheEventError::ParentBlockNotFound);
347                    }
348                };
349
350                for block_id in op.blocks {
351                    let mut inner = current.borrow_mut();
352                    let block = match inner.children.get(&block_id.tokens_hash) {
353                        Some(block) => block.clone(),
354                        None => {
355                            // create new block - automatically added to the lookup table
356                            let new_block = worker_lookup
357                                .get(&block_id.block_hash)
358                                .cloned()
359                                .unwrap_or_else(|| Rc::new(RefCell::new(RadixBlock::new())));
360
361                            // insert into radix tree
362                            inner
363                                .children
364                                .insert(block_id.tokens_hash, new_block.clone());
365
366                            new_block
367                        }
368                    };
369
370                    // add our worker_id to the block
371                    block.borrow_mut().workers.insert(worker_id);
372
373                    // add the block to the worker_id lookup table
374                    worker_lookup.insert(block_id.block_hash, block.clone());
375
376                    // drop inner so we can shift current to this block
377                    drop(inner);
378
379                    current = block;
380                }
381                Ok(())
382            }
383            KvCacheEventData::Removed(remove) => {
384                // tracing::trace!(id, "KV Remove Operation: {:?}", op);
385                // let mut worker_lookup = self.lookup.get(&worker_id).expect("Worker not found");
386
387                for block in remove.block_hashes {
388                    // entry in radix tree
389                    // a small optimization would be to get the next block from the reduced set of children
390                    // in order to apply this optimization, we would need to know the list of blocks is always sorted
391                    // by parent -> child relationship
392                    let entry = match worker_lookup.get(&block) {
393                        Some(entry) => entry.clone(),
394                        None => {
395                            tracing::warn!(
396                                worker_id = worker_id.to_string(),
397                                id,
398                                "Failed to find block to remove; skipping remove operation"
399                            );
400                            return Err(KvCacheEventError::BlockNotFound);
401                        }
402                    };
403
404                    let mut guard = entry.borrow_mut();
405                    guard.workers.remove(&worker_id);
406                    if guard.workers.is_empty() {
407                        // if no worker are using this block, that is true for all children
408                        guard.children.clear();
409                    }
410                    // remove the block from the lookup table
411                    worker_lookup.remove(&block);
412                }
413                Ok(())
414            }
415            KvCacheEventData::Cleared => {
416                self.clear_all_blocks(worker_id);
417                Ok(())
418            }
419        }
420    }
421
422    pub fn remove_worker(&mut self, worker: WorkerId) {
423        if let Some((_, blocks)) = self.lookup.remove_entry(&worker) {
424            blocks.iter().for_each(|(_, block)| {
425                block.borrow_mut().workers.remove(&worker);
426            });
427        }
428    }
429
430    pub fn clear_all_blocks(&mut self, worker: WorkerId) {
431        // Check if the worker has any blocks to clear
432        if let Some(blocks) = self.lookup.get(&worker) {
433            let blocks_to_clear: Vec<_> = blocks.values().collect();
434
435            // Remove the worker from each block's workers set
436            blocks_to_clear.iter().for_each(|block| {
437                block.borrow_mut().workers.remove(&worker);
438            });
439
440            // Clear the worker's blocks
441            if let Some(worker_blocks) = self.lookup.get_mut(&worker) {
442                worker_blocks.clear();
443            }
444        }
445    }
446
447    /// Dump the radix tree as a series of RouterEvents that can reconstruct the tree.
448    /// Uses BFS traversal to ensure that the tree reconstruction is unique,
449    /// though the exact event ordering will be lost.
450    pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
451        let mut events = Vec::new();
452        let mut event_id = 0u64;
453
454        // BFS queue: (current_block, parent_external_hash, tokens_hash)
455        let mut queue = VecDeque::new();
456
457        // Process root's children first
458        let root_borrow = self.root.borrow();
459        for (tokens_hash, child_block) in &root_borrow.children {
460            queue.push_back((child_block.clone(), None, *tokens_hash));
461        }
462        drop(root_borrow);
463
464        while let Some((current_block, parent_external_hash, tokens_hash)) = queue.pop_front() {
465            let current_borrow = current_block.borrow();
466
467            // Closure to find external hash for a block in a worker's lookup
468            let find_external_hash = |worker_id: &WorkerId| {
469                self.lookup.get(worker_id).and_then(|worker_blocks| {
470                    worker_blocks
471                        .iter()
472                        .find(|(_, block)| Rc::ptr_eq(block, &current_block))
473                        .map(|(hash, _)| *hash)
474                })
475            };
476
477            // For each worker that has this block
478            for worker_id in &current_borrow.workers {
479                // Find the external hash for this block from the worker's lookup
480                let external_hash = find_external_hash(worker_id);
481
482                if let Some(block_hash) = external_hash {
483                    // Create a store event for this worker
484                    let event = RouterEvent {
485                        worker_id: *worker_id,
486                        event: KvCacheEvent {
487                            event_id,
488                            data: KvCacheEventData::Stored(KvCacheStoreData {
489                                parent_hash: parent_external_hash,
490                                blocks: vec![KvCacheStoredBlockData {
491                                    block_hash,
492                                    tokens_hash,
493                                }],
494                            }),
495                        },
496                    };
497                    events.push(event);
498                    event_id += 1;
499                }
500            }
501
502            // Add children to queue for BFS traversal
503            // We need to find any external hash for this block to use as parent
504            let any_external_hash = if !current_borrow.workers.is_empty() {
505                current_borrow
506                    .workers
507                    .iter()
508                    .next()
509                    .and_then(find_external_hash)
510            } else {
511                None
512            };
513
514            for (child_tokens_hash, child_block) in &current_borrow.children {
515                queue.push_back((child_block.clone(), any_external_hash, *child_tokens_hash));
516            }
517        }
518
519        events
520    }
521}
522
523/// Metrics for the KV Indexer.
524#[derive(Clone)]
525pub struct KvIndexerMetrics {
526    /// Counter of events applied.
527    pub kv_cache_events_applied: IntCounterVec,
528}
529
530/// Metric status labels.
531pub const METRIC_STATUS_OK: &str = "ok";
532pub const METRIC_STATUS_PARENT_NOT_FOUND: &str = "parent_block_not_found";
533pub const METRIC_STATUS_BLOCK_NOT_FOUND: &str = "block_not_found";
534
535/// Metric event labels.
536pub const METRIC_EVENT_STORED: &str = "stored";
537pub const METRIC_EVENT_REMOVED: &str = "removed";
538pub const METRIC_EVENT_CLEARED: &str = "cleared";
539
540static KV_INDEXER_METRICS: OnceLock<Arc<KvIndexerMetrics>> = OnceLock::new();
541
542impl KvIndexerMetrics {
543    fn new(kv_cache_events_applied: IntCounterVec) -> Self {
544        Self {
545            kv_cache_events_applied,
546        }
547    }
548
549    /// Creates a new KvIndexerMetrics from a Component, memoizing the result in
550    /// KV_INDEXER_METRICS to avoid duplicate registration issues.
551    pub fn from_component(component: &Component) -> Arc<Self> {
552        KV_INDEXER_METRICS.get_or_init(|| {
553            match component.create_intcountervec(
554                kvrouter::KV_CACHE_EVENTS_APPLIED,
555                "Total number of KV cache events applied to index",
556                &["event_type", "status"],
557                &[],
558            ) {
559                Ok(kv_cache_events_applied) => Arc::new(Self::new(kv_cache_events_applied)),
560                Err(e) => {
561                    tracing::warn!("Failed to create kv indexer metrics from component: {}. Using unregistered metrics as fallback.", e);
562                    Arc::new(Self::new_unregistered())
563                }
564            }
565        }).clone()
566    }
567
568    /// Creates a new KvIndexerMetrics which is not registered with a MetricsRegistry.
569    /// This may be used for tests or as a fallback for when a MetricsRegistry is not available / has errored.
570    pub fn new_unregistered() -> Self {
571        Self {
572            kv_cache_events_applied: IntCounterVec::new(
573                Opts::new(
574                    kvrouter::KV_CACHE_EVENTS_APPLIED,
575                    "Total number of KV cache events applied to index",
576                ),
577                &["event_type", "status"],
578            )
579            .unwrap(),
580        }
581    }
582
583    pub fn get_event_type(event_data: &KvCacheEventData) -> &'static str {
584        match event_data {
585            KvCacheEventData::Stored(_) => METRIC_EVENT_STORED,
586            KvCacheEventData::Removed(_) => METRIC_EVENT_REMOVED,
587            KvCacheEventData::Cleared => METRIC_EVENT_CLEARED,
588        }
589    }
590
591    pub fn increment_event_applied(
592        &self,
593        event_type: &'static str,
594        result: Result<(), KvCacheEventError>,
595    ) {
596        match result {
597            Ok(_) => {
598                self.kv_cache_events_applied
599                    .with_label_values(&[event_type, METRIC_STATUS_OK])
600                    .inc_by(1);
601            }
602            Err(e) => {
603                let error_label = match e {
604                    KvCacheEventError::ParentBlockNotFound => METRIC_STATUS_PARENT_NOT_FOUND,
605                    KvCacheEventError::BlockNotFound => METRIC_STATUS_BLOCK_NOT_FOUND,
606                };
607                self.kv_cache_events_applied
608                    .with_label_values(&[event_type, error_label])
609                    .inc_by(1);
610            }
611        }
612    }
613}
614
615/// Scores representing the overlap of workers.
616#[derive(Debug, Clone, Serialize, Deserialize)]
617pub struct OverlapScores {
618    // map of worker_id to score
619    pub scores: HashMap<WorkerId, u32>,
620    // List of frequencies that the blocks have been accessed. Entries with value 0 are omitted.
621    pub frequencies: Vec<usize>,
622}
623
624impl Default for OverlapScores {
625    fn default() -> Self {
626        Self::new()
627    }
628}
629
630impl OverlapScores {
631    /// Create a new `OverlapScores`.
632    ///
633    /// ### Returns
634    ///
635    /// A new `OverlapScores`.
636    pub fn new() -> Self {
637        Self {
638            scores: HashMap::new(),
639            frequencies: Vec::with_capacity(32),
640        }
641    }
642
643    /// Update the scores with a set of workers.
644    ///
645    /// ### Arguments
646    ///
647    /// * `workers` - A reference to a `HashSet` of `WorkerId`s.
648    pub fn update_scores(&mut self, workers: &HashSet<WorkerId>) {
649        for worker in workers {
650            let score = self.scores.entry(*worker).or_insert(0);
651            *score += 1;
652        }
653    }
654
655    /// Add an entry in the frequency list.
656    pub fn add_frequency(&mut self, frequency: usize) {
657        if frequency != 0 {
658            self.frequencies
659                .last()
660                .inspect(|elem| debug_assert!(**elem >= frequency));
661            self.frequencies.push(frequency);
662        }
663    }
664}
665
666/// A request to find matches in the Radix Tree.
667pub struct MatchRequest {
668    /// A vector of `LocalBlockHash` representing the sequence to match.
669    sequence: Vec<LocalBlockHash>,
670    /// A boolean indicating whether to exit early if a single match is found.
671    early_exit: bool,
672    /// A channel sender to send the `OverlapScores` response.
673    resp: oneshot::Sender<OverlapScores>,
674}
675
676/// A request to dump the tree as events
677pub struct DumpRequest {
678    /// Channel to send the dumped events
679    pub resp: oneshot::Sender<Vec<RouterEvent>>,
680}
681
682#[async_trait]
683pub trait KvIndexerInterface {
684    /// Find matches for a given sequence of `LocalBlockHash`es.
685    ///
686    /// ### Arguments
687    ///
688    /// * `sequence` - A vector of `LocalBlockHash` representing the sequence to match.
689    ///
690    /// ### Returns
691    ///
692    /// An `OverlapScores` representing the match scores.
693    async fn find_matches(
694        &self,
695        sequence: Vec<LocalBlockHash>,
696    ) -> Result<OverlapScores, KvRouterError>;
697
698    /// Find matches for a given sequence of tokens.
699    ///
700    /// ### Arguments
701    ///
702    /// * `tokens` - A vector of `u32` tokens.
703    ///
704    /// ### Returns
705    ///
706    /// An `OverlapScores` representing the match scores.
707    async fn find_matches_for_request(
708        &self,
709        tokens: &[u32],
710    ) -> Result<OverlapScores, KvRouterError>;
711
712    /// Apply a `RouterEvent` to the KV store.
713    ///
714    /// ### Arguments
715    ///
716    /// * `event` - The `RouterEvent` to apply.
717    async fn apply_event(&mut self, event: RouterEvent);
718
719    /// Remove a worker's entries from the trie.
720    ///
721    /// ### Arguments
722    ///
723    /// * `worker` - The worker to remove from the trie.
724    async fn remove_worker(&mut self, worker: WorkerId);
725
726    /// Shutdown the KV Indexer.
727    fn shutdown(&mut self);
728
729    /// Dump the entire tree as RouterEvents.
730    ///
731    /// ### Returns
732    ///
733    /// A vector of RouterEvents representing the current state of the tree.
734    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError>;
735}
736
737/// The KV Indexer, managing the KV store and handling events and match requests.
738pub struct KvIndexer {
739    /// A `CancellationToken` for managing shutdown.
740    cancel: CancellationToken,
741    /// A sender for `RouterEvent`s.
742    event_tx: mpsc::Sender<RouterEvent>,
743    /// A sender for `MatchRequest`s.
744    match_tx: mpsc::Sender<MatchRequest>,
745    /// A sender for remove worker requests.
746    remove_worker_tx: mpsc::Sender<WorkerId>,
747    /// A sender for dump requests.
748    dump_tx: mpsc::Sender<DumpRequest>,
749    /// A handle to the background task managing the KV store.
750    task: OnceLock<std::thread::JoinHandle<()>>,
751    /// The size of the KV block this indexer can handle.
752    kv_block_size: u32,
753}
754
755impl KvIndexer {
756    /// Create a new `KvIndexer`.
757    ///
758    /// ### Arguments
759    ///
760    /// * `token` - A `CancellationToken` for managing shutdown.
761    /// * `expiration_duration` - The amount of time that block usage should be buffered.
762    ///
763    /// ### Returns
764    ///
765    /// A new `KvIndexer`.
766    pub fn new_with_frequency(
767        token: CancellationToken,
768        expiration_duration: Option<Duration>,
769        kv_block_size: u32,
770        metrics: Arc<KvIndexerMetrics>,
771    ) -> Self {
772        let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
773        let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
774        let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
775        let (dump_tx, dump_rx) = mpsc::channel::<DumpRequest>(16);
776        let cancel_clone = token.clone();
777
778        let task = std::thread::spawn(move || {
779            // create a new tokio runtime which will only perform work on a single thread
780            let runtime = tokio::runtime::Builder::new_multi_thread()
781                .worker_threads(1) // Single-threaded environment
782                .enable_all()
783                .build()
784                .unwrap();
785
786            let local_set = tokio::task::LocalSet::new();
787
788            runtime.block_on(local_set.run_until(async move {
789                tokio::task::spawn_local(async move {
790                    let cancel = cancel_clone;
791                    let mut match_rx = match_rx;
792                    let mut event_rx = event_rx;
793                    let mut remove_worker_rx = remove_worker_rx;
794                    let mut dump_rx = dump_rx;
795                    let mut trie = RadixTree::new_with_frequency(expiration_duration);
796                    loop {
797                        tokio::select! {
798                            biased;
799
800                            _ = cancel.cancelled() => {
801                                tracing::debug!("KvCacheIndexer progress loop shutting down");
802                                return;
803                            }
804
805                            Some(worker) = remove_worker_rx.recv() => {
806                                trie.remove_worker(worker);
807                            }
808
809                            Some(event) = event_rx.recv() => {
810                                let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
811                                let result = trie.apply_event(event);
812                                metrics.increment_event_applied(event_type, result);
813                            }
814
815                            Some(dump_req) = dump_rx.recv() => {
816                                let events = trie.dump_tree_as_events();
817                                let _ = dump_req.resp.send(events);
818                            }
819
820                            Some(req) = match_rx.recv() => {
821                                let matches = trie.find_matches(req.sequence, req.early_exit);
822                                let _ = req.resp.send(matches);
823                            }
824                        }
825                    }
826                })
827                .await
828                .unwrap()
829            }));
830
831            tracing::debug!("KvCacheIndexer task completed");
832        });
833
834        let once = OnceLock::new();
835        once.set(task).unwrap();
836
837        Self {
838            cancel: token,
839            event_tx,
840            match_tx,
841            remove_worker_tx,
842            dump_tx,
843            task: once,
844            kv_block_size,
845        }
846    }
847
848    pub fn block_size(&self) -> u32 {
849        self.kv_block_size
850    }
851
852    pub fn new(
853        token: CancellationToken,
854        kv_block_size: u32,
855        metrics: Arc<KvIndexerMetrics>,
856    ) -> Self {
857        Self::new_with_frequency(token, None, kv_block_size, metrics)
858    }
859
860    /// Get a sender for `RouterEvent`s.
861    ///
862    /// ### Returns
863    ///
864    /// A `mpsc::Sender` for `RouterEvent`s.
865    pub fn event_sender(&self) -> mpsc::Sender<RouterEvent> {
866        self.event_tx.clone()
867    }
868
869    /// Get a sender for dump requests (snapshot events).
870    ///
871    /// ### Returns
872    ///
873    /// A `mpsc::Sender` for `DumpRequest`s.
874    pub fn snapshot_event_sender(&self) -> mpsc::Sender<DumpRequest> {
875        self.dump_tx.clone()
876    }
877
878    /// Get a sender for worker removal requests.
879    ///
880    /// ### Returns
881    ///
882    /// A `mpsc::Sender` for `WorkerId`s.
883    pub fn remove_worker_sender(&self) -> mpsc::Sender<WorkerId> {
884        self.remove_worker_tx.clone()
885    }
886}
887
888#[async_trait]
889impl KvIndexerInterface for KvIndexer {
890    async fn find_matches(
891        &self,
892        sequence: Vec<LocalBlockHash>,
893    ) -> Result<OverlapScores, KvRouterError> {
894        let (resp_tx, resp_rx) = oneshot::channel();
895        let req = MatchRequest {
896            sequence,
897            early_exit: false,
898            resp: resp_tx,
899        };
900
901        if let Err(e) = self.match_tx.send(req).await {
902            tracing::error!(
903                "Failed to send match request: {:?}; the indexer maybe offline",
904                e
905            );
906            return Err(KvRouterError::IndexerOffline);
907        }
908
909        resp_rx
910            .await
911            .map_err(|_| KvRouterError::IndexerDroppedRequest)
912    }
913
914    async fn find_matches_for_request(
915        &self,
916        tokens: &[u32],
917    ) -> Result<OverlapScores, KvRouterError> {
918        tracing::debug!(
919            "Finding matches for request tokens: {:?} / len: {}",
920            tokens,
921            tokens.len()
922        );
923        let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
924        tracing::debug!("Computed sequence: {:?}", sequence);
925        self.find_matches(sequence).await
926    }
927
928    async fn apply_event(&mut self, event: RouterEvent) {
929        self.event_tx.send(event).await.unwrap();
930    }
931
932    async fn remove_worker(&mut self, worker: WorkerId) {
933        self.remove_worker_tx.send(worker).await.unwrap();
934    }
935
936    fn shutdown(&mut self) {
937        self.cancel.cancel();
938        if let Some(task) = self.task.take() {
939            task.join().expect("Failed to join kv indexer task");
940        }
941    }
942
943    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
944        let (resp_tx, resp_rx) = oneshot::channel();
945        let dump_req = DumpRequest { resp: resp_tx };
946
947        if let Err(e) = self.dump_tx.send(dump_req).await {
948            tracing::error!("Failed to send dump request: {:?}", e);
949            return Err(KvRouterError::IndexerOffline);
950        }
951
952        resp_rx
953            .await
954            .map_err(|_| KvRouterError::IndexerDroppedRequest)
955    }
956}
957
958#[derive(Debug, Clone)]
959pub struct ShardedMatchRequest {
960    sequence: Vec<LocalBlockHash>,
961    early_exit: bool,
962    resp: mpsc::Sender<OverlapScores>,
963}
964
965/// A sharded KV Indexer that partitions the RadixTree across multiple independent shards.
966///
967/// ## Sharding Strategy
968/// - Each worker is **permanently assigned** to a single shard on first event
969/// - All KV blocks from a worker exist only in that worker's assigned shard
970/// - New workers are assigned to the shard with the fewest workers (load balancing)
971///
972/// ## Operation
973/// - **Events**: Routed directly to the worker's assigned shard
974/// - **Match requests**: Broadcast to all shards (scatter-gather pattern)
975/// - **Threading**: Each shard runs in its own thread with a single-threaded runtime
976///
977/// This design ensures no cross-shard synchronization for writes while enabling
978/// parallel processing and better scalability.
979pub struct KvIndexerSharded {
980    /// A `CancellationToken` for managing shutdown.
981    cancel: CancellationToken,
982    /// The size of the KV block this indexer can handle.
983    kv_block_size: u32,
984    worker_assignments: HashMap<WorkerId, usize>,
985    worker_counts: Vec<usize>,
986
987    event_tx: Vec<mpsc::Sender<RouterEvent>>,
988    request_broadcast_tx: broadcast::Sender<ShardedMatchRequest>,
989    remove_worker_tx: Vec<mpsc::Sender<WorkerId>>,
990    dump_tx: Vec<mpsc::Sender<DumpRequest>>,
991    tasks: Vec<JoinHandle<()>>,
992}
993
994impl KvIndexerSharded {
995    /// Create a new `KvIndexerSharded`.
996    ///
997    /// ### Arguments
998    ///
999    /// * `token` - A `CancellationToken` for managing shutdown.
1000    /// * `shards` - A list of kvindexer shards.
1001    /// * `expiration_duration` - The amount of time that block usage should be buffered.
1002    ///
1003    /// ### Returns
1004    ///
1005    /// A new `KvIndexer`.
1006    pub fn new_with_frequency(
1007        token: CancellationToken,
1008        num_shards: usize,
1009        expiration_duration: Option<Duration>,
1010        kv_block_size: u32,
1011        metrics: Arc<KvIndexerMetrics>,
1012    ) -> Self {
1013        let worker_assignments: HashMap<WorkerId, usize> = HashMap::new();
1014        let worker_counts: Vec<usize> = vec![0; num_shards];
1015
1016        let mut event_tx = Vec::new();
1017        let mut remove_worker_tx = Vec::new();
1018        let mut dump_tx = Vec::new(); // Add dump channels
1019        let mut tasks = Vec::new();
1020
1021        let (request_broadcast_tx, _) = broadcast::channel::<ShardedMatchRequest>(1048576);
1022
1023        for _ in 0..num_shards {
1024            let (shard_event_tx, mut shard_event_rx) = mpsc::channel::<RouterEvent>(2048);
1025            let (shard_remove_worker_tx, mut shard_remove_worker_rx) =
1026                mpsc::channel::<WorkerId>(16);
1027            let (shard_dump_tx, mut shard_dump_rx) = mpsc::channel::<DumpRequest>(16); // Add dump channel
1028            let mut shard_broadcast_rx = request_broadcast_tx.subscribe();
1029            let cancel = token.clone();
1030            let metrics = metrics.clone();
1031
1032            event_tx.push(shard_event_tx);
1033            remove_worker_tx.push(shard_remove_worker_tx);
1034            dump_tx.push(shard_dump_tx); // Store dump sender
1035
1036            let runtime = tokio::runtime::Builder::new_multi_thread()
1037                .worker_threads(1)
1038                .enable_all()
1039                .build()
1040                .unwrap();
1041
1042            tasks.push(std::thread::spawn(move || {
1043                let local_set = tokio::task::LocalSet::new();
1044
1045                runtime.block_on(local_set.run_until(async move {
1046                    tokio::task::spawn_local(async move {
1047                        let mut trie = RadixTree::new_with_frequency(expiration_duration);
1048                        loop {
1049                            tokio::select! {
1050                                biased;
1051
1052                                _ = cancel.cancelled() => {
1053                                    tracing::trace!("KvCacheIndexer progress loop shutting down");
1054                                    return;
1055                                }
1056
1057                                Some(worker) = shard_remove_worker_rx.recv() => {
1058                                    trie.remove_worker(worker);
1059                                }
1060
1061                                Some(event) = shard_event_rx.recv() => {
1062                                    let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
1063                                    let result = trie.apply_event(event);
1064                                    metrics.increment_event_applied(event_type, result);
1065                                }
1066
1067                                Some(dump_req) = shard_dump_rx.recv() => {
1068                                    let events = trie.dump_tree_as_events();
1069                                    let _ = dump_req.resp.send(events);
1070                                }
1071
1072                                Ok(req) = shard_broadcast_rx.recv() => {
1073                                    let matches = trie.find_matches(req.sequence, req.early_exit);
1074                                    if let Err(e) = req.resp.send(matches).await {
1075                                        tracing::trace!("Failed to send match response: {:?}", e);
1076                                    }
1077                                }
1078                            }
1079                        }
1080                    })
1081                    .await
1082                    .unwrap()
1083                }));
1084
1085                tracing::debug!("KvCacheIndexer task completed");
1086            }));
1087        }
1088
1089        Self {
1090            cancel: token,
1091            kv_block_size,
1092            worker_assignments,
1093            worker_counts,
1094            event_tx,
1095            request_broadcast_tx,
1096            remove_worker_tx,
1097            dump_tx, // Add dump_tx field
1098            tasks,
1099        }
1100    }
1101
1102    pub fn block_size(&self) -> u32 {
1103        self.kv_block_size
1104    }
1105
1106    pub fn new(
1107        token: CancellationToken,
1108        num_shards: usize,
1109        kv_block_size: u32,
1110        metrics: Arc<KvIndexerMetrics>,
1111    ) -> Self {
1112        Self::new_with_frequency(token, num_shards, None, kv_block_size, metrics)
1113    }
1114}
1115
1116#[async_trait]
1117impl KvIndexerInterface for KvIndexerSharded {
1118    async fn find_matches(
1119        &self,
1120        sequence: Vec<LocalBlockHash>,
1121    ) -> Result<OverlapScores, KvRouterError> {
1122        'match_loop: loop {
1123            let (match_tx, mut match_rx) = mpsc::channel(self.event_tx.len());
1124            self.request_broadcast_tx
1125                .send(ShardedMatchRequest {
1126                    sequence: sequence.clone(),
1127                    early_exit: false,
1128                    resp: match_tx,
1129                })
1130                .map_err(|_| KvRouterError::IndexerOffline)?;
1131
1132            let mut scores = OverlapScores::new();
1133
1134            for response_num in 0..self.event_tx.len() {
1135                match match_rx.recv().await {
1136                    Some(response) => {
1137                        scores.scores.extend(response.scores);
1138
1139                        if response_num == 0 {
1140                            scores.frequencies = response.frequencies;
1141                        } else {
1142                            let diff = (response.frequencies.len() as i64)
1143                                - (scores.frequencies.len() as i64);
1144
1145                            if diff > 0 {
1146                                scores.frequencies.extend(iter::repeat_n(0, diff as usize));
1147                            }
1148
1149                            for i in 0..response.frequencies.len() {
1150                                scores.frequencies[i] += response.frequencies[i];
1151                            }
1152                        }
1153                    }
1154                    None => {
1155                        // This can only happen if the broadcast channel overflows.
1156                        // In this case, we don't want to recursively call find_matches again. Otherwise, we could overflow the stack.
1157                        continue 'match_loop;
1158                    }
1159                }
1160            }
1161            return Ok(scores);
1162        }
1163    }
1164
1165    async fn find_matches_for_request(
1166        &self,
1167        tokens: &[u32],
1168    ) -> Result<OverlapScores, KvRouterError> {
1169        let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
1170        self.find_matches(sequence).await
1171    }
1172
1173    async fn apply_event(&mut self, event: RouterEvent) {
1174        #[allow(clippy::map_entry)]
1175        if !self.worker_assignments.contains_key(&event.worker_id) {
1176            // Get the shard with the smallest amount of workers.
1177            let selected_shard = self
1178                .worker_counts
1179                .iter()
1180                .enumerate()
1181                .min_by_key(|&(_, value)| value)
1182                .unwrap()
1183                .0;
1184
1185            self.worker_assignments
1186                .insert(event.worker_id, selected_shard);
1187            self.worker_counts[selected_shard] += 1;
1188        }
1189
1190        self.event_tx[self.worker_assignments[&event.worker_id]]
1191            .send(event)
1192            .await
1193            .unwrap();
1194    }
1195
1196    async fn remove_worker(&mut self, worker: WorkerId) {
1197        if let Some((_, shard)) = self.worker_assignments.remove_entry(&worker) {
1198            self.worker_counts[shard] -= 1;
1199            self.remove_worker_tx[shard].send(worker).await.unwrap();
1200        }
1201    }
1202
1203    /// Shutdown the KV Indexer.
1204    fn shutdown(&mut self) {
1205        self.cancel.cancel();
1206        while !self.tasks.is_empty() {
1207            self.tasks.pop().unwrap().join().unwrap();
1208        }
1209    }
1210
1211    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
1212        let mut all_events = Vec::new();
1213
1214        // Create channels for each shard
1215        let mut receivers = Vec::new();
1216
1217        for shard_dump_tx in &self.dump_tx {
1218            let (resp_tx, resp_rx) = oneshot::channel();
1219            let dump_req = DumpRequest { resp: resp_tx };
1220
1221            if let Err(e) = shard_dump_tx.send(dump_req).await {
1222                tracing::error!("Failed to send dump request to shard: {:?}", e);
1223                return Err(KvRouterError::IndexerOffline);
1224            }
1225
1226            receivers.push(resp_rx);
1227        }
1228
1229        // Collect results from all shards
1230        for resp_rx in receivers {
1231            match resp_rx.await {
1232                Ok(events) => all_events.extend(events),
1233                Err(_) => return Err(KvRouterError::IndexerDroppedRequest),
1234            }
1235        }
1236
1237        Ok(all_events)
1238    }
1239}
1240
1241#[cfg(test)]
1242mod tests {
1243
1244    use super::*;
1245    use rstest::rstest;
1246    use rstest_reuse::{self, *};
1247    use tokio::time;
1248    use tokio_util::sync::CancellationToken;
1249
1250    fn setup() {
1251        dynamo_runtime::logging::init();
1252    }
1253
1254    fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
1255        hashes
1256            .iter()
1257            .map(|i| KvCacheStoredBlockData {
1258                tokens_hash: LocalBlockHash(*i),
1259                block_hash: ExternalSequenceBlockHash(*i * 100),
1260            })
1261            .collect()
1262    }
1263
1264    fn add_blocks(
1265        hashes: Vec<u64>,
1266        parent_hash: Option<ExternalSequenceBlockHash>,
1267    ) -> KvCacheEventData {
1268        KvCacheEventData::Stored(KvCacheStoreData {
1269            parent_hash,
1270            blocks: make_blocks(hashes),
1271        })
1272    }
1273
1274    fn create_store_event(
1275        worker_id: WorkerId,
1276        event_id: u64,
1277        hashes: Vec<u64>,
1278        parent: Option<ExternalSequenceBlockHash>,
1279    ) -> RouterEvent {
1280        RouterEvent {
1281            worker_id,
1282            event: KvCacheEvent {
1283                event_id,
1284                data: add_blocks(hashes, parent),
1285            },
1286        }
1287    }
1288
1289    fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
1290        RouterEvent {
1291            worker_id,
1292            event: KvCacheEvent {
1293                event_id,
1294                data: KvCacheEventData::Removed(KvCacheRemoveData {
1295                    block_hashes: hashes
1296                        .iter()
1297                        .map(|i| ExternalSequenceBlockHash(*i * 100))
1298                        .collect(),
1299                }),
1300            },
1301        }
1302    }
1303
1304    #[test]
1305    fn test_radix_tree() {
1306        setup();
1307
1308        let mut trie = RadixTree::new();
1309
1310        let worker_1 = 0;
1311        let worker_2 = 1;
1312
1313        trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
1314            .unwrap();
1315
1316        let scores = trie.find_matches(
1317            vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
1318            false,
1319        );
1320        assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
1321
1322        assert_eq!(trie.lookup.len(), 1);
1323        assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1324        assert_eq!(trie.root.borrow().workers.len(), 0);
1325        assert_eq!(trie.root.borrow().children.len(), 1);
1326        assert_eq!(
1327            trie.root
1328                .borrow()
1329                .children
1330                .get(&LocalBlockHash(1))
1331                .unwrap()
1332                .borrow()
1333                .workers
1334                .len(),
1335            1
1336        );
1337        assert_eq!(
1338            trie.root
1339                .borrow()
1340                .children
1341                .get(&LocalBlockHash(1))
1342                .unwrap()
1343                .borrow()
1344                .children
1345                .len(),
1346            1
1347        );
1348
1349        trie.apply_event(create_store_event(worker_2, 1, vec![1, 4, 5], None))
1350            .unwrap();
1351
1352        let scores = trie.find_matches(
1353            vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
1354            false,
1355        );
1356        assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
1357        assert_eq!(scores.scores.get(&worker_2).unwrap(), &1);
1358
1359        assert_eq!(trie.lookup.len(), 2);
1360        assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1361        assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 3);
1362        assert_eq!(trie.root.borrow().workers.len(), 0);
1363        assert_eq!(trie.root.borrow().children.len(), 1);
1364        assert_eq!(
1365            trie.root
1366                .borrow()
1367                .children
1368                .get(&LocalBlockHash(1))
1369                .unwrap()
1370                .borrow()
1371                .workers
1372                .len(),
1373            2
1374        );
1375        assert_eq!(
1376            trie.root
1377                .borrow()
1378                .children
1379                .get(&LocalBlockHash(1))
1380                .unwrap()
1381                .borrow()
1382                .children
1383                .len(),
1384            2
1385        );
1386
1387        trie.apply_event(create_remove_event(worker_2, 2, vec![5]))
1388            .unwrap();
1389        assert_eq!(trie.lookup.len(), 2);
1390        assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1391        assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 2);
1392        assert_eq!(trie.root.borrow().workers.len(), 0);
1393        assert_eq!(trie.root.borrow().children.len(), 1);
1394        assert_eq!(
1395            trie.root
1396                .borrow()
1397                .children
1398                .get(&LocalBlockHash(1))
1399                .unwrap()
1400                .borrow()
1401                .workers
1402                .len(),
1403            2
1404        );
1405        assert_eq!(
1406            trie.root
1407                .borrow()
1408                .children
1409                .get(&LocalBlockHash(1))
1410                .unwrap()
1411                .borrow()
1412                .children
1413                .len(),
1414            2
1415        );
1416
1417        trie.apply_event(create_remove_event(worker_2, 3, vec![4]))
1418            .unwrap();
1419
1420        assert_eq!(trie.lookup.len(), 2);
1421        assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1422        assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 1);
1423        assert_eq!(trie.root.borrow().workers.len(), 0);
1424        assert_eq!(trie.root.borrow().children.len(), 1);
1425        assert_eq!(
1426            trie.root
1427                .borrow()
1428                .children
1429                .get(&LocalBlockHash(1))
1430                .unwrap()
1431                .borrow()
1432                .workers
1433                .len(),
1434            2
1435        );
1436        assert_eq!(
1437            trie.root
1438                .borrow()
1439                .children
1440                .get(&LocalBlockHash(1))
1441                .unwrap()
1442                .borrow()
1443                .children
1444                .len(),
1445            2
1446        );
1447
1448        trie.apply_event(create_store_event(
1449            worker_2,
1450            4,
1451            vec![2, 6, 7],
1452            Some(ExternalSequenceBlockHash(100)),
1453        ))
1454        .unwrap();
1455
1456        let scores = trie.find_matches(
1457            vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
1458            false,
1459        );
1460        assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
1461        assert_eq!(scores.scores.get(&worker_2).unwrap(), &2);
1462
1463        assert_eq!(trie.lookup.len(), 2);
1464        assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1465        assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 4);
1466        assert_eq!(trie.root.borrow().workers.len(), 0);
1467        assert_eq!(trie.root.borrow().children.len(), 1);
1468        assert_eq!(
1469            trie.root
1470                .borrow()
1471                .children
1472                .get(&LocalBlockHash(1))
1473                .unwrap()
1474                .borrow()
1475                .workers
1476                .len(),
1477            2
1478        );
1479        assert_eq!(
1480            trie.root
1481                .borrow()
1482                .children
1483                .get(&LocalBlockHash(1))
1484                .unwrap()
1485                .borrow()
1486                .children
1487                .len(),
1488            2
1489        );
1490        assert_eq!(
1491            trie.lookup
1492                .get(&worker_1)
1493                .unwrap()
1494                .get(&ExternalSequenceBlockHash(200))
1495                .unwrap()
1496                .borrow()
1497                .workers
1498                .len(),
1499            2
1500        );
1501        assert_eq!(
1502            trie.lookup
1503                .get(&worker_2)
1504                .unwrap()
1505                .get(&ExternalSequenceBlockHash(200))
1506                .unwrap()
1507                .borrow()
1508                .workers
1509                .len(),
1510            2
1511        );
1512    }
1513
1514    #[test]
1515    fn test_radix_tree_apply_event_errors() {
1516        let mut trie = RadixTree::new();
1517        let worker_0 = 0;
1518
1519        // Parent block not found
1520        let result = trie.apply_event(create_store_event(
1521            worker_0,
1522            0,
1523            vec![1, 2, 3],
1524            Some(ExternalSequenceBlockHash(12345)),
1525        ));
1526        assert!(result.is_err());
1527        assert!(matches!(
1528            result.unwrap_err(),
1529            KvCacheEventError::ParentBlockNotFound
1530        ));
1531
1532        // Block not found for remove event.
1533        let result = trie.apply_event(create_remove_event(worker_0, 0, vec![1, 2, 3]));
1534        assert!(result.is_err());
1535        assert!(matches!(
1536            result.unwrap_err(),
1537            KvCacheEventError::BlockNotFound
1538        ));
1539    }
1540
1541    #[test]
1542    fn test_remove_worker() {
1543        setup();
1544        let mut trie = RadixTree::new();
1545
1546        let worker_0 = 0;
1547        let worker_1 = 1;
1548
1549        assert!(
1550            trie.find_matches(vec![LocalBlockHash(0)], false)
1551                .scores
1552                .is_empty()
1553        );
1554
1555        trie.apply_event(create_store_event(worker_0, 0, vec![0], None))
1556            .unwrap();
1557        trie.apply_event(create_store_event(worker_1, 0, vec![0], None))
1558            .unwrap();
1559
1560        let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
1561        assert!(result.len() == 2 && result[&worker_0] == 1 && result[&worker_1] == 1);
1562
1563        trie.remove_worker(worker_0);
1564
1565        let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
1566        assert!(result.len() == 1 && result[&worker_1] == 1);
1567    }
1568
1569    #[test]
1570    fn test_clear_all_blocks() {
1571        let mut trie = RadixTree::new();
1572
1573        let worker_0 = 0;
1574        let worker_1 = 1;
1575
1576        assert!(
1577            trie.find_matches(vec![LocalBlockHash(0)], false)
1578                .scores
1579                .is_empty()
1580        );
1581
1582        // Test clearing an empty worker
1583        trie.clear_all_blocks(worker_0);
1584        assert!(!trie.lookup.contains_key(&worker_0));
1585
1586        // Test clearing a worker with shared blocks
1587        trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 3], None))
1588            .unwrap();
1589        trie.apply_event(create_store_event(worker_1, 0, vec![0, 2, 3], None))
1590            .unwrap();
1591
1592        let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
1593        assert!(result.len() == 2 && result[&worker_0] == 1 && result[&worker_1] == 1);
1594
1595        trie.clear_all_blocks(worker_0);
1596
1597        assert!(trie.lookup.contains_key(&worker_0));
1598        assert!(trie.lookup.get(&worker_0).unwrap().is_empty());
1599        let result = trie
1600            .find_matches(vec![LocalBlockHash(0), LocalBlockHash(2)], false)
1601            .scores;
1602        assert_eq!(result.len(), 1);
1603        assert_eq!(result[&worker_1], 2);
1604        let result = trie
1605            .find_matches(
1606                vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(3)],
1607                false,
1608            )
1609            .scores;
1610        assert_eq!(result.len(), 1);
1611        assert_eq!(result[&worker_1], 1);
1612
1613        // Test re-adding blocks after clearing worker
1614        trie.apply_event(create_store_event(worker_0, 0, vec![4, 5], None))
1615            .unwrap();
1616        let result = trie
1617            .find_matches(vec![LocalBlockHash(4), LocalBlockHash(5)], false)
1618            .scores;
1619        assert_eq!(result.len(), 1);
1620        assert_eq!(result[&worker_0], 2);
1621
1622        // Test multiple clears
1623        trie.clear_all_blocks(worker_0);
1624        trie.clear_all_blocks(worker_0);
1625        assert!(trie.lookup.contains_key(&worker_0));
1626
1627        // Test clearing all workers
1628        trie.clear_all_blocks(worker_0);
1629        trie.clear_all_blocks(worker_1);
1630        assert!(!trie.lookup.is_empty());
1631        assert!(trie.lookup.get(&worker_0).unwrap().is_empty());
1632        assert!(trie.lookup.get(&worker_1).unwrap().is_empty());
1633
1634        // Test clearing a worker that has been removed
1635        trie.apply_event(create_store_event(worker_0, 0, vec![6], None))
1636            .unwrap();
1637        trie.apply_event(create_store_event(worker_1, 0, vec![6], None))
1638            .unwrap();
1639        trie.remove_worker(worker_0);
1640        trie.clear_all_blocks(worker_0);
1641        assert!(!trie.lookup.contains_key(&worker_0));
1642        let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores;
1643        assert_eq!(result.len(), 1);
1644        assert_eq!(result[&worker_1], 1);
1645
1646        // Test clearing a worker that doesn't exist
1647        let worker_fake = 2;
1648        assert!(!trie.lookup.contains_key(&worker_fake));
1649        trie.clear_all_blocks(worker_fake);
1650        assert!(!trie.lookup.contains_key(&worker_fake));
1651        assert!(trie.lookup.contains_key(&worker_1));
1652        let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores;
1653        assert_eq!(result.len(), 1);
1654        assert_eq!(result[&worker_1], 1);
1655    }
1656
1657    #[test]
1658    fn test_early_stopping() {
1659        setup();
1660        let mut trie = RadixTree::new();
1661
1662        let worker_0 = 0;
1663        let worker_1 = 1;
1664
1665        trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 2], None))
1666            .unwrap();
1667        trie.apply_event(create_store_event(worker_1, 0, vec![0], None))
1668            .unwrap();
1669
1670        let result = trie
1671            .find_matches(
1672                vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(2)],
1673                true,
1674            )
1675            .scores;
1676
1677        assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1);
1678
1679        let result = trie
1680            .find_matches(vec![LocalBlockHash(0), LocalBlockHash(1)], true)
1681            .scores;
1682        assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1);
1683    }
1684
1685    #[rstest]
1686    #[case(11)]
1687    #[case(32)]
1688    #[case(64)]
1689    fn test_compute_block_hash_for_seq(#[case] kv_block_size: u32) {
1690        setup();
1691        // create a sequence of 64 elements
1692        let sequence = (0..kv_block_size).collect::<Vec<u32>>();
1693        let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
1694        assert_eq!(hashes.len(), 1);
1695
1696        // create a sequence of 65 elements
1697        let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>();
1698        let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
1699        assert_eq!(hashes.len(), 1);
1700
1701        // create a sequence of 129 elements
1702        let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>();
1703        let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
1704        assert_eq!(hashes.len(), 2);
1705    }
1706
1707    fn make_indexer(
1708        token: &CancellationToken,
1709        num_shards: usize,
1710        kv_block_size: u32,
1711    ) -> Box<dyn KvIndexerInterface> {
1712        let metrics = KvIndexerMetrics::new_unregistered();
1713        if num_shards == 1 {
1714            Box::new(KvIndexer::new(token.clone(), kv_block_size, metrics.into()))
1715        } else {
1716            Box::new(KvIndexerSharded::new(
1717                token.clone(),
1718                num_shards,
1719                kv_block_size,
1720                metrics.into(),
1721            ))
1722        }
1723    }
1724
1725    #[template]
1726    #[rstest]
1727    fn indexer_template(
1728        #[values(1, 3, 8)] num_shards: usize,
1729        #[values(11, 32, 64)] kv_block_size: usize,
1730    ) {
1731    }
1732
1733    #[tokio::test]
1734    #[apply(indexer_template)]
1735    async fn test_kv_indexer_new(num_shards: usize, kv_block_size: u32) {
1736        setup();
1737        let token: CancellationToken = CancellationToken::new();
1738        let _ = make_indexer(&token, num_shards, kv_block_size);
1739    }
1740
1741    #[tokio::test]
1742    #[apply(indexer_template)]
1743    async fn test_find_matches(num_shards: usize, kv_block_size: u32) {
1744        setup();
1745        let token = CancellationToken::new();
1746        let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1747
1748        let sequence = vec![compute_block_hash(b"test data")];
1749        let scores = kv_indexer.find_matches(sequence).await;
1750
1751        assert!(scores.unwrap().scores.is_empty());
1752    }
1753
1754    #[tokio::test]
1755    #[apply(indexer_template)]
1756    async fn test_find_matches_for_request(num_shards: usize, kv_block_size: u32) {
1757        setup();
1758        let token = CancellationToken::new();
1759        let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1760
1761        let tokens = vec![1, 2, 3, 4];
1762        let scores = kv_indexer.find_matches_for_request(&tokens).await;
1763
1764        assert!(scores.unwrap().scores.is_empty());
1765    }
1766
1767    #[tokio::test]
1768    #[apply(indexer_template)]
1769    async fn test_apply_event(num_shards: usize, kv_block_size: u32) {
1770        setup();
1771        let worker_id = 0;
1772
1773        let token = CancellationToken::new();
1774        let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1775
1776        let event = create_store_event(worker_id, 1, vec![1, 2, 3], None);
1777        kv_indexer.apply_event(event).await;
1778
1779        // No assertion here, just ensuring it runs without panic
1780    }
1781
1782    #[tokio::test]
1783    #[apply(indexer_template)]
1784    async fn test_shutdown(num_shards: usize, kv_block_size: u32) {
1785        setup();
1786        let token = CancellationToken::new();
1787        let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1788
1789        kv_indexer.shutdown();
1790    }
1791
1792    #[tokio::test]
1793    #[apply(indexer_template)]
1794    async fn test_frequency(num_shards: usize, kv_block_size: u32) {
1795        const ONE_MILLIS: Duration = Duration::from_millis(1);
1796
1797        setup();
1798        let mut kv_indexer: Box<dyn KvIndexerInterface>;
1799        let token = CancellationToken::new();
1800        let expiration = Duration::from_millis(50);
1801        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
1802
1803        if num_shards == 1 {
1804            kv_indexer = Box::new(KvIndexer::new_with_frequency(
1805                token,
1806                Some(expiration),
1807                kv_block_size,
1808                metrics,
1809            ));
1810        } else {
1811            kv_indexer = Box::new(KvIndexerSharded::new_with_frequency(
1812                token,
1813                num_shards,
1814                Some(expiration),
1815                kv_block_size,
1816                metrics,
1817            ));
1818        }
1819
1820        // The blocks
1821        let block_hashes = vec![
1822            LocalBlockHash(1),
1823            LocalBlockHash(2),
1824            LocalBlockHash(3),
1825            LocalBlockHash(4),
1826        ];
1827
1828        let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1829        assert_eq!(
1830            overlap.frequencies.len(),
1831            0,
1832            "Should be no cached blocks yet"
1833        );
1834
1835        // Blocks go in cache
1836        let worker_id = 0;
1837        let event = create_store_event(worker_id, 0, vec![1, 2, 3, 4], None);
1838        kv_indexer.apply_event(event).await;
1839
1840        // First access
1841        // The store event is applied async so poll briefly
1842        let mut overlap = OverlapScores::default();
1843        let timeout = Duration::from_millis(10);
1844        let start = Instant::now();
1845        while overlap.scores.is_empty() && Instant::now().duration_since(start) < timeout {
1846            time::sleep(ONE_MILLIS).await;
1847            overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1848        }
1849        assert_eq!(
1850            overlap.scores.len(),
1851            1,
1852            "One worker has these blocks cached"
1853        );
1854        assert_eq!(
1855            overlap.frequencies.len(),
1856            0,
1857            "Blocks have not previously been accessed"
1858        );
1859
1860        // Second access
1861        let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1862        assert_eq!(overlap.scores.len(), 1, "Still one worker matches");
1863        assert_eq!(
1864            overlap.frequencies,
1865            vec![1, 1, 1, 1],
1866            "We should see the first access now"
1867        );
1868
1869        // Let those two accesses expire
1870        time::sleep(expiration + Duration::from_millis(10)).await;
1871
1872        // New first access
1873        let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1874        assert_eq!(
1875            overlap.frequencies.len(),
1876            0,
1877            "Blocks were accessed too long ago"
1878        );
1879
1880        // New second access
1881        let _ = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1882
1883        // Access only the first three blocks
1884        let overlap = kv_indexer
1885            .find_matches(block_hashes[0..3].to_vec())
1886            .await
1887            .unwrap();
1888        // We see the previous two new accesses
1889        assert_eq!(overlap.frequencies, vec![2, 2, 2]);
1890
1891        // The third access did not touch the last block
1892        let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1893        assert_eq!(overlap.frequencies, vec![3, 3, 3, 2]);
1894    }
1895
1896    #[test]
1897    fn test_router_event_new() {
1898        setup();
1899        let worker_id = 0;
1900        let kv_cache_event = KvCacheEvent {
1901            event_id: 1,
1902            data: KvCacheEventData::Stored(KvCacheStoreData {
1903                parent_hash: None,
1904                blocks: vec![KvCacheStoredBlockData {
1905                    block_hash: ExternalSequenceBlockHash(0),
1906                    tokens_hash: LocalBlockHash(13226331709069118873),
1907                }],
1908            }),
1909        };
1910        let router_event = RouterEvent::new(worker_id, kv_cache_event);
1911
1912        assert_eq!(router_event.worker_id, worker_id);
1913        assert_eq!(router_event.event.event_id, 1);
1914        if let KvCacheEventData::Stored(store_op) = &router_event.event.data {
1915            assert_eq!(store_op.blocks.len(), 1);
1916            assert_eq!(
1917                store_op.blocks[0].tokens_hash,
1918                compute_block_hash(b"test data")
1919            );
1920            assert_eq!(store_op.blocks[0].block_hash, ExternalSequenceBlockHash(0));
1921        } else {
1922            panic!("Expected KvCacheEventData::Stored");
1923        }
1924    }
1925
1926    #[test]
1927    fn test_radix_tree_default() {
1928        setup();
1929        let radix_tree: RadixTree = Default::default();
1930        assert!(radix_tree.root.borrow().children.is_empty());
1931        assert!(radix_tree.root.borrow().workers.is_empty());
1932        assert!(radix_tree.lookup.is_empty());
1933    }
1934
1935    #[test]
1936    fn test_overlap_scores_default() {
1937        setup();
1938        let overlap_scores: OverlapScores = Default::default();
1939        assert!(overlap_scores.scores.is_empty());
1940    }
1941
1942    #[tokio::test]
1943    async fn test_dump_tree_as_events_round_trip() {
1944        setup();
1945
1946        // Configuration
1947        let kv_block_size = 32;
1948        let num_shards = 2;
1949        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
1950
1951        // Build a non-trivial indexer with events
1952        let token1 = CancellationToken::new();
1953        let mut original_indexer =
1954            KvIndexerSharded::new(token1.clone(), num_shards, kv_block_size, metrics.clone());
1955
1956        let worker_0 = 0;
1957        let worker_1 = 1;
1958        let worker_2 = 2;
1959
1960        // Apply events to the original indexer
1961        original_indexer
1962            .apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
1963            .await;
1964
1965        original_indexer
1966            .apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
1967            .await;
1968        original_indexer
1969            .apply_event(create_store_event(
1970                worker_1,
1971                2,
1972                vec![4, 5],
1973                Some(ExternalSequenceBlockHash(100)),
1974            ))
1975            .await;
1976
1977        original_indexer
1978            .apply_event(create_store_event(worker_2, 3, vec![6, 7], None))
1979            .await;
1980
1981        original_indexer
1982            .apply_event(create_store_event(
1983                worker_0,
1984                4,
1985                vec![4],
1986                Some(ExternalSequenceBlockHash(100)),
1987            ))
1988            .await;
1989
1990        // Allow some time for events to be processed
1991        tokio::time::sleep(Duration::from_millis(50)).await;
1992
1993        // Dump the original indexer
1994        let dump1 = original_indexer.dump_events().await.unwrap();
1995        println!("Dumped {} events", dump1.len());
1996
1997        // Create a new indexer and apply all dumped events
1998        let token2 = CancellationToken::new();
1999        let mut reconstructed_indexer =
2000            KvIndexerSharded::new(token2.clone(), num_shards, kv_block_size, metrics);
2001
2002        for event in &dump1 {
2003            reconstructed_indexer.apply_event(event.clone()).await;
2004        }
2005
2006        // Allow some time for events to be processed
2007        tokio::time::sleep(Duration::from_millis(50)).await;
2008
2009        // Dump the reconstructed indexer
2010        let dump2 = reconstructed_indexer.dump_events().await.unwrap();
2011
2012        // Sort both dumps for comparison (order might differ due to HashMap iteration and sharding)
2013        let mut sorted_dump1 = dump1.clone();
2014        let mut sorted_dump2 = dump2.clone();
2015
2016        // Sort by (worker_id, tokens_hash, parent_hash)
2017        let sort_key = |event: &RouterEvent| {
2018            if let KvCacheEventData::Stored(ref data) = event.event.data {
2019                (
2020                    event.worker_id,
2021                    data.blocks.first().map(|b| b.tokens_hash.0).unwrap_or(0),
2022                    data.parent_hash.map(|h| h.0).unwrap_or(0),
2023                )
2024            } else {
2025                (event.worker_id, 0, 0)
2026            }
2027        };
2028
2029        sorted_dump1.sort_by_key(sort_key);
2030        sorted_dump2.sort_by_key(sort_key);
2031
2032        // Verify the dumps have the same length
2033        assert_eq!(
2034            sorted_dump1.len(),
2035            sorted_dump2.len(),
2036            "Dumps have different lengths: {} vs {}",
2037            sorted_dump1.len(),
2038            sorted_dump2.len()
2039        );
2040
2041        // Verify each event matches
2042        for (i, (event1, event2)) in sorted_dump1.iter().zip(sorted_dump2.iter()).enumerate() {
2043            assert_eq!(
2044                event1.worker_id, event2.worker_id,
2045                "Event {} worker_id mismatch",
2046                i
2047            );
2048
2049            if let (KvCacheEventData::Stored(data1), KvCacheEventData::Stored(data2)) =
2050                (&event1.event.data, &event2.event.data)
2051            {
2052                assert_eq!(
2053                    data1.parent_hash, data2.parent_hash,
2054                    "Event {} parent_hash mismatch",
2055                    i
2056                );
2057                assert_eq!(
2058                    data1.blocks.len(),
2059                    data2.blocks.len(),
2060                    "Event {} blocks length mismatch",
2061                    i
2062                );
2063
2064                for (j, (block1, block2)) in
2065                    data1.blocks.iter().zip(data2.blocks.iter()).enumerate()
2066                {
2067                    assert_eq!(
2068                        block1.tokens_hash, block2.tokens_hash,
2069                        "Event {} block {} tokens_hash mismatch",
2070                        i, j
2071                    );
2072                    assert_eq!(
2073                        block1.block_hash, block2.block_hash,
2074                        "Event {} block {} block_hash mismatch",
2075                        i, j
2076                    );
2077                }
2078            } else {
2079                panic!("Expected Stored events in both dumps");
2080            }
2081        }
2082
2083        // Also verify that both indexers produce the same match results
2084        for test_seq in [
2085            vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
2086            vec![LocalBlockHash(1), LocalBlockHash(4), LocalBlockHash(5)],
2087            vec![LocalBlockHash(6), LocalBlockHash(7)],
2088            vec![LocalBlockHash(1)],
2089        ] {
2090            let scores1 = original_indexer
2091                .find_matches(test_seq.clone())
2092                .await
2093                .unwrap();
2094            let scores2 = reconstructed_indexer
2095                .find_matches(test_seq.clone())
2096                .await
2097                .unwrap();
2098
2099            // Sort the scores to compare
2100            let mut scores1_sorted: Vec<_> = scores1.scores.iter().collect();
2101            let mut scores2_sorted: Vec<_> = scores2.scores.iter().collect();
2102            scores1_sorted.sort_by_key(|(k, _)| *k);
2103            scores2_sorted.sort_by_key(|(k, _)| *k);
2104
2105            assert_eq!(
2106                scores1_sorted, scores2_sorted,
2107                "Match scores differ for sequence {:?}",
2108                test_seq
2109            );
2110        }
2111
2112        // Clean up
2113        original_indexer.shutdown();
2114        reconstructed_indexer.shutdown();
2115    }
2116
2117    #[test]
2118    fn test_increment_event_applied() {
2119        let metrics = KvIndexerMetrics::new_unregistered();
2120
2121        metrics.increment_event_applied(METRIC_EVENT_STORED, Ok(()));
2122        assert_eq!(
2123            metrics
2124                .kv_cache_events_applied
2125                .get_metric_with_label_values(&[METRIC_EVENT_STORED, METRIC_STATUS_OK])
2126                .unwrap()
2127                .get(),
2128            1
2129        );
2130
2131        metrics.increment_event_applied(
2132            METRIC_EVENT_STORED,
2133            Err(KvCacheEventError::ParentBlockNotFound),
2134        );
2135        assert_eq!(
2136            metrics
2137                .kv_cache_events_applied
2138                .get_metric_with_label_values(&[
2139                    METRIC_EVENT_STORED,
2140                    METRIC_STATUS_PARENT_NOT_FOUND
2141                ])
2142                .unwrap()
2143                .get(),
2144            1
2145        );
2146
2147        metrics
2148            .increment_event_applied(METRIC_EVENT_REMOVED, Err(KvCacheEventError::BlockNotFound));
2149        assert_eq!(
2150            metrics
2151                .kv_cache_events_applied
2152                .get_metric_with_label_values(&[
2153                    METRIC_EVENT_REMOVED,
2154                    METRIC_STATUS_BLOCK_NOT_FOUND
2155                ])
2156                .unwrap()
2157                .get(),
2158            1
2159        );
2160    }
2161}