dynamo_llm/
kv_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5
6use anyhow::Result;
7use dynamo_runtime::{
8    component::{Component, InstanceSource},
9    pipeline::{
10        async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter,
11        ResponseStream, SingleIn,
12    },
13    prelude::*,
14    protocols::annotated::Annotated,
15};
16use futures::stream::{self, StreamExt};
17
18pub mod indexer;
19pub mod metrics_aggregator;
20pub mod protocols;
21pub mod publisher;
22pub mod recorder;
23pub mod scheduler;
24pub mod scoring;
25
26use crate::{
27    kv_router::{
28        indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
29        metrics_aggregator::KvMetricsAggregator,
30        protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
31        scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
32        scoring::ProcessedEndpoints,
33    },
34    preprocessor::PreprocessedRequest,
35    protocols::common::llm_backend::LLMEngineOutput,
36    tokens::TokenBlockSequence,
37};
38
39use dynamo_runtime::traits::events::EventSubscriber;
40
41// [gluo TODO] shouldn't need to be public
42// this should be discovered from the component
43pub const KV_EVENT_SUBJECT: &str = "kv_events";
44pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
45pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
46
47/// A trait that users can implement to define custom selection logic
48pub trait WorkerSelector {
49    fn select_worker(
50        &self,
51        workers: &ProcessedEndpoints,
52        request: &SchedulingRequest,
53        block_size: usize,
54    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
55}
56
57/// KV Router configuration parameters
58#[derive(Debug, Clone)]
59pub struct KvRouterConfig {
60    /// Weight for overlap score in worker selection.
61    /// Higher values prioritize KV cache reuse. Default: 2.0
62    pub overlap_score_weight: f64,
63
64    /// Weight for GPU cache usage in worker selection.
65    /// Higher values avoid workers with nearly full KV caches. Default: 1.0
66    pub gpu_cache_usage_weight: f64,
67
68    /// Weight for waiting requests in worker selection.
69    /// Higher values avoid workers with queued requests. Default: 1.0
70    pub waiting_requests_weight: f64,
71}
72
73impl Default for KvRouterConfig {
74    fn default() -> Self {
75        Self {
76            overlap_score_weight: 2.0,
77            gpu_cache_usage_weight: 1.0,
78            waiting_requests_weight: 1.0,
79        }
80    }
81}
82
83impl KvRouterConfig {
84    /// Create a new KvRouterConfig with optional weight values.
85    /// If a weight is None, the default value will be used.
86    pub fn new(
87        overlap_score_weight: Option<f64>,
88        gpu_cache_usage_weight: Option<f64>,
89        waiting_requests_weight: Option<f64>,
90    ) -> Self {
91        let default = Self::default();
92        Self {
93            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
94            gpu_cache_usage_weight: gpu_cache_usage_weight
95                .unwrap_or(default.gpu_cache_usage_weight),
96            waiting_requests_weight: waiting_requests_weight
97                .unwrap_or(default.waiting_requests_weight),
98        }
99    }
100}
101
102/// A KvRouter only decides which worker you should use. It doesn't send you there.
103/// TODO: Rename this to indicate it only selects a worker, it does not route.
104pub struct KvRouter {
105    indexer: KvIndexer,
106    scheduler: KvScheduler,
107    block_size: usize,
108}
109
110impl KvRouter {
111    pub async fn new(
112        component: Component,
113        block_size: usize,
114        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
115    ) -> Result<Self> {
116        let cancellation_token = component
117            .drt()
118            .primary_lease()
119            .expect("Cannot KV route static workers")
120            .primary_token();
121        tracing::info!("KV Routing initialized");
122        let metrics_aggregator =
123            KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await;
124        let indexer = KvIndexer::new(cancellation_token.clone(), block_size);
125        let scheduler = KvScheduler::start(
126            component.namespace().clone(),
127            block_size,
128            metrics_aggregator.endpoints_watcher(),
129            selector,
130        )
131        .await?;
132
133        // [gluo TODO] try subscribe_with_type::<RouterEvent>,
134        // error checking below will be different.
135        let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
136        let kv_events_tx = indexer.event_sender();
137
138        tokio::spawn(async move {
139            while let Some(event) = kv_events_rx.next().await {
140                let event: RouterEvent = match serde_json::from_slice(&event.payload) {
141                    Ok(event) => event,
142                    Err(e) => {
143                        tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
144                        // Choosing warn and continue to process other events from other workers
145                        // A bad event likely signals a problem with a worker, but potentially other workers are still healthy
146                        continue;
147                    }
148                };
149                if let Err(e) = kv_events_tx.send(event).await {
150                    tracing::debug!("failed to send kv event to indexer; shutting down: {:?}", e);
151                }
152            }
153        });
154
155        Ok(Self {
156            scheduler,
157            indexer,
158            block_size,
159        })
160    }
161
162    // [TODO] indexer needs to take 'lora_id' as parameter
163    pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
164        // Extracting part of the code in KvRouter::generate() for only
165        // the decision making part, routing is done by the caller
166        let isl_tokens = token_ids.len();
167        let overlap_scores = self
168            .indexer
169            .find_matches_for_request(token_ids.as_slice())
170            .await?;
171        tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
172        let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
173        Ok(worker_id)
174    }
175
176    /// Give these tokens, find the worker with the best match in it's KV cache.
177    /// Returned overlap amount is in number of blocks.
178    async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(i64, u32)> {
179        let isl_tokens = tokens.len();
180        let block_size = self.block_size;
181
182        let (complete_blocks, _partial_block) =
183            TokenBlockSequence::split_tokens(tokens, block_size, 1337_u64);
184
185        let local_block_hashes = complete_blocks
186            .into_iter()
187            .map(|block| LocalBlockHash(block.block_hash()))
188            .collect();
189        let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
190        let worker_id = self
191            .scheduler
192            .schedule(overlap_scores.clone(), isl_tokens)
193            .await?;
194        let overlap_amount = overlap_scores.scores.get(&worker_id).copied().unwrap_or(0);
195        Ok((worker_id, overlap_amount))
196    }
197
198    /// Get the block size this router was configured with
199    pub fn block_size(&self) -> usize {
200        self.block_size
201    }
202}
203
204#[async_trait]
205impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
206    async fn generate(
207        &self,
208        request: SingleIn<RouterRequest>,
209    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
210        let (request, ctx) = request.into_parts();
211        let (worker_id, _) = self.find_best_match(&request.tokens).await?;
212
213        let response = RouterResponse { worker_id };
214        let response = Annotated::from_data(response);
215        let stream = stream::iter(vec![response]);
216        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
217    }
218}
219
220pub struct KvPushRouter {
221    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
222    chooser: Arc<KvRouter>,
223}
224
225impl KvPushRouter {
226    pub fn new(
227        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
228        chooser: Arc<KvRouter>,
229    ) -> Self {
230        KvPushRouter { inner, chooser }
231    }
232}
233
234#[async_trait]
235impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
236    for KvPushRouter
237{
238    async fn generate(
239        &self,
240        request: SingleIn<PreprocessedRequest>,
241    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
242        match self.inner.client.instance_source.as_ref() {
243            InstanceSource::Static => self.inner.r#static(request).await,
244            InstanceSource::Dynamic(_) => {
245                let (instance_id, overlap_amount) =
246                    self.chooser.find_best_match(&request.token_ids).await?;
247                // Update the request with the estimated prefix hit blocks
248                let (mut backend_input, context) = request.into_parts();
249                backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
250                let updated_request = context.map(|_| backend_input);
251                self.inner.direct(updated_request, instance_id).await
252            }
253        }
254    }
255}