dynamo_runtime/pipeline/network/egress/
push_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG};
5use crate::{
6    component::{Client, Endpoint, InstanceSource},
7    engine::{AsyncEngine, Data},
8    pipeline::{
9        AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
10        error::{PipelineError, PipelineErrorExt},
11    },
12    protocols::maybe_error::MaybeError,
13    traits::DistributedRuntimeProvider,
14};
15use async_nats::client::{
16    RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
17};
18use async_trait::async_trait;
19use rand::Rng;
20use serde::{Deserialize, Serialize};
21use std::{
22    future::Future,
23    marker::PhantomData,
24    sync::{
25        Arc,
26        atomic::{AtomicU64, Ordering},
27    },
28};
29use tokio_stream::StreamExt;
30
31/// Trait for monitoring worker load and determining busy state.
32/// Implementations can define custom load metrics and busy thresholds.
33#[async_trait]
34pub trait WorkerLoadMonitor: Send + Sync {
35    /// Start background monitoring of worker load.
36    /// This should spawn background tasks that update the client's free instances.
37    async fn start_monitoring(&self) -> anyhow::Result<()>;
38}
39
40#[derive(Clone)]
41pub struct PushRouter<T, U>
42where
43    T: Data + Serialize,
44    U: Data + for<'de> Deserialize<'de>,
45{
46    // TODO: This shouldn't be pub, but lib/bindings/python/rust/lib.rs exposes it.
47    /// The Client is how we gather remote endpoint information from etcd.
48    pub client: Client,
49
50    /// How we choose which instance to send traffic to.
51    ///
52    /// Setting this to KV means we never intend to call `generate` on this PushRouter. We are
53    /// not using it as an AsyncEngine.
54    /// Instead we will decide whether to call random/round_robin/direct ourselves and call them directly.
55    /// dynamo-llm's KV Routing does this.
56    router_mode: RouterMode,
57
58    /// Number of round robin requests handled. Used to decide which server is next.
59    round_robin_counter: Arc<AtomicU64>,
60
61    /// The next step in the chain. PushRouter (this object) picks an instances,
62    /// addresses it, then passes it to AddressedPushRouter which does the network traffic.
63    addressed: Arc<AddressedPushRouter>,
64
65    /// Threshold for determining when a worker is busy (0.0 to 1.0)
66    /// If None, busy detection is disabled
67    busy_threshold: Option<f64>,
68
69    /// An internal Rust type. This says that PushRouter is generic over the T and U types,
70    /// which are the input and output types of it's `generate` function. It allows the
71    /// compiler to specialize us at compile time.
72    _phantom: PhantomData<(T, U)>,
73}
74
75#[derive(Default, Debug, Clone, Copy, PartialEq)]
76pub enum RouterMode {
77    #[default]
78    RoundRobin,
79    Random,
80    Direct(u64),
81    // Marker value, KV routing itself is in dynamo-llm
82    KV,
83}
84
85impl RouterMode {
86    pub fn is_kv_routing(&self) -> bool {
87        *self == RouterMode::KV
88    }
89}
90
91async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
92    // Get network manager and create client (no mode checks!)
93    let manager = endpoint.drt().network_manager().await?;
94    let req_client = manager.create_client()?;
95    let resp_transport = endpoint.drt().tcp_server().await?;
96
97    tracing::debug!(
98        transport = req_client.transport_name(),
99        "Creating AddressedPushRouter with request plane client"
100    );
101
102    AddressedPushRouter::new(req_client, resp_transport)
103}
104
105impl<T, U> PushRouter<T, U>
106where
107    T: Data + Serialize,
108    U: Data + for<'de> Deserialize<'de> + MaybeError,
109{
110    /// Create a new PushRouter without busy threshold (no busy detection)
111    pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
112        Self::from_client_with_threshold(client, router_mode, None, None).await
113    }
114
115    /// Create a new PushRouter with optional busy threshold and worker load monitor
116    pub async fn from_client_with_threshold(
117        client: Client,
118        router_mode: RouterMode,
119        busy_threshold: Option<f64>,
120        worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
121    ) -> anyhow::Result<Self> {
122        let addressed = addressed_router(&client.endpoint).await?;
123
124        // Start worker monitor if provided and in dynamic mode
125        if let Some(monitor) = worker_monitor.as_ref()
126            && matches!(client.instance_source.as_ref(), InstanceSource::Dynamic(_))
127        {
128            monitor.start_monitoring().await?;
129        }
130
131        let router = PushRouter {
132            client: client.clone(),
133            addressed,
134            router_mode,
135            round_robin_counter: Arc::new(AtomicU64::new(0)),
136            busy_threshold,
137            _phantom: PhantomData,
138        };
139
140        Ok(router)
141    }
142
143    /// Issue a request to the next available instance in a round-robin fashion
144    pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
145        let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
146
147        let instance_id = {
148            let instance_ids = self.client.instance_ids_avail();
149            let count = instance_ids.len();
150            if count == 0 {
151                return Err(anyhow::anyhow!(
152                    "no instances found for endpoint {:?}",
153                    self.client.endpoint.etcd_root()
154                ));
155            }
156            instance_ids[counter % count]
157        };
158        tracing::trace!("round robin router selected {instance_id}");
159
160        self.generate_with_fault_detection(instance_id, request)
161            .await
162    }
163
164    /// Issue a request to a random endpoint
165    pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
166        let instance_id = {
167            let instance_ids = self.client.instance_ids_avail();
168            let count = instance_ids.len();
169            if count == 0 {
170                return Err(anyhow::anyhow!(
171                    "no instances found for endpoint {:?}",
172                    self.client.endpoint.etcd_root()
173                ));
174            }
175            let counter = rand::rng().random::<u64>() as usize;
176            instance_ids[counter % count]
177        };
178        tracing::trace!("random router selected {instance_id}");
179
180        self.generate_with_fault_detection(instance_id, request)
181            .await
182    }
183
184    /// Issue a request to a specific endpoint
185    pub async fn direct(
186        &self,
187        request: SingleIn<T>,
188        instance_id: u64,
189    ) -> anyhow::Result<ManyOut<U>> {
190        let found = self.client.instance_ids_avail().contains(&instance_id);
191
192        if !found {
193            return Err(anyhow::anyhow!(
194                "instance_id={instance_id} not found for endpoint {:?}",
195                self.client.endpoint.etcd_root()
196            ));
197        }
198
199        self.generate_with_fault_detection(instance_id, request)
200            .await
201    }
202
203    pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
204        let subject = self.client.endpoint.subject();
205        tracing::debug!("static got subject: {subject}");
206        let request = request.map(|req| AddressedRequest::new(req, subject));
207        tracing::debug!("router generate");
208        self.addressed.generate(request).await
209    }
210
211    async fn generate_with_fault_detection(
212        &self,
213        instance_id: u64,
214        request: SingleIn<T>,
215    ) -> anyhow::Result<ManyOut<U>> {
216        // Check if all workers are busy (only if busy threshold is set)
217        if self.busy_threshold.is_some() {
218            let free_instances = self.client.instance_ids_free();
219            if free_instances.is_empty() {
220                // Check if we actually have any instances at all
221                let all_instances = self.client.instance_ids();
222                if !all_instances.is_empty() {
223                    return Err(PipelineError::ServiceOverloaded(
224                        "All workers are busy, please retry later".to_string(),
225                    )
226                    .into());
227                }
228            }
229        }
230
231        // Get the address based on discovered transport type
232        let address = {
233            use crate::component::TransportType;
234
235            // Get the instance and use its actual transport type
236            let instances = self.client.instances();
237            let instance = instances
238                .iter()
239                .find(|i| i.instance_id == instance_id)
240                .ok_or_else(|| {
241                    anyhow::anyhow!("Instance {} not found in available instances", instance_id)
242                })?;
243
244            match &instance.transport {
245                TransportType::Http(http_endpoint) => {
246                    tracing::debug!(
247                        instance_id = instance_id,
248                        http_endpoint = %http_endpoint,
249                        "Using HTTP transport for instance"
250                    );
251                    http_endpoint.clone()
252                }
253                TransportType::Tcp(tcp_endpoint) => {
254                    tracing::debug!(
255                        instance_id = instance_id,
256                        tcp_endpoint = %tcp_endpoint,
257                        "Using TCP transport for instance"
258                    );
259                    tcp_endpoint.clone()
260                }
261                TransportType::Nats(subject) => {
262                    tracing::debug!(
263                        instance_id = instance_id,
264                        subject = %subject,
265                        "Using NATS transport for instance"
266                    );
267                    subject.clone()
268                }
269            }
270        };
271
272        let request = request.map(|req| AddressedRequest::new(req, address));
273
274        let stream: anyhow::Result<ManyOut<U>> = self.addressed.generate(request).await;
275        match stream {
276            Ok(stream) => {
277                let engine_ctx = stream.context();
278                let client = self.client.clone();
279                let stream = stream.map(move |res| {
280                    // TODO: Standardize error type to avoid using string matching DIS-364
281                    if let Some(err) = res.err()
282                        && format!("{:?}", err) == STREAM_ERR_MSG
283                    {
284                        tracing::debug!(
285                            "Reporting instance {instance_id} down due to stream error: {err}"
286                        );
287                        client.report_instance_down(instance_id);
288                    }
289                    res
290                });
291                Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
292            }
293            Err(err) => {
294                if let Some(req_err) = err.downcast_ref::<NatsRequestError>()
295                    && matches!(req_err.kind(), NatsNoResponders)
296                {
297                    tracing::debug!(
298                        "Reporting instance {instance_id} down due to request error: {req_err}"
299                    );
300                    self.client.report_instance_down(instance_id);
301                }
302                Err(err)
303            }
304        }
305    }
306}
307
308#[async_trait]
309impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
310where
311    T: Data + Serialize,
312    U: Data + for<'de> Deserialize<'de> + MaybeError,
313{
314    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
315        match self.client.instance_source.as_ref() {
316            InstanceSource::Static => self.r#static(request).await,
317            InstanceSource::Dynamic(_) => match self.router_mode {
318                RouterMode::Random => self.random(request).await,
319                RouterMode::RoundRobin => self.round_robin(request).await,
320                RouterMode::Direct(instance_id) => self.direct(request, instance_id).await,
321                RouterMode::KV => {
322                    anyhow::bail!("KV routing should not call generate on PushRouter");
323                }
324            },
325        }
326    }
327}