use std::sync::Arc;
use std::time::Instant;
use anyhow::Result;
use dynamo_kv_router::{
PrefillLoadEstimator,
config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env},
indexer::KvRouterError,
protocols::KV_EVENT_SUBJECT,
protocols::{
BlockExtraInfo, BlockHashOptions, DpRank, LocalBlockHash, PrefillLoadHint, RouterEvent,
RouterRequest, RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank,
compute_block_hash_for_seq,
},
};
use dynamo_runtime::{
component::{Client, Endpoint},
discovery::DiscoveryQuery,
pipeline::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn,
async_trait,
},
protocols::EndpointId,
protocols::annotated::Annotated,
traits::DistributedRuntimeProvider,
};
use futures::stream;
use tracing::Instrument;
use validator::Validate;
pub use dynamo_kv_router::approx;
pub use dynamo_kv_router::protocols;
pub use dynamo_kv_router::scheduling;
pub use dynamo_kv_router::selector;
pub mod agent_controller;
pub mod indexer;
pub mod metrics;
pub mod prefill_router;
pub mod publisher;
pub mod push_router;
pub mod scheduler;
pub mod sequence;
pub mod sticky_sessions;
pub use agent_controller::AgentController;
pub use indexer::{Indexer, ServedIndexerHandle, ServedIndexerMode, ensure_served_indexer_service};
pub use prefill_router::PrefillRouter;
pub use push_router::{DirectRoutingRouter, KvPushRouter};
pub use sticky_sessions::StickySessionRouter;
use crate::{
discovery::RuntimeConfigWatch,
kv_router::{
scheduler::{DefaultWorkerSelector, KvScheduler, PotentialLoad},
sequence::{SequenceError, SequenceRequest},
},
local_model::runtime_config::ModelRuntimeConfig,
};
use std::collections::HashSet;
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024;
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
format!("worker_kv_indexer_query_dp{dp_rank}")
}
fn log_routing_input_hashes(
request_id: Option<&str>,
block_size: u32,
tokens: &[u32],
local_hashes: &[LocalBlockHash],
) {
if !tracing::enabled!(tracing::Level::DEBUG) {
return;
}
let local_hash_ids: Vec<u64> = local_hashes.iter().map(|hash| hash.0).collect();
tracing::debug!(
request_id = request_id.unwrap_or(""),
isl_tokens = tokens.len(),
block_size,
num_blocks = local_hashes.len(),
local_hashes = ?local_hash_ids,
"[ROUTING_INPUT] request local hashes"
);
}
pub const KV_ROUTER_ENDPOINT: &str = "router-discovery";
pub fn router_endpoint_id(namespace: String, component: String) -> EndpointId {
EndpointId {
namespace,
component,
name: KV_ROUTER_ENDPOINT.to_string(),
}
}
pub fn router_discovery_query(namespace: String, component: String) -> DiscoveryQuery {
DiscoveryQuery::Endpoint {
namespace,
component,
endpoint: KV_ROUTER_ENDPOINT.to_string(),
}
}
pub struct KvRouter<Sel = DefaultWorkerSelector>
where
Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
indexer: Indexer,
scheduler: KvScheduler<Sel>,
workers_with_configs: RuntimeConfigWatch,
block_size: u32,
kv_router_config: KvRouterConfig,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
cancellation_token: tokio_util::sync::CancellationToken,
client: Client,
is_eagle: bool,
_served_indexer_handle: Option<ServedIndexerHandle>,
}
impl<Sel> KvRouter<Sel>
where
Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
#[allow(clippy::too_many_arguments)]
pub async fn new(
endpoint: Endpoint,
client: Client,
workers_with_configs: RuntimeConfigWatch,
block_size: u32,
selector: Sel,
kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
worker_type: &'static str,
model_name: Option<String>,
is_eagle: bool,
) -> Result<Self> {
let kv_router_config = kv_router_config.unwrap_or_default();
kv_router_config.validate()?;
let component = endpoint.component();
let cancellation_token = component.drt().primary_token();
let min_initial_workers = min_initial_workers_from_env()?;
let indexer = Indexer::new(
component,
&kv_router_config,
block_size,
model_name.as_deref(),
)
.await?;
if min_initial_workers > 0 && !kv_router_config.skip_initial_worker_wait {
let mut startup_watch = workers_with_configs.clone();
let _ = startup_watch
.wait_for(|m| m.len() >= min_initial_workers)
.await
.map_err(|_| {
anyhow::anyhow!(
"runtime config watch closed before {} workers appeared",
min_initial_workers
)
})?;
}
let scheduler = KvScheduler::start(
component.clone(),
block_size,
workers_with_configs.clone(),
selector,
&kv_router_config,
prefill_load_estimator.clone(),
worker_type,
)
.await?;
if kv_router_config.use_remote_indexer {
tracing::info!("Skipping KV event subscription (using remote indexer)");
} else if kv_router_config.should_subscribe_to_kv_events() {
indexer::start_subscriber(component.clone(), &kv_router_config, indexer.clone())
.await?;
} else {
tracing::info!(
"Skipping KV event subscription (use_kv_events={}, overlap_score_weight={})",
kv_router_config.use_kv_events,
kv_router_config.overlap_score_weight,
);
}
let served_indexer_handle = if kv_router_config.serve_indexer {
let model_name = model_name.clone().ok_or_else(|| {
anyhow::anyhow!("model_name is required when serve_indexer is configured")
})?;
Some(
ensure_served_indexer_service(
component.clone(),
ServedIndexerMode::from_use_kv_events(kv_router_config.use_kv_events),
model_name,
indexer.clone(),
)
.await?,
)
} else {
None
};
tracing::info!("KV Routing initialized");
Ok(Self {
indexer,
scheduler,
workers_with_configs,
block_size,
kv_router_config,
prefill_load_estimator,
cancellation_token,
client,
is_eagle,
_served_indexer_handle: served_indexer_handle,
})
}
pub fn client(&self) -> &Client {
&self.client
}
pub fn indexer(&self) -> &Indexer {
&self.indexer
}
pub fn kv_router_config(&self) -> &KvRouterConfig {
&self.kv_router_config
}
pub fn is_eagle(&self) -> bool {
self.is_eagle
}
pub async fn record_routing_decision(
&self,
mut tokens_with_hashes: TokensWithHashes,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
self.indexer
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn find_best_match(
&self,
context_id: Option<&str>,
tokens: &[u32],
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
router_config_override: Option<&RouterConfigOverride>,
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
expected_output_tokens: Option<u32>,
pinned_worker: Option<WorkerWithDpRank>,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> anyhow::Result<(WorkerWithDpRank, u32)> {
let start = Instant::now();
if update_states && context_id.is_none() {
anyhow::bail!("context_id must be provided when update_states is true");
}
let isl_tokens = tokens.len();
let hash_options = BlockHashOptions {
block_mm_infos,
lora_name: lora_name.as_deref(),
is_eagle: Some(self.is_eagle),
};
let block_hashes = tracing::info_span!("kv_router.compute_block_hashes")
.in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, hash_options));
log_routing_input_hashes(context_id, self.block_size, tokens, &block_hashes);
let hash_elapsed = start.elapsed();
let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
self.kv_router_config.compute_seq_hashes_for_tracking(
tokens,
self.block_size,
router_config_override,
hash_options,
Some(&block_hashes),
)
});
let seq_hash_elapsed = start.elapsed();
let overlap_scores = self
.indexer
.find_matches(block_hashes)
.instrument(tracing::info_span!("kv_router.find_matches"))
.await?;
let find_matches_elapsed = start.elapsed();
let response = self
.scheduler
.schedule(
context_id.map(|s| s.to_string()),
isl_tokens,
maybe_seq_hashes,
overlap_scores,
router_config_override,
update_states,
lora_name,
priority_jump,
expected_output_tokens,
pinned_worker,
allowed_worker_ids,
)
.instrument(tracing::info_span!("kv_router.schedule"))
.await?;
let total_elapsed = start.elapsed();
if let Some(m) = metrics::RoutingOverheadMetrics::get() {
m.observe(
hash_elapsed,
seq_hash_elapsed,
find_matches_elapsed,
total_elapsed,
);
}
#[cfg(feature = "bench")]
tracing::info!(
isl_tokens,
hash_us = hash_elapsed.as_micros() as u64,
seq_hash_us = (seq_hash_elapsed - hash_elapsed).as_micros() as u64,
find_matches_us = (find_matches_elapsed - seq_hash_elapsed).as_micros() as u64,
schedule_us = (total_elapsed - find_matches_elapsed).as_micros() as u64,
total_us = total_elapsed.as_micros() as u64,
"find_best_match completed"
);
Ok((response.best_worker, response.overlap_blocks))
}
pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
self.scheduler.register_workers(worker_ids);
}
#[allow(clippy::too_many_arguments)]
pub async fn add_request(
&self,
request_id: String,
tokens: &[u32],
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
overlap_blocks: u32,
expected_output_tokens: Option<u32>,
worker: WorkerWithDpRank,
lora_name: Option<String>,
router_config_override: Option<&RouterConfigOverride>,
) {
let isl_tokens = tokens.len();
let hash_options = BlockHashOptions {
block_mm_infos,
lora_name: lora_name.as_deref(),
is_eagle: Some(self.is_eagle),
};
let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
tokens,
self.block_size,
router_config_override,
hash_options,
None,
);
let track_prefill_tokens = self
.kv_router_config
.track_prefill_tokens(router_config_override);
let prefill_load_hint =
self.prefill_load_hint_for(isl_tokens, overlap_blocks, track_prefill_tokens);
if let Err(e) = self
.scheduler
.add_request(SequenceRequest {
request_id: request_id.clone(),
token_sequence: maybe_seq_hashes,
isl: isl_tokens,
overlap: overlap_blocks,
track_prefill_tokens,
expected_output_tokens,
prefill_load_hint,
worker,
lora_name,
})
.await
{
tracing::warn!("Failed to add request {request_id}: {e}");
}
}
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
self.scheduler.mark_prefill_completed(request_id).await
}
pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
self.scheduler.free(request_id).await
}
pub fn pending_count(&self) -> usize {
self.scheduler.pending_count()
}
fn prefill_load_hint_for(
&self,
isl_tokens: usize,
overlap_blocks: u32,
track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> {
if !track_prefill_tokens {
return None;
}
let prefix = (overlap_blocks as usize) * (self.block_size as usize);
let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 {
return None;
}
let Some(estimator) = &self.prefill_load_estimator else {
return None;
};
match estimator.predict_prefill_duration(1, effective_isl, prefix) {
Ok(expected_prefill_duration) => Some(PrefillLoadHint {
initial_effective_prefill_tokens: effective_isl,
expected_prefill_duration: Some(expected_prefill_duration),
}),
Err(error) => {
tracing::warn!(
effective_isl,
prefix,
"failed to predict prefill duration for direct add_request path: {error}"
);
None
}
}
}
pub fn worker_type(&self) -> &'static str {
self.scheduler.worker_type()
}
pub fn unique_dp_rank_for_worker(&self, worker_id: WorkerId) -> Option<u32> {
let configs = self.workers_with_configs.borrow();
let config = configs.get(&worker_id)?;
(config.data_parallel_size == 1).then_some(config.data_parallel_start_rank)
}
pub fn add_output_block(
&self,
request_id: &str,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
self.scheduler.add_output_block(request_id, decay_fraction)
}
pub fn block_size(&self) -> u32 {
self.block_size
}
pub async fn get_overlap_blocks(
&self,
tokens: &[u32],
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
worker: WorkerWithDpRank,
lora_name: Option<&str>,
) -> Result<u32, KvRouterError> {
let block_hashes = compute_block_hash_for_seq(
tokens,
self.block_size,
BlockHashOptions {
block_mm_infos,
lora_name,
is_eagle: Some(self.is_eagle),
},
);
log_routing_input_hashes(None, self.block_size, tokens, &block_hashes);
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
}
pub async fn get_potential_loads(
&self,
tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
lora_name: Option<&str>,
) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len();
let hash_options = BlockHashOptions {
block_mm_infos,
lora_name,
is_eagle: Some(self.is_eagle),
};
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, hash_options);
let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
tokens,
self.block_size,
router_config_override,
hash_options,
Some(&block_hashes),
);
let track_prefill_tokens = self
.kv_router_config
.track_prefill_tokens(router_config_override);
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
Ok(self.scheduler.get_potential_loads(
maybe_seq_hashes,
isl_tokens,
overlap_scores,
track_prefill_tokens,
))
}
pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.indexer.dump_events().await
}
}
#[async_trait]
impl<Sel> AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error>
for KvRouter<Sel>
where
Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
async fn generate(
&self,
request: SingleIn<RouterRequest>,
) -> Result<ManyOut<Annotated<RouterResponse>>> {
let (request, ctx) = request.into_parts();
let context_id = ctx.context().id().to_string();
let response = match request {
RouterRequest::New {
tokens,
block_mm_infos,
} => {
let (best_worker, overlap_blocks) = self
.find_best_match(
Some(&context_id),
&tokens,
block_mm_infos.as_deref(),
None,
true,
None,
0.0,
None,
None,
None,
)
.await?;
RouterResponse::New {
worker_id: best_worker.worker_id,
dp_rank: best_worker.dp_rank,
overlap_blocks,
}
}
RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
success: self.mark_prefill_completed(&context_id).await.is_ok(),
},
RouterRequest::MarkFree { request_id } => {
let request_id = match request_id.as_deref() {
Some(request_id) if !request_id.trim().is_empty() => request_id,
_ => &context_id,
};
RouterResponse::FreeMarked {
success: self.free(request_id).await.is_ok(),
}
}
};
let response = Annotated::from_data(response);
let stream = stream::iter(vec![response]);
Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
}
}
impl<Sel> Drop for KvRouter<Sel>
where
Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
fn drop(&mut self) {
tracing::info!("Dropping KvRouter - cancelling background tasks");
self.cancellation_token.cancel();
}
}