use super::{AsyncEngineContextProvider, ResponseStream};
use crate::error::{BackendError, DynamoError, ErrorType, match_error_chain};
use crate::{
component::{
Client, DeviceType, Endpoint, Instance, RoutingInstances, RoutingOccupancyState,
get_or_create_routing_occupancy_state,
},
discovery::EndpointInstanceId,
dynamo_nvtx_range,
engine::{AsyncEngine, AsyncEngineContext, Data},
metrics::frontend_perf::{STAGE_DURATION_SECONDS, STAGE_ROUTE},
pipeline::{
AddressedPushRouter, AddressedRequest, Error, ManyIn, ManyOut, SingleIn,
error::{PipelineError, PipelineErrorExt},
},
protocols::{EndpointId, maybe_error::MaybeError},
traits::DistributedRuntimeProvider,
};
use async_trait::async_trait;
use futures::Stream;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
marker::PhantomData,
pin::Pin,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
task::Poll,
time::Instant,
};
use tokio_stream::StreamExt;
use tracing::Instrument;
fn is_inhibited(err: &(dyn std::error::Error + 'static)) -> bool {
const INHIBITED: &[ErrorType] = &[
ErrorType::CannotConnect,
ErrorType::Disconnected,
ErrorType::ConnectionTimeout,
ErrorType::ResponseTimeout,
ErrorType::Backend(BackendError::EngineShutdown),
];
match_error_chain(err, INHIBITED, &[])
}
fn response_inactivity_timeout() -> Option<std::time::Duration> {
use crate::config::environment_names::llm::DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS;
std::env::var(DYN_HTTP_BACKEND_STREAM_TIMEOUT_SECS)
.ok()
.and_then(|s| s.parse::<u64>().ok())
.filter(|&secs| secs > 0)
.map(std::time::Duration::from_secs)
}
struct OccupancyPermit {
state: Arc<RoutingOccupancyState>,
instance_id: u64,
armed: bool,
}
impl OccupancyPermit {
fn new(state: Arc<RoutingOccupancyState>, instance_id: u64) -> Self {
Self {
state,
instance_id,
armed: true,
}
}
fn into_tracked_stream<U: Data>(mut self, stream: ManyOut<U>) -> ManyOut<U> {
self.armed = false;
let engine_ctx = stream.context();
ResponseStream::new(
Box::pin(OccupancyTrackedStream {
inner: stream,
state: self.state.clone(),
instance_id: self.instance_id,
}),
engine_ctx,
)
}
fn instance_id(&self) -> u64 {
self.instance_id
}
}
impl Drop for OccupancyPermit {
fn drop(&mut self) {
if self.armed {
self.state.decrement(self.instance_id);
}
}
}
#[async_trait]
pub trait WorkerLoadMonitor: Send + Sync {
async fn start_monitoring(&self) -> anyhow::Result<()>;
}
#[derive(Clone)]
pub struct PushRouter<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
{
pub client: Client,
router_mode: RouterMode,
round_robin_counter: Arc<AtomicU64>,
addressed: Arc<AddressedPushRouter>,
fault_detection_enabled: bool,
response_timeout: Option<std::time::Duration>,
occupancy_state: Option<Arc<RoutingOccupancyState>>,
_phantom: PhantomData<(T, U)>,
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RouterMode {
#[default]
RoundRobin,
Random,
PowerOfTwoChoices,
KV,
Direct,
LeastLoaded,
DeviceAwareWeighted,
}
impl RouterMode {
pub fn is_kv_routing(&self) -> bool {
*self == RouterMode::KV
}
pub fn is_direct_routing(&self) -> bool {
*self == RouterMode::Direct
}
}
fn p2c_select_from(occupancy_state: &RoutingOccupancyState, instance_ids: &[u64]) -> u64 {
let count = instance_ids.len();
if count == 1 {
return instance_ids[0];
}
let mut rng = rand::rng();
let idx1 = rng.random_range(0..count);
let idx2 = (idx1 + 1 + rng.random_range(0..count - 1)) % count;
let id1 = instance_ids[idx1];
let id2 = instance_ids[idx2];
let load1 = occupancy_state.load(id1);
let load2 = occupancy_state.load(id2);
let selected = if load1 <= load2 { id1 } else { id2 };
tracing::debug!(
candidate_a = id1,
candidate_a_load = load1,
candidate_b = id2,
candidate_b_load = load2,
selected = selected,
"p2c selection"
);
selected
}
fn device_aware_candidate_group(
state: &RoutingOccupancyState,
instance_ids: &[u64],
device_type_map: &HashMap<u64, Option<DeviceType>>,
non_cpu_to_cpu_ratio: usize,
) -> Vec<u64> {
let cpu_ids: Vec<u64> = instance_ids
.iter()
.copied()
.filter(|id| matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
.collect();
let non_cpu_ids: Vec<u64> = instance_ids
.iter()
.copied()
.filter(|id| !matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
.collect();
if cpu_ids.is_empty() {
return non_cpu_ids;
}
if non_cpu_ids.is_empty() {
return cpu_ids;
}
let total_non_cpu_inflight: u64 = non_cpu_ids.iter().map(|id| state.load(*id)).sum();
let total_cpu_inflight: u64 = cpu_ids.iter().map(|id| state.load(*id)).sum();
let cpu_count = cpu_ids.len() as u64;
let non_cpu_count = non_cpu_ids.len() as u64;
let allowed_cpu_inflight = total_non_cpu_inflight.saturating_mul(cpu_count)
/ ((non_cpu_to_cpu_ratio as u64).saturating_mul(non_cpu_count));
if total_cpu_inflight < allowed_cpu_inflight {
cpu_ids
} else {
non_cpu_ids
}
}
static ENDPOINT_WATCHER_ACTIVE: std::sync::OnceLock<dashmap::DashMap<EndpointId, ()>> =
std::sync::OnceLock::new();
fn spawn_instance_removal_watcher(
endpoint: Endpoint,
addressed: Arc<AddressedPushRouter>,
cancel_token: tokio_util::sync::CancellationToken,
) {
use crate::discovery::{
DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
};
use tokio_stream::StreamExt as _;
let guard = ENDPOINT_WATCHER_ACTIVE.get_or_init(dashmap::DashMap::new);
let endpoint_id = endpoint.id();
if guard.insert(endpoint_id.clone(), ()).is_some() {
tracing::debug!(
?endpoint_id,
"Instance removal watcher already running for this endpoint, skipping"
);
return;
}
let endpoint_name = endpoint.name().to_string();
tokio::spawn(async move {
struct GuardRelease(EndpointId);
impl Drop for GuardRelease {
fn drop(&mut self) {
if let Some(map) = ENDPOINT_WATCHER_ACTIVE.get() {
map.remove(&self.0);
}
}
}
let _release = GuardRelease(endpoint_id);
let namespace = endpoint.component().namespace().name();
let component = endpoint.component().name().to_string();
const RECONNECT_BACKOFF: std::time::Duration = std::time::Duration::from_secs(5);
'reconnect: loop {
let query = DiscoveryQuery::Endpoint {
namespace: namespace.clone(),
component: component.clone(),
endpoint: endpoint_name.clone(),
};
let mut stream = match endpoint.drt().discovery().list_and_watch(query, None).await {
Ok(s) => s,
Err(e) => {
tracing::warn!(
endpoint = %endpoint_name,
"Failed to start instance removal watcher (will retry): {e}"
);
tokio::select! {
_ = tokio::time::sleep(RECONNECT_BACKOFF) => continue 'reconnect,
_ = cancel_token.cancelled() => break 'reconnect,
}
}
};
loop {
tokio::select! {
event = stream.next() => {
match event {
Some(Ok(DiscoveryEvent::Removed(id))) => {
if let DiscoveryInstanceId::Endpoint(eid) = &id {
let n = addressed.cancel_instance_streams(eid).await;
if n > 0 {
tracing::warn!(
namespace = %eid.namespace,
component = %eid.component,
endpoint = %eid.endpoint,
instance_id = eid.instance_id,
cancelled = n,
"Cancelled pending response streams for removed \
instance (discovery-driven cleanup)"
);
}
}
}
Some(Ok(DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)))) => {
let eid: EndpointInstanceId = inst.endpoint_instance_id();
addressed.clear_instance_tombstone(&eid).await;
}
Some(Ok(_)) => {}
Some(Err(e)) => {
tracing::warn!(
endpoint = %endpoint_name,
"Instance removal watcher stream error: {e}"
);
}
None => {
tracing::warn!(
endpoint = %endpoint_name,
"Instance removal watcher stream ended; reconnecting"
);
continue 'reconnect;
}
}
}
_ = cancel_token.cancelled() => {
break 'reconnect;
}
}
}
}
tracing::debug!(endpoint = %endpoint_name, "Instance removal watcher exiting");
});
}
async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
AddressedPushRouter::from_runtime_provider(endpoint).await
}
impl<T, U> PushRouter<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de> + MaybeError,
{
pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
Self::from_client_with_monitor(client, router_mode, None).await
}
pub async fn from_client_no_fault_detection(
client: Client,
router_mode: RouterMode,
) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?;
let occupancy_state = if matches!(
router_mode,
RouterMode::PowerOfTwoChoices
| RouterMode::LeastLoaded
| RouterMode::DeviceAwareWeighted
) {
Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
} else {
None
};
spawn_instance_removal_watcher(
client.endpoint.clone(),
addressed.clone(),
client.endpoint.drt().primary_token(),
);
Ok(PushRouter {
client,
addressed,
router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)),
fault_detection_enabled: false,
response_timeout: response_inactivity_timeout(),
occupancy_state,
_phantom: PhantomData,
})
}
pub async fn from_client_with_monitor(
client: Client,
router_mode: RouterMode,
worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?;
if let Some(monitor) = worker_monitor.as_ref() {
monitor.start_monitoring().await?;
}
let occupancy_state = if matches!(
router_mode,
RouterMode::PowerOfTwoChoices
| RouterMode::LeastLoaded
| RouterMode::DeviceAwareWeighted
) {
Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
} else {
None
};
spawn_instance_removal_watcher(
client.endpoint.clone(),
addressed.clone(),
client.endpoint.drt().primary_token(),
);
let router = PushRouter {
client,
addressed,
router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)),
fault_detection_enabled: true,
response_timeout: response_inactivity_timeout(),
occupancy_state,
_phantom: PhantomData,
};
Ok(router)
}
fn empty_free_pool_error(&self, routing_instances: &RoutingInstances) -> anyhow::Error {
if !routing_instances.routable_ids().is_empty() {
let cause = PipelineError::ServiceOverloaded(
"All workers are busy, please retry later".to_string(),
);
return DynamoError::builder()
.error_type(ErrorType::ResourceExhausted)
.message("All workers are busy, please retry later")
.cause(cause)
.build()
.into();
}
anyhow::anyhow!(
"no instances found for endpoint {}",
self.client.endpoint.id()
)
}
pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
let instance_id = {
let routing_instances = self.client.routing_instances();
let count = routing_instances.free_ids().len();
if count == 0 {
return Err(self.empty_free_pool_error(&routing_instances));
}
routing_instances.free_ids()[counter % count]
};
tracing::trace!("round robin router selected {instance_id}");
self.generate_with_fault_detection(instance_id, request)
.await
}
pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let instance_id = {
let routing_instances = self.client.routing_instances();
let count = routing_instances.free_ids().len();
if count == 0 {
return Err(self.empty_free_pool_error(&routing_instances));
}
let counter = rand::rng().random::<u64>() as usize;
routing_instances.free_ids()[counter % count]
};
tracing::trace!("random router selected {instance_id}");
self.generate_with_fault_detection(instance_id, request)
.await
}
pub async fn power_of_two_choices(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let state = self.occupancy_state()?;
let instance_id = {
let routing_instances = self.client.routing_instances();
if routing_instances.free_ids().is_empty() {
return Err(self.empty_free_pool_error(&routing_instances));
}
p2c_select_from(state.as_ref(), routing_instances.free_ids())
};
state.increment(instance_id);
let permit = OccupancyPermit::new(state, instance_id);
match self
.generate_with_fault_detection(instance_id, request)
.await
{
Ok(stream) => Ok(permit.into_tracked_stream(stream)),
Err(err) => Err(err),
}
}
pub async fn direct(
&self,
request: SingleIn<T>,
instance_id: u64,
) -> anyhow::Result<ManyOut<U>> {
let found = {
if self.fault_detection_enabled {
let routing_instances = self.client.routing_instances();
routing_instances.routable_ids().contains(&instance_id)
} else {
self.client.instance_ids().contains(&instance_id)
}
};
if !found {
return Err(anyhow::anyhow!(
"instance_id={instance_id} not found for endpoint {}",
self.client.endpoint.id()
));
}
self.generate_with_fault_detection(instance_id, request)
.await
}
pub async fn device_aware_weighted(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let state = self.occupancy_state()?;
let routing_instances = self.client.routing_instances();
let instance_ids = routing_instances.free_ids().to_vec();
if instance_ids.is_empty() {
return Err(self.empty_free_pool_error(&routing_instances));
}
let endpoint_id = self.client.endpoint.id();
let instances = self.client.instances();
let device_type_map: std::collections::HashMap<u64, Option<DeviceType>> = instances
.iter()
.map(|inst| (inst.instance_id, inst.device_type.clone()))
.collect();
let cuda_to_cpu_ratio = std::env::var("DYN_ENCODER_CUDA_TO_CPU_RATIO")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|v| *v >= 1)
.unwrap_or(8);
let candidates = device_aware_candidate_group(
state.as_ref(),
&instance_ids,
&device_type_map,
cuda_to_cpu_ratio,
);
let instance_id = state
.select_exact_min_and_increment(&candidates)
.await
.ok_or_else(|| self.empty_free_pool_error(&routing_instances))?;
let permit = OccupancyPermit::new(state.clone(), instance_id);
let is_cpu = matches!(
device_type_map.get(&instance_id),
Some(Some(DeviceType::Cpu))
);
tracing::info!(
endpoint = %endpoint_id,
selected_instance = instance_id,
is_cpu,
"DeviceAwareWeighted selected instance"
);
match self
.generate_with_fault_detection(instance_id, request)
.await
{
Ok(stream) => Ok(permit.into_tracked_stream(stream)),
Err(err) => Err(err),
}
}
pub async fn least_loaded(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let state = self.occupancy_state()?;
let routing_instances = self.client.routing_instances();
let instance_ids = routing_instances.free_ids().to_vec();
let instance_id = state
.select_exact_min_and_increment(&instance_ids)
.await
.ok_or_else(|| self.empty_free_pool_error(&routing_instances))?;
let permit = OccupancyPermit::new(state.clone(), instance_id);
tracing::trace!(
"least loaded router selected {instance_id} (connections: {})",
state.load(instance_id)
);
match self
.generate_with_fault_detection(instance_id, request)
.await
{
Ok(stream) => Ok(permit.into_tracked_stream(stream)),
Err(err) => Err(err),
}
}
pub fn select_next_worker(&self) -> Option<u64> {
let routing_instances = self.client.routing_instances();
let count = routing_instances.free_ids().len();
if count == 0 {
return None;
}
match self.router_mode {
RouterMode::RoundRobin => {
let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
Some(routing_instances.free_ids()[counter % count])
}
RouterMode::Random => {
let counter = rand::rng().random::<u64>() as usize;
Some(routing_instances.free_ids()[counter % count])
}
RouterMode::PowerOfTwoChoices
| RouterMode::Direct
| RouterMode::LeastLoaded
| RouterMode::DeviceAwareWeighted => None,
RouterMode::KV => {
panic!(
"select_next_worker should not be called for {:?} routing mode",
self.router_mode
)
}
}
}
pub fn peek_next_worker(&self) -> Option<u64> {
let routing_instances = self.client.routing_instances();
let count = routing_instances.free_ids().len();
if count == 0 {
return None;
}
match self.router_mode {
RouterMode::RoundRobin => {
let counter = self.round_robin_counter.load(Ordering::Relaxed) as usize;
Some(routing_instances.free_ids()[counter % count])
}
RouterMode::Random => {
let counter = rand::rng().random::<u64>() as usize;
Some(routing_instances.free_ids()[counter % count])
}
RouterMode::PowerOfTwoChoices
| RouterMode::Direct
| RouterMode::LeastLoaded
| RouterMode::DeviceAwareWeighted => None,
RouterMode::KV => {
panic!(
"peek_next_worker should not be called for {:?} routing mode",
self.router_mode
)
}
}
}
fn occupancy_state(&self) -> anyhow::Result<Arc<RoutingOccupancyState>> {
self.occupancy_state.clone().ok_or_else(|| {
anyhow::anyhow!(
"routing occupancy state not initialized for endpoint {}",
self.client.endpoint.id()
)
})
}
async fn generate_with_fault_detection(
&self,
mut instance_id: u64,
request: SingleIn<T>,
) -> anyhow::Result<ManyOut<U>> {
let route_start = Instant::now();
let request_id = request.id().to_string();
let route_span = if matches!(self.router_mode, RouterMode::KV) {
tracing::Span::none()
} else {
tracing::info_span!(
"router.route_request",
request_id = %request_id,
worker_id = instance_id,
router_mode = ?self.router_mode,
)
};
if self.fault_detection_enabled {
let routing_instances = self.client.routing_instances();
let selected_worker_overloaded = routing_instances.is_overloaded(instance_id);
let counts = routing_instances.counts();
if tracing::enabled!(tracing::Level::DEBUG) {
tracing::debug!(
request_id = %request_id,
instance_id,
router_mode = ?self.router_mode,
free_workers = counts.free,
overloaded_workers = counts.overloaded,
total_workers = counts.discovered,
selected_worker_overloaded,
"checked worker overload state"
);
}
if selected_worker_overloaded {
tracing::warn!(
instance_id,
overloaded_workers = counts.overloaded,
total_workers = counts.discovered,
"Rejecting request: selected worker is overloaded"
);
let cause = PipelineError::ServiceOverloaded(
"Selected worker is overloaded, please retry later".to_string(),
);
return Err(DynamoError::builder()
.error_type(ErrorType::ResourceExhausted)
.message("Selected worker is overloaded, please retry later")
.cause(cause)
.build()
.into());
}
}
let (address, _transport_kind, instance) = {
use crate::component::TransportType;
let resolve_transport = |id: u64| {
let instances = self.client.instances();
instances
.iter()
.find(|i| i.instance_id == id)
.map(|instance| {
let (addr, kind) = match &instance.transport {
TransportType::Tcp(tcp_endpoint) => {
tracing::debug!(
instance_id = id,
tcp_endpoint = %tcp_endpoint,
"Using TCP transport for instance"
);
(tcp_endpoint.clone(), "transport.tcp.request")
}
TransportType::Nats(subject) => {
tracing::debug!(
instance_id = id,
subject = %subject,
"Using NATS transport for instance"
);
(subject.clone(), "transport.nats.request")
}
};
(addr, kind, instance.clone())
})
};
if let Some(result) = resolve_transport(instance_id) {
result
} else {
let routing_instances = self.client.routing_instances();
let fallback_id = routing_instances
.free_ids()
.iter()
.copied()
.find(|&id| id != instance_id);
match fallback_id {
Some(id) => {
tracing::warn!(
original_instance = instance_id,
fallback_instance = id,
"Instance disappeared during routing, reselecting"
);
instance_id = id;
resolve_transport(id).ok_or_else(|| {
anyhow::anyhow!(
"Fallback instance {} also not found for endpoint {}",
id,
self.client.endpoint.id()
)
})?
}
None => {
return Err(anyhow::anyhow!(
"Instance {} not found and no other instances available \
for endpoint {}",
instance_id,
self.client.endpoint.id()
));
}
}
}
};
let request = request.map(|req| AddressedRequest::with_instance(req, address, instance));
STAGE_DURATION_SECONDS
.with_label_values(&[STAGE_ROUTE])
.observe(route_start.elapsed().as_secs_f64());
let _nvtx_transport = dynamo_nvtx_range!(_transport_kind);
let stream: anyhow::Result<ManyOut<U>> = self
.addressed
.generate(request)
.instrument(route_span)
.await;
match stream {
Ok(stream) => {
if !self.fault_detection_enabled {
return Ok(stream);
}
let engine_ctx = stream.context();
let client = self.client.clone();
let client_for_timeout = self.client.clone();
let stream = stream.map(move |res| {
if let Some(err) = res.err()
&& is_inhibited(&err)
{
tracing::debug!(
"Reporting instance {instance_id} down due to migratable error: {err}"
);
client.report_instance_down(instance_id);
}
res
});
let stream: Pin<Box<dyn Stream<Item = U> + Send>> = if let Some(timeout) =
self.response_timeout
{
Box::pin(async_stream::stream! {
let mut inner = Box::pin(stream);
loop {
tokio::select! {
biased;
item = inner.next() => {
match item {
Some(item) => yield item,
None => break,
}
}
_ = tokio::time::sleep(timeout) => {
tracing::warn!(
instance_id,
timeout_secs = timeout.as_secs(),
"backend response inactivity timeout — quarantining worker"
);
client_for_timeout.report_instance_down(instance_id);
yield U::from_err(
crate::error::DynamoError::builder()
.error_type(crate::error::ErrorType::ResponseTimeout)
.message("backend response inactivity timeout")
.build()
);
break;
}
}
}
})
} else {
Box::pin(stream)
};
Ok(ResponseStream::new(stream, engine_ctx))
}
Err(err) => {
if self.fault_detection_enabled && is_inhibited(err.as_ref()) {
tracing::debug!("Reporting instance {instance_id} down due to error: {err}");
self.client.report_instance_down(instance_id);
}
Err(err)
}
}
}
}
#[async_trait]
impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de> + MaybeError,
{
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
match self.router_mode {
RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::PowerOfTwoChoices => self.power_of_two_choices(request).await,
RouterMode::KV => {
anyhow::bail!("KV routing should not call generate on PushRouter");
}
RouterMode::Direct => {
anyhow::bail!(
"Direct routing should not call generate on PushRouter directly; use DirectRoutingRouter wrapper"
);
}
RouterMode::LeastLoaded => self.least_loaded(request).await,
RouterMode::DeviceAwareWeighted => self.device_aware_weighted(request).await,
}
}
}
#[async_trait]
impl<T, U> AsyncEngine<ManyIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de> + MaybeError,
{
async fn generate(&self, mut input: ManyIn<T>) -> Result<ManyOut<U>, Error> {
match self.router_mode {
RouterMode::KV => {
anyhow::bail!("KV routing should not call generate on PushRouter");
}
RouterMode::Direct => {
anyhow::bail!(
"Direct routing should not call generate on PushRouter directly; use DirectRoutingRouter wrapper"
);
}
_ => {}
}
if input.next().await.is_none() {
anyhow::bail!("bidirectional input stream closed before first frame");
}
let instance_id = self
.select_next_worker()
.ok_or_else(|| anyhow::anyhow!("no instances available for bidirectional routing"))?;
anyhow::bail!(
"bidirectional remote dispatch is not yet implemented (selected instance {instance_id})"
)
}
}
struct OccupancyTrackedStream<U: Data> {
inner: ManyOut<U>,
state: Arc<RoutingOccupancyState>,
instance_id: u64,
}
impl<U: Data> Drop for OccupancyTrackedStream<U> {
fn drop(&mut self) {
self.state.decrement(self.instance_id);
}
}
impl<U: Data> std::fmt::Debug for OccupancyTrackedStream<U> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OccupancyTrackedStream")
.field("instance_id", &self.instance_id)
.finish()
}
}
impl<U: Data> Stream for OccupancyTrackedStream<U> {
type Item = U;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
impl<U: Data> AsyncEngineContextProvider for OccupancyTrackedStream<U> {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
self.inner.context()
}
}
impl<U: Data> crate::engine::AsyncEngineStream<U> for OccupancyTrackedStream<U> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
DistributedRuntime, Runtime,
distributed::DistributedConfig,
error::DynamoError,
pipeline::{ResponseStream, context::Controller},
};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Deserialize, Serialize)]
struct TestResponse {
error: Option<DynamoError>,
}
impl MaybeError for TestResponse {
fn from_err(err: impl std::error::Error + 'static) -> Self {
Self {
error: Some(DynamoError::from(
Box::new(err) as Box<dyn std::error::Error + 'static>
)),
}
}
fn err(&self) -> Option<DynamoError> {
self.error.clone()
}
}
#[test]
fn p2c_selects_lower_load_worker() {
let state = RoutingOccupancyState::default();
for _ in 0..10 {
state.increment(1);
}
state.increment(2);
let result = p2c_select_from(&state, &[1, 2]);
assert_eq!(result, 2);
}
#[test]
fn p2c_selects_single_worker() {
let state = RoutingOccupancyState::default();
assert_eq!(p2c_select_from(&state, &[42]), 42);
}
#[test]
fn p2c_treats_missing_counts_as_zero() {
let state = RoutingOccupancyState::default();
for _ in 0..5 {
state.increment(1);
}
let result = p2c_select_from(&state, &[1, 2]);
assert_eq!(result, 2);
}
#[test]
fn p2c_returns_valid_worker_on_tie() {
let state = RoutingOccupancyState::default();
for _ in 0..3 {
state.increment(1);
state.increment(2);
}
for _ in 0..100 {
let result = p2c_select_from(&state, &[1, 2]);
assert!(result == 1 || result == 2);
}
}
#[test]
fn occupancy_permit_decrements_before_stream_creation() {
let state = Arc::new(RoutingOccupancyState::default());
state.increment(42);
let permit = OccupancyPermit::new(state.clone(), 42);
assert_eq!(state.load(42), 1);
drop(permit);
assert_eq!(state.load(42), 0);
}
#[test]
fn occupancy_tracked_stream_decrements_on_drop() {
let state = Arc::new(RoutingOccupancyState::default());
state.increment(7);
let permit = OccupancyPermit::new(state.clone(), 7);
let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
let stream = permit.into_tracked_stream(ResponseStream::new(
Box::pin(tokio_stream::iter(vec![1u64])),
ctx,
));
assert_eq!(state.load(7), 1);
drop(stream);
assert_eq!(state.load(7), 0);
}
#[test]
fn p2c_lifecycle_tracks_inflight_counts_with_shared_tracker() {
let state = Arc::new(RoutingOccupancyState::default());
let mut permits = Vec::new();
for _ in 0..5 {
let selected = p2c_select_from(&state, &[1, 2]);
state.increment(selected);
permits.push(OccupancyPermit::new(state.clone(), selected));
}
let total = state.load(1) + state.load(2);
assert_eq!(total, 5, "5 in-flight requests should be tracked");
drop(permits);
let total = state.load(1) + state.load(2);
assert_eq!(total, 0, "All guards dropped, counts should be 0");
}
#[test]
fn p2c_never_selects_dominated_worker() {
let state = RoutingOccupancyState::default();
for _ in 0..100 {
state.increment(3);
}
let mut selected = [0u32; 3];
for _ in 0..1000 {
let result = p2c_select_from(&state, &[1, 2, 3]);
match result {
1 => selected[0] += 1,
2 => selected[1] += 1,
3 => selected[2] += 1,
_ => panic!("unexpected worker id"),
}
}
assert_eq!(
selected[2], 0,
"Worker 3 (load=100) should never be selected against load=0 workers, but got {} times",
selected[2]
);
}
#[tokio::test]
async fn least_loaded_selects_exact_min_and_tracks_counts() {
let state = Arc::new(RoutingOccupancyState::default());
state.increment(1);
state.increment(1);
state.increment(2);
let selected = state
.select_exact_min_and_increment(&[1, 2, 3])
.await
.unwrap();
assert_eq!(selected, 3);
let permit = OccupancyPermit::new(state.clone(), selected);
assert_eq!(state.load(selected), 1);
drop(permit);
assert_eq!(state.load(selected), 0);
}
#[tokio::test]
async fn bidirectional_generate_bails_with_no_instances() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_bidi_no_instances".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = endpoint.client().await.unwrap();
let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::RoundRobin)
.await
.unwrap();
let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
let input: ManyIn<u64> =
ResponseStream::new(Box::pin(tokio_stream::iter(vec![1u64, 2u64])), ctx);
let result = router.generate(input).await;
assert!(
result.is_err(),
"bidirectional generate must bail when no instances are registered"
);
rt.shutdown();
}
#[tokio::test]
async fn bidirectional_generate_bails_for_kv_router_mode() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_bidi_kv_mode".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = endpoint.client().await.unwrap();
let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::KV)
.await
.unwrap();
let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
let input: ManyIn<u64> = ResponseStream::new(Box::pin(tokio_stream::iter(vec![1u64])), ctx);
let result = router.generate(input).await;
assert!(
result.is_err(),
"bidirectional generate must bail for RouterMode::KV"
);
let err_msg = format!("{:?}", result.unwrap_err());
assert!(
err_msg.contains("KV") || err_msg.contains("kv"),
"error should mention KV: got {err_msg}"
);
rt.shutdown();
}
#[tokio::test]
async fn bidirectional_generate_bails_for_direct_router_mode() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_bidi_direct_mode".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = endpoint.client().await.unwrap();
let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::Direct)
.await
.unwrap();
let ctx: Arc<dyn AsyncEngineContext> = Arc::new(Controller::default());
let input: ManyIn<u64> = ResponseStream::new(Box::pin(tokio_stream::iter(vec![1u64])), ctx);
let result = router.generate(input).await;
assert!(
result.is_err(),
"bidirectional generate must bail for RouterMode::Direct"
);
let err_msg = format!("{:?}", result.unwrap_err());
assert!(
err_msg.contains("Direct") || err_msg.contains("direct"),
"error should mention Direct: got {err_msg}"
);
rt.shutdown();
}
#[tokio::test]
async fn least_loaded_select_and_peek_return_none_with_available_worker() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt
.namespace("test_least_loaded_router".to_string())
.unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = endpoint.client().await.unwrap();
endpoint.register_endpoint_instance().await.unwrap();
client.wait_for_instances().await.unwrap();
let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::LeastLoaded)
.await
.unwrap();
assert_eq!(router.select_next_worker(), None);
assert_eq!(router.peek_next_worker(), None);
rt.shutdown();
}
#[tokio::test]
async fn selected_overloaded_worker_is_rejected_before_dispatch() {
const TEST_RECONCILE_INTERVAL: std::time::Duration = std::time::Duration::from_millis(50);
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt
.namespace("test_selected_overloaded_worker_rejected".to_string())
.unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
.await
.unwrap();
endpoint.register_endpoint_instance().await.unwrap();
let instances = client.wait_for_instances().await.unwrap();
let worker_id = instances[0].id();
for _ in 0..10 {
if client.instance_ids_avail().contains(&worker_id) {
break;
}
tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
}
assert!(
client.instance_ids_avail().contains(&worker_id),
"worker should be routable before marking it overloaded"
);
client.set_overloaded_instances(&[worker_id]);
let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::RoundRobin)
.await
.unwrap();
let result = router.generate(SingleIn::new(42u64)).await;
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(
msg.contains("All workers are busy"),
"expected empty-free-pool rejection, got: {msg}"
);
rt.shutdown();
}
#[tokio::test]
async fn round_robin_excludes_overloaded_workers_from_candidates() {
const TEST_RECONCILE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(3600);
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt
.namespace("test_round_robin_excludes_overloaded".to_string())
.unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
.await
.unwrap();
endpoint.register_endpoint_instance().await.unwrap();
let instances = client.wait_for_instances().await.unwrap();
let real_id = instances[0].id();
for _ in 0..50 {
if client.instance_ids_avail().contains(&real_id) {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
client.override_instance_avail(vec![1, 2]);
client.set_overloaded_instances(&[1]);
let router = PushRouter::<u64, TestResponse>::from_client(client, RouterMode::RoundRobin)
.await
.unwrap();
for _ in 0..6 {
let selected = router
.peek_next_worker()
.expect("peek should succeed with a free worker");
assert_eq!(
selected, 2,
"overloaded worker 1 must not appear in the candidate set"
);
}
rt.shutdown();
}
#[tokio::test]
async fn device_aware_cpu_only_selects_least_loaded_instance() {
let state = RoutingOccupancyState::default();
for _ in 0..3 {
state.increment(1);
}
state.increment(3);
let instance_ids = vec![1, 2, 3];
let device_type_map = HashMap::from([
(1, Some(DeviceType::Cpu)),
(2, Some(DeviceType::Cpu)),
(3, Some(DeviceType::Cpu)),
]);
let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
assert_eq!(candidates, vec![1, 2, 3]);
let selected = state
.select_exact_min_and_increment(&candidates)
.await
.unwrap();
assert_eq!(selected, 2);
}
#[tokio::test]
async fn device_aware_non_cpu_only_selects_least_loaded_instance() {
let state = RoutingOccupancyState::default();
for _ in 0..3 {
state.increment(1);
}
state.increment(3);
let instance_ids = vec![1, 2, 3];
let device_type_map = HashMap::from([
(1, Some(DeviceType::Cuda)),
(2, Some(DeviceType::Cuda)),
(3, Some(DeviceType::Cuda)),
]);
let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
assert_eq!(candidates, vec![1, 2, 3]);
let selected = state
.select_exact_min_and_increment(&candidates)
.await
.unwrap();
assert_eq!(selected, 2);
}
#[test]
fn device_aware_group_uses_ratio_budget() {
let state = RoutingOccupancyState::default();
for _ in 0..4 {
state.increment(3);
state.increment(4);
}
for _ in 0..3 {
state.increment(1);
}
let instance_ids = vec![1, 2, 3, 4];
let device_type_map = HashMap::from([
(1, Some(DeviceType::Cpu)),
(2, Some(DeviceType::Cpu)),
(3, Some(DeviceType::Cuda)),
(4, Some(DeviceType::Cuda)),
]);
let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 2);
assert_eq!(candidates, vec![1, 2]);
let selected =
futures::executor::block_on(state.select_exact_min_and_increment(&candidates)).unwrap();
assert_eq!(selected, 2);
}
#[tokio::test]
async fn device_aware_weighted_select_and_peek_return_none_with_available_worker() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt
.namespace("test_device_aware_router".to_string())
.unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = endpoint.client().await.unwrap();
endpoint.register_endpoint_instance().await.unwrap();
client.wait_for_instances().await.unwrap();
let router =
PushRouter::<u64, TestResponse>::from_client(client, RouterMode::DeviceAwareWeighted)
.await
.unwrap();
assert_eq!(router.select_next_worker(), None);
assert_eq!(router.peek_next_worker(), None);
rt.shutdown();
}
#[tokio::test]
async fn transport_resolution_falls_back_when_selected_instance_disappears() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt
.namespace("test_transport_fallback".to_string())
.unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = endpoint.client().await.unwrap();
endpoint.register_endpoint_instance().await.unwrap();
client.wait_for_instances().await.unwrap();
let real_id = client.instance_ids()[0];
let stale_id = real_id + 1000;
client.override_instance_avail(vec![stale_id, real_id]);
let router =
PushRouter::<u64, TestResponse>::from_client(client.clone(), RouterMode::RoundRobin)
.await
.unwrap();
let request = SingleIn::new(42u64);
let result = router.generate(request).await;
if let Err(err) = &result {
let msg = format!("{err}");
assert!(
!msg.contains("not found"),
"Transport resolution should have fallen back, but got: {msg}"
);
}
rt.shutdown();
}
#[tokio::test]
async fn transport_resolution_errors_when_no_instances_available() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt
.namespace("test_transport_no_fallback".to_string())
.unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = endpoint.client().await.unwrap();
endpoint.register_endpoint_instance().await.unwrap();
client.wait_for_instances().await.unwrap();
let router =
PushRouter::<u64, TestResponse>::from_client(client.clone(), RouterMode::RoundRobin)
.await
.unwrap();
let stale_id = 99999;
client.override_instance_avail(vec![stale_id]);
let request = SingleIn::new(42u64);
let result = router.generate(request).await;
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(
msg.contains("not found") && msg.contains("no other instances available"),
"Expected clear error about missing instance with no fallback, got: {msg}"
);
rt.shutdown();
}
#[tokio::test]
async fn watcher_dedup_guard_released_on_panic() {
let endpoint_id = EndpointId {
namespace: "panic-test-ns".to_string(),
component: "panic-test-comp".to_string(),
name: "panic-test-endpoint".to_string(),
};
let map = ENDPOINT_WATCHER_ACTIVE.get_or_init(dashmap::DashMap::new);
map.insert(endpoint_id.clone(), ());
let endpoint_id_clone = endpoint_id.clone();
let join = tokio::spawn(async move {
struct GuardRelease(EndpointId);
impl Drop for GuardRelease {
fn drop(&mut self) {
if let Some(map) = ENDPOINT_WATCHER_ACTIVE.get() {
map.remove(&self.0);
}
}
}
let _release = GuardRelease(endpoint_id_clone);
panic!("simulated watcher-task panic");
});
let result = join.await;
assert!(result.is_err() && result.unwrap_err().is_panic());
assert!(
!map.contains_key(&endpoint_id),
"Drop guard must release the dedup entry even on panic"
);
}
#[tokio::test]
async fn watcher_dedup_guard_released_on_normal_exit() {
let endpoint_id = EndpointId {
namespace: "normal-test-ns".to_string(),
component: "normal-test-comp".to_string(),
name: "normal-test-endpoint".to_string(),
};
let map = ENDPOINT_WATCHER_ACTIVE.get_or_init(dashmap::DashMap::new);
map.insert(endpoint_id.clone(), ());
let endpoint_id_clone = endpoint_id.clone();
tokio::spawn(async move {
struct GuardRelease(EndpointId);
impl Drop for GuardRelease {
fn drop(&mut self) {
if let Some(map) = ENDPOINT_WATCHER_ACTIVE.get() {
map.remove(&self.0);
}
}
}
let _release = GuardRelease(endpoint_id_clone);
})
.await
.unwrap();
assert!(!map.contains_key(&endpoint_id));
}
}