dynamo_llm/
kv_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use anyhow::Result;
17use dynamo_runtime::{
18    component::Component,
19    pipeline::{
20        async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream,
21        SingleIn,
22    },
23    prelude::*,
24    protocols::annotated::Annotated,
25};
26use futures::stream::{self, StreamExt};
27use std::sync::Arc;
28
29pub mod indexer;
30pub mod metrics_aggregator;
31pub mod protocols;
32pub mod publisher;
33pub mod recorder;
34pub mod scheduler;
35pub mod scoring;
36
37use crate::{
38    kv_router::{
39        indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
40        metrics_aggregator::KvMetricsAggregator,
41        protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
42        scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
43        scoring::ProcessedEndpoints,
44    },
45    tokens::Tokens,
46};
47
48use dynamo_runtime::traits::events::EventSubscriber;
49
50// [gluo TODO] shouldn't need to be public
51// this should be discovered from the component
52pub const KV_EVENT_SUBJECT: &str = "kv_events";
53pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
54pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
55
56/// A trait that users can implement to define custom selection logic
57pub trait WorkerSelector {
58    fn select_worker(
59        &self,
60        workers: &ProcessedEndpoints,
61        request: &SchedulingRequest,
62        block_size: usize,
63    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
64}
65
66pub struct KvRouter {
67    indexer: KvIndexer,
68    scheduler: KvScheduler,
69    block_size: usize,
70}
71
72impl KvRouter {
73    pub async fn new(
74        component: Component,
75        block_size: usize,
76        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
77    ) -> Result<Arc<Self>> {
78        let cancellation_token = component
79            .drt()
80            .primary_lease()
81            .expect("Cannot KV route static workers")
82            .primary_token();
83
84        let metrics_aggregator =
85            KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await;
86        let indexer = KvIndexer::new(cancellation_token.clone(), block_size);
87        let scheduler = KvScheduler::start(
88            component.namespace().clone(),
89            block_size,
90            metrics_aggregator.endpoints_watcher(),
91            selector,
92        )
93        .await?;
94
95        // [gluo TODO] try subscribe_with_type::<RouterEvent>,
96        // error checking below will be different.
97        let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
98        let kv_events_tx = indexer.event_sender();
99
100        tokio::spawn(async move {
101            while let Some(event) = kv_events_rx.next().await {
102                let event: RouterEvent = match serde_json::from_slice(&event.payload) {
103                    Ok(event) => {
104                        tracing::debug!("received kv event: {:?}", event);
105                        event
106                    }
107                    Err(e) => {
108                        tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
109                        // Choosing warn and continue to process other events from other workers
110                        // A bad event likely signals a problem with a worker, but potentially other workers are still healthy
111                        continue;
112                    }
113                };
114                if let Err(e) = kv_events_tx.send(event).await {
115                    tracing::trace!("failed to send kv event to indexer; shutting down: {:?}", e);
116                }
117            }
118        });
119
120        Ok(Arc::new(Self {
121            scheduler,
122            indexer,
123            block_size,
124        }))
125    }
126
127    // [TODO] indexer needs to take 'lora_id' as parameter
128    pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
129        // Extracting part of the code in KvRouter::generate() for only
130        // the decision making part, routing is done by the caller
131        let isl_tokens = token_ids.len();
132        let overlap_scores = self
133            .indexer
134            .find_matches_for_request(token_ids.as_slice())
135            .await?;
136        tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
137        let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
138        Ok(worker_id)
139    }
140}
141
142#[async_trait]
143impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
144    async fn generate(
145        &self,
146        request: SingleIn<RouterRequest>,
147    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
148        let (request, ctx) = request.into_parts();
149        let isl_tokens = request.tokens.len();
150        let block_size = self.block_size;
151
152        // Compute the block hashes in a blocking task
153        let local_block_hashes: Vec<LocalBlockHash> = tokio::task::spawn_blocking(move || {
154            Tokens::compute_block_hash(&request.tokens, block_size)
155                .into_iter()
156                .map(LocalBlockHash)
157                .collect()
158        })
159        .await?;
160
161        let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
162        let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
163
164        let response = RouterResponse { worker_id };
165        let response = Annotated::from_data(response);
166        let stream = stream::iter(vec![response]);
167        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
168    }
169}