Skip to main content

dynamo_runtime/pipeline/network/egress/
push_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{AsyncEngineContextProvider, ResponseStream};
5use crate::error::{BackendError, DynamoError, ErrorType, match_error_chain};
6use crate::{
7    component::{
8        Client, DeviceType, Endpoint, RoutingOccupancyState, get_or_create_routing_occupancy_state,
9    },
10    dynamo_nvtx_range,
11    engine::{AsyncEngine, AsyncEngineContext, Data},
12    metrics::frontend_perf::STAGE_DURATION_SECONDS,
13    pipeline::{
14        AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
15        error::{PipelineError, PipelineErrorExt},
16    },
17    protocols::maybe_error::MaybeError,
18    traits::DistributedRuntimeProvider,
19};
20use async_trait::async_trait;
21use futures::Stream;
22use rand::Rng;
23use serde::{Deserialize, Serialize};
24use std::{
25    collections::HashMap,
26    marker::PhantomData,
27    pin::Pin,
28    sync::{
29        Arc,
30        atomic::{AtomicU64, Ordering},
31    },
32    task::Poll,
33    time::Instant,
34};
35use tokio_stream::StreamExt;
36use tracing::Instrument;
37
38/// Check if an error chain indicates the worker should be reported as down.
39fn is_inhibited(err: &(dyn std::error::Error + 'static)) -> bool {
40    const INHIBITED: &[ErrorType] = &[
41        ErrorType::CannotConnect,
42        ErrorType::Disconnected,
43        ErrorType::ConnectionTimeout,
44        ErrorType::ResponseTimeout,
45        ErrorType::Backend(BackendError::EngineShutdown),
46    ];
47    match_error_chain(err, INHIBITED, &[])
48}
49
50/// Read the backend response inactivity timeout from the environment.
51/// Reuses `DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS` — the same env var
52/// as the HTTP-layer safety net in `disconnect.rs`.
53fn response_inactivity_timeout() -> Option<std::time::Duration> {
54    use crate::config::environment_names::llm::DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS;
55    std::env::var(DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS)
56        .ok()
57        .and_then(|s| s.parse::<u64>().ok())
58        .filter(|&secs| secs > 0)
59        .map(std::time::Duration::from_secs)
60}
61
62struct OccupancyPermit {
63    state: Arc<RoutingOccupancyState>,
64    instance_id: u64,
65    armed: bool,
66}
67
68impl OccupancyPermit {
69    fn new(state: Arc<RoutingOccupancyState>, instance_id: u64) -> Self {
70        Self {
71            state,
72            instance_id,
73            armed: true,
74        }
75    }
76
77    fn into_tracked_stream<U: Data>(mut self, stream: ManyOut<U>) -> ManyOut<U> {
78        self.armed = false;
79        let engine_ctx = stream.context();
80        ResponseStream::new(
81            Box::pin(OccupancyTrackedStream {
82                inner: stream,
83                state: self.state.clone(),
84                instance_id: self.instance_id,
85            }),
86            engine_ctx,
87        )
88    }
89
90    fn instance_id(&self) -> u64 {
91        self.instance_id
92    }
93}
94
95impl Drop for OccupancyPermit {
96    fn drop(&mut self) {
97        if self.armed {
98            self.state.decrement(self.instance_id);
99        }
100    }
101}
102
103/// Trait for monitoring worker load and determining busy state.
104/// Implementations can define custom load metrics and busy thresholds.
105#[async_trait]
106pub trait WorkerLoadMonitor: Send + Sync {
107    /// Start background monitoring of worker load.
108    /// This should spawn background tasks that update the client's free instances.
109    async fn start_monitoring(&self) -> anyhow::Result<()>;
110}
111
112#[derive(Clone)]
113pub struct PushRouter<T, U>
114where
115    T: Data + Serialize,
116    U: Data + for<'de> Deserialize<'de>,
117{
118    // TODO: This shouldn't be pub, but lib/bindings/python/rust/lib.rs exposes it.
119    /// The Client is how we gather remote endpoint information from etcd.
120    pub client: Client,
121
122    /// How we choose which instance to send traffic to.
123    ///
124    /// Setting this to KV means we never intend to call `generate` on this PushRouter. We are
125    /// not using it as an AsyncEngine.
126    /// Instead we will decide whether to call random/round_robin/direct ourselves and call them directly.
127    /// dynamo-llm's KV Routing does this.
128    router_mode: RouterMode,
129
130    /// Number of round robin requests handled. Used to decide which server is next.
131    round_robin_counter: Arc<AtomicU64>,
132
133    /// The next step in the chain. PushRouter (this object) picks an instances,
134    /// addresses it, then passes it to AddressedPushRouter which does the network traffic.
135    addressed: Arc<AddressedPushRouter>,
136
137    /// When false, `generate_with_fault_detection` skips fault detection logic:
138    /// it won't call `report_instance_down` on errors, and it uses the raw discovery
139    /// instance list instead of the filtered avail list. Use for recovery/query paths
140    /// where transient failures are expected.
141    fault_detection_enabled: bool,
142
143    /// Cached response inactivity timeout. Read once at construction from
144    /// [`environment_names::llm::DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS`](crate::config::environment_names::llm::DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS) to avoid a syscall per request.
145    response_timeout: Option<std::time::Duration>,
146
147    /// Shared request occupancy state for tracked routing modes.
148    occupancy_state: Option<Arc<RoutingOccupancyState>>,
149
150    /// An internal Rust type. This says that PushRouter is generic over the T and U types,
151    /// which are the input and output types of it's `generate` function. It allows the
152    /// compiler to specialize us at compile time.
153    _phantom: PhantomData<(T, U)>,
154}
155
156#[derive(Default, Debug, Clone, Copy, PartialEq)]
157pub enum RouterMode {
158    #[default]
159    RoundRobin,
160    Random,
161    PowerOfTwoChoices,
162    KV,
163    Direct,
164    LeastLoaded,
165    /// Device-aware weighted routing for heterogeneous workers.
166    DeviceAwareWeighted,
167}
168
169impl RouterMode {
170    pub fn is_kv_routing(&self) -> bool {
171        *self == RouterMode::KV
172    }
173
174    pub fn is_direct_routing(&self) -> bool {
175        *self == RouterMode::Direct
176    }
177}
178
179/// Pick the instance with lower in-flight count from two random candidates.
180/// Returns the single instance if only one is available.
181fn p2c_select_from(occupancy_state: &RoutingOccupancyState, instance_ids: &[u64]) -> u64 {
182    let count = instance_ids.len();
183    if count == 1 {
184        return instance_ids[0];
185    }
186    let mut rng = rand::rng();
187    let idx1 = rng.random_range(0..count);
188    let idx2 = (idx1 + 1 + rng.random_range(0..count - 1)) % count;
189    let id1 = instance_ids[idx1];
190    let id2 = instance_ids[idx2];
191    let load1 = occupancy_state.load(id1);
192    let load2 = occupancy_state.load(id2);
193    let selected = if load1 <= load2 { id1 } else { id2 };
194    tracing::debug!(
195        candidate_a = id1,
196        candidate_a_load = load1,
197        candidate_b = id2,
198        candidate_b_load = load2,
199        selected = selected,
200        "p2c selection"
201    );
202    selected
203}
204
205/// Select the target device group for the next request in `DeviceAwareWeighted` mode.
206///
207/// If only one class exists (all CPU or all non-CPU), returns that class directly.
208/// If both classes exist, compares capability-normalized load and returns the less-loaded group.
209///
210/// Budget check (integer form):
211/// `allowed_cpu_inflight = total_non_cpu_inflight * cpu_count / (ratio * non_cpu_count)`
212/// and choose CPU when `total_cpu_inflight < allowed_cpu_inflight`.
213///
214/// `ratio` is `non_cpu_to_cpu_ratio` (from `DYN_ENCODER_CUDA_TO_CPU_RATIO`,
215/// default `8` in `device_aware_weighted`).
216fn device_aware_candidate_group(
217    state: &RoutingOccupancyState,
218    instance_ids: &[u64],
219    device_type_map: &HashMap<u64, Option<DeviceType>>,
220    non_cpu_to_cpu_ratio: usize,
221) -> Vec<u64> {
222    let cpu_ids: Vec<u64> = instance_ids
223        .iter()
224        .copied()
225        .filter(|id| matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
226        .collect();
227    let non_cpu_ids: Vec<u64> = instance_ids
228        .iter()
229        .copied()
230        .filter(|id| !matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
231        .collect();
232
233    if cpu_ids.is_empty() {
234        return non_cpu_ids;
235    }
236    if non_cpu_ids.is_empty() {
237        return cpu_ids;
238    }
239
240    // Both classes exist: compute a budget for CPU in-flight requests.
241    let total_non_cpu_inflight: u64 = non_cpu_ids.iter().map(|id| state.load(*id)).sum();
242    let total_cpu_inflight: u64 = cpu_ids.iter().map(|id| state.load(*id)).sum();
243    let cpu_count = cpu_ids.len() as u64;
244    let non_cpu_count = non_cpu_ids.len() as u64;
245    let allowed_cpu_inflight = total_non_cpu_inflight.saturating_mul(cpu_count)
246        / ((non_cpu_to_cpu_ratio as u64).saturating_mul(non_cpu_count));
247
248    if total_cpu_inflight < allowed_cpu_inflight {
249        cpu_ids
250    } else {
251        non_cpu_ids
252    }
253}
254
255async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
256    // Get network manager and create client (no mode checks!)
257    let manager = endpoint.drt().network_manager();
258    let req_client = manager.create_client()?;
259    let resp_transport = endpoint.drt().tcp_server().await?;
260
261    tracing::debug!(
262        transport = req_client.transport_name(),
263        "Creating AddressedPushRouter with request plane client"
264    );
265
266    AddressedPushRouter::new(req_client, resp_transport)
267}
268
269impl<T, U> PushRouter<T, U>
270where
271    T: Data + Serialize,
272    U: Data + for<'de> Deserialize<'de> + MaybeError,
273{
274    /// Create a new PushRouter without a worker load monitor (no busy detection)
275    pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
276        Self::from_client_with_monitor(client, router_mode, None).await
277    }
278
279    /// Create a new PushRouter with fault detection disabled.
280    ///
281    /// Unlike `from_client`, this router will not call `report_instance_down` on
282    /// transient errors, and `direct()` uses the raw discovery instance list instead
283    /// of the filtered avail list. Use for recovery/query paths.
284    pub async fn from_client_no_fault_detection(
285        client: Client,
286        router_mode: RouterMode,
287    ) -> anyhow::Result<Self> {
288        let addressed = addressed_router(&client.endpoint).await?;
289
290        let occupancy_state = if matches!(
291            router_mode,
292            RouterMode::PowerOfTwoChoices
293                | RouterMode::LeastLoaded
294                | RouterMode::DeviceAwareWeighted
295        ) {
296            Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
297        } else {
298            None
299        };
300
301        Ok(PushRouter {
302            client,
303            addressed,
304            router_mode,
305            round_robin_counter: Arc::new(AtomicU64::new(0)),
306            fault_detection_enabled: false,
307            response_timeout: response_inactivity_timeout(),
308            occupancy_state,
309            _phantom: PhantomData,
310        })
311    }
312
313    /// Create a new PushRouter with an optional worker load monitor.
314    ///
315    /// The rejection path is gated by `fault_detection_enabled` (true here);
316    /// busy detection itself is driven by the monitor via `client.update_free_instances(...)`.
317    /// If no thresholds are configured on the monitor (or no monitor is provided),
318    /// `client.instance_ids_free()` returns all instances and the gate never rejects.
319    pub async fn from_client_with_monitor(
320        client: Client,
321        router_mode: RouterMode,
322        worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
323    ) -> anyhow::Result<Self> {
324        let addressed = addressed_router(&client.endpoint).await?;
325
326        // Start worker monitor if provided and in dynamic mode
327        if let Some(monitor) = worker_monitor.as_ref() {
328            monitor.start_monitoring().await?;
329        }
330
331        let occupancy_state = if matches!(
332            router_mode,
333            RouterMode::PowerOfTwoChoices
334                | RouterMode::LeastLoaded
335                | RouterMode::DeviceAwareWeighted
336        ) {
337            Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
338        } else {
339            None
340        };
341
342        let router = PushRouter {
343            client,
344            addressed,
345            router_mode,
346            round_robin_counter: Arc::new(AtomicU64::new(0)),
347            fault_detection_enabled: true,
348            response_timeout: response_inactivity_timeout(),
349            occupancy_state,
350            _phantom: PhantomData,
351        };
352
353        Ok(router)
354    }
355
356    /// Issue a request to the next available instance in a round-robin fashion
357    pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
358        let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
359
360        let instance_id = {
361            let instance_ids = self.client.instance_ids_avail();
362            let count = instance_ids.len();
363            if count == 0 {
364                return Err(anyhow::anyhow!(
365                    "no instances found for endpoint {}",
366                    self.client.endpoint.id()
367                ));
368            }
369            instance_ids[counter % count]
370        };
371        tracing::trace!("round robin router selected {instance_id}");
372
373        self.generate_with_fault_detection(instance_id, request)
374            .await
375    }
376
377    /// Issue a request to a random endpoint
378    pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
379        let instance_id = {
380            let instance_ids = self.client.instance_ids_avail();
381            let count = instance_ids.len();
382            if count == 0 {
383                return Err(anyhow::anyhow!(
384                    "no instances found for endpoint {}",
385                    self.client.endpoint.id()
386                ));
387            }
388            let counter = rand::rng().random::<u64>() as usize;
389            instance_ids[counter % count]
390        };
391        tracing::trace!("random router selected {instance_id}");
392
393        self.generate_with_fault_detection(instance_id, request)
394            .await
395    }
396
397    /// Issue a request using power-of-two-choices: pick 2 random healthy workers,
398    /// route to the one with fewer in-flight requests.
399    pub async fn power_of_two_choices(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
400        let state = self.occupancy_state()?;
401        let instance_id = {
402            let instance_ids = self
403                .client
404                .instance_ids_avail()
405                .iter()
406                .copied()
407                .collect::<Vec<_>>();
408            if instance_ids.is_empty() {
409                return Err(anyhow::anyhow!(
410                    "no instances found for endpoint {}",
411                    self.client.endpoint.id()
412                ));
413            }
414            p2c_select_from(state.as_ref(), &instance_ids)
415        };
416        state.increment(instance_id);
417        let permit = OccupancyPermit::new(state, instance_id);
418
419        match self
420            .generate_with_fault_detection(instance_id, request)
421            .await
422        {
423            Ok(stream) => Ok(permit.into_tracked_stream(stream)),
424            Err(err) => Err(err),
425        }
426    }
427
428    /// Issue a request to a specific endpoint
429    pub async fn direct(
430        &self,
431        request: SingleIn<T>,
432        instance_id: u64,
433    ) -> anyhow::Result<ManyOut<U>> {
434        // When fault detection is disabled, check the raw discovery list
435        // (not filtered by report_instance_down) so transient failures
436        // don't poison the instance for subsequent retries.
437        let found = if self.fault_detection_enabled {
438            self.client.instance_ids_avail().contains(&instance_id)
439        } else {
440            self.client.instance_ids().contains(&instance_id)
441        };
442
443        if !found {
444            return Err(anyhow::anyhow!(
445                "instance_id={instance_id} not found for endpoint {}",
446                self.client.endpoint.id()
447            ));
448        }
449
450        self.generate_with_fault_detection(instance_id, request)
451            .await
452    }
453
454    /// Issue a request using device-aware weighted routing.
455    ///
456    /// Instances are partitioned by device type (CPU vs non-CPU), then the router
457    /// applies a budget policy and selects the least-loaded instance within the
458    /// chosen group.
459    ///
460    /// If only one device class exists (all CPU or all non-CPU), this naturally
461    /// degenerates to least-loaded routing over the available instances.
462    pub async fn device_aware_weighted(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
463        let state = self.occupancy_state()?;
464        let instance_ids = self
465            .client
466            .instance_ids_avail()
467            .iter()
468            .copied()
469            .collect::<Vec<_>>();
470
471        if instance_ids.is_empty() {
472            return Err(anyhow::anyhow!(
473                "no instances found for endpoint {}",
474                self.client.endpoint.id()
475            ));
476        }
477
478        // Apply a unified policy for all endpoints.
479        let endpoint_id = self.client.endpoint.id();
480
481        // For encoder endpoints, partition by device type
482        let instances = self.client.instances();
483        let device_type_map: std::collections::HashMap<u64, Option<DeviceType>> = instances
484            .iter()
485            .map(|inst| (inst.instance_id, inst.device_type.clone()))
486            .collect();
487
488        // Apply budget-based routing to determine which group to send to
489        let cuda_to_cpu_ratio = std::env::var("DYN_ENCODER_CUDA_TO_CPU_RATIO")
490            .ok()
491            .and_then(|v| v.parse::<usize>().ok())
492            .filter(|v| *v >= 1)
493            .unwrap_or(8);
494        let candidates = device_aware_candidate_group(
495            state.as_ref(),
496            &instance_ids,
497            &device_type_map,
498            cuda_to_cpu_ratio,
499        );
500
501        // Select least-loaded within the chosen group
502        let instance_id = state
503            .select_exact_min_and_increment(&candidates)
504            .await
505            .ok_or_else(|| {
506                anyhow::anyhow!(
507                    "no instances in selected device group for endpoint {}",
508                    endpoint_id
509                )
510            })?;
511        let permit = OccupancyPermit::new(state.clone(), instance_id);
512        let is_cpu = matches!(
513            device_type_map.get(&instance_id),
514            Some(Some(DeviceType::Cpu))
515        );
516        tracing::info!(
517            endpoint = %endpoint_id,
518            selected_instance = instance_id,
519            is_cpu,
520            "DeviceAwareWeighted selected instance"
521        );
522
523        match self
524            .generate_with_fault_detection(instance_id, request)
525            .await
526        {
527            Ok(stream) => Ok(permit.into_tracked_stream(stream)),
528            Err(err) => Err(err),
529        }
530    }
531
532    /// Issue a request to the instance with the fewest active connections.
533    pub async fn least_loaded(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
534        let state = self.occupancy_state()?;
535        let instance_ids = self
536            .client
537            .instance_ids_avail()
538            .iter()
539            .copied()
540            .collect::<Vec<_>>();
541        let instance_id = state
542            .select_exact_min_and_increment(&instance_ids)
543            .await
544            .ok_or_else(|| {
545                anyhow::anyhow!(
546                    "no instances found for endpoint {}",
547                    self.client.endpoint.id()
548                )
549            })?;
550        let permit = OccupancyPermit::new(state.clone(), instance_id);
551        tracing::trace!(
552            "least loaded router selected {instance_id} (connections: {})",
553            state.load(instance_id)
554        );
555
556        match self
557            .generate_with_fault_detection(instance_id, request)
558            .await
559        {
560            Ok(stream) => Ok(permit.into_tracked_stream(stream)),
561            Err(err) => Err(err),
562        }
563    }
564
565    /// Select the next worker according to the routing mode.
566    /// Increments round-robin counter if applicable.
567    /// Returns None for modes that require request lifecycle tracking or explicit routing hints.
568    pub fn select_next_worker(&self) -> Option<u64> {
569        let instance_ids = self.client.instance_ids_avail();
570        let count = instance_ids.len();
571        if count == 0 {
572            return None;
573        }
574
575        match self.router_mode {
576            RouterMode::RoundRobin => {
577                let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
578                Some(instance_ids[counter % count])
579            }
580            RouterMode::Random => {
581                let counter = rand::rng().random::<u64>() as usize;
582                Some(instance_ids[counter % count])
583            }
584            RouterMode::PowerOfTwoChoices
585            | RouterMode::Direct
586            | RouterMode::LeastLoaded
587            | RouterMode::DeviceAwareWeighted => None,
588            RouterMode::KV => {
589                panic!(
590                    "select_next_worker should not be called for {:?} routing mode",
591                    self.router_mode
592                )
593            }
594        }
595    }
596
597    /// Peek the next worker according to the routing mode without incrementing the counter.
598    /// Useful for checking if a worker is suitable before committing to it.
599    /// Returns None for modes that require request lifecycle tracking or explicit routing hints.
600    pub fn peek_next_worker(&self) -> Option<u64> {
601        let instance_ids = self.client.instance_ids_avail();
602        let count = instance_ids.len();
603        if count == 0 {
604            return None;
605        }
606
607        match self.router_mode {
608            RouterMode::RoundRobin => {
609                // Just peek at the current counter value without incrementing
610                let counter = self.round_robin_counter.load(Ordering::Relaxed) as usize;
611                Some(instance_ids[counter % count])
612            }
613            RouterMode::Random => {
614                // For random, peeking implies a fresh random selection since it's stateless.
615                // Note: The caller must realize that select_next_worker() will pick a DIFFERENT random worker.
616                let counter = rand::rng().random::<u64>() as usize;
617                Some(instance_ids[counter % count])
618            }
619            RouterMode::PowerOfTwoChoices
620            | RouterMode::Direct
621            | RouterMode::LeastLoaded
622            | RouterMode::DeviceAwareWeighted => None,
623            RouterMode::KV => {
624                panic!(
625                    "peek_next_worker should not be called for {:?} routing mode",
626                    self.router_mode
627                )
628            }
629        }
630    }
631
632    fn occupancy_state(&self) -> anyhow::Result<Arc<RoutingOccupancyState>> {
633        self.occupancy_state.clone().ok_or_else(|| {
634            anyhow::anyhow!(
635                "routing occupancy state not initialized for endpoint {}",
636                self.client.endpoint.id()
637            )
638        })
639    }
640
641    /*
642    pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
643        let subject = self.client.endpoint.subject();
644        tracing::debug!("static got subject: {subject}");
645        let request = request.map(|req| AddressedRequest::new(req, subject));
646        tracing::debug!("router generate");
647        self.addressed.generate(request).await
648    }
649    */
650
651    async fn generate_with_fault_detection(
652        &self,
653        mut instance_id: u64,
654        request: SingleIn<T>,
655    ) -> anyhow::Result<ManyOut<U>> {
656        let route_start = Instant::now();
657        let request_id = request.id().to_string();
658        let route_span = if matches!(self.router_mode, RouterMode::KV) {
659            tracing::Span::none()
660        } else {
661            tracing::info_span!(
662                "router.route_request",
663                request_id = %request_id,
664                worker_id = instance_id,
665                router_mode = ?self.router_mode,
666            )
667        };
668
669        // Check if all workers are busy (when fault detection is enabled).
670        if self.fault_detection_enabled {
671            let free_instances = self.client.instance_ids_free();
672            if free_instances.is_empty() {
673                // Check if we actually have any instances at all
674                let all_instances = self.client.instance_ids();
675                if !all_instances.is_empty() {
676                    tracing::warn!(
677                        instance_id,
678                        total_workers = all_instances.len(),
679                        "Rejecting request: all workers are busy"
680                    );
681                    let cause = PipelineError::ServiceOverloaded(
682                        "All workers are busy, please retry later".to_string(),
683                    );
684                    return Err(DynamoError::builder()
685                        .error_type(ErrorType::ResourceExhausted)
686                        .message("All workers are busy, please retry later")
687                        .cause(cause)
688                        .build()
689                        .into());
690                }
691            }
692        }
693
694        // Get the address based on discovered transport type.
695        // If the selected instance disappeared between selection and dispatch
696        // (e.g. deregistered during scale-down), fall back to another available
697        // instance rather than returning a spurious 500.
698        let (address, _transport_kind) = {
699            use crate::component::TransportType;
700
701            let resolve_transport = |id: u64| {
702                let instances = self.client.instances();
703                instances
704                    .iter()
705                    .find(|i| i.instance_id == id)
706                    .map(|instance| match &instance.transport {
707                        TransportType::Http(http_endpoint) => {
708                            tracing::debug!(
709                                instance_id = id,
710                                http_endpoint = %http_endpoint,
711                                "Using HTTP transport for instance"
712                            );
713                            (http_endpoint.clone(), "transport.http.request")
714                        }
715                        TransportType::Tcp(tcp_endpoint) => {
716                            tracing::debug!(
717                                instance_id = id,
718                                tcp_endpoint = %tcp_endpoint,
719                                "Using TCP transport for instance"
720                            );
721                            (tcp_endpoint.clone(), "transport.tcp.request")
722                        }
723                        TransportType::Nats(subject) => {
724                            tracing::debug!(
725                                instance_id = id,
726                                subject = %subject,
727                                "Using NATS transport for instance"
728                            );
729                            (subject.clone(), "transport.nats.request")
730                        }
731                    })
732            };
733
734            if let Some(result) = resolve_transport(instance_id) {
735                result
736            } else {
737                // Instance vanished — pick a different one from the current
738                // availability list and retry the lookup once.
739                let avail = self.client.instance_ids_avail();
740                let fallback_id = avail.iter().copied().find(|&id| id != instance_id);
741                match fallback_id {
742                    Some(id) => {
743                        tracing::warn!(
744                            original_instance = instance_id,
745                            fallback_instance = id,
746                            "Instance disappeared during routing, reselecting"
747                        );
748                        instance_id = id;
749                        resolve_transport(id).ok_or_else(|| {
750                            anyhow::anyhow!(
751                                "Fallback instance {} also not found for endpoint {}",
752                                id,
753                                self.client.endpoint.id()
754                            )
755                        })?
756                    }
757                    None => {
758                        return Err(anyhow::anyhow!(
759                            "Instance {} not found and no other instances available \
760                             for endpoint {}",
761                            instance_id,
762                            self.client.endpoint.id()
763                        ));
764                    }
765                }
766            }
767        };
768
769        let request = request.map(|req| AddressedRequest::new(req, address));
770
771        STAGE_DURATION_SECONDS
772            .with_label_values(&["route"])
773            .observe(route_start.elapsed().as_secs_f64());
774
775        let _nvtx_transport = dynamo_nvtx_range!(_transport_kind);
776        let stream: anyhow::Result<ManyOut<U>> = self
777            .addressed
778            .generate(request)
779            .instrument(route_span)
780            .await;
781        match stream {
782            Ok(stream) => {
783                if !self.fault_detection_enabled {
784                    return Ok(stream);
785                }
786                let engine_ctx = stream.context();
787                let client = self.client.clone();
788                let client_for_timeout = self.client.clone();
789                let stream = stream.map(move |res| {
790                    // Check if the error is migratable (indicates worker/connection failure)
791                    if let Some(err) = res.err()
792                        && is_inhibited(&err)
793                    {
794                        tracing::debug!(
795                            "Reporting instance {instance_id} down due to migratable error: {err}"
796                        );
797                        client.report_instance_down(instance_id);
798                    }
799                    res
800                });
801
802                // Request-plane inactivity timeout: emit a ResponseTimeout error item
803                // when the backend stops producing output. This triggers is_inhibited()
804                // → report_instance_down() to quarantine the worker.
805                let stream: Pin<Box<dyn Stream<Item = U> + Send>> = if let Some(timeout) =
806                    self.response_timeout
807                {
808                    Box::pin(async_stream::stream! {
809                        let mut inner = Box::pin(stream);
810                        loop {
811                            tokio::select! {
812                                biased;
813                                item = inner.next() => {
814                                    match item {
815                                        Some(item) => yield item,
816                                        None => break,
817                                    }
818                                }
819                                _ = tokio::time::sleep(timeout) => {
820                                    tracing::warn!(
821                                        instance_id,
822                                        timeout_secs = timeout.as_secs(),
823                                        "backend response inactivity timeout — quarantining worker"
824                                    );
825                                    client_for_timeout.report_instance_down(instance_id);
826                                    yield U::from_err(
827                                        crate::error::DynamoError::builder()
828                                            .error_type(crate::error::ErrorType::ResponseTimeout)
829                                            .message("backend response inactivity timeout")
830                                            .build()
831                                    );
832                                    break;
833                                }
834                            }
835                        }
836                    })
837                } else {
838                    Box::pin(stream)
839                };
840
841                Ok(ResponseStream::new(stream, engine_ctx))
842            }
843            Err(err) => {
844                if self.fault_detection_enabled && is_inhibited(err.as_ref()) {
845                    tracing::debug!("Reporting instance {instance_id} down due to error: {err}");
846                    self.client.report_instance_down(instance_id);
847                }
848                Err(err)
849            }
850        }
851    }
852}
853
854#[async_trait]
855impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
856where
857    T: Data + Serialize,
858    U: Data + for<'de> Deserialize<'de> + MaybeError,
859{
860    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
861        match self.router_mode {
862            RouterMode::Random => self.random(request).await,
863            RouterMode::RoundRobin => self.round_robin(request).await,
864            RouterMode::PowerOfTwoChoices => self.power_of_two_choices(request).await,
865            RouterMode::KV => {
866                anyhow::bail!("KV routing should not call generate on PushRouter");
867            }
868            RouterMode::Direct => {
869                anyhow::bail!(
870                    "Direct routing should not call generate on PushRouter directly; use DirectRoutingRouter wrapper"
871                );
872            }
873            RouterMode::LeastLoaded => self.least_loaded(request).await,
874            RouterMode::DeviceAwareWeighted => self.device_aware_weighted(request).await,
875        }
876    }
877}
878
879struct OccupancyTrackedStream<U: Data> {
880    inner: ManyOut<U>,
881    state: Arc<RoutingOccupancyState>,
882    instance_id: u64,
883}
884
885impl<U: Data> Drop for OccupancyTrackedStream<U> {
886    fn drop(&mut self) {
887        self.state.decrement(self.instance_id);
888    }
889}
890
891impl<U: Data> std::fmt::Debug for OccupancyTrackedStream<U> {
892    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
893        f.debug_struct("OccupancyTrackedStream")
894            .field("instance_id", &self.instance_id)
895            .finish()
896    }
897}
898
899impl<U: Data> Stream for OccupancyTrackedStream<U> {
900    type Item = U;
901
902    fn poll_next(
903        mut self: Pin<&mut Self>,
904        cx: &mut std::task::Context<'_>,
905    ) -> Poll<Option<Self::Item>> {
906        self.inner.as_mut().poll_next(cx)
907    }
908}
909
910impl<U: Data> AsyncEngineContextProvider for OccupancyTrackedStream<U> {
911    fn context(&self) -> Arc<dyn AsyncEngineContext> {
912        self.inner.context()
913    }
914}
915
916impl<U: Data> crate::engine::AsyncEngineStream<U> for OccupancyTrackedStream<U> {}
917
918#[cfg(test)]
919mod tests {
920    use super::*;
921    use crate::{
922        DistributedRuntime, Runtime,
923        distributed::DistributedConfig,
924        error::DynamoError,
925        pipeline::{ResponseStream, context::Controller},
926    };
927    use serde::{Deserialize, Serialize};
928
929    #[derive(Clone, Debug, Deserialize, Serialize)]
930    struct TestResponse {
931        error: Option<DynamoError>,
932    }
933
934    impl MaybeError for TestResponse {
935        fn from_err(err: impl std::error::Error + 'static) -> Self {
936            Self {
937                error: Some(DynamoError::from(
938                    Box::new(err) as Box<dyn std::error::Error + 'static>
939                )),
940            }
941        }
942
943        fn err(&self) -> Option<DynamoError> {
944            self.error.clone()
945        }
946    }
947
948    #[test]
949    fn p2c_selects_lower_load_worker() {
950        let state = RoutingOccupancyState::default();
951        for _ in 0..10 {
952            state.increment(1);
953        }
954        state.increment(2);
955
956        // With only two workers, p2c_select_from must pick both and choose id=2 (lower load).
957        let result = p2c_select_from(&state, &[1, 2]);
958        assert_eq!(result, 2);
959    }
960
961    #[test]
962    fn p2c_selects_single_worker() {
963        let state = RoutingOccupancyState::default();
964        assert_eq!(p2c_select_from(&state, &[42]), 42);
965    }
966
967    #[test]
968    fn p2c_treats_missing_counts_as_zero() {
969        let state = RoutingOccupancyState::default();
970        for _ in 0..5 {
971            state.increment(1);
972        }
973        // Worker 2 has no entry — should be treated as 0, so it wins.
974        let result = p2c_select_from(&state, &[1, 2]);
975        assert_eq!(result, 2);
976    }
977
978    #[test]
979    fn p2c_returns_valid_worker_on_tie() {
980        let state = RoutingOccupancyState::default();
981        for _ in 0..3 {
982            state.increment(1);
983            state.increment(2);
984        }
985
986        for _ in 0..100 {
987            let result = p2c_select_from(&state, &[1, 2]);
988            assert!(result == 1 || result == 2);
989        }
990    }
991
992    #[test]
993    fn occupancy_permit_decrements_before_stream_creation() {
994        let state = Arc::new(RoutingOccupancyState::default());
995        state.increment(42);
996        let permit = OccupancyPermit::new(state.clone(), 42);
997        assert_eq!(state.load(42), 1);
998        drop(permit);
999        assert_eq!(state.load(42), 0);
1000    }
1001
1002    #[test]
1003    fn occupancy_tracked_stream_decrements_on_drop() {
1004        let state = Arc::new(RoutingOccupancyState::default());
1005        state.increment(7);
1006        let permit = OccupancyPermit::new(state.clone(), 7);
1007        let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
1008        let stream = permit.into_tracked_stream(ResponseStream::new(
1009            Box::pin(tokio_stream::iter(vec![1u64])),
1010            ctx,
1011        ));
1012        assert_eq!(state.load(7), 1);
1013        drop(stream);
1014        assert_eq!(state.load(7), 0);
1015    }
1016
1017    #[test]
1018    fn p2c_lifecycle_tracks_inflight_counts_with_shared_tracker() {
1019        let state = Arc::new(RoutingOccupancyState::default());
1020        let mut permits = Vec::new();
1021        for _ in 0..5 {
1022            let selected = p2c_select_from(&state, &[1, 2]);
1023            state.increment(selected);
1024            permits.push(OccupancyPermit::new(state.clone(), selected));
1025        }
1026
1027        let total = state.load(1) + state.load(2);
1028        assert_eq!(total, 5, "5 in-flight requests should be tracked");
1029
1030        drop(permits);
1031        let total = state.load(1) + state.load(2);
1032        assert_eq!(total, 0, "All guards dropped, counts should be 0");
1033    }
1034
1035    #[test]
1036    fn p2c_never_selects_dominated_worker() {
1037        let state = RoutingOccupancyState::default();
1038        for _ in 0..100 {
1039            state.increment(3);
1040        }
1041
1042        let mut selected = [0u32; 3];
1043        for _ in 0..1000 {
1044            let result = p2c_select_from(&state, &[1, 2, 3]);
1045            match result {
1046                1 => selected[0] += 1,
1047                2 => selected[1] += 1,
1048                3 => selected[2] += 1,
1049                _ => panic!("unexpected worker id"),
1050            }
1051        }
1052        assert_eq!(
1053            selected[2], 0,
1054            "Worker 3 (load=100) should never be selected against load=0 workers, but got {} times",
1055            selected[2]
1056        );
1057    }
1058
1059    #[tokio::test]
1060    async fn least_loaded_selects_exact_min_and_tracks_counts() {
1061        let state = Arc::new(RoutingOccupancyState::default());
1062        state.increment(1);
1063        state.increment(1);
1064        state.increment(2);
1065
1066        let selected = state
1067            .select_exact_min_and_increment(&[1, 2, 3])
1068            .await
1069            .unwrap();
1070        assert_eq!(selected, 3);
1071
1072        let permit = OccupancyPermit::new(state.clone(), selected);
1073        assert_eq!(state.load(selected), 1);
1074        drop(permit);
1075        assert_eq!(state.load(selected), 0);
1076    }
1077
1078    #[tokio::test]
1079    async fn least_loaded_select_and_peek_return_none_with_available_worker() {
1080        let rt = Runtime::from_current().unwrap();
1081        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1082            .await
1083            .unwrap();
1084        let ns = drt
1085            .namespace("test_least_loaded_router".to_string())
1086            .unwrap();
1087        let component = ns.component("test_component".to_string()).unwrap();
1088        let endpoint = component.endpoint("test_endpoint".to_string());
1089        let client = endpoint.client().await.unwrap();
1090
1091        endpoint.register_endpoint_instance().await.unwrap();
1092        client.wait_for_instances().await.unwrap();
1093
1094        let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::LeastLoaded)
1095            .await
1096            .unwrap();
1097
1098        assert_eq!(router.select_next_worker(), None);
1099        assert_eq!(router.peek_next_worker(), None);
1100
1101        rt.shutdown();
1102    }
1103
1104    #[tokio::test]
1105    async fn device_aware_cpu_only_selects_least_loaded_instance() {
1106        let state = RoutingOccupancyState::default();
1107        // All candidates are CPU. Make worker 2 the least-loaded one.
1108        for _ in 0..3 {
1109            state.increment(1);
1110        }
1111        state.increment(3);
1112
1113        let instance_ids = vec![1, 2, 3];
1114        let device_type_map = HashMap::from([
1115            (1, Some(DeviceType::Cpu)),
1116            (2, Some(DeviceType::Cpu)),
1117            (3, Some(DeviceType::Cpu)),
1118        ]);
1119
1120        let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
1121        assert_eq!(candidates, vec![1, 2, 3]);
1122
1123        let selected = state
1124            .select_exact_min_and_increment(&candidates)
1125            .await
1126            .unwrap();
1127        assert_eq!(selected, 2);
1128    }
1129
1130    #[tokio::test]
1131    async fn device_aware_non_cpu_only_selects_least_loaded_instance() {
1132        let state = RoutingOccupancyState::default();
1133        // All candidates are non-CPU. Make worker 2 the least-loaded one.
1134        for _ in 0..3 {
1135            state.increment(1);
1136        }
1137        state.increment(3);
1138
1139        let instance_ids = vec![1, 2, 3];
1140        let device_type_map = HashMap::from([
1141            (1, Some(DeviceType::Cuda)),
1142            (2, Some(DeviceType::Cuda)),
1143            (3, Some(DeviceType::Cuda)),
1144        ]);
1145
1146        let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
1147        assert_eq!(candidates, vec![1, 2, 3]);
1148
1149        let selected = state
1150            .select_exact_min_and_increment(&candidates)
1151            .await
1152            .unwrap();
1153        assert_eq!(selected, 2);
1154    }
1155
1156    #[test]
1157    fn device_aware_group_uses_ratio_budget() {
1158        let state = RoutingOccupancyState::default();
1159        // CPU ids: 1,2 ; non-CPU ids: 3,4
1160        for _ in 0..4 {
1161            state.increment(3);
1162            state.increment(4);
1163        }
1164        // CPU inflight can differ across instances; budgeting uses total CPU inflight.
1165        for _ in 0..3 {
1166            state.increment(1);
1167        }
1168        // total_non_cpu_inflight=8, cpu_count=2, non_cpu_count=2, ratio=2
1169        // allowed_cpu_inflight = 8*2/(2*2)=4
1170        // total_cpu_inflight=3 < 4 => choose CPU group.
1171        let instance_ids = vec![1, 2, 3, 4];
1172        let device_type_map = HashMap::from([
1173            (1, Some(DeviceType::Cpu)),
1174            (2, Some(DeviceType::Cpu)),
1175            (3, Some(DeviceType::Cuda)),
1176            (4, Some(DeviceType::Cuda)),
1177        ]);
1178
1179        let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 2);
1180        assert_eq!(candidates, vec![1, 2]);
1181
1182        // Within selected CPU group, final choice should be the least-loaded instance (id=2).
1183        let selected =
1184            futures::executor::block_on(state.select_exact_min_and_increment(&candidates)).unwrap();
1185        assert_eq!(selected, 2);
1186    }
1187
1188    #[tokio::test]
1189    async fn device_aware_weighted_select_and_peek_return_none_with_available_worker() {
1190        let rt = Runtime::from_current().unwrap();
1191        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1192            .await
1193            .unwrap();
1194        let ns = drt
1195            .namespace("test_device_aware_router".to_string())
1196            .unwrap();
1197        let component = ns.component("test_component".to_string()).unwrap();
1198        let endpoint = component.endpoint("test_endpoint".to_string());
1199        let client = endpoint.client().await.unwrap();
1200
1201        endpoint.register_endpoint_instance().await.unwrap();
1202        client.wait_for_instances().await.unwrap();
1203
1204        let router =
1205            PushRouter::<u64, TestResponse>::from_client(client, RouterMode::DeviceAwareWeighted)
1206                .await
1207                .unwrap();
1208
1209        assert_eq!(router.select_next_worker(), None);
1210        assert_eq!(router.peek_next_worker(), None);
1211
1212        rt.shutdown();
1213    }
1214
1215    /// When the router selects an instance that has deregistered between selection
1216    /// and transport resolution, it should fall back to another available instance
1217    /// rather than returning a 500 error.
1218    #[tokio::test]
1219    async fn transport_resolution_falls_back_when_selected_instance_disappears() {
1220        let rt = Runtime::from_current().unwrap();
1221        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1222            .await
1223            .unwrap();
1224        let ns = drt
1225            .namespace("test_transport_fallback".to_string())
1226            .unwrap();
1227        let component = ns.component("test_component".to_string()).unwrap();
1228        let endpoint = component.endpoint("test_endpoint".to_string());
1229        let client = endpoint.client().await.unwrap();
1230
1231        // Register one real instance so it appears in instance_source.
1232        endpoint.register_endpoint_instance().await.unwrap();
1233        client.wait_for_instances().await.unwrap();
1234
1235        let real_id = client.instance_ids()[0];
1236
1237        // Inject a stale ID into instance_avail that does NOT exist in
1238        // instance_source. This simulates the race window where an instance
1239        // deregistered after selection but before transport resolution.
1240        let stale_id = real_id + 1000;
1241        client.override_instance_avail(vec![stale_id, real_id]);
1242
1243        // Build a router and call direct() targeting the *real* instance to
1244        // verify the router can still resolve transport for known instances.
1245        let router =
1246            PushRouter::<u64, TestResponse>::from_client(client.clone(), RouterMode::RoundRobin)
1247                .await
1248                .unwrap();
1249
1250        // Round robin should succeed — even if it picks stale_id first, the
1251        // fallback logic should resolve transport via real_id.
1252        // We cannot fully test the network send without a worker, but we can
1253        // verify it doesn't fail at the transport resolution stage by checking
1254        // that the error (if any) is a transport/network error, not
1255        // "Instance not found".
1256        let request = SingleIn::new(42u64);
1257        let result = router.generate(request).await;
1258
1259        // The request may fail at the network level (no actual worker), but it
1260        // must NOT fail with "Instance X not found" — that would mean the
1261        // fallback did not work.
1262        if let Err(err) = &result {
1263            let msg = format!("{err}");
1264            assert!(
1265                !msg.contains("not found"),
1266                "Transport resolution should have fallen back, but got: {msg}"
1267            );
1268        }
1269
1270        rt.shutdown();
1271    }
1272
1273    /// When no instances are available at all (both primary and fallback),
1274    /// the router should return a clear error.
1275    #[tokio::test]
1276    async fn transport_resolution_errors_when_no_instances_available() {
1277        let rt = Runtime::from_current().unwrap();
1278        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1279            .await
1280            .unwrap();
1281        let ns = drt
1282            .namespace("test_transport_no_fallback".to_string())
1283            .unwrap();
1284        let component = ns.component("test_component".to_string()).unwrap();
1285        let endpoint = component.endpoint("test_endpoint".to_string());
1286        let client = endpoint.client().await.unwrap();
1287
1288        // Register an instance so we can create the router (needs transport setup).
1289        endpoint.register_endpoint_instance().await.unwrap();
1290        client.wait_for_instances().await.unwrap();
1291
1292        let router =
1293            PushRouter::<u64, TestResponse>::from_client(client.clone(), RouterMode::RoundRobin)
1294                .await
1295                .unwrap();
1296
1297        // Override avail to contain only a stale ID with no real backing
1298        // instance AND no other available fallback.
1299        let stale_id = 99999;
1300        client.override_instance_avail(vec![stale_id]);
1301
1302        let request = SingleIn::new(42u64);
1303        let result = router.generate(request).await;
1304
1305        assert!(result.is_err());
1306        let msg = format!("{}", result.unwrap_err());
1307        assert!(
1308            msg.contains("not found") && msg.contains("no other instances available"),
1309            "Expected clear error about missing instance with no fallback, got: {msg}"
1310        );
1311
1312        rt.shutdown();
1313    }
1314}