Skip to main content

dynamo_runtime/pipeline/network/egress/
push_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{AsyncEngineContextProvider, ResponseStream};
5use crate::error::{BackendError, DynamoError, ErrorType, match_error_chain};
6use crate::{
7    component::{
8        Client, DeviceType, Endpoint, Instance, RoutingInstances, RoutingOccupancyState,
9        get_or_create_routing_occupancy_state,
10    },
11    discovery::EndpointInstanceId,
12    dynamo_nvtx_range,
13    engine::{AsyncEngine, AsyncEngineContext, Data},
14    metrics::frontend_perf::{STAGE_DURATION_SECONDS, STAGE_ROUTE},
15    pipeline::{
16        AddressedPushRouter, AddressedRequest, Error, ManyIn, ManyOut, SingleIn,
17        error::{PipelineError, PipelineErrorExt},
18    },
19    protocols::{EndpointId, maybe_error::MaybeError},
20    traits::DistributedRuntimeProvider,
21};
22use async_trait::async_trait;
23use futures::Stream;
24use rand::Rng;
25use serde::{Deserialize, Serialize};
26use std::{
27    collections::HashMap,
28    marker::PhantomData,
29    pin::Pin,
30    sync::{
31        Arc,
32        atomic::{AtomicU64, Ordering},
33    },
34    task::Poll,
35    time::Instant,
36};
37use tokio_stream::StreamExt;
38use tracing::Instrument;
39
40/// Check if an error chain indicates the worker should be reported as down.
41fn is_inhibited(err: &(dyn std::error::Error + 'static)) -> bool {
42    const INHIBITED: &[ErrorType] = &[
43        ErrorType::CannotConnect,
44        ErrorType::Disconnected,
45        ErrorType::ConnectionTimeout,
46        ErrorType::ResponseTimeout,
47        ErrorType::Backend(BackendError::EngineShutdown),
48    ];
49    match_error_chain(err, INHIBITED, &[])
50}
51
52/// Read the backend response inactivity timeout from the environment.
53/// Reuses `DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS` — the same env var
54/// as the HTTP-layer safety net in `disconnect.rs`.
55fn response_inactivity_timeout() -> Option<std::time::Duration> {
56    use crate::config::environment_names::llm::DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS;
57    std::env::var(DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS)
58        .ok()
59        .and_then(|s| s.parse::<u64>().ok())
60        .filter(|&secs| secs > 0)
61        .map(std::time::Duration::from_secs)
62}
63
64struct OccupancyPermit {
65    state: Arc<RoutingOccupancyState>,
66    instance_id: u64,
67    armed: bool,
68}
69
70impl OccupancyPermit {
71    fn new(state: Arc<RoutingOccupancyState>, instance_id: u64) -> Self {
72        Self {
73            state,
74            instance_id,
75            armed: true,
76        }
77    }
78
79    fn into_tracked_stream<U: Data>(mut self, stream: ManyOut<U>) -> ManyOut<U> {
80        self.armed = false;
81        let engine_ctx = stream.context();
82        ResponseStream::new(
83            Box::pin(OccupancyTrackedStream {
84                inner: stream,
85                state: self.state.clone(),
86                instance_id: self.instance_id,
87            }),
88            engine_ctx,
89        )
90    }
91
92    fn instance_id(&self) -> u64 {
93        self.instance_id
94    }
95}
96
97impl Drop for OccupancyPermit {
98    fn drop(&mut self) {
99        if self.armed {
100            self.state.decrement(self.instance_id);
101        }
102    }
103}
104
105/// Trait for monitoring worker load and determining overload state.
106/// Implementations can define custom load metrics and overload thresholds.
107#[async_trait]
108pub trait WorkerLoadMonitor: Send + Sync {
109    /// Start background monitoring of worker load.
110    /// This should spawn background tasks that update the client's overloaded instances.
111    async fn start_monitoring(&self) -> anyhow::Result<()>;
112}
113
114#[derive(Clone)]
115pub struct PushRouter<T, U>
116where
117    T: Data + Serialize,
118    U: Data + for<'de> Deserialize<'de>,
119{
120    // TODO: This shouldn't be pub, but lib/bindings/python/rust/lib.rs exposes it.
121    /// The Client is how we gather remote endpoint information from etcd.
122    pub client: Client,
123
124    /// How we choose which instance to send traffic to.
125    ///
126    /// Setting this to KV means we never intend to call `generate` on this PushRouter. We are
127    /// not using it as an AsyncEngine.
128    /// Instead we will decide whether to call random/round_robin/direct ourselves and call them directly.
129    /// dynamo-llm's KV Routing does this.
130    router_mode: RouterMode,
131
132    /// Number of round robin requests handled. Used to decide which server is next.
133    round_robin_counter: Arc<AtomicU64>,
134
135    /// The next step in the chain. PushRouter (this object) picks an instances,
136    /// addresses it, then passes it to AddressedPushRouter which does the network traffic.
137    addressed: Arc<AddressedPushRouter>,
138
139    /// When false, `generate_with_fault_detection` skips fault detection logic:
140    /// it won't call `report_instance_down` on errors, and it uses the raw discovery
141    /// instance list instead of the filtered avail list. Use for recovery/query paths
142    /// where transient failures are expected.
143    fault_detection_enabled: bool,
144
145    /// Cached response inactivity timeout. Read once at construction from
146    /// [`environment_names::llm::DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS`](crate::config::environment_names::llm::DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS) to avoid a syscall per request.
147    response_timeout: Option<std::time::Duration>,
148
149    /// Shared request occupancy state for tracked routing modes.
150    occupancy_state: Option<Arc<RoutingOccupancyState>>,
151
152    /// An internal Rust type. This says that PushRouter is generic over the T and U types,
153    /// which are the input and output types of it's `generate` function. It allows the
154    /// compiler to specialize us at compile time.
155    _phantom: PhantomData<(T, U)>,
156}
157
158#[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
159#[serde(rename_all = "snake_case")]
160pub enum RouterMode {
161    #[default]
162    RoundRobin,
163    Random,
164    PowerOfTwoChoices,
165    KV,
166    Direct,
167    LeastLoaded,
168    /// Device-aware weighted routing for heterogeneous workers.
169    DeviceAwareWeighted,
170}
171
172impl RouterMode {
173    pub fn is_kv_routing(&self) -> bool {
174        *self == RouterMode::KV
175    }
176
177    pub fn is_direct_routing(&self) -> bool {
178        *self == RouterMode::Direct
179    }
180}
181
182/// Pick the instance with lower in-flight count from two random candidates.
183/// Returns the single instance if only one is available.
184fn p2c_select_from(occupancy_state: &RoutingOccupancyState, instance_ids: &[u64]) -> u64 {
185    let count = instance_ids.len();
186    if count == 1 {
187        return instance_ids[0];
188    }
189    let mut rng = rand::rng();
190    let idx1 = rng.random_range(0..count);
191    let idx2 = (idx1 + 1 + rng.random_range(0..count - 1)) % count;
192    let id1 = instance_ids[idx1];
193    let id2 = instance_ids[idx2];
194    let load1 = occupancy_state.load(id1);
195    let load2 = occupancy_state.load(id2);
196    let selected = if load1 <= load2 { id1 } else { id2 };
197    tracing::debug!(
198        candidate_a = id1,
199        candidate_a_load = load1,
200        candidate_b = id2,
201        candidate_b_load = load2,
202        selected = selected,
203        "p2c selection"
204    );
205    selected
206}
207
208/// Select the target device group for the next request in `DeviceAwareWeighted` mode.
209///
210/// If only one class exists (all CPU or all non-CPU), returns that class directly.
211/// If both classes exist, compares capability-normalized load and returns the less-loaded group.
212///
213/// Budget check (integer form):
214/// `allowed_cpu_inflight = total_non_cpu_inflight * cpu_count / (ratio * non_cpu_count)`
215/// and choose CPU when `total_cpu_inflight < allowed_cpu_inflight`.
216///
217/// `ratio` is `non_cpu_to_cpu_ratio` (from `DYN_ENCODER_CUDA_TO_CPU_RATIO`,
218/// default `8` in `device_aware_weighted`).
219fn device_aware_candidate_group(
220    state: &RoutingOccupancyState,
221    instance_ids: &[u64],
222    device_type_map: &HashMap<u64, Option<DeviceType>>,
223    non_cpu_to_cpu_ratio: usize,
224) -> Vec<u64> {
225    let cpu_ids: Vec<u64> = instance_ids
226        .iter()
227        .copied()
228        .filter(|id| matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
229        .collect();
230    let non_cpu_ids: Vec<u64> = instance_ids
231        .iter()
232        .copied()
233        .filter(|id| !matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
234        .collect();
235
236    if cpu_ids.is_empty() {
237        return non_cpu_ids;
238    }
239    if non_cpu_ids.is_empty() {
240        return cpu_ids;
241    }
242
243    // Both classes exist: compute a budget for CPU in-flight requests.
244    let total_non_cpu_inflight: u64 = non_cpu_ids.iter().map(|id| state.load(*id)).sum();
245    let total_cpu_inflight: u64 = cpu_ids.iter().map(|id| state.load(*id)).sum();
246    let cpu_count = cpu_ids.len() as u64;
247    let non_cpu_count = non_cpu_ids.len() as u64;
248    let allowed_cpu_inflight = total_non_cpu_inflight.saturating_mul(cpu_count)
249        / ((non_cpu_to_cpu_ratio as u64).saturating_mul(non_cpu_count));
250
251    if total_cpu_inflight < allowed_cpu_inflight {
252        cpu_ids
253    } else {
254        non_cpu_ids
255    }
256}
257
258/// At most one `list_and_watch` per endpoint, across all `PushRouter`
259/// instances. Entry removed on watcher exit so a later router can re-arm.
260static ENDPOINT_WATCHER_ACTIVE: std::sync::OnceLock<dashmap::DashMap<EndpointId, ()>> =
261    std::sync::OnceLock::new();
262
263/// Watch discovery for instance removals and cancel pending response-stream
264/// registrations on the removed instance, unblocking queued requests with
265/// a migratable `Disconnected` error. Uses raw `list_and_watch` events
266/// (not a coalesced snapshot diff) so a rapid remove→re-add of the same
267/// identity is not silently swallowed. Keyed by full `EndpointInstanceId`.
268fn spawn_instance_removal_watcher(
269    endpoint: Endpoint,
270    addressed: Arc<AddressedPushRouter>,
271    cancel_token: tokio_util::sync::CancellationToken,
272) {
273    use crate::discovery::{
274        DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
275    };
276    use tokio_stream::StreamExt as _;
277
278    // One watcher per endpoint: if one is already running, skip.
279    let guard = ENDPOINT_WATCHER_ACTIVE.get_or_init(dashmap::DashMap::new);
280    let endpoint_id = endpoint.id();
281    if guard.insert(endpoint_id.clone(), ()).is_some() {
282        tracing::debug!(
283            ?endpoint_id,
284            "Instance removal watcher already running for this endpoint, skipping"
285        );
286        return;
287    }
288
289    let endpoint_name = endpoint.name().to_string();
290
291    tokio::spawn(async move {
292        // Release on every exit path (including panic); a leaked entry
293        // silently disables removal cancellation until process restart.
294        struct GuardRelease(EndpointId);
295        impl Drop for GuardRelease {
296            fn drop(&mut self) {
297                if let Some(map) = ENDPOINT_WATCHER_ACTIVE.get() {
298                    map.remove(&self.0);
299                }
300            }
301        }
302        let _release = GuardRelease(endpoint_id);
303
304        let namespace = endpoint.component().namespace().name();
305        let component = endpoint.component().name().to_string();
306
307        // Reconnect on transient discovery failure; cancel-aware backoff.
308        const RECONNECT_BACKOFF: std::time::Duration = std::time::Duration::from_secs(5);
309        'reconnect: loop {
310            let query = DiscoveryQuery::Endpoint {
311                namespace: namespace.clone(),
312                component: component.clone(),
313                endpoint: endpoint_name.clone(),
314            };
315
316            let mut stream = match endpoint.drt().discovery().list_and_watch(query, None).await {
317                Ok(s) => s,
318                Err(e) => {
319                    tracing::warn!(
320                        endpoint = %endpoint_name,
321                        "Failed to start instance removal watcher (will retry): {e}"
322                    );
323                    tokio::select! {
324                        _ = tokio::time::sleep(RECONNECT_BACKOFF) => continue 'reconnect,
325                        _ = cancel_token.cancelled() => break 'reconnect,
326                    }
327                }
328            };
329
330            loop {
331                tokio::select! {
332                    event = stream.next() => {
333                        match event {
334                            Some(Ok(DiscoveryEvent::Removed(id))) => {
335                                if let DiscoveryInstanceId::Endpoint(eid) = &id {
336                                    let n = addressed.cancel_instance_streams(eid).await;
337                                    if n > 0 {
338                                        tracing::warn!(
339                                            namespace = %eid.namespace,
340                                            component = %eid.component,
341                                            endpoint = %eid.endpoint,
342                                            instance_id = eid.instance_id,
343                                            cancelled = n,
344                                            "Cancelled pending response streams for removed \
345                                             instance (discovery-driven cleanup)"
346                                        );
347                                    }
348                                }
349                            }
350                            Some(Ok(DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)))) => {
351                                let eid: EndpointInstanceId = inst.endpoint_instance_id();
352                                addressed.clear_instance_tombstone(&eid).await;
353                            }
354                            Some(Ok(_)) => {}
355                            Some(Err(e)) => {
356                                tracing::warn!(
357                                    endpoint = %endpoint_name,
358                                    "Instance removal watcher stream error: {e}"
359                                );
360                            }
361                            None => {
362                                tracing::warn!(
363                                    endpoint = %endpoint_name,
364                                    "Instance removal watcher stream ended; reconnecting"
365                                );
366                                continue 'reconnect;
367                            }
368                        }
369                    }
370                    _ = cancel_token.cancelled() => {
371                        break 'reconnect;
372                    }
373                }
374            }
375        }
376
377        tracing::debug!(endpoint = %endpoint_name, "Instance removal watcher exiting");
378    });
379}
380
381async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
382    AddressedPushRouter::from_runtime_provider(endpoint).await
383}
384
385impl<T, U> PushRouter<T, U>
386where
387    T: Data + Serialize,
388    U: Data + for<'de> Deserialize<'de> + MaybeError,
389{
390    /// Create a new PushRouter without a worker load monitor (no overload detection)
391    pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
392        Self::from_client_with_monitor(client, router_mode, None).await
393    }
394
395    /// Create a new PushRouter with fault detection disabled.
396    ///
397    /// Unlike `from_client`, this router will not call `report_instance_down` on
398    /// transient errors, and `direct()` uses the raw discovery instance list instead
399    /// of the filtered avail list. Use for recovery/query paths.
400    pub async fn from_client_no_fault_detection(
401        client: Client,
402        router_mode: RouterMode,
403    ) -> anyhow::Result<Self> {
404        let addressed = addressed_router(&client.endpoint).await?;
405
406        let occupancy_state = if matches!(
407            router_mode,
408            RouterMode::PowerOfTwoChoices
409                | RouterMode::LeastLoaded
410                | RouterMode::DeviceAwareWeighted
411        ) {
412            Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
413        } else {
414            None
415        };
416
417        // Cancel orphaned pending response streams when workers die.
418        spawn_instance_removal_watcher(
419            client.endpoint.clone(),
420            addressed.clone(),
421            client.endpoint.drt().primary_token(),
422        );
423
424        Ok(PushRouter {
425            client,
426            addressed,
427            router_mode,
428            round_robin_counter: Arc::new(AtomicU64::new(0)),
429            fault_detection_enabled: false,
430            response_timeout: response_inactivity_timeout(),
431            occupancy_state,
432            _phantom: PhantomData,
433        })
434    }
435
436    /// Create a new PushRouter with an optional worker load monitor.
437    ///
438    /// The rejection path is gated by `fault_detection_enabled` (true here);
439    /// overload detection itself is driven by the monitor via `client.set_overloaded_instances(...)`.
440    /// If no thresholds are configured on the monitor (or no monitor is provided),
441    /// the routing snapshot reports at least one free instance and the gate never rejects.
442    pub async fn from_client_with_monitor(
443        client: Client,
444        router_mode: RouterMode,
445        worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
446    ) -> anyhow::Result<Self> {
447        let addressed = addressed_router(&client.endpoint).await?;
448
449        // Start worker monitor if provided and in dynamic mode
450        if let Some(monitor) = worker_monitor.as_ref() {
451            monitor.start_monitoring().await?;
452        }
453
454        let occupancy_state = if matches!(
455            router_mode,
456            RouterMode::PowerOfTwoChoices
457                | RouterMode::LeastLoaded
458                | RouterMode::DeviceAwareWeighted
459        ) {
460            Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
461        } else {
462            None
463        };
464
465        // Cancel orphaned pending response streams when workers die.
466        spawn_instance_removal_watcher(
467            client.endpoint.clone(),
468            addressed.clone(),
469            client.endpoint.drt().primary_token(),
470        );
471
472        let router = PushRouter {
473            client,
474            addressed,
475            router_mode,
476            round_robin_counter: Arc::new(AtomicU64::new(0)),
477            fault_detection_enabled: true,
478            response_timeout: response_inactivity_timeout(),
479            occupancy_state,
480            _phantom: PhantomData,
481        };
482
483        Ok(router)
484    }
485
486    /// `ResourceExhausted` when workers are routable but all overloaded;
487    /// `anyhow!("no instances found")` when no routable workers exist.
488    fn empty_free_pool_error(&self, routing_instances: &RoutingInstances) -> anyhow::Error {
489        if !routing_instances.routable_ids().is_empty() {
490            let cause = PipelineError::ServiceOverloaded(
491                "All workers are busy, please retry later".to_string(),
492            );
493            return DynamoError::builder()
494                .error_type(ErrorType::ResourceExhausted)
495                .message("All workers are busy, please retry later")
496                .cause(cause)
497                .build()
498                .into();
499        }
500        anyhow::anyhow!(
501            "no instances found for endpoint {}",
502            self.client.endpoint.id()
503        )
504    }
505
506    /// Issue a request to the next available instance in a round-robin fashion
507    pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
508        let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
509
510        let instance_id = {
511            let routing_instances = self.client.routing_instances();
512            let count = routing_instances.free_ids().len();
513            if count == 0 {
514                return Err(self.empty_free_pool_error(&routing_instances));
515            }
516            routing_instances.free_ids()[counter % count]
517        };
518        tracing::trace!("round robin router selected {instance_id}");
519
520        self.generate_with_fault_detection(instance_id, request)
521            .await
522    }
523
524    /// Issue a request to a random endpoint
525    pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
526        let instance_id = {
527            let routing_instances = self.client.routing_instances();
528            let count = routing_instances.free_ids().len();
529            if count == 0 {
530                return Err(self.empty_free_pool_error(&routing_instances));
531            }
532            let counter = rand::rng().random::<u64>() as usize;
533            routing_instances.free_ids()[counter % count]
534        };
535        tracing::trace!("random router selected {instance_id}");
536
537        self.generate_with_fault_detection(instance_id, request)
538            .await
539    }
540
541    /// Issue a request using power-of-two-choices: pick 2 random healthy workers,
542    /// route to the one with fewer in-flight requests.
543    pub async fn power_of_two_choices(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
544        let state = self.occupancy_state()?;
545        let instance_id = {
546            let routing_instances = self.client.routing_instances();
547            if routing_instances.free_ids().is_empty() {
548                return Err(self.empty_free_pool_error(&routing_instances));
549            }
550            p2c_select_from(state.as_ref(), routing_instances.free_ids())
551        };
552        state.increment(instance_id);
553        let permit = OccupancyPermit::new(state, instance_id);
554
555        match self
556            .generate_with_fault_detection(instance_id, request)
557            .await
558        {
559            Ok(stream) => Ok(permit.into_tracked_stream(stream)),
560            Err(err) => Err(err),
561        }
562    }
563
564    /// Issue a request to a specific endpoint
565    pub async fn direct(
566        &self,
567        request: SingleIn<T>,
568        instance_id: u64,
569    ) -> anyhow::Result<ManyOut<U>> {
570        // When fault detection is disabled, check the raw discovery list
571        // (not filtered by report_instance_down) so transient failures
572        // don't poison the instance for subsequent retries.
573        let found = {
574            if self.fault_detection_enabled {
575                let routing_instances = self.client.routing_instances();
576                routing_instances.routable_ids().contains(&instance_id)
577            } else {
578                self.client.instance_ids().contains(&instance_id)
579            }
580        };
581
582        if !found {
583            return Err(anyhow::anyhow!(
584                "instance_id={instance_id} not found for endpoint {}",
585                self.client.endpoint.id()
586            ));
587        }
588
589        self.generate_with_fault_detection(instance_id, request)
590            .await
591    }
592
593    /// Issue a request using device-aware weighted routing.
594    ///
595    /// Instances are partitioned by device type (CPU vs non-CPU), then the router
596    /// applies a budget policy and selects the least-loaded instance within the
597    /// chosen group.
598    ///
599    /// If only one device class exists (all CPU or all non-CPU), this naturally
600    /// degenerates to least-loaded routing over the available instances.
601    pub async fn device_aware_weighted(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
602        let state = self.occupancy_state()?;
603        let routing_instances = self.client.routing_instances();
604        let instance_ids = routing_instances.free_ids().to_vec();
605
606        if instance_ids.is_empty() {
607            return Err(self.empty_free_pool_error(&routing_instances));
608        }
609
610        // Apply a unified policy for all endpoints.
611        let endpoint_id = self.client.endpoint.id();
612
613        // For encoder endpoints, partition by device type
614        let instances = self.client.instances();
615        let device_type_map: std::collections::HashMap<u64, Option<DeviceType>> = instances
616            .iter()
617            .map(|inst| (inst.instance_id, inst.device_type.clone()))
618            .collect();
619
620        // Apply budget-based routing to determine which group to send to
621        let cuda_to_cpu_ratio = std::env::var("DYN_ENCODER_CUDA_TO_CPU_RATIO")
622            .ok()
623            .and_then(|v| v.parse::<usize>().ok())
624            .filter(|v| *v >= 1)
625            .unwrap_or(8);
626        let candidates = device_aware_candidate_group(
627            state.as_ref(),
628            &instance_ids,
629            &device_type_map,
630            cuda_to_cpu_ratio,
631        );
632
633        // Empty group: budget-selected device class has no free workers.
634        let instance_id = state
635            .select_exact_min_and_increment(&candidates)
636            .await
637            .ok_or_else(|| self.empty_free_pool_error(&routing_instances))?;
638        let permit = OccupancyPermit::new(state.clone(), instance_id);
639        let is_cpu = matches!(
640            device_type_map.get(&instance_id),
641            Some(Some(DeviceType::Cpu))
642        );
643        tracing::info!(
644            endpoint = %endpoint_id,
645            selected_instance = instance_id,
646            is_cpu,
647            "DeviceAwareWeighted selected instance"
648        );
649
650        match self
651            .generate_with_fault_detection(instance_id, request)
652            .await
653        {
654            Ok(stream) => Ok(permit.into_tracked_stream(stream)),
655            Err(err) => Err(err),
656        }
657    }
658
659    /// Issue a request to the instance with the fewest active connections.
660    pub async fn least_loaded(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
661        let state = self.occupancy_state()?;
662        let routing_instances = self.client.routing_instances();
663        let instance_ids = routing_instances.free_ids().to_vec();
664        let instance_id = state
665            .select_exact_min_and_increment(&instance_ids)
666            .await
667            .ok_or_else(|| self.empty_free_pool_error(&routing_instances))?;
668        let permit = OccupancyPermit::new(state.clone(), instance_id);
669        tracing::trace!(
670            "least loaded router selected {instance_id} (connections: {})",
671            state.load(instance_id)
672        );
673
674        match self
675            .generate_with_fault_detection(instance_id, request)
676            .await
677        {
678            Ok(stream) => Ok(permit.into_tracked_stream(stream)),
679            Err(err) => Err(err),
680        }
681    }
682
683    /// Select the next worker according to the routing mode.
684    /// Increments round-robin counter if applicable.
685    /// Returns None for modes that require request lifecycle tracking or explicit routing hints.
686    pub fn select_next_worker(&self) -> Option<u64> {
687        let routing_instances = self.client.routing_instances();
688        let count = routing_instances.free_ids().len();
689        if count == 0 {
690            return None;
691        }
692
693        match self.router_mode {
694            RouterMode::RoundRobin => {
695                let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
696                Some(routing_instances.free_ids()[counter % count])
697            }
698            RouterMode::Random => {
699                let counter = rand::rng().random::<u64>() as usize;
700                Some(routing_instances.free_ids()[counter % count])
701            }
702            RouterMode::PowerOfTwoChoices
703            | RouterMode::Direct
704            | RouterMode::LeastLoaded
705            | RouterMode::DeviceAwareWeighted => None,
706            RouterMode::KV => {
707                panic!(
708                    "select_next_worker should not be called for {:?} routing mode",
709                    self.router_mode
710                )
711            }
712        }
713    }
714
715    /// Peek the next worker according to the routing mode without incrementing the counter.
716    /// Useful for checking if a worker is suitable before committing to it.
717    /// Returns None for modes that require request lifecycle tracking or explicit routing hints.
718    pub fn peek_next_worker(&self) -> Option<u64> {
719        let routing_instances = self.client.routing_instances();
720        let count = routing_instances.free_ids().len();
721        if count == 0 {
722            return None;
723        }
724
725        match self.router_mode {
726            RouterMode::RoundRobin => {
727                // Just peek at the current counter value without incrementing
728                let counter = self.round_robin_counter.load(Ordering::Relaxed) as usize;
729                Some(routing_instances.free_ids()[counter % count])
730            }
731            RouterMode::Random => {
732                // For random, peeking implies a fresh random selection since it's stateless.
733                // Note: The caller must realize that select_next_worker() will pick a DIFFERENT random worker.
734                let counter = rand::rng().random::<u64>() as usize;
735                Some(routing_instances.free_ids()[counter % count])
736            }
737            RouterMode::PowerOfTwoChoices
738            | RouterMode::Direct
739            | RouterMode::LeastLoaded
740            | RouterMode::DeviceAwareWeighted => None,
741            RouterMode::KV => {
742                panic!(
743                    "peek_next_worker should not be called for {:?} routing mode",
744                    self.router_mode
745                )
746            }
747        }
748    }
749
750    fn occupancy_state(&self) -> anyhow::Result<Arc<RoutingOccupancyState>> {
751        self.occupancy_state.clone().ok_or_else(|| {
752            anyhow::anyhow!(
753                "routing occupancy state not initialized for endpoint {}",
754                self.client.endpoint.id()
755            )
756        })
757    }
758
759    /*
760    pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
761        let subject = self.client.endpoint.subject();
762        tracing::debug!("static got subject: {subject}");
763        let request = request.map(|req| AddressedRequest::new(req, subject));
764        tracing::debug!("router generate");
765        self.addressed.generate(request).await
766    }
767    */
768
769    async fn generate_with_fault_detection(
770        &self,
771        mut instance_id: u64,
772        request: SingleIn<T>,
773    ) -> anyhow::Result<ManyOut<U>> {
774        let route_start = Instant::now();
775        let request_id = request.id().to_string();
776        let route_span = if matches!(self.router_mode, RouterMode::KV) {
777            tracing::Span::none()
778        } else {
779            tracing::info_span!(
780                "router.route_request",
781                request_id = %request_id,
782                worker_id = instance_id,
783                router_mode = ?self.router_mode,
784            )
785        };
786
787        // Check if the selected worker is overloaded (when fault detection is enabled).
788        if self.fault_detection_enabled {
789            let routing_instances = self.client.routing_instances();
790            let selected_worker_overloaded = routing_instances.is_overloaded(instance_id);
791            let counts = routing_instances.counts();
792            if tracing::enabled!(tracing::Level::DEBUG) {
793                tracing::debug!(
794                    request_id = %request_id,
795                    instance_id,
796                    router_mode = ?self.router_mode,
797                    free_workers = counts.free,
798                    overloaded_workers = counts.overloaded,
799                    total_workers = counts.discovered,
800                    selected_worker_overloaded,
801                    "checked worker overload state"
802                );
803            }
804            if selected_worker_overloaded {
805                tracing::warn!(
806                    instance_id,
807                    overloaded_workers = counts.overloaded,
808                    total_workers = counts.discovered,
809                    "Rejecting request: selected worker is overloaded"
810                );
811                let cause = PipelineError::ServiceOverloaded(
812                    "Selected worker is overloaded, please retry later".to_string(),
813                );
814                return Err(DynamoError::builder()
815                    .error_type(ErrorType::ResourceExhausted)
816                    .message("Selected worker is overloaded, please retry later")
817                    .cause(cause)
818                    .build()
819                    .into());
820            }
821        }
822
823        // Resolve transport address; if the selected instance disappeared
824        // between selection and dispatch, fall back to another available one.
825        let (address, _transport_kind, instance) = {
826            use crate::component::TransportType;
827
828            let resolve_transport = |id: u64| {
829                let instances = self.client.instances();
830                instances
831                    .iter()
832                    .find(|i| i.instance_id == id)
833                    .map(|instance| {
834                        let (addr, kind) = match &instance.transport {
835                            TransportType::Tcp(tcp_endpoint) => {
836                                tracing::debug!(
837                                    instance_id = id,
838                                    tcp_endpoint = %tcp_endpoint,
839                                    "Using TCP transport for instance"
840                                );
841                                (tcp_endpoint.clone(), "transport.tcp.request")
842                            }
843                            TransportType::Nats(subject) => {
844                                tracing::debug!(
845                                    instance_id = id,
846                                    subject = %subject,
847                                    "Using NATS transport for instance"
848                                );
849                                (subject.clone(), "transport.nats.request")
850                            }
851                        };
852                        (addr, kind, instance.clone())
853                    })
854            };
855
856            if let Some(result) = resolve_transport(instance_id) {
857                result
858            } else {
859                // Instance vanished — pick another from free_ids (same filter
860                // as pre-selection) and retry the lookup once.
861                let routing_instances = self.client.routing_instances();
862                let fallback_id = routing_instances
863                    .free_ids()
864                    .iter()
865                    .copied()
866                    .find(|&id| id != instance_id);
867                match fallback_id {
868                    Some(id) => {
869                        tracing::warn!(
870                            original_instance = instance_id,
871                            fallback_instance = id,
872                            "Instance disappeared during routing, reselecting"
873                        );
874                        instance_id = id;
875                        resolve_transport(id).ok_or_else(|| {
876                            anyhow::anyhow!(
877                                "Fallback instance {} also not found for endpoint {}",
878                                id,
879                                self.client.endpoint.id()
880                            )
881                        })?
882                    }
883                    None => {
884                        return Err(anyhow::anyhow!(
885                            "Instance {} not found and no other instances available \
886                             for endpoint {}",
887                            instance_id,
888                            self.client.endpoint.id()
889                        ));
890                    }
891                }
892            }
893        };
894
895        let request = request.map(|req| AddressedRequest::with_instance(req, address, instance));
896
897        STAGE_DURATION_SECONDS
898            .with_label_values(&[STAGE_ROUTE])
899            .observe(route_start.elapsed().as_secs_f64());
900
901        let _nvtx_transport = dynamo_nvtx_range!(_transport_kind);
902        let stream: anyhow::Result<ManyOut<U>> = self
903            .addressed
904            .generate(request)
905            .instrument(route_span)
906            .await;
907        match stream {
908            Ok(stream) => {
909                if !self.fault_detection_enabled {
910                    return Ok(stream);
911                }
912                let engine_ctx = stream.context();
913                let client = self.client.clone();
914                let client_for_timeout = self.client.clone();
915                let stream = stream.map(move |res| {
916                    // Check if the error is migratable (indicates worker/connection failure)
917                    if let Some(err) = res.err()
918                        && is_inhibited(&err)
919                    {
920                        tracing::debug!(
921                            "Reporting instance {instance_id} down due to migratable error: {err}"
922                        );
923                        client.report_instance_down(instance_id);
924                    }
925                    res
926                });
927
928                // Request-plane inactivity timeout: emit a ResponseTimeout error item
929                // when the backend stops producing output. This triggers is_inhibited()
930                // → report_instance_down() to quarantine the worker.
931                let stream: Pin<Box<dyn Stream<Item = U> + Send>> = if let Some(timeout) =
932                    self.response_timeout
933                {
934                    Box::pin(async_stream::stream! {
935                        let mut inner = Box::pin(stream);
936                        loop {
937                            tokio::select! {
938                                biased;
939                                item = inner.next() => {
940                                    match item {
941                                        Some(item) => yield item,
942                                        None => break,
943                                    }
944                                }
945                                _ = tokio::time::sleep(timeout) => {
946                                    tracing::warn!(
947                                        instance_id,
948                                        timeout_secs = timeout.as_secs(),
949                                        "backend response inactivity timeout — quarantining worker"
950                                    );
951                                    client_for_timeout.report_instance_down(instance_id);
952                                    yield U::from_err(
953                                        crate::error::DynamoError::builder()
954                                            .error_type(crate::error::ErrorType::ResponseTimeout)
955                                            .message("backend response inactivity timeout")
956                                            .build()
957                                    );
958                                    break;
959                                }
960                            }
961                        }
962                    })
963                } else {
964                    Box::pin(stream)
965                };
966
967                Ok(ResponseStream::new(stream, engine_ctx))
968            }
969            Err(err) => {
970                if self.fault_detection_enabled && is_inhibited(err.as_ref()) {
971                    tracing::debug!("Reporting instance {instance_id} down due to error: {err}");
972                    self.client.report_instance_down(instance_id);
973                }
974                Err(err)
975            }
976        }
977    }
978}
979
980#[async_trait]
981impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
982where
983    T: Data + Serialize,
984    U: Data + for<'de> Deserialize<'de> + MaybeError,
985{
986    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
987        match self.router_mode {
988            RouterMode::Random => self.random(request).await,
989            RouterMode::RoundRobin => self.round_robin(request).await,
990            RouterMode::PowerOfTwoChoices => self.power_of_two_choices(request).await,
991            RouterMode::KV => {
992                anyhow::bail!("KV routing should not call generate on PushRouter");
993            }
994            RouterMode::Direct => {
995                anyhow::bail!(
996                    "Direct routing should not call generate on PushRouter directly; use DirectRoutingRouter wrapper"
997                );
998            }
999            RouterMode::LeastLoaded => self.least_loaded(request).await,
1000            RouterMode::DeviceAwareWeighted => self.device_aware_weighted(request).await,
1001        }
1002    }
1003}
1004
1005/// Bidirectional `AsyncEngine` impl for streaming-input workloads (e.g. the
1006/// OpenAI Realtime API). Selects a sticky instance on the first inbound frame
1007/// and binds the whole input stream to that worker. Required so engines of
1008/// shape `BidirectionalStreamingEngine<T, U>` can be stored as a `PushRouter`
1009/// in `WorkerSet`.
1010///
1011/// Remote per-frame dispatch over `AddressedPushRouter` / `PushWorkHandler`
1012/// is not yet implemented; this impl currently bails after selecting the
1013/// worker. KV and Direct modes inherit the same `bail!` invariants as the
1014/// unary impl.
1015#[async_trait]
1016impl<T, U> AsyncEngine<ManyIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
1017where
1018    T: Data + Serialize,
1019    U: Data + for<'de> Deserialize<'de> + MaybeError,
1020{
1021    async fn generate(&self, mut input: ManyIn<T>) -> Result<ManyOut<U>, Error> {
1022        match self.router_mode {
1023            RouterMode::KV => {
1024                anyhow::bail!("KV routing should not call generate on PushRouter");
1025            }
1026            RouterMode::Direct => {
1027                anyhow::bail!(
1028                    "Direct routing should not call generate on PushRouter directly; use DirectRoutingRouter wrapper"
1029                );
1030            }
1031            _ => {}
1032        }
1033
1034        // Wait for the first frame so the sticky-instance pick reflects the
1035        // session's actual start, not router construction time.
1036        if input.next().await.is_none() {
1037            anyhow::bail!("bidirectional input stream closed before first frame");
1038        }
1039        let instance_id = self
1040            .select_next_worker()
1041            .ok_or_else(|| anyhow::anyhow!("no instances available for bidirectional routing"))?;
1042
1043        // Per-frame remote dispatch over AddressedPushRouter / PushWorkHandler
1044        // is tracked in #9361. Until that lands, callers must register engines
1045        // in-process via ModelManager rather than rely on discovered workers.
1046        anyhow::bail!(
1047            "bidirectional remote dispatch is not yet implemented (selected instance {instance_id})"
1048        )
1049    }
1050}
1051
1052struct OccupancyTrackedStream<U: Data> {
1053    inner: ManyOut<U>,
1054    state: Arc<RoutingOccupancyState>,
1055    instance_id: u64,
1056}
1057
1058impl<U: Data> Drop for OccupancyTrackedStream<U> {
1059    fn drop(&mut self) {
1060        self.state.decrement(self.instance_id);
1061    }
1062}
1063
1064impl<U: Data> std::fmt::Debug for OccupancyTrackedStream<U> {
1065    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1066        f.debug_struct("OccupancyTrackedStream")
1067            .field("instance_id", &self.instance_id)
1068            .finish()
1069    }
1070}
1071
1072impl<U: Data> Stream for OccupancyTrackedStream<U> {
1073    type Item = U;
1074
1075    fn poll_next(
1076        mut self: Pin<&mut Self>,
1077        cx: &mut std::task::Context<'_>,
1078    ) -> Poll<Option<Self::Item>> {
1079        self.inner.as_mut().poll_next(cx)
1080    }
1081}
1082
1083impl<U: Data> AsyncEngineContextProvider for OccupancyTrackedStream<U> {
1084    fn context(&self) -> Arc<dyn AsyncEngineContext> {
1085        self.inner.context()
1086    }
1087}
1088
1089impl<U: Data> crate::engine::AsyncEngineStream<U> for OccupancyTrackedStream<U> {}
1090
1091#[cfg(test)]
1092mod tests {
1093    use super::*;
1094    use crate::{
1095        DistributedRuntime, Runtime,
1096        distributed::DistributedConfig,
1097        error::DynamoError,
1098        pipeline::{ResponseStream, context::Controller},
1099    };
1100    use serde::{Deserialize, Serialize};
1101
1102    #[derive(Clone, Debug, Deserialize, Serialize)]
1103    struct TestResponse {
1104        error: Option<DynamoError>,
1105    }
1106
1107    impl MaybeError for TestResponse {
1108        fn from_err(err: impl std::error::Error + 'static) -> Self {
1109            Self {
1110                error: Some(DynamoError::from(
1111                    Box::new(err) as Box<dyn std::error::Error + 'static>
1112                )),
1113            }
1114        }
1115
1116        fn err(&self) -> Option<DynamoError> {
1117            self.error.clone()
1118        }
1119    }
1120
1121    #[test]
1122    fn p2c_selects_lower_load_worker() {
1123        let state = RoutingOccupancyState::default();
1124        for _ in 0..10 {
1125            state.increment(1);
1126        }
1127        state.increment(2);
1128
1129        // With only two workers, p2c_select_from must pick both and choose id=2 (lower load).
1130        let result = p2c_select_from(&state, &[1, 2]);
1131        assert_eq!(result, 2);
1132    }
1133
1134    #[test]
1135    fn p2c_selects_single_worker() {
1136        let state = RoutingOccupancyState::default();
1137        assert_eq!(p2c_select_from(&state, &[42]), 42);
1138    }
1139
1140    #[test]
1141    fn p2c_treats_missing_counts_as_zero() {
1142        let state = RoutingOccupancyState::default();
1143        for _ in 0..5 {
1144            state.increment(1);
1145        }
1146        // Worker 2 has no entry — should be treated as 0, so it wins.
1147        let result = p2c_select_from(&state, &[1, 2]);
1148        assert_eq!(result, 2);
1149    }
1150
1151    #[test]
1152    fn p2c_returns_valid_worker_on_tie() {
1153        let state = RoutingOccupancyState::default();
1154        for _ in 0..3 {
1155            state.increment(1);
1156            state.increment(2);
1157        }
1158
1159        for _ in 0..100 {
1160            let result = p2c_select_from(&state, &[1, 2]);
1161            assert!(result == 1 || result == 2);
1162        }
1163    }
1164
1165    #[test]
1166    fn occupancy_permit_decrements_before_stream_creation() {
1167        let state = Arc::new(RoutingOccupancyState::default());
1168        state.increment(42);
1169        let permit = OccupancyPermit::new(state.clone(), 42);
1170        assert_eq!(state.load(42), 1);
1171        drop(permit);
1172        assert_eq!(state.load(42), 0);
1173    }
1174
1175    #[test]
1176    fn occupancy_tracked_stream_decrements_on_drop() {
1177        let state = Arc::new(RoutingOccupancyState::default());
1178        state.increment(7);
1179        let permit = OccupancyPermit::new(state.clone(), 7);
1180        let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
1181        let stream = permit.into_tracked_stream(ResponseStream::new(
1182            Box::pin(tokio_stream::iter(vec![1u64])),
1183            ctx,
1184        ));
1185        assert_eq!(state.load(7), 1);
1186        drop(stream);
1187        assert_eq!(state.load(7), 0);
1188    }
1189
1190    #[test]
1191    fn p2c_lifecycle_tracks_inflight_counts_with_shared_tracker() {
1192        let state = Arc::new(RoutingOccupancyState::default());
1193        let mut permits = Vec::new();
1194        for _ in 0..5 {
1195            let selected = p2c_select_from(&state, &[1, 2]);
1196            state.increment(selected);
1197            permits.push(OccupancyPermit::new(state.clone(), selected));
1198        }
1199
1200        let total = state.load(1) + state.load(2);
1201        assert_eq!(total, 5, "5 in-flight requests should be tracked");
1202
1203        drop(permits);
1204        let total = state.load(1) + state.load(2);
1205        assert_eq!(total, 0, "All guards dropped, counts should be 0");
1206    }
1207
1208    #[test]
1209    fn p2c_never_selects_dominated_worker() {
1210        let state = RoutingOccupancyState::default();
1211        for _ in 0..100 {
1212            state.increment(3);
1213        }
1214
1215        let mut selected = [0u32; 3];
1216        for _ in 0..1000 {
1217            let result = p2c_select_from(&state, &[1, 2, 3]);
1218            match result {
1219                1 => selected[0] += 1,
1220                2 => selected[1] += 1,
1221                3 => selected[2] += 1,
1222                _ => panic!("unexpected worker id"),
1223            }
1224        }
1225        assert_eq!(
1226            selected[2], 0,
1227            "Worker 3 (load=100) should never be selected against load=0 workers, but got {} times",
1228            selected[2]
1229        );
1230    }
1231
1232    #[tokio::test]
1233    async fn least_loaded_selects_exact_min_and_tracks_counts() {
1234        let state = Arc::new(RoutingOccupancyState::default());
1235        state.increment(1);
1236        state.increment(1);
1237        state.increment(2);
1238
1239        let selected = state
1240            .select_exact_min_and_increment(&[1, 2, 3])
1241            .await
1242            .unwrap();
1243        assert_eq!(selected, 3);
1244
1245        let permit = OccupancyPermit::new(state.clone(), selected);
1246        assert_eq!(state.load(selected), 1);
1247        drop(permit);
1248        assert_eq!(state.load(selected), 0);
1249    }
1250
1251    #[tokio::test]
1252    async fn bidirectional_generate_bails_with_no_instances() {
1253        let rt = Runtime::from_current().unwrap();
1254        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1255            .await
1256            .unwrap();
1257        let ns = drt.namespace("test_bidi_no_instances".to_string()).unwrap();
1258        let component = ns.component("test_component".to_string()).unwrap();
1259        let endpoint = component.endpoint("test_endpoint".to_string());
1260        let client = endpoint.client().await.unwrap();
1261
1262        let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::RoundRobin)
1263            .await
1264            .unwrap();
1265
1266        let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
1267        let input: ManyIn<u64> =
1268            ResponseStream::new(Box::pin(tokio_stream::iter(vec![1u64, 2u64])), ctx);
1269        let result = router.generate(input).await;
1270        assert!(
1271            result.is_err(),
1272            "bidirectional generate must bail when no instances are registered"
1273        );
1274
1275        rt.shutdown();
1276    }
1277
1278    #[tokio::test]
1279    async fn bidirectional_generate_bails_for_kv_router_mode() {
1280        let rt = Runtime::from_current().unwrap();
1281        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1282            .await
1283            .unwrap();
1284        let ns = drt.namespace("test_bidi_kv_mode".to_string()).unwrap();
1285        let component = ns.component("test_component".to_string()).unwrap();
1286        let endpoint = component.endpoint("test_endpoint".to_string());
1287        let client = endpoint.client().await.unwrap();
1288
1289        let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::KV)
1290            .await
1291            .unwrap();
1292
1293        let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
1294        let input: ManyIn<u64> = ResponseStream::new(Box::pin(tokio_stream::iter(vec![1u64])), ctx);
1295        let result = router.generate(input).await;
1296        assert!(
1297            result.is_err(),
1298            "bidirectional generate must bail for RouterMode::KV"
1299        );
1300        let err_msg = format!("{:?}", result.unwrap_err());
1301        assert!(
1302            err_msg.contains("KV") || err_msg.contains("kv"),
1303            "error should mention KV: got {err_msg}"
1304        );
1305
1306        rt.shutdown();
1307    }
1308
1309    #[tokio::test]
1310    async fn bidirectional_generate_bails_for_direct_router_mode() {
1311        let rt = Runtime::from_current().unwrap();
1312        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1313            .await
1314            .unwrap();
1315        let ns = drt.namespace("test_bidi_direct_mode".to_string()).unwrap();
1316        let component = ns.component("test_component".to_string()).unwrap();
1317        let endpoint = component.endpoint("test_endpoint".to_string());
1318        let client = endpoint.client().await.unwrap();
1319
1320        let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::Direct)
1321            .await
1322            .unwrap();
1323
1324        let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
1325        let input: ManyIn<u64> = ResponseStream::new(Box::pin(tokio_stream::iter(vec![1u64])), ctx);
1326        let result = router.generate(input).await;
1327        assert!(
1328            result.is_err(),
1329            "bidirectional generate must bail for RouterMode::Direct"
1330        );
1331        let err_msg = format!("{:?}", result.unwrap_err());
1332        assert!(
1333            err_msg.contains("Direct") || err_msg.contains("direct"),
1334            "error should mention Direct: got {err_msg}"
1335        );
1336
1337        rt.shutdown();
1338    }
1339
1340    #[tokio::test]
1341    async fn least_loaded_select_and_peek_return_none_with_available_worker() {
1342        let rt = Runtime::from_current().unwrap();
1343        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1344            .await
1345            .unwrap();
1346        let ns = drt
1347            .namespace("test_least_loaded_router".to_string())
1348            .unwrap();
1349        let component = ns.component("test_component".to_string()).unwrap();
1350        let endpoint = component.endpoint("test_endpoint".to_string());
1351        let client = endpoint.client().await.unwrap();
1352
1353        endpoint.register_endpoint_instance().await.unwrap();
1354        client.wait_for_instances().await.unwrap();
1355
1356        let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::LeastLoaded)
1357            .await
1358            .unwrap();
1359
1360        assert_eq!(router.select_next_worker(), None);
1361        assert_eq!(router.peek_next_worker(), None);
1362
1363        rt.shutdown();
1364    }
1365
1366    #[tokio::test]
1367    async fn selected_overloaded_worker_is_rejected_before_dispatch() {
1368        const TEST_RECONCILE_INTERVAL: std::time::Duration = std::time::Duration::from_millis(50);
1369
1370        let rt = Runtime::from_current().unwrap();
1371        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1372            .await
1373            .unwrap();
1374        let ns = drt
1375            .namespace("test_selected_overloaded_worker_rejected".to_string())
1376            .unwrap();
1377        let component = ns.component("test_component".to_string()).unwrap();
1378        let endpoint = component.endpoint("test_endpoint".to_string());
1379        let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
1380            .await
1381            .unwrap();
1382
1383        endpoint.register_endpoint_instance().await.unwrap();
1384        let instances = client.wait_for_instances().await.unwrap();
1385        let worker_id = instances[0].id();
1386
1387        for _ in 0..10 {
1388            if client.instance_ids_avail().contains(&worker_id) {
1389                break;
1390            }
1391            tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
1392        }
1393        assert!(
1394            client.instance_ids_avail().contains(&worker_id),
1395            "worker should be routable before marking it overloaded"
1396        );
1397
1398        client.set_overloaded_instances(&[worker_id]);
1399        let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::RoundRobin)
1400            .await
1401            .unwrap();
1402
1403        let result = router.generate(SingleIn::new(42u64)).await;
1404        assert!(result.is_err());
1405        let msg = format!("{}", result.unwrap_err());
1406        // With pre-selection filtering on free_ids, the single-overloaded-worker
1407        // case is now caught before selection rather than after — the chosen
1408        // worker is never overloaded because the candidate pool excludes it.
1409        // The post-selection check in route() remains as a race-condition
1410        // backstop.
1411        assert!(
1412            msg.contains("All workers are busy"),
1413            "expected empty-free-pool rejection, got: {msg}"
1414        );
1415
1416        rt.shutdown();
1417    }
1418
1419    #[tokio::test]
1420    async fn round_robin_excludes_overloaded_workers_from_candidates() {
1421        // Long reconcile interval so the synthetic override below survives
1422        // the test. We still register a real endpoint instance up front so
1423        // the initial reconcile (which fires immediately when the monitor
1424        // task spawns) settles on a non-empty source — without that, the
1425        // first reconcile would clobber the override before it takes effect.
1426        const TEST_RECONCILE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(3600);
1427
1428        let rt = Runtime::from_current().unwrap();
1429        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1430            .await
1431            .unwrap();
1432        let ns = drt
1433            .namespace("test_round_robin_excludes_overloaded".to_string())
1434            .unwrap();
1435        let component = ns.component("test_component".to_string()).unwrap();
1436        let endpoint = component.endpoint("test_endpoint".to_string());
1437        let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
1438            .await
1439            .unwrap();
1440
1441        endpoint.register_endpoint_instance().await.unwrap();
1442        let instances = client.wait_for_instances().await.unwrap();
1443        let real_id = instances[0].id();
1444        for _ in 0..50 {
1445            if client.instance_ids_avail().contains(&real_id) {
1446                break;
1447            }
1448            tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1449        }
1450
1451        // Now override with two synthetic IDs and mark one overloaded.
1452        // round_robin must never select the overloaded one — that's the
1453        // whole point of selecting from free_ids instead of routable_ids.
1454        // The post-selection overload check in route() would otherwise 503
1455        // one of N requests on each pass, which is the bug this PR closes
1456        // for non-KV selectors.
1457        client.override_instance_avail(vec![1, 2]);
1458        client.set_overloaded_instances(&[1]);
1459
1460        let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::RoundRobin)
1461            .await
1462            .unwrap();
1463
1464        // Round-robin over N requests should land on worker 2 every time.
1465        // We use peek_next_worker for a side-effect-free probe.
1466        for _ in 0..6 {
1467            let selected = router
1468                .peek_next_worker()
1469                .expect("peek should succeed with a free worker");
1470            assert_eq!(
1471                selected, 2,
1472                "overloaded worker 1 must not appear in the candidate set"
1473            );
1474        }
1475
1476        rt.shutdown();
1477    }
1478
1479    #[tokio::test]
1480    async fn device_aware_cpu_only_selects_least_loaded_instance() {
1481        let state = RoutingOccupancyState::default();
1482        // All candidates are CPU. Make worker 2 the least-loaded one.
1483        for _ in 0..3 {
1484            state.increment(1);
1485        }
1486        state.increment(3);
1487
1488        let instance_ids = vec![1, 2, 3];
1489        let device_type_map = HashMap::from([
1490            (1, Some(DeviceType::Cpu)),
1491            (2, Some(DeviceType::Cpu)),
1492            (3, Some(DeviceType::Cpu)),
1493        ]);
1494
1495        let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
1496        assert_eq!(candidates, vec![1, 2, 3]);
1497
1498        let selected = state
1499            .select_exact_min_and_increment(&candidates)
1500            .await
1501            .unwrap();
1502        assert_eq!(selected, 2);
1503    }
1504
1505    #[tokio::test]
1506    async fn device_aware_non_cpu_only_selects_least_loaded_instance() {
1507        let state = RoutingOccupancyState::default();
1508        // All candidates are non-CPU. Make worker 2 the least-loaded one.
1509        for _ in 0..3 {
1510            state.increment(1);
1511        }
1512        state.increment(3);
1513
1514        let instance_ids = vec![1, 2, 3];
1515        let device_type_map = HashMap::from([
1516            (1, Some(DeviceType::Cuda)),
1517            (2, Some(DeviceType::Cuda)),
1518            (3, Some(DeviceType::Cuda)),
1519        ]);
1520
1521        let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
1522        assert_eq!(candidates, vec![1, 2, 3]);
1523
1524        let selected = state
1525            .select_exact_min_and_increment(&candidates)
1526            .await
1527            .unwrap();
1528        assert_eq!(selected, 2);
1529    }
1530
1531    #[test]
1532    fn device_aware_group_uses_ratio_budget() {
1533        let state = RoutingOccupancyState::default();
1534        // CPU ids: 1,2 ; non-CPU ids: 3,4
1535        for _ in 0..4 {
1536            state.increment(3);
1537            state.increment(4);
1538        }
1539        // CPU inflight can differ across instances; budgeting uses total CPU inflight.
1540        for _ in 0..3 {
1541            state.increment(1);
1542        }
1543        // total_non_cpu_inflight=8, cpu_count=2, non_cpu_count=2, ratio=2
1544        // allowed_cpu_inflight = 8*2/(2*2)=4
1545        // total_cpu_inflight=3 < 4 => choose CPU group.
1546        let instance_ids = vec![1, 2, 3, 4];
1547        let device_type_map = HashMap::from([
1548            (1, Some(DeviceType::Cpu)),
1549            (2, Some(DeviceType::Cpu)),
1550            (3, Some(DeviceType::Cuda)),
1551            (4, Some(DeviceType::Cuda)),
1552        ]);
1553
1554        let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 2);
1555        assert_eq!(candidates, vec![1, 2]);
1556
1557        // Within selected CPU group, final choice should be the least-loaded instance (id=2).
1558        let selected =
1559            futures::executor::block_on(state.select_exact_min_and_increment(&candidates)).unwrap();
1560        assert_eq!(selected, 2);
1561    }
1562
1563    #[tokio::test]
1564    async fn device_aware_weighted_select_and_peek_return_none_with_available_worker() {
1565        let rt = Runtime::from_current().unwrap();
1566        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1567            .await
1568            .unwrap();
1569        let ns = drt
1570            .namespace("test_device_aware_router".to_string())
1571            .unwrap();
1572        let component = ns.component("test_component".to_string()).unwrap();
1573        let endpoint = component.endpoint("test_endpoint".to_string());
1574        let client = endpoint.client().await.unwrap();
1575
1576        endpoint.register_endpoint_instance().await.unwrap();
1577        client.wait_for_instances().await.unwrap();
1578
1579        let router =
1580            PushRouter::<u64, TestResponse>::from_client(client, RouterMode::DeviceAwareWeighted)
1581                .await
1582                .unwrap();
1583
1584        assert_eq!(router.select_next_worker(), None);
1585        assert_eq!(router.peek_next_worker(), None);
1586
1587        rt.shutdown();
1588    }
1589
1590    /// When the router selects an instance that has deregistered between selection
1591    /// and transport resolution, it should fall back to another available instance
1592    /// rather than returning a 500 error.
1593    #[tokio::test]
1594    async fn transport_resolution_falls_back_when_selected_instance_disappears() {
1595        let rt = Runtime::from_current().unwrap();
1596        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1597            .await
1598            .unwrap();
1599        let ns = drt
1600            .namespace("test_transport_fallback".to_string())
1601            .unwrap();
1602        let component = ns.component("test_component".to_string()).unwrap();
1603        let endpoint = component.endpoint("test_endpoint".to_string());
1604        let client = endpoint.client().await.unwrap();
1605
1606        // Register one real instance so it appears in instance_source.
1607        endpoint.register_endpoint_instance().await.unwrap();
1608        client.wait_for_instances().await.unwrap();
1609
1610        let real_id = client.instance_ids()[0];
1611
1612        // Inject a stale ID into instance_avail that does NOT exist in
1613        // instance_source. This simulates the race window where an instance
1614        // deregistered after selection but before transport resolution.
1615        let stale_id = real_id + 1000;
1616        client.override_instance_avail(vec![stale_id, real_id]);
1617
1618        // Build a router and call direct() targeting the *real* instance to
1619        // verify the router can still resolve transport for known instances.
1620        let router =
1621            PushRouter::<u64, TestResponse>::from_client(client.clone(), RouterMode::RoundRobin)
1622                .await
1623                .unwrap();
1624
1625        // Round robin should succeed — even if it picks stale_id first, the
1626        // fallback logic should resolve transport via real_id.
1627        // We cannot fully test the network send without a worker, but we can
1628        // verify it doesn't fail at the transport resolution stage by checking
1629        // that the error (if any) is a transport/network error, not
1630        // "Instance not found".
1631        let request = SingleIn::new(42u64);
1632        let result = router.generate(request).await;
1633
1634        // The request may fail at the network level (no actual worker), but it
1635        // must NOT fail with "Instance X not found" — that would mean the
1636        // fallback did not work.
1637        if let Err(err) = &result {
1638            let msg = format!("{err}");
1639            assert!(
1640                !msg.contains("not found"),
1641                "Transport resolution should have fallen back, but got: {msg}"
1642            );
1643        }
1644
1645        rt.shutdown();
1646    }
1647
1648    /// When no instances are available at all (both primary and fallback),
1649    /// the router should return a clear error.
1650    #[tokio::test]
1651    async fn transport_resolution_errors_when_no_instances_available() {
1652        let rt = Runtime::from_current().unwrap();
1653        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1654            .await
1655            .unwrap();
1656        let ns = drt
1657            .namespace("test_transport_no_fallback".to_string())
1658            .unwrap();
1659        let component = ns.component("test_component".to_string()).unwrap();
1660        let endpoint = component.endpoint("test_endpoint".to_string());
1661        let client = endpoint.client().await.unwrap();
1662
1663        // Register an instance so we can create the router (needs transport setup).
1664        endpoint.register_endpoint_instance().await.unwrap();
1665        client.wait_for_instances().await.unwrap();
1666
1667        let router =
1668            PushRouter::<u64, TestResponse>::from_client(client.clone(), RouterMode::RoundRobin)
1669                .await
1670                .unwrap();
1671
1672        // Override avail to contain only a stale ID with no real backing
1673        // instance AND no other available fallback.
1674        let stale_id = 99999;
1675        client.override_instance_avail(vec![stale_id]);
1676
1677        let request = SingleIn::new(42u64);
1678        let result = router.generate(request).await;
1679
1680        assert!(result.is_err());
1681        let msg = format!("{}", result.unwrap_err());
1682        assert!(
1683            msg.contains("not found") && msg.contains("no other instances available"),
1684            "Expected clear error about missing instance with no fallback, got: {msg}"
1685        );
1686
1687        rt.shutdown();
1688    }
1689
1690    /// The watcher dedup guard must be released even if the spawned task panics.
1691    /// Without this, a panic anywhere in the watcher body would leave a stale
1692    /// `ENDPOINT_WATCHER_ACTIVE` entry, silently disabling orphaned-pending-
1693    /// request cancellation for that endpoint until process restart.
1694    ///
1695    /// We exercise the Drop-guard pattern directly against the same static
1696    /// rather than driving `spawn_instance_removal_watcher` end-to-end (which
1697    /// would require staging a panicking discovery stream). The test mirrors
1698    /// the production code's GuardRelease shape; if the production code stops
1699    /// using a Drop guard, the integration would regress and the existing
1700    /// orphan-cancellation tests would fail.
1701    #[tokio::test]
1702    async fn watcher_dedup_guard_released_on_panic() {
1703        let endpoint_id = EndpointId {
1704            namespace: "panic-test-ns".to_string(),
1705            component: "panic-test-comp".to_string(),
1706            name: "panic-test-endpoint".to_string(),
1707        };
1708
1709        // Mimic the production code's pre-spawn dedup insert.
1710        let map = ENDPOINT_WATCHER_ACTIVE.get_or_init(dashmap::DashMap::new);
1711        map.insert(endpoint_id.clone(), ());
1712
1713        let endpoint_id_clone = endpoint_id.clone();
1714        let join = tokio::spawn(async move {
1715            // Same shape as in spawn_instance_removal_watcher.
1716            struct GuardRelease(EndpointId);
1717            impl Drop for GuardRelease {
1718                fn drop(&mut self) {
1719                    if let Some(map) = ENDPOINT_WATCHER_ACTIVE.get() {
1720                        map.remove(&self.0);
1721                    }
1722                }
1723            }
1724            let _release = GuardRelease(endpoint_id_clone);
1725            panic!("simulated watcher-task panic");
1726        });
1727
1728        let result = join.await;
1729        assert!(result.is_err() && result.unwrap_err().is_panic());
1730        assert!(
1731            !map.contains_key(&endpoint_id),
1732            "Drop guard must release the dedup entry even on panic"
1733        );
1734    }
1735
1736    /// Normal-exit path: the Drop guard releases the entry when the task
1737    /// finishes without panicking. This is the everyday case (cancel_token
1738    /// fires or discovery stream closes).
1739    #[tokio::test]
1740    async fn watcher_dedup_guard_released_on_normal_exit() {
1741        let endpoint_id = EndpointId {
1742            namespace: "normal-test-ns".to_string(),
1743            component: "normal-test-comp".to_string(),
1744            name: "normal-test-endpoint".to_string(),
1745        };
1746
1747        let map = ENDPOINT_WATCHER_ACTIVE.get_or_init(dashmap::DashMap::new);
1748        map.insert(endpoint_id.clone(), ());
1749
1750        let endpoint_id_clone = endpoint_id.clone();
1751        tokio::spawn(async move {
1752            struct GuardRelease(EndpointId);
1753            impl Drop for GuardRelease {
1754                fn drop(&mut self) {
1755                    if let Some(map) = ENDPOINT_WATCHER_ACTIVE.get() {
1756                        map.remove(&self.0);
1757                    }
1758                }
1759            }
1760            let _release = GuardRelease(endpoint_id_clone);
1761            // task body returns normally
1762        })
1763        .await
1764        .unwrap();
1765
1766        assert!(!map.contains_key(&endpoint_id));
1767    }
1768}