1use super::{AsyncEngineContextProvider, ResponseStream};
5use crate::error::{BackendError, DynamoError, ErrorType, match_error_chain};
6use crate::{
7 component::{
8 Client, DeviceType, Endpoint, Instance, RoutingInstances, RoutingOccupancyState,
9 get_or_create_routing_occupancy_state,
10 },
11 discovery::EndpointInstanceId,
12 dynamo_nvtx_range,
13 engine::{AsyncEngine, AsyncEngineContext, Data},
14 metrics::frontend_perf::{STAGE_DURATION_SECONDS, STAGE_ROUTE},
15 pipeline::{
16 AddressedPushRouter, AddressedRequest, Error, ManyIn, ManyOut, SingleIn,
17 error::{PipelineError, PipelineErrorExt},
18 },
19 protocols::{EndpointId, maybe_error::MaybeError},
20 traits::DistributedRuntimeProvider,
21};
22use async_trait::async_trait;
23use futures::Stream;
24use rand::Rng;
25use serde::{Deserialize, Serialize};
26use std::{
27 collections::HashMap,
28 marker::PhantomData,
29 pin::Pin,
30 sync::{
31 Arc,
32 atomic::{AtomicU64, Ordering},
33 },
34 task::Poll,
35 time::Instant,
36};
37use tokio_stream::StreamExt;
38use tracing::Instrument;
39
40fn is_inhibited(err: &(dyn std::error::Error + 'static)) -> bool {
42 const INHIBITED: &[ErrorType] = &[
43 ErrorType::CannotConnect,
44 ErrorType::Disconnected,
45 ErrorType::ConnectionTimeout,
46 ErrorType::ResponseTimeout,
47 ErrorType::Backend(BackendError::EngineShutdown),
48 ];
49 match_error_chain(err, INHIBITED, &[])
50}
51
52fn response_inactivity_timeout() -> Option<std::time::Duration> {
56 use crate::config::environment_names::llm::DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS;
57 std::env::var(DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS)
58 .ok()
59 .and_then(|s| s.parse::<u64>().ok())
60 .filter(|&secs| secs > 0)
61 .map(std::time::Duration::from_secs)
62}
63
64struct OccupancyPermit {
65 state: Arc<RoutingOccupancyState>,
66 instance_id: u64,
67 armed: bool,
68}
69
70impl OccupancyPermit {
71 fn new(state: Arc<RoutingOccupancyState>, instance_id: u64) -> Self {
72 Self {
73 state,
74 instance_id,
75 armed: true,
76 }
77 }
78
79 fn into_tracked_stream<U: Data>(mut self, stream: ManyOut<U>) -> ManyOut<U> {
80 self.armed = false;
81 let engine_ctx = stream.context();
82 ResponseStream::new(
83 Box::pin(OccupancyTrackedStream {
84 inner: stream,
85 state: self.state.clone(),
86 instance_id: self.instance_id,
87 }),
88 engine_ctx,
89 )
90 }
91
92 fn instance_id(&self) -> u64 {
93 self.instance_id
94 }
95}
96
97impl Drop for OccupancyPermit {
98 fn drop(&mut self) {
99 if self.armed {
100 self.state.decrement(self.instance_id);
101 }
102 }
103}
104
105#[async_trait]
108pub trait WorkerLoadMonitor: Send + Sync {
109 async fn start_monitoring(&self) -> anyhow::Result<()>;
112}
113
114#[derive(Clone)]
115pub struct PushRouter<T, U>
116where
117 T: Data + Serialize,
118 U: Data + for<'de> Deserialize<'de>,
119{
120 pub client: Client,
123
124 router_mode: RouterMode,
131
132 round_robin_counter: Arc<AtomicU64>,
134
135 addressed: Arc<AddressedPushRouter>,
138
139 fault_detection_enabled: bool,
144
145 response_timeout: Option<std::time::Duration>,
148
149 occupancy_state: Option<Arc<RoutingOccupancyState>>,
151
152 _phantom: PhantomData<(T, U)>,
156}
157
158#[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
159#[serde(rename_all = "snake_case")]
160pub enum RouterMode {
161 #[default]
162 RoundRobin,
163 Random,
164 PowerOfTwoChoices,
165 KV,
166 Direct,
167 LeastLoaded,
168 DeviceAwareWeighted,
170}
171
172impl RouterMode {
173 pub fn is_kv_routing(&self) -> bool {
174 *self == RouterMode::KV
175 }
176
177 pub fn is_direct_routing(&self) -> bool {
178 *self == RouterMode::Direct
179 }
180}
181
182fn p2c_select_from(occupancy_state: &RoutingOccupancyState, instance_ids: &[u64]) -> u64 {
185 let count = instance_ids.len();
186 if count == 1 {
187 return instance_ids[0];
188 }
189 let mut rng = rand::rng();
190 let idx1 = rng.random_range(0..count);
191 let idx2 = (idx1 + 1 + rng.random_range(0..count - 1)) % count;
192 let id1 = instance_ids[idx1];
193 let id2 = instance_ids[idx2];
194 let load1 = occupancy_state.load(id1);
195 let load2 = occupancy_state.load(id2);
196 let selected = if load1 <= load2 { id1 } else { id2 };
197 tracing::debug!(
198 candidate_a = id1,
199 candidate_a_load = load1,
200 candidate_b = id2,
201 candidate_b_load = load2,
202 selected = selected,
203 "p2c selection"
204 );
205 selected
206}
207
208fn device_aware_candidate_group(
220 state: &RoutingOccupancyState,
221 instance_ids: &[u64],
222 device_type_map: &HashMap<u64, Option<DeviceType>>,
223 non_cpu_to_cpu_ratio: usize,
224) -> Vec<u64> {
225 let cpu_ids: Vec<u64> = instance_ids
226 .iter()
227 .copied()
228 .filter(|id| matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
229 .collect();
230 let non_cpu_ids: Vec<u64> = instance_ids
231 .iter()
232 .copied()
233 .filter(|id| !matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
234 .collect();
235
236 if cpu_ids.is_empty() {
237 return non_cpu_ids;
238 }
239 if non_cpu_ids.is_empty() {
240 return cpu_ids;
241 }
242
243 let total_non_cpu_inflight: u64 = non_cpu_ids.iter().map(|id| state.load(*id)).sum();
245 let total_cpu_inflight: u64 = cpu_ids.iter().map(|id| state.load(*id)).sum();
246 let cpu_count = cpu_ids.len() as u64;
247 let non_cpu_count = non_cpu_ids.len() as u64;
248 let allowed_cpu_inflight = total_non_cpu_inflight.saturating_mul(cpu_count)
249 / ((non_cpu_to_cpu_ratio as u64).saturating_mul(non_cpu_count));
250
251 if total_cpu_inflight < allowed_cpu_inflight {
252 cpu_ids
253 } else {
254 non_cpu_ids
255 }
256}
257
258static ENDPOINT_WATCHER_ACTIVE: std::sync::OnceLock<dashmap::DashMap<EndpointId, ()>> =
261 std::sync::OnceLock::new();
262
263fn spawn_instance_removal_watcher(
269 endpoint: Endpoint,
270 addressed: Arc<AddressedPushRouter>,
271 cancel_token: tokio_util::sync::CancellationToken,
272) {
273 use crate::discovery::{
274 DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
275 };
276 use tokio_stream::StreamExt as _;
277
278 let guard = ENDPOINT_WATCHER_ACTIVE.get_or_init(dashmap::DashMap::new);
280 let endpoint_id = endpoint.id();
281 if guard.insert(endpoint_id.clone(), ()).is_some() {
282 tracing::debug!(
283 ?endpoint_id,
284 "Instance removal watcher already running for this endpoint, skipping"
285 );
286 return;
287 }
288
289 let endpoint_name = endpoint.name().to_string();
290
291 tokio::spawn(async move {
292 struct GuardRelease(EndpointId);
295 impl Drop for GuardRelease {
296 fn drop(&mut self) {
297 if let Some(map) = ENDPOINT_WATCHER_ACTIVE.get() {
298 map.remove(&self.0);
299 }
300 }
301 }
302 let _release = GuardRelease(endpoint_id);
303
304 let namespace = endpoint.component().namespace().name();
305 let component = endpoint.component().name().to_string();
306
307 const RECONNECT_BACKOFF: std::time::Duration = std::time::Duration::from_secs(5);
309 'reconnect: loop {
310 let query = DiscoveryQuery::Endpoint {
311 namespace: namespace.clone(),
312 component: component.clone(),
313 endpoint: endpoint_name.clone(),
314 };
315
316 let mut stream = match endpoint.drt().discovery().list_and_watch(query, None).await {
317 Ok(s) => s,
318 Err(e) => {
319 tracing::warn!(
320 endpoint = %endpoint_name,
321 "Failed to start instance removal watcher (will retry): {e}"
322 );
323 tokio::select! {
324 _ = tokio::time::sleep(RECONNECT_BACKOFF) => continue 'reconnect,
325 _ = cancel_token.cancelled() => break 'reconnect,
326 }
327 }
328 };
329
330 loop {
331 tokio::select! {
332 event = stream.next() => {
333 match event {
334 Some(Ok(DiscoveryEvent::Removed(id))) => {
335 if let DiscoveryInstanceId::Endpoint(eid) = &id {
336 let n = addressed.cancel_instance_streams(eid).await;
337 if n > 0 {
338 tracing::warn!(
339 namespace = %eid.namespace,
340 component = %eid.component,
341 endpoint = %eid.endpoint,
342 instance_id = eid.instance_id,
343 cancelled = n,
344 "Cancelled pending response streams for removed \
345 instance (discovery-driven cleanup)"
346 );
347 }
348 }
349 }
350 Some(Ok(DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)))) => {
351 let eid: EndpointInstanceId = inst.endpoint_instance_id();
352 addressed.clear_instance_tombstone(&eid).await;
353 }
354 Some(Ok(_)) => {}
355 Some(Err(e)) => {
356 tracing::warn!(
357 endpoint = %endpoint_name,
358 "Instance removal watcher stream error: {e}"
359 );
360 }
361 None => {
362 tracing::warn!(
363 endpoint = %endpoint_name,
364 "Instance removal watcher stream ended; reconnecting"
365 );
366 continue 'reconnect;
367 }
368 }
369 }
370 _ = cancel_token.cancelled() => {
371 break 'reconnect;
372 }
373 }
374 }
375 }
376
377 tracing::debug!(endpoint = %endpoint_name, "Instance removal watcher exiting");
378 });
379}
380
381async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
382 AddressedPushRouter::from_runtime_provider(endpoint).await
383}
384
385impl<T, U> PushRouter<T, U>
386where
387 T: Data + Serialize,
388 U: Data + for<'de> Deserialize<'de> + MaybeError,
389{
390 pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
392 Self::from_client_with_monitor(client, router_mode, None).await
393 }
394
395 pub async fn from_client_no_fault_detection(
401 client: Client,
402 router_mode: RouterMode,
403 ) -> anyhow::Result<Self> {
404 let addressed = addressed_router(&client.endpoint).await?;
405
406 let occupancy_state = if matches!(
407 router_mode,
408 RouterMode::PowerOfTwoChoices
409 | RouterMode::LeastLoaded
410 | RouterMode::DeviceAwareWeighted
411 ) {
412 Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
413 } else {
414 None
415 };
416
417 spawn_instance_removal_watcher(
419 client.endpoint.clone(),
420 addressed.clone(),
421 client.endpoint.drt().primary_token(),
422 );
423
424 Ok(PushRouter {
425 client,
426 addressed,
427 router_mode,
428 round_robin_counter: Arc::new(AtomicU64::new(0)),
429 fault_detection_enabled: false,
430 response_timeout: response_inactivity_timeout(),
431 occupancy_state,
432 _phantom: PhantomData,
433 })
434 }
435
436 pub async fn from_client_with_monitor(
443 client: Client,
444 router_mode: RouterMode,
445 worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
446 ) -> anyhow::Result<Self> {
447 let addressed = addressed_router(&client.endpoint).await?;
448
449 if let Some(monitor) = worker_monitor.as_ref() {
451 monitor.start_monitoring().await?;
452 }
453
454 let occupancy_state = if matches!(
455 router_mode,
456 RouterMode::PowerOfTwoChoices
457 | RouterMode::LeastLoaded
458 | RouterMode::DeviceAwareWeighted
459 ) {
460 Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
461 } else {
462 None
463 };
464
465 spawn_instance_removal_watcher(
467 client.endpoint.clone(),
468 addressed.clone(),
469 client.endpoint.drt().primary_token(),
470 );
471
472 let router = PushRouter {
473 client,
474 addressed,
475 router_mode,
476 round_robin_counter: Arc::new(AtomicU64::new(0)),
477 fault_detection_enabled: true,
478 response_timeout: response_inactivity_timeout(),
479 occupancy_state,
480 _phantom: PhantomData,
481 };
482
483 Ok(router)
484 }
485
486 fn empty_free_pool_error(&self, routing_instances: &RoutingInstances) -> anyhow::Error {
489 if !routing_instances.routable_ids().is_empty() {
490 let cause = PipelineError::ServiceOverloaded(
491 "All workers are busy, please retry later".to_string(),
492 );
493 return DynamoError::builder()
494 .error_type(ErrorType::ResourceExhausted)
495 .message("All workers are busy, please retry later")
496 .cause(cause)
497 .build()
498 .into();
499 }
500 anyhow::anyhow!(
501 "no instances found for endpoint {}",
502 self.client.endpoint.id()
503 )
504 }
505
506 pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
508 let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
509
510 let instance_id = {
511 let routing_instances = self.client.routing_instances();
512 let count = routing_instances.free_ids().len();
513 if count == 0 {
514 return Err(self.empty_free_pool_error(&routing_instances));
515 }
516 routing_instances.free_ids()[counter % count]
517 };
518 tracing::trace!("round robin router selected {instance_id}");
519
520 self.generate_with_fault_detection(instance_id, request)
521 .await
522 }
523
524 pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
526 let instance_id = {
527 let routing_instances = self.client.routing_instances();
528 let count = routing_instances.free_ids().len();
529 if count == 0 {
530 return Err(self.empty_free_pool_error(&routing_instances));
531 }
532 let counter = rand::rng().random::<u64>() as usize;
533 routing_instances.free_ids()[counter % count]
534 };
535 tracing::trace!("random router selected {instance_id}");
536
537 self.generate_with_fault_detection(instance_id, request)
538 .await
539 }
540
541 pub async fn power_of_two_choices(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
544 let state = self.occupancy_state()?;
545 let instance_id = {
546 let routing_instances = self.client.routing_instances();
547 if routing_instances.free_ids().is_empty() {
548 return Err(self.empty_free_pool_error(&routing_instances));
549 }
550 p2c_select_from(state.as_ref(), routing_instances.free_ids())
551 };
552 state.increment(instance_id);
553 let permit = OccupancyPermit::new(state, instance_id);
554
555 match self
556 .generate_with_fault_detection(instance_id, request)
557 .await
558 {
559 Ok(stream) => Ok(permit.into_tracked_stream(stream)),
560 Err(err) => Err(err),
561 }
562 }
563
564 pub async fn direct(
566 &self,
567 request: SingleIn<T>,
568 instance_id: u64,
569 ) -> anyhow::Result<ManyOut<U>> {
570 let found = {
574 if self.fault_detection_enabled {
575 let routing_instances = self.client.routing_instances();
576 routing_instances.routable_ids().contains(&instance_id)
577 } else {
578 self.client.instance_ids().contains(&instance_id)
579 }
580 };
581
582 if !found {
583 return Err(anyhow::anyhow!(
584 "instance_id={instance_id} not found for endpoint {}",
585 self.client.endpoint.id()
586 ));
587 }
588
589 self.generate_with_fault_detection(instance_id, request)
590 .await
591 }
592
593 pub async fn device_aware_weighted(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
602 let state = self.occupancy_state()?;
603 let routing_instances = self.client.routing_instances();
604 let instance_ids = routing_instances.free_ids().to_vec();
605
606 if instance_ids.is_empty() {
607 return Err(self.empty_free_pool_error(&routing_instances));
608 }
609
610 let endpoint_id = self.client.endpoint.id();
612
613 let instances = self.client.instances();
615 let device_type_map: std::collections::HashMap<u64, Option<DeviceType>> = instances
616 .iter()
617 .map(|inst| (inst.instance_id, inst.device_type.clone()))
618 .collect();
619
620 let cuda_to_cpu_ratio = std::env::var("DYN_ENCODER_CUDA_TO_CPU_RATIO")
622 .ok()
623 .and_then(|v| v.parse::<usize>().ok())
624 .filter(|v| *v >= 1)
625 .unwrap_or(8);
626 let candidates = device_aware_candidate_group(
627 state.as_ref(),
628 &instance_ids,
629 &device_type_map,
630 cuda_to_cpu_ratio,
631 );
632
633 let instance_id = state
635 .select_exact_min_and_increment(&candidates)
636 .await
637 .ok_or_else(|| self.empty_free_pool_error(&routing_instances))?;
638 let permit = OccupancyPermit::new(state.clone(), instance_id);
639 let is_cpu = matches!(
640 device_type_map.get(&instance_id),
641 Some(Some(DeviceType::Cpu))
642 );
643 tracing::info!(
644 endpoint = %endpoint_id,
645 selected_instance = instance_id,
646 is_cpu,
647 "DeviceAwareWeighted selected instance"
648 );
649
650 match self
651 .generate_with_fault_detection(instance_id, request)
652 .await
653 {
654 Ok(stream) => Ok(permit.into_tracked_stream(stream)),
655 Err(err) => Err(err),
656 }
657 }
658
659 pub async fn least_loaded(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
661 let state = self.occupancy_state()?;
662 let routing_instances = self.client.routing_instances();
663 let instance_ids = routing_instances.free_ids().to_vec();
664 let instance_id = state
665 .select_exact_min_and_increment(&instance_ids)
666 .await
667 .ok_or_else(|| self.empty_free_pool_error(&routing_instances))?;
668 let permit = OccupancyPermit::new(state.clone(), instance_id);
669 tracing::trace!(
670 "least loaded router selected {instance_id} (connections: {})",
671 state.load(instance_id)
672 );
673
674 match self
675 .generate_with_fault_detection(instance_id, request)
676 .await
677 {
678 Ok(stream) => Ok(permit.into_tracked_stream(stream)),
679 Err(err) => Err(err),
680 }
681 }
682
683 pub fn select_next_worker(&self) -> Option<u64> {
687 let routing_instances = self.client.routing_instances();
688 let count = routing_instances.free_ids().len();
689 if count == 0 {
690 return None;
691 }
692
693 match self.router_mode {
694 RouterMode::RoundRobin => {
695 let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
696 Some(routing_instances.free_ids()[counter % count])
697 }
698 RouterMode::Random => {
699 let counter = rand::rng().random::<u64>() as usize;
700 Some(routing_instances.free_ids()[counter % count])
701 }
702 RouterMode::PowerOfTwoChoices
703 | RouterMode::Direct
704 | RouterMode::LeastLoaded
705 | RouterMode::DeviceAwareWeighted => None,
706 RouterMode::KV => {
707 panic!(
708 "select_next_worker should not be called for {:?} routing mode",
709 self.router_mode
710 )
711 }
712 }
713 }
714
715 pub fn peek_next_worker(&self) -> Option<u64> {
719 let routing_instances = self.client.routing_instances();
720 let count = routing_instances.free_ids().len();
721 if count == 0 {
722 return None;
723 }
724
725 match self.router_mode {
726 RouterMode::RoundRobin => {
727 let counter = self.round_robin_counter.load(Ordering::Relaxed) as usize;
729 Some(routing_instances.free_ids()[counter % count])
730 }
731 RouterMode::Random => {
732 let counter = rand::rng().random::<u64>() as usize;
735 Some(routing_instances.free_ids()[counter % count])
736 }
737 RouterMode::PowerOfTwoChoices
738 | RouterMode::Direct
739 | RouterMode::LeastLoaded
740 | RouterMode::DeviceAwareWeighted => None,
741 RouterMode::KV => {
742 panic!(
743 "peek_next_worker should not be called for {:?} routing mode",
744 self.router_mode
745 )
746 }
747 }
748 }
749
750 fn occupancy_state(&self) -> anyhow::Result<Arc<RoutingOccupancyState>> {
751 self.occupancy_state.clone().ok_or_else(|| {
752 anyhow::anyhow!(
753 "routing occupancy state not initialized for endpoint {}",
754 self.client.endpoint.id()
755 )
756 })
757 }
758
759 async fn generate_with_fault_detection(
770 &self,
771 mut instance_id: u64,
772 request: SingleIn<T>,
773 ) -> anyhow::Result<ManyOut<U>> {
774 let route_start = Instant::now();
775 let request_id = request.id().to_string();
776 let route_span = if matches!(self.router_mode, RouterMode::KV) {
777 tracing::Span::none()
778 } else {
779 tracing::info_span!(
780 "router.route_request",
781 request_id = %request_id,
782 worker_id = instance_id,
783 router_mode = ?self.router_mode,
784 )
785 };
786
787 if self.fault_detection_enabled {
789 let routing_instances = self.client.routing_instances();
790 let selected_worker_overloaded = routing_instances.is_overloaded(instance_id);
791 let counts = routing_instances.counts();
792 if tracing::enabled!(tracing::Level::DEBUG) {
793 tracing::debug!(
794 request_id = %request_id,
795 instance_id,
796 router_mode = ?self.router_mode,
797 free_workers = counts.free,
798 overloaded_workers = counts.overloaded,
799 total_workers = counts.discovered,
800 selected_worker_overloaded,
801 "checked worker overload state"
802 );
803 }
804 if selected_worker_overloaded {
805 tracing::warn!(
806 instance_id,
807 overloaded_workers = counts.overloaded,
808 total_workers = counts.discovered,
809 "Rejecting request: selected worker is overloaded"
810 );
811 let cause = PipelineError::ServiceOverloaded(
812 "Selected worker is overloaded, please retry later".to_string(),
813 );
814 return Err(DynamoError::builder()
815 .error_type(ErrorType::ResourceExhausted)
816 .message("Selected worker is overloaded, please retry later")
817 .cause(cause)
818 .build()
819 .into());
820 }
821 }
822
823 let (address, _transport_kind, instance) = {
826 use crate::component::TransportType;
827
828 let resolve_transport = |id: u64| {
829 let instances = self.client.instances();
830 instances
831 .iter()
832 .find(|i| i.instance_id == id)
833 .map(|instance| {
834 let (addr, kind) = match &instance.transport {
835 TransportType::Tcp(tcp_endpoint) => {
836 tracing::debug!(
837 instance_id = id,
838 tcp_endpoint = %tcp_endpoint,
839 "Using TCP transport for instance"
840 );
841 (tcp_endpoint.clone(), "transport.tcp.request")
842 }
843 TransportType::Nats(subject) => {
844 tracing::debug!(
845 instance_id = id,
846 subject = %subject,
847 "Using NATS transport for instance"
848 );
849 (subject.clone(), "transport.nats.request")
850 }
851 };
852 (addr, kind, instance.clone())
853 })
854 };
855
856 if let Some(result) = resolve_transport(instance_id) {
857 result
858 } else {
859 let routing_instances = self.client.routing_instances();
862 let fallback_id = routing_instances
863 .free_ids()
864 .iter()
865 .copied()
866 .find(|&id| id != instance_id);
867 match fallback_id {
868 Some(id) => {
869 tracing::warn!(
870 original_instance = instance_id,
871 fallback_instance = id,
872 "Instance disappeared during routing, reselecting"
873 );
874 instance_id = id;
875 resolve_transport(id).ok_or_else(|| {
876 anyhow::anyhow!(
877 "Fallback instance {} also not found for endpoint {}",
878 id,
879 self.client.endpoint.id()
880 )
881 })?
882 }
883 None => {
884 return Err(anyhow::anyhow!(
885 "Instance {} not found and no other instances available \
886 for endpoint {}",
887 instance_id,
888 self.client.endpoint.id()
889 ));
890 }
891 }
892 }
893 };
894
895 let request = request.map(|req| AddressedRequest::with_instance(req, address, instance));
896
897 STAGE_DURATION_SECONDS
898 .with_label_values(&[STAGE_ROUTE])
899 .observe(route_start.elapsed().as_secs_f64());
900
901 let _nvtx_transport = dynamo_nvtx_range!(_transport_kind);
902 let stream: anyhow::Result<ManyOut<U>> = self
903 .addressed
904 .generate(request)
905 .instrument(route_span)
906 .await;
907 match stream {
908 Ok(stream) => {
909 if !self.fault_detection_enabled {
910 return Ok(stream);
911 }
912 let engine_ctx = stream.context();
913 let client = self.client.clone();
914 let client_for_timeout = self.client.clone();
915 let stream = stream.map(move |res| {
916 if let Some(err) = res.err()
918 && is_inhibited(&err)
919 {
920 tracing::debug!(
921 "Reporting instance {instance_id} down due to migratable error: {err}"
922 );
923 client.report_instance_down(instance_id);
924 }
925 res
926 });
927
928 let stream: Pin<Box<dyn Stream<Item = U> + Send>> = if let Some(timeout) =
932 self.response_timeout
933 {
934 Box::pin(async_stream::stream! {
935 let mut inner = Box::pin(stream);
936 loop {
937 tokio::select! {
938 biased;
939 item = inner.next() => {
940 match item {
941 Some(item) => yield item,
942 None => break,
943 }
944 }
945 _ = tokio::time::sleep(timeout) => {
946 tracing::warn!(
947 instance_id,
948 timeout_secs = timeout.as_secs(),
949 "backend response inactivity timeout — quarantining worker"
950 );
951 client_for_timeout.report_instance_down(instance_id);
952 yield U::from_err(
953 crate::error::DynamoError::builder()
954 .error_type(crate::error::ErrorType::ResponseTimeout)
955 .message("backend response inactivity timeout")
956 .build()
957 );
958 break;
959 }
960 }
961 }
962 })
963 } else {
964 Box::pin(stream)
965 };
966
967 Ok(ResponseStream::new(stream, engine_ctx))
968 }
969 Err(err) => {
970 if self.fault_detection_enabled && is_inhibited(err.as_ref()) {
971 tracing::debug!("Reporting instance {instance_id} down due to error: {err}");
972 self.client.report_instance_down(instance_id);
973 }
974 Err(err)
975 }
976 }
977 }
978}
979
980#[async_trait]
981impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
982where
983 T: Data + Serialize,
984 U: Data + for<'de> Deserialize<'de> + MaybeError,
985{
986 async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
987 match self.router_mode {
988 RouterMode::Random => self.random(request).await,
989 RouterMode::RoundRobin => self.round_robin(request).await,
990 RouterMode::PowerOfTwoChoices => self.power_of_two_choices(request).await,
991 RouterMode::KV => {
992 anyhow::bail!("KV routing should not call generate on PushRouter");
993 }
994 RouterMode::Direct => {
995 anyhow::bail!(
996 "Direct routing should not call generate on PushRouter directly; use DirectRoutingRouter wrapper"
997 );
998 }
999 RouterMode::LeastLoaded => self.least_loaded(request).await,
1000 RouterMode::DeviceAwareWeighted => self.device_aware_weighted(request).await,
1001 }
1002 }
1003}
1004
1005#[async_trait]
1016impl<T, U> AsyncEngine<ManyIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
1017where
1018 T: Data + Serialize,
1019 U: Data + for<'de> Deserialize<'de> + MaybeError,
1020{
1021 async fn generate(&self, mut input: ManyIn<T>) -> Result<ManyOut<U>, Error> {
1022 match self.router_mode {
1023 RouterMode::KV => {
1024 anyhow::bail!("KV routing should not call generate on PushRouter");
1025 }
1026 RouterMode::Direct => {
1027 anyhow::bail!(
1028 "Direct routing should not call generate on PushRouter directly; use DirectRoutingRouter wrapper"
1029 );
1030 }
1031 _ => {}
1032 }
1033
1034 if input.next().await.is_none() {
1037 anyhow::bail!("bidirectional input stream closed before first frame");
1038 }
1039 let instance_id = self
1040 .select_next_worker()
1041 .ok_or_else(|| anyhow::anyhow!("no instances available for bidirectional routing"))?;
1042
1043 anyhow::bail!(
1047 "bidirectional remote dispatch is not yet implemented (selected instance {instance_id})"
1048 )
1049 }
1050}
1051
1052struct OccupancyTrackedStream<U: Data> {
1053 inner: ManyOut<U>,
1054 state: Arc<RoutingOccupancyState>,
1055 instance_id: u64,
1056}
1057
1058impl<U: Data> Drop for OccupancyTrackedStream<U> {
1059 fn drop(&mut self) {
1060 self.state.decrement(self.instance_id);
1061 }
1062}
1063
1064impl<U: Data> std::fmt::Debug for OccupancyTrackedStream<U> {
1065 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1066 f.debug_struct("OccupancyTrackedStream")
1067 .field("instance_id", &self.instance_id)
1068 .finish()
1069 }
1070}
1071
1072impl<U: Data> Stream for OccupancyTrackedStream<U> {
1073 type Item = U;
1074
1075 fn poll_next(
1076 mut self: Pin<&mut Self>,
1077 cx: &mut std::task::Context<'_>,
1078 ) -> Poll<Option<Self::Item>> {
1079 self.inner.as_mut().poll_next(cx)
1080 }
1081}
1082
1083impl<U: Data> AsyncEngineContextProvider for OccupancyTrackedStream<U> {
1084 fn context(&self) -> Arc<dyn AsyncEngineContext> {
1085 self.inner.context()
1086 }
1087}
1088
1089impl<U: Data> crate::engine::AsyncEngineStream<U> for OccupancyTrackedStream<U> {}
1090
1091#[cfg(test)]
1092mod tests {
1093 use super::*;
1094 use crate::{
1095 DistributedRuntime, Runtime,
1096 distributed::DistributedConfig,
1097 error::DynamoError,
1098 pipeline::{ResponseStream, context::Controller},
1099 };
1100 use serde::{Deserialize, Serialize};
1101
1102 #[derive(Clone, Debug, Deserialize, Serialize)]
1103 struct TestResponse {
1104 error: Option<DynamoError>,
1105 }
1106
1107 impl MaybeError for TestResponse {
1108 fn from_err(err: impl std::error::Error + 'static) -> Self {
1109 Self {
1110 error: Some(DynamoError::from(
1111 Box::new(err) as Box<dyn std::error::Error + 'static>
1112 )),
1113 }
1114 }
1115
1116 fn err(&self) -> Option<DynamoError> {
1117 self.error.clone()
1118 }
1119 }
1120
1121 #[test]
1122 fn p2c_selects_lower_load_worker() {
1123 let state = RoutingOccupancyState::default();
1124 for _ in 0..10 {
1125 state.increment(1);
1126 }
1127 state.increment(2);
1128
1129 let result = p2c_select_from(&state, &[1, 2]);
1131 assert_eq!(result, 2);
1132 }
1133
1134 #[test]
1135 fn p2c_selects_single_worker() {
1136 let state = RoutingOccupancyState::default();
1137 assert_eq!(p2c_select_from(&state, &[42]), 42);
1138 }
1139
1140 #[test]
1141 fn p2c_treats_missing_counts_as_zero() {
1142 let state = RoutingOccupancyState::default();
1143 for _ in 0..5 {
1144 state.increment(1);
1145 }
1146 let result = p2c_select_from(&state, &[1, 2]);
1148 assert_eq!(result, 2);
1149 }
1150
1151 #[test]
1152 fn p2c_returns_valid_worker_on_tie() {
1153 let state = RoutingOccupancyState::default();
1154 for _ in 0..3 {
1155 state.increment(1);
1156 state.increment(2);
1157 }
1158
1159 for _ in 0..100 {
1160 let result = p2c_select_from(&state, &[1, 2]);
1161 assert!(result == 1 || result == 2);
1162 }
1163 }
1164
1165 #[test]
1166 fn occupancy_permit_decrements_before_stream_creation() {
1167 let state = Arc::new(RoutingOccupancyState::default());
1168 state.increment(42);
1169 let permit = OccupancyPermit::new(state.clone(), 42);
1170 assert_eq!(state.load(42), 1);
1171 drop(permit);
1172 assert_eq!(state.load(42), 0);
1173 }
1174
1175 #[test]
1176 fn occupancy_tracked_stream_decrements_on_drop() {
1177 let state = Arc::new(RoutingOccupancyState::default());
1178 state.increment(7);
1179 let permit = OccupancyPermit::new(state.clone(), 7);
1180 let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
1181 let stream = permit.into_tracked_stream(ResponseStream::new(
1182 Box::pin(tokio_stream::iter(vec![1u64])),
1183 ctx,
1184 ));
1185 assert_eq!(state.load(7), 1);
1186 drop(stream);
1187 assert_eq!(state.load(7), 0);
1188 }
1189
1190 #[test]
1191 fn p2c_lifecycle_tracks_inflight_counts_with_shared_tracker() {
1192 let state = Arc::new(RoutingOccupancyState::default());
1193 let mut permits = Vec::new();
1194 for _ in 0..5 {
1195 let selected = p2c_select_from(&state, &[1, 2]);
1196 state.increment(selected);
1197 permits.push(OccupancyPermit::new(state.clone(), selected));
1198 }
1199
1200 let total = state.load(1) + state.load(2);
1201 assert_eq!(total, 5, "5 in-flight requests should be tracked");
1202
1203 drop(permits);
1204 let total = state.load(1) + state.load(2);
1205 assert_eq!(total, 0, "All guards dropped, counts should be 0");
1206 }
1207
1208 #[test]
1209 fn p2c_never_selects_dominated_worker() {
1210 let state = RoutingOccupancyState::default();
1211 for _ in 0..100 {
1212 state.increment(3);
1213 }
1214
1215 let mut selected = [0u32; 3];
1216 for _ in 0..1000 {
1217 let result = p2c_select_from(&state, &[1, 2, 3]);
1218 match result {
1219 1 => selected[0] += 1,
1220 2 => selected[1] += 1,
1221 3 => selected[2] += 1,
1222 _ => panic!("unexpected worker id"),
1223 }
1224 }
1225 assert_eq!(
1226 selected[2], 0,
1227 "Worker 3 (load=100) should never be selected against load=0 workers, but got {} times",
1228 selected[2]
1229 );
1230 }
1231
1232 #[tokio::test]
1233 async fn least_loaded_selects_exact_min_and_tracks_counts() {
1234 let state = Arc::new(RoutingOccupancyState::default());
1235 state.increment(1);
1236 state.increment(1);
1237 state.increment(2);
1238
1239 let selected = state
1240 .select_exact_min_and_increment(&[1, 2, 3])
1241 .await
1242 .unwrap();
1243 assert_eq!(selected, 3);
1244
1245 let permit = OccupancyPermit::new(state.clone(), selected);
1246 assert_eq!(state.load(selected), 1);
1247 drop(permit);
1248 assert_eq!(state.load(selected), 0);
1249 }
1250
1251 #[tokio::test]
1252 async fn bidirectional_generate_bails_with_no_instances() {
1253 let rt = Runtime::from_current().unwrap();
1254 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1255 .await
1256 .unwrap();
1257 let ns = drt.namespace("test_bidi_no_instances".to_string()).unwrap();
1258 let component = ns.component("test_component".to_string()).unwrap();
1259 let endpoint = component.endpoint("test_endpoint".to_string());
1260 let client = endpoint.client().await.unwrap();
1261
1262 let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::RoundRobin)
1263 .await
1264 .unwrap();
1265
1266 let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
1267 let input: ManyIn<u64> =
1268 ResponseStream::new(Box::pin(tokio_stream::iter(vec![1u64, 2u64])), ctx);
1269 let result = router.generate(input).await;
1270 assert!(
1271 result.is_err(),
1272 "bidirectional generate must bail when no instances are registered"
1273 );
1274
1275 rt.shutdown();
1276 }
1277
1278 #[tokio::test]
1279 async fn bidirectional_generate_bails_for_kv_router_mode() {
1280 let rt = Runtime::from_current().unwrap();
1281 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1282 .await
1283 .unwrap();
1284 let ns = drt.namespace("test_bidi_kv_mode".to_string()).unwrap();
1285 let component = ns.component("test_component".to_string()).unwrap();
1286 let endpoint = component.endpoint("test_endpoint".to_string());
1287 let client = endpoint.client().await.unwrap();
1288
1289 let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::KV)
1290 .await
1291 .unwrap();
1292
1293 let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
1294 let input: ManyIn<u64> = ResponseStream::new(Box::pin(tokio_stream::iter(vec![1u64])), ctx);
1295 let result = router.generate(input).await;
1296 assert!(
1297 result.is_err(),
1298 "bidirectional generate must bail for RouterMode::KV"
1299 );
1300 let err_msg = format!("{:?}", result.unwrap_err());
1301 assert!(
1302 err_msg.contains("KV") || err_msg.contains("kv"),
1303 "error should mention KV: got {err_msg}"
1304 );
1305
1306 rt.shutdown();
1307 }
1308
1309 #[tokio::test]
1310 async fn bidirectional_generate_bails_for_direct_router_mode() {
1311 let rt = Runtime::from_current().unwrap();
1312 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1313 .await
1314 .unwrap();
1315 let ns = drt.namespace("test_bidi_direct_mode".to_string()).unwrap();
1316 let component = ns.component("test_component".to_string()).unwrap();
1317 let endpoint = component.endpoint("test_endpoint".to_string());
1318 let client = endpoint.client().await.unwrap();
1319
1320 let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::Direct)
1321 .await
1322 .unwrap();
1323
1324 let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
1325 let input: ManyIn<u64> = ResponseStream::new(Box::pin(tokio_stream::iter(vec![1u64])), ctx);
1326 let result = router.generate(input).await;
1327 assert!(
1328 result.is_err(),
1329 "bidirectional generate must bail for RouterMode::Direct"
1330 );
1331 let err_msg = format!("{:?}", result.unwrap_err());
1332 assert!(
1333 err_msg.contains("Direct") || err_msg.contains("direct"),
1334 "error should mention Direct: got {err_msg}"
1335 );
1336
1337 rt.shutdown();
1338 }
1339
1340 #[tokio::test]
1341 async fn least_loaded_select_and_peek_return_none_with_available_worker() {
1342 let rt = Runtime::from_current().unwrap();
1343 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1344 .await
1345 .unwrap();
1346 let ns = drt
1347 .namespace("test_least_loaded_router".to_string())
1348 .unwrap();
1349 let component = ns.component("test_component".to_string()).unwrap();
1350 let endpoint = component.endpoint("test_endpoint".to_string());
1351 let client = endpoint.client().await.unwrap();
1352
1353 endpoint.register_endpoint_instance().await.unwrap();
1354 client.wait_for_instances().await.unwrap();
1355
1356 let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::LeastLoaded)
1357 .await
1358 .unwrap();
1359
1360 assert_eq!(router.select_next_worker(), None);
1361 assert_eq!(router.peek_next_worker(), None);
1362
1363 rt.shutdown();
1364 }
1365
1366 #[tokio::test]
1367 async fn selected_overloaded_worker_is_rejected_before_dispatch() {
1368 const TEST_RECONCILE_INTERVAL: std::time::Duration = std::time::Duration::from_millis(50);
1369
1370 let rt = Runtime::from_current().unwrap();
1371 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1372 .await
1373 .unwrap();
1374 let ns = drt
1375 .namespace("test_selected_overloaded_worker_rejected".to_string())
1376 .unwrap();
1377 let component = ns.component("test_component".to_string()).unwrap();
1378 let endpoint = component.endpoint("test_endpoint".to_string());
1379 let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
1380 .await
1381 .unwrap();
1382
1383 endpoint.register_endpoint_instance().await.unwrap();
1384 let instances = client.wait_for_instances().await.unwrap();
1385 let worker_id = instances[0].id();
1386
1387 for _ in 0..10 {
1388 if client.instance_ids_avail().contains(&worker_id) {
1389 break;
1390 }
1391 tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
1392 }
1393 assert!(
1394 client.instance_ids_avail().contains(&worker_id),
1395 "worker should be routable before marking it overloaded"
1396 );
1397
1398 client.set_overloaded_instances(&[worker_id]);
1399 let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::RoundRobin)
1400 .await
1401 .unwrap();
1402
1403 let result = router.generate(SingleIn::new(42u64)).await;
1404 assert!(result.is_err());
1405 let msg = format!("{}", result.unwrap_err());
1406 assert!(
1412 msg.contains("All workers are busy"),
1413 "expected empty-free-pool rejection, got: {msg}"
1414 );
1415
1416 rt.shutdown();
1417 }
1418
1419 #[tokio::test]
1420 async fn round_robin_excludes_overloaded_workers_from_candidates() {
1421 const TEST_RECONCILE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(3600);
1427
1428 let rt = Runtime::from_current().unwrap();
1429 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1430 .await
1431 .unwrap();
1432 let ns = drt
1433 .namespace("test_round_robin_excludes_overloaded".to_string())
1434 .unwrap();
1435 let component = ns.component("test_component".to_string()).unwrap();
1436 let endpoint = component.endpoint("test_endpoint".to_string());
1437 let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
1438 .await
1439 .unwrap();
1440
1441 endpoint.register_endpoint_instance().await.unwrap();
1442 let instances = client.wait_for_instances().await.unwrap();
1443 let real_id = instances[0].id();
1444 for _ in 0..50 {
1445 if client.instance_ids_avail().contains(&real_id) {
1446 break;
1447 }
1448 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1449 }
1450
1451 client.override_instance_avail(vec![1, 2]);
1458 client.set_overloaded_instances(&[1]);
1459
1460 let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::RoundRobin)
1461 .await
1462 .unwrap();
1463
1464 for _ in 0..6 {
1467 let selected = router
1468 .peek_next_worker()
1469 .expect("peek should succeed with a free worker");
1470 assert_eq!(
1471 selected, 2,
1472 "overloaded worker 1 must not appear in the candidate set"
1473 );
1474 }
1475
1476 rt.shutdown();
1477 }
1478
1479 #[tokio::test]
1480 async fn device_aware_cpu_only_selects_least_loaded_instance() {
1481 let state = RoutingOccupancyState::default();
1482 for _ in 0..3 {
1484 state.increment(1);
1485 }
1486 state.increment(3);
1487
1488 let instance_ids = vec![1, 2, 3];
1489 let device_type_map = HashMap::from([
1490 (1, Some(DeviceType::Cpu)),
1491 (2, Some(DeviceType::Cpu)),
1492 (3, Some(DeviceType::Cpu)),
1493 ]);
1494
1495 let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
1496 assert_eq!(candidates, vec![1, 2, 3]);
1497
1498 let selected = state
1499 .select_exact_min_and_increment(&candidates)
1500 .await
1501 .unwrap();
1502 assert_eq!(selected, 2);
1503 }
1504
1505 #[tokio::test]
1506 async fn device_aware_non_cpu_only_selects_least_loaded_instance() {
1507 let state = RoutingOccupancyState::default();
1508 for _ in 0..3 {
1510 state.increment(1);
1511 }
1512 state.increment(3);
1513
1514 let instance_ids = vec![1, 2, 3];
1515 let device_type_map = HashMap::from([
1516 (1, Some(DeviceType::Cuda)),
1517 (2, Some(DeviceType::Cuda)),
1518 (3, Some(DeviceType::Cuda)),
1519 ]);
1520
1521 let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
1522 assert_eq!(candidates, vec![1, 2, 3]);
1523
1524 let selected = state
1525 .select_exact_min_and_increment(&candidates)
1526 .await
1527 .unwrap();
1528 assert_eq!(selected, 2);
1529 }
1530
1531 #[test]
1532 fn device_aware_group_uses_ratio_budget() {
1533 let state = RoutingOccupancyState::default();
1534 for _ in 0..4 {
1536 state.increment(3);
1537 state.increment(4);
1538 }
1539 for _ in 0..3 {
1541 state.increment(1);
1542 }
1543 let instance_ids = vec![1, 2, 3, 4];
1547 let device_type_map = HashMap::from([
1548 (1, Some(DeviceType::Cpu)),
1549 (2, Some(DeviceType::Cpu)),
1550 (3, Some(DeviceType::Cuda)),
1551 (4, Some(DeviceType::Cuda)),
1552 ]);
1553
1554 let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 2);
1555 assert_eq!(candidates, vec![1, 2]);
1556
1557 let selected =
1559 futures::executor::block_on(state.select_exact_min_and_increment(&candidates)).unwrap();
1560 assert_eq!(selected, 2);
1561 }
1562
1563 #[tokio::test]
1564 async fn device_aware_weighted_select_and_peek_return_none_with_available_worker() {
1565 let rt = Runtime::from_current().unwrap();
1566 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1567 .await
1568 .unwrap();
1569 let ns = drt
1570 .namespace("test_device_aware_router".to_string())
1571 .unwrap();
1572 let component = ns.component("test_component".to_string()).unwrap();
1573 let endpoint = component.endpoint("test_endpoint".to_string());
1574 let client = endpoint.client().await.unwrap();
1575
1576 endpoint.register_endpoint_instance().await.unwrap();
1577 client.wait_for_instances().await.unwrap();
1578
1579 let router =
1580 PushRouter::<u64, TestResponse>::from_client(client, RouterMode::DeviceAwareWeighted)
1581 .await
1582 .unwrap();
1583
1584 assert_eq!(router.select_next_worker(), None);
1585 assert_eq!(router.peek_next_worker(), None);
1586
1587 rt.shutdown();
1588 }
1589
1590 #[tokio::test]
1594 async fn transport_resolution_falls_back_when_selected_instance_disappears() {
1595 let rt = Runtime::from_current().unwrap();
1596 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1597 .await
1598 .unwrap();
1599 let ns = drt
1600 .namespace("test_transport_fallback".to_string())
1601 .unwrap();
1602 let component = ns.component("test_component".to_string()).unwrap();
1603 let endpoint = component.endpoint("test_endpoint".to_string());
1604 let client = endpoint.client().await.unwrap();
1605
1606 endpoint.register_endpoint_instance().await.unwrap();
1608 client.wait_for_instances().await.unwrap();
1609
1610 let real_id = client.instance_ids()[0];
1611
1612 let stale_id = real_id + 1000;
1616 client.override_instance_avail(vec![stale_id, real_id]);
1617
1618 let router =
1621 PushRouter::<u64, TestResponse>::from_client(client.clone(), RouterMode::RoundRobin)
1622 .await
1623 .unwrap();
1624
1625 let request = SingleIn::new(42u64);
1632 let result = router.generate(request).await;
1633
1634 if let Err(err) = &result {
1638 let msg = format!("{err}");
1639 assert!(
1640 !msg.contains("not found"),
1641 "Transport resolution should have fallen back, but got: {msg}"
1642 );
1643 }
1644
1645 rt.shutdown();
1646 }
1647
1648 #[tokio::test]
1651 async fn transport_resolution_errors_when_no_instances_available() {
1652 let rt = Runtime::from_current().unwrap();
1653 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1654 .await
1655 .unwrap();
1656 let ns = drt
1657 .namespace("test_transport_no_fallback".to_string())
1658 .unwrap();
1659 let component = ns.component("test_component".to_string()).unwrap();
1660 let endpoint = component.endpoint("test_endpoint".to_string());
1661 let client = endpoint.client().await.unwrap();
1662
1663 endpoint.register_endpoint_instance().await.unwrap();
1665 client.wait_for_instances().await.unwrap();
1666
1667 let router =
1668 PushRouter::<u64, TestResponse>::from_client(client.clone(), RouterMode::RoundRobin)
1669 .await
1670 .unwrap();
1671
1672 let stale_id = 99999;
1675 client.override_instance_avail(vec![stale_id]);
1676
1677 let request = SingleIn::new(42u64);
1678 let result = router.generate(request).await;
1679
1680 assert!(result.is_err());
1681 let msg = format!("{}", result.unwrap_err());
1682 assert!(
1683 msg.contains("not found") && msg.contains("no other instances available"),
1684 "Expected clear error about missing instance with no fallback, got: {msg}"
1685 );
1686
1687 rt.shutdown();
1688 }
1689
1690 #[tokio::test]
1702 async fn watcher_dedup_guard_released_on_panic() {
1703 let endpoint_id = EndpointId {
1704 namespace: "panic-test-ns".to_string(),
1705 component: "panic-test-comp".to_string(),
1706 name: "panic-test-endpoint".to_string(),
1707 };
1708
1709 let map = ENDPOINT_WATCHER_ACTIVE.get_or_init(dashmap::DashMap::new);
1711 map.insert(endpoint_id.clone(), ());
1712
1713 let endpoint_id_clone = endpoint_id.clone();
1714 let join = tokio::spawn(async move {
1715 struct GuardRelease(EndpointId);
1717 impl Drop for GuardRelease {
1718 fn drop(&mut self) {
1719 if let Some(map) = ENDPOINT_WATCHER_ACTIVE.get() {
1720 map.remove(&self.0);
1721 }
1722 }
1723 }
1724 let _release = GuardRelease(endpoint_id_clone);
1725 panic!("simulated watcher-task panic");
1726 });
1727
1728 let result = join.await;
1729 assert!(result.is_err() && result.unwrap_err().is_panic());
1730 assert!(
1731 !map.contains_key(&endpoint_id),
1732 "Drop guard must release the dedup entry even on panic"
1733 );
1734 }
1735
1736 #[tokio::test]
1740 async fn watcher_dedup_guard_released_on_normal_exit() {
1741 let endpoint_id = EndpointId {
1742 namespace: "normal-test-ns".to_string(),
1743 component: "normal-test-comp".to_string(),
1744 name: "normal-test-endpoint".to_string(),
1745 };
1746
1747 let map = ENDPOINT_WATCHER_ACTIVE.get_or_init(dashmap::DashMap::new);
1748 map.insert(endpoint_id.clone(), ());
1749
1750 let endpoint_id_clone = endpoint_id.clone();
1751 tokio::spawn(async move {
1752 struct GuardRelease(EndpointId);
1753 impl Drop for GuardRelease {
1754 fn drop(&mut self) {
1755 if let Some(map) = ENDPOINT_WATCHER_ACTIVE.get() {
1756 map.remove(&self.0);
1757 }
1758 }
1759 }
1760 let _release = GuardRelease(endpoint_id_clone);
1761 })
1763 .await
1764 .unwrap();
1765
1766 assert!(!map.contains_key(&endpoint_id));
1767 }
1768}