dynamo_llm/
kv_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::Duration;
7
8use anyhow::Result;
9use derive_builder::Builder;
10use dynamo_runtime::{
11    component::{Component, InstanceSource},
12    pipeline::{
13        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
14        SingleIn, async_trait,
15    },
16    prelude::*,
17    protocols::annotated::Annotated,
18    utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction},
19};
20use futures::stream::{self, StreamExt};
21use serde::{Deserialize, Serialize};
22
23pub mod approx;
24pub mod indexer;
25pub mod metrics_aggregator;
26pub mod protocols;
27pub mod publisher;
28pub mod recorder;
29pub mod scheduler;
30pub mod scoring;
31pub mod sequence;
32pub mod subscriber;
33
34use crate::{
35    discovery::{MODEL_ROOT_PATH, ModelEntry},
36    kv_router::{
37        approx::ApproxKvIndexer,
38        indexer::{
39            KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
40            compute_block_hash_for_seq, compute_seq_hash_for_block,
41        },
42        protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
43        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
44        scoring::ProcessedEndpoints,
45        subscriber::start_kv_router_background,
46    },
47    local_model::runtime_config::ModelRuntimeConfig,
48    preprocessor::PreprocessedRequest,
49    protocols::common::llm_backend::LLMEngineOutput,
50};
51
52// [gluo TODO] shouldn't need to be public
53// this should be discovered from the component
54
55// for metric scraping (pull-based)
56pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
57
58// for metric publishing (push-based)
59pub const KV_EVENT_SUBJECT: &str = "kv_events";
60pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
61pub const KV_METRICS_SUBJECT: &str = "kv_metrics";
62
63// for inter-router comms
64pub const PREFILL_SUBJECT: &str = "prefill_events";
65pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
66
67// for radix tree snapshot storage
68pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
69pub const RADIX_STATE_FILE: &str = "radix-state";
70pub const ROUTER_SNAPSHOT_LOCK: &str = "router-snapshot-lock";
71pub const ROUTER_CLEANUP_LOCK: &str = "router-cleanup-lock";
72
73/// A trait that users can implement to define custom selection logic
74pub trait WorkerSelector {
75    fn select_worker(
76        &self,
77        workers: &HashMap<i64, Option<ModelRuntimeConfig>>,
78        request: &SchedulingRequest,
79        block_size: u32,
80    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
81}
82
83/// Override configuration for router settings that can be specified per-request
84#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
85pub struct RouterConfigOverride {
86    #[builder(default)]
87    pub overlap_score_weight: Option<f64>,
88
89    #[builder(default)]
90    pub router_temperature: Option<f64>,
91}
92
93/// KV Router configuration parameters
94#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
95pub struct KvRouterConfig {
96    pub overlap_score_weight: f64,
97
98    pub router_temperature: f64,
99
100    pub use_kv_events: bool,
101
102    pub router_replica_sync: bool,
103
104    /// Whether to track active blocks in the router (default: true)
105    pub router_track_active_blocks: bool,
106
107    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
108    pub router_snapshot_threshold: Option<u32>,
109
110    /// Whether to reset the router state on startup (default: false)
111    pub router_reset_states: bool,
112}
113
114impl Default for KvRouterConfig {
115    fn default() -> Self {
116        Self {
117            overlap_score_weight: 1.0,
118            router_temperature: 0.0,
119            use_kv_events: true,
120            router_replica_sync: false,
121            router_track_active_blocks: true,
122            router_snapshot_threshold: Some(1000000),
123            router_reset_states: false,
124        }
125    }
126}
127
128impl KvRouterConfig {
129    /// Create a new KvRouterConfig with optional weight values.
130    /// If a weight is None, the default value will be used.
131    #[allow(clippy::too_many_arguments)]
132    pub fn new(
133        overlap_score_weight: Option<f64>,
134        temperature: Option<f64>,
135        use_kv_events: Option<bool>,
136        replica_sync: Option<bool>,
137        track_active_blocks: Option<bool>,
138        router_snapshot_threshold: Option<Option<u32>>,
139        router_reset_states: Option<bool>,
140    ) -> Self {
141        let default = Self::default();
142        Self {
143            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
144            router_temperature: temperature.unwrap_or(default.router_temperature),
145            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
146            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
147            router_track_active_blocks: track_active_blocks
148                .unwrap_or(default.router_track_active_blocks),
149            router_snapshot_threshold: router_snapshot_threshold
150                .unwrap_or(default.router_snapshot_threshold),
151            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
152        }
153    }
154}
155
156// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
157// since both variants implement it
158pub enum Indexer {
159    /// Updates itself based on KV events emitted by backend workers.
160    /// Has the ability to persist and snapshot states.
161    KvIndexer(KvIndexer),
162
163    /// Predicts the cached blocks based on requests on a TTL basis.
164    /// Currently does not persist or snapshot states (WIP to enable that).
165    ApproxKvIndexer(ApproxKvIndexer),
166
167    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
168    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
169    None,
170}
171
172impl Indexer {
173    async fn find_matches(
174        &self,
175        sequence: Vec<LocalBlockHash>,
176    ) -> Result<OverlapScores, KvRouterError> {
177        match self {
178            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
179            Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
180            Indexer::None => Ok(OverlapScores {
181                scores: HashMap::new(),
182                frequencies: Vec::new(),
183            }),
184        }
185    }
186
187    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
188        match self {
189            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
190            Indexer::ApproxKvIndexer(indexer) => indexer.dump_events().await,
191            Indexer::None => {
192                panic!(
193                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
194                );
195            }
196        }
197    }
198}
199
200/// A KvRouter only decides which worker you should use. It doesn't send you there.
201/// TODO: Rename this to indicate it only selects a worker, it does not route.
202pub struct KvRouter {
203    indexer: Indexer,
204
205    // How about a Box<dyn KvIndexerInterface>
206    scheduler: KvScheduler,
207
208    block_size: u32,
209
210    kv_router_config: KvRouterConfig,
211
212    cancellation_token: tokio_util::sync::CancellationToken,
213}
214
215impl KvRouter {
216    pub async fn new(
217        component: Component,
218        block_size: u32,
219        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
220        kv_router_config: Option<KvRouterConfig>,
221        consumer_uuid: String,
222    ) -> Result<Self> {
223        let kv_router_config = kv_router_config.unwrap_or_default();
224
225        let cancellation_token = component
226            .drt()
227            .primary_lease()
228            .expect("Cannot KV route static workers")
229            .primary_token();
230
231        let generate_endpoint = component.endpoint("generate");
232        let client = generate_endpoint.client().await?;
233
234        let instances_rx = match client.instance_source.as_ref() {
235            InstanceSource::Dynamic(rx) => rx.clone(),
236            InstanceSource::Static => {
237                panic!("Expected dynamic instance source for KV routing");
238            }
239        };
240
241        // Create runtime config watcher using the generic etcd watcher
242        // TODO: Migrate to discovery_client() once it exposes kv_get_and_watch_prefix functionality
243        let etcd_client = component
244            .drt()
245            .etcd_client()
246            .expect("Cannot KV route without etcd client");
247
248        let runtime_configs_watcher = watch_prefix_with_extraction(
249            etcd_client,
250            MODEL_ROOT_PATH,
251            key_extractors::lease_id,
252            |model_entry: ModelEntry| model_entry.runtime_config,
253            cancellation_token.clone(),
254        )
255        .await?;
256        let runtime_configs_rx = runtime_configs_watcher.receiver();
257
258        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
259            // When overlap_score_weight is zero, we don't need to track prefixes
260            Indexer::None
261        } else if kv_router_config.use_kv_events {
262            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(&component);
263            Indexer::KvIndexer(KvIndexer::new(
264                cancellation_token.clone(),
265                block_size,
266                kv_indexer_metrics,
267            ))
268        } else {
269            // hard code 120 seconds for now
270            Indexer::ApproxKvIndexer(ApproxKvIndexer::new(
271                cancellation_token.clone(),
272                block_size,
273                Duration::from_secs(120),
274            ))
275        };
276
277        let scheduler = KvScheduler::start(
278            component.clone(),
279            block_size,
280            instances_rx,
281            runtime_configs_rx,
282            selector,
283            kv_router_config.router_replica_sync,
284            consumer_uuid.clone(),
285        )
286        .await?;
287
288        // Start unified background process if using KvIndexer
289        if let Indexer::KvIndexer(ref kv_indexer) = indexer {
290            start_kv_router_background(
291                component.clone(),
292                consumer_uuid,
293                kv_indexer.event_sender(),
294                kv_indexer.remove_worker_sender(),
295                kv_router_config
296                    .router_snapshot_threshold
297                    .map(|_| kv_indexer.snapshot_event_sender()),
298                cancellation_token.clone(),
299                kv_router_config.router_snapshot_threshold,
300                kv_router_config.router_reset_states,
301            )
302            .await?;
303        }
304
305        tracing::info!("KV Routing initialized");
306        Ok(Self {
307            indexer,
308            scheduler,
309            block_size,
310            kv_router_config,
311            cancellation_token,
312        })
313    }
314
315    /// Give these tokens, find the worker with the best match in it's KV cache.
316    /// Returned overlap amount is in number of blocks.
317    /// Now also takes optional context_id for request tracking
318    pub async fn find_best_match(
319        &self,
320        context_id: Option<&str>,
321        tokens: &[u32],
322        router_config_override: Option<&RouterConfigOverride>,
323        update_states: bool,
324    ) -> anyhow::Result<(i64, u32)> {
325        // Validate that context_id is provided when update_states is true
326        if update_states && context_id.is_none() {
327            panic!("context_id must be provided if update_states is true");
328        }
329
330        let isl_tokens = tokens.len();
331
332        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
333        let seq_hashes = compute_seq_hash_for_block(&block_hashes);
334
335        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
336
337        // Determine who needs seq_hashes
338        let approx_indexer_needs_it = matches!(self.indexer, Indexer::ApproxKvIndexer(_));
339        let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;
340
341        // Optimize cloning: only clone if both need it, otherwise move
342        let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
343            match (approx_indexer_needs_it, scheduler_needs_it) {
344                (true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
345                (true, false) => (Some(seq_hashes), None),
346                (false, true) => (None, Some(seq_hashes)),
347                (false, false) => (None, None),
348            };
349
350        let best_worker_id = self
351            .scheduler
352            .schedule(
353                context_id.map(|s| s.to_string()),
354                isl_tokens,
355                maybe_seq_hashes_2,
356                overlap_scores.clone(),
357                router_config_override,
358                update_states,
359            )
360            .await?;
361
362        if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
363            indexer
364                .process_routing_decision(best_worker_id, block_hashes, maybe_seq_hashes_1.unwrap())
365                .await
366                .unwrap();
367        };
368
369        let overlap_amount = overlap_scores
370            .scores
371            .get(&best_worker_id)
372            .copied()
373            .unwrap_or(0);
374        Ok((best_worker_id, overlap_amount))
375    }
376
377    pub async fn add_request(
378        &self,
379        request_id: String,
380        tokens: &[u32],
381        overlap_blocks: u32,
382        worker_id: i64,
383    ) {
384        let isl_tokens = tokens.len();
385
386        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
387            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
388            compute_seq_hash_for_block(&block_hashes)
389        });
390
391        self.scheduler
392            .add_request(
393                request_id,
394                maybe_seq_hashes,
395                isl_tokens,
396                overlap_blocks,
397                worker_id,
398            )
399            .await;
400    }
401
402    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<()> {
403        self.scheduler.mark_prefill_completed(request_id).await
404    }
405
406    pub async fn free(&self, request_id: &str) -> Result<()> {
407        self.scheduler.free(request_id).await
408    }
409
410    pub fn block_size(&self) -> u32 {
411        self.block_size
412    }
413
414    /// Get potential prefill and decode loads for all workers
415    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
416        let isl_tokens = tokens.len();
417        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
418        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
419
420        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
421            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
422            compute_seq_hash_for_block(&block_hashes)
423        });
424
425        Ok(self
426            .scheduler
427            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
428            .await)
429    }
430
431    /// Dump all events from the indexer
432    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
433        self.indexer.dump_events().await
434    }
435}
436
437// NOTE: KVRouter works like a PushRouter,
438// but without the reverse proxy functionality, but based on contract of 3 request types
439#[async_trait]
440impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
441    async fn generate(
442        &self,
443        request: SingleIn<RouterRequest>,
444    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
445        let (request, ctx) = request.into_parts();
446        let context_id = ctx.context().id().to_string();
447        // Handle different request types
448        let response = match request {
449            RouterRequest::New { tokens } => {
450                let (worker_id, overlap_blocks) = self
451                    .find_best_match(Some(&context_id), &tokens, None, true)
452                    .await?;
453
454                RouterResponse::New {
455                    worker_id,
456                    overlap_blocks,
457                }
458            }
459            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
460                success: self.mark_prefill_completed(&context_id).await.is_ok(),
461            },
462            RouterRequest::MarkFree => RouterResponse::FreeMarked {
463                success: self.free(&context_id).await.is_ok(),
464            },
465        };
466
467        let response = Annotated::from_data(response);
468        let stream = stream::iter(vec![response]);
469        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
470    }
471}
472
473pub struct KvPushRouter {
474    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
475    chooser: Arc<KvRouter>,
476}
477
478impl KvPushRouter {
479    pub fn new(
480        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
481        chooser: Arc<KvRouter>,
482    ) -> Self {
483        KvPushRouter { inner, chooser }
484    }
485
486    /// Find the best matching worker for the given tokens without updating states
487    pub async fn find_best_match(
488        &self,
489        tokens: &[u32],
490        router_config_override: Option<&RouterConfigOverride>,
491    ) -> Result<(i64, u32)> {
492        self.chooser
493            .find_best_match(None, tokens, router_config_override, false)
494            .await
495    }
496
497    /// Get potential prefill and decode loads for all workers
498    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
499        self.chooser.get_potential_loads(tokens).await
500    }
501
502    /// Dump all events from the KV router's indexer
503    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
504        self.chooser.dump_events().await
505    }
506}
507
508#[async_trait]
509impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
510    for KvPushRouter
511{
512    /// Generate method that handles KV-aware routing with three distinct behaviors:
513    ///
514    /// 1. **If `query_instance_id` annotation is set**:
515    ///    - Returns the best matching worker ID without routing the request
516    ///    - Does NOT update any router local states
517    ///    - Response includes worker_instance_id and token_data annotations
518    ///
519    /// 2. **If `backend_instance_id` is set in the request**:
520    ///    - Routes directly to the specified backend instance
521    ///    - DOES update router states to track this request (unless query_instance_id is also set)
522    ///    - Bypasses the normal KV matching logic
523    ///
524    /// 3. **If neither are set (default behavior)**:
525    ///    - Finds the best worker based on KV cache overlap
526    ///    - Updates router states to track the request
527    ///    - Routes to the selected worker
528    ///
529    /// The router state updates include tracking active sequences and managing
530    /// prefill/completion lifecycle for proper KV cache management.
531    async fn generate(
532        &self,
533        request: SingleIn<PreprocessedRequest>,
534    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
535        match self.inner.client.instance_source.as_ref() {
536            InstanceSource::Static => self.inner.r#static(request).await,
537            InstanceSource::Dynamic(_) => {
538                // Extract context ID for request tracking
539                let context_id = request.context().id().to_string();
540
541                // Check if this is a query_instance_id request first
542                let query_instance_id = request.has_annotation("query_instance_id");
543
544                let (instance_id, overlap_amount) = if let Some(id) = request.backend_instance_id {
545                    // If instance_id is set, use it and manually add the request to track it
546                    if !query_instance_id {
547                        self.chooser
548                            .add_request(context_id.clone(), &request.token_ids, 0, id)
549                            .await;
550                    }
551                    (id, 0)
552                } else {
553                    // Otherwise, find the best match
554                    self.chooser
555                        .find_best_match(
556                            Some(&context_id),
557                            &request.token_ids,
558                            request.router_config_override.as_ref(),
559                            !query_instance_id, // Don't update states if query_instance_id
560                        )
561                        .await?
562                };
563
564                // if request has the annotation "query_instance_id",
565                // then the request will not be routed to the worker,
566                // and instead the worker_instance_id will be returned.
567                let stream_context = request.context().clone();
568                if query_instance_id {
569                    let instance_id_str = instance_id.to_string();
570                    let response =
571                        Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
572
573                    // Return the tokens in nvext.token_data format
574                    let response_tokens =
575                        Annotated::from_annotation("token_data", &request.token_ids)?;
576                    tracing::trace!(
577                        "Tokens requested in the response through the query_instance_id annotation: {:?}",
578                        response_tokens
579                    );
580                    let stream = stream::iter(vec![response, response_tokens]);
581                    return Ok(ResponseStream::new(Box::pin(stream), stream_context));
582                }
583                let (mut backend_input, context) = request.into_parts();
584                backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
585                let updated_request = context.map(|_| backend_input);
586
587                let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
588                let stream_context = response_stream.context();
589                let chooser = self.chooser.clone();
590
591                let wrapped_stream = Box::pin(async_stream::stream! {
592                    if let Some(first_item) = response_stream.next().await {
593                        if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
594                            tracing::warn!("Failed to mark prefill completed for request {context_id}: {e:?}");
595                        }
596                        yield first_item;
597                    }
598
599                    while let Some(item) = response_stream.next().await {
600                        yield item;
601                    }
602
603                    if let Err(e) = chooser.free(&context_id).await {
604                        tracing::warn!("Failed to free request {context_id}: {e:?}");
605                    }
606                });
607                Ok(ResponseStream::new(wrapped_stream, stream_context))
608            }
609        }
610    }
611}
612
613impl Drop for KvRouter {
614    fn drop(&mut self) {
615        tracing::info!("Dropping KvRouter - cancelling background tasks");
616        self.cancellation_token.cancel();
617    }
618}