1use super::{AsyncEngineContextProvider, ResponseStream};
5use crate::error::{BackendError, DynamoError, ErrorType, match_error_chain};
6use crate::{
7 component::{
8 Client, DeviceType, Endpoint, RoutingOccupancyState, get_or_create_routing_occupancy_state,
9 },
10 dynamo_nvtx_range,
11 engine::{AsyncEngine, AsyncEngineContext, Data},
12 metrics::frontend_perf::STAGE_DURATION_SECONDS,
13 pipeline::{
14 AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
15 error::{PipelineError, PipelineErrorExt},
16 },
17 protocols::maybe_error::MaybeError,
18 traits::DistributedRuntimeProvider,
19};
20use async_trait::async_trait;
21use futures::Stream;
22use rand::Rng;
23use serde::{Deserialize, Serialize};
24use std::{
25 collections::HashMap,
26 marker::PhantomData,
27 pin::Pin,
28 sync::{
29 Arc,
30 atomic::{AtomicU64, Ordering},
31 },
32 task::Poll,
33 time::Instant,
34};
35use tokio_stream::StreamExt;
36use tracing::Instrument;
37
38fn is_inhibited(err: &(dyn std::error::Error + 'static)) -> bool {
40 const INHIBITED: &[ErrorType] = &[
41 ErrorType::CannotConnect,
42 ErrorType::Disconnected,
43 ErrorType::ConnectionTimeout,
44 ErrorType::ResponseTimeout,
45 ErrorType::Backend(BackendError::EngineShutdown),
46 ];
47 match_error_chain(err, INHIBITED, &[])
48}
49
50fn response_inactivity_timeout() -> Option<std::time::Duration> {
54 use crate::config::environment_names::llm::DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS;
55 std::env::var(DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS)
56 .ok()
57 .and_then(|s| s.parse::<u64>().ok())
58 .filter(|&secs| secs > 0)
59 .map(std::time::Duration::from_secs)
60}
61
62struct OccupancyPermit {
63 state: Arc<RoutingOccupancyState>,
64 instance_id: u64,
65 armed: bool,
66}
67
68impl OccupancyPermit {
69 fn new(state: Arc<RoutingOccupancyState>, instance_id: u64) -> Self {
70 Self {
71 state,
72 instance_id,
73 armed: true,
74 }
75 }
76
77 fn into_tracked_stream<U: Data>(mut self, stream: ManyOut<U>) -> ManyOut<U> {
78 self.armed = false;
79 let engine_ctx = stream.context();
80 ResponseStream::new(
81 Box::pin(OccupancyTrackedStream {
82 inner: stream,
83 state: self.state.clone(),
84 instance_id: self.instance_id,
85 }),
86 engine_ctx,
87 )
88 }
89
90 fn instance_id(&self) -> u64 {
91 self.instance_id
92 }
93}
94
95impl Drop for OccupancyPermit {
96 fn drop(&mut self) {
97 if self.armed {
98 self.state.decrement(self.instance_id);
99 }
100 }
101}
102
103#[async_trait]
106pub trait WorkerLoadMonitor: Send + Sync {
107 async fn start_monitoring(&self) -> anyhow::Result<()>;
110}
111
112#[derive(Clone)]
113pub struct PushRouter<T, U>
114where
115 T: Data + Serialize,
116 U: Data + for<'de> Deserialize<'de>,
117{
118 pub client: Client,
121
122 router_mode: RouterMode,
129
130 round_robin_counter: Arc<AtomicU64>,
132
133 addressed: Arc<AddressedPushRouter>,
136
137 fault_detection_enabled: bool,
142
143 response_timeout: Option<std::time::Duration>,
146
147 occupancy_state: Option<Arc<RoutingOccupancyState>>,
149
150 _phantom: PhantomData<(T, U)>,
154}
155
156#[derive(Default, Debug, Clone, Copy, PartialEq)]
157pub enum RouterMode {
158 #[default]
159 RoundRobin,
160 Random,
161 PowerOfTwoChoices,
162 KV,
163 Direct,
164 LeastLoaded,
165 DeviceAwareWeighted,
167}
168
169impl RouterMode {
170 pub fn is_kv_routing(&self) -> bool {
171 *self == RouterMode::KV
172 }
173
174 pub fn is_direct_routing(&self) -> bool {
175 *self == RouterMode::Direct
176 }
177}
178
179fn p2c_select_from(occupancy_state: &RoutingOccupancyState, instance_ids: &[u64]) -> u64 {
182 let count = instance_ids.len();
183 if count == 1 {
184 return instance_ids[0];
185 }
186 let mut rng = rand::rng();
187 let idx1 = rng.random_range(0..count);
188 let idx2 = (idx1 + 1 + rng.random_range(0..count - 1)) % count;
189 let id1 = instance_ids[idx1];
190 let id2 = instance_ids[idx2];
191 let load1 = occupancy_state.load(id1);
192 let load2 = occupancy_state.load(id2);
193 let selected = if load1 <= load2 { id1 } else { id2 };
194 tracing::debug!(
195 candidate_a = id1,
196 candidate_a_load = load1,
197 candidate_b = id2,
198 candidate_b_load = load2,
199 selected = selected,
200 "p2c selection"
201 );
202 selected
203}
204
205fn device_aware_candidate_group(
217 state: &RoutingOccupancyState,
218 instance_ids: &[u64],
219 device_type_map: &HashMap<u64, Option<DeviceType>>,
220 non_cpu_to_cpu_ratio: usize,
221) -> Vec<u64> {
222 let cpu_ids: Vec<u64> = instance_ids
223 .iter()
224 .copied()
225 .filter(|id| matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
226 .collect();
227 let non_cpu_ids: Vec<u64> = instance_ids
228 .iter()
229 .copied()
230 .filter(|id| !matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
231 .collect();
232
233 if cpu_ids.is_empty() {
234 return non_cpu_ids;
235 }
236 if non_cpu_ids.is_empty() {
237 return cpu_ids;
238 }
239
240 let total_non_cpu_inflight: u64 = non_cpu_ids.iter().map(|id| state.load(*id)).sum();
242 let total_cpu_inflight: u64 = cpu_ids.iter().map(|id| state.load(*id)).sum();
243 let cpu_count = cpu_ids.len() as u64;
244 let non_cpu_count = non_cpu_ids.len() as u64;
245 let allowed_cpu_inflight = total_non_cpu_inflight.saturating_mul(cpu_count)
246 / ((non_cpu_to_cpu_ratio as u64).saturating_mul(non_cpu_count));
247
248 if total_cpu_inflight < allowed_cpu_inflight {
249 cpu_ids
250 } else {
251 non_cpu_ids
252 }
253}
254
255async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
256 let manager = endpoint.drt().network_manager();
258 let req_client = manager.create_client()?;
259 let resp_transport = endpoint.drt().tcp_server().await?;
260
261 tracing::debug!(
262 transport = req_client.transport_name(),
263 "Creating AddressedPushRouter with request plane client"
264 );
265
266 AddressedPushRouter::new(req_client, resp_transport)
267}
268
269impl<T, U> PushRouter<T, U>
270where
271 T: Data + Serialize,
272 U: Data + for<'de> Deserialize<'de> + MaybeError,
273{
274 pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
276 Self::from_client_with_monitor(client, router_mode, None).await
277 }
278
279 pub async fn from_client_no_fault_detection(
285 client: Client,
286 router_mode: RouterMode,
287 ) -> anyhow::Result<Self> {
288 let addressed = addressed_router(&client.endpoint).await?;
289
290 let occupancy_state = if matches!(
291 router_mode,
292 RouterMode::PowerOfTwoChoices
293 | RouterMode::LeastLoaded
294 | RouterMode::DeviceAwareWeighted
295 ) {
296 Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
297 } else {
298 None
299 };
300
301 Ok(PushRouter {
302 client,
303 addressed,
304 router_mode,
305 round_robin_counter: Arc::new(AtomicU64::new(0)),
306 fault_detection_enabled: false,
307 response_timeout: response_inactivity_timeout(),
308 occupancy_state,
309 _phantom: PhantomData,
310 })
311 }
312
313 pub async fn from_client_with_monitor(
320 client: Client,
321 router_mode: RouterMode,
322 worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
323 ) -> anyhow::Result<Self> {
324 let addressed = addressed_router(&client.endpoint).await?;
325
326 if let Some(monitor) = worker_monitor.as_ref() {
328 monitor.start_monitoring().await?;
329 }
330
331 let occupancy_state = if matches!(
332 router_mode,
333 RouterMode::PowerOfTwoChoices
334 | RouterMode::LeastLoaded
335 | RouterMode::DeviceAwareWeighted
336 ) {
337 Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
338 } else {
339 None
340 };
341
342 let router = PushRouter {
343 client,
344 addressed,
345 router_mode,
346 round_robin_counter: Arc::new(AtomicU64::new(0)),
347 fault_detection_enabled: true,
348 response_timeout: response_inactivity_timeout(),
349 occupancy_state,
350 _phantom: PhantomData,
351 };
352
353 Ok(router)
354 }
355
356 pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
358 let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
359
360 let instance_id = {
361 let instance_ids = self.client.instance_ids_avail();
362 let count = instance_ids.len();
363 if count == 0 {
364 return Err(anyhow::anyhow!(
365 "no instances found for endpoint {}",
366 self.client.endpoint.id()
367 ));
368 }
369 instance_ids[counter % count]
370 };
371 tracing::trace!("round robin router selected {instance_id}");
372
373 self.generate_with_fault_detection(instance_id, request)
374 .await
375 }
376
377 pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
379 let instance_id = {
380 let instance_ids = self.client.instance_ids_avail();
381 let count = instance_ids.len();
382 if count == 0 {
383 return Err(anyhow::anyhow!(
384 "no instances found for endpoint {}",
385 self.client.endpoint.id()
386 ));
387 }
388 let counter = rand::rng().random::<u64>() as usize;
389 instance_ids[counter % count]
390 };
391 tracing::trace!("random router selected {instance_id}");
392
393 self.generate_with_fault_detection(instance_id, request)
394 .await
395 }
396
397 pub async fn power_of_two_choices(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
400 let state = self.occupancy_state()?;
401 let instance_id = {
402 let instance_ids = self
403 .client
404 .instance_ids_avail()
405 .iter()
406 .copied()
407 .collect::<Vec<_>>();
408 if instance_ids.is_empty() {
409 return Err(anyhow::anyhow!(
410 "no instances found for endpoint {}",
411 self.client.endpoint.id()
412 ));
413 }
414 p2c_select_from(state.as_ref(), &instance_ids)
415 };
416 state.increment(instance_id);
417 let permit = OccupancyPermit::new(state, instance_id);
418
419 match self
420 .generate_with_fault_detection(instance_id, request)
421 .await
422 {
423 Ok(stream) => Ok(permit.into_tracked_stream(stream)),
424 Err(err) => Err(err),
425 }
426 }
427
428 pub async fn direct(
430 &self,
431 request: SingleIn<T>,
432 instance_id: u64,
433 ) -> anyhow::Result<ManyOut<U>> {
434 let found = if self.fault_detection_enabled {
438 self.client.instance_ids_avail().contains(&instance_id)
439 } else {
440 self.client.instance_ids().contains(&instance_id)
441 };
442
443 if !found {
444 return Err(anyhow::anyhow!(
445 "instance_id={instance_id} not found for endpoint {}",
446 self.client.endpoint.id()
447 ));
448 }
449
450 self.generate_with_fault_detection(instance_id, request)
451 .await
452 }
453
454 pub async fn device_aware_weighted(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
463 let state = self.occupancy_state()?;
464 let instance_ids = self
465 .client
466 .instance_ids_avail()
467 .iter()
468 .copied()
469 .collect::<Vec<_>>();
470
471 if instance_ids.is_empty() {
472 return Err(anyhow::anyhow!(
473 "no instances found for endpoint {}",
474 self.client.endpoint.id()
475 ));
476 }
477
478 let endpoint_id = self.client.endpoint.id();
480
481 let instances = self.client.instances();
483 let device_type_map: std::collections::HashMap<u64, Option<DeviceType>> = instances
484 .iter()
485 .map(|inst| (inst.instance_id, inst.device_type.clone()))
486 .collect();
487
488 let cuda_to_cpu_ratio = std::env::var("DYN_ENCODER_CUDA_TO_CPU_RATIO")
490 .ok()
491 .and_then(|v| v.parse::<usize>().ok())
492 .filter(|v| *v >= 1)
493 .unwrap_or(8);
494 let candidates = device_aware_candidate_group(
495 state.as_ref(),
496 &instance_ids,
497 &device_type_map,
498 cuda_to_cpu_ratio,
499 );
500
501 let instance_id = state
503 .select_exact_min_and_increment(&candidates)
504 .await
505 .ok_or_else(|| {
506 anyhow::anyhow!(
507 "no instances in selected device group for endpoint {}",
508 endpoint_id
509 )
510 })?;
511 let permit = OccupancyPermit::new(state.clone(), instance_id);
512 let is_cpu = matches!(
513 device_type_map.get(&instance_id),
514 Some(Some(DeviceType::Cpu))
515 );
516 tracing::info!(
517 endpoint = %endpoint_id,
518 selected_instance = instance_id,
519 is_cpu,
520 "DeviceAwareWeighted selected instance"
521 );
522
523 match self
524 .generate_with_fault_detection(instance_id, request)
525 .await
526 {
527 Ok(stream) => Ok(permit.into_tracked_stream(stream)),
528 Err(err) => Err(err),
529 }
530 }
531
532 pub async fn least_loaded(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
534 let state = self.occupancy_state()?;
535 let instance_ids = self
536 .client
537 .instance_ids_avail()
538 .iter()
539 .copied()
540 .collect::<Vec<_>>();
541 let instance_id = state
542 .select_exact_min_and_increment(&instance_ids)
543 .await
544 .ok_or_else(|| {
545 anyhow::anyhow!(
546 "no instances found for endpoint {}",
547 self.client.endpoint.id()
548 )
549 })?;
550 let permit = OccupancyPermit::new(state.clone(), instance_id);
551 tracing::trace!(
552 "least loaded router selected {instance_id} (connections: {})",
553 state.load(instance_id)
554 );
555
556 match self
557 .generate_with_fault_detection(instance_id, request)
558 .await
559 {
560 Ok(stream) => Ok(permit.into_tracked_stream(stream)),
561 Err(err) => Err(err),
562 }
563 }
564
565 pub fn select_next_worker(&self) -> Option<u64> {
569 let instance_ids = self.client.instance_ids_avail();
570 let count = instance_ids.len();
571 if count == 0 {
572 return None;
573 }
574
575 match self.router_mode {
576 RouterMode::RoundRobin => {
577 let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
578 Some(instance_ids[counter % count])
579 }
580 RouterMode::Random => {
581 let counter = rand::rng().random::<u64>() as usize;
582 Some(instance_ids[counter % count])
583 }
584 RouterMode::PowerOfTwoChoices
585 | RouterMode::Direct
586 | RouterMode::LeastLoaded
587 | RouterMode::DeviceAwareWeighted => None,
588 RouterMode::KV => {
589 panic!(
590 "select_next_worker should not be called for {:?} routing mode",
591 self.router_mode
592 )
593 }
594 }
595 }
596
597 pub fn peek_next_worker(&self) -> Option<u64> {
601 let instance_ids = self.client.instance_ids_avail();
602 let count = instance_ids.len();
603 if count == 0 {
604 return None;
605 }
606
607 match self.router_mode {
608 RouterMode::RoundRobin => {
609 let counter = self.round_robin_counter.load(Ordering::Relaxed) as usize;
611 Some(instance_ids[counter % count])
612 }
613 RouterMode::Random => {
614 let counter = rand::rng().random::<u64>() as usize;
617 Some(instance_ids[counter % count])
618 }
619 RouterMode::PowerOfTwoChoices
620 | RouterMode::Direct
621 | RouterMode::LeastLoaded
622 | RouterMode::DeviceAwareWeighted => None,
623 RouterMode::KV => {
624 panic!(
625 "peek_next_worker should not be called for {:?} routing mode",
626 self.router_mode
627 )
628 }
629 }
630 }
631
632 fn occupancy_state(&self) -> anyhow::Result<Arc<RoutingOccupancyState>> {
633 self.occupancy_state.clone().ok_or_else(|| {
634 anyhow::anyhow!(
635 "routing occupancy state not initialized for endpoint {}",
636 self.client.endpoint.id()
637 )
638 })
639 }
640
641 async fn generate_with_fault_detection(
652 &self,
653 mut instance_id: u64,
654 request: SingleIn<T>,
655 ) -> anyhow::Result<ManyOut<U>> {
656 let route_start = Instant::now();
657 let request_id = request.id().to_string();
658 let route_span = if matches!(self.router_mode, RouterMode::KV) {
659 tracing::Span::none()
660 } else {
661 tracing::info_span!(
662 "router.route_request",
663 request_id = %request_id,
664 worker_id = instance_id,
665 router_mode = ?self.router_mode,
666 )
667 };
668
669 if self.fault_detection_enabled {
671 let free_instances = self.client.instance_ids_free();
672 if free_instances.is_empty() {
673 let all_instances = self.client.instance_ids();
675 if !all_instances.is_empty() {
676 tracing::warn!(
677 instance_id,
678 total_workers = all_instances.len(),
679 "Rejecting request: all workers are busy"
680 );
681 let cause = PipelineError::ServiceOverloaded(
682 "All workers are busy, please retry later".to_string(),
683 );
684 return Err(DynamoError::builder()
685 .error_type(ErrorType::ResourceExhausted)
686 .message("All workers are busy, please retry later")
687 .cause(cause)
688 .build()
689 .into());
690 }
691 }
692 }
693
694 let (address, _transport_kind) = {
699 use crate::component::TransportType;
700
701 let resolve_transport = |id: u64| {
702 let instances = self.client.instances();
703 instances
704 .iter()
705 .find(|i| i.instance_id == id)
706 .map(|instance| match &instance.transport {
707 TransportType::Http(http_endpoint) => {
708 tracing::debug!(
709 instance_id = id,
710 http_endpoint = %http_endpoint,
711 "Using HTTP transport for instance"
712 );
713 (http_endpoint.clone(), "transport.http.request")
714 }
715 TransportType::Tcp(tcp_endpoint) => {
716 tracing::debug!(
717 instance_id = id,
718 tcp_endpoint = %tcp_endpoint,
719 "Using TCP transport for instance"
720 );
721 (tcp_endpoint.clone(), "transport.tcp.request")
722 }
723 TransportType::Nats(subject) => {
724 tracing::debug!(
725 instance_id = id,
726 subject = %subject,
727 "Using NATS transport for instance"
728 );
729 (subject.clone(), "transport.nats.request")
730 }
731 })
732 };
733
734 if let Some(result) = resolve_transport(instance_id) {
735 result
736 } else {
737 let avail = self.client.instance_ids_avail();
740 let fallback_id = avail.iter().copied().find(|&id| id != instance_id);
741 match fallback_id {
742 Some(id) => {
743 tracing::warn!(
744 original_instance = instance_id,
745 fallback_instance = id,
746 "Instance disappeared during routing, reselecting"
747 );
748 instance_id = id;
749 resolve_transport(id).ok_or_else(|| {
750 anyhow::anyhow!(
751 "Fallback instance {} also not found for endpoint {}",
752 id,
753 self.client.endpoint.id()
754 )
755 })?
756 }
757 None => {
758 return Err(anyhow::anyhow!(
759 "Instance {} not found and no other instances available \
760 for endpoint {}",
761 instance_id,
762 self.client.endpoint.id()
763 ));
764 }
765 }
766 }
767 };
768
769 let request = request.map(|req| AddressedRequest::new(req, address));
770
771 STAGE_DURATION_SECONDS
772 .with_label_values(&["route"])
773 .observe(route_start.elapsed().as_secs_f64());
774
775 let _nvtx_transport = dynamo_nvtx_range!(_transport_kind);
776 let stream: anyhow::Result<ManyOut<U>> = self
777 .addressed
778 .generate(request)
779 .instrument(route_span)
780 .await;
781 match stream {
782 Ok(stream) => {
783 if !self.fault_detection_enabled {
784 return Ok(stream);
785 }
786 let engine_ctx = stream.context();
787 let client = self.client.clone();
788 let client_for_timeout = self.client.clone();
789 let stream = stream.map(move |res| {
790 if let Some(err) = res.err()
792 && is_inhibited(&err)
793 {
794 tracing::debug!(
795 "Reporting instance {instance_id} down due to migratable error: {err}"
796 );
797 client.report_instance_down(instance_id);
798 }
799 res
800 });
801
802 let stream: Pin<Box<dyn Stream<Item = U> + Send>> = if let Some(timeout) =
806 self.response_timeout
807 {
808 Box::pin(async_stream::stream! {
809 let mut inner = Box::pin(stream);
810 loop {
811 tokio::select! {
812 biased;
813 item = inner.next() => {
814 match item {
815 Some(item) => yield item,
816 None => break,
817 }
818 }
819 _ = tokio::time::sleep(timeout) => {
820 tracing::warn!(
821 instance_id,
822 timeout_secs = timeout.as_secs(),
823 "backend response inactivity timeout — quarantining worker"
824 );
825 client_for_timeout.report_instance_down(instance_id);
826 yield U::from_err(
827 crate::error::DynamoError::builder()
828 .error_type(crate::error::ErrorType::ResponseTimeout)
829 .message("backend response inactivity timeout")
830 .build()
831 );
832 break;
833 }
834 }
835 }
836 })
837 } else {
838 Box::pin(stream)
839 };
840
841 Ok(ResponseStream::new(stream, engine_ctx))
842 }
843 Err(err) => {
844 if self.fault_detection_enabled && is_inhibited(err.as_ref()) {
845 tracing::debug!("Reporting instance {instance_id} down due to error: {err}");
846 self.client.report_instance_down(instance_id);
847 }
848 Err(err)
849 }
850 }
851 }
852}
853
854#[async_trait]
855impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
856where
857 T: Data + Serialize,
858 U: Data + for<'de> Deserialize<'de> + MaybeError,
859{
860 async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
861 match self.router_mode {
862 RouterMode::Random => self.random(request).await,
863 RouterMode::RoundRobin => self.round_robin(request).await,
864 RouterMode::PowerOfTwoChoices => self.power_of_two_choices(request).await,
865 RouterMode::KV => {
866 anyhow::bail!("KV routing should not call generate on PushRouter");
867 }
868 RouterMode::Direct => {
869 anyhow::bail!(
870 "Direct routing should not call generate on PushRouter directly; use DirectRoutingRouter wrapper"
871 );
872 }
873 RouterMode::LeastLoaded => self.least_loaded(request).await,
874 RouterMode::DeviceAwareWeighted => self.device_aware_weighted(request).await,
875 }
876 }
877}
878
879struct OccupancyTrackedStream<U: Data> {
880 inner: ManyOut<U>,
881 state: Arc<RoutingOccupancyState>,
882 instance_id: u64,
883}
884
885impl<U: Data> Drop for OccupancyTrackedStream<U> {
886 fn drop(&mut self) {
887 self.state.decrement(self.instance_id);
888 }
889}
890
891impl<U: Data> std::fmt::Debug for OccupancyTrackedStream<U> {
892 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
893 f.debug_struct("OccupancyTrackedStream")
894 .field("instance_id", &self.instance_id)
895 .finish()
896 }
897}
898
899impl<U: Data> Stream for OccupancyTrackedStream<U> {
900 type Item = U;
901
902 fn poll_next(
903 mut self: Pin<&mut Self>,
904 cx: &mut std::task::Context<'_>,
905 ) -> Poll<Option<Self::Item>> {
906 self.inner.as_mut().poll_next(cx)
907 }
908}
909
910impl<U: Data> AsyncEngineContextProvider for OccupancyTrackedStream<U> {
911 fn context(&self) -> Arc<dyn AsyncEngineContext> {
912 self.inner.context()
913 }
914}
915
916impl<U: Data> crate::engine::AsyncEngineStream<U> for OccupancyTrackedStream<U> {}
917
918#[cfg(test)]
919mod tests {
920 use super::*;
921 use crate::{
922 DistributedRuntime, Runtime,
923 distributed::DistributedConfig,
924 error::DynamoError,
925 pipeline::{ResponseStream, context::Controller},
926 };
927 use serde::{Deserialize, Serialize};
928
929 #[derive(Clone, Debug, Deserialize, Serialize)]
930 struct TestResponse {
931 error: Option<DynamoError>,
932 }
933
934 impl MaybeError for TestResponse {
935 fn from_err(err: impl std::error::Error + 'static) -> Self {
936 Self {
937 error: Some(DynamoError::from(
938 Box::new(err) as Box<dyn std::error::Error + 'static>
939 )),
940 }
941 }
942
943 fn err(&self) -> Option<DynamoError> {
944 self.error.clone()
945 }
946 }
947
948 #[test]
949 fn p2c_selects_lower_load_worker() {
950 let state = RoutingOccupancyState::default();
951 for _ in 0..10 {
952 state.increment(1);
953 }
954 state.increment(2);
955
956 let result = p2c_select_from(&state, &[1, 2]);
958 assert_eq!(result, 2);
959 }
960
961 #[test]
962 fn p2c_selects_single_worker() {
963 let state = RoutingOccupancyState::default();
964 assert_eq!(p2c_select_from(&state, &[42]), 42);
965 }
966
967 #[test]
968 fn p2c_treats_missing_counts_as_zero() {
969 let state = RoutingOccupancyState::default();
970 for _ in 0..5 {
971 state.increment(1);
972 }
973 let result = p2c_select_from(&state, &[1, 2]);
975 assert_eq!(result, 2);
976 }
977
978 #[test]
979 fn p2c_returns_valid_worker_on_tie() {
980 let state = RoutingOccupancyState::default();
981 for _ in 0..3 {
982 state.increment(1);
983 state.increment(2);
984 }
985
986 for _ in 0..100 {
987 let result = p2c_select_from(&state, &[1, 2]);
988 assert!(result == 1 || result == 2);
989 }
990 }
991
992 #[test]
993 fn occupancy_permit_decrements_before_stream_creation() {
994 let state = Arc::new(RoutingOccupancyState::default());
995 state.increment(42);
996 let permit = OccupancyPermit::new(state.clone(), 42);
997 assert_eq!(state.load(42), 1);
998 drop(permit);
999 assert_eq!(state.load(42), 0);
1000 }
1001
1002 #[test]
1003 fn occupancy_tracked_stream_decrements_on_drop() {
1004 let state = Arc::new(RoutingOccupancyState::default());
1005 state.increment(7);
1006 let permit = OccupancyPermit::new(state.clone(), 7);
1007 let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
1008 let stream = permit.into_tracked_stream(ResponseStream::new(
1009 Box::pin(tokio_stream::iter(vec![1u64])),
1010 ctx,
1011 ));
1012 assert_eq!(state.load(7), 1);
1013 drop(stream);
1014 assert_eq!(state.load(7), 0);
1015 }
1016
1017 #[test]
1018 fn p2c_lifecycle_tracks_inflight_counts_with_shared_tracker() {
1019 let state = Arc::new(RoutingOccupancyState::default());
1020 let mut permits = Vec::new();
1021 for _ in 0..5 {
1022 let selected = p2c_select_from(&state, &[1, 2]);
1023 state.increment(selected);
1024 permits.push(OccupancyPermit::new(state.clone(), selected));
1025 }
1026
1027 let total = state.load(1) + state.load(2);
1028 assert_eq!(total, 5, "5 in-flight requests should be tracked");
1029
1030 drop(permits);
1031 let total = state.load(1) + state.load(2);
1032 assert_eq!(total, 0, "All guards dropped, counts should be 0");
1033 }
1034
1035 #[test]
1036 fn p2c_never_selects_dominated_worker() {
1037 let state = RoutingOccupancyState::default();
1038 for _ in 0..100 {
1039 state.increment(3);
1040 }
1041
1042 let mut selected = [0u32; 3];
1043 for _ in 0..1000 {
1044 let result = p2c_select_from(&state, &[1, 2, 3]);
1045 match result {
1046 1 => selected[0] += 1,
1047 2 => selected[1] += 1,
1048 3 => selected[2] += 1,
1049 _ => panic!("unexpected worker id"),
1050 }
1051 }
1052 assert_eq!(
1053 selected[2], 0,
1054 "Worker 3 (load=100) should never be selected against load=0 workers, but got {} times",
1055 selected[2]
1056 );
1057 }
1058
1059 #[tokio::test]
1060 async fn least_loaded_selects_exact_min_and_tracks_counts() {
1061 let state = Arc::new(RoutingOccupancyState::default());
1062 state.increment(1);
1063 state.increment(1);
1064 state.increment(2);
1065
1066 let selected = state
1067 .select_exact_min_and_increment(&[1, 2, 3])
1068 .await
1069 .unwrap();
1070 assert_eq!(selected, 3);
1071
1072 let permit = OccupancyPermit::new(state.clone(), selected);
1073 assert_eq!(state.load(selected), 1);
1074 drop(permit);
1075 assert_eq!(state.load(selected), 0);
1076 }
1077
1078 #[tokio::test]
1079 async fn least_loaded_select_and_peek_return_none_with_available_worker() {
1080 let rt = Runtime::from_current().unwrap();
1081 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1082 .await
1083 .unwrap();
1084 let ns = drt
1085 .namespace("test_least_loaded_router".to_string())
1086 .unwrap();
1087 let component = ns.component("test_component".to_string()).unwrap();
1088 let endpoint = component.endpoint("test_endpoint".to_string());
1089 let client = endpoint.client().await.unwrap();
1090
1091 endpoint.register_endpoint_instance().await.unwrap();
1092 client.wait_for_instances().await.unwrap();
1093
1094 let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::LeastLoaded)
1095 .await
1096 .unwrap();
1097
1098 assert_eq!(router.select_next_worker(), None);
1099 assert_eq!(router.peek_next_worker(), None);
1100
1101 rt.shutdown();
1102 }
1103
1104 #[tokio::test]
1105 async fn device_aware_cpu_only_selects_least_loaded_instance() {
1106 let state = RoutingOccupancyState::default();
1107 for _ in 0..3 {
1109 state.increment(1);
1110 }
1111 state.increment(3);
1112
1113 let instance_ids = vec![1, 2, 3];
1114 let device_type_map = HashMap::from([
1115 (1, Some(DeviceType::Cpu)),
1116 (2, Some(DeviceType::Cpu)),
1117 (3, Some(DeviceType::Cpu)),
1118 ]);
1119
1120 let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
1121 assert_eq!(candidates, vec![1, 2, 3]);
1122
1123 let selected = state
1124 .select_exact_min_and_increment(&candidates)
1125 .await
1126 .unwrap();
1127 assert_eq!(selected, 2);
1128 }
1129
1130 #[tokio::test]
1131 async fn device_aware_non_cpu_only_selects_least_loaded_instance() {
1132 let state = RoutingOccupancyState::default();
1133 for _ in 0..3 {
1135 state.increment(1);
1136 }
1137 state.increment(3);
1138
1139 let instance_ids = vec![1, 2, 3];
1140 let device_type_map = HashMap::from([
1141 (1, Some(DeviceType::Cuda)),
1142 (2, Some(DeviceType::Cuda)),
1143 (3, Some(DeviceType::Cuda)),
1144 ]);
1145
1146 let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
1147 assert_eq!(candidates, vec![1, 2, 3]);
1148
1149 let selected = state
1150 .select_exact_min_and_increment(&candidates)
1151 .await
1152 .unwrap();
1153 assert_eq!(selected, 2);
1154 }
1155
1156 #[test]
1157 fn device_aware_group_uses_ratio_budget() {
1158 let state = RoutingOccupancyState::default();
1159 for _ in 0..4 {
1161 state.increment(3);
1162 state.increment(4);
1163 }
1164 for _ in 0..3 {
1166 state.increment(1);
1167 }
1168 let instance_ids = vec![1, 2, 3, 4];
1172 let device_type_map = HashMap::from([
1173 (1, Some(DeviceType::Cpu)),
1174 (2, Some(DeviceType::Cpu)),
1175 (3, Some(DeviceType::Cuda)),
1176 (4, Some(DeviceType::Cuda)),
1177 ]);
1178
1179 let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 2);
1180 assert_eq!(candidates, vec![1, 2]);
1181
1182 let selected =
1184 futures::executor::block_on(state.select_exact_min_and_increment(&candidates)).unwrap();
1185 assert_eq!(selected, 2);
1186 }
1187
1188 #[tokio::test]
1189 async fn device_aware_weighted_select_and_peek_return_none_with_available_worker() {
1190 let rt = Runtime::from_current().unwrap();
1191 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1192 .await
1193 .unwrap();
1194 let ns = drt
1195 .namespace("test_device_aware_router".to_string())
1196 .unwrap();
1197 let component = ns.component("test_component".to_string()).unwrap();
1198 let endpoint = component.endpoint("test_endpoint".to_string());
1199 let client = endpoint.client().await.unwrap();
1200
1201 endpoint.register_endpoint_instance().await.unwrap();
1202 client.wait_for_instances().await.unwrap();
1203
1204 let router =
1205 PushRouter::<u64, TestResponse>::from_client(client, RouterMode::DeviceAwareWeighted)
1206 .await
1207 .unwrap();
1208
1209 assert_eq!(router.select_next_worker(), None);
1210 assert_eq!(router.peek_next_worker(), None);
1211
1212 rt.shutdown();
1213 }
1214
1215 #[tokio::test]
1219 async fn transport_resolution_falls_back_when_selected_instance_disappears() {
1220 let rt = Runtime::from_current().unwrap();
1221 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1222 .await
1223 .unwrap();
1224 let ns = drt
1225 .namespace("test_transport_fallback".to_string())
1226 .unwrap();
1227 let component = ns.component("test_component".to_string()).unwrap();
1228 let endpoint = component.endpoint("test_endpoint".to_string());
1229 let client = endpoint.client().await.unwrap();
1230
1231 endpoint.register_endpoint_instance().await.unwrap();
1233 client.wait_for_instances().await.unwrap();
1234
1235 let real_id = client.instance_ids()[0];
1236
1237 let stale_id = real_id + 1000;
1241 client.override_instance_avail(vec![stale_id, real_id]);
1242
1243 let router =
1246 PushRouter::<u64, TestResponse>::from_client(client.clone(), RouterMode::RoundRobin)
1247 .await
1248 .unwrap();
1249
1250 let request = SingleIn::new(42u64);
1257 let result = router.generate(request).await;
1258
1259 if let Err(err) = &result {
1263 let msg = format!("{err}");
1264 assert!(
1265 !msg.contains("not found"),
1266 "Transport resolution should have fallen back, but got: {msg}"
1267 );
1268 }
1269
1270 rt.shutdown();
1271 }
1272
1273 #[tokio::test]
1276 async fn transport_resolution_errors_when_no_instances_available() {
1277 let rt = Runtime::from_current().unwrap();
1278 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1279 .await
1280 .unwrap();
1281 let ns = drt
1282 .namespace("test_transport_no_fallback".to_string())
1283 .unwrap();
1284 let component = ns.component("test_component".to_string()).unwrap();
1285 let endpoint = component.endpoint("test_endpoint".to_string());
1286 let client = endpoint.client().await.unwrap();
1287
1288 endpoint.register_endpoint_instance().await.unwrap();
1290 client.wait_for_instances().await.unwrap();
1291
1292 let router =
1293 PushRouter::<u64, TestResponse>::from_client(client.clone(), RouterMode::RoundRobin)
1294 .await
1295 .unwrap();
1296
1297 let stale_id = 99999;
1300 client.override_instance_avail(vec![stale_id]);
1301
1302 let request = SingleIn::new(42u64);
1303 let result = router.generate(request).await;
1304
1305 assert!(result.is_err());
1306 let msg = format!("{}", result.unwrap_err());
1307 assert!(
1308 msg.contains("not found") && msg.contains("no other instances available"),
1309 "Expected clear error about missing instance with no fallback, got: {msg}"
1310 );
1311
1312 rt.shutdown();
1313 }
1314}