use super::KvRouterConfig;
use super::RouterConfigOverride;
use super::WorkerSelector;
use super::protocols::{DpRank, OverlapScores, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::queue::SchedulerQueue;
use super::sequence::{
ActiveSequencesMulti, SequenceError, SequenceRequest, create_multi_worker_sequences,
};
use crate::discovery::RuntimeConfigWatch;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "bench")]
use std::time::Instant;
use dynamo_tokens::SequenceHash;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad {
pub worker_id: WorkerId,
pub dp_rank: DpRank,
pub potential_prefill_tokens: usize,
pub potential_decode_blocks: usize,
}
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
#[error("no endpoints available to route work")]
NoEndpoints,
#[error("endpoint subscriber shutdown")]
SubscriberShutdown,
#[error("failed to initialize event publisher: {0}")]
InitFailed(String),
}
#[derive(Debug)]
pub struct SchedulingResponse {
pub best_worker: WorkerWithDpRank,
pub overlap_blocks: u32,
}
pub struct SchedulingRequest {
pub maybe_request_id: Option<String>,
pub token_seq: Option<Vec<SequenceHash>>,
pub isl_tokens: usize,
pub overlaps: OverlapScores,
pub decode_blocks: HashMap<WorkerWithDpRank, usize>,
pub prefill_tokens: HashMap<WorkerWithDpRank, usize>,
pub router_config_override: Option<RouterConfigOverride>,
pub update_states: bool,
pub lora_name: Option<String>,
pub priority_jump: f64,
pub allowed_worker_ids: Option<HashSet<WorkerId>>,
resp_tx: Option<tokio::sync::oneshot::Sender<Result<SchedulingResponse, KvSchedulerError>>>,
}
impl SchedulingRequest {
pub fn respond(&mut self, result: Result<SchedulingResponse, KvSchedulerError>) {
let Some(tx) = self.resp_tx.take() else {
tracing::error!("respond called multiple times on same request");
return;
};
if tx.send(result).is_err() {
tracing::error!("failed to send response to requestor");
}
}
}
pub struct KvScheduler {
request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMulti>,
queue: Arc<SchedulerQueue>,
}
impl KvScheduler {
pub async fn start(
component: Component,
block_size: u32,
workers_with_configs: RuntimeConfigWatch,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
kv_router_config: &KvRouterConfig,
worker_type: &'static str,
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
workers_with_configs.borrow().clone();
let router_id = component.drt().discovery().instance_id();
let slots = create_multi_worker_sequences(
component.clone(),
block_size as usize,
initial_workers,
kv_router_config.router_replica_sync,
router_id,
worker_type,
)
.await
.map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?;
let slots_monitor = slots.clone();
let mut monitor_rx = workers_with_configs.clone();
let monitor_cancel_token = component.drt().child_token();
tokio::spawn(async move {
tracing::trace!("KvScheduler workers monitoring task started");
let mut last_workers: HashMap<WorkerId, ModelRuntimeConfig> = HashMap::new();
loop {
tokio::select! {
_ = monitor_cancel_token.cancelled() => {
tracing::trace!("KvScheduler workers monitoring task shutting down");
break;
}
result = monitor_rx.changed() => {
if result.is_err() {
tracing::warn!("KvScheduler: config watch sender dropped, shutting down");
break;
}
}
}
let current_workers = monitor_rx.borrow_and_update().clone();
if current_workers != last_workers {
let dp_range: HashMap<u64, (u32, u32)> = current_workers
.iter()
.map(|(&id, c)| (id, (c.data_parallel_start_rank, c.data_parallel_size)))
.collect();
slots_monitor.update_workers(&dp_range);
last_workers = current_workers;
}
}
});
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
let scheduler_cancel_token = component.drt().primary_token();
let queue = Arc::new(SchedulerQueue::new(
slots.clone(),
workers_with_configs.clone(),
kv_router_config.router_queue_threshold,
block_size,
selector,
));
let queue_clone = queue.clone();
tokio::spawn(async move {
let mut request_rx = request_rx;
let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
tracing::trace!("scheduler background task started");
loop {
tokio::select! {
_ = scheduler_cancel_token.cancelled() => {
tracing::trace!("scheduler background task shutting down");
break;
}
request = request_rx.recv() => {
let Some(request) = request else {
tracing::warn!("scheduler shutdown");
break;
};
tracing::trace!("received request to be scheduled");
queue_clone.enqueue(request).await;
}
_ = recheck_interval.tick() => {
queue_clone.update().await;
}
}
}
tracing::trace!("background endpoint subscriber shutting down");
});
Ok(KvScheduler {
request_tx,
slots,
queue,
})
}
#[allow(clippy::too_many_arguments)]
pub async fn schedule(
&self,
maybe_request_id: Option<String>,
isl_tokens: usize,
token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores,
router_config_override: Option<&RouterConfigOverride>,
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> Result<SchedulingResponse, KvSchedulerError> {
#[cfg(feature = "bench")]
let start = Instant::now();
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
maybe_request_id,
token_seq,
isl_tokens,
overlaps,
decode_blocks: HashMap::new(),
prefill_tokens: HashMap::new(),
router_config_override: router_config_override.cloned(),
update_states,
lora_name,
priority_jump,
allowed_worker_ids,
resp_tx: Some(resp_tx),
};
self.request_tx
.send(request)
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
#[cfg(feature = "bench")]
let send_elapsed = start.elapsed();
let response = resp_rx
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)??;
#[cfg(feature = "bench")]
let total_elapsed = start.elapsed();
#[cfg(feature = "bench")]
tracing::info!(
isl_tokens,
send_us = send_elapsed.as_micros() as u64,
total_us = total_elapsed.as_micros() as u64,
"scheduler.schedule completed"
);
Ok(response)
}
pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
self.slots.add_request(req).await
}
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots
.mark_prefill_completed(&request_id.to_string())
.await?;
self.queue.update().await;
Ok(())
}
pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots.free(&request_id.to_string()).await?;
self.queue.update().await;
Ok(())
}
pub fn worker_type(&self) -> &'static str {
self.slots.worker_type()
}
pub fn add_output_block(
&self,
request_id: &str,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
self.slots
.add_output_block(&request_id.to_string(), decay_fraction)
}
pub fn get_potential_loads(
&self,
token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize,
overlaps: OverlapScores,
) -> Vec<PotentialLoad> {
let (decode_blocks, prefill_tokens) = self
.slots
.potential_blocks_and_tokens(token_seq, isl_tokens, overlaps);
let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
workers.extend(decode_blocks.keys().copied());
workers.extend(prefill_tokens.keys().copied());
let mut loads = Vec::new();
for worker in workers {
loads.push(PotentialLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
potential_prefill_tokens: prefill_tokens
.get(&worker)
.copied()
.unwrap_or(isl_tokens),
potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
});
}
loads
}
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
self.slots.get_active_lora_counts()
}
}
fn softmax_sample(
logits: &HashMap<WorkerWithDpRank, f64>,
temperature: f64,
) -> Vec<WorkerWithDpRank> {
if logits.is_empty() {
panic!("Empty logits for softmax sampling");
}
if temperature == 0.0 {
let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b));
let min_keys: Vec<_> = logits
.iter()
.filter(|&(_, &v)| v == min_logit)
.map(|(k, _)| *k)
.collect();
return min_keys;
}
let keys: Vec<_> = logits.keys().copied().collect();
let values: Vec<_> = logits.values().copied().collect();
let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let probabilities = if min_val == max_val {
vec![1.0 / keys.len() as f64; keys.len()]
} else {
let range = max_val - min_val;
let scaled: Vec<f64> = values.iter().map(|&v| -(v / range) / temperature).collect();
let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let mut probs: Vec<f64> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();
let sum: f64 = probs.iter().sum();
probs.iter_mut().for_each(|p| *p /= sum);
probs
};
let mut rng = rand::rng();
let sample: f64 = rng.random();
let mut cumsum = 0.0;
for (i, &prob) in probabilities.iter().enumerate() {
cumsum += prob;
if sample <= cumsum {
return vec![keys[i]];
}
}
vec![keys[keys.len() - 1]]
}
#[derive(Debug, Clone, Default)]
pub struct DefaultWorkerSelector {
pub kv_router_config: KvRouterConfig,
}
impl DefaultWorkerSelector {
pub fn new(kv_router_config: Option<KvRouterConfig>) -> Self {
Self {
kv_router_config: kv_router_config.unwrap_or_default(),
}
}
}
impl WorkerSelector for DefaultWorkerSelector {
fn select_worker(
&self,
workers: &HashMap<WorkerId, ModelRuntimeConfig>,
request: &SchedulingRequest,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> {
assert!(request.isl_tokens > 0);
let allowed_ids = request.allowed_worker_ids.as_ref();
if allowed_ids.map_or(workers.is_empty(), |ids| {
!workers.keys().any(|wid| ids.contains(wid))
}) {
return Err(KvSchedulerError::NoEndpoints);
}
let isl = request.isl_tokens;
let request_blocks = isl.div_ceil(block_size as usize);
let overlaps = &request.overlaps.scores;
let decode_blocks = &request.decode_blocks;
let prefill_tokens = &request.prefill_tokens;
let mut worker_logits = HashMap::new();
let overlap_weight = request
.router_config_override
.as_ref()
.and_then(|cfg| cfg.overlap_score_weight)
.unwrap_or(self.kv_router_config.overlap_score_weight);
for (worker_id, config) in workers
.iter()
.filter(|(wid, _)| allowed_ids.is_none_or(|ids| ids.contains(wid)))
{
let data_parallel_size = config.data_parallel_size;
let data_parallel_start_rank = config.data_parallel_start_rank;
for dp_rank in data_parallel_start_rank..data_parallel_start_rank + data_parallel_size {
let worker = WorkerWithDpRank::new(*worker_id, dp_rank);
let overlap = *overlaps.get(&worker).unwrap_or(&0);
let prefill_token = *prefill_tokens.get(&worker).unwrap_or(&isl);
let potential_prefill_block = (prefill_token as f64) / (block_size as f64);
let decode_block = *decode_blocks
.get(&worker)
.unwrap_or(&(potential_prefill_block.floor() as usize))
as f64;
let logit = overlap_weight * potential_prefill_block + decode_block;
worker_logits.insert(worker, logit);
tracing::debug!(
"Formula for worker_id={} dp_rank={:?} with {overlap} cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
worker.worker_id,
worker.dp_rank
);
}
}
let temperature = request
.router_config_override
.as_ref()
.and_then(|cfg| cfg.router_temperature)
.unwrap_or(self.kv_router_config.router_temperature);
let candidates = softmax_sample(&worker_logits, temperature);
let best_worker = if candidates.len() > 1 {
tracing::debug!(
"Multiple workers tied with same logit, using tree size as tie-breaker"
);
let tree_sizes: Vec<(usize, &WorkerWithDpRank)> = candidates
.iter()
.map(|w| (request.overlaps.tree_sizes.get(w).copied().unwrap_or(0), w))
.collect();
if tree_sizes.iter().all(|(s, _)| *s == tree_sizes[0].0) {
let idx = rand::rng().random_range(0..candidates.len());
candidates[idx]
} else {
*tree_sizes.iter().min_by_key(|(s, _)| *s).unwrap().1
}
} else {
candidates[0]
};
let best_logit = worker_logits[&best_worker];
let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0);
let total_blocks_info = workers
.get(&best_worker.worker_id)
.and_then(|cfg| cfg.total_kv_blocks)
.map(|blocks| format!(", total blocks: {}", blocks))
.unwrap_or_default();
let tree_size = request
.overlaps
.tree_sizes
.get(&best_worker)
.copied()
.unwrap_or(0);
tracing::info!(
"Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}, tree size: {}{}",
best_worker.worker_id,
best_worker.dp_rank,
best_logit,
best_overlap,
tree_size,
total_blocks_info
);
Ok(WorkerSelectionResult {
worker: best_worker,
required_blocks: request_blocks as u64,
overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_softmax_sample_single_key() {
let mut logits = HashMap::new();
let worker = WorkerWithDpRank::from_worker_id(42);
logits.insert(worker, 0.5);
for temperature in &[0.1, 1.0, 10.0] {
let result = softmax_sample(&logits, *temperature);
assert_eq!(result.len(), 1, "Should return exactly one worker");
assert_eq!(result[0], worker, "Should return the only available worker");
}
logits.clear();
logits.insert(worker, -100.0); let result = softmax_sample(&logits, 1.0);
assert_eq!(result.len(), 1);
assert_eq!(result[0], worker);
logits.clear();
logits.insert(worker, 100.0); let result = softmax_sample(&logits, 1.0);
assert_eq!(result.len(), 1);
assert_eq!(result[0], worker);
logits.clear();
logits.insert(worker, 0.0); let result = softmax_sample(&logits, 1.0);
assert_eq!(result.len(), 1);
assert_eq!(result[0], worker);
}
#[test]
fn test_softmax_sample_zero_temperature() {
let mut logits = HashMap::new();
let worker1 = WorkerWithDpRank::from_worker_id(1);
let worker2 = WorkerWithDpRank::from_worker_id(2);
let worker3 = WorkerWithDpRank::from_worker_id(3);
let worker4 = WorkerWithDpRank::from_worker_id(4);
logits.insert(worker1, 5.0);
logits.insert(worker2, 3.0); logits.insert(worker3, 7.0);
logits.insert(worker4, 3.5);
let result = softmax_sample(&logits, 0.0);
assert_eq!(
result.len(),
1,
"Should return one worker when there's no tie"
);
assert_eq!(
result[0], worker2,
"Should return worker with smallest logit when temperature is 0"
);
logits.clear();
let worker5 = WorkerWithDpRank::from_worker_id(5);
let worker6 = WorkerWithDpRank::from_worker_id(6);
logits.insert(worker1, 5.0);
logits.insert(worker2, 3.0); logits.insert(worker5, 3.0); logits.insert(worker6, 7.0);
let result = softmax_sample(&logits, 0.0);
assert_eq!(
result.len(),
2,
"Should return all workers with smallest logit when tied"
);
assert!(
result.contains(&worker2) && result.contains(&worker5),
"Should contain both tied workers"
);
logits.clear();
let worker10 = WorkerWithDpRank::from_worker_id(10);
let worker20 = WorkerWithDpRank::from_worker_id(20);
let worker30 = WorkerWithDpRank::from_worker_id(30);
logits.insert(worker10, -1.0);
logits.insert(worker20, -5.0); logits.insert(worker30, 0.0);
let result = softmax_sample(&logits, 0.0);
assert_eq!(result.len(), 1);
assert_eq!(
result[0], worker20,
"Should handle negative logits correctly"
);
}
}