1use std::sync::Arc;
5
6use anyhow::Result;
7use dynamo_runtime::{
8 component::{Component, InstanceSource},
9 pipeline::{
10 async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter,
11 ResponseStream, SingleIn,
12 },
13 prelude::*,
14 protocols::annotated::Annotated,
15};
16use futures::stream::{self, StreamExt};
17
18pub mod indexer;
19pub mod metrics_aggregator;
20pub mod protocols;
21pub mod publisher;
22pub mod recorder;
23pub mod scheduler;
24pub mod scoring;
25
26use crate::{
27 kv_router::{
28 indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
29 metrics_aggregator::KvMetricsAggregator,
30 protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
31 scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
32 scoring::ProcessedEndpoints,
33 },
34 preprocessor::PreprocessedRequest,
35 protocols::common::llm_backend::LLMEngineOutput,
36 tokens::TokenBlockSequence,
37};
38
39use dynamo_runtime::traits::events::EventSubscriber;
40
41pub const KV_EVENT_SUBJECT: &str = "kv_events";
44pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
45pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
46
47pub trait WorkerSelector {
49 fn select_worker(
50 &self,
51 workers: &ProcessedEndpoints,
52 request: &SchedulingRequest,
53 block_size: usize,
54 ) -> Result<WorkerSelectionResult, KvSchedulerError>;
55}
56
57#[derive(Debug, Clone)]
59pub struct KvRouterConfig {
60 pub overlap_score_weight: f64,
63
64 pub gpu_cache_usage_weight: f64,
67
68 pub waiting_requests_weight: f64,
71}
72
73impl Default for KvRouterConfig {
74 fn default() -> Self {
75 Self {
76 overlap_score_weight: 2.0,
77 gpu_cache_usage_weight: 1.0,
78 waiting_requests_weight: 1.0,
79 }
80 }
81}
82
83impl KvRouterConfig {
84 pub fn new(
87 overlap_score_weight: Option<f64>,
88 gpu_cache_usage_weight: Option<f64>,
89 waiting_requests_weight: Option<f64>,
90 ) -> Self {
91 let default = Self::default();
92 Self {
93 overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
94 gpu_cache_usage_weight: gpu_cache_usage_weight
95 .unwrap_or(default.gpu_cache_usage_weight),
96 waiting_requests_weight: waiting_requests_weight
97 .unwrap_or(default.waiting_requests_weight),
98 }
99 }
100}
101
102pub struct KvRouter {
105 indexer: KvIndexer,
106 scheduler: KvScheduler,
107 block_size: usize,
108}
109
110impl KvRouter {
111 pub async fn new(
112 component: Component,
113 block_size: usize,
114 selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
115 ) -> Result<Self> {
116 let cancellation_token = component
117 .drt()
118 .primary_lease()
119 .expect("Cannot KV route static workers")
120 .primary_token();
121 tracing::info!("KV Routing initialized");
122 let metrics_aggregator =
123 KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await;
124 let indexer = KvIndexer::new(cancellation_token.clone(), block_size);
125 let scheduler = KvScheduler::start(
126 component.namespace().clone(),
127 block_size,
128 metrics_aggregator.endpoints_watcher(),
129 selector,
130 )
131 .await?;
132
133 let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
136 let kv_events_tx = indexer.event_sender();
137
138 tokio::spawn(async move {
139 while let Some(event) = kv_events_rx.next().await {
140 let event: RouterEvent = match serde_json::from_slice(&event.payload) {
141 Ok(event) => event,
142 Err(e) => {
143 tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
144 continue;
147 }
148 };
149 if let Err(e) = kv_events_tx.send(event).await {
150 tracing::debug!("failed to send kv event to indexer; shutting down: {:?}", e);
151 }
152 }
153 });
154
155 Ok(Self {
156 scheduler,
157 indexer,
158 block_size,
159 })
160 }
161
162 pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
164 let isl_tokens = token_ids.len();
167 let overlap_scores = self
168 .indexer
169 .find_matches_for_request(token_ids.as_slice())
170 .await?;
171 tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
172 let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
173 Ok(worker_id)
174 }
175
176 async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(i64, u32)> {
179 let isl_tokens = tokens.len();
180 let block_size = self.block_size;
181
182 let (complete_blocks, _partial_block) =
183 TokenBlockSequence::split_tokens(tokens, block_size, 1337_u64);
184
185 let local_block_hashes = complete_blocks
186 .into_iter()
187 .map(|block| LocalBlockHash(block.block_hash()))
188 .collect();
189 let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
190 let worker_id = self
191 .scheduler
192 .schedule(overlap_scores.clone(), isl_tokens)
193 .await?;
194 let overlap_amount = overlap_scores.scores.get(&worker_id).copied().unwrap_or(0);
195 Ok((worker_id, overlap_amount))
196 }
197
198 pub fn block_size(&self) -> usize {
200 self.block_size
201 }
202}
203
204#[async_trait]
205impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
206 async fn generate(
207 &self,
208 request: SingleIn<RouterRequest>,
209 ) -> Result<ManyOut<Annotated<RouterResponse>>> {
210 let (request, ctx) = request.into_parts();
211 let (worker_id, _) = self.find_best_match(&request.tokens).await?;
212
213 let response = RouterResponse { worker_id };
214 let response = Annotated::from_data(response);
215 let stream = stream::iter(vec![response]);
216 Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
217 }
218}
219
220pub struct KvPushRouter {
221 inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
222 chooser: Arc<KvRouter>,
223}
224
225impl KvPushRouter {
226 pub fn new(
227 inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
228 chooser: Arc<KvRouter>,
229 ) -> Self {
230 KvPushRouter { inner, chooser }
231 }
232}
233
234#[async_trait]
235impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
236 for KvPushRouter
237{
238 async fn generate(
239 &self,
240 request: SingleIn<PreprocessedRequest>,
241 ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
242 match self.inner.client.instance_source.as_ref() {
243 InstanceSource::Static => self.inner.r#static(request).await,
244 InstanceSource::Dynamic(_) => {
245 let (instance_id, overlap_amount) =
246 self.chooser.find_best_match(&request.token_ids).await?;
247 let (mut backend_input, context) = request.into_parts();
249 backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
250 let updated_request = context.map(|_| backend_input);
251 self.inner.direct(updated_request, instance_id).await
252 }
253 }
254 }
255}