1use crate::local_model::runtime_config::ModelRuntimeConfig;
5use anyhow::Result;
6use dynamo_runtime::component::{Component, Instance};
7use dynamo_runtime::traits::DistributedRuntimeProvider;
8use dynamo_runtime::traits::events::EventPublisher;
9use rand::Rng;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::{RwLock, watch};
15
16use super::KV_HIT_RATE_SUBJECT;
17use super::KvRouterConfig;
18use super::RouterConfigOverride;
19use super::WorkerSelector;
20use super::indexer::OverlapScores;
21use super::protocols::WorkerSelectionResult;
22use super::sequence::ActiveSequencesMultiWorker;
23
24use crate::tokens::SequenceHash;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct KVHitRateEvent {
28 pub worker_id: i64,
29 pub isl_blocks: usize,
30 pub overlap_blocks: u32,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct PotentialLoad {
35 pub worker_id: i64,
36 pub potential_prefill_tokens: usize,
37 pub potential_decode_blocks: usize,
38}
39
40#[derive(Debug, thiserror::Error)]
41pub enum KvSchedulerError {
42 #[error("no endpoints aviailable to route work")]
43 NoEndpoints,
44
45 #[error("all workers busy")]
46 AllWorkersBusy,
47
48 #[error("endpoint subscriber shutdown")]
49 SubscriberShutdown,
50}
51
52#[derive(Debug)]
53pub struct SchedulingResponse {
54 pub best_worker_id: i64,
55 pub overlap_blocks: u32,
56}
57
58pub struct SchedulingRequest {
59 pub maybe_request_id: Option<String>,
60 pub token_seq: Option<Vec<SequenceHash>>,
61 pub isl_tokens: usize,
62 pub overlaps: OverlapScores,
63 pub decode_blocks: HashMap<i64, usize>,
64 pub prefill_tokens: HashMap<i64, usize>,
65 pub router_config_override: Option<RouterConfigOverride>,
67 pub update_states: bool,
69 resp_tx: Option<tokio::sync::oneshot::Sender<SchedulingResponse>>,
71}
72
73impl SchedulingRequest {
74 pub fn respond(&mut self, response: SchedulingResponse) {
75 if let Some(tx) = self.resp_tx.take() {
77 if tx.send(response).is_err() {
79 tracing::error!("failed to send response to requestor");
80 }
81 } else {
82 tracing::error!("respond called multiple times on same request");
83 }
84 }
85}
86
87pub struct KvScheduler {
88 request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
89 slots: Arc<ActiveSequencesMultiWorker>,
90}
91
92impl KvScheduler {
93 pub async fn start(
94 component: Component,
95 block_size: u32,
96 instances_rx: watch::Receiver<Vec<Instance>>,
97 runtime_configs_rx: watch::Receiver<HashMap<i64, ModelRuntimeConfig>>,
98 selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
99 replica_sync: bool,
100 router_uuid: String,
101 ) -> Result<Self, KvSchedulerError> {
102 let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
103 let instances: Vec<Instance> = instances_rx.borrow().clone();
104 let runtime_configs: HashMap<i64, ModelRuntimeConfig> = runtime_configs_rx.borrow().clone();
105
106 let workers_with_configs: Arc<RwLock<HashMap<i64, Option<ModelRuntimeConfig>>>> = {
108 let mut initial_map = HashMap::new();
109 for instance in &instances {
110 let worker_id = instance.instance_id;
111 let config = runtime_configs.get(&worker_id).cloned();
112 if config.is_some() {
113 tracing::info!("Runtime config found for worker_id: {}", worker_id);
114 }
115 initial_map.insert(worker_id, config);
116 }
117 Arc::new(RwLock::new(initial_map))
118 };
119
120 let worker_ids: Vec<i64> = instances
121 .iter()
122 .map(|instance| instance.instance_id)
123 .collect();
124 let slots = Arc::new(ActiveSequencesMultiWorker::new(
125 component.clone(),
126 block_size as usize,
127 worker_ids,
128 replica_sync,
129 router_uuid,
130 ));
131
132 let workers_monitor = workers_with_configs.clone();
134 let slots_monitor = slots.clone();
135 let mut instances_monitor_rx = instances_rx.clone();
136 let mut configs_monitor_rx = runtime_configs_rx.clone();
137 let monitor_cancel_token = component.drt().primary_token();
138 tokio::spawn(async move {
139 tracing::trace!("workers monitoring task started");
140 loop {
141 tokio::select! {
143 _ = monitor_cancel_token.cancelled() => {
144 tracing::trace!("workers monitoring task shutting down");
145 break;
146 }
147 result = instances_monitor_rx.changed() => {
148 if result.is_err() {
149 tracing::warn!("endpoint watch sender shutdown in monitor");
150 break;
151 }
152 }
153 result = configs_monitor_rx.changed() => {
154 if result.is_err() {
155 tracing::warn!("runtime configs watch sender shutdown in monitor");
156 break;
157 }
158 }
159 }
160
161 let new_instances = instances_monitor_rx.borrow_and_update().clone();
163 let new_configs = configs_monitor_rx.borrow_and_update().clone();
164
165 let worker_ids: Vec<i64> = new_instances
167 .iter()
168 .map(|instance| instance.instance_id)
169 .collect();
170 slots_monitor.update_workers(worker_ids);
171
172 let mut workers_map = workers_monitor.write().await;
174 workers_map.clear();
175 for instance in &new_instances {
176 let worker_id = instance.instance_id;
177 let config = new_configs.get(&worker_id).cloned();
178 if config.is_some() {
179 tracing::info!("Runtime config found for worker_id: {}", worker_id);
180 }
181 workers_map.insert(worker_id, config);
182 }
183 tracing::trace!(
184 "Updated workers_with_configs with {} workers",
185 workers_map.len()
186 );
187 }
188 tracing::trace!("workers monitoring task shutting down");
189 });
190
191 let slots_clone = slots.clone();
192 let workers_scheduler = workers_with_configs.clone();
193 let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
194 let scheduler_cancel_token = component.drt().primary_token();
195 let ns_clone = component.namespace().clone();
196
197 tokio::spawn(async move {
199 let mut request_rx = request_rx;
200 tracing::trace!("scheduler background task started");
201
202 loop {
203 if scheduler_cancel_token.is_cancelled() {
205 tracing::trace!("scheduler background task shutting down");
206 break;
207 }
208
209 let Some(mut request) = request_rx.recv().await else {
211 tracing::warn!("scheduler shutdown");
212 break;
213 };
214 tracing::trace!("received request to be scheduled");
215
216 let (decode_blocks, prefill_tokens) = slots_clone
217 .potential_blocks_and_tokens(
218 request.token_seq.clone(),
219 request.isl_tokens,
220 request.overlaps.clone(),
221 )
222 .await;
223 request.decode_blocks = decode_blocks;
224 request.prefill_tokens = prefill_tokens;
225
226 let workers = workers_scheduler.read().await.clone();
228
229 match selector.select_worker(&workers, &request, block_size) {
230 Ok(selection) => {
231 let event = KVHitRateEvent {
232 worker_id: selection.worker_id,
233 isl_blocks: selection.required_blocks as usize,
234 overlap_blocks: selection.overlap_blocks,
235 };
236 if let Err(e) = ns_clone.publish(KV_HIT_RATE_SUBJECT, &event).await {
237 tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
238 }
239
240 let response = SchedulingResponse {
241 best_worker_id: selection.worker_id,
242 overlap_blocks: selection.overlap_blocks,
243 };
244 request.respond(response);
245
246 if !request.update_states {
248 continue;
249 }
250
251 let Some(request_id) = request.maybe_request_id else {
252 tracing::error!(
253 "No request_id provided to add_request to the slot tracker"
254 );
255 continue;
256 };
257
258 if let Err(e) = slots_clone
259 .add_request(
260 request_id.clone(),
261 request.token_seq,
262 request.isl_tokens,
263 selection.overlap_blocks,
264 selection.worker_id,
265 )
266 .await
267 {
268 tracing::warn!(
269 "Failed to add request {request_id} to local slot tracker: {e:?}"
270 );
271 }
272 }
273 Err(KvSchedulerError::NoEndpoints) => {
274 tracing::trace!("no endpoints available; waiting for endpoints update");
275 tokio::time::sleep(Duration::from_millis(5)).await;
276 continue;
277 }
278 Err(KvSchedulerError::AllWorkersBusy) => {
280 tracing::trace!("all workers busy; waiting for more capacity");
281 tokio::time::sleep(Duration::from_millis(5)).await;
282 continue;
283 }
284 Err(e) => {
285 tracing::error!("error scheduling request: {:?}", e);
286 break;
287 }
288 }
289 }
290
291 tracing::trace!("background endpoint subscriber shutting down");
292 });
293
294 Ok(KvScheduler { request_tx, slots })
295 }
296
297 pub async fn schedule(
298 &self,
299 maybe_request_id: Option<String>,
300 isl_tokens: usize,
301 token_seq: Option<Vec<SequenceHash>>,
302 overlaps: OverlapScores,
303 router_config_override: Option<&RouterConfigOverride>,
304 update_states: bool,
305 ) -> Result<i64, KvSchedulerError> {
306 let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
307 let request = SchedulingRequest {
308 maybe_request_id,
309 token_seq,
310 isl_tokens,
311 overlaps,
312 decode_blocks: HashMap::new(),
313 prefill_tokens: HashMap::new(),
314 router_config_override: router_config_override.cloned(),
315 update_states,
316 resp_tx: Some(resp_tx), };
318
319 self.request_tx
320 .send(request)
321 .await
322 .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
323 let response = resp_rx
324 .await
325 .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
326
327 let best_worker_id = response.best_worker_id;
328 Ok(best_worker_id)
329 }
330
331 pub async fn add_request(
332 &self,
333 request_id: String,
334 token_sequence: Option<Vec<SequenceHash>>,
335 isl: usize,
336 overlap: u32,
337 worker_id: i64,
338 ) {
339 let _ = self
340 .slots
341 .add_request(request_id, token_sequence, isl, overlap, worker_id)
342 .await;
343 }
344
345 pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<()> {
346 self.slots
347 .mark_prefill_completed(&request_id.to_string())
348 .await
349 }
350
351 pub async fn free(&self, request_id: &str) -> Result<()> {
352 self.slots.free(&request_id.to_string()).await
353 }
354
355 pub async fn get_potential_loads(
356 &self,
357 token_seq: Option<Vec<SequenceHash>>,
358 isl_tokens: usize,
359 overlaps: OverlapScores,
360 ) -> Vec<PotentialLoad> {
361 let (decode_blocks, prefill_tokens) = self
362 .slots
363 .potential_blocks_and_tokens(token_seq, isl_tokens, overlaps)
364 .await;
365
366 let mut worker_ids: HashSet<i64> = HashSet::new();
368 worker_ids.extend(decode_blocks.keys().copied());
369 worker_ids.extend(prefill_tokens.keys().copied());
370
371 let mut loads = Vec::new();
373 for worker_id in worker_ids {
374 loads.push(PotentialLoad {
375 worker_id,
376 potential_prefill_tokens: prefill_tokens
377 .get(&worker_id)
378 .copied()
379 .unwrap_or(isl_tokens),
380 potential_decode_blocks: decode_blocks.get(&worker_id).copied().unwrap_or(0),
381 });
382 }
383
384 loads
385 }
386}
387
388fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 {
390 if logits.is_empty() {
391 panic!("Empty logits for softmax sampling");
392 }
393
394 if temperature == 0.0 {
396 let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b));
398
399 let min_keys: Vec<_> = logits
401 .iter()
402 .filter(|&(_, &v)| v == min_logit)
403 .map(|(k, _)| *k)
404 .collect();
405
406 let mut rng = rand::rng();
408 let index = rng.random_range(0..min_keys.len());
409 return min_keys[index];
410 }
411
412 let keys: Vec<_> = logits.keys().copied().collect();
413 let values: Vec<_> = logits.values().copied().collect();
414
415 let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
417 let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
418
419 let probabilities = if min_val == max_val {
420 vec![1.0 / keys.len() as f64; keys.len()]
422 } else {
423 let normalized: Vec<_> = values
425 .iter()
426 .map(|&v| {
427 let norm = v / (max_val - min_val);
430 -norm
431 })
432 .collect();
433
434 let scaled: Vec<_> = normalized.iter().map(|&v| v / temperature).collect();
436
437 let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
438 let exp_values: Vec<_> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();
439
440 let sum_exp: f64 = exp_values.iter().sum();
441 exp_values.iter().map(|&v| v / sum_exp).collect()
442 };
443
444 let mut rng = rand::rng();
446 let sample: f64 = rng.random();
447
448 let mut cumsum = 0.0;
449 for (i, &prob) in probabilities.iter().enumerate() {
450 cumsum += prob;
451 if sample <= cumsum {
452 return keys[i];
453 }
454 }
455
456 keys[keys.len() - 1]
458}
459
460#[derive(Debug, Clone, Default)]
462pub struct DefaultWorkerSelector {
463 pub kv_router_config: KvRouterConfig,
464}
465
466impl DefaultWorkerSelector {
467 pub fn new(kv_router_config: Option<KvRouterConfig>) -> Self {
468 Self {
469 kv_router_config: kv_router_config.unwrap_or_default(),
470 }
471 }
472}
473
474impl WorkerSelector for DefaultWorkerSelector {
475 fn select_worker(
476 &self,
477 workers: &HashMap<i64, Option<ModelRuntimeConfig>>,
478 request: &SchedulingRequest,
479 block_size: u32,
480 ) -> Result<WorkerSelectionResult, KvSchedulerError> {
481 assert!(request.isl_tokens > 0);
482
483 if workers.is_empty() {
484 return Err(KvSchedulerError::NoEndpoints);
485 }
486
487 let isl = request.isl_tokens;
488 let request_blocks = isl.div_ceil(block_size as usize);
489 let overlaps = &request.overlaps.scores;
490
491 let decode_blocks = &request.decode_blocks;
492 let prefill_tokens = &request.prefill_tokens;
493
494 let mut worker_logits = HashMap::new();
495 let mut max_logit = f64::NEG_INFINITY;
496
497 for worker_id in workers.keys() {
499 let overlap = *overlaps.get(worker_id).unwrap_or(&0);
500
501 let prefill_token = *prefill_tokens.get(worker_id).unwrap_or(&isl);
503 let potential_prefill_block = (prefill_token as f64) / (block_size as f64);
504
505 let decode_block = *decode_blocks
507 .get(worker_id)
508 .unwrap_or(&(potential_prefill_block.floor() as usize))
509 as f64;
510
511 let overlap_weight = request
513 .router_config_override
514 .as_ref()
515 .and_then(|cfg| cfg.overlap_score_weight)
516 .unwrap_or(self.kv_router_config.overlap_score_weight);
517
518 let logit = overlap_weight * potential_prefill_block + decode_block;
520 max_logit = max_logit.max(logit);
521
522 worker_logits.insert(*worker_id, logit);
523
524 tracing::info!(
525 "Formula for {worker_id} with {overlap} cached blocks: {logit:.3} \
526 = {overlap_weight:.1} * prefill_blocks + decode_blocks \
527 = {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}"
528 );
529 }
530
531 let temperature = request
534 .router_config_override
535 .as_ref()
536 .and_then(|cfg| cfg.router_temperature)
537 .unwrap_or(self.kv_router_config.router_temperature);
538 let best_worker_id = softmax_sample(&worker_logits, temperature);
539 let best_logit = worker_logits[&best_worker_id];
540
541 let best_overlap = *overlaps.get(&best_worker_id).unwrap_or(&0);
542 let total_blocks_info = workers
543 .get(&best_worker_id)
544 .and_then(|cfg| cfg.as_ref())
545 .and_then(|cfg| cfg.total_kv_blocks)
546 .map(|blocks| format!(", total blocks: {}", blocks))
547 .unwrap_or_default();
548
549 tracing::info!(
550 "Selected worker: {}, logit: {:.3}, cached blocks: {}{}",
551 best_worker_id,
552 best_logit,
553 best_overlap,
554 total_blocks_info
555 );
556
557 Ok(WorkerSelectionResult {
558 worker_id: best_worker_id,
559 required_blocks: request_blocks as u64,
560 overlap_blocks: overlaps.get(&best_worker_id).copied().unwrap_or(0),
561 })
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568
569 #[test]
570 fn test_softmax_sample_single_key() {
571 let mut logits = HashMap::new();
573 let worker_id = 42;
574 logits.insert(worker_id, 0.5); for temperature in &[0.1, 1.0, 10.0] {
578 let result = softmax_sample(&logits, *temperature);
579 assert_eq!(result, worker_id, "Should return the only available worker");
580 }
581
582 logits.clear();
584 logits.insert(worker_id, -100.0); assert_eq!(softmax_sample(&logits, 1.0), worker_id);
586
587 logits.clear();
588 logits.insert(worker_id, 100.0); assert_eq!(softmax_sample(&logits, 1.0), worker_id);
590
591 logits.clear();
592 logits.insert(worker_id, 0.0); assert_eq!(softmax_sample(&logits, 1.0), worker_id);
594 }
595
596 #[test]
597 fn test_softmax_sample_zero_temperature() {
598 let mut logits = HashMap::new();
600 logits.insert(1, 5.0);
601 logits.insert(2, 3.0); logits.insert(3, 7.0);
603 logits.insert(4, 3.5);
604
605 for _ in 0..10 {
607 let result = softmax_sample(&logits, 0.0);
608 assert_eq!(
609 result, 2,
610 "Should return worker with smallest logit when temperature is 0"
611 );
612 }
613
614 logits.clear();
616 logits.insert(10, -1.0);
617 logits.insert(20, -5.0); logits.insert(30, 0.0);
619
620 let result = softmax_sample(&logits, 0.0);
621 assert_eq!(result, 20, "Should handle negative logits correctly");
622 }
623}