dynamo_llm/kv_router/
scheduler.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use dynamo_runtime::component::Namespace;
17use dynamo_runtime::traits::events::EventPublisher;
18use rand::Rng;
19use serde::{Deserialize, Serialize};
20use std::borrow::BorrowMut;
21use std::collections::HashMap;
22
23use crate::kv_router::indexer::OverlapScores;
24pub use crate::kv_router::protocols::ForwardPassMetrics;
25use crate::kv_router::scoring::ProcessedEndpoints;
26use crate::kv_router::KV_HIT_RATE_SUBJECT;
27
28use super::protocols::WorkerSelectionResult;
29use super::WorkerSelector;
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct KVHitRateEvent {
33    pub worker_id: i64,
34    pub isl_blocks: usize,
35    pub overlap_blocks: usize,
36}
37
38#[derive(Debug, thiserror::Error)]
39pub enum KvSchedulerError {
40    #[error("no endpoints aviailable to route work")]
41    NoEndpoints,
42
43    #[error("all workers busy")]
44    AllWorkersBusy,
45
46    #[error("endpoint subscriber shutdown")]
47    SubscriberShutdown,
48}
49
50/// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
51/// is cleaned (not optional)
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct Endpoint {
54    pub name: String,
55    pub subject: String,
56    pub data: ForwardPassMetrics,
57}
58
59impl Endpoint {
60    pub fn worker_id(&self) -> i64 {
61        i64::from_str_radix(
62            self.subject
63                .split("-")
64                .last()
65                .expect("invalid subject")
66                .to_string()
67                .as_str(),
68            16,
69        )
70        .expect("invalid worker id")
71    }
72}
73
74pub struct SchedulingRequest {
75    pub isl_tokens: usize,
76    pub overlap: OverlapScores,
77    resp_tx: tokio::sync::oneshot::Sender<i64>,
78}
79
80impl SchedulingRequest {
81    pub fn respond(self, worker_id: i64) {
82        if self.resp_tx.send(worker_id).is_err() {
83            tracing::trace!("failed to send response to requestor");
84        }
85    }
86}
87
88pub struct KvScheduler {
89    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
90}
91
92impl KvScheduler {
93    pub async fn start(
94        ns: Namespace,
95        block_size: usize,
96        endpoints_rx: tokio::sync::watch::Receiver<ProcessedEndpoints>,
97        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
98    ) -> Result<Self, KvSchedulerError> {
99        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector));
100        let mut endpoints_rx = endpoints_rx;
101        let mut endpoints: ProcessedEndpoints = endpoints_rx.borrow_and_update().clone();
102
103        let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();
104        tokio::spawn(async move {
105            let mut event_rx = event_rx;
106            while let Some(event) = event_rx.recv().await {
107                if let Err(e) = ns.publish(KV_HIT_RATE_SUBJECT, &event).await {
108                    tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
109                }
110            }
111        });
112
113        // Channel to accept new scheduling requests
114        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
115        tracing::debug!("scheduler starting");
116        // Background task to handle scheduling requests
117        tokio::spawn(async move {
118            let mut request: SchedulingRequest;
119            let mut request_rx = request_rx;
120            tracing::debug!("scheduler background task started");
121
122            'outer: loop {
123                request = tokio::select! {
124                    biased;
125
126                    new_request = request_rx.recv() => {
127                        match new_request {
128                            Some(new_request) => {
129                                tracing::trace!("received request to be scheduled");
130                                new_request
131                            },
132                            None => {
133                                tracing::trace!("scheduler shutdown");
134                                break 'outer;
135                            }
136                        }
137                    }
138
139                    _ = endpoints_rx.changed() => {
140                        endpoints = endpoints_rx.borrow_and_update().clone();
141                        continue 'outer;
142                    }
143                };
144                tracing::debug!("selected");
145                loop {
146                    match selector.select_worker(&endpoints, &request, block_size) {
147                        Ok(selection) => {
148                            let worker_id = process_worker_selection(
149                                endpoints.borrow_mut(),
150                                selection,
151                                &event_tx,
152                            );
153                            request.respond(worker_id);
154                            continue 'outer;
155                        }
156                        Err(KvSchedulerError::AllWorkersBusy) => {
157                            tracing::trace!("all workers busy; waiting for more capacity");
158                            match endpoints_rx.changed().await {
159                                Ok(_) => {}
160                                Err(e) => {
161                                    tracing::error!("error waiting for endpoints change: {:?}", e);
162                                    break 'outer;
163                                }
164                            };
165                            endpoints = endpoints_rx.borrow_and_update().clone();
166                        }
167                        Err(e) => {
168                            tracing::error!("error scheduling request: {:?}", e);
169                            break 'outer;
170                        }
171                    }
172                }
173            }
174
175            tracing::trace!("background endpoint subscriber shutting down");
176        });
177
178        Ok(KvScheduler { request_tx })
179    }
180
181    pub async fn schedule(
182        &self,
183        overlap: OverlapScores,
184        isl_tokens: usize,
185    ) -> Result<i64, KvSchedulerError> {
186        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
187        let request = SchedulingRequest {
188            isl_tokens,
189            overlap,
190            resp_tx,
191        };
192        tracing::debug!("before sending request");
193        self.request_tx
194            .send(request)
195            .await
196            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
197        tracing::debug!("after sending request");
198
199        let res = resp_rx
200            .await
201            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
202        tracing::debug!("after receiving response");
203        Ok(res)
204    }
205}
206
207// This becomes the driver function that handles the selection result
208pub fn process_worker_selection(
209    workers: &mut ProcessedEndpoints,
210    selection: WorkerSelectionResult,
211    event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>,
212) -> i64 {
213    let worker = workers
214        .endpoints
215        .get_mut(&selection.worker_id)
216        .expect("worker not found");
217
218    // Update worker state
219    worker.data.request_active_slots += 1;
220    worker.data.kv_active_blocks += selection.required_blocks - selection.overlap_blocks as u64;
221
222    // Emit event
223    if let Err(e) = event_tx.send(KVHitRateEvent {
224        worker_id: selection.worker_id,
225        isl_blocks: selection.required_blocks as usize,
226        overlap_blocks: selection.overlap_blocks,
227    }) {
228        tracing::warn!("Failed to send KV hit rate event: {:?}", e);
229    }
230
231    selection.worker_id
232}
233
234// Default implementation matching the Python _cost_function
235#[derive(Default)]
236pub struct DefaultWorkerSelector;
237
238impl WorkerSelector for DefaultWorkerSelector {
239    fn select_worker(
240        &self,
241        workers: &ProcessedEndpoints,
242        request: &SchedulingRequest,
243        block_size: usize,
244    ) -> Result<WorkerSelectionResult, KvSchedulerError> {
245        assert!(request.isl_tokens > 0);
246
247        let mut worker_scores = HashMap::new();
248        let mut max_active = 0.0;
249
250        // Calculate worker scores and find max waiting requests
251        for (worker_id, ep) in workers.endpoints.iter() {
252            // Calculate score similar to Python version
253            if let Some(score) = request.overlap.scores.get(worker_id) {
254                let score = *score as f64 * block_size as f64 / request.isl_tokens as f64;
255                worker_scores.insert(worker_id, score);
256            }
257
258            // Track max waiting requests
259            max_active = f64::max(max_active, ep.data.request_active_slots as f64);
260        }
261
262        if max_active == 0.0 {
263            return Err(KvSchedulerError::NoEndpoints);
264        }
265
266        // make immutable
267        let worker_scores = worker_scores;
268        let max_active = max_active;
269
270        // Calculate logits for each worker
271        let mut best_logit = f64::NEG_INFINITY;
272        let mut best_workers = Vec::new();
273
274        for (worker_id, ep) in workers.endpoints.iter() {
275            let worker_id = *worker_id;
276
277            // Get score or default to 0.0
278            let score = worker_scores.get(&worker_id).copied().unwrap_or(0.0);
279
280            // Calculate normalized metrics
281            assert!(ep.data.kv_total_blocks > 0);
282            let gpu_cache_usage = ep.data.kv_active_blocks as f64 / ep.data.kv_total_blocks as f64;
283            let normalized_active = if max_active > 0.0 {
284                ep.data.request_active_slots as f64 / max_active
285            } else {
286                0.0
287            };
288
289            // Calculate logit using same formula as Python
290            let logit = 2.0 * score - gpu_cache_usage - normalized_active;
291
292            tracing::info!(
293                "Formula for {}: {:.3} = 2.0 * {:.3} - {:.3} - {:.3}",
294                worker_id,
295                logit,
296                score,
297                gpu_cache_usage,
298                normalized_active
299            );
300
301            // Track best workers
302            match logit.partial_cmp(&best_logit) {
303                Some(std::cmp::Ordering::Greater) => {
304                    best_logit = logit;
305                    best_workers.clear();
306                    best_workers.push(worker_id);
307                }
308                Some(std::cmp::Ordering::Equal) => {
309                    best_workers.push(worker_id);
310                }
311                _ => {}
312            }
313        }
314
315        // Return early if no valid workers found
316        if best_workers.is_empty() || best_logit == 0.0 {
317            return Err(KvSchedulerError::NoEndpoints);
318        }
319
320        let worker_id = if best_workers.len() == 1 {
321            best_workers[0]
322        } else {
323            // Randomly select from best workers
324            let mut rng = rand::rng();
325            best_workers[rng.random_range(0..best_workers.len())]
326        };
327
328        // Log selection metrics
329        tracing::info!("Selected worker: {}, logit: {:.3}", worker_id, best_logit);
330
331        let total_blocks = std::cmp::min(request.isl_tokens / block_size, 1) as u64;
332        let overlap_blocks = request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize;
333
334        Ok(WorkerSelectionResult {
335            worker_id,
336            required_blocks: total_blocks,
337            overlap_blocks,
338        })
339    }
340}