dynamo_llm/kv_router/
scheduler.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::local_model::runtime_config::ModelRuntimeConfig;
5use anyhow::Result;
6use dynamo_runtime::component::{Component, Instance};
7use dynamo_runtime::traits::DistributedRuntimeProvider;
8use dynamo_runtime::traits::events::EventPublisher;
9use rand::Rng;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::{RwLock, watch};
15
16use super::KV_HIT_RATE_SUBJECT;
17use super::KvRouterConfig;
18use super::RouterConfigOverride;
19use super::WorkerSelector;
20use super::indexer::OverlapScores;
21use super::protocols::WorkerSelectionResult;
22use super::sequence::ActiveSequencesMultiWorker;
23
24use crate::tokens::SequenceHash;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct KVHitRateEvent {
28    pub worker_id: i64,
29    pub isl_blocks: usize,
30    pub overlap_blocks: u32,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct PotentialLoad {
35    pub worker_id: i64,
36    pub potential_prefill_tokens: usize,
37    pub potential_decode_blocks: usize,
38}
39
40#[derive(Debug, thiserror::Error)]
41pub enum KvSchedulerError {
42    #[error("no endpoints aviailable to route work")]
43    NoEndpoints,
44
45    #[error("all workers busy")]
46    AllWorkersBusy,
47
48    #[error("endpoint subscriber shutdown")]
49    SubscriberShutdown,
50}
51
52#[derive(Debug)]
53pub struct SchedulingResponse {
54    pub best_worker_id: i64,
55    pub overlap_blocks: u32,
56}
57
58pub struct SchedulingRequest {
59    pub maybe_request_id: Option<String>,
60    pub token_seq: Option<Vec<SequenceHash>>,
61    pub isl_tokens: usize,
62    pub overlaps: OverlapScores,
63    pub decode_blocks: HashMap<i64, usize>,
64    pub prefill_tokens: HashMap<i64, usize>,
65    // Router config overrides for this specific request
66    pub router_config_override: Option<RouterConfigOverride>,
67    // Whether to update scheduler states (false for query_instance_id requests)
68    pub update_states: bool,
69    // Option to take it out to send the response without moving the struct
70    resp_tx: Option<tokio::sync::oneshot::Sender<SchedulingResponse>>,
71}
72
73impl SchedulingRequest {
74    pub fn respond(&mut self, response: SchedulingResponse) {
75        // Changed to &mut self
76        if let Some(tx) = self.resp_tx.take() {
77            // Use take() to extract the sender
78            if tx.send(response).is_err() {
79                tracing::error!("failed to send response to requestor");
80            }
81        } else {
82            tracing::error!("respond called multiple times on same request");
83        }
84    }
85}
86
87pub struct KvScheduler {
88    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
89    slots: Arc<ActiveSequencesMultiWorker>,
90}
91
92impl KvScheduler {
93    pub async fn start(
94        component: Component,
95        block_size: u32,
96        instances_rx: watch::Receiver<Vec<Instance>>,
97        runtime_configs_rx: watch::Receiver<HashMap<i64, ModelRuntimeConfig>>,
98        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
99        replica_sync: bool,
100        router_uuid: String,
101    ) -> Result<Self, KvSchedulerError> {
102        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
103        let instances: Vec<Instance> = instances_rx.borrow().clone();
104        let runtime_configs: HashMap<i64, ModelRuntimeConfig> = runtime_configs_rx.borrow().clone();
105
106        // Create shared workers_with_configs wrapped in Arc<RwLock>
107        let workers_with_configs: Arc<RwLock<HashMap<i64, Option<ModelRuntimeConfig>>>> = {
108            let mut initial_map = HashMap::new();
109            for instance in &instances {
110                let worker_id = instance.instance_id;
111                let config = runtime_configs.get(&worker_id).cloned();
112                if config.is_some() {
113                    tracing::info!("Runtime config found for worker_id: {}", worker_id);
114                }
115                initial_map.insert(worker_id, config);
116            }
117            Arc::new(RwLock::new(initial_map))
118        };
119
120        let worker_ids: Vec<i64> = instances
121            .iter()
122            .map(|instance| instance.instance_id)
123            .collect();
124        let slots = Arc::new(ActiveSequencesMultiWorker::new(
125            component.clone(),
126            block_size as usize,
127            worker_ids,
128            replica_sync,
129            router_uuid,
130        ));
131
132        // Spawn background task to monitor and update workers_with_configs
133        let workers_monitor = workers_with_configs.clone();
134        let slots_monitor = slots.clone();
135        let mut instances_monitor_rx = instances_rx.clone();
136        let mut configs_monitor_rx = runtime_configs_rx.clone();
137        let monitor_cancel_token = component.drt().primary_token();
138        tokio::spawn(async move {
139            tracing::trace!("workers monitoring task started");
140            loop {
141                // Wait for either instances or configs to change
142                tokio::select! {
143                    _ = monitor_cancel_token.cancelled() => {
144                        tracing::trace!("workers monitoring task shutting down");
145                        break;
146                    }
147                    result = instances_monitor_rx.changed() => {
148                        if result.is_err() {
149                            tracing::warn!("endpoint watch sender shutdown in monitor");
150                            break;
151                        }
152                    }
153                    result = configs_monitor_rx.changed() => {
154                        if result.is_err() {
155                            tracing::warn!("runtime configs watch sender shutdown in monitor");
156                            break;
157                        }
158                    }
159                }
160
161                // Get the latest values from both channels
162                let new_instances = instances_monitor_rx.borrow_and_update().clone();
163                let new_configs = configs_monitor_rx.borrow_and_update().clone();
164
165                // Update workers when instances change
166                let worker_ids: Vec<i64> = new_instances
167                    .iter()
168                    .map(|instance| instance.instance_id)
169                    .collect();
170                slots_monitor.update_workers(worker_ids);
171
172                // Update the shared workers_with_configs
173                let mut workers_map = workers_monitor.write().await;
174                workers_map.clear();
175                for instance in &new_instances {
176                    let worker_id = instance.instance_id;
177                    let config = new_configs.get(&worker_id).cloned();
178                    if config.is_some() {
179                        tracing::info!("Runtime config found for worker_id: {}", worker_id);
180                    }
181                    workers_map.insert(worker_id, config);
182                }
183                tracing::trace!(
184                    "Updated workers_with_configs with {} workers",
185                    workers_map.len()
186                );
187            }
188            tracing::trace!("workers monitoring task shutting down");
189        });
190
191        let slots_clone = slots.clone();
192        let workers_scheduler = workers_with_configs.clone();
193        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
194        let scheduler_cancel_token = component.drt().primary_token();
195        let ns_clone = component.namespace().clone();
196
197        // Background task to handle scheduling requests
198        tokio::spawn(async move {
199            let mut request_rx = request_rx;
200            tracing::trace!("scheduler background task started");
201
202            loop {
203                // Check for cancellation at beginning of loop
204                if scheduler_cancel_token.is_cancelled() {
205                    tracing::trace!("scheduler background task shutting down");
206                    break;
207                }
208
209                // Wait for a new request
210                let Some(mut request) = request_rx.recv().await else {
211                    tracing::warn!("scheduler shutdown");
212                    break;
213                };
214                tracing::trace!("received request to be scheduled");
215
216                let (decode_blocks, prefill_tokens) = slots_clone
217                    .potential_blocks_and_tokens(
218                        request.token_seq.clone(),
219                        request.isl_tokens,
220                        request.overlaps.clone(),
221                    )
222                    .await;
223                request.decode_blocks = decode_blocks;
224                request.prefill_tokens = prefill_tokens;
225
226                // Read the current workers configuration
227                let workers = workers_scheduler.read().await.clone();
228
229                match selector.select_worker(&workers, &request, block_size) {
230                    Ok(selection) => {
231                        let event = KVHitRateEvent {
232                            worker_id: selection.worker_id,
233                            isl_blocks: selection.required_blocks as usize,
234                            overlap_blocks: selection.overlap_blocks,
235                        };
236                        if let Err(e) = ns_clone.publish(KV_HIT_RATE_SUBJECT, &event).await {
237                            tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
238                        }
239
240                        let response = SchedulingResponse {
241                            best_worker_id: selection.worker_id,
242                            overlap_blocks: selection.overlap_blocks,
243                        };
244                        request.respond(response);
245
246                        // Skip state update if not requested
247                        if !request.update_states {
248                            continue;
249                        }
250
251                        let Some(request_id) = request.maybe_request_id else {
252                            tracing::error!(
253                                "No request_id provided to add_request to the slot tracker"
254                            );
255                            continue;
256                        };
257
258                        if let Err(e) = slots_clone
259                            .add_request(
260                                request_id.clone(),
261                                request.token_seq,
262                                request.isl_tokens,
263                                selection.overlap_blocks,
264                                selection.worker_id,
265                            )
266                            .await
267                        {
268                            tracing::warn!(
269                                "Failed to add request {request_id} to local slot tracker: {e:?}"
270                            );
271                        }
272                    }
273                    Err(KvSchedulerError::NoEndpoints) => {
274                        tracing::trace!("no endpoints available; waiting for endpoints update");
275                        tokio::time::sleep(Duration::from_millis(5)).await;
276                        continue;
277                    }
278                    // TODO: this is not actually hooked up
279                    Err(KvSchedulerError::AllWorkersBusy) => {
280                        tracing::trace!("all workers busy; waiting for more capacity");
281                        tokio::time::sleep(Duration::from_millis(5)).await;
282                        continue;
283                    }
284                    Err(e) => {
285                        tracing::error!("error scheduling request: {:?}", e);
286                        break;
287                    }
288                }
289            }
290
291            tracing::trace!("background endpoint subscriber shutting down");
292        });
293
294        Ok(KvScheduler { request_tx, slots })
295    }
296
297    pub async fn schedule(
298        &self,
299        maybe_request_id: Option<String>,
300        isl_tokens: usize,
301        token_seq: Option<Vec<SequenceHash>>,
302        overlaps: OverlapScores,
303        router_config_override: Option<&RouterConfigOverride>,
304        update_states: bool,
305    ) -> Result<i64, KvSchedulerError> {
306        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
307        let request = SchedulingRequest {
308            maybe_request_id,
309            token_seq,
310            isl_tokens,
311            overlaps,
312            decode_blocks: HashMap::new(),
313            prefill_tokens: HashMap::new(),
314            router_config_override: router_config_override.cloned(),
315            update_states,
316            resp_tx: Some(resp_tx), // Wrap in Some()
317        };
318
319        self.request_tx
320            .send(request)
321            .await
322            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
323        let response = resp_rx
324            .await
325            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
326
327        let best_worker_id = response.best_worker_id;
328        Ok(best_worker_id)
329    }
330
331    pub async fn add_request(
332        &self,
333        request_id: String,
334        token_sequence: Option<Vec<SequenceHash>>,
335        isl: usize,
336        overlap: u32,
337        worker_id: i64,
338    ) {
339        let _ = self
340            .slots
341            .add_request(request_id, token_sequence, isl, overlap, worker_id)
342            .await;
343    }
344
345    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<()> {
346        self.slots
347            .mark_prefill_completed(&request_id.to_string())
348            .await
349    }
350
351    pub async fn free(&self, request_id: &str) -> Result<()> {
352        self.slots.free(&request_id.to_string()).await
353    }
354
355    pub async fn get_potential_loads(
356        &self,
357        token_seq: Option<Vec<SequenceHash>>,
358        isl_tokens: usize,
359        overlaps: OverlapScores,
360    ) -> Vec<PotentialLoad> {
361        let (decode_blocks, prefill_tokens) = self
362            .slots
363            .potential_blocks_and_tokens(token_seq, isl_tokens, overlaps)
364            .await;
365
366        // Get all unique worker IDs from both hashmaps
367        let mut worker_ids: HashSet<i64> = HashSet::new();
368        worker_ids.extend(decode_blocks.keys().copied());
369        worker_ids.extend(prefill_tokens.keys().copied());
370
371        // Create PotentialLoad for each worker
372        let mut loads = Vec::new();
373        for worker_id in worker_ids {
374            loads.push(PotentialLoad {
375                worker_id,
376                potential_prefill_tokens: prefill_tokens
377                    .get(&worker_id)
378                    .copied()
379                    .unwrap_or(isl_tokens),
380                potential_decode_blocks: decode_blocks.get(&worker_id).copied().unwrap_or(0),
381            });
382        }
383
384        loads
385    }
386}
387
388// Helper function for softmax sampling
389fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 {
390    if logits.is_empty() {
391        panic!("Empty logits for softmax sampling");
392    }
393
394    // Guard: if temperature is 0, return the key with the smallest logit value
395    if temperature == 0.0 {
396        // Find the minimum logit value
397        let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b));
398
399        // Collect all keys with the minimum logit value (to handle ties)
400        let min_keys: Vec<_> = logits
401            .iter()
402            .filter(|&(_, &v)| v == min_logit)
403            .map(|(k, _)| *k)
404            .collect();
405
406        // Randomly select from the minimum keys (handles single key case naturally)
407        let mut rng = rand::rng();
408        let index = rng.random_range(0..min_keys.len());
409        return min_keys[index];
410    }
411
412    let keys: Vec<_> = logits.keys().copied().collect();
413    let values: Vec<_> = logits.values().copied().collect();
414
415    // Find min and max for normalization
416    let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
417    let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
418
419    let probabilities = if min_val == max_val {
420        // All values are the same, uniform probability
421        vec![1.0 / keys.len() as f64; keys.len()]
422    } else {
423        // Normalize values
424        let normalized: Vec<_> = values
425            .iter()
426            .map(|&v| {
427                // Lower is better, so negate
428                // Note we don't need to do actual min-max norm here, just off by an offset
429                let norm = v / (max_val - min_val);
430                -norm
431            })
432            .collect();
433
434        // Apply temperature and softmax
435        let scaled: Vec<_> = normalized.iter().map(|&v| v / temperature).collect();
436
437        let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
438        let exp_values: Vec<_> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();
439
440        let sum_exp: f64 = exp_values.iter().sum();
441        exp_values.iter().map(|&v| v / sum_exp).collect()
442    };
443
444    // Sample from the probability distribution
445    let mut rng = rand::rng();
446    let sample: f64 = rng.random();
447
448    let mut cumsum = 0.0;
449    for (i, &prob) in probabilities.iter().enumerate() {
450        cumsum += prob;
451        if sample <= cumsum {
452            return keys[i];
453        }
454    }
455
456    // Fallback to last key (shouldn't normally reach here)
457    keys[keys.len() - 1]
458}
459
460// Default implementation matching the Python _cost_function
461#[derive(Debug, Clone, Default)]
462pub struct DefaultWorkerSelector {
463    pub kv_router_config: KvRouterConfig,
464}
465
466impl DefaultWorkerSelector {
467    pub fn new(kv_router_config: Option<KvRouterConfig>) -> Self {
468        Self {
469            kv_router_config: kv_router_config.unwrap_or_default(),
470        }
471    }
472}
473
474impl WorkerSelector for DefaultWorkerSelector {
475    fn select_worker(
476        &self,
477        workers: &HashMap<i64, Option<ModelRuntimeConfig>>,
478        request: &SchedulingRequest,
479        block_size: u32,
480    ) -> Result<WorkerSelectionResult, KvSchedulerError> {
481        assert!(request.isl_tokens > 0);
482
483        if workers.is_empty() {
484            return Err(KvSchedulerError::NoEndpoints);
485        }
486
487        let isl = request.isl_tokens;
488        let request_blocks = isl.div_ceil(block_size as usize);
489        let overlaps = &request.overlaps.scores;
490
491        let decode_blocks = &request.decode_blocks;
492        let prefill_tokens = &request.prefill_tokens;
493
494        let mut worker_logits = HashMap::new();
495        let mut max_logit = f64::NEG_INFINITY;
496
497        // Calculate logits for each worker
498        for worker_id in workers.keys() {
499            let overlap = *overlaps.get(worker_id).unwrap_or(&0);
500
501            // this is the number of prefill tokens the worker would have if the request were scheduled there
502            let prefill_token = *prefill_tokens.get(worker_id).unwrap_or(&isl);
503            let potential_prefill_block = (prefill_token as f64) / (block_size as f64);
504
505            // this is the number of decode blocks the worker would have if the request were scheduled there
506            let decode_block = *decode_blocks
507                .get(worker_id)
508                .unwrap_or(&(potential_prefill_block.floor() as usize))
509                as f64;
510
511            // Use override if provided, otherwise use default config
512            let overlap_weight = request
513                .router_config_override
514                .as_ref()
515                .and_then(|cfg| cfg.overlap_score_weight)
516                .unwrap_or(self.kv_router_config.overlap_score_weight);
517
518            // Calculate logit (lower is better)
519            let logit = overlap_weight * potential_prefill_block + decode_block;
520            max_logit = max_logit.max(logit);
521
522            worker_logits.insert(*worker_id, logit);
523
524            tracing::info!(
525                "Formula for {worker_id} with {overlap} cached blocks: {logit:.3} \
526                 = {overlap_weight:.1} * prefill_blocks + decode_blocks \
527                 = {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}"
528            );
529        }
530
531        // Use softmax sampling to select worker
532        // Use override if provided, otherwise use default config
533        let temperature = request
534            .router_config_override
535            .as_ref()
536            .and_then(|cfg| cfg.router_temperature)
537            .unwrap_or(self.kv_router_config.router_temperature);
538        let best_worker_id = softmax_sample(&worker_logits, temperature);
539        let best_logit = worker_logits[&best_worker_id];
540
541        let best_overlap = *overlaps.get(&best_worker_id).unwrap_or(&0);
542        let total_blocks_info = workers
543            .get(&best_worker_id)
544            .and_then(|cfg| cfg.as_ref())
545            .and_then(|cfg| cfg.total_kv_blocks)
546            .map(|blocks| format!(", total blocks: {}", blocks))
547            .unwrap_or_default();
548
549        tracing::info!(
550            "Selected worker: {}, logit: {:.3}, cached blocks: {}{}",
551            best_worker_id,
552            best_logit,
553            best_overlap,
554            total_blocks_info
555        );
556
557        Ok(WorkerSelectionResult {
558            worker_id: best_worker_id,
559            required_blocks: request_blocks as u64,
560            overlap_blocks: overlaps.get(&best_worker_id).copied().unwrap_or(0),
561        })
562    }
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568
569    #[test]
570    fn test_softmax_sample_single_key() {
571        // Test that with a single key, softmax_sample always returns that key
572        let mut logits = HashMap::new();
573        let worker_id = 42;
574        logits.insert(worker_id, 0.5); // The value doesn't matter
575
576        // Test with different temperatures
577        for temperature in &[0.1, 1.0, 10.0] {
578            let result = softmax_sample(&logits, *temperature);
579            assert_eq!(result, worker_id, "Should return the only available worker");
580        }
581
582        // Test with different logit values
583        logits.clear();
584        logits.insert(worker_id, -100.0); // Very negative value
585        assert_eq!(softmax_sample(&logits, 1.0), worker_id);
586
587        logits.clear();
588        logits.insert(worker_id, 100.0); // Very positive value
589        assert_eq!(softmax_sample(&logits, 1.0), worker_id);
590
591        logits.clear();
592        logits.insert(worker_id, 0.0); // Zero value
593        assert_eq!(softmax_sample(&logits, 1.0), worker_id);
594    }
595
596    #[test]
597    fn test_softmax_sample_zero_temperature() {
598        // Test that with temperature 0, softmax_sample returns the key with smallest logit
599        let mut logits = HashMap::new();
600        logits.insert(1, 5.0);
601        logits.insert(2, 3.0); // This has the smallest logit
602        logits.insert(3, 7.0);
603        logits.insert(4, 3.5);
604
605        // With temperature 0, should always return worker 2 (smallest logit)
606        for _ in 0..10 {
607            let result = softmax_sample(&logits, 0.0);
608            assert_eq!(
609                result, 2,
610                "Should return worker with smallest logit when temperature is 0"
611            );
612        }
613
614        // Test with negative values
615        logits.clear();
616        logits.insert(10, -1.0);
617        logits.insert(20, -5.0); // This has the smallest logit
618        logits.insert(30, 0.0);
619
620        let result = softmax_sample(&logits, 0.0);
621        assert_eq!(result, 20, "Should handle negative logits correctly");
622    }
623}