1use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::Duration;
7
8use anyhow::Result;
9use derive_builder::Builder;
10use dynamo_runtime::{
11 component::{Component, InstanceSource},
12 pipeline::{
13 AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
14 SingleIn, async_trait,
15 },
16 prelude::*,
17 protocols::annotated::Annotated,
18 utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction},
19};
20use futures::stream::{self, StreamExt};
21use serde::{Deserialize, Serialize};
22
23pub mod approx;
24pub mod indexer;
25pub mod metrics_aggregator;
26pub mod protocols;
27pub mod publisher;
28pub mod recorder;
29pub mod scheduler;
30pub mod scoring;
31pub mod sequence;
32pub mod subscriber;
33
34use crate::{
35 discovery::{MODEL_ROOT_PATH, ModelEntry},
36 kv_router::{
37 approx::ApproxKvIndexer,
38 indexer::{
39 KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
40 compute_block_hash_for_seq, compute_seq_hash_for_block,
41 },
42 protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
43 scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
44 scoring::ProcessedEndpoints,
45 subscriber::start_kv_router_background,
46 },
47 local_model::runtime_config::ModelRuntimeConfig,
48 preprocessor::PreprocessedRequest,
49 protocols::common::llm_backend::LLMEngineOutput,
50};
51
52pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
57
58pub const KV_EVENT_SUBJECT: &str = "kv_events";
60pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
61pub const KV_METRICS_SUBJECT: &str = "kv_metrics";
62
63pub const PREFILL_SUBJECT: &str = "prefill_events";
65pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
66
67pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
69pub const RADIX_STATE_FILE: &str = "radix-state";
70pub const ROUTER_SNAPSHOT_LOCK: &str = "router-snapshot-lock";
71pub const ROUTER_CLEANUP_LOCK: &str = "router-cleanup-lock";
72
73pub trait WorkerSelector {
75 fn select_worker(
76 &self,
77 workers: &HashMap<i64, Option<ModelRuntimeConfig>>,
78 request: &SchedulingRequest,
79 block_size: u32,
80 ) -> Result<WorkerSelectionResult, KvSchedulerError>;
81}
82
83#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
85pub struct RouterConfigOverride {
86 #[builder(default)]
87 pub overlap_score_weight: Option<f64>,
88
89 #[builder(default)]
90 pub router_temperature: Option<f64>,
91}
92
93#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
95pub struct KvRouterConfig {
96 pub overlap_score_weight: f64,
97
98 pub router_temperature: f64,
99
100 pub use_kv_events: bool,
101
102 pub router_replica_sync: bool,
103
104 pub router_track_active_blocks: bool,
106
107 pub router_snapshot_threshold: Option<u32>,
109
110 pub router_reset_states: bool,
112}
113
114impl Default for KvRouterConfig {
115 fn default() -> Self {
116 Self {
117 overlap_score_weight: 1.0,
118 router_temperature: 0.0,
119 use_kv_events: true,
120 router_replica_sync: false,
121 router_track_active_blocks: true,
122 router_snapshot_threshold: Some(1000000),
123 router_reset_states: false,
124 }
125 }
126}
127
128impl KvRouterConfig {
129 #[allow(clippy::too_many_arguments)]
132 pub fn new(
133 overlap_score_weight: Option<f64>,
134 temperature: Option<f64>,
135 use_kv_events: Option<bool>,
136 replica_sync: Option<bool>,
137 track_active_blocks: Option<bool>,
138 router_snapshot_threshold: Option<Option<u32>>,
139 router_reset_states: Option<bool>,
140 ) -> Self {
141 let default = Self::default();
142 Self {
143 overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
144 router_temperature: temperature.unwrap_or(default.router_temperature),
145 use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
146 router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
147 router_track_active_blocks: track_active_blocks
148 .unwrap_or(default.router_track_active_blocks),
149 router_snapshot_threshold: router_snapshot_threshold
150 .unwrap_or(default.router_snapshot_threshold),
151 router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
152 }
153 }
154}
155
156pub enum Indexer {
159 KvIndexer(KvIndexer),
162
163 ApproxKvIndexer(ApproxKvIndexer),
166
167 None,
170}
171
172impl Indexer {
173 async fn find_matches(
174 &self,
175 sequence: Vec<LocalBlockHash>,
176 ) -> Result<OverlapScores, KvRouterError> {
177 match self {
178 Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
179 Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
180 Indexer::None => Ok(OverlapScores {
181 scores: HashMap::new(),
182 frequencies: Vec::new(),
183 }),
184 }
185 }
186
187 async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
188 match self {
189 Indexer::KvIndexer(indexer) => indexer.dump_events().await,
190 Indexer::ApproxKvIndexer(indexer) => indexer.dump_events().await,
191 Indexer::None => {
192 panic!(
193 "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
194 );
195 }
196 }
197 }
198}
199
200pub struct KvRouter {
203 indexer: Indexer,
204
205 scheduler: KvScheduler,
207
208 block_size: u32,
209
210 kv_router_config: KvRouterConfig,
211
212 cancellation_token: tokio_util::sync::CancellationToken,
213}
214
215impl KvRouter {
216 pub async fn new(
217 component: Component,
218 block_size: u32,
219 selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
220 kv_router_config: Option<KvRouterConfig>,
221 consumer_uuid: String,
222 ) -> Result<Self> {
223 let kv_router_config = kv_router_config.unwrap_or_default();
224
225 let cancellation_token = component
226 .drt()
227 .primary_lease()
228 .expect("Cannot KV route static workers")
229 .primary_token();
230
231 let generate_endpoint = component.endpoint("generate");
232 let client = generate_endpoint.client().await?;
233
234 let instances_rx = match client.instance_source.as_ref() {
235 InstanceSource::Dynamic(rx) => rx.clone(),
236 InstanceSource::Static => {
237 panic!("Expected dynamic instance source for KV routing");
238 }
239 };
240
241 let etcd_client = component
244 .drt()
245 .etcd_client()
246 .expect("Cannot KV route without etcd client");
247
248 let runtime_configs_watcher = watch_prefix_with_extraction(
249 etcd_client,
250 MODEL_ROOT_PATH,
251 key_extractors::lease_id,
252 |model_entry: ModelEntry| model_entry.runtime_config,
253 cancellation_token.clone(),
254 )
255 .await?;
256 let runtime_configs_rx = runtime_configs_watcher.receiver();
257
258 let indexer = if kv_router_config.overlap_score_weight == 0.0 {
259 Indexer::None
261 } else if kv_router_config.use_kv_events {
262 let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(&component);
263 Indexer::KvIndexer(KvIndexer::new(
264 cancellation_token.clone(),
265 block_size,
266 kv_indexer_metrics,
267 ))
268 } else {
269 Indexer::ApproxKvIndexer(ApproxKvIndexer::new(
271 cancellation_token.clone(),
272 block_size,
273 Duration::from_secs(120),
274 ))
275 };
276
277 let scheduler = KvScheduler::start(
278 component.clone(),
279 block_size,
280 instances_rx,
281 runtime_configs_rx,
282 selector,
283 kv_router_config.router_replica_sync,
284 consumer_uuid.clone(),
285 )
286 .await?;
287
288 if let Indexer::KvIndexer(ref kv_indexer) = indexer {
290 start_kv_router_background(
291 component.clone(),
292 consumer_uuid,
293 kv_indexer.event_sender(),
294 kv_indexer.remove_worker_sender(),
295 kv_router_config
296 .router_snapshot_threshold
297 .map(|_| kv_indexer.snapshot_event_sender()),
298 cancellation_token.clone(),
299 kv_router_config.router_snapshot_threshold,
300 kv_router_config.router_reset_states,
301 )
302 .await?;
303 }
304
305 tracing::info!("KV Routing initialized");
306 Ok(Self {
307 indexer,
308 scheduler,
309 block_size,
310 kv_router_config,
311 cancellation_token,
312 })
313 }
314
315 pub async fn find_best_match(
319 &self,
320 context_id: Option<&str>,
321 tokens: &[u32],
322 router_config_override: Option<&RouterConfigOverride>,
323 update_states: bool,
324 ) -> anyhow::Result<(i64, u32)> {
325 if update_states && context_id.is_none() {
327 panic!("context_id must be provided if update_states is true");
328 }
329
330 let isl_tokens = tokens.len();
331
332 let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
333 let seq_hashes = compute_seq_hash_for_block(&block_hashes);
334
335 let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
336
337 let approx_indexer_needs_it = matches!(self.indexer, Indexer::ApproxKvIndexer(_));
339 let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;
340
341 let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
343 match (approx_indexer_needs_it, scheduler_needs_it) {
344 (true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
345 (true, false) => (Some(seq_hashes), None),
346 (false, true) => (None, Some(seq_hashes)),
347 (false, false) => (None, None),
348 };
349
350 let best_worker_id = self
351 .scheduler
352 .schedule(
353 context_id.map(|s| s.to_string()),
354 isl_tokens,
355 maybe_seq_hashes_2,
356 overlap_scores.clone(),
357 router_config_override,
358 update_states,
359 )
360 .await?;
361
362 if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
363 indexer
364 .process_routing_decision(best_worker_id, block_hashes, maybe_seq_hashes_1.unwrap())
365 .await
366 .unwrap();
367 };
368
369 let overlap_amount = overlap_scores
370 .scores
371 .get(&best_worker_id)
372 .copied()
373 .unwrap_or(0);
374 Ok((best_worker_id, overlap_amount))
375 }
376
377 pub async fn add_request(
378 &self,
379 request_id: String,
380 tokens: &[u32],
381 overlap_blocks: u32,
382 worker_id: i64,
383 ) {
384 let isl_tokens = tokens.len();
385
386 let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
387 let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
388 compute_seq_hash_for_block(&block_hashes)
389 });
390
391 self.scheduler
392 .add_request(
393 request_id,
394 maybe_seq_hashes,
395 isl_tokens,
396 overlap_blocks,
397 worker_id,
398 )
399 .await;
400 }
401
402 pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<()> {
403 self.scheduler.mark_prefill_completed(request_id).await
404 }
405
406 pub async fn free(&self, request_id: &str) -> Result<()> {
407 self.scheduler.free(request_id).await
408 }
409
410 pub fn block_size(&self) -> u32 {
411 self.block_size
412 }
413
414 pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
416 let isl_tokens = tokens.len();
417 let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
418 let overlap_scores = self.indexer.find_matches(block_hashes).await?;
419
420 let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
421 let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
422 compute_seq_hash_for_block(&block_hashes)
423 });
424
425 Ok(self
426 .scheduler
427 .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
428 .await)
429 }
430
431 pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
433 self.indexer.dump_events().await
434 }
435}
436
437#[async_trait]
440impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
441 async fn generate(
442 &self,
443 request: SingleIn<RouterRequest>,
444 ) -> Result<ManyOut<Annotated<RouterResponse>>> {
445 let (request, ctx) = request.into_parts();
446 let context_id = ctx.context().id().to_string();
447 let response = match request {
449 RouterRequest::New { tokens } => {
450 let (worker_id, overlap_blocks) = self
451 .find_best_match(Some(&context_id), &tokens, None, true)
452 .await?;
453
454 RouterResponse::New {
455 worker_id,
456 overlap_blocks,
457 }
458 }
459 RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
460 success: self.mark_prefill_completed(&context_id).await.is_ok(),
461 },
462 RouterRequest::MarkFree => RouterResponse::FreeMarked {
463 success: self.free(&context_id).await.is_ok(),
464 },
465 };
466
467 let response = Annotated::from_data(response);
468 let stream = stream::iter(vec![response]);
469 Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
470 }
471}
472
473pub struct KvPushRouter {
474 inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
475 chooser: Arc<KvRouter>,
476}
477
478impl KvPushRouter {
479 pub fn new(
480 inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
481 chooser: Arc<KvRouter>,
482 ) -> Self {
483 KvPushRouter { inner, chooser }
484 }
485
486 pub async fn find_best_match(
488 &self,
489 tokens: &[u32],
490 router_config_override: Option<&RouterConfigOverride>,
491 ) -> Result<(i64, u32)> {
492 self.chooser
493 .find_best_match(None, tokens, router_config_override, false)
494 .await
495 }
496
497 pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
499 self.chooser.get_potential_loads(tokens).await
500 }
501
502 pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
504 self.chooser.dump_events().await
505 }
506}
507
508#[async_trait]
509impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
510 for KvPushRouter
511{
512 async fn generate(
532 &self,
533 request: SingleIn<PreprocessedRequest>,
534 ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
535 match self.inner.client.instance_source.as_ref() {
536 InstanceSource::Static => self.inner.r#static(request).await,
537 InstanceSource::Dynamic(_) => {
538 let context_id = request.context().id().to_string();
540
541 let query_instance_id = request.has_annotation("query_instance_id");
543
544 let (instance_id, overlap_amount) = if let Some(id) = request.backend_instance_id {
545 if !query_instance_id {
547 self.chooser
548 .add_request(context_id.clone(), &request.token_ids, 0, id)
549 .await;
550 }
551 (id, 0)
552 } else {
553 self.chooser
555 .find_best_match(
556 Some(&context_id),
557 &request.token_ids,
558 request.router_config_override.as_ref(),
559 !query_instance_id, )
561 .await?
562 };
563
564 let stream_context = request.context().clone();
568 if query_instance_id {
569 let instance_id_str = instance_id.to_string();
570 let response =
571 Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
572
573 let response_tokens =
575 Annotated::from_annotation("token_data", &request.token_ids)?;
576 tracing::trace!(
577 "Tokens requested in the response through the query_instance_id annotation: {:?}",
578 response_tokens
579 );
580 let stream = stream::iter(vec![response, response_tokens]);
581 return Ok(ResponseStream::new(Box::pin(stream), stream_context));
582 }
583 let (mut backend_input, context) = request.into_parts();
584 backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
585 let updated_request = context.map(|_| backend_input);
586
587 let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
588 let stream_context = response_stream.context();
589 let chooser = self.chooser.clone();
590
591 let wrapped_stream = Box::pin(async_stream::stream! {
592 if let Some(first_item) = response_stream.next().await {
593 if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
594 tracing::warn!("Failed to mark prefill completed for request {context_id}: {e:?}");
595 }
596 yield first_item;
597 }
598
599 while let Some(item) = response_stream.next().await {
600 yield item;
601 }
602
603 if let Err(e) = chooser.free(&context_id).await {
604 tracing::warn!("Failed to free request {context_id}: {e:?}");
605 }
606 });
607 Ok(ResponseStream::new(wrapped_stream, stream_context))
608 }
609 }
610 }
611}
612
613impl Drop for KvRouter {
614 fn drop(&mut self) {
615 tracing::info!("Dropping KvRouter - cancelling background tasks");
616 self.cancellation_token.cancel();
617 }
618}