dynamo_runtime/pipeline/network/egress/
push_router.rs1use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG};
5use crate::{
6 component::{Client, Endpoint},
7 engine::{AsyncEngine, Data},
8 pipeline::{
9 AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
10 error::{PipelineError, PipelineErrorExt},
11 },
12 protocols::maybe_error::MaybeError,
13 traits::DistributedRuntimeProvider,
14};
15use async_nats::client::{
16 RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
17};
18use async_trait::async_trait;
19use rand::Rng;
20use serde::{Deserialize, Serialize};
21use std::{
22 future::Future,
23 marker::PhantomData,
24 sync::{
25 Arc,
26 atomic::{AtomicU64, Ordering},
27 },
28};
29use tokio_stream::StreamExt;
30
31#[async_trait]
34pub trait WorkerLoadMonitor: Send + Sync {
35 async fn start_monitoring(&self) -> anyhow::Result<()>;
38}
39
40#[derive(Clone)]
41pub struct PushRouter<T, U>
42where
43 T: Data + Serialize,
44 U: Data + for<'de> Deserialize<'de>,
45{
46 pub client: Client,
49
50 router_mode: RouterMode,
57
58 round_robin_counter: Arc<AtomicU64>,
60
61 addressed: Arc<AddressedPushRouter>,
64
65 busy_threshold: Option<f64>,
68
69 _phantom: PhantomData<(T, U)>,
73}
74
75#[derive(Default, Debug, Clone, Copy, PartialEq)]
76pub enum RouterMode {
77 #[default]
78 RoundRobin,
79 Random,
80 Direct(u64),
81 KV,
83}
84
85impl RouterMode {
86 pub fn is_kv_routing(&self) -> bool {
87 *self == RouterMode::KV
88 }
89}
90
91async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
92 let manager = endpoint.drt().network_manager();
94 let req_client = manager.create_client()?;
95 let resp_transport = endpoint.drt().tcp_server().await?;
96
97 tracing::debug!(
98 transport = req_client.transport_name(),
99 "Creating AddressedPushRouter with request plane client"
100 );
101
102 AddressedPushRouter::new(req_client, resp_transport)
103}
104
105impl<T, U> PushRouter<T, U>
106where
107 T: Data + Serialize,
108 U: Data + for<'de> Deserialize<'de> + MaybeError,
109{
110 pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
112 Self::from_client_with_threshold(client, router_mode, None, None).await
113 }
114
115 pub async fn from_client_with_threshold(
117 client: Client,
118 router_mode: RouterMode,
119 busy_threshold: Option<f64>,
120 worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
121 ) -> anyhow::Result<Self> {
122 let addressed = addressed_router(&client.endpoint).await?;
123
124 if let Some(monitor) = worker_monitor.as_ref() {
126 monitor.start_monitoring().await?;
127 }
128
129 let router = PushRouter {
130 client: client.clone(),
131 addressed,
132 router_mode,
133 round_robin_counter: Arc::new(AtomicU64::new(0)),
134 busy_threshold,
135 _phantom: PhantomData,
136 };
137
138 Ok(router)
139 }
140
141 pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
143 let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
144
145 let instance_id = {
146 let instance_ids = self.client.instance_ids_avail();
147 let count = instance_ids.len();
148 if count == 0 {
149 return Err(anyhow::anyhow!(
150 "no instances found for endpoint {}",
151 self.client.endpoint.id()
152 ));
153 }
154 instance_ids[counter % count]
155 };
156 tracing::trace!("round robin router selected {instance_id}");
157
158 self.generate_with_fault_detection(instance_id, request)
159 .await
160 }
161
162 pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
164 let instance_id = {
165 let instance_ids = self.client.instance_ids_avail();
166 let count = instance_ids.len();
167 if count == 0 {
168 return Err(anyhow::anyhow!(
169 "no instances found for endpoint {}",
170 self.client.endpoint.id()
171 ));
172 }
173 let counter = rand::rng().random::<u64>() as usize;
174 instance_ids[counter % count]
175 };
176 tracing::trace!("random router selected {instance_id}");
177
178 self.generate_with_fault_detection(instance_id, request)
179 .await
180 }
181
182 pub async fn direct(
184 &self,
185 request: SingleIn<T>,
186 instance_id: u64,
187 ) -> anyhow::Result<ManyOut<U>> {
188 let found = self.client.instance_ids_avail().contains(&instance_id);
189
190 if !found {
191 return Err(anyhow::anyhow!(
192 "instance_id={instance_id} not found for endpoint {}",
193 self.client.endpoint.id()
194 ));
195 }
196
197 self.generate_with_fault_detection(instance_id, request)
198 .await
199 }
200
201 pub fn select_next_worker(&self) -> Option<u64> {
205 let instance_ids = self.client.instance_ids_avail();
206 let count = instance_ids.len();
207 if count == 0 {
208 return None;
209 }
210
211 match self.router_mode {
212 RouterMode::RoundRobin => {
213 let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
214 Some(instance_ids[counter % count])
215 }
216 RouterMode::Random => {
217 let counter = rand::rng().random::<u64>() as usize;
218 Some(instance_ids[counter % count])
219 }
220 _ => {
221 panic!(
222 "select_next_worker should not be called for {:?} routing mode",
223 self.router_mode
224 )
225 }
226 }
227 }
228
229 pub fn peek_next_worker(&self) -> Option<u64> {
232 let instance_ids = self.client.instance_ids_avail();
233 let count = instance_ids.len();
234 if count == 0 {
235 return None;
236 }
237
238 match self.router_mode {
239 RouterMode::RoundRobin => {
240 let counter = self.round_robin_counter.load(Ordering::Relaxed) as usize;
242 Some(instance_ids[counter % count])
243 }
244 RouterMode::Random => {
245 let counter = rand::rng().random::<u64>() as usize;
248 Some(instance_ids[counter % count])
249 }
250 _ => {
251 panic!(
252 "peek_next_worker should not be called for {:?} routing mode",
253 self.router_mode
254 )
255 }
256 }
257 }
258
259 async fn generate_with_fault_detection(
270 &self,
271 instance_id: u64,
272 request: SingleIn<T>,
273 ) -> anyhow::Result<ManyOut<U>> {
274 if self.busy_threshold.is_some() {
276 let free_instances = self.client.instance_ids_free();
277 if free_instances.is_empty() {
278 let all_instances = self.client.instance_ids();
280 if !all_instances.is_empty() {
281 return Err(PipelineError::ServiceOverloaded(
282 "All workers are busy, please retry later".to_string(),
283 )
284 .into());
285 }
286 }
287 }
288
289 let address = {
291 use crate::component::TransportType;
292
293 let instances = self.client.instances();
295 let instance = instances
296 .iter()
297 .find(|i| i.instance_id == instance_id)
298 .ok_or_else(|| {
299 anyhow::anyhow!("Instance {} not found in available instances", instance_id)
300 })?;
301
302 match &instance.transport {
303 TransportType::Http(http_endpoint) => {
304 tracing::debug!(
305 instance_id = instance_id,
306 http_endpoint = %http_endpoint,
307 "Using HTTP transport for instance"
308 );
309 http_endpoint.clone()
310 }
311 TransportType::Tcp(tcp_endpoint) => {
312 tracing::debug!(
313 instance_id = instance_id,
314 tcp_endpoint = %tcp_endpoint,
315 "Using TCP transport for instance"
316 );
317 tcp_endpoint.clone()
318 }
319 TransportType::Nats(subject) => {
320 tracing::debug!(
321 instance_id = instance_id,
322 subject = %subject,
323 "Using NATS transport for instance"
324 );
325 subject.clone()
326 }
327 }
328 };
329
330 let request = request.map(|req| AddressedRequest::new(req, address));
331
332 let stream: anyhow::Result<ManyOut<U>> = self.addressed.generate(request).await;
333 match stream {
334 Ok(stream) => {
335 let engine_ctx = stream.context();
336 let client = self.client.clone();
337 let stream = stream.map(move |res| {
338 if let Some(err) = res.err()
340 && format!("{:?}", err) == STREAM_ERR_MSG
341 {
342 tracing::debug!(
343 "Reporting instance {instance_id} down due to stream error: {err}"
344 );
345 client.report_instance_down(instance_id);
346 }
347 res
348 });
349 Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
350 }
351 Err(err) => {
352 if let Some(req_err) = err.downcast_ref::<NatsRequestError>()
353 && matches!(req_err.kind(), NatsNoResponders)
354 {
355 tracing::debug!(
356 "Reporting instance {instance_id} down due to request error: {req_err}"
357 );
358 self.client.report_instance_down(instance_id);
359 }
360 Err(err)
361 }
362 }
363 }
364}
365
366#[async_trait]
367impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
368where
369 T: Data + Serialize,
370 U: Data + for<'de> Deserialize<'de> + MaybeError,
371{
372 async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
373 match self.router_mode {
375 RouterMode::Random => self.random(request).await,
376 RouterMode::RoundRobin => self.round_robin(request).await,
377 RouterMode::Direct(instance_id) => self.direct(request, instance_id).await,
378 RouterMode::KV => {
379 anyhow::bail!("KV routing should not call generate on PushRouter");
380 }
381 }
382 }
383}