dynamo_runtime/pipeline/network/egress/
push_router.rs1use super::{AsyncEngineContextProvider, ResponseStream};
5use crate::error::{BackendError, ErrorType, match_error_chain};
6
7fn is_inhibited(err: &(dyn std::error::Error + 'static)) -> bool {
9 const INHIBITED: &[ErrorType] = &[
10 ErrorType::CannotConnect,
11 ErrorType::Disconnected,
12 ErrorType::ConnectionTimeout,
13 ErrorType::Backend(BackendError::EngineShutdown),
14 ];
15 match_error_chain(err, INHIBITED, &[])
16}
17use crate::{
18 component::{Client, Endpoint},
19 engine::{AsyncEngine, Data},
20 pipeline::{
21 AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
22 error::{PipelineError, PipelineErrorExt},
23 },
24 protocols::maybe_error::MaybeError,
25 traits::DistributedRuntimeProvider,
26};
27use async_trait::async_trait;
28use rand::Rng;
29use serde::{Deserialize, Serialize};
30use std::{
31 future::Future,
32 marker::PhantomData,
33 sync::{
34 Arc,
35 atomic::{AtomicU64, Ordering},
36 },
37};
38use tokio_stream::StreamExt;
39use tracing::Instrument;
40
41#[async_trait]
44pub trait WorkerLoadMonitor: Send + Sync {
45 async fn start_monitoring(&self) -> anyhow::Result<()>;
48}
49
50#[derive(Clone)]
51pub struct PushRouter<T, U>
52where
53 T: Data + Serialize,
54 U: Data + for<'de> Deserialize<'de>,
55{
56 pub client: Client,
59
60 router_mode: RouterMode,
67
68 round_robin_counter: Arc<AtomicU64>,
70
71 addressed: Arc<AddressedPushRouter>,
74
75 busy_threshold: Option<f64>,
78
79 fault_detection_enabled: bool,
84
85 _phantom: PhantomData<(T, U)>,
89}
90
91#[derive(Default, Debug, Clone, Copy, PartialEq)]
92pub enum RouterMode {
93 #[default]
94 RoundRobin,
95 Random,
96 KV,
97 Direct,
98}
99
100impl RouterMode {
101 pub fn is_kv_routing(&self) -> bool {
102 *self == RouterMode::KV
103 }
104
105 pub fn is_direct_routing(&self) -> bool {
106 *self == RouterMode::Direct
107 }
108}
109
110async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
111 let manager = endpoint.drt().network_manager();
113 let req_client = manager.create_client()?;
114 let resp_transport = endpoint.drt().tcp_server().await?;
115
116 tracing::debug!(
117 transport = req_client.transport_name(),
118 "Creating AddressedPushRouter with request plane client"
119 );
120
121 AddressedPushRouter::new(req_client, resp_transport)
122}
123
124impl<T, U> PushRouter<T, U>
125where
126 T: Data + Serialize,
127 U: Data + for<'de> Deserialize<'de> + MaybeError,
128{
129 pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
131 Self::from_client_with_threshold(client, router_mode, None, None).await
132 }
133
134 pub async fn from_client_no_fault_detection(
140 client: Client,
141 router_mode: RouterMode,
142 ) -> anyhow::Result<Self> {
143 let addressed = addressed_router(&client.endpoint).await?;
144
145 Ok(PushRouter {
146 client: client.clone(),
147 addressed,
148 router_mode,
149 round_robin_counter: Arc::new(AtomicU64::new(0)),
150 busy_threshold: None,
151 fault_detection_enabled: false,
152 _phantom: PhantomData,
153 })
154 }
155
156 pub async fn from_client_with_threshold(
158 client: Client,
159 router_mode: RouterMode,
160 busy_threshold: Option<f64>,
161 worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
162 ) -> anyhow::Result<Self> {
163 let addressed = addressed_router(&client.endpoint).await?;
164
165 if let Some(monitor) = worker_monitor.as_ref() {
167 monitor.start_monitoring().await?;
168 }
169
170 let router = PushRouter {
171 client: client.clone(),
172 addressed,
173 router_mode,
174 round_robin_counter: Arc::new(AtomicU64::new(0)),
175 busy_threshold,
176 fault_detection_enabled: true,
177 _phantom: PhantomData,
178 };
179
180 Ok(router)
181 }
182
183 pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
185 let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
186
187 let instance_id = {
188 let instance_ids = self.client.instance_ids_avail();
189 let count = instance_ids.len();
190 if count == 0 {
191 return Err(anyhow::anyhow!(
192 "no instances found for endpoint {}",
193 self.client.endpoint.id()
194 ));
195 }
196 instance_ids[counter % count]
197 };
198 tracing::trace!("round robin router selected {instance_id}");
199
200 self.generate_with_fault_detection(instance_id, request)
201 .await
202 }
203
204 pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
206 let instance_id = {
207 let instance_ids = self.client.instance_ids_avail();
208 let count = instance_ids.len();
209 if count == 0 {
210 return Err(anyhow::anyhow!(
211 "no instances found for endpoint {}",
212 self.client.endpoint.id()
213 ));
214 }
215 let counter = rand::rng().random::<u64>() as usize;
216 instance_ids[counter % count]
217 };
218 tracing::trace!("random router selected {instance_id}");
219
220 self.generate_with_fault_detection(instance_id, request)
221 .await
222 }
223
224 pub async fn direct(
226 &self,
227 request: SingleIn<T>,
228 instance_id: u64,
229 ) -> anyhow::Result<ManyOut<U>> {
230 let found = if self.fault_detection_enabled {
234 self.client.instance_ids_avail().contains(&instance_id)
235 } else {
236 self.client.instance_ids().contains(&instance_id)
237 };
238
239 if !found {
240 return Err(anyhow::anyhow!(
241 "instance_id={instance_id} not found for endpoint {}",
242 self.client.endpoint.id()
243 ));
244 }
245
246 self.generate_with_fault_detection(instance_id, request)
247 .await
248 }
249
250 pub fn select_next_worker(&self) -> Option<u64> {
254 let instance_ids = self.client.instance_ids_avail();
255 let count = instance_ids.len();
256 if count == 0 {
257 return None;
258 }
259
260 match self.router_mode {
261 RouterMode::RoundRobin => {
262 let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
263 Some(instance_ids[counter % count])
264 }
265 RouterMode::Random => {
266 let counter = rand::rng().random::<u64>() as usize;
267 Some(instance_ids[counter % count])
268 }
269 _ => {
270 panic!(
271 "select_next_worker should not be called for {:?} routing mode",
272 self.router_mode
273 )
274 }
275 }
276 }
277
278 pub fn peek_next_worker(&self) -> Option<u64> {
281 let instance_ids = self.client.instance_ids_avail();
282 let count = instance_ids.len();
283 if count == 0 {
284 return None;
285 }
286
287 match self.router_mode {
288 RouterMode::RoundRobin => {
289 let counter = self.round_robin_counter.load(Ordering::Relaxed) as usize;
291 Some(instance_ids[counter % count])
292 }
293 RouterMode::Random => {
294 let counter = rand::rng().random::<u64>() as usize;
297 Some(instance_ids[counter % count])
298 }
299 _ => {
300 panic!(
301 "peek_next_worker should not be called for {:?} routing mode",
302 self.router_mode
303 )
304 }
305 }
306 }
307
308 async fn generate_with_fault_detection(
319 &self,
320 instance_id: u64,
321 request: SingleIn<T>,
322 ) -> anyhow::Result<ManyOut<U>> {
323 let request_id = request.id().to_string();
324 let route_span = if matches!(self.router_mode, RouterMode::KV) {
325 tracing::Span::none()
326 } else {
327 tracing::info_span!(
328 "router.route_request",
329 request_id = %request_id,
330 worker_id = instance_id,
331 router_mode = ?self.router_mode,
332 )
333 };
334
335 if self.fault_detection_enabled && self.busy_threshold.is_some() {
337 let free_instances = self.client.instance_ids_free();
338 if free_instances.is_empty() {
339 let all_instances = self.client.instance_ids();
341 if !all_instances.is_empty() {
342 tracing::warn!(
343 instance_id,
344 total_workers = all_instances.len(),
345 "Rejecting request: all workers are busy"
346 );
347 return Err(PipelineError::ServiceOverloaded(
348 "All workers are busy, please retry later".to_string(),
349 )
350 .into());
351 }
352 }
353 }
354
355 let address = {
357 use crate::component::TransportType;
358
359 let instances = self.client.instances();
361 let instance = instances
362 .iter()
363 .find(|i| i.instance_id == instance_id)
364 .ok_or_else(|| {
365 anyhow::anyhow!("Instance {} not found in available instances", instance_id)
366 })?;
367
368 match &instance.transport {
369 TransportType::Http(http_endpoint) => {
370 tracing::debug!(
371 instance_id = instance_id,
372 http_endpoint = %http_endpoint,
373 "Using HTTP transport for instance"
374 );
375 http_endpoint.clone()
376 }
377 TransportType::Tcp(tcp_endpoint) => {
378 tracing::debug!(
379 instance_id = instance_id,
380 tcp_endpoint = %tcp_endpoint,
381 "Using TCP transport for instance"
382 );
383 tcp_endpoint.clone()
384 }
385 TransportType::Nats(subject) => {
386 tracing::debug!(
387 instance_id = instance_id,
388 subject = %subject,
389 "Using NATS transport for instance"
390 );
391 subject.clone()
392 }
393 }
394 };
395
396 let request = request.map(|req| AddressedRequest::new(req, address));
397
398 let stream: anyhow::Result<ManyOut<U>> = self
399 .addressed
400 .generate(request)
401 .instrument(route_span)
402 .await;
403 match stream {
404 Ok(stream) => {
405 if !self.fault_detection_enabled {
406 return Ok(stream);
407 }
408 let engine_ctx = stream.context();
409 let client = self.client.clone();
410 let stream = stream.map(move |res| {
411 if let Some(err) = res.err()
413 && is_inhibited(&err)
414 {
415 tracing::debug!(
416 "Reporting instance {instance_id} down due to migratable error: {err}"
417 );
418 client.report_instance_down(instance_id);
419 }
420 res
421 });
422 Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
423 }
424 Err(err) => {
425 if self.fault_detection_enabled && is_inhibited(err.as_ref()) {
426 tracing::debug!("Reporting instance {instance_id} down due to error: {err}");
427 self.client.report_instance_down(instance_id);
428 }
429 Err(err)
430 }
431 }
432 }
433}
434
435#[async_trait]
436impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
437where
438 T: Data + Serialize,
439 U: Data + for<'de> Deserialize<'de> + MaybeError,
440{
441 async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
442 match self.router_mode {
443 RouterMode::Random => self.random(request).await,
444 RouterMode::RoundRobin => self.round_robin(request).await,
445 RouterMode::KV => {
446 anyhow::bail!("KV routing should not call generate on PushRouter");
447 }
448 RouterMode::Direct => {
449 anyhow::bail!(
450 "Direct routing should not call generate on PushRouter directly; use DirectRoutingRouter wrapper"
451 );
452 }
453 }
454 }
455}