1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::{AsyncEngineContextProvider, ResponseStream};
use crate::error::{BackendError, ErrorType, match_error_chain};
/// Check if an error chain indicates the worker should be reported as down.
fn is_inhibited(err: &(dyn std::error::Error + 'static)) -> bool {
const INHIBITED: &[ErrorType] = &[
ErrorType::CannotConnect,
ErrorType::Disconnected,
ErrorType::ConnectionTimeout,
ErrorType::Backend(BackendError::EngineShutdown),
];
match_error_chain(err, INHIBITED, &[])
}
use crate::{
component::{Client, Endpoint},
engine::{AsyncEngine, Data},
pipeline::{
AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
error::{PipelineError, PipelineErrorExt},
},
protocols::maybe_error::MaybeError,
traits::DistributedRuntimeProvider,
};
use async_trait::async_trait;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::{
future::Future,
marker::PhantomData,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
};
use tokio_stream::StreamExt;
use tracing::Instrument;
/// Trait for monitoring worker load and determining busy state.
/// Implementations can define custom load metrics and busy thresholds.
#[async_trait]
pub trait WorkerLoadMonitor: Send + Sync {
/// Start background monitoring of worker load.
/// This should spawn background tasks that update the client's free instances.
async fn start_monitoring(&self) -> anyhow::Result<()>;
}
#[derive(Clone)]
pub struct PushRouter<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
{
// TODO: This shouldn't be pub, but lib/bindings/python/rust/lib.rs exposes it.
/// The Client is how we gather remote endpoint information from etcd.
pub client: Client,
/// How we choose which instance to send traffic to.
///
/// Setting this to KV means we never intend to call `generate` on this PushRouter. We are
/// not using it as an AsyncEngine.
/// Instead we will decide whether to call random/round_robin/direct ourselves and call them directly.
/// dynamo-llm's KV Routing does this.
router_mode: RouterMode,
/// Number of round robin requests handled. Used to decide which server is next.
round_robin_counter: Arc<AtomicU64>,
/// The next step in the chain. PushRouter (this object) picks an instances,
/// addresses it, then passes it to AddressedPushRouter which does the network traffic.
addressed: Arc<AddressedPushRouter>,
/// Threshold for determining when a worker is busy (0.0 to 1.0)
/// If None, busy detection is disabled
busy_threshold: Option<f64>,
/// When false, `generate_with_fault_detection` skips fault detection logic:
/// it won't call `report_instance_down` on errors, and it uses the raw discovery
/// instance list instead of the filtered avail list. Use for recovery/query paths
/// where transient failures are expected.
fault_detection_enabled: bool,
/// An internal Rust type. This says that PushRouter is generic over the T and U types,
/// which are the input and output types of it's `generate` function. It allows the
/// compiler to specialize us at compile time.
_phantom: PhantomData<(T, U)>,
}
#[derive(Default, Debug, Clone, Copy, PartialEq)]
pub enum RouterMode {
#[default]
RoundRobin,
Random,
KV,
Direct,
}
impl RouterMode {
pub fn is_kv_routing(&self) -> bool {
*self == RouterMode::KV
}
pub fn is_direct_routing(&self) -> bool {
*self == RouterMode::Direct
}
}
async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
// Get network manager and create client (no mode checks!)
let manager = endpoint.drt().network_manager();
let req_client = manager.create_client()?;
let resp_transport = endpoint.drt().tcp_server().await?;
tracing::debug!(
transport = req_client.transport_name(),
"Creating AddressedPushRouter with request plane client"
);
AddressedPushRouter::new(req_client, resp_transport)
}
impl<T, U> PushRouter<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de> + MaybeError,
{
/// Create a new PushRouter without busy threshold (no busy detection)
pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
Self::from_client_with_threshold(client, router_mode, None, None).await
}
/// Create a new PushRouter with fault detection disabled.
///
/// Unlike `from_client`, this router will not call `report_instance_down` on
/// transient errors, and `direct()` uses the raw discovery instance list instead
/// of the filtered avail list. Use for recovery/query paths.
pub async fn from_client_no_fault_detection(
client: Client,
router_mode: RouterMode,
) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?;
Ok(PushRouter {
client: client.clone(),
addressed,
router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)),
busy_threshold: None,
fault_detection_enabled: false,
_phantom: PhantomData,
})
}
/// Create a new PushRouter with optional busy threshold and worker load monitor
pub async fn from_client_with_threshold(
client: Client,
router_mode: RouterMode,
busy_threshold: Option<f64>,
worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?;
// Start worker monitor if provided and in dynamic mode
if let Some(monitor) = worker_monitor.as_ref() {
monitor.start_monitoring().await?;
}
let router = PushRouter {
client: client.clone(),
addressed,
router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)),
busy_threshold,
fault_detection_enabled: true,
_phantom: PhantomData,
};
Ok(router)
}
/// Issue a request to the next available instance in a round-robin fashion
pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
let instance_id = {
let instance_ids = self.client.instance_ids_avail();
let count = instance_ids.len();
if count == 0 {
return Err(anyhow::anyhow!(
"no instances found for endpoint {}",
self.client.endpoint.id()
));
}
instance_ids[counter % count]
};
tracing::trace!("round robin router selected {instance_id}");
self.generate_with_fault_detection(instance_id, request)
.await
}
/// Issue a request to a random endpoint
pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let instance_id = {
let instance_ids = self.client.instance_ids_avail();
let count = instance_ids.len();
if count == 0 {
return Err(anyhow::anyhow!(
"no instances found for endpoint {}",
self.client.endpoint.id()
));
}
let counter = rand::rng().random::<u64>() as usize;
instance_ids[counter % count]
};
tracing::trace!("random router selected {instance_id}");
self.generate_with_fault_detection(instance_id, request)
.await
}
/// Issue a request to a specific endpoint
pub async fn direct(
&self,
request: SingleIn<T>,
instance_id: u64,
) -> anyhow::Result<ManyOut<U>> {
// When fault detection is disabled, check the raw discovery list
// (not filtered by report_instance_down) so transient failures
// don't poison the instance for subsequent retries.
let found = if self.fault_detection_enabled {
self.client.instance_ids_avail().contains(&instance_id)
} else {
self.client.instance_ids().contains(&instance_id)
};
if !found {
return Err(anyhow::anyhow!(
"instance_id={instance_id} not found for endpoint {}",
self.client.endpoint.id()
));
}
self.generate_with_fault_detection(instance_id, request)
.await
}
/// Select the next worker according to the routing mode.
/// Increments round-robin counter if applicable.
/// Panics if called on Direct or KV mode - those have their own selection mechanisms.
pub fn select_next_worker(&self) -> Option<u64> {
let instance_ids = self.client.instance_ids_avail();
let count = instance_ids.len();
if count == 0 {
return None;
}
match self.router_mode {
RouterMode::RoundRobin => {
let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
Some(instance_ids[counter % count])
}
RouterMode::Random => {
let counter = rand::rng().random::<u64>() as usize;
Some(instance_ids[counter % count])
}
_ => {
panic!(
"select_next_worker should not be called for {:?} routing mode",
self.router_mode
)
}
}
}
/// Peek the next worker according to the routing mode without incrementing the counter.
/// Useful for checking if a worker is suitable before committing to it.
pub fn peek_next_worker(&self) -> Option<u64> {
let instance_ids = self.client.instance_ids_avail();
let count = instance_ids.len();
if count == 0 {
return None;
}
match self.router_mode {
RouterMode::RoundRobin => {
// Just peek at the current counter value without incrementing
let counter = self.round_robin_counter.load(Ordering::Relaxed) as usize;
Some(instance_ids[counter % count])
}
RouterMode::Random => {
// For random, peeking implies a fresh random selection since it's stateless.
// Note: The caller must realize that select_next_worker() will pick a DIFFERENT random worker.
let counter = rand::rng().random::<u64>() as usize;
Some(instance_ids[counter % count])
}
_ => {
panic!(
"peek_next_worker should not be called for {:?} routing mode",
self.router_mode
)
}
}
}
/*
pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let subject = self.client.endpoint.subject();
tracing::debug!("static got subject: {subject}");
let request = request.map(|req| AddressedRequest::new(req, subject));
tracing::debug!("router generate");
self.addressed.generate(request).await
}
*/
async fn generate_with_fault_detection(
&self,
instance_id: u64,
request: SingleIn<T>,
) -> anyhow::Result<ManyOut<U>> {
let request_id = request.id().to_string();
let route_span = if matches!(self.router_mode, RouterMode::KV) {
tracing::Span::none()
} else {
tracing::info_span!(
"router.route_request",
request_id = %request_id,
worker_id = instance_id,
router_mode = ?self.router_mode,
)
};
// Check if all workers are busy (only if busy threshold is set and fault detection enabled)
if self.fault_detection_enabled && self.busy_threshold.is_some() {
let free_instances = self.client.instance_ids_free();
if free_instances.is_empty() {
// Check if we actually have any instances at all
let all_instances = self.client.instance_ids();
if !all_instances.is_empty() {
tracing::warn!(
instance_id,
total_workers = all_instances.len(),
"Rejecting request: all workers are busy"
);
return Err(PipelineError::ServiceOverloaded(
"All workers are busy, please retry later".to_string(),
)
.into());
}
}
}
// Get the address based on discovered transport type
let address = {
use crate::component::TransportType;
// Get the instance and use its actual transport type
let instances = self.client.instances();
let instance = instances
.iter()
.find(|i| i.instance_id == instance_id)
.ok_or_else(|| {
anyhow::anyhow!("Instance {} not found in available instances", instance_id)
})?;
match &instance.transport {
TransportType::Http(http_endpoint) => {
tracing::debug!(
instance_id = instance_id,
http_endpoint = %http_endpoint,
"Using HTTP transport for instance"
);
http_endpoint.clone()
}
TransportType::Tcp(tcp_endpoint) => {
tracing::debug!(
instance_id = instance_id,
tcp_endpoint = %tcp_endpoint,
"Using TCP transport for instance"
);
tcp_endpoint.clone()
}
TransportType::Nats(subject) => {
tracing::debug!(
instance_id = instance_id,
subject = %subject,
"Using NATS transport for instance"
);
subject.clone()
}
}
};
let request = request.map(|req| AddressedRequest::new(req, address));
let stream: anyhow::Result<ManyOut<U>> = self
.addressed
.generate(request)
.instrument(route_span)
.await;
match stream {
Ok(stream) => {
if !self.fault_detection_enabled {
return Ok(stream);
}
let engine_ctx = stream.context();
let client = self.client.clone();
let stream = stream.map(move |res| {
// Check if the error is migratable (indicates worker/connection failure)
if let Some(err) = res.err()
&& is_inhibited(&err)
{
tracing::debug!(
"Reporting instance {instance_id} down due to migratable error: {err}"
);
client.report_instance_down(instance_id);
}
res
});
Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
}
Err(err) => {
if self.fault_detection_enabled && is_inhibited(err.as_ref()) {
tracing::debug!("Reporting instance {instance_id} down due to error: {err}");
self.client.report_instance_down(instance_id);
}
Err(err)
}
}
}
}
#[async_trait]
impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de> + MaybeError,
{
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
match self.router_mode {
RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::KV => {
anyhow::bail!("KV routing should not call generate on PushRouter");
}
RouterMode::Direct => {
anyhow::bail!(
"Direct routing should not call generate on PushRouter directly; use DirectRoutingRouter wrapper"
);
}
}
}
}