1use super::{AsyncEngineContextProvider, ResponseStream};
5use crate::error::{BackendError, DynamoError, ErrorType, match_error_chain};
6use crate::{
7 component::{
8 Client, DeviceType, Endpoint, Instance, RoutingOccupancyState,
9 get_or_create_routing_occupancy_state,
10 },
11 discovery::EndpointInstanceId,
12 dynamo_nvtx_range,
13 engine::{AsyncEngine, AsyncEngineContext, Data},
14 metrics::frontend_perf::{STAGE_DURATION_SECONDS, STAGE_ROUTE},
15 pipeline::{
16 AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
17 error::{PipelineError, PipelineErrorExt},
18 },
19 protocols::{EndpointId, maybe_error::MaybeError},
20 traits::DistributedRuntimeProvider,
21};
22use async_trait::async_trait;
23use futures::Stream;
24use rand::Rng;
25use serde::{Deserialize, Serialize};
26use std::{
27 collections::HashMap,
28 marker::PhantomData,
29 pin::Pin,
30 sync::{
31 Arc,
32 atomic::{AtomicU64, Ordering},
33 },
34 task::Poll,
35 time::Instant,
36};
37use tokio_stream::StreamExt;
38use tracing::Instrument;
39
40fn is_inhibited(err: &(dyn std::error::Error + 'static)) -> bool {
42 const INHIBITED: &[ErrorType] = &[
43 ErrorType::CannotConnect,
44 ErrorType::Disconnected,
45 ErrorType::ConnectionTimeout,
46 ErrorType::ResponseTimeout,
47 ErrorType::Backend(BackendError::EngineShutdown),
48 ];
49 match_error_chain(err, INHIBITED, &[])
50}
51
52fn response_inactivity_timeout() -> Option<std::time::Duration> {
56 use crate::config::environment_names::llm::DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS;
57 std::env::var(DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS)
58 .ok()
59 .and_then(|s| s.parse::<u64>().ok())
60 .filter(|&secs| secs > 0)
61 .map(std::time::Duration::from_secs)
62}
63
64struct OccupancyPermit {
65 state: Arc<RoutingOccupancyState>,
66 instance_id: u64,
67 armed: bool,
68}
69
70impl OccupancyPermit {
71 fn new(state: Arc<RoutingOccupancyState>, instance_id: u64) -> Self {
72 Self {
73 state,
74 instance_id,
75 armed: true,
76 }
77 }
78
79 fn into_tracked_stream<U: Data>(mut self, stream: ManyOut<U>) -> ManyOut<U> {
80 self.armed = false;
81 let engine_ctx = stream.context();
82 ResponseStream::new(
83 Box::pin(OccupancyTrackedStream {
84 inner: stream,
85 state: self.state.clone(),
86 instance_id: self.instance_id,
87 }),
88 engine_ctx,
89 )
90 }
91
92 fn instance_id(&self) -> u64 {
93 self.instance_id
94 }
95}
96
97impl Drop for OccupancyPermit {
98 fn drop(&mut self) {
99 if self.armed {
100 self.state.decrement(self.instance_id);
101 }
102 }
103}
104
105#[async_trait]
108pub trait WorkerLoadMonitor: Send + Sync {
109 async fn start_monitoring(&self) -> anyhow::Result<()>;
112}
113
114#[derive(Clone)]
115pub struct PushRouter<T, U>
116where
117 T: Data + Serialize,
118 U: Data + for<'de> Deserialize<'de>,
119{
120 pub client: Client,
123
124 router_mode: RouterMode,
131
132 round_robin_counter: Arc<AtomicU64>,
134
135 addressed: Arc<AddressedPushRouter>,
138
139 fault_detection_enabled: bool,
144
145 response_timeout: Option<std::time::Duration>,
148
149 occupancy_state: Option<Arc<RoutingOccupancyState>>,
151
152 _phantom: PhantomData<(T, U)>,
156}
157
158#[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
159#[serde(rename_all = "snake_case")]
160pub enum RouterMode {
161 #[default]
162 RoundRobin,
163 Random,
164 PowerOfTwoChoices,
165 KV,
166 Direct,
167 LeastLoaded,
168 DeviceAwareWeighted,
170}
171
172impl RouterMode {
173 pub fn is_kv_routing(&self) -> bool {
174 *self == RouterMode::KV
175 }
176
177 pub fn is_direct_routing(&self) -> bool {
178 *self == RouterMode::Direct
179 }
180}
181
182fn p2c_select_from(occupancy_state: &RoutingOccupancyState, instance_ids: &[u64]) -> u64 {
185 let count = instance_ids.len();
186 if count == 1 {
187 return instance_ids[0];
188 }
189 let mut rng = rand::rng();
190 let idx1 = rng.random_range(0..count);
191 let idx2 = (idx1 + 1 + rng.random_range(0..count - 1)) % count;
192 let id1 = instance_ids[idx1];
193 let id2 = instance_ids[idx2];
194 let load1 = occupancy_state.load(id1);
195 let load2 = occupancy_state.load(id2);
196 let selected = if load1 <= load2 { id1 } else { id2 };
197 tracing::debug!(
198 candidate_a = id1,
199 candidate_a_load = load1,
200 candidate_b = id2,
201 candidate_b_load = load2,
202 selected = selected,
203 "p2c selection"
204 );
205 selected
206}
207
208fn device_aware_candidate_group(
220 state: &RoutingOccupancyState,
221 instance_ids: &[u64],
222 device_type_map: &HashMap<u64, Option<DeviceType>>,
223 non_cpu_to_cpu_ratio: usize,
224) -> Vec<u64> {
225 let cpu_ids: Vec<u64> = instance_ids
226 .iter()
227 .copied()
228 .filter(|id| matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
229 .collect();
230 let non_cpu_ids: Vec<u64> = instance_ids
231 .iter()
232 .copied()
233 .filter(|id| !matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
234 .collect();
235
236 if cpu_ids.is_empty() {
237 return non_cpu_ids;
238 }
239 if non_cpu_ids.is_empty() {
240 return cpu_ids;
241 }
242
243 let total_non_cpu_inflight: u64 = non_cpu_ids.iter().map(|id| state.load(*id)).sum();
245 let total_cpu_inflight: u64 = cpu_ids.iter().map(|id| state.load(*id)).sum();
246 let cpu_count = cpu_ids.len() as u64;
247 let non_cpu_count = non_cpu_ids.len() as u64;
248 let allowed_cpu_inflight = total_non_cpu_inflight.saturating_mul(cpu_count)
249 / ((non_cpu_to_cpu_ratio as u64).saturating_mul(non_cpu_count));
250
251 if total_cpu_inflight < allowed_cpu_inflight {
252 cpu_ids
253 } else {
254 non_cpu_ids
255 }
256}
257
258static ENDPOINT_WATCHER_ACTIVE: std::sync::OnceLock<dashmap::DashMap<EndpointId, ()>> =
261 std::sync::OnceLock::new();
262
263fn spawn_instance_removal_watcher(
269 endpoint: Endpoint,
270 addressed: Arc<AddressedPushRouter>,
271 cancel_token: tokio_util::sync::CancellationToken,
272) {
273 use crate::discovery::{
274 DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
275 };
276 use tokio_stream::StreamExt as _;
277
278 let guard = ENDPOINT_WATCHER_ACTIVE.get_or_init(dashmap::DashMap::new);
280 let endpoint_id = endpoint.id();
281 if guard.insert(endpoint_id.clone(), ()).is_some() {
282 tracing::debug!(
283 ?endpoint_id,
284 "Instance removal watcher already running for this endpoint, skipping"
285 );
286 return;
287 }
288
289 let endpoint_name = endpoint.name().to_string();
290
291 tokio::spawn(async move {
292 struct GuardRelease(EndpointId);
295 impl Drop for GuardRelease {
296 fn drop(&mut self) {
297 if let Some(map) = ENDPOINT_WATCHER_ACTIVE.get() {
298 map.remove(&self.0);
299 }
300 }
301 }
302 let _release = GuardRelease(endpoint_id);
303
304 let namespace = endpoint.component().namespace().name();
305 let component = endpoint.component().name().to_string();
306
307 const RECONNECT_BACKOFF: std::time::Duration = std::time::Duration::from_secs(5);
309 'reconnect: loop {
310 let query = DiscoveryQuery::Endpoint {
311 namespace: namespace.clone(),
312 component: component.clone(),
313 endpoint: endpoint_name.clone(),
314 };
315
316 let mut stream = match endpoint.drt().discovery().list_and_watch(query, None).await {
317 Ok(s) => s,
318 Err(e) => {
319 tracing::warn!(
320 endpoint = %endpoint_name,
321 "Failed to start instance removal watcher (will retry): {e}"
322 );
323 tokio::select! {
324 _ = tokio::time::sleep(RECONNECT_BACKOFF) => continue 'reconnect,
325 _ = cancel_token.cancelled() => break 'reconnect,
326 }
327 }
328 };
329
330 loop {
331 tokio::select! {
332 event = stream.next() => {
333 match event {
334 Some(Ok(DiscoveryEvent::Removed(id))) => {
335 if let DiscoveryInstanceId::Endpoint(eid) = &id {
336 let n = addressed.cancel_instance_streams(eid).await;
337 if n > 0 {
338 tracing::warn!(
339 namespace = %eid.namespace,
340 component = %eid.component,
341 endpoint = %eid.endpoint,
342 instance_id = eid.instance_id,
343 cancelled = n,
344 "Cancelled pending response streams for removed \
345 instance (discovery-driven cleanup)"
346 );
347 }
348 }
349 }
350 Some(Ok(DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)))) => {
351 let eid: EndpointInstanceId = inst.endpoint_instance_id();
352 addressed.clear_instance_tombstone(&eid).await;
353 }
354 Some(Ok(_)) => {}
355 Some(Err(e)) => {
356 tracing::warn!(
357 endpoint = %endpoint_name,
358 "Instance removal watcher stream error: {e}"
359 );
360 }
361 None => {
362 tracing::warn!(
363 endpoint = %endpoint_name,
364 "Instance removal watcher stream ended; reconnecting"
365 );
366 continue 'reconnect;
367 }
368 }
369 }
370 _ = cancel_token.cancelled() => {
371 break 'reconnect;
372 }
373 }
374 }
375 }
376
377 tracing::debug!(endpoint = %endpoint_name, "Instance removal watcher exiting");
378 });
379}
380
381async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
382 let manager = endpoint.drt().network_manager();
384 let req_client = manager.create_client()?;
385 let resp_transport = endpoint.drt().tcp_server().await?;
386
387 tracing::debug!(
388 transport = req_client.transport_name(),
389 "Creating AddressedPushRouter with request plane client"
390 );
391
392 AddressedPushRouter::new(req_client, resp_transport)
393}
394
395impl<T, U> PushRouter<T, U>
396where
397 T: Data + Serialize,
398 U: Data + for<'de> Deserialize<'de> + MaybeError,
399{
400 pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
402 Self::from_client_with_monitor(client, router_mode, None).await
403 }
404
405 pub async fn from_client_no_fault_detection(
411 client: Client,
412 router_mode: RouterMode,
413 ) -> anyhow::Result<Self> {
414 let addressed = addressed_router(&client.endpoint).await?;
415
416 let occupancy_state = if matches!(
417 router_mode,
418 RouterMode::PowerOfTwoChoices
419 | RouterMode::LeastLoaded
420 | RouterMode::DeviceAwareWeighted
421 ) {
422 Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
423 } else {
424 None
425 };
426
427 spawn_instance_removal_watcher(
429 client.endpoint.clone(),
430 addressed.clone(),
431 client.endpoint.drt().primary_token(),
432 );
433
434 Ok(PushRouter {
435 client,
436 addressed,
437 router_mode,
438 round_robin_counter: Arc::new(AtomicU64::new(0)),
439 fault_detection_enabled: false,
440 response_timeout: response_inactivity_timeout(),
441 occupancy_state,
442 _phantom: PhantomData,
443 })
444 }
445
446 pub async fn from_client_with_monitor(
453 client: Client,
454 router_mode: RouterMode,
455 worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
456 ) -> anyhow::Result<Self> {
457 let addressed = addressed_router(&client.endpoint).await?;
458
459 if let Some(monitor) = worker_monitor.as_ref() {
461 monitor.start_monitoring().await?;
462 }
463
464 let occupancy_state = if matches!(
465 router_mode,
466 RouterMode::PowerOfTwoChoices
467 | RouterMode::LeastLoaded
468 | RouterMode::DeviceAwareWeighted
469 ) {
470 Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
471 } else {
472 None
473 };
474
475 spawn_instance_removal_watcher(
477 client.endpoint.clone(),
478 addressed.clone(),
479 client.endpoint.drt().primary_token(),
480 );
481
482 let router = PushRouter {
483 client,
484 addressed,
485 router_mode,
486 round_robin_counter: Arc::new(AtomicU64::new(0)),
487 fault_detection_enabled: true,
488 response_timeout: response_inactivity_timeout(),
489 occupancy_state,
490 _phantom: PhantomData,
491 };
492
493 Ok(router)
494 }
495
496 pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
498 let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
499
500 let instance_id = {
501 let instance_ids = self.client.instance_ids_avail();
502 let count = instance_ids.len();
503 if count == 0 {
504 return Err(anyhow::anyhow!(
505 "no instances found for endpoint {}",
506 self.client.endpoint.id()
507 ));
508 }
509 instance_ids[counter % count]
510 };
511 tracing::trace!("round robin router selected {instance_id}");
512
513 self.generate_with_fault_detection(instance_id, request)
514 .await
515 }
516
517 pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
519 let instance_id = {
520 let instance_ids = self.client.instance_ids_avail();
521 let count = instance_ids.len();
522 if count == 0 {
523 return Err(anyhow::anyhow!(
524 "no instances found for endpoint {}",
525 self.client.endpoint.id()
526 ));
527 }
528 let counter = rand::rng().random::<u64>() as usize;
529 instance_ids[counter % count]
530 };
531 tracing::trace!("random router selected {instance_id}");
532
533 self.generate_with_fault_detection(instance_id, request)
534 .await
535 }
536
537 pub async fn power_of_two_choices(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
540 let state = self.occupancy_state()?;
541 let instance_id = {
542 let instance_ids = self
543 .client
544 .instance_ids_avail()
545 .iter()
546 .copied()
547 .collect::<Vec<_>>();
548 if instance_ids.is_empty() {
549 return Err(anyhow::anyhow!(
550 "no instances found for endpoint {}",
551 self.client.endpoint.id()
552 ));
553 }
554 p2c_select_from(state.as_ref(), &instance_ids)
555 };
556 state.increment(instance_id);
557 let permit = OccupancyPermit::new(state, instance_id);
558
559 match self
560 .generate_with_fault_detection(instance_id, request)
561 .await
562 {
563 Ok(stream) => Ok(permit.into_tracked_stream(stream)),
564 Err(err) => Err(err),
565 }
566 }
567
568 pub async fn direct(
570 &self,
571 request: SingleIn<T>,
572 instance_id: u64,
573 ) -> anyhow::Result<ManyOut<U>> {
574 let found = if self.fault_detection_enabled {
578 self.client.instance_ids_avail().contains(&instance_id)
579 } else {
580 self.client.instance_ids().contains(&instance_id)
581 };
582
583 if !found {
584 return Err(anyhow::anyhow!(
585 "instance_id={instance_id} not found for endpoint {}",
586 self.client.endpoint.id()
587 ));
588 }
589
590 self.generate_with_fault_detection(instance_id, request)
591 .await
592 }
593
594 pub async fn device_aware_weighted(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
603 let state = self.occupancy_state()?;
604 let instance_ids = self
605 .client
606 .instance_ids_avail()
607 .iter()
608 .copied()
609 .collect::<Vec<_>>();
610
611 if instance_ids.is_empty() {
612 return Err(anyhow::anyhow!(
613 "no instances found for endpoint {}",
614 self.client.endpoint.id()
615 ));
616 }
617
618 let endpoint_id = self.client.endpoint.id();
620
621 let instances = self.client.instances();
623 let device_type_map: std::collections::HashMap<u64, Option<DeviceType>> = instances
624 .iter()
625 .map(|inst| (inst.instance_id, inst.device_type.clone()))
626 .collect();
627
628 let cuda_to_cpu_ratio = std::env::var("DYN_ENCODER_CUDA_TO_CPU_RATIO")
630 .ok()
631 .and_then(|v| v.parse::<usize>().ok())
632 .filter(|v| *v >= 1)
633 .unwrap_or(8);
634 let candidates = device_aware_candidate_group(
635 state.as_ref(),
636 &instance_ids,
637 &device_type_map,
638 cuda_to_cpu_ratio,
639 );
640
641 let instance_id = state
643 .select_exact_min_and_increment(&candidates)
644 .await
645 .ok_or_else(|| {
646 anyhow::anyhow!(
647 "no instances in selected device group for endpoint {}",
648 endpoint_id
649 )
650 })?;
651 let permit = OccupancyPermit::new(state.clone(), instance_id);
652 let is_cpu = matches!(
653 device_type_map.get(&instance_id),
654 Some(Some(DeviceType::Cpu))
655 );
656 tracing::info!(
657 endpoint = %endpoint_id,
658 selected_instance = instance_id,
659 is_cpu,
660 "DeviceAwareWeighted selected instance"
661 );
662
663 match self
664 .generate_with_fault_detection(instance_id, request)
665 .await
666 {
667 Ok(stream) => Ok(permit.into_tracked_stream(stream)),
668 Err(err) => Err(err),
669 }
670 }
671
672 pub async fn least_loaded(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
674 let state = self.occupancy_state()?;
675 let instance_ids = self
676 .client
677 .instance_ids_avail()
678 .iter()
679 .copied()
680 .collect::<Vec<_>>();
681 let instance_id = state
682 .select_exact_min_and_increment(&instance_ids)
683 .await
684 .ok_or_else(|| {
685 anyhow::anyhow!(
686 "no instances found for endpoint {}",
687 self.client.endpoint.id()
688 )
689 })?;
690 let permit = OccupancyPermit::new(state.clone(), instance_id);
691 tracing::trace!(
692 "least loaded router selected {instance_id} (connections: {})",
693 state.load(instance_id)
694 );
695
696 match self
697 .generate_with_fault_detection(instance_id, request)
698 .await
699 {
700 Ok(stream) => Ok(permit.into_tracked_stream(stream)),
701 Err(err) => Err(err),
702 }
703 }
704
705 pub fn select_next_worker(&self) -> Option<u64> {
709 let instance_ids = self.client.instance_ids_avail();
710 let count = instance_ids.len();
711 if count == 0 {
712 return None;
713 }
714
715 match self.router_mode {
716 RouterMode::RoundRobin => {
717 let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
718 Some(instance_ids[counter % count])
719 }
720 RouterMode::Random => {
721 let counter = rand::rng().random::<u64>() as usize;
722 Some(instance_ids[counter % count])
723 }
724 RouterMode::PowerOfTwoChoices
725 | RouterMode::Direct
726 | RouterMode::LeastLoaded
727 | RouterMode::DeviceAwareWeighted => None,
728 RouterMode::KV => {
729 panic!(
730 "select_next_worker should not be called for {:?} routing mode",
731 self.router_mode
732 )
733 }
734 }
735 }
736
737 pub fn peek_next_worker(&self) -> Option<u64> {
741 let instance_ids = self.client.instance_ids_avail();
742 let count = instance_ids.len();
743 if count == 0 {
744 return None;
745 }
746
747 match self.router_mode {
748 RouterMode::RoundRobin => {
749 let counter = self.round_robin_counter.load(Ordering::Relaxed) as usize;
751 Some(instance_ids[counter % count])
752 }
753 RouterMode::Random => {
754 let counter = rand::rng().random::<u64>() as usize;
757 Some(instance_ids[counter % count])
758 }
759 RouterMode::PowerOfTwoChoices
760 | RouterMode::Direct
761 | RouterMode::LeastLoaded
762 | RouterMode::DeviceAwareWeighted => None,
763 RouterMode::KV => {
764 panic!(
765 "peek_next_worker should not be called for {:?} routing mode",
766 self.router_mode
767 )
768 }
769 }
770 }
771
772 fn occupancy_state(&self) -> anyhow::Result<Arc<RoutingOccupancyState>> {
773 self.occupancy_state.clone().ok_or_else(|| {
774 anyhow::anyhow!(
775 "routing occupancy state not initialized for endpoint {}",
776 self.client.endpoint.id()
777 )
778 })
779 }
780
781 async fn generate_with_fault_detection(
792 &self,
793 mut instance_id: u64,
794 request: SingleIn<T>,
795 ) -> anyhow::Result<ManyOut<U>> {
796 let route_start = Instant::now();
797 let request_id = request.id().to_string();
798 let route_span = if matches!(self.router_mode, RouterMode::KV) {
799 tracing::Span::none()
800 } else {
801 tracing::info_span!(
802 "router.route_request",
803 request_id = %request_id,
804 worker_id = instance_id,
805 router_mode = ?self.router_mode,
806 )
807 };
808
809 if self.fault_detection_enabled {
811 let free_instances = self.client.instance_ids_free();
812 if free_instances.is_empty() {
813 let all_instances = self.client.instance_ids();
815 if !all_instances.is_empty() {
816 tracing::warn!(
817 instance_id,
818 total_workers = all_instances.len(),
819 "Rejecting request: all workers are busy"
820 );
821 let cause = PipelineError::ServiceOverloaded(
822 "All workers are busy, please retry later".to_string(),
823 );
824 return Err(DynamoError::builder()
825 .error_type(ErrorType::ResourceExhausted)
826 .message("All workers are busy, please retry later")
827 .cause(cause)
828 .build()
829 .into());
830 }
831 }
832 }
833
834 let (address, _transport_kind, instance) = {
837 use crate::component::TransportType;
838
839 let resolve_transport = |id: u64| {
840 let instances = self.client.instances();
841 instances
842 .iter()
843 .find(|i| i.instance_id == id)
844 .map(|instance| {
845 let (addr, kind) = match &instance.transport {
846 TransportType::Http(http_endpoint) => {
847 tracing::debug!(
848 instance_id = id,
849 http_endpoint = %http_endpoint,
850 "Using HTTP transport for instance"
851 );
852 (http_endpoint.clone(), "transport.http.request")
853 }
854 TransportType::Tcp(tcp_endpoint) => {
855 tracing::debug!(
856 instance_id = id,
857 tcp_endpoint = %tcp_endpoint,
858 "Using TCP transport for instance"
859 );
860 (tcp_endpoint.clone(), "transport.tcp.request")
861 }
862 TransportType::Nats(subject) => {
863 tracing::debug!(
864 instance_id = id,
865 subject = %subject,
866 "Using NATS transport for instance"
867 );
868 (subject.clone(), "transport.nats.request")
869 }
870 };
871 (addr, kind, instance.clone())
872 })
873 };
874
875 if let Some(result) = resolve_transport(instance_id) {
876 result
877 } else {
878 let avail = self.client.instance_ids_avail();
881 let fallback_id = avail.iter().copied().find(|&id| id != instance_id);
882 match fallback_id {
883 Some(id) => {
884 tracing::warn!(
885 original_instance = instance_id,
886 fallback_instance = id,
887 "Instance disappeared during routing, reselecting"
888 );
889 instance_id = id;
890 resolve_transport(id).ok_or_else(|| {
891 anyhow::anyhow!(
892 "Fallback instance {} also not found for endpoint {}",
893 id,
894 self.client.endpoint.id()
895 )
896 })?
897 }
898 None => {
899 return Err(anyhow::anyhow!(
900 "Instance {} not found and no other instances available \
901 for endpoint {}",
902 instance_id,
903 self.client.endpoint.id()
904 ));
905 }
906 }
907 }
908 };
909
910 let request = request.map(|req| AddressedRequest::with_instance(req, address, instance));
911
912 STAGE_DURATION_SECONDS
913 .with_label_values(&[STAGE_ROUTE])
914 .observe(route_start.elapsed().as_secs_f64());
915
916 let _nvtx_transport = dynamo_nvtx_range!(_transport_kind);
917 let stream: anyhow::Result<ManyOut<U>> = self
918 .addressed
919 .generate(request)
920 .instrument(route_span)
921 .await;
922 match stream {
923 Ok(stream) => {
924 if !self.fault_detection_enabled {
925 return Ok(stream);
926 }
927 let engine_ctx = stream.context();
928 let client = self.client.clone();
929 let client_for_timeout = self.client.clone();
930 let stream = stream.map(move |res| {
931 if let Some(err) = res.err()
933 && is_inhibited(&err)
934 {
935 tracing::debug!(
936 "Reporting instance {instance_id} down due to migratable error: {err}"
937 );
938 client.report_instance_down(instance_id);
939 }
940 res
941 });
942
943 let stream: Pin<Box<dyn Stream<Item = U> + Send>> = if let Some(timeout) =
947 self.response_timeout
948 {
949 Box::pin(async_stream::stream! {
950 let mut inner = Box::pin(stream);
951 loop {
952 tokio::select! {
953 biased;
954 item = inner.next() => {
955 match item {
956 Some(item) => yield item,
957 None => break,
958 }
959 }
960 _ = tokio::time::sleep(timeout) => {
961 tracing::warn!(
962 instance_id,
963 timeout_secs = timeout.as_secs(),
964 "backend response inactivity timeout — quarantining worker"
965 );
966 client_for_timeout.report_instance_down(instance_id);
967 yield U::from_err(
968 crate::error::DynamoError::builder()
969 .error_type(crate::error::ErrorType::ResponseTimeout)
970 .message("backend response inactivity timeout")
971 .build()
972 );
973 break;
974 }
975 }
976 }
977 })
978 } else {
979 Box::pin(stream)
980 };
981
982 Ok(ResponseStream::new(stream, engine_ctx))
983 }
984 Err(err) => {
985 if self.fault_detection_enabled && is_inhibited(err.as_ref()) {
986 tracing::debug!("Reporting instance {instance_id} down due to error: {err}");
987 self.client.report_instance_down(instance_id);
988 }
989 Err(err)
990 }
991 }
992 }
993}
994
995#[async_trait]
996impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
997where
998 T: Data + Serialize,
999 U: Data + for<'de> Deserialize<'de> + MaybeError,
1000{
1001 async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
1002 match self.router_mode {
1003 RouterMode::Random => self.random(request).await,
1004 RouterMode::RoundRobin => self.round_robin(request).await,
1005 RouterMode::PowerOfTwoChoices => self.power_of_two_choices(request).await,
1006 RouterMode::KV => {
1007 anyhow::bail!("KV routing should not call generate on PushRouter");
1008 }
1009 RouterMode::Direct => {
1010 anyhow::bail!(
1011 "Direct routing should not call generate on PushRouter directly; use DirectRoutingRouter wrapper"
1012 );
1013 }
1014 RouterMode::LeastLoaded => self.least_loaded(request).await,
1015 RouterMode::DeviceAwareWeighted => self.device_aware_weighted(request).await,
1016 }
1017 }
1018}
1019
1020struct OccupancyTrackedStream<U: Data> {
1021 inner: ManyOut<U>,
1022 state: Arc<RoutingOccupancyState>,
1023 instance_id: u64,
1024}
1025
1026impl<U: Data> Drop for OccupancyTrackedStream<U> {
1027 fn drop(&mut self) {
1028 self.state.decrement(self.instance_id);
1029 }
1030}
1031
1032impl<U: Data> std::fmt::Debug for OccupancyTrackedStream<U> {
1033 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1034 f.debug_struct("OccupancyTrackedStream")
1035 .field("instance_id", &self.instance_id)
1036 .finish()
1037 }
1038}
1039
1040impl<U: Data> Stream for OccupancyTrackedStream<U> {
1041 type Item = U;
1042
1043 fn poll_next(
1044 mut self: Pin<&mut Self>,
1045 cx: &mut std::task::Context<'_>,
1046 ) -> Poll<Option<Self::Item>> {
1047 self.inner.as_mut().poll_next(cx)
1048 }
1049}
1050
1051impl<U: Data> AsyncEngineContextProvider for OccupancyTrackedStream<U> {
1052 fn context(&self) -> Arc<dyn AsyncEngineContext> {
1053 self.inner.context()
1054 }
1055}
1056
1057impl<U: Data> crate::engine::AsyncEngineStream<U> for OccupancyTrackedStream<U> {}
1058
1059#[cfg(test)]
1060mod tests {
1061 use super::*;
1062 use crate::{
1063 DistributedRuntime, Runtime,
1064 distributed::DistributedConfig,
1065 error::DynamoError,
1066 pipeline::{ResponseStream, context::Controller},
1067 };
1068 use serde::{Deserialize, Serialize};
1069
1070 #[derive(Clone, Debug, Deserialize, Serialize)]
1071 struct TestResponse {
1072 error: Option<DynamoError>,
1073 }
1074
1075 impl MaybeError for TestResponse {
1076 fn from_err(err: impl std::error::Error + 'static) -> Self {
1077 Self {
1078 error: Some(DynamoError::from(
1079 Box::new(err) as Box<dyn std::error::Error + 'static>
1080 )),
1081 }
1082 }
1083
1084 fn err(&self) -> Option<DynamoError> {
1085 self.error.clone()
1086 }
1087 }
1088
1089 #[test]
1090 fn p2c_selects_lower_load_worker() {
1091 let state = RoutingOccupancyState::default();
1092 for _ in 0..10 {
1093 state.increment(1);
1094 }
1095 state.increment(2);
1096
1097 let result = p2c_select_from(&state, &[1, 2]);
1099 assert_eq!(result, 2);
1100 }
1101
1102 #[test]
1103 fn p2c_selects_single_worker() {
1104 let state = RoutingOccupancyState::default();
1105 assert_eq!(p2c_select_from(&state, &[42]), 42);
1106 }
1107
1108 #[test]
1109 fn p2c_treats_missing_counts_as_zero() {
1110 let state = RoutingOccupancyState::default();
1111 for _ in 0..5 {
1112 state.increment(1);
1113 }
1114 let result = p2c_select_from(&state, &[1, 2]);
1116 assert_eq!(result, 2);
1117 }
1118
1119 #[test]
1120 fn p2c_returns_valid_worker_on_tie() {
1121 let state = RoutingOccupancyState::default();
1122 for _ in 0..3 {
1123 state.increment(1);
1124 state.increment(2);
1125 }
1126
1127 for _ in 0..100 {
1128 let result = p2c_select_from(&state, &[1, 2]);
1129 assert!(result == 1 || result == 2);
1130 }
1131 }
1132
1133 #[test]
1134 fn occupancy_permit_decrements_before_stream_creation() {
1135 let state = Arc::new(RoutingOccupancyState::default());
1136 state.increment(42);
1137 let permit = OccupancyPermit::new(state.clone(), 42);
1138 assert_eq!(state.load(42), 1);
1139 drop(permit);
1140 assert_eq!(state.load(42), 0);
1141 }
1142
1143 #[test]
1144 fn occupancy_tracked_stream_decrements_on_drop() {
1145 let state = Arc::new(RoutingOccupancyState::default());
1146 state.increment(7);
1147 let permit = OccupancyPermit::new(state.clone(), 7);
1148 let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
1149 let stream = permit.into_tracked_stream(ResponseStream::new(
1150 Box::pin(tokio_stream::iter(vec![1u64])),
1151 ctx,
1152 ));
1153 assert_eq!(state.load(7), 1);
1154 drop(stream);
1155 assert_eq!(state.load(7), 0);
1156 }
1157
1158 #[test]
1159 fn p2c_lifecycle_tracks_inflight_counts_with_shared_tracker() {
1160 let state = Arc::new(RoutingOccupancyState::default());
1161 let mut permits = Vec::new();
1162 for _ in 0..5 {
1163 let selected = p2c_select_from(&state, &[1, 2]);
1164 state.increment(selected);
1165 permits.push(OccupancyPermit::new(state.clone(), selected));
1166 }
1167
1168 let total = state.load(1) + state.load(2);
1169 assert_eq!(total, 5, "5 in-flight requests should be tracked");
1170
1171 drop(permits);
1172 let total = state.load(1) + state.load(2);
1173 assert_eq!(total, 0, "All guards dropped, counts should be 0");
1174 }
1175
1176 #[test]
1177 fn p2c_never_selects_dominated_worker() {
1178 let state = RoutingOccupancyState::default();
1179 for _ in 0..100 {
1180 state.increment(3);
1181 }
1182
1183 let mut selected = [0u32; 3];
1184 for _ in 0..1000 {
1185 let result = p2c_select_from(&state, &[1, 2, 3]);
1186 match result {
1187 1 => selected[0] += 1,
1188 2 => selected[1] += 1,
1189 3 => selected[2] += 1,
1190 _ => panic!("unexpected worker id"),
1191 }
1192 }
1193 assert_eq!(
1194 selected[2], 0,
1195 "Worker 3 (load=100) should never be selected against load=0 workers, but got {} times",
1196 selected[2]
1197 );
1198 }
1199
1200 #[tokio::test]
1201 async fn least_loaded_selects_exact_min_and_tracks_counts() {
1202 let state = Arc::new(RoutingOccupancyState::default());
1203 state.increment(1);
1204 state.increment(1);
1205 state.increment(2);
1206
1207 let selected = state
1208 .select_exact_min_and_increment(&[1, 2, 3])
1209 .await
1210 .unwrap();
1211 assert_eq!(selected, 3);
1212
1213 let permit = OccupancyPermit::new(state.clone(), selected);
1214 assert_eq!(state.load(selected), 1);
1215 drop(permit);
1216 assert_eq!(state.load(selected), 0);
1217 }
1218
1219 #[tokio::test]
1220 async fn least_loaded_select_and_peek_return_none_with_available_worker() {
1221 let rt = Runtime::from_current().unwrap();
1222 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1223 .await
1224 .unwrap();
1225 let ns = drt
1226 .namespace("test_least_loaded_router".to_string())
1227 .unwrap();
1228 let component = ns.component("test_component".to_string()).unwrap();
1229 let endpoint = component.endpoint("test_endpoint".to_string());
1230 let client = endpoint.client().await.unwrap();
1231
1232 endpoint.register_endpoint_instance().await.unwrap();
1233 client.wait_for_instances().await.unwrap();
1234
1235 let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::LeastLoaded)
1236 .await
1237 .unwrap();
1238
1239 assert_eq!(router.select_next_worker(), None);
1240 assert_eq!(router.peek_next_worker(), None);
1241
1242 rt.shutdown();
1243 }
1244
1245 #[tokio::test]
1246 async fn device_aware_cpu_only_selects_least_loaded_instance() {
1247 let state = RoutingOccupancyState::default();
1248 for _ in 0..3 {
1250 state.increment(1);
1251 }
1252 state.increment(3);
1253
1254 let instance_ids = vec![1, 2, 3];
1255 let device_type_map = HashMap::from([
1256 (1, Some(DeviceType::Cpu)),
1257 (2, Some(DeviceType::Cpu)),
1258 (3, Some(DeviceType::Cpu)),
1259 ]);
1260
1261 let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
1262 assert_eq!(candidates, vec![1, 2, 3]);
1263
1264 let selected = state
1265 .select_exact_min_and_increment(&candidates)
1266 .await
1267 .unwrap();
1268 assert_eq!(selected, 2);
1269 }
1270
1271 #[tokio::test]
1272 async fn device_aware_non_cpu_only_selects_least_loaded_instance() {
1273 let state = RoutingOccupancyState::default();
1274 for _ in 0..3 {
1276 state.increment(1);
1277 }
1278 state.increment(3);
1279
1280 let instance_ids = vec![1, 2, 3];
1281 let device_type_map = HashMap::from([
1282 (1, Some(DeviceType::Cuda)),
1283 (2, Some(DeviceType::Cuda)),
1284 (3, Some(DeviceType::Cuda)),
1285 ]);
1286
1287 let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
1288 assert_eq!(candidates, vec![1, 2, 3]);
1289
1290 let selected = state
1291 .select_exact_min_and_increment(&candidates)
1292 .await
1293 .unwrap();
1294 assert_eq!(selected, 2);
1295 }
1296
1297 #[test]
1298 fn device_aware_group_uses_ratio_budget() {
1299 let state = RoutingOccupancyState::default();
1300 for _ in 0..4 {
1302 state.increment(3);
1303 state.increment(4);
1304 }
1305 for _ in 0..3 {
1307 state.increment(1);
1308 }
1309 let instance_ids = vec![1, 2, 3, 4];
1313 let device_type_map = HashMap::from([
1314 (1, Some(DeviceType::Cpu)),
1315 (2, Some(DeviceType::Cpu)),
1316 (3, Some(DeviceType::Cuda)),
1317 (4, Some(DeviceType::Cuda)),
1318 ]);
1319
1320 let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 2);
1321 assert_eq!(candidates, vec![1, 2]);
1322
1323 let selected =
1325 futures::executor::block_on(state.select_exact_min_and_increment(&candidates)).unwrap();
1326 assert_eq!(selected, 2);
1327 }
1328
1329 #[tokio::test]
1330 async fn device_aware_weighted_select_and_peek_return_none_with_available_worker() {
1331 let rt = Runtime::from_current().unwrap();
1332 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1333 .await
1334 .unwrap();
1335 let ns = drt
1336 .namespace("test_device_aware_router".to_string())
1337 .unwrap();
1338 let component = ns.component("test_component".to_string()).unwrap();
1339 let endpoint = component.endpoint("test_endpoint".to_string());
1340 let client = endpoint.client().await.unwrap();
1341
1342 endpoint.register_endpoint_instance().await.unwrap();
1343 client.wait_for_instances().await.unwrap();
1344
1345 let router =
1346 PushRouter::<u64, TestResponse>::from_client(client, RouterMode::DeviceAwareWeighted)
1347 .await
1348 .unwrap();
1349
1350 assert_eq!(router.select_next_worker(), None);
1351 assert_eq!(router.peek_next_worker(), None);
1352
1353 rt.shutdown();
1354 }
1355
1356 #[tokio::test]
1360 async fn transport_resolution_falls_back_when_selected_instance_disappears() {
1361 let rt = Runtime::from_current().unwrap();
1362 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1363 .await
1364 .unwrap();
1365 let ns = drt
1366 .namespace("test_transport_fallback".to_string())
1367 .unwrap();
1368 let component = ns.component("test_component".to_string()).unwrap();
1369 let endpoint = component.endpoint("test_endpoint".to_string());
1370 let client = endpoint.client().await.unwrap();
1371
1372 endpoint.register_endpoint_instance().await.unwrap();
1374 client.wait_for_instances().await.unwrap();
1375
1376 let real_id = client.instance_ids()[0];
1377
1378 let stale_id = real_id + 1000;
1382 client.override_instance_avail(vec![stale_id, real_id]);
1383
1384 let router =
1387 PushRouter::<u64, TestResponse>::from_client(client.clone(), RouterMode::RoundRobin)
1388 .await
1389 .unwrap();
1390
1391 let request = SingleIn::new(42u64);
1398 let result = router.generate(request).await;
1399
1400 if let Err(err) = &result {
1404 let msg = format!("{err}");
1405 assert!(
1406 !msg.contains("not found"),
1407 "Transport resolution should have fallen back, but got: {msg}"
1408 );
1409 }
1410
1411 rt.shutdown();
1412 }
1413
1414 #[tokio::test]
1417 async fn transport_resolution_errors_when_no_instances_available() {
1418 let rt = Runtime::from_current().unwrap();
1419 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1420 .await
1421 .unwrap();
1422 let ns = drt
1423 .namespace("test_transport_no_fallback".to_string())
1424 .unwrap();
1425 let component = ns.component("test_component".to_string()).unwrap();
1426 let endpoint = component.endpoint("test_endpoint".to_string());
1427 let client = endpoint.client().await.unwrap();
1428
1429 endpoint.register_endpoint_instance().await.unwrap();
1431 client.wait_for_instances().await.unwrap();
1432
1433 let router =
1434 PushRouter::<u64, TestResponse>::from_client(client.clone(), RouterMode::RoundRobin)
1435 .await
1436 .unwrap();
1437
1438 let stale_id = 99999;
1441 client.override_instance_avail(vec![stale_id]);
1442
1443 let request = SingleIn::new(42u64);
1444 let result = router.generate(request).await;
1445
1446 assert!(result.is_err());
1447 let msg = format!("{}", result.unwrap_err());
1448 assert!(
1449 msg.contains("not found") && msg.contains("no other instances available"),
1450 "Expected clear error about missing instance with no fallback, got: {msg}"
1451 );
1452
1453 rt.shutdown();
1454 }
1455
1456 #[tokio::test]
1468 async fn watcher_dedup_guard_released_on_panic() {
1469 let endpoint_id = EndpointId {
1470 namespace: "panic-test-ns".to_string(),
1471 component: "panic-test-comp".to_string(),
1472 name: "panic-test-endpoint".to_string(),
1473 };
1474
1475 let map = ENDPOINT_WATCHER_ACTIVE.get_or_init(dashmap::DashMap::new);
1477 map.insert(endpoint_id.clone(), ());
1478
1479 let endpoint_id_clone = endpoint_id.clone();
1480 let join = tokio::spawn(async move {
1481 struct GuardRelease(EndpointId);
1483 impl Drop for GuardRelease {
1484 fn drop(&mut self) {
1485 if let Some(map) = ENDPOINT_WATCHER_ACTIVE.get() {
1486 map.remove(&self.0);
1487 }
1488 }
1489 }
1490 let _release = GuardRelease(endpoint_id_clone);
1491 panic!("simulated watcher-task panic");
1492 });
1493
1494 let result = join.await;
1495 assert!(result.is_err() && result.unwrap_err().is_panic());
1496 assert!(
1497 !map.contains_key(&endpoint_id),
1498 "Drop guard must release the dedup entry even on panic"
1499 );
1500 }
1501
1502 #[tokio::test]
1506 async fn watcher_dedup_guard_released_on_normal_exit() {
1507 let endpoint_id = EndpointId {
1508 namespace: "normal-test-ns".to_string(),
1509 component: "normal-test-comp".to_string(),
1510 name: "normal-test-endpoint".to_string(),
1511 };
1512
1513 let map = ENDPOINT_WATCHER_ACTIVE.get_or_init(dashmap::DashMap::new);
1514 map.insert(endpoint_id.clone(), ());
1515
1516 let endpoint_id_clone = endpoint_id.clone();
1517 tokio::spawn(async move {
1518 struct GuardRelease(EndpointId);
1519 impl Drop for GuardRelease {
1520 fn drop(&mut self) {
1521 if let Some(map) = ENDPOINT_WATCHER_ACTIVE.get() {
1522 map.remove(&self.0);
1523 }
1524 }
1525 }
1526 let _release = GuardRelease(endpoint_id_clone);
1527 })
1529 .await
1530 .unwrap();
1531
1532 assert!(!map.contains_key(&endpoint_id));
1533 }
1534}