dynamo_llm/kv_router/
indexer.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! KV RadixTree
17//!
18//! This module implements a key-value (KV) store using a Radix Tree structure to efficiently manage and retrieve data blocks.
19//! It is designed to support LLM (Large Language Model) inference by re-using a global KV cache.
20//!
21//! # Overview
22//!
23//! The main components of this module include:
24//!
25//! - **Radix Tree Structure**:
26//!   - The `RadixTree` struct represents the main data structure, with nodes (`RadixBlock`) containing children and associated worker IDs.
27//!   - It allows efficient storage and retrieval of data blocks based on their hashes.
28//!
29//! - **Event Handling**:
30//!   - The `RouterEvent` struct represents events emitted by LLM workers, which can be applied to the Radix Tree to update its state.
31//!   - The `KvIndexer` struct manages these events and match requests asynchronously using Tokio channels.
32//!
33//! - **Hash Computation**:
34//!   - Functions like `compute_block_hash` and `compute_block_hash_for_seq` compute hashes for data blocks and sequences of tokens, facilitating quick lookups.
35//!
36//! - **Concurrency and Asynchronous Operations**:
37//!   - The `KvIndexer` uses a single-threaded Tokio runtime to handle events and match requests concurrently, ensuring efficient processing without blocking.
38//!
39//! - **Match Requests**:
40//!   - The `MatchRequest` struct represents requests to find matches in the Radix Tree, returning overlap scores indicating the best matches.
41//!
42//! # Purpose
43//!
44//! This module provides a scalable and efficient way to manage and retrieve data blocks for LLM inference, leveraging a global KV cache to optimize performance.
45
46use bytes::Bytes;
47// use prometheus::{IntCounter, IntGauge};
48use async_trait::async_trait;
49use serde::{Deserialize, Serialize};
50use std::{
51    cell::RefCell,
52    collections::{HashMap, HashSet, VecDeque},
53    iter,
54    rc::Rc,
55    sync::OnceLock,
56    thread::JoinHandle,
57    time::{Duration, Instant},
58};
59use tokio::sync::{broadcast, mpsc, oneshot};
60use tokio_util::sync::CancellationToken;
61use tracing as log;
62use xxhash_rust::xxh3;
63
64pub const XXH3_SEED: u64 = 1337;
65
66use crate::kv_router::protocols::*;
67
68/// Errors that can occur in the KV Router.
69#[derive(Debug, thiserror::Error)]
70pub enum KvRouterError {
71    #[error("Block not found")]
72    BlockNotFound,
73
74    #[error("Indexer is offline")]
75    IndexerOffline,
76
77    #[error("Indexer is dropped request")]
78    IndexerDroppedRequest,
79}
80
81/// Identifier of a LLM worker which emits events to the router.
82pub type WorkerId = i64;
83
84/// A shared reference to a [`RadixBlock`].
85type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
86
87pub fn compute_hash(data: &[u8]) -> u64 {
88    xxh3::xxh3_64_with_seed(data, XXH3_SEED)
89}
90
91/// Compute the hash of a local block.
92///
93/// ### Arguments
94///
95/// * `data` - A byte slice representing the data to hash.
96///
97/// ### Returns
98///
99/// A `LocalBlockHash` representing the computed hash.
100pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
101    LocalBlockHash(compute_hash(data))
102}
103
104// /// Updated version of the `compute_block_hash` function that included the lora_id
105// pub fn compute_block_hash_v2(token_id: &[u32], lora_id: u64) {
106//     let mut bytes = Vec::new();
107//     for token in token_id {
108//         bytes.extend_from_slice(&token.to_le_bytes());
109//     }
110//     bytes.extend_from_slice(&lora_id.to_le_bytes());
111//     let hash = xxh3::xxh3_64_with_seed(&bytes, XXH3_SEED);
112// }
113
114/// Compute the hash for a sequence of tokens.
115///
116/// ### Arguments
117///
118/// * `tokens` - A vector of `u32` tokens.
119///
120/// ### Returns
121///
122/// A vector of `LocalBlockHash` representing the computed hashes for each chunk of tokens.
123pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: usize) -> Vec<LocalBlockHash> {
124    tokens
125        .chunks_exact(kv_block_size) // Split into chunks of kv_block_size elements
126        .map(|chunk| {
127            let bytes: Vec<u8> = chunk
128                .iter()
129                .flat_map(|&num| num.to_le_bytes()) // Convert each i32 to its little-endian bytes
130                .collect();
131
132            compute_block_hash(&Bytes::from(bytes)) // Convert the byte Vec to Bytes
133        })
134        .collect()
135}
136
137/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`].
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct RouterEvent {
140    /// The ID of the worker emitting the event.
141    worker_id: WorkerId,
142    /// The cache event associated with the worker.
143    event: KvCacheEvent,
144}
145
146impl RouterEvent {
147    /// Create a new `RouterEvent`.
148    ///
149    /// ### Arguments
150    ///
151    /// * `worker_id` - The ID of the worker emitting the event.
152    /// * `event` - The cache event.
153    ///
154    /// ### Returns
155    ///
156    /// A new `RouterEvent`.
157    pub fn new(worker_id: WorkerId, event: KvCacheEvent) -> Self {
158        Self { worker_id, event }
159    }
160}
161
162/// A block in the Radix Tree.
163struct RadixBlock {
164    /// A map of child blocks, keyed by their local block hash.
165    children: HashMap<LocalBlockHash, SharedRadixBlock>,
166    /// A set of worker IDs associated with this block.
167    workers: HashSet<WorkerId>,
168    /// A buffer of times that this block was last traversed
169    recent_uses: VecDeque<Instant>,
170}
171
172impl RadixBlock {
173    /// Create a new `RadixBlock`.
174    ///
175    /// ### Returns
176    ///
177    /// A new `RadixBlock`.
178    pub fn new() -> Self {
179        Self {
180            children: HashMap::new(),
181            workers: HashSet::new(),
182            recent_uses: VecDeque::new(),
183        }
184    }
185}
186
187pub struct RadixTree {
188    /// This is the root of the radix/prefix tree
189    /// This will only contain root blocks
190    root: SharedRadixBlock,
191
192    /// This is a global lookup table for all blocks which will let you jump into
193    /// the radix tree at any point
194    /// Lookup is best case O(1) and worst case O(N); however, even constant in-time
195    /// could be expensive if N is large
196    /// We should monitor the size of this table and consider using a proper radix tree.
197    /// Transitioning to a radix tree only would require a change in the messaging structure
198    /// as the entire prefix would need to be sent. Alternatively, we could use block_depth
199    /// integers to indicate how many blocks to skip and use a radix/prefix tree at each level.
200    lookup: HashMap<WorkerId, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
201    /// The time buffer the radix tree should check when considering frequence of block accesses
202    expiration_duration: Option<Duration>,
203}
204
205impl Default for RadixTree {
206    fn default() -> Self {
207        Self::new()
208    }
209}
210
211impl RadixTree {
212    /// Create a new `RadixTree`.
213    ///
214    /// ### Returns
215    ///
216    /// A new `RadixTree`.
217    pub fn new_with_frequency(expiration_duration: Option<Duration>) -> Self {
218        Self {
219            root: Rc::new(RefCell::new(RadixBlock::new())),
220            lookup: HashMap::new(),
221            expiration_duration,
222        }
223    }
224
225    pub fn new() -> Self {
226        Self::new_with_frequency(None)
227    }
228
229    /// Traverse the radix tree to find the best match for a given sequence of [`LocalBlockHash`]es.
230    ///
231    /// ### Arguments
232    ///
233    /// * `sequence` - A vector of `LocalBlockHash` representing the sequence to match.
234    /// * `early_exit` - A boolean indicating whether to exit early if a single match is found.
235    ///
236    /// ### Returns
237    ///
238    /// An `OverlapScores` representing the match scores.
239    pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
240        let mut scores = OverlapScores::new();
241        let mut current = self.root.clone();
242        let now = Instant::now();
243        for block_hash in sequence {
244            let next_block = {
245                let current_borrow = current.borrow();
246                current_borrow.children.get(&block_hash).cloned()
247            };
248
249            if let Some(block) = next_block {
250                scores.update_scores(&block.borrow().workers);
251
252                if let Some(expiration_duration) = self.expiration_duration {
253                    let mut block_mut = block.borrow_mut();
254
255                    while let Some(access_time) = block_mut.recent_uses.front() {
256                        if now.duration_since(*access_time) > expiration_duration {
257                            block_mut.recent_uses.pop_front();
258                        } else {
259                            break;
260                        }
261                    }
262                    scores.add_frequency(block_mut.recent_uses.len());
263                    block_mut.recent_uses.push_back(now);
264                }
265
266                if early_exit && block.borrow().workers.len() == 1 {
267                    break;
268                }
269
270                current = block;
271            } else {
272                break;
273            }
274        }
275
276        scores
277    }
278
279    /// Apply a [`RouterEvent`] to the radix tree.
280    ///
281    /// ### Arguments
282    ///
283    /// * `event` - The `RouterEvent` to apply.
284    pub fn apply_event(&mut self, event: RouterEvent) {
285        let (worker_id, event) = (event.worker_id, event.event);
286        let (id, op) = (event.event_id, event.data);
287        log::debug!(id, "Store operation: {:?}", op);
288
289        let worker_lookup = self.lookup.entry(worker_id).or_default();
290
291        match op {
292            KvCacheEventData::Stored(op) => {
293                // find the parent block - if the parent exists it must be on our worker, if not,
294                // we check the radix tree's root to find it.
295                // this is the single most expensive lookup
296                let current = match op.parent_hash {
297                    Some(parent) => worker_lookup.get(&parent),
298                    None => Some(&self.root),
299                };
300
301                let mut current = match current {
302                    Some(current) => current.clone(),
303                    None => {
304                        log::warn!(
305                            worker_id = worker_id.to_string(),
306                            id,
307                            parent_hash = ?op.parent_hash,
308                            "Failed to find parent block; skipping store operation"
309                        );
310                        return;
311                    }
312                };
313
314                for block_id in op.blocks {
315                    let mut inner = current.borrow_mut();
316                    let block = match inner.children.get(&block_id.tokens_hash) {
317                        Some(block) => block.clone(),
318                        None => {
319                            // create new block - automatically added to the lookup table
320                            let new_block = worker_lookup
321                                .get(&block_id.block_hash)
322                                .cloned()
323                                .unwrap_or_else(|| Rc::new(RefCell::new(RadixBlock::new())));
324
325                            // insert into radix tree
326                            inner
327                                .children
328                                .insert(block_id.tokens_hash, new_block.clone());
329
330                            new_block
331                        }
332                    };
333
334                    // add our worker_id to the block
335                    block.borrow_mut().workers.insert(worker_id);
336
337                    // add the block to the worker_id lookup table
338                    worker_lookup.insert(block_id.block_hash, block.clone());
339
340                    // drop inner so we can shift current to this block
341                    drop(inner);
342
343                    current = block;
344                }
345            }
346            KvCacheEventData::Removed(remove) => {
347                // log::trace!(id, "KV Remove Operation: {:?}", op);
348                // let mut worker_lookup = self.lookup.get(&worker_id).expect("Worker not found");
349
350                for block in remove.block_hashes {
351                    // entry in radix tree
352                    // a small optimization would be to get the next block from the reduced set of children
353                    // in order to apply this optimization, we would need to know the list of blocks is always sorted
354                    // by parent -> child relationship
355                    let entry = match worker_lookup.get(&block) {
356                        Some(entry) => entry.clone(),
357                        None => {
358                            log::warn!(
359                                worker_id = worker_id.to_string(),
360                                id,
361                                "Failed to find block to remove; skipping remove operation"
362                            );
363                            continue;
364                        }
365                    };
366
367                    let mut guard = entry.borrow_mut();
368                    guard.workers.remove(&worker_id);
369                    if guard.workers.is_empty() {
370                        // if no worker are using this block, that is true for all children
371                        guard.children.clear();
372                    }
373                    // remove the block from the lookup table
374                    worker_lookup.remove(&block);
375                }
376            }
377        }
378    }
379
380    pub fn remove_worker(&mut self, worker: WorkerId) {
381        if let Some((_, blocks)) = self.lookup.remove_entry(&worker) {
382            blocks.iter().for_each(|(_, block)| {
383                block.borrow_mut().workers.remove(&worker);
384            });
385        }
386    }
387}
388
389/// Scores representing the overlap of workers.
390#[derive(Debug, Clone, Serialize, Deserialize)]
391pub struct OverlapScores {
392    // map of worker_id to score
393    pub scores: HashMap<WorkerId, u32>,
394    // List of frequencies that the blocks have been accessed. Entries with value 0 are omitted.
395    pub frequencies: Vec<usize>,
396}
397
398impl Default for OverlapScores {
399    fn default() -> Self {
400        Self::new()
401    }
402}
403
404impl OverlapScores {
405    /// Create a new `OverlapScores`.
406    ///
407    /// ### Returns
408    ///
409    /// A new `OverlapScores`.
410    pub fn new() -> Self {
411        Self {
412            scores: HashMap::new(),
413            frequencies: Vec::with_capacity(32),
414        }
415    }
416
417    /// Update the scores with a set of workers.
418    ///
419    /// ### Arguments
420    ///
421    /// * `workers` - A reference to a `HashSet` of `WorkerId`s.
422    pub fn update_scores(&mut self, workers: &HashSet<WorkerId>) {
423        for worker in workers {
424            let score = self.scores.entry(*worker).or_insert(0);
425            *score += 1;
426        }
427    }
428
429    /// Add an entry in the frequency list.
430    pub fn add_frequency(&mut self, frequency: usize) {
431        if frequency != 0 {
432            self.frequencies
433                .last()
434                .inspect(|elem| debug_assert!(**elem >= frequency));
435            self.frequencies.push(frequency);
436        }
437    }
438}
439
440/// A request to find matches in the Radix Tree.
441pub struct MatchRequest {
442    /// A vector of `LocalBlockHash` representing the sequence to match.
443    sequence: Vec<LocalBlockHash>,
444    /// A boolean indicating whether to exit early if a single match is found.
445    early_exit: bool,
446    /// A channel sender to send the `OverlapScores` response.
447    resp: oneshot::Sender<OverlapScores>,
448}
449
450#[async_trait]
451pub trait KvIndexerInterface {
452    /// Find matches for a given sequence of `LocalBlockHash`es.
453    ///
454    /// ### Arguments
455    ///
456    /// * `sequence` - A vector of `LocalBlockHash` representing the sequence to match.
457    ///
458    /// ### Returns
459    ///
460    /// An `OverlapScores` representing the match scores.
461    async fn find_matches(
462        &self,
463        sequence: Vec<LocalBlockHash>,
464    ) -> Result<OverlapScores, KvRouterError>;
465
466    /// Find matches for a given sequence of tokens.
467    ///
468    /// ### Arguments
469    ///
470    /// * `tokens` - A vector of `u32` tokens.
471    ///
472    /// ### Returns
473    ///
474    /// An `OverlapScores` representing the match scores.
475    async fn find_matches_for_request(
476        &self,
477        tokens: &[u32],
478    ) -> Result<OverlapScores, KvRouterError>;
479
480    /// Apply a `RouterEvent` to the KV store.
481    ///
482    /// ### Arguments
483    ///
484    /// * `event` - The `RouterEvent` to apply.
485    async fn apply_event(&mut self, event: RouterEvent);
486
487    /// Remove a worker's entries from the trie.
488    ///
489    /// ### Arguments
490    ///
491    /// * `worker` - The worker to remove from the trie.
492    async fn remove_worker(&mut self, worker: WorkerId);
493
494    /// Shutdown the KV Indexer.
495    fn shutdown(&mut self);
496}
497
498/// The KV Indexer, managing the KV store and handling events and match requests.
499pub struct KvIndexer {
500    /// A `CancellationToken` for managing shutdown.
501    cancel: CancellationToken,
502    /// A sender for `RouterEvent`s.
503    event_tx: mpsc::Sender<RouterEvent>,
504    /// A sender for `MatchRequest`s.
505    match_tx: mpsc::Sender<MatchRequest>,
506    /// A sender for remove worker requests.
507    remove_worker_tx: mpsc::Sender<WorkerId>,
508    /// A handle to the background task managing the KV store.
509    task: OnceLock<std::thread::JoinHandle<()>>,
510    /// The size of the KV block this indexer can handle.
511    kv_block_size: usize,
512}
513
514impl KvIndexer {
515    /// Create a new `KvIndexer`.
516    ///
517    /// ### Arguments
518    ///
519    /// * `token` - A `CancellationToken` for managing shutdown.
520    /// * `expiration_duration` - The amount of time that block usage should be buffered.
521    ///
522    /// ### Returns
523    ///
524    /// A new `KvIndexer`.
525    pub fn new_with_frequency(
526        token: CancellationToken,
527        expiration_duration: Option<Duration>,
528        kv_block_size: usize,
529    ) -> Self {
530        let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
531        let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
532        let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
533        let cancel_clone = token.clone();
534        let task = std::thread::spawn(move || {
535            // create a new tokio runtime which will only perform work on a single thread
536            let runtime = tokio::runtime::Builder::new_multi_thread()
537                .worker_threads(1) // Single-threaded environment
538                .enable_all()
539                .build()
540                .unwrap();
541
542            let local_set = tokio::task::LocalSet::new();
543
544            runtime.block_on(local_set.run_until(async move {
545                tokio::task::spawn_local(async move {
546                    let cancel = cancel_clone;
547                    let mut match_rx = match_rx;
548                    let mut event_rx = event_rx;
549                    let mut remove_worker_rx = remove_worker_rx;
550                    let mut trie = RadixTree::new_with_frequency(expiration_duration);
551                    loop {
552                        tokio::select! {
553                            biased;
554
555                            Some(worker) = remove_worker_rx.recv() => {
556                                trie.remove_worker(worker);
557                            }
558
559                            Some(req) = match_rx.recv() => {
560                                let matches = trie.find_matches(req.sequence, req.early_exit);
561                                let _ = req.resp.send(matches);
562                            }
563
564                            _ = cancel.cancelled() => {
565                                log::debug!("KvCacheIndexer progress loop shutting down");
566                                return;
567                            }
568
569                            Some(event) = event_rx.recv() => {
570                                trie.apply_event(event);
571                            }
572                        }
573                    }
574                })
575                .await
576                .unwrap()
577            }));
578
579            log::debug!("KvCacheIndexer task completed");
580        });
581
582        let once = OnceLock::new();
583        once.set(task).unwrap();
584
585        Self {
586            cancel: token,
587            event_tx,
588            match_tx,
589            remove_worker_tx,
590            task: once,
591            kv_block_size,
592        }
593    }
594
595    pub fn block_size(&self) -> usize {
596        self.kv_block_size
597    }
598
599    pub fn new(token: CancellationToken, kv_block_size: usize) -> Self {
600        Self::new_with_frequency(token, None, kv_block_size)
601    }
602
603    /// Get a sender for `RouterEvent`s.
604    ///
605    /// ### Returns
606    ///
607    /// A `mpsc::Sender` for `RouterEvent`s.
608    pub fn event_sender(&self) -> mpsc::Sender<RouterEvent> {
609        self.event_tx.clone()
610    }
611}
612
613#[async_trait]
614impl KvIndexerInterface for KvIndexer {
615    async fn find_matches(
616        &self,
617        sequence: Vec<LocalBlockHash>,
618    ) -> Result<OverlapScores, KvRouterError> {
619        let (resp_tx, resp_rx) = oneshot::channel();
620        let req = MatchRequest {
621            sequence,
622            early_exit: false,
623            resp: resp_tx,
624        };
625
626        if let Err(e) = self.match_tx.send(req).await {
627            log::error!(
628                "Failed to send match request: {:?}; the indexer maybe offline",
629                e
630            );
631            return Err(KvRouterError::IndexerOffline);
632        }
633
634        resp_rx
635            .await
636            .map_err(|_| KvRouterError::IndexerDroppedRequest)
637    }
638
639    async fn find_matches_for_request(
640        &self,
641        tokens: &[u32],
642    ) -> Result<OverlapScores, KvRouterError> {
643        log::debug!(
644            "Finding matches for request tokens: {:?} / len: {}",
645            tokens,
646            tokens.len()
647        );
648        let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
649        log::debug!("Computed sequence: {:?}", sequence);
650        self.find_matches(sequence).await
651    }
652
653    async fn apply_event(&mut self, event: RouterEvent) {
654        self.event_tx.send(event).await.unwrap();
655    }
656
657    async fn remove_worker(&mut self, worker: WorkerId) {
658        self.remove_worker_tx.send(worker).await.unwrap();
659    }
660
661    fn shutdown(&mut self) {
662        self.cancel.cancel();
663        if let Some(task) = self.task.take() {
664            task.join().expect("Failed to join kv indexer task");
665        }
666    }
667}
668
669#[derive(Debug, Clone)]
670pub struct ShardedMatchRequest {
671    sequence: Vec<LocalBlockHash>,
672    early_exit: bool,
673    resp: mpsc::Sender<OverlapScores>,
674}
675
676/// The KV Indexer, managing the KV store and handling events and match requests.
677pub struct KvIndexerSharded {
678    /// A `CancellationToken` for managing shutdown.
679    cancel: CancellationToken,
680    /// The size of the KV block this indexer can handle.
681    kv_block_size: usize,
682    worker_assignments: HashMap<WorkerId, usize>,
683    worker_counts: Vec<usize>,
684
685    event_tx: Vec<mpsc::Sender<RouterEvent>>,
686    request_broadcast_tx: broadcast::Sender<ShardedMatchRequest>,
687    remove_worker_tx: Vec<mpsc::Sender<WorkerId>>,
688    tasks: Vec<JoinHandle<()>>,
689}
690
691impl KvIndexerSharded {
692    /// Create a new `KvIndexerSharded`.
693    ///
694    /// ### Arguments
695    ///
696    /// * `token` - A `CancellationToken` for managing shutdown.
697    /// * `shards` - A list of kvindexer shards.
698    /// * `expiration_duration` - The amount of time that block usage should be buffered.
699    ///
700    /// ### Returns
701    ///
702    /// A new `KvIndexer`.
703    pub fn new_with_frequency(
704        token: CancellationToken,
705        num_shards: usize,
706        expiration_duration: Option<Duration>,
707        kv_block_size: usize,
708    ) -> Self {
709        let worker_assignments: HashMap<WorkerId, usize> = HashMap::new();
710        let worker_counts: Vec<usize> = vec![0; num_shards];
711
712        let mut event_tx = Vec::new();
713        let mut remove_worker_tx = Vec::new();
714        let mut tasks = Vec::new();
715
716        let (request_broadcast_tx, _) = broadcast::channel::<ShardedMatchRequest>(1048576);
717
718        for _ in 0..num_shards {
719            let (shard_event_tx, mut shard_event_rx) = mpsc::channel::<RouterEvent>(2048);
720            let (shard_remove_worker_tx, mut shard_remove_worker_rx) =
721                mpsc::channel::<WorkerId>(16);
722            let mut shard_broadcast_rx = request_broadcast_tx.subscribe();
723            let cancel = token.clone();
724
725            event_tx.push(shard_event_tx);
726            remove_worker_tx.push(shard_remove_worker_tx);
727
728            let runtime = tokio::runtime::Builder::new_multi_thread()
729                .worker_threads(1)
730                .enable_all()
731                .build()
732                .unwrap();
733
734            tasks.push(std::thread::spawn(move || {
735                let local_set = tokio::task::LocalSet::new();
736
737                runtime.block_on(local_set.run_until(async move {
738                    tokio::task::spawn_local(async move {
739                        let mut trie = RadixTree::new_with_frequency(expiration_duration);
740                        loop {
741                            tokio::select! {
742                                biased;
743
744                                Some(worker) = shard_remove_worker_rx.recv() => {
745                                    trie.remove_worker(worker);
746                                }
747
748                                Ok(req) = shard_broadcast_rx.recv() => {
749                                    let matches = trie.find_matches(req.sequence, req.early_exit);
750                                    if let Err(e) = req.resp.send(matches).await {
751                                        log::trace!("Failed to send match response: {:?}", e);
752                                    }
753                                }
754
755                                _ = cancel.cancelled() => {
756                                    log::debug!("KvCacheIndexer progress loop shutting down");
757                                    return;
758                                }
759
760                                Some(event) = shard_event_rx.recv() => {
761                                    trie.apply_event(event);
762                                }
763                            }
764                        }
765                    })
766                    .await
767                    .unwrap()
768                }));
769
770                log::debug!("KvCacheIndexer task completed");
771            }));
772        }
773
774        Self {
775            cancel: token,
776            kv_block_size,
777            worker_assignments,
778            worker_counts,
779            event_tx,
780            request_broadcast_tx,
781            remove_worker_tx,
782            tasks,
783        }
784    }
785
786    pub fn block_size(&self) -> usize {
787        self.kv_block_size
788    }
789
790    pub fn new(token: CancellationToken, num_shards: usize, kv_block_size: usize) -> Self {
791        Self::new_with_frequency(token, num_shards, None, kv_block_size)
792    }
793}
794
795#[async_trait]
796impl KvIndexerInterface for KvIndexerSharded {
797    async fn find_matches(
798        &self,
799        sequence: Vec<LocalBlockHash>,
800    ) -> Result<OverlapScores, KvRouterError> {
801        'match_loop: loop {
802            let (match_tx, mut match_rx) = mpsc::channel(self.event_tx.len());
803            self.request_broadcast_tx
804                .send(ShardedMatchRequest {
805                    sequence: sequence.clone(),
806                    early_exit: false,
807                    resp: match_tx,
808                })
809                .map_err(|_| KvRouterError::IndexerOffline)?;
810
811            let mut scores = OverlapScores::new();
812
813            for response_num in 0..self.event_tx.len() {
814                match match_rx.recv().await {
815                    Some(response) => {
816                        scores.scores.extend(response.scores);
817
818                        if response_num == 0 {
819                            scores.frequencies = response.frequencies;
820                        } else {
821                            let diff = (response.frequencies.len() as i64)
822                                - (scores.frequencies.len() as i64);
823
824                            if diff > 0 {
825                                scores.frequencies.extend(iter::repeat_n(0, diff as usize));
826                            }
827
828                            for i in 0..response.frequencies.len() {
829                                scores.frequencies[i] += response.frequencies[i];
830                            }
831                        }
832                    }
833                    None => {
834                        // This can only happen if the broadcast channel overflows.
835                        // In this case, we don't want to recursively call find_matches again. Otherwise, we could overflow the stack.
836                        continue 'match_loop;
837                    }
838                }
839            }
840            return Ok(scores);
841        }
842    }
843
844    async fn find_matches_for_request(
845        &self,
846        tokens: &[u32],
847    ) -> Result<OverlapScores, KvRouterError> {
848        let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
849        self.find_matches(sequence).await
850    }
851
852    async fn apply_event(&mut self, event: RouterEvent) {
853        #[allow(clippy::map_entry)]
854        if !self.worker_assignments.contains_key(&event.worker_id) {
855            // Get the shard with the smallest amount of workers.
856            let selected_shard = self
857                .worker_counts
858                .iter()
859                .enumerate()
860                .min_by_key(|&(_, value)| value)
861                .unwrap()
862                .0;
863
864            self.worker_assignments
865                .insert(event.worker_id, selected_shard);
866            self.worker_counts[selected_shard] += 1;
867        }
868
869        self.event_tx[self.worker_assignments[&event.worker_id]]
870            .send(event)
871            .await
872            .unwrap();
873    }
874
875    async fn remove_worker(&mut self, worker: WorkerId) {
876        if let Some((_, shard)) = self.worker_assignments.remove_entry(&worker) {
877            self.worker_counts[shard] -= 1;
878            self.remove_worker_tx[shard].send(worker).await.unwrap();
879        }
880    }
881
882    /// Shutdown the KV Indexer.
883    fn shutdown(&mut self) {
884        self.cancel.cancel();
885        while !self.tasks.is_empty() {
886            self.tasks.pop().unwrap().join().unwrap();
887        }
888    }
889}
890
891#[cfg(test)]
892mod tests {
893
894    use super::*;
895    use rstest::rstest;
896    use rstest_reuse::{self, *};
897    use tokio::time;
898    use tokio_util::sync::CancellationToken;
899
900    fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
901        hashes
902            .iter()
903            .map(|i| KvCacheStoredBlockData {
904                tokens_hash: LocalBlockHash(*i),
905                block_hash: ExternalSequenceBlockHash(*i * 100),
906            })
907            .collect()
908    }
909
910    fn add_blocks(
911        hashes: Vec<u64>,
912        parent_hash: Option<ExternalSequenceBlockHash>,
913    ) -> KvCacheEventData {
914        KvCacheEventData::Stored(KvCacheStoreData {
915            parent_hash,
916            blocks: make_blocks(hashes),
917        })
918    }
919
920    fn create_store_event(
921        worker_id: WorkerId,
922        event_id: u64,
923        hashes: Vec<u64>,
924        parent: Option<ExternalSequenceBlockHash>,
925    ) -> RouterEvent {
926        RouterEvent {
927            worker_id,
928            event: KvCacheEvent {
929                event_id,
930                data: add_blocks(hashes, parent),
931            },
932        }
933    }
934
935    fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
936        RouterEvent {
937            worker_id,
938            event: KvCacheEvent {
939                event_id,
940                data: KvCacheEventData::Removed(KvCacheRemoveData {
941                    block_hashes: hashes
942                        .iter()
943                        .map(|i| ExternalSequenceBlockHash(*i * 100))
944                        .collect(),
945                }),
946            },
947        }
948    }
949
950    #[test]
951    fn test_radix_tree() {
952        let mut trie = RadixTree::new();
953
954        let worker_1 = 0;
955        let worker_2 = 1;
956
957        trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None));
958
959        let scores = trie.find_matches(
960            vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
961            false,
962        );
963        assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
964
965        assert_eq!(trie.lookup.len(), 1);
966        assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
967        assert_eq!(trie.root.borrow().workers.len(), 0);
968        assert_eq!(trie.root.borrow().children.len(), 1);
969        assert_eq!(
970            trie.root
971                .borrow()
972                .children
973                .get(&LocalBlockHash(1))
974                .unwrap()
975                .borrow()
976                .workers
977                .len(),
978            1
979        );
980        assert_eq!(
981            trie.root
982                .borrow()
983                .children
984                .get(&LocalBlockHash(1))
985                .unwrap()
986                .borrow()
987                .children
988                .len(),
989            1
990        );
991
992        trie.apply_event(create_store_event(worker_2, 1, vec![1, 4, 5], None));
993
994        let scores = trie.find_matches(
995            vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
996            false,
997        );
998        assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
999        assert_eq!(scores.scores.get(&worker_2).unwrap(), &1);
1000
1001        assert_eq!(trie.lookup.len(), 2);
1002        assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1003        assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 3);
1004        assert_eq!(trie.root.borrow().workers.len(), 0);
1005        assert_eq!(trie.root.borrow().children.len(), 1);
1006        assert_eq!(
1007            trie.root
1008                .borrow()
1009                .children
1010                .get(&LocalBlockHash(1))
1011                .unwrap()
1012                .borrow()
1013                .workers
1014                .len(),
1015            2
1016        );
1017        assert_eq!(
1018            trie.root
1019                .borrow()
1020                .children
1021                .get(&LocalBlockHash(1))
1022                .unwrap()
1023                .borrow()
1024                .children
1025                .len(),
1026            2
1027        );
1028
1029        trie.apply_event(create_remove_event(worker_2, 2, vec![5]));
1030        assert_eq!(trie.lookup.len(), 2);
1031        assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1032        assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 2);
1033        assert_eq!(trie.root.borrow().workers.len(), 0);
1034        assert_eq!(trie.root.borrow().children.len(), 1);
1035        assert_eq!(
1036            trie.root
1037                .borrow()
1038                .children
1039                .get(&LocalBlockHash(1))
1040                .unwrap()
1041                .borrow()
1042                .workers
1043                .len(),
1044            2
1045        );
1046        assert_eq!(
1047            trie.root
1048                .borrow()
1049                .children
1050                .get(&LocalBlockHash(1))
1051                .unwrap()
1052                .borrow()
1053                .children
1054                .len(),
1055            2
1056        );
1057
1058        trie.apply_event(create_remove_event(worker_2, 3, vec![4]));
1059
1060        assert_eq!(trie.lookup.len(), 2);
1061        assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1062        assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 1);
1063        assert_eq!(trie.root.borrow().workers.len(), 0);
1064        assert_eq!(trie.root.borrow().children.len(), 1);
1065        assert_eq!(
1066            trie.root
1067                .borrow()
1068                .children
1069                .get(&LocalBlockHash(1))
1070                .unwrap()
1071                .borrow()
1072                .workers
1073                .len(),
1074            2
1075        );
1076        assert_eq!(
1077            trie.root
1078                .borrow()
1079                .children
1080                .get(&LocalBlockHash(1))
1081                .unwrap()
1082                .borrow()
1083                .children
1084                .len(),
1085            2
1086        );
1087
1088        trie.apply_event(create_store_event(
1089            worker_2,
1090            4,
1091            vec![2, 6, 7],
1092            Some(ExternalSequenceBlockHash(100)),
1093        ));
1094
1095        let scores = trie.find_matches(
1096            vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
1097            false,
1098        );
1099        assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
1100        assert_eq!(scores.scores.get(&worker_2).unwrap(), &2);
1101
1102        assert_eq!(trie.lookup.len(), 2);
1103        assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
1104        assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 4);
1105        assert_eq!(trie.root.borrow().workers.len(), 0);
1106        assert_eq!(trie.root.borrow().children.len(), 1);
1107        assert_eq!(
1108            trie.root
1109                .borrow()
1110                .children
1111                .get(&LocalBlockHash(1))
1112                .unwrap()
1113                .borrow()
1114                .workers
1115                .len(),
1116            2
1117        );
1118        assert_eq!(
1119            trie.root
1120                .borrow()
1121                .children
1122                .get(&LocalBlockHash(1))
1123                .unwrap()
1124                .borrow()
1125                .children
1126                .len(),
1127            2
1128        );
1129        assert_eq!(
1130            trie.lookup
1131                .get(&worker_1)
1132                .unwrap()
1133                .get(&ExternalSequenceBlockHash(200))
1134                .unwrap()
1135                .borrow()
1136                .workers
1137                .len(),
1138            2
1139        );
1140        assert_eq!(
1141            trie.lookup
1142                .get(&worker_2)
1143                .unwrap()
1144                .get(&ExternalSequenceBlockHash(200))
1145                .unwrap()
1146                .borrow()
1147                .workers
1148                .len(),
1149            2
1150        );
1151    }
1152
1153    #[test]
1154    fn test_remove_worker() {
1155        let mut trie = RadixTree::new();
1156
1157        let worker_0 = 0;
1158        let worker_1 = 1;
1159
1160        assert!(trie
1161            .find_matches(vec![LocalBlockHash(0)], false)
1162            .scores
1163            .is_empty());
1164
1165        trie.apply_event(create_store_event(worker_0, 0, vec![0], None));
1166        trie.apply_event(create_store_event(worker_1, 0, vec![0], None));
1167
1168        let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
1169        assert!(result.len() == 2 && result[&worker_0] == 1 && result[&worker_1] == 1);
1170
1171        trie.remove_worker(worker_0);
1172
1173        let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
1174        assert!(result.len() == 1 && result[&worker_1] == 1);
1175    }
1176
1177    #[test]
1178    fn test_early_stopping() {
1179        let mut trie = RadixTree::new();
1180
1181        let worker_0 = 0;
1182        let worker_1 = 1;
1183
1184        trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 2], None));
1185        trie.apply_event(create_store_event(worker_1, 0, vec![0], None));
1186
1187        let result = trie
1188            .find_matches(
1189                vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(2)],
1190                true,
1191            )
1192            .scores;
1193
1194        assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1);
1195
1196        let result = trie
1197            .find_matches(vec![LocalBlockHash(0), LocalBlockHash(1)], true)
1198            .scores;
1199        assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1);
1200    }
1201
1202    #[rstest]
1203    #[case(11)]
1204    #[case(32)]
1205    #[case(64)]
1206    fn test_compute_block_hash_for_seq(#[case] kv_block_size: usize) {
1207        // create a sequence of 64 elements
1208        let sequence = (0..kv_block_size).map(|i| i as u32).collect::<Vec<u32>>();
1209        let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
1210        assert_eq!(hashes.len(), 1);
1211
1212        // create a sequence of 65 elements
1213        let sequence = (0..(kv_block_size + 1))
1214            .map(|i| i as u32)
1215            .collect::<Vec<u32>>();
1216        let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
1217        assert_eq!(hashes.len(), 1);
1218
1219        // create a sequence of 129 elements
1220        let sequence = (0..(2 * kv_block_size + 1))
1221            .map(|i| i as u32)
1222            .collect::<Vec<u32>>();
1223        let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
1224        assert_eq!(hashes.len(), 2);
1225    }
1226
1227    fn make_indexer(
1228        token: &CancellationToken,
1229        num_shards: usize,
1230        kv_block_size: usize,
1231    ) -> Box<dyn KvIndexerInterface> {
1232        if num_shards == 1 {
1233            Box::new(KvIndexer::new(token.clone(), kv_block_size))
1234        } else {
1235            Box::new(KvIndexerSharded::new(
1236                token.clone(),
1237                num_shards,
1238                kv_block_size,
1239            ))
1240        }
1241    }
1242
1243    #[template]
1244    #[rstest]
1245    fn indexer_template(
1246        #[values(1, 3, 8)] num_shards: usize,
1247        #[values(11, 32, 64)] kv_block_size: usize,
1248    ) {
1249    }
1250
1251    #[tokio::test]
1252    #[apply(indexer_template)]
1253    async fn test_kv_indexer_new(num_shards: usize, kv_block_size: usize) {
1254        let token: CancellationToken = CancellationToken::new();
1255        let _ = make_indexer(&token, num_shards, kv_block_size);
1256    }
1257
1258    #[tokio::test]
1259    #[apply(indexer_template)]
1260    async fn test_find_matches(num_shards: usize, kv_block_size: usize) {
1261        let token = CancellationToken::new();
1262        let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1263
1264        let sequence = vec![compute_block_hash(b"test data")];
1265        let scores = kv_indexer.find_matches(sequence).await;
1266
1267        assert!(scores.unwrap().scores.is_empty());
1268    }
1269
1270    #[tokio::test]
1271    #[apply(indexer_template)]
1272    async fn test_find_matches_for_request(num_shards: usize, kv_block_size: usize) {
1273        let token = CancellationToken::new();
1274        let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1275
1276        let tokens = vec![1, 2, 3, 4];
1277        let scores = kv_indexer.find_matches_for_request(&tokens).await;
1278
1279        assert!(scores.unwrap().scores.is_empty());
1280    }
1281
1282    #[tokio::test]
1283    #[apply(indexer_template)]
1284    async fn test_apply_event(num_shards: usize, kv_block_size: usize) {
1285        let worker_id = 0;
1286
1287        let token = CancellationToken::new();
1288        let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1289
1290        let event = create_store_event(worker_id, 1, vec![1, 2, 3], None);
1291        kv_indexer.apply_event(event).await;
1292
1293        // No assertion here, just ensuring it runs without panic
1294    }
1295
1296    #[tokio::test]
1297    #[apply(indexer_template)]
1298    async fn test_shutdown(num_shards: usize, kv_block_size: usize) {
1299        let token = CancellationToken::new();
1300        let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1301
1302        kv_indexer.shutdown();
1303    }
1304
1305    #[tokio::test]
1306    #[apply(indexer_template)]
1307    async fn test_frequency(num_shards: usize, kv_block_size: usize) {
1308        let mut kv_indexer: Box<dyn KvIndexerInterface>;
1309        let token = CancellationToken::new();
1310        let duration = Some(Duration::from_millis(50));
1311
1312        if num_shards == 1 {
1313            kv_indexer = Box::new(KvIndexer::new_with_frequency(
1314                token,
1315                duration,
1316                kv_block_size,
1317            ));
1318        } else {
1319            kv_indexer = Box::new(KvIndexerSharded::new_with_frequency(
1320                token,
1321                num_shards,
1322                duration,
1323                kv_block_size,
1324            ));
1325        }
1326
1327        let worker_id = 0;
1328
1329        let event = create_store_event(worker_id, 0, vec![1, 2, 3, 4], None);
1330        kv_indexer.apply_event(event).await;
1331
1332        time::sleep(Duration::from_millis(5)).await;
1333
1334        let block_hashes = vec![
1335            LocalBlockHash(1),
1336            LocalBlockHash(2),
1337            LocalBlockHash(3),
1338            LocalBlockHash(4),
1339        ];
1340        let scores = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1341
1342        assert_eq!(scores.frequencies.len(), 0);
1343
1344        let scores = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1345        assert_eq!(scores.frequencies, vec![1, 1, 1, 1]);
1346
1347        time::sleep(Duration::from_millis(100)).await;
1348
1349        let scores = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1350        assert_eq!(scores.frequencies.len(), 0);
1351
1352        let scores = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1353        assert_eq!(scores.frequencies, vec![1, 1, 1, 1]);
1354
1355        let scores = kv_indexer
1356            .find_matches(block_hashes[0..3].to_vec())
1357            .await
1358            .unwrap();
1359        assert_eq!(scores.frequencies, vec![2, 2, 2]);
1360
1361        let scores = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
1362        assert_eq!(scores.frequencies, vec![3, 3, 3, 2]);
1363    }
1364
1365    #[test]
1366    fn test_router_event_new() {
1367        let worker_id = 0;
1368        let kv_cache_event = KvCacheEvent {
1369            event_id: 1,
1370            data: KvCacheEventData::Stored(KvCacheStoreData {
1371                parent_hash: None,
1372                blocks: vec![KvCacheStoredBlockData {
1373                    block_hash: ExternalSequenceBlockHash(0),
1374                    tokens_hash: LocalBlockHash(13226331709069118873),
1375                }],
1376            }),
1377        };
1378        let router_event = RouterEvent::new(worker_id, kv_cache_event);
1379
1380        assert_eq!(router_event.worker_id, worker_id);
1381        assert_eq!(router_event.event.event_id, 1);
1382        if let KvCacheEventData::Stored(store_op) = &router_event.event.data {
1383            assert_eq!(store_op.blocks.len(), 1);
1384            assert_eq!(
1385                store_op.blocks[0].tokens_hash,
1386                compute_block_hash(b"test data")
1387            );
1388            assert_eq!(store_op.blocks[0].block_hash, ExternalSequenceBlockHash(0));
1389        } else {
1390            panic!("Expected KvCacheEventData::Stored");
1391        }
1392    }
1393
1394    #[test]
1395    fn test_radix_tree_default() {
1396        let radix_tree: RadixTree = Default::default();
1397        assert!(radix_tree.root.borrow().children.is_empty());
1398        assert!(radix_tree.root.borrow().workers.is_empty());
1399        assert!(radix_tree.lookup.is_empty());
1400    }
1401
1402    #[test]
1403    fn test_overlap_scores_default() {
1404        let overlap_scores: OverlapScores = Default::default();
1405        assert!(overlap_scores.scores.is_empty());
1406    }
1407}