dynamo_runtime/pipeline/network/egress/
push_router.rs1use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG};
5use crate::utils::worker_monitor::WorkerMonitor;
6use crate::{
7 component::{Client, Endpoint, InstanceSource},
8 engine::{AsyncEngine, Data},
9 pipeline::{
10 AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
11 error::{PipelineError, PipelineErrorExt},
12 },
13 protocols::maybe_error::MaybeError,
14 traits::DistributedRuntimeProvider,
15};
16use async_nats::client::{
17 RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
18};
19use async_trait::async_trait;
20use rand::Rng;
21use serde::{Deserialize, Serialize};
22use std::{
23 future::Future,
24 marker::PhantomData,
25 sync::{
26 Arc,
27 atomic::{AtomicU64, Ordering},
28 },
29};
30use tokio_stream::StreamExt;
31
32#[derive(Clone)]
33pub struct PushRouter<T, U>
34where
35 T: Data + Serialize,
36 U: Data + for<'de> Deserialize<'de>,
37{
38 pub client: Client,
41
42 router_mode: RouterMode,
49
50 round_robin_counter: Arc<AtomicU64>,
52
53 addressed: Arc<AddressedPushRouter>,
56
57 worker_monitor: Option<Arc<WorkerMonitor>>,
59
60 busy_threshold: Option<f64>,
63
64 _phantom: PhantomData<(T, U)>,
68}
69
70#[derive(Default, Debug, Clone, Copy, PartialEq)]
71pub enum RouterMode {
72 #[default]
73 RoundRobin,
74 Random,
75 Direct(i64),
76 KV,
78}
79
80impl RouterMode {
81 pub fn is_kv_routing(&self) -> bool {
82 *self == RouterMode::KV
83 }
84}
85
86async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
87 AddressedPushRouter::new(
88 endpoint.drt().nats_client.client().clone(),
89 endpoint.drt().tcp_server().await?,
90 )
91}
92
93impl<T, U> PushRouter<T, U>
94where
95 T: Data + Serialize,
96 U: Data + for<'de> Deserialize<'de> + MaybeError,
97{
98 pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
100 Self::from_client_with_threshold(client, router_mode, None).await
101 }
102
103 pub async fn from_client_with_threshold(
105 client: Client,
106 router_mode: RouterMode,
107 busy_threshold: Option<f64>,
108 ) -> anyhow::Result<Self> {
109 let addressed = addressed_router(&client.endpoint).await?;
110
111 let worker_monitor = match (busy_threshold, client.instance_source.as_ref()) {
113 (Some(threshold), InstanceSource::Dynamic(_)) => {
114 let monitor = Arc::new(WorkerMonitor::new_with_threshold(
115 Arc::new(client.clone()),
116 threshold,
117 ));
118 monitor.start_monitoring().await?;
119 Some(monitor)
120 }
121 _ => None,
122 };
123
124 let router = PushRouter {
125 client: client.clone(),
126 addressed,
127 router_mode,
128 round_robin_counter: Arc::new(AtomicU64::new(0)),
129 worker_monitor,
130 busy_threshold,
131 _phantom: PhantomData,
132 };
133
134 Ok(router)
135 }
136
137 pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
139 let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
140
141 let instance_id = {
142 let instance_ids = self.client.instance_ids_avail();
143 let count = instance_ids.len();
144 if count == 0 {
145 return Err(anyhow::anyhow!(
146 "no instances found for endpoint {:?}",
147 self.client.endpoint.etcd_root()
148 ));
149 }
150 instance_ids[counter % count]
151 };
152 tracing::trace!("round robin router selected {instance_id}");
153
154 self.generate_with_fault_detection(instance_id, request)
155 .await
156 }
157
158 pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
160 let instance_id = {
161 let instance_ids = self.client.instance_ids_avail();
162 let count = instance_ids.len();
163 if count == 0 {
164 return Err(anyhow::anyhow!(
165 "no instances found for endpoint {:?}",
166 self.client.endpoint.etcd_root()
167 ));
168 }
169 let counter = rand::rng().random::<u64>() as usize;
170 instance_ids[counter % count]
171 };
172 tracing::trace!("random router selected {instance_id}");
173
174 self.generate_with_fault_detection(instance_id, request)
175 .await
176 }
177
178 pub async fn direct(
180 &self,
181 request: SingleIn<T>,
182 instance_id: i64,
183 ) -> anyhow::Result<ManyOut<U>> {
184 let found = self.client.instance_ids_avail().contains(&instance_id);
185
186 if !found {
187 return Err(anyhow::anyhow!(
188 "instance_id={instance_id} not found for endpoint {:?}",
189 self.client.endpoint.etcd_root()
190 ));
191 }
192
193 self.generate_with_fault_detection(instance_id, request)
194 .await
195 }
196
197 pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
198 let subject = self.client.endpoint.subject();
199 tracing::debug!("static got subject: {subject}");
200 let request = request.map(|req| AddressedRequest::new(req, subject));
201 tracing::debug!("router generate");
202 self.addressed.generate(request).await
203 }
204
205 async fn generate_with_fault_detection(
206 &self,
207 instance_id: i64,
208 request: SingleIn<T>,
209 ) -> anyhow::Result<ManyOut<U>> {
210 if self.busy_threshold.is_some() {
212 let free_instances = self.client.instance_ids_free();
213 if free_instances.is_empty() {
214 let all_instances = self.client.instance_ids();
216 if !all_instances.is_empty() {
217 return Err(PipelineError::ServiceOverloaded(
218 "All workers are busy, please retry later".to_string(),
219 )
220 .into());
221 }
222 }
223 }
224
225 let subject = self.client.endpoint.subject_to(instance_id);
226 let request = request.map(|req| AddressedRequest::new(req, subject));
227
228 let stream: anyhow::Result<ManyOut<U>> = self.addressed.generate(request).await;
229 match stream {
230 Ok(stream) => {
231 let engine_ctx = stream.context();
232 let client = self.client.clone();
233 let stream = stream.map(move |res| {
234 if let Some(err) = res.err()
236 && format!("{:?}", err) == STREAM_ERR_MSG
237 {
238 tracing::debug!(
239 "Reporting instance {instance_id} down due to stream error: {err}"
240 );
241 client.report_instance_down(instance_id);
242 }
243 res
244 });
245 Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
246 }
247 Err(err) => {
248 if let Some(req_err) = err.downcast_ref::<NatsRequestError>()
249 && matches!(req_err.kind(), NatsNoResponders)
250 {
251 tracing::debug!(
252 "Reporting instance {instance_id} down due to request error: {req_err}"
253 );
254 self.client.report_instance_down(instance_id);
255 }
256 Err(err)
257 }
258 }
259 }
260}
261
262#[async_trait]
263impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
264where
265 T: Data + Serialize,
266 U: Data + for<'de> Deserialize<'de> + MaybeError,
267{
268 async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
269 match self.client.instance_source.as_ref() {
270 InstanceSource::Static => self.r#static(request).await,
271 InstanceSource::Dynamic(_) => match self.router_mode {
272 RouterMode::Random => self.random(request).await,
273 RouterMode::RoundRobin => self.round_robin(request).await,
274 RouterMode::Direct(instance_id) => self.direct(request, instance_id).await,
275 RouterMode::KV => {
276 anyhow::bail!("KV routing should not call generate on PushRouter");
277 }
278 },
279 }
280 }
281}