dynamo_llm/kv_router/
scheduler.rs1use dynamo_runtime::component::Namespace;
17use dynamo_runtime::traits::events::EventPublisher;
18use rand::Rng;
19use serde::{Deserialize, Serialize};
20use std::borrow::BorrowMut;
21use std::collections::HashMap;
22
23use crate::kv_router::indexer::OverlapScores;
24pub use crate::kv_router::protocols::ForwardPassMetrics;
25use crate::kv_router::scoring::ProcessedEndpoints;
26use crate::kv_router::KV_HIT_RATE_SUBJECT;
27
28use super::protocols::WorkerSelectionResult;
29use super::WorkerSelector;
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct KVHitRateEvent {
33 pub worker_id: i64,
34 pub isl_blocks: usize,
35 pub overlap_blocks: usize,
36}
37
38#[derive(Debug, thiserror::Error)]
39pub enum KvSchedulerError {
40 #[error("no endpoints aviailable to route work")]
41 NoEndpoints,
42
43 #[error("all workers busy")]
44 AllWorkersBusy,
45
46 #[error("endpoint subscriber shutdown")]
47 SubscriberShutdown,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct Endpoint {
54 pub name: String,
55 pub subject: String,
56 pub data: ForwardPassMetrics,
57}
58
59impl Endpoint {
60 pub fn worker_id(&self) -> i64 {
61 i64::from_str_radix(
62 self.subject
63 .split("-")
64 .last()
65 .expect("invalid subject")
66 .to_string()
67 .as_str(),
68 16,
69 )
70 .expect("invalid worker id")
71 }
72}
73
74pub struct SchedulingRequest {
75 pub isl_tokens: usize,
76 pub overlap: OverlapScores,
77 resp_tx: tokio::sync::oneshot::Sender<i64>,
78}
79
80impl SchedulingRequest {
81 pub fn respond(self, worker_id: i64) {
82 if self.resp_tx.send(worker_id).is_err() {
83 tracing::trace!("failed to send response to requestor");
84 }
85 }
86}
87
88pub struct KvScheduler {
89 request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
90}
91
92impl KvScheduler {
93 pub async fn start(
94 ns: Namespace,
95 block_size: usize,
96 endpoints_rx: tokio::sync::watch::Receiver<ProcessedEndpoints>,
97 selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
98 ) -> Result<Self, KvSchedulerError> {
99 let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector));
100 let mut endpoints_rx = endpoints_rx;
101 let mut endpoints: ProcessedEndpoints = endpoints_rx.borrow_and_update().clone();
102
103 let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();
104 tokio::spawn(async move {
105 let mut event_rx = event_rx;
106 while let Some(event) = event_rx.recv().await {
107 if let Err(e) = ns.publish(KV_HIT_RATE_SUBJECT, &event).await {
108 tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
109 }
110 }
111 });
112
113 let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
115 tracing::debug!("scheduler starting");
116 tokio::spawn(async move {
118 let mut request: SchedulingRequest;
119 let mut request_rx = request_rx;
120 tracing::debug!("scheduler background task started");
121
122 'outer: loop {
123 request = tokio::select! {
124 biased;
125
126 new_request = request_rx.recv() => {
127 match new_request {
128 Some(new_request) => {
129 tracing::trace!("received request to be scheduled");
130 new_request
131 },
132 None => {
133 tracing::trace!("scheduler shutdown");
134 break 'outer;
135 }
136 }
137 }
138
139 _ = endpoints_rx.changed() => {
140 endpoints = endpoints_rx.borrow_and_update().clone();
141 continue 'outer;
142 }
143 };
144 tracing::debug!("selected");
145 loop {
146 match selector.select_worker(&endpoints, &request, block_size) {
147 Ok(selection) => {
148 let worker_id = process_worker_selection(
149 endpoints.borrow_mut(),
150 selection,
151 &event_tx,
152 );
153 request.respond(worker_id);
154 continue 'outer;
155 }
156 Err(KvSchedulerError::AllWorkersBusy) => {
157 tracing::trace!("all workers busy; waiting for more capacity");
158 match endpoints_rx.changed().await {
159 Ok(_) => {}
160 Err(e) => {
161 tracing::error!("error waiting for endpoints change: {:?}", e);
162 break 'outer;
163 }
164 };
165 endpoints = endpoints_rx.borrow_and_update().clone();
166 }
167 Err(e) => {
168 tracing::error!("error scheduling request: {:?}", e);
169 break 'outer;
170 }
171 }
172 }
173 }
174
175 tracing::trace!("background endpoint subscriber shutting down");
176 });
177
178 Ok(KvScheduler { request_tx })
179 }
180
181 pub async fn schedule(
182 &self,
183 overlap: OverlapScores,
184 isl_tokens: usize,
185 ) -> Result<i64, KvSchedulerError> {
186 let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
187 let request = SchedulingRequest {
188 isl_tokens,
189 overlap,
190 resp_tx,
191 };
192 tracing::debug!("before sending request");
193 self.request_tx
194 .send(request)
195 .await
196 .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
197 tracing::debug!("after sending request");
198
199 let res = resp_rx
200 .await
201 .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
202 tracing::debug!("after receiving response");
203 Ok(res)
204 }
205}
206
207pub fn process_worker_selection(
209 workers: &mut ProcessedEndpoints,
210 selection: WorkerSelectionResult,
211 event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>,
212) -> i64 {
213 let worker = workers
214 .endpoints
215 .get_mut(&selection.worker_id)
216 .expect("worker not found");
217
218 worker.data.request_active_slots += 1;
220 worker.data.kv_active_blocks += selection.required_blocks - selection.overlap_blocks as u64;
221
222 if let Err(e) = event_tx.send(KVHitRateEvent {
224 worker_id: selection.worker_id,
225 isl_blocks: selection.required_blocks as usize,
226 overlap_blocks: selection.overlap_blocks,
227 }) {
228 tracing::warn!("Failed to send KV hit rate event: {:?}", e);
229 }
230
231 selection.worker_id
232}
233
234#[derive(Default)]
236pub struct DefaultWorkerSelector;
237
238impl WorkerSelector for DefaultWorkerSelector {
239 fn select_worker(
240 &self,
241 workers: &ProcessedEndpoints,
242 request: &SchedulingRequest,
243 block_size: usize,
244 ) -> Result<WorkerSelectionResult, KvSchedulerError> {
245 assert!(request.isl_tokens > 0);
246
247 let mut worker_scores = HashMap::new();
248 let mut max_active = 0.0;
249
250 for (worker_id, ep) in workers.endpoints.iter() {
252 if let Some(score) = request.overlap.scores.get(worker_id) {
254 let score = *score as f64 * block_size as f64 / request.isl_tokens as f64;
255 worker_scores.insert(worker_id, score);
256 }
257
258 max_active = f64::max(max_active, ep.data.request_active_slots as f64);
260 }
261
262 if max_active == 0.0 {
263 return Err(KvSchedulerError::NoEndpoints);
264 }
265
266 let worker_scores = worker_scores;
268 let max_active = max_active;
269
270 let mut best_logit = f64::NEG_INFINITY;
272 let mut best_workers = Vec::new();
273
274 for (worker_id, ep) in workers.endpoints.iter() {
275 let worker_id = *worker_id;
276
277 let score = worker_scores.get(&worker_id).copied().unwrap_or(0.0);
279
280 assert!(ep.data.kv_total_blocks > 0);
282 let gpu_cache_usage = ep.data.kv_active_blocks as f64 / ep.data.kv_total_blocks as f64;
283 let normalized_active = if max_active > 0.0 {
284 ep.data.request_active_slots as f64 / max_active
285 } else {
286 0.0
287 };
288
289 let logit = 2.0 * score - gpu_cache_usage - normalized_active;
291
292 tracing::info!(
293 "Formula for {}: {:.3} = 2.0 * {:.3} - {:.3} - {:.3}",
294 worker_id,
295 logit,
296 score,
297 gpu_cache_usage,
298 normalized_active
299 );
300
301 match logit.partial_cmp(&best_logit) {
303 Some(std::cmp::Ordering::Greater) => {
304 best_logit = logit;
305 best_workers.clear();
306 best_workers.push(worker_id);
307 }
308 Some(std::cmp::Ordering::Equal) => {
309 best_workers.push(worker_id);
310 }
311 _ => {}
312 }
313 }
314
315 if best_workers.is_empty() || best_logit == 0.0 {
317 return Err(KvSchedulerError::NoEndpoints);
318 }
319
320 let worker_id = if best_workers.len() == 1 {
321 best_workers[0]
322 } else {
323 let mut rng = rand::rng();
325 best_workers[rng.random_range(0..best_workers.len())]
326 };
327
328 tracing::info!("Selected worker: {}, logit: {:.3}", worker_id, best_logit);
330
331 let total_blocks = std::cmp::min(request.isl_tokens / block_size, 1) as u64;
332 let overlap_blocks = request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize;
333
334 Ok(WorkerSelectionResult {
335 worker_id,
336 required_blocks: total_blocks,
337 overlap_blocks,
338 })
339 }
340}