dynamo_runtime/pipeline/network/egress/
push_router.rs1use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG};
5use crate::{
6 component::{Client, Endpoint, InstanceSource},
7 engine::{AsyncEngine, Data},
8 pipeline::{
9 AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
10 error::{PipelineError, PipelineErrorExt},
11 },
12 protocols::maybe_error::MaybeError,
13 traits::DistributedRuntimeProvider,
14};
15use async_nats::client::{
16 RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
17};
18use async_trait::async_trait;
19use rand::Rng;
20use serde::{Deserialize, Serialize};
21use std::{
22 future::Future,
23 marker::PhantomData,
24 sync::{
25 Arc,
26 atomic::{AtomicU64, Ordering},
27 },
28};
29use tokio_stream::StreamExt;
30
31#[async_trait]
34pub trait WorkerLoadMonitor: Send + Sync {
35 async fn start_monitoring(&self) -> anyhow::Result<()>;
38}
39
40#[derive(Clone)]
41pub struct PushRouter<T, U>
42where
43 T: Data + Serialize,
44 U: Data + for<'de> Deserialize<'de>,
45{
46 pub client: Client,
49
50 router_mode: RouterMode,
57
58 round_robin_counter: Arc<AtomicU64>,
60
61 addressed: Arc<AddressedPushRouter>,
64
65 busy_threshold: Option<f64>,
68
69 _phantom: PhantomData<(T, U)>,
73}
74
75#[derive(Default, Debug, Clone, Copy, PartialEq)]
76pub enum RouterMode {
77 #[default]
78 RoundRobin,
79 Random,
80 Direct(u64),
81 KV,
83}
84
85impl RouterMode {
86 pub fn is_kv_routing(&self) -> bool {
87 *self == RouterMode::KV
88 }
89}
90
91async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
92 let manager = endpoint.drt().network_manager().await?;
94 let req_client = manager.create_client()?;
95 let resp_transport = endpoint.drt().tcp_server().await?;
96
97 tracing::debug!(
98 transport = req_client.transport_name(),
99 "Creating AddressedPushRouter with request plane client"
100 );
101
102 AddressedPushRouter::new(req_client, resp_transport)
103}
104
105impl<T, U> PushRouter<T, U>
106where
107 T: Data + Serialize,
108 U: Data + for<'de> Deserialize<'de> + MaybeError,
109{
110 pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
112 Self::from_client_with_threshold(client, router_mode, None, None).await
113 }
114
115 pub async fn from_client_with_threshold(
117 client: Client,
118 router_mode: RouterMode,
119 busy_threshold: Option<f64>,
120 worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
121 ) -> anyhow::Result<Self> {
122 let addressed = addressed_router(&client.endpoint).await?;
123
124 if let Some(monitor) = worker_monitor.as_ref()
126 && matches!(client.instance_source.as_ref(), InstanceSource::Dynamic(_))
127 {
128 monitor.start_monitoring().await?;
129 }
130
131 let router = PushRouter {
132 client: client.clone(),
133 addressed,
134 router_mode,
135 round_robin_counter: Arc::new(AtomicU64::new(0)),
136 busy_threshold,
137 _phantom: PhantomData,
138 };
139
140 Ok(router)
141 }
142
143 pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
145 let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
146
147 let instance_id = {
148 let instance_ids = self.client.instance_ids_avail();
149 let count = instance_ids.len();
150 if count == 0 {
151 return Err(anyhow::anyhow!(
152 "no instances found for endpoint {:?}",
153 self.client.endpoint.etcd_root()
154 ));
155 }
156 instance_ids[counter % count]
157 };
158 tracing::trace!("round robin router selected {instance_id}");
159
160 self.generate_with_fault_detection(instance_id, request)
161 .await
162 }
163
164 pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
166 let instance_id = {
167 let instance_ids = self.client.instance_ids_avail();
168 let count = instance_ids.len();
169 if count == 0 {
170 return Err(anyhow::anyhow!(
171 "no instances found for endpoint {:?}",
172 self.client.endpoint.etcd_root()
173 ));
174 }
175 let counter = rand::rng().random::<u64>() as usize;
176 instance_ids[counter % count]
177 };
178 tracing::trace!("random router selected {instance_id}");
179
180 self.generate_with_fault_detection(instance_id, request)
181 .await
182 }
183
184 pub async fn direct(
186 &self,
187 request: SingleIn<T>,
188 instance_id: u64,
189 ) -> anyhow::Result<ManyOut<U>> {
190 let found = self.client.instance_ids_avail().contains(&instance_id);
191
192 if !found {
193 return Err(anyhow::anyhow!(
194 "instance_id={instance_id} not found for endpoint {:?}",
195 self.client.endpoint.etcd_root()
196 ));
197 }
198
199 self.generate_with_fault_detection(instance_id, request)
200 .await
201 }
202
203 pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
204 let subject = self.client.endpoint.subject();
205 tracing::debug!("static got subject: {subject}");
206 let request = request.map(|req| AddressedRequest::new(req, subject));
207 tracing::debug!("router generate");
208 self.addressed.generate(request).await
209 }
210
211 async fn generate_with_fault_detection(
212 &self,
213 instance_id: u64,
214 request: SingleIn<T>,
215 ) -> anyhow::Result<ManyOut<U>> {
216 if self.busy_threshold.is_some() {
218 let free_instances = self.client.instance_ids_free();
219 if free_instances.is_empty() {
220 let all_instances = self.client.instance_ids();
222 if !all_instances.is_empty() {
223 return Err(PipelineError::ServiceOverloaded(
224 "All workers are busy, please retry later".to_string(),
225 )
226 .into());
227 }
228 }
229 }
230
231 let address = {
233 use crate::component::TransportType;
234
235 let instances = self.client.instances();
237 let instance = instances
238 .iter()
239 .find(|i| i.instance_id == instance_id)
240 .ok_or_else(|| {
241 anyhow::anyhow!("Instance {} not found in available instances", instance_id)
242 })?;
243
244 match &instance.transport {
245 TransportType::Http(http_endpoint) => {
246 tracing::debug!(
247 instance_id = instance_id,
248 http_endpoint = %http_endpoint,
249 "Using HTTP transport for instance"
250 );
251 http_endpoint.clone()
252 }
253 TransportType::Tcp(tcp_endpoint) => {
254 tracing::debug!(
255 instance_id = instance_id,
256 tcp_endpoint = %tcp_endpoint,
257 "Using TCP transport for instance"
258 );
259 tcp_endpoint.clone()
260 }
261 TransportType::Nats(subject) => {
262 tracing::debug!(
263 instance_id = instance_id,
264 subject = %subject,
265 "Using NATS transport for instance"
266 );
267 subject.clone()
268 }
269 }
270 };
271
272 let request = request.map(|req| AddressedRequest::new(req, address));
273
274 let stream: anyhow::Result<ManyOut<U>> = self.addressed.generate(request).await;
275 match stream {
276 Ok(stream) => {
277 let engine_ctx = stream.context();
278 let client = self.client.clone();
279 let stream = stream.map(move |res| {
280 if let Some(err) = res.err()
282 && format!("{:?}", err) == STREAM_ERR_MSG
283 {
284 tracing::debug!(
285 "Reporting instance {instance_id} down due to stream error: {err}"
286 );
287 client.report_instance_down(instance_id);
288 }
289 res
290 });
291 Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
292 }
293 Err(err) => {
294 if let Some(req_err) = err.downcast_ref::<NatsRequestError>()
295 && matches!(req_err.kind(), NatsNoResponders)
296 {
297 tracing::debug!(
298 "Reporting instance {instance_id} down due to request error: {req_err}"
299 );
300 self.client.report_instance_down(instance_id);
301 }
302 Err(err)
303 }
304 }
305 }
306}
307
308#[async_trait]
309impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
310where
311 T: Data + Serialize,
312 U: Data + for<'de> Deserialize<'de> + MaybeError,
313{
314 async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
315 match self.client.instance_source.as_ref() {
316 InstanceSource::Static => self.r#static(request).await,
317 InstanceSource::Dynamic(_) => match self.router_mode {
318 RouterMode::Random => self.random(request).await,
319 RouterMode::RoundRobin => self.round_robin(request).await,
320 RouterMode::Direct(instance_id) => self.direct(request, instance_id).await,
321 RouterMode::KV => {
322 anyhow::bail!("KV routing should not call generate on PushRouter");
323 }
324 },
325 }
326 }
327}