dynamo_runtime/pipeline/network/egress/
push_router.rs1use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG};
5use crate::{
6 component::{Client, Endpoint, InstanceSource},
7 engine::{AsyncEngine, Data},
8 pipeline::{
9 AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
10 error::{PipelineError, PipelineErrorExt},
11 },
12 protocols::maybe_error::MaybeError,
13 traits::DistributedRuntimeProvider,
14};
15use async_nats::client::{
16 RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
17};
18use async_trait::async_trait;
19use rand::Rng;
20use serde::{Deserialize, Serialize};
21use std::{
22 future::Future,
23 marker::PhantomData,
24 sync::{
25 Arc,
26 atomic::{AtomicU64, Ordering},
27 },
28};
29use tokio_stream::StreamExt;
30
31#[async_trait]
34pub trait WorkerLoadMonitor: Send + Sync {
35 async fn start_monitoring(&self) -> anyhow::Result<()>;
38}
39
40#[derive(Clone)]
41pub struct PushRouter<T, U>
42where
43 T: Data + Serialize,
44 U: Data + for<'de> Deserialize<'de>,
45{
46 pub client: Client,
49
50 router_mode: RouterMode,
57
58 round_robin_counter: Arc<AtomicU64>,
60
61 addressed: Arc<AddressedPushRouter>,
64
65 busy_threshold: Option<f64>,
68
69 _phantom: PhantomData<(T, U)>,
73}
74
75#[derive(Default, Debug, Clone, Copy, PartialEq)]
76pub enum RouterMode {
77 #[default]
78 RoundRobin,
79 Random,
80 Direct(u64),
81 KV,
83}
84
85impl RouterMode {
86 pub fn is_kv_routing(&self) -> bool {
87 *self == RouterMode::KV
88 }
89}
90
91async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
92 let Some(nats_client) = endpoint.drt().nats_client() else {
93 anyhow::bail!("Missing NATS. Please ensure it is running and accessible.");
94 };
95 AddressedPushRouter::new(
96 nats_client.client().clone(),
97 endpoint.drt().tcp_server().await?,
98 )
99}
100
101impl<T, U> PushRouter<T, U>
102where
103 T: Data + Serialize,
104 U: Data + for<'de> Deserialize<'de> + MaybeError,
105{
106 pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
108 Self::from_client_with_threshold(client, router_mode, None, None).await
109 }
110
111 pub async fn from_client_with_threshold(
113 client: Client,
114 router_mode: RouterMode,
115 busy_threshold: Option<f64>,
116 worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
117 ) -> anyhow::Result<Self> {
118 let addressed = addressed_router(&client.endpoint).await?;
119
120 if let Some(monitor) = worker_monitor.as_ref()
122 && matches!(client.instance_source.as_ref(), InstanceSource::Dynamic(_))
123 {
124 monitor.start_monitoring().await?;
125 }
126
127 let router = PushRouter {
128 client: client.clone(),
129 addressed,
130 router_mode,
131 round_robin_counter: Arc::new(AtomicU64::new(0)),
132 busy_threshold,
133 _phantom: PhantomData,
134 };
135
136 Ok(router)
137 }
138
139 pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
141 let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
142
143 let instance_id = {
144 let instance_ids = self.client.instance_ids_avail();
145 let count = instance_ids.len();
146 if count == 0 {
147 return Err(anyhow::anyhow!(
148 "no instances found for endpoint {:?}",
149 self.client.endpoint.etcd_root()
150 ));
151 }
152 instance_ids[counter % count]
153 };
154 tracing::trace!("round robin router selected {instance_id}");
155
156 self.generate_with_fault_detection(instance_id, request)
157 .await
158 }
159
160 pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
162 let instance_id = {
163 let instance_ids = self.client.instance_ids_avail();
164 let count = instance_ids.len();
165 if count == 0 {
166 return Err(anyhow::anyhow!(
167 "no instances found for endpoint {:?}",
168 self.client.endpoint.etcd_root()
169 ));
170 }
171 let counter = rand::rng().random::<u64>() as usize;
172 instance_ids[counter % count]
173 };
174 tracing::trace!("random router selected {instance_id}");
175
176 self.generate_with_fault_detection(instance_id, request)
177 .await
178 }
179
180 pub async fn direct(
182 &self,
183 request: SingleIn<T>,
184 instance_id: u64,
185 ) -> anyhow::Result<ManyOut<U>> {
186 let found = self.client.instance_ids_avail().contains(&instance_id);
187
188 if !found {
189 return Err(anyhow::anyhow!(
190 "instance_id={instance_id} not found for endpoint {:?}",
191 self.client.endpoint.etcd_root()
192 ));
193 }
194
195 self.generate_with_fault_detection(instance_id, request)
196 .await
197 }
198
199 pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
200 let subject = self.client.endpoint.subject();
201 tracing::debug!("static got subject: {subject}");
202 let request = request.map(|req| AddressedRequest::new(req, subject));
203 tracing::debug!("router generate");
204 self.addressed.generate(request).await
205 }
206
207 async fn generate_with_fault_detection(
208 &self,
209 instance_id: u64,
210 request: SingleIn<T>,
211 ) -> anyhow::Result<ManyOut<U>> {
212 if self.busy_threshold.is_some() {
214 let free_instances = self.client.instance_ids_free();
215 if free_instances.is_empty() {
216 let all_instances = self.client.instance_ids();
218 if !all_instances.is_empty() {
219 return Err(PipelineError::ServiceOverloaded(
220 "All workers are busy, please retry later".to_string(),
221 )
222 .into());
223 }
224 }
225 }
226
227 let subject = self.client.endpoint.subject_to(instance_id);
228 let request = request.map(|req| AddressedRequest::new(req, subject));
229
230 let stream: anyhow::Result<ManyOut<U>> = self.addressed.generate(request).await;
231 match stream {
232 Ok(stream) => {
233 let engine_ctx = stream.context();
234 let client = self.client.clone();
235 let stream = stream.map(move |res| {
236 if let Some(err) = res.err()
238 && format!("{:?}", err) == STREAM_ERR_MSG
239 {
240 tracing::debug!(
241 "Reporting instance {instance_id} down due to stream error: {err}"
242 );
243 client.report_instance_down(instance_id);
244 }
245 res
246 });
247 Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
248 }
249 Err(err) => {
250 if let Some(req_err) = err.downcast_ref::<NatsRequestError>()
251 && matches!(req_err.kind(), NatsNoResponders)
252 {
253 tracing::debug!(
254 "Reporting instance {instance_id} down due to request error: {req_err}"
255 );
256 self.client.report_instance_down(instance_id);
257 }
258 Err(err)
259 }
260 }
261 }
262}
263
264#[async_trait]
265impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
266where
267 T: Data + Serialize,
268 U: Data + for<'de> Deserialize<'de> + MaybeError,
269{
270 async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
271 match self.client.instance_source.as_ref() {
272 InstanceSource::Static => self.r#static(request).await,
273 InstanceSource::Dynamic(_) => match self.router_mode {
274 RouterMode::Random => self.random(request).await,
275 RouterMode::RoundRobin => self.round_robin(request).await,
276 RouterMode::Direct(instance_id) => self.direct(request, instance_id).await,
277 RouterMode::KV => {
278 anyhow::bail!("KV routing should not call generate on PushRouter");
279 }
280 },
281 }
282 }
283}