Skip to main content

dynamo_runtime/pipeline/network/egress/
push_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{AsyncEngineContextProvider, ResponseStream};
5use crate::error::{BackendError, DynamoError, ErrorType, match_error_chain};
6use crate::{
7    component::{
8        Client, DeviceType, Endpoint, Instance, RoutingOccupancyState,
9        get_or_create_routing_occupancy_state,
10    },
11    discovery::EndpointInstanceId,
12    dynamo_nvtx_range,
13    engine::{AsyncEngine, AsyncEngineContext, Data},
14    metrics::frontend_perf::{STAGE_DURATION_SECONDS, STAGE_ROUTE},
15    pipeline::{
16        AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
17        error::{PipelineError, PipelineErrorExt},
18    },
19    protocols::{EndpointId, maybe_error::MaybeError},
20    traits::DistributedRuntimeProvider,
21};
22use async_trait::async_trait;
23use futures::Stream;
24use rand::Rng;
25use serde::{Deserialize, Serialize};
26use std::{
27    collections::HashMap,
28    marker::PhantomData,
29    pin::Pin,
30    sync::{
31        Arc,
32        atomic::{AtomicU64, Ordering},
33    },
34    task::Poll,
35    time::Instant,
36};
37use tokio_stream::StreamExt;
38use tracing::Instrument;
39
40/// Check if an error chain indicates the worker should be reported as down.
41fn is_inhibited(err: &(dyn std::error::Error + 'static)) -> bool {
42    const INHIBITED: &[ErrorType] = &[
43        ErrorType::CannotConnect,
44        ErrorType::Disconnected,
45        ErrorType::ConnectionTimeout,
46        ErrorType::ResponseTimeout,
47        ErrorType::Backend(BackendError::EngineShutdown),
48    ];
49    match_error_chain(err, INHIBITED, &[])
50}
51
52/// Read the backend response inactivity timeout from the environment.
53/// Reuses `DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS` — the same env var
54/// as the HTTP-layer safety net in `disconnect.rs`.
55fn response_inactivity_timeout() -> Option<std::time::Duration> {
56    use crate::config::environment_names::llm::DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS;
57    std::env::var(DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS)
58        .ok()
59        .and_then(|s| s.parse::<u64>().ok())
60        .filter(|&secs| secs > 0)
61        .map(std::time::Duration::from_secs)
62}
63
64struct OccupancyPermit {
65    state: Arc<RoutingOccupancyState>,
66    instance_id: u64,
67    armed: bool,
68}
69
70impl OccupancyPermit {
71    fn new(state: Arc<RoutingOccupancyState>, instance_id: u64) -> Self {
72        Self {
73            state,
74            instance_id,
75            armed: true,
76        }
77    }
78
79    fn into_tracked_stream<U: Data>(mut self, stream: ManyOut<U>) -> ManyOut<U> {
80        self.armed = false;
81        let engine_ctx = stream.context();
82        ResponseStream::new(
83            Box::pin(OccupancyTrackedStream {
84                inner: stream,
85                state: self.state.clone(),
86                instance_id: self.instance_id,
87            }),
88            engine_ctx,
89        )
90    }
91
92    fn instance_id(&self) -> u64 {
93        self.instance_id
94    }
95}
96
97impl Drop for OccupancyPermit {
98    fn drop(&mut self) {
99        if self.armed {
100            self.state.decrement(self.instance_id);
101        }
102    }
103}
104
105/// Trait for monitoring worker load and determining busy state.
106/// Implementations can define custom load metrics and busy thresholds.
107#[async_trait]
108pub trait WorkerLoadMonitor: Send + Sync {
109    /// Start background monitoring of worker load.
110    /// This should spawn background tasks that update the client's free instances.
111    async fn start_monitoring(&self) -> anyhow::Result<()>;
112}
113
114#[derive(Clone)]
115pub struct PushRouter<T, U>
116where
117    T: Data + Serialize,
118    U: Data + for<'de> Deserialize<'de>,
119{
120    // TODO: This shouldn't be pub, but lib/bindings/python/rust/lib.rs exposes it.
121    /// The Client is how we gather remote endpoint information from etcd.
122    pub client: Client,
123
124    /// How we choose which instance to send traffic to.
125    ///
126    /// Setting this to KV means we never intend to call `generate` on this PushRouter. We are
127    /// not using it as an AsyncEngine.
128    /// Instead we will decide whether to call random/round_robin/direct ourselves and call them directly.
129    /// dynamo-llm's KV Routing does this.
130    router_mode: RouterMode,
131
132    /// Number of round robin requests handled. Used to decide which server is next.
133    round_robin_counter: Arc<AtomicU64>,
134
135    /// The next step in the chain. PushRouter (this object) picks an instances,
136    /// addresses it, then passes it to AddressedPushRouter which does the network traffic.
137    addressed: Arc<AddressedPushRouter>,
138
139    /// When false, `generate_with_fault_detection` skips fault detection logic:
140    /// it won't call `report_instance_down` on errors, and it uses the raw discovery
141    /// instance list instead of the filtered avail list. Use for recovery/query paths
142    /// where transient failures are expected.
143    fault_detection_enabled: bool,
144
145    /// Cached response inactivity timeout. Read once at construction from
146    /// [`environment_names::llm::DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS`](crate::config::environment_names::llm::DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS) to avoid a syscall per request.
147    response_timeout: Option<std::time::Duration>,
148
149    /// Shared request occupancy state for tracked routing modes.
150    occupancy_state: Option<Arc<RoutingOccupancyState>>,
151
152    /// An internal Rust type. This says that PushRouter is generic over the T and U types,
153    /// which are the input and output types of it's `generate` function. It allows the
154    /// compiler to specialize us at compile time.
155    _phantom: PhantomData<(T, U)>,
156}
157
158#[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
159#[serde(rename_all = "snake_case")]
160pub enum RouterMode {
161    #[default]
162    RoundRobin,
163    Random,
164    PowerOfTwoChoices,
165    KV,
166    Direct,
167    LeastLoaded,
168    /// Device-aware weighted routing for heterogeneous workers.
169    DeviceAwareWeighted,
170}
171
172impl RouterMode {
173    pub fn is_kv_routing(&self) -> bool {
174        *self == RouterMode::KV
175    }
176
177    pub fn is_direct_routing(&self) -> bool {
178        *self == RouterMode::Direct
179    }
180}
181
182/// Pick the instance with lower in-flight count from two random candidates.
183/// Returns the single instance if only one is available.
184fn p2c_select_from(occupancy_state: &RoutingOccupancyState, instance_ids: &[u64]) -> u64 {
185    let count = instance_ids.len();
186    if count == 1 {
187        return instance_ids[0];
188    }
189    let mut rng = rand::rng();
190    let idx1 = rng.random_range(0..count);
191    let idx2 = (idx1 + 1 + rng.random_range(0..count - 1)) % count;
192    let id1 = instance_ids[idx1];
193    let id2 = instance_ids[idx2];
194    let load1 = occupancy_state.load(id1);
195    let load2 = occupancy_state.load(id2);
196    let selected = if load1 <= load2 { id1 } else { id2 };
197    tracing::debug!(
198        candidate_a = id1,
199        candidate_a_load = load1,
200        candidate_b = id2,
201        candidate_b_load = load2,
202        selected = selected,
203        "p2c selection"
204    );
205    selected
206}
207
208/// Select the target device group for the next request in `DeviceAwareWeighted` mode.
209///
210/// If only one class exists (all CPU or all non-CPU), returns that class directly.
211/// If both classes exist, compares capability-normalized load and returns the less-loaded group.
212///
213/// Budget check (integer form):
214/// `allowed_cpu_inflight = total_non_cpu_inflight * cpu_count / (ratio * non_cpu_count)`
215/// and choose CPU when `total_cpu_inflight < allowed_cpu_inflight`.
216///
217/// `ratio` is `non_cpu_to_cpu_ratio` (from `DYN_ENCODER_CUDA_TO_CPU_RATIO`,
218/// default `8` in `device_aware_weighted`).
219fn device_aware_candidate_group(
220    state: &RoutingOccupancyState,
221    instance_ids: &[u64],
222    device_type_map: &HashMap<u64, Option<DeviceType>>,
223    non_cpu_to_cpu_ratio: usize,
224) -> Vec<u64> {
225    let cpu_ids: Vec<u64> = instance_ids
226        .iter()
227        .copied()
228        .filter(|id| matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
229        .collect();
230    let non_cpu_ids: Vec<u64> = instance_ids
231        .iter()
232        .copied()
233        .filter(|id| !matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
234        .collect();
235
236    if cpu_ids.is_empty() {
237        return non_cpu_ids;
238    }
239    if non_cpu_ids.is_empty() {
240        return cpu_ids;
241    }
242
243    // Both classes exist: compute a budget for CPU in-flight requests.
244    let total_non_cpu_inflight: u64 = non_cpu_ids.iter().map(|id| state.load(*id)).sum();
245    let total_cpu_inflight: u64 = cpu_ids.iter().map(|id| state.load(*id)).sum();
246    let cpu_count = cpu_ids.len() as u64;
247    let non_cpu_count = non_cpu_ids.len() as u64;
248    let allowed_cpu_inflight = total_non_cpu_inflight.saturating_mul(cpu_count)
249        / ((non_cpu_to_cpu_ratio as u64).saturating_mul(non_cpu_count));
250
251    if total_cpu_inflight < allowed_cpu_inflight {
252        cpu_ids
253    } else {
254        non_cpu_ids
255    }
256}
257
258/// At most one `list_and_watch` per endpoint, across all `PushRouter`
259/// instances. Entry removed on watcher exit so a later router can re-arm.
260static ENDPOINT_WATCHER_ACTIVE: std::sync::OnceLock<dashmap::DashMap<EndpointId, ()>> =
261    std::sync::OnceLock::new();
262
263/// Watch discovery for instance removals and cancel pending response-stream
264/// registrations on the removed instance, unblocking queued requests with
265/// a migratable `Disconnected` error. Uses raw `list_and_watch` events
266/// (not a coalesced snapshot diff) so a rapid remove→re-add of the same
267/// identity is not silently swallowed. Keyed by full `EndpointInstanceId`.
268fn spawn_instance_removal_watcher(
269    endpoint: Endpoint,
270    addressed: Arc<AddressedPushRouter>,
271    cancel_token: tokio_util::sync::CancellationToken,
272) {
273    use crate::discovery::{
274        DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
275    };
276    use tokio_stream::StreamExt as _;
277
278    // One watcher per endpoint: if one is already running, skip.
279    let guard = ENDPOINT_WATCHER_ACTIVE.get_or_init(dashmap::DashMap::new);
280    let endpoint_id = endpoint.id();
281    if guard.insert(endpoint_id.clone(), ()).is_some() {
282        tracing::debug!(
283            ?endpoint_id,
284            "Instance removal watcher already running for this endpoint, skipping"
285        );
286        return;
287    }
288
289    let endpoint_name = endpoint.name().to_string();
290
291    tokio::spawn(async move {
292        // Release on every exit path (including panic); a leaked entry
293        // silently disables removal cancellation until process restart.
294        struct GuardRelease(EndpointId);
295        impl Drop for GuardRelease {
296            fn drop(&mut self) {
297                if let Some(map) = ENDPOINT_WATCHER_ACTIVE.get() {
298                    map.remove(&self.0);
299                }
300            }
301        }
302        let _release = GuardRelease(endpoint_id);
303
304        let namespace = endpoint.component().namespace().name();
305        let component = endpoint.component().name().to_string();
306
307        // Reconnect on transient discovery failure; cancel-aware backoff.
308        const RECONNECT_BACKOFF: std::time::Duration = std::time::Duration::from_secs(5);
309        'reconnect: loop {
310            let query = DiscoveryQuery::Endpoint {
311                namespace: namespace.clone(),
312                component: component.clone(),
313                endpoint: endpoint_name.clone(),
314            };
315
316            let mut stream = match endpoint.drt().discovery().list_and_watch(query, None).await {
317                Ok(s) => s,
318                Err(e) => {
319                    tracing::warn!(
320                        endpoint = %endpoint_name,
321                        "Failed to start instance removal watcher (will retry): {e}"
322                    );
323                    tokio::select! {
324                        _ = tokio::time::sleep(RECONNECT_BACKOFF) => continue 'reconnect,
325                        _ = cancel_token.cancelled() => break 'reconnect,
326                    }
327                }
328            };
329
330            loop {
331                tokio::select! {
332                    event = stream.next() => {
333                        match event {
334                            Some(Ok(DiscoveryEvent::Removed(id))) => {
335                                if let DiscoveryInstanceId::Endpoint(eid) = &id {
336                                    let n = addressed.cancel_instance_streams(eid).await;
337                                    if n > 0 {
338                                        tracing::warn!(
339                                            namespace = %eid.namespace,
340                                            component = %eid.component,
341                                            endpoint = %eid.endpoint,
342                                            instance_id = eid.instance_id,
343                                            cancelled = n,
344                                            "Cancelled pending response streams for removed \
345                                             instance (discovery-driven cleanup)"
346                                        );
347                                    }
348                                }
349                            }
350                            Some(Ok(DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)))) => {
351                                let eid: EndpointInstanceId = inst.endpoint_instance_id();
352                                addressed.clear_instance_tombstone(&eid).await;
353                            }
354                            Some(Ok(_)) => {}
355                            Some(Err(e)) => {
356                                tracing::warn!(
357                                    endpoint = %endpoint_name,
358                                    "Instance removal watcher stream error: {e}"
359                                );
360                            }
361                            None => {
362                                tracing::warn!(
363                                    endpoint = %endpoint_name,
364                                    "Instance removal watcher stream ended; reconnecting"
365                                );
366                                continue 'reconnect;
367                            }
368                        }
369                    }
370                    _ = cancel_token.cancelled() => {
371                        break 'reconnect;
372                    }
373                }
374            }
375        }
376
377        tracing::debug!(endpoint = %endpoint_name, "Instance removal watcher exiting");
378    });
379}
380
381async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
382    // Get network manager and create client (no mode checks!)
383    let manager = endpoint.drt().network_manager();
384    let req_client = manager.create_client()?;
385    let resp_transport = endpoint.drt().tcp_server().await?;
386
387    tracing::debug!(
388        transport = req_client.transport_name(),
389        "Creating AddressedPushRouter with request plane client"
390    );
391
392    AddressedPushRouter::new(req_client, resp_transport)
393}
394
395impl<T, U> PushRouter<T, U>
396where
397    T: Data + Serialize,
398    U: Data + for<'de> Deserialize<'de> + MaybeError,
399{
400    /// Create a new PushRouter without a worker load monitor (no busy detection)
401    pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
402        Self::from_client_with_monitor(client, router_mode, None).await
403    }
404
405    /// Create a new PushRouter with fault detection disabled.
406    ///
407    /// Unlike `from_client`, this router will not call `report_instance_down` on
408    /// transient errors, and `direct()` uses the raw discovery instance list instead
409    /// of the filtered avail list. Use for recovery/query paths.
410    pub async fn from_client_no_fault_detection(
411        client: Client,
412        router_mode: RouterMode,
413    ) -> anyhow::Result<Self> {
414        let addressed = addressed_router(&client.endpoint).await?;
415
416        let occupancy_state = if matches!(
417            router_mode,
418            RouterMode::PowerOfTwoChoices
419                | RouterMode::LeastLoaded
420                | RouterMode::DeviceAwareWeighted
421        ) {
422            Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
423        } else {
424            None
425        };
426
427        // Cancel orphaned pending response streams when workers die.
428        spawn_instance_removal_watcher(
429            client.endpoint.clone(),
430            addressed.clone(),
431            client.endpoint.drt().primary_token(),
432        );
433
434        Ok(PushRouter {
435            client,
436            addressed,
437            router_mode,
438            round_robin_counter: Arc::new(AtomicU64::new(0)),
439            fault_detection_enabled: false,
440            response_timeout: response_inactivity_timeout(),
441            occupancy_state,
442            _phantom: PhantomData,
443        })
444    }
445
446    /// Create a new PushRouter with an optional worker load monitor.
447    ///
448    /// The rejection path is gated by `fault_detection_enabled` (true here);
449    /// busy detection itself is driven by the monitor via `client.update_free_instances(...)`.
450    /// If no thresholds are configured on the monitor (or no monitor is provided),
451    /// `client.instance_ids_free()` returns all instances and the gate never rejects.
452    pub async fn from_client_with_monitor(
453        client: Client,
454        router_mode: RouterMode,
455        worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
456    ) -> anyhow::Result<Self> {
457        let addressed = addressed_router(&client.endpoint).await?;
458
459        // Start worker monitor if provided and in dynamic mode
460        if let Some(monitor) = worker_monitor.as_ref() {
461            monitor.start_monitoring().await?;
462        }
463
464        let occupancy_state = if matches!(
465            router_mode,
466            RouterMode::PowerOfTwoChoices
467                | RouterMode::LeastLoaded
468                | RouterMode::DeviceAwareWeighted
469        ) {
470            Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
471        } else {
472            None
473        };
474
475        // Cancel orphaned pending response streams when workers die.
476        spawn_instance_removal_watcher(
477            client.endpoint.clone(),
478            addressed.clone(),
479            client.endpoint.drt().primary_token(),
480        );
481
482        let router = PushRouter {
483            client,
484            addressed,
485            router_mode,
486            round_robin_counter: Arc::new(AtomicU64::new(0)),
487            fault_detection_enabled: true,
488            response_timeout: response_inactivity_timeout(),
489            occupancy_state,
490            _phantom: PhantomData,
491        };
492
493        Ok(router)
494    }
495
496    /// Issue a request to the next available instance in a round-robin fashion
497    pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
498        let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
499
500        let instance_id = {
501            let instance_ids = self.client.instance_ids_avail();
502            let count = instance_ids.len();
503            if count == 0 {
504                return Err(anyhow::anyhow!(
505                    "no instances found for endpoint {}",
506                    self.client.endpoint.id()
507                ));
508            }
509            instance_ids[counter % count]
510        };
511        tracing::trace!("round robin router selected {instance_id}");
512
513        self.generate_with_fault_detection(instance_id, request)
514            .await
515    }
516
517    /// Issue a request to a random endpoint
518    pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
519        let instance_id = {
520            let instance_ids = self.client.instance_ids_avail();
521            let count = instance_ids.len();
522            if count == 0 {
523                return Err(anyhow::anyhow!(
524                    "no instances found for endpoint {}",
525                    self.client.endpoint.id()
526                ));
527            }
528            let counter = rand::rng().random::<u64>() as usize;
529            instance_ids[counter % count]
530        };
531        tracing::trace!("random router selected {instance_id}");
532
533        self.generate_with_fault_detection(instance_id, request)
534            .await
535    }
536
537    /// Issue a request using power-of-two-choices: pick 2 random healthy workers,
538    /// route to the one with fewer in-flight requests.
539    pub async fn power_of_two_choices(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
540        let state = self.occupancy_state()?;
541        let instance_id = {
542            let instance_ids = self
543                .client
544                .instance_ids_avail()
545                .iter()
546                .copied()
547                .collect::<Vec<_>>();
548            if instance_ids.is_empty() {
549                return Err(anyhow::anyhow!(
550                    "no instances found for endpoint {}",
551                    self.client.endpoint.id()
552                ));
553            }
554            p2c_select_from(state.as_ref(), &instance_ids)
555        };
556        state.increment(instance_id);
557        let permit = OccupancyPermit::new(state, instance_id);
558
559        match self
560            .generate_with_fault_detection(instance_id, request)
561            .await
562        {
563            Ok(stream) => Ok(permit.into_tracked_stream(stream)),
564            Err(err) => Err(err),
565        }
566    }
567
568    /// Issue a request to a specific endpoint
569    pub async fn direct(
570        &self,
571        request: SingleIn<T>,
572        instance_id: u64,
573    ) -> anyhow::Result<ManyOut<U>> {
574        // When fault detection is disabled, check the raw discovery list
575        // (not filtered by report_instance_down) so transient failures
576        // don't poison the instance for subsequent retries.
577        let found = if self.fault_detection_enabled {
578            self.client.instance_ids_avail().contains(&instance_id)
579        } else {
580            self.client.instance_ids().contains(&instance_id)
581        };
582
583        if !found {
584            return Err(anyhow::anyhow!(
585                "instance_id={instance_id} not found for endpoint {}",
586                self.client.endpoint.id()
587            ));
588        }
589
590        self.generate_with_fault_detection(instance_id, request)
591            .await
592    }
593
594    /// Issue a request using device-aware weighted routing.
595    ///
596    /// Instances are partitioned by device type (CPU vs non-CPU), then the router
597    /// applies a budget policy and selects the least-loaded instance within the
598    /// chosen group.
599    ///
600    /// If only one device class exists (all CPU or all non-CPU), this naturally
601    /// degenerates to least-loaded routing over the available instances.
602    pub async fn device_aware_weighted(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
603        let state = self.occupancy_state()?;
604        let instance_ids = self
605            .client
606            .instance_ids_avail()
607            .iter()
608            .copied()
609            .collect::<Vec<_>>();
610
611        if instance_ids.is_empty() {
612            return Err(anyhow::anyhow!(
613                "no instances found for endpoint {}",
614                self.client.endpoint.id()
615            ));
616        }
617
618        // Apply a unified policy for all endpoints.
619        let endpoint_id = self.client.endpoint.id();
620
621        // For encoder endpoints, partition by device type
622        let instances = self.client.instances();
623        let device_type_map: std::collections::HashMap<u64, Option<DeviceType>> = instances
624            .iter()
625            .map(|inst| (inst.instance_id, inst.device_type.clone()))
626            .collect();
627
628        // Apply budget-based routing to determine which group to send to
629        let cuda_to_cpu_ratio = std::env::var("DYN_ENCODER_CUDA_TO_CPU_RATIO")
630            .ok()
631            .and_then(|v| v.parse::<usize>().ok())
632            .filter(|v| *v >= 1)
633            .unwrap_or(8);
634        let candidates = device_aware_candidate_group(
635            state.as_ref(),
636            &instance_ids,
637            &device_type_map,
638            cuda_to_cpu_ratio,
639        );
640
641        // Select least-loaded within the chosen group
642        let instance_id = state
643            .select_exact_min_and_increment(&candidates)
644            .await
645            .ok_or_else(|| {
646                anyhow::anyhow!(
647                    "no instances in selected device group for endpoint {}",
648                    endpoint_id
649                )
650            })?;
651        let permit = OccupancyPermit::new(state.clone(), instance_id);
652        let is_cpu = matches!(
653            device_type_map.get(&instance_id),
654            Some(Some(DeviceType::Cpu))
655        );
656        tracing::info!(
657            endpoint = %endpoint_id,
658            selected_instance = instance_id,
659            is_cpu,
660            "DeviceAwareWeighted selected instance"
661        );
662
663        match self
664            .generate_with_fault_detection(instance_id, request)
665            .await
666        {
667            Ok(stream) => Ok(permit.into_tracked_stream(stream)),
668            Err(err) => Err(err),
669        }
670    }
671
672    /// Issue a request to the instance with the fewest active connections.
673    pub async fn least_loaded(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
674        let state = self.occupancy_state()?;
675        let instance_ids = self
676            .client
677            .instance_ids_avail()
678            .iter()
679            .copied()
680            .collect::<Vec<_>>();
681        let instance_id = state
682            .select_exact_min_and_increment(&instance_ids)
683            .await
684            .ok_or_else(|| {
685                anyhow::anyhow!(
686                    "no instances found for endpoint {}",
687                    self.client.endpoint.id()
688                )
689            })?;
690        let permit = OccupancyPermit::new(state.clone(), instance_id);
691        tracing::trace!(
692            "least loaded router selected {instance_id} (connections: {})",
693            state.load(instance_id)
694        );
695
696        match self
697            .generate_with_fault_detection(instance_id, request)
698            .await
699        {
700            Ok(stream) => Ok(permit.into_tracked_stream(stream)),
701            Err(err) => Err(err),
702        }
703    }
704
705    /// Select the next worker according to the routing mode.
706    /// Increments round-robin counter if applicable.
707    /// Returns None for modes that require request lifecycle tracking or explicit routing hints.
708    pub fn select_next_worker(&self) -> Option<u64> {
709        let instance_ids = self.client.instance_ids_avail();
710        let count = instance_ids.len();
711        if count == 0 {
712            return None;
713        }
714
715        match self.router_mode {
716            RouterMode::RoundRobin => {
717                let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
718                Some(instance_ids[counter % count])
719            }
720            RouterMode::Random => {
721                let counter = rand::rng().random::<u64>() as usize;
722                Some(instance_ids[counter % count])
723            }
724            RouterMode::PowerOfTwoChoices
725            | RouterMode::Direct
726            | RouterMode::LeastLoaded
727            | RouterMode::DeviceAwareWeighted => None,
728            RouterMode::KV => {
729                panic!(
730                    "select_next_worker should not be called for {:?} routing mode",
731                    self.router_mode
732                )
733            }
734        }
735    }
736
737    /// Peek the next worker according to the routing mode without incrementing the counter.
738    /// Useful for checking if a worker is suitable before committing to it.
739    /// Returns None for modes that require request lifecycle tracking or explicit routing hints.
740    pub fn peek_next_worker(&self) -> Option<u64> {
741        let instance_ids = self.client.instance_ids_avail();
742        let count = instance_ids.len();
743        if count == 0 {
744            return None;
745        }
746
747        match self.router_mode {
748            RouterMode::RoundRobin => {
749                // Just peek at the current counter value without incrementing
750                let counter = self.round_robin_counter.load(Ordering::Relaxed) as usize;
751                Some(instance_ids[counter % count])
752            }
753            RouterMode::Random => {
754                // For random, peeking implies a fresh random selection since it's stateless.
755                // Note: The caller must realize that select_next_worker() will pick a DIFFERENT random worker.
756                let counter = rand::rng().random::<u64>() as usize;
757                Some(instance_ids[counter % count])
758            }
759            RouterMode::PowerOfTwoChoices
760            | RouterMode::Direct
761            | RouterMode::LeastLoaded
762            | RouterMode::DeviceAwareWeighted => None,
763            RouterMode::KV => {
764                panic!(
765                    "peek_next_worker should not be called for {:?} routing mode",
766                    self.router_mode
767                )
768            }
769        }
770    }
771
772    fn occupancy_state(&self) -> anyhow::Result<Arc<RoutingOccupancyState>> {
773        self.occupancy_state.clone().ok_or_else(|| {
774            anyhow::anyhow!(
775                "routing occupancy state not initialized for endpoint {}",
776                self.client.endpoint.id()
777            )
778        })
779    }
780
781    /*
782    pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
783        let subject = self.client.endpoint.subject();
784        tracing::debug!("static got subject: {subject}");
785        let request = request.map(|req| AddressedRequest::new(req, subject));
786        tracing::debug!("router generate");
787        self.addressed.generate(request).await
788    }
789    */
790
791    async fn generate_with_fault_detection(
792        &self,
793        mut instance_id: u64,
794        request: SingleIn<T>,
795    ) -> anyhow::Result<ManyOut<U>> {
796        let route_start = Instant::now();
797        let request_id = request.id().to_string();
798        let route_span = if matches!(self.router_mode, RouterMode::KV) {
799            tracing::Span::none()
800        } else {
801            tracing::info_span!(
802                "router.route_request",
803                request_id = %request_id,
804                worker_id = instance_id,
805                router_mode = ?self.router_mode,
806            )
807        };
808
809        // Check if all workers are busy (when fault detection is enabled).
810        if self.fault_detection_enabled {
811            let free_instances = self.client.instance_ids_free();
812            if free_instances.is_empty() {
813                // Check if we actually have any instances at all
814                let all_instances = self.client.instance_ids();
815                if !all_instances.is_empty() {
816                    tracing::warn!(
817                        instance_id,
818                        total_workers = all_instances.len(),
819                        "Rejecting request: all workers are busy"
820                    );
821                    let cause = PipelineError::ServiceOverloaded(
822                        "All workers are busy, please retry later".to_string(),
823                    );
824                    return Err(DynamoError::builder()
825                        .error_type(ErrorType::ResourceExhausted)
826                        .message("All workers are busy, please retry later")
827                        .cause(cause)
828                        .build()
829                        .into());
830                }
831            }
832        }
833
834        // Resolve transport address; if the selected instance disappeared
835        // between selection and dispatch, fall back to another available one.
836        let (address, _transport_kind, instance) = {
837            use crate::component::TransportType;
838
839            let resolve_transport = |id: u64| {
840                let instances = self.client.instances();
841                instances
842                    .iter()
843                    .find(|i| i.instance_id == id)
844                    .map(|instance| {
845                        let (addr, kind) = match &instance.transport {
846                            TransportType::Http(http_endpoint) => {
847                                tracing::debug!(
848                                    instance_id = id,
849                                    http_endpoint = %http_endpoint,
850                                    "Using HTTP transport for instance"
851                                );
852                                (http_endpoint.clone(), "transport.http.request")
853                            }
854                            TransportType::Tcp(tcp_endpoint) => {
855                                tracing::debug!(
856                                    instance_id = id,
857                                    tcp_endpoint = %tcp_endpoint,
858                                    "Using TCP transport for instance"
859                                );
860                                (tcp_endpoint.clone(), "transport.tcp.request")
861                            }
862                            TransportType::Nats(subject) => {
863                                tracing::debug!(
864                                    instance_id = id,
865                                    subject = %subject,
866                                    "Using NATS transport for instance"
867                                );
868                                (subject.clone(), "transport.nats.request")
869                            }
870                        };
871                        (addr, kind, instance.clone())
872                    })
873            };
874
875            if let Some(result) = resolve_transport(instance_id) {
876                result
877            } else {
878                // Instance vanished — pick a different one from the current
879                // availability list and retry the lookup once.
880                let avail = self.client.instance_ids_avail();
881                let fallback_id = avail.iter().copied().find(|&id| id != instance_id);
882                match fallback_id {
883                    Some(id) => {
884                        tracing::warn!(
885                            original_instance = instance_id,
886                            fallback_instance = id,
887                            "Instance disappeared during routing, reselecting"
888                        );
889                        instance_id = id;
890                        resolve_transport(id).ok_or_else(|| {
891                            anyhow::anyhow!(
892                                "Fallback instance {} also not found for endpoint {}",
893                                id,
894                                self.client.endpoint.id()
895                            )
896                        })?
897                    }
898                    None => {
899                        return Err(anyhow::anyhow!(
900                            "Instance {} not found and no other instances available \
901                             for endpoint {}",
902                            instance_id,
903                            self.client.endpoint.id()
904                        ));
905                    }
906                }
907            }
908        };
909
910        let request = request.map(|req| AddressedRequest::with_instance(req, address, instance));
911
912        STAGE_DURATION_SECONDS
913            .with_label_values(&[STAGE_ROUTE])
914            .observe(route_start.elapsed().as_secs_f64());
915
916        let _nvtx_transport = dynamo_nvtx_range!(_transport_kind);
917        let stream: anyhow::Result<ManyOut<U>> = self
918            .addressed
919            .generate(request)
920            .instrument(route_span)
921            .await;
922        match stream {
923            Ok(stream) => {
924                if !self.fault_detection_enabled {
925                    return Ok(stream);
926                }
927                let engine_ctx = stream.context();
928                let client = self.client.clone();
929                let client_for_timeout = self.client.clone();
930                let stream = stream.map(move |res| {
931                    // Check if the error is migratable (indicates worker/connection failure)
932                    if let Some(err) = res.err()
933                        && is_inhibited(&err)
934                    {
935                        tracing::debug!(
936                            "Reporting instance {instance_id} down due to migratable error: {err}"
937                        );
938                        client.report_instance_down(instance_id);
939                    }
940                    res
941                });
942
943                // Request-plane inactivity timeout: emit a ResponseTimeout error item
944                // when the backend stops producing output. This triggers is_inhibited()
945                // → report_instance_down() to quarantine the worker.
946                let stream: Pin<Box<dyn Stream<Item = U> + Send>> = if let Some(timeout) =
947                    self.response_timeout
948                {
949                    Box::pin(async_stream::stream! {
950                        let mut inner = Box::pin(stream);
951                        loop {
952                            tokio::select! {
953                                biased;
954                                item = inner.next() => {
955                                    match item {
956                                        Some(item) => yield item,
957                                        None => break,
958                                    }
959                                }
960                                _ = tokio::time::sleep(timeout) => {
961                                    tracing::warn!(
962                                        instance_id,
963                                        timeout_secs = timeout.as_secs(),
964                                        "backend response inactivity timeout — quarantining worker"
965                                    );
966                                    client_for_timeout.report_instance_down(instance_id);
967                                    yield U::from_err(
968                                        crate::error::DynamoError::builder()
969                                            .error_type(crate::error::ErrorType::ResponseTimeout)
970                                            .message("backend response inactivity timeout")
971                                            .build()
972                                    );
973                                    break;
974                                }
975                            }
976                        }
977                    })
978                } else {
979                    Box::pin(stream)
980                };
981
982                Ok(ResponseStream::new(stream, engine_ctx))
983            }
984            Err(err) => {
985                if self.fault_detection_enabled && is_inhibited(err.as_ref()) {
986                    tracing::debug!("Reporting instance {instance_id} down due to error: {err}");
987                    self.client.report_instance_down(instance_id);
988                }
989                Err(err)
990            }
991        }
992    }
993}
994
995#[async_trait]
996impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
997where
998    T: Data + Serialize,
999    U: Data + for<'de> Deserialize<'de> + MaybeError,
1000{
1001    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
1002        match self.router_mode {
1003            RouterMode::Random => self.random(request).await,
1004            RouterMode::RoundRobin => self.round_robin(request).await,
1005            RouterMode::PowerOfTwoChoices => self.power_of_two_choices(request).await,
1006            RouterMode::KV => {
1007                anyhow::bail!("KV routing should not call generate on PushRouter");
1008            }
1009            RouterMode::Direct => {
1010                anyhow::bail!(
1011                    "Direct routing should not call generate on PushRouter directly; use DirectRoutingRouter wrapper"
1012                );
1013            }
1014            RouterMode::LeastLoaded => self.least_loaded(request).await,
1015            RouterMode::DeviceAwareWeighted => self.device_aware_weighted(request).await,
1016        }
1017    }
1018}
1019
1020struct OccupancyTrackedStream<U: Data> {
1021    inner: ManyOut<U>,
1022    state: Arc<RoutingOccupancyState>,
1023    instance_id: u64,
1024}
1025
1026impl<U: Data> Drop for OccupancyTrackedStream<U> {
1027    fn drop(&mut self) {
1028        self.state.decrement(self.instance_id);
1029    }
1030}
1031
1032impl<U: Data> std::fmt::Debug for OccupancyTrackedStream<U> {
1033    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1034        f.debug_struct("OccupancyTrackedStream")
1035            .field("instance_id", &self.instance_id)
1036            .finish()
1037    }
1038}
1039
1040impl<U: Data> Stream for OccupancyTrackedStream<U> {
1041    type Item = U;
1042
1043    fn poll_next(
1044        mut self: Pin<&mut Self>,
1045        cx: &mut std::task::Context<'_>,
1046    ) -> Poll<Option<Self::Item>> {
1047        self.inner.as_mut().poll_next(cx)
1048    }
1049}
1050
1051impl<U: Data> AsyncEngineContextProvider for OccupancyTrackedStream<U> {
1052    fn context(&self) -> Arc<dyn AsyncEngineContext> {
1053        self.inner.context()
1054    }
1055}
1056
1057impl<U: Data> crate::engine::AsyncEngineStream<U> for OccupancyTrackedStream<U> {}
1058
1059#[cfg(test)]
1060mod tests {
1061    use super::*;
1062    use crate::{
1063        DistributedRuntime, Runtime,
1064        distributed::DistributedConfig,
1065        error::DynamoError,
1066        pipeline::{ResponseStream, context::Controller},
1067    };
1068    use serde::{Deserialize, Serialize};
1069
1070    #[derive(Clone, Debug, Deserialize, Serialize)]
1071    struct TestResponse {
1072        error: Option<DynamoError>,
1073    }
1074
1075    impl MaybeError for TestResponse {
1076        fn from_err(err: impl std::error::Error + 'static) -> Self {
1077            Self {
1078                error: Some(DynamoError::from(
1079                    Box::new(err) as Box<dyn std::error::Error + 'static>
1080                )),
1081            }
1082        }
1083
1084        fn err(&self) -> Option<DynamoError> {
1085            self.error.clone()
1086        }
1087    }
1088
1089    #[test]
1090    fn p2c_selects_lower_load_worker() {
1091        let state = RoutingOccupancyState::default();
1092        for _ in 0..10 {
1093            state.increment(1);
1094        }
1095        state.increment(2);
1096
1097        // With only two workers, p2c_select_from must pick both and choose id=2 (lower load).
1098        let result = p2c_select_from(&state, &[1, 2]);
1099        assert_eq!(result, 2);
1100    }
1101
1102    #[test]
1103    fn p2c_selects_single_worker() {
1104        let state = RoutingOccupancyState::default();
1105        assert_eq!(p2c_select_from(&state, &[42]), 42);
1106    }
1107
1108    #[test]
1109    fn p2c_treats_missing_counts_as_zero() {
1110        let state = RoutingOccupancyState::default();
1111        for _ in 0..5 {
1112            state.increment(1);
1113        }
1114        // Worker 2 has no entry — should be treated as 0, so it wins.
1115        let result = p2c_select_from(&state, &[1, 2]);
1116        assert_eq!(result, 2);
1117    }
1118
1119    #[test]
1120    fn p2c_returns_valid_worker_on_tie() {
1121        let state = RoutingOccupancyState::default();
1122        for _ in 0..3 {
1123            state.increment(1);
1124            state.increment(2);
1125        }
1126
1127        for _ in 0..100 {
1128            let result = p2c_select_from(&state, &[1, 2]);
1129            assert!(result == 1 || result == 2);
1130        }
1131    }
1132
1133    #[test]
1134    fn occupancy_permit_decrements_before_stream_creation() {
1135        let state = Arc::new(RoutingOccupancyState::default());
1136        state.increment(42);
1137        let permit = OccupancyPermit::new(state.clone(), 42);
1138        assert_eq!(state.load(42), 1);
1139        drop(permit);
1140        assert_eq!(state.load(42), 0);
1141    }
1142
1143    #[test]
1144    fn occupancy_tracked_stream_decrements_on_drop() {
1145        let state = Arc::new(RoutingOccupancyState::default());
1146        state.increment(7);
1147        let permit = OccupancyPermit::new(state.clone(), 7);
1148        let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
1149        let stream = permit.into_tracked_stream(ResponseStream::new(
1150            Box::pin(tokio_stream::iter(vec![1u64])),
1151            ctx,
1152        ));
1153        assert_eq!(state.load(7), 1);
1154        drop(stream);
1155        assert_eq!(state.load(7), 0);
1156    }
1157
1158    #[test]
1159    fn p2c_lifecycle_tracks_inflight_counts_with_shared_tracker() {
1160        let state = Arc::new(RoutingOccupancyState::default());
1161        let mut permits = Vec::new();
1162        for _ in 0..5 {
1163            let selected = p2c_select_from(&state, &[1, 2]);
1164            state.increment(selected);
1165            permits.push(OccupancyPermit::new(state.clone(), selected));
1166        }
1167
1168        let total = state.load(1) + state.load(2);
1169        assert_eq!(total, 5, "5 in-flight requests should be tracked");
1170
1171        drop(permits);
1172        let total = state.load(1) + state.load(2);
1173        assert_eq!(total, 0, "All guards dropped, counts should be 0");
1174    }
1175
1176    #[test]
1177    fn p2c_never_selects_dominated_worker() {
1178        let state = RoutingOccupancyState::default();
1179        for _ in 0..100 {
1180            state.increment(3);
1181        }
1182
1183        let mut selected = [0u32; 3];
1184        for _ in 0..1000 {
1185            let result = p2c_select_from(&state, &[1, 2, 3]);
1186            match result {
1187                1 => selected[0] += 1,
1188                2 => selected[1] += 1,
1189                3 => selected[2] += 1,
1190                _ => panic!("unexpected worker id"),
1191            }
1192        }
1193        assert_eq!(
1194            selected[2], 0,
1195            "Worker 3 (load=100) should never be selected against load=0 workers, but got {} times",
1196            selected[2]
1197        );
1198    }
1199
1200    #[tokio::test]
1201    async fn least_loaded_selects_exact_min_and_tracks_counts() {
1202        let state = Arc::new(RoutingOccupancyState::default());
1203        state.increment(1);
1204        state.increment(1);
1205        state.increment(2);
1206
1207        let selected = state
1208            .select_exact_min_and_increment(&[1, 2, 3])
1209            .await
1210            .unwrap();
1211        assert_eq!(selected, 3);
1212
1213        let permit = OccupancyPermit::new(state.clone(), selected);
1214        assert_eq!(state.load(selected), 1);
1215        drop(permit);
1216        assert_eq!(state.load(selected), 0);
1217    }
1218
1219    #[tokio::test]
1220    async fn least_loaded_select_and_peek_return_none_with_available_worker() {
1221        let rt = Runtime::from_current().unwrap();
1222        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1223            .await
1224            .unwrap();
1225        let ns = drt
1226            .namespace("test_least_loaded_router".to_string())
1227            .unwrap();
1228        let component = ns.component("test_component".to_string()).unwrap();
1229        let endpoint = component.endpoint("test_endpoint".to_string());
1230        let client = endpoint.client().await.unwrap();
1231
1232        endpoint.register_endpoint_instance().await.unwrap();
1233        client.wait_for_instances().await.unwrap();
1234
1235        let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::LeastLoaded)
1236            .await
1237            .unwrap();
1238
1239        assert_eq!(router.select_next_worker(), None);
1240        assert_eq!(router.peek_next_worker(), None);
1241
1242        rt.shutdown();
1243    }
1244
1245    #[tokio::test]
1246    async fn device_aware_cpu_only_selects_least_loaded_instance() {
1247        let state = RoutingOccupancyState::default();
1248        // All candidates are CPU. Make worker 2 the least-loaded one.
1249        for _ in 0..3 {
1250            state.increment(1);
1251        }
1252        state.increment(3);
1253
1254        let instance_ids = vec![1, 2, 3];
1255        let device_type_map = HashMap::from([
1256            (1, Some(DeviceType::Cpu)),
1257            (2, Some(DeviceType::Cpu)),
1258            (3, Some(DeviceType::Cpu)),
1259        ]);
1260
1261        let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
1262        assert_eq!(candidates, vec![1, 2, 3]);
1263
1264        let selected = state
1265            .select_exact_min_and_increment(&candidates)
1266            .await
1267            .unwrap();
1268        assert_eq!(selected, 2);
1269    }
1270
1271    #[tokio::test]
1272    async fn device_aware_non_cpu_only_selects_least_loaded_instance() {
1273        let state = RoutingOccupancyState::default();
1274        // All candidates are non-CPU. Make worker 2 the least-loaded one.
1275        for _ in 0..3 {
1276            state.increment(1);
1277        }
1278        state.increment(3);
1279
1280        let instance_ids = vec![1, 2, 3];
1281        let device_type_map = HashMap::from([
1282            (1, Some(DeviceType::Cuda)),
1283            (2, Some(DeviceType::Cuda)),
1284            (3, Some(DeviceType::Cuda)),
1285        ]);
1286
1287        let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
1288        assert_eq!(candidates, vec![1, 2, 3]);
1289
1290        let selected = state
1291            .select_exact_min_and_increment(&candidates)
1292            .await
1293            .unwrap();
1294        assert_eq!(selected, 2);
1295    }
1296
1297    #[test]
1298    fn device_aware_group_uses_ratio_budget() {
1299        let state = RoutingOccupancyState::default();
1300        // CPU ids: 1,2 ; non-CPU ids: 3,4
1301        for _ in 0..4 {
1302            state.increment(3);
1303            state.increment(4);
1304        }
1305        // CPU inflight can differ across instances; budgeting uses total CPU inflight.
1306        for _ in 0..3 {
1307            state.increment(1);
1308        }
1309        // total_non_cpu_inflight=8, cpu_count=2, non_cpu_count=2, ratio=2
1310        // allowed_cpu_inflight = 8*2/(2*2)=4
1311        // total_cpu_inflight=3 < 4 => choose CPU group.
1312        let instance_ids = vec![1, 2, 3, 4];
1313        let device_type_map = HashMap::from([
1314            (1, Some(DeviceType::Cpu)),
1315            (2, Some(DeviceType::Cpu)),
1316            (3, Some(DeviceType::Cuda)),
1317            (4, Some(DeviceType::Cuda)),
1318        ]);
1319
1320        let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 2);
1321        assert_eq!(candidates, vec![1, 2]);
1322
1323        // Within selected CPU group, final choice should be the least-loaded instance (id=2).
1324        let selected =
1325            futures::executor::block_on(state.select_exact_min_and_increment(&candidates)).unwrap();
1326        assert_eq!(selected, 2);
1327    }
1328
1329    #[tokio::test]
1330    async fn device_aware_weighted_select_and_peek_return_none_with_available_worker() {
1331        let rt = Runtime::from_current().unwrap();
1332        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1333            .await
1334            .unwrap();
1335        let ns = drt
1336            .namespace("test_device_aware_router".to_string())
1337            .unwrap();
1338        let component = ns.component("test_component".to_string()).unwrap();
1339        let endpoint = component.endpoint("test_endpoint".to_string());
1340        let client = endpoint.client().await.unwrap();
1341
1342        endpoint.register_endpoint_instance().await.unwrap();
1343        client.wait_for_instances().await.unwrap();
1344
1345        let router =
1346            PushRouter::<u64, TestResponse>::from_client(client, RouterMode::DeviceAwareWeighted)
1347                .await
1348                .unwrap();
1349
1350        assert_eq!(router.select_next_worker(), None);
1351        assert_eq!(router.peek_next_worker(), None);
1352
1353        rt.shutdown();
1354    }
1355
1356    /// When the router selects an instance that has deregistered between selection
1357    /// and transport resolution, it should fall back to another available instance
1358    /// rather than returning a 500 error.
1359    #[tokio::test]
1360    async fn transport_resolution_falls_back_when_selected_instance_disappears() {
1361        let rt = Runtime::from_current().unwrap();
1362        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1363            .await
1364            .unwrap();
1365        let ns = drt
1366            .namespace("test_transport_fallback".to_string())
1367            .unwrap();
1368        let component = ns.component("test_component".to_string()).unwrap();
1369        let endpoint = component.endpoint("test_endpoint".to_string());
1370        let client = endpoint.client().await.unwrap();
1371
1372        // Register one real instance so it appears in instance_source.
1373        endpoint.register_endpoint_instance().await.unwrap();
1374        client.wait_for_instances().await.unwrap();
1375
1376        let real_id = client.instance_ids()[0];
1377
1378        // Inject a stale ID into instance_avail that does NOT exist in
1379        // instance_source. This simulates the race window where an instance
1380        // deregistered after selection but before transport resolution.
1381        let stale_id = real_id + 1000;
1382        client.override_instance_avail(vec![stale_id, real_id]);
1383
1384        // Build a router and call direct() targeting the *real* instance to
1385        // verify the router can still resolve transport for known instances.
1386        let router =
1387            PushRouter::<u64, TestResponse>::from_client(client.clone(), RouterMode::RoundRobin)
1388                .await
1389                .unwrap();
1390
1391        // Round robin should succeed — even if it picks stale_id first, the
1392        // fallback logic should resolve transport via real_id.
1393        // We cannot fully test the network send without a worker, but we can
1394        // verify it doesn't fail at the transport resolution stage by checking
1395        // that the error (if any) is a transport/network error, not
1396        // "Instance not found".
1397        let request = SingleIn::new(42u64);
1398        let result = router.generate(request).await;
1399
1400        // The request may fail at the network level (no actual worker), but it
1401        // must NOT fail with "Instance X not found" — that would mean the
1402        // fallback did not work.
1403        if let Err(err) = &result {
1404            let msg = format!("{err}");
1405            assert!(
1406                !msg.contains("not found"),
1407                "Transport resolution should have fallen back, but got: {msg}"
1408            );
1409        }
1410
1411        rt.shutdown();
1412    }
1413
1414    /// When no instances are available at all (both primary and fallback),
1415    /// the router should return a clear error.
1416    #[tokio::test]
1417    async fn transport_resolution_errors_when_no_instances_available() {
1418        let rt = Runtime::from_current().unwrap();
1419        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1420            .await
1421            .unwrap();
1422        let ns = drt
1423            .namespace("test_transport_no_fallback".to_string())
1424            .unwrap();
1425        let component = ns.component("test_component".to_string()).unwrap();
1426        let endpoint = component.endpoint("test_endpoint".to_string());
1427        let client = endpoint.client().await.unwrap();
1428
1429        // Register an instance so we can create the router (needs transport setup).
1430        endpoint.register_endpoint_instance().await.unwrap();
1431        client.wait_for_instances().await.unwrap();
1432
1433        let router =
1434            PushRouter::<u64, TestResponse>::from_client(client.clone(), RouterMode::RoundRobin)
1435                .await
1436                .unwrap();
1437
1438        // Override avail to contain only a stale ID with no real backing
1439        // instance AND no other available fallback.
1440        let stale_id = 99999;
1441        client.override_instance_avail(vec![stale_id]);
1442
1443        let request = SingleIn::new(42u64);
1444        let result = router.generate(request).await;
1445
1446        assert!(result.is_err());
1447        let msg = format!("{}", result.unwrap_err());
1448        assert!(
1449            msg.contains("not found") && msg.contains("no other instances available"),
1450            "Expected clear error about missing instance with no fallback, got: {msg}"
1451        );
1452
1453        rt.shutdown();
1454    }
1455
1456    /// The watcher dedup guard must be released even if the spawned task panics.
1457    /// Without this, a panic anywhere in the watcher body would leave a stale
1458    /// `ENDPOINT_WATCHER_ACTIVE` entry, silently disabling orphaned-pending-
1459    /// request cancellation for that endpoint until process restart.
1460    ///
1461    /// We exercise the Drop-guard pattern directly against the same static
1462    /// rather than driving `spawn_instance_removal_watcher` end-to-end (which
1463    /// would require staging a panicking discovery stream). The test mirrors
1464    /// the production code's GuardRelease shape; if the production code stops
1465    /// using a Drop guard, the integration would regress and the existing
1466    /// orphan-cancellation tests would fail.
1467    #[tokio::test]
1468    async fn watcher_dedup_guard_released_on_panic() {
1469        let endpoint_id = EndpointId {
1470            namespace: "panic-test-ns".to_string(),
1471            component: "panic-test-comp".to_string(),
1472            name: "panic-test-endpoint".to_string(),
1473        };
1474
1475        // Mimic the production code's pre-spawn dedup insert.
1476        let map = ENDPOINT_WATCHER_ACTIVE.get_or_init(dashmap::DashMap::new);
1477        map.insert(endpoint_id.clone(), ());
1478
1479        let endpoint_id_clone = endpoint_id.clone();
1480        let join = tokio::spawn(async move {
1481            // Same shape as in spawn_instance_removal_watcher.
1482            struct GuardRelease(EndpointId);
1483            impl Drop for GuardRelease {
1484                fn drop(&mut self) {
1485                    if let Some(map) = ENDPOINT_WATCHER_ACTIVE.get() {
1486                        map.remove(&self.0);
1487                    }
1488                }
1489            }
1490            let _release = GuardRelease(endpoint_id_clone);
1491            panic!("simulated watcher-task panic");
1492        });
1493
1494        let result = join.await;
1495        assert!(result.is_err() && result.unwrap_err().is_panic());
1496        assert!(
1497            !map.contains_key(&endpoint_id),
1498            "Drop guard must release the dedup entry even on panic"
1499        );
1500    }
1501
1502    /// Normal-exit path: the Drop guard releases the entry when the task
1503    /// finishes without panicking. This is the everyday case (cancel_token
1504    /// fires or discovery stream closes).
1505    #[tokio::test]
1506    async fn watcher_dedup_guard_released_on_normal_exit() {
1507        let endpoint_id = EndpointId {
1508            namespace: "normal-test-ns".to_string(),
1509            component: "normal-test-comp".to_string(),
1510            name: "normal-test-endpoint".to_string(),
1511        };
1512
1513        let map = ENDPOINT_WATCHER_ACTIVE.get_or_init(dashmap::DashMap::new);
1514        map.insert(endpoint_id.clone(), ());
1515
1516        let endpoint_id_clone = endpoint_id.clone();
1517        tokio::spawn(async move {
1518            struct GuardRelease(EndpointId);
1519            impl Drop for GuardRelease {
1520                fn drop(&mut self) {
1521                    if let Some(map) = ENDPOINT_WATCHER_ACTIVE.get() {
1522                        map.remove(&self.0);
1523                    }
1524                }
1525            }
1526            let _release = GuardRelease(endpoint_id_clone);
1527            // task body returns normally
1528        })
1529        .await
1530        .unwrap();
1531
1532        assert!(!map.contains_key(&endpoint_id));
1533    }
1534}