Skip to main content

dynamo_runtime/pipeline/network/egress/
push_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG};
5use crate::{
6    component::{Client, Endpoint},
7    engine::{AsyncEngine, Data},
8    pipeline::{
9        AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
10        error::{PipelineError, PipelineErrorExt},
11    },
12    protocols::maybe_error::MaybeError,
13    traits::DistributedRuntimeProvider,
14};
15use async_nats::client::{
16    RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
17};
18use async_trait::async_trait;
19use rand::Rng;
20use serde::{Deserialize, Serialize};
21use std::{
22    future::Future,
23    marker::PhantomData,
24    sync::{
25        Arc,
26        atomic::{AtomicU64, Ordering},
27    },
28};
29use tokio_stream::StreamExt;
30
31/// Trait for monitoring worker load and determining busy state.
32/// Implementations can define custom load metrics and busy thresholds.
33#[async_trait]
34pub trait WorkerLoadMonitor: Send + Sync {
35    /// Start background monitoring of worker load.
36    /// This should spawn background tasks that update the client's free instances.
37    async fn start_monitoring(&self) -> anyhow::Result<()>;
38}
39
40#[derive(Clone)]
41pub struct PushRouter<T, U>
42where
43    T: Data + Serialize,
44    U: Data + for<'de> Deserialize<'de>,
45{
46    // TODO: This shouldn't be pub, but lib/bindings/python/rust/lib.rs exposes it.
47    /// The Client is how we gather remote endpoint information from etcd.
48    pub client: Client,
49
50    /// How we choose which instance to send traffic to.
51    ///
52    /// Setting this to KV means we never intend to call `generate` on this PushRouter. We are
53    /// not using it as an AsyncEngine.
54    /// Instead we will decide whether to call random/round_robin/direct ourselves and call them directly.
55    /// dynamo-llm's KV Routing does this.
56    router_mode: RouterMode,
57
58    /// Number of round robin requests handled. Used to decide which server is next.
59    round_robin_counter: Arc<AtomicU64>,
60
61    /// The next step in the chain. PushRouter (this object) picks an instances,
62    /// addresses it, then passes it to AddressedPushRouter which does the network traffic.
63    addressed: Arc<AddressedPushRouter>,
64
65    /// Threshold for determining when a worker is busy (0.0 to 1.0)
66    /// If None, busy detection is disabled
67    busy_threshold: Option<f64>,
68
69    /// An internal Rust type. This says that PushRouter is generic over the T and U types,
70    /// which are the input and output types of it's `generate` function. It allows the
71    /// compiler to specialize us at compile time.
72    _phantom: PhantomData<(T, U)>,
73}
74
75#[derive(Default, Debug, Clone, Copy, PartialEq)]
76pub enum RouterMode {
77    #[default]
78    RoundRobin,
79    Random,
80    Direct(u64),
81    // Marker value, KV routing itself is in dynamo-llm
82    KV,
83}
84
85impl RouterMode {
86    pub fn is_kv_routing(&self) -> bool {
87        *self == RouterMode::KV
88    }
89}
90
91async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
92    // Get network manager and create client (no mode checks!)
93    let manager = endpoint.drt().network_manager();
94    let req_client = manager.create_client()?;
95    let resp_transport = endpoint.drt().tcp_server().await?;
96
97    tracing::debug!(
98        transport = req_client.transport_name(),
99        "Creating AddressedPushRouter with request plane client"
100    );
101
102    AddressedPushRouter::new(req_client, resp_transport)
103}
104
105impl<T, U> PushRouter<T, U>
106where
107    T: Data + Serialize,
108    U: Data + for<'de> Deserialize<'de> + MaybeError,
109{
110    /// Create a new PushRouter without busy threshold (no busy detection)
111    pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
112        Self::from_client_with_threshold(client, router_mode, None, None).await
113    }
114
115    /// Create a new PushRouter with optional busy threshold and worker load monitor
116    pub async fn from_client_with_threshold(
117        client: Client,
118        router_mode: RouterMode,
119        busy_threshold: Option<f64>,
120        worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
121    ) -> anyhow::Result<Self> {
122        let addressed = addressed_router(&client.endpoint).await?;
123
124        // Start worker monitor if provided and in dynamic mode
125        if let Some(monitor) = worker_monitor.as_ref() {
126            monitor.start_monitoring().await?;
127        }
128
129        let router = PushRouter {
130            client: client.clone(),
131            addressed,
132            router_mode,
133            round_robin_counter: Arc::new(AtomicU64::new(0)),
134            busy_threshold,
135            _phantom: PhantomData,
136        };
137
138        Ok(router)
139    }
140
141    /// Issue a request to the next available instance in a round-robin fashion
142    pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
143        let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
144
145        let instance_id = {
146            let instance_ids = self.client.instance_ids_avail();
147            let count = instance_ids.len();
148            if count == 0 {
149                return Err(anyhow::anyhow!(
150                    "no instances found for endpoint {}",
151                    self.client.endpoint.id()
152                ));
153            }
154            instance_ids[counter % count]
155        };
156        tracing::trace!("round robin router selected {instance_id}");
157
158        self.generate_with_fault_detection(instance_id, request)
159            .await
160    }
161
162    /// Issue a request to a random endpoint
163    pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
164        let instance_id = {
165            let instance_ids = self.client.instance_ids_avail();
166            let count = instance_ids.len();
167            if count == 0 {
168                return Err(anyhow::anyhow!(
169                    "no instances found for endpoint {}",
170                    self.client.endpoint.id()
171                ));
172            }
173            let counter = rand::rng().random::<u64>() as usize;
174            instance_ids[counter % count]
175        };
176        tracing::trace!("random router selected {instance_id}");
177
178        self.generate_with_fault_detection(instance_id, request)
179            .await
180    }
181
182    /// Issue a request to a specific endpoint
183    pub async fn direct(
184        &self,
185        request: SingleIn<T>,
186        instance_id: u64,
187    ) -> anyhow::Result<ManyOut<U>> {
188        let found = self.client.instance_ids_avail().contains(&instance_id);
189
190        if !found {
191            return Err(anyhow::anyhow!(
192                "instance_id={instance_id} not found for endpoint {}",
193                self.client.endpoint.id()
194            ));
195        }
196
197        self.generate_with_fault_detection(instance_id, request)
198            .await
199    }
200
201    /// Select the next worker according to the routing mode.
202    /// Increments round-robin counter if applicable.
203    /// Panics if called on Direct or KV mode - those have their own selection mechanisms.
204    pub fn select_next_worker(&self) -> Option<u64> {
205        let instance_ids = self.client.instance_ids_avail();
206        let count = instance_ids.len();
207        if count == 0 {
208            return None;
209        }
210
211        match self.router_mode {
212            RouterMode::RoundRobin => {
213                let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
214                Some(instance_ids[counter % count])
215            }
216            RouterMode::Random => {
217                let counter = rand::rng().random::<u64>() as usize;
218                Some(instance_ids[counter % count])
219            }
220            _ => {
221                panic!(
222                    "select_next_worker should not be called for {:?} routing mode",
223                    self.router_mode
224                )
225            }
226        }
227    }
228
229    /// Peek the next worker according to the routing mode without incrementing the counter.
230    /// Useful for checking if a worker is suitable before committing to it.
231    pub fn peek_next_worker(&self) -> Option<u64> {
232        let instance_ids = self.client.instance_ids_avail();
233        let count = instance_ids.len();
234        if count == 0 {
235            return None;
236        }
237
238        match self.router_mode {
239            RouterMode::RoundRobin => {
240                // Just peek at the current counter value without incrementing
241                let counter = self.round_robin_counter.load(Ordering::Relaxed) as usize;
242                Some(instance_ids[counter % count])
243            }
244            RouterMode::Random => {
245                // For random, peeking implies a fresh random selection since it's stateless.
246                // Note: The caller must realize that select_next_worker() will pick a DIFFERENT random worker.
247                let counter = rand::rng().random::<u64>() as usize;
248                Some(instance_ids[counter % count])
249            }
250            _ => {
251                panic!(
252                    "peek_next_worker should not be called for {:?} routing mode",
253                    self.router_mode
254                )
255            }
256        }
257    }
258
259    /*
260    pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
261        let subject = self.client.endpoint.subject();
262        tracing::debug!("static got subject: {subject}");
263        let request = request.map(|req| AddressedRequest::new(req, subject));
264        tracing::debug!("router generate");
265        self.addressed.generate(request).await
266    }
267    */
268
269    async fn generate_with_fault_detection(
270        &self,
271        instance_id: u64,
272        request: SingleIn<T>,
273    ) -> anyhow::Result<ManyOut<U>> {
274        // Check if all workers are busy (only if busy threshold is set)
275        if self.busy_threshold.is_some() {
276            let free_instances = self.client.instance_ids_free();
277            if free_instances.is_empty() {
278                // Check if we actually have any instances at all
279                let all_instances = self.client.instance_ids();
280                if !all_instances.is_empty() {
281                    return Err(PipelineError::ServiceOverloaded(
282                        "All workers are busy, please retry later".to_string(),
283                    )
284                    .into());
285                }
286            }
287        }
288
289        // Get the address based on discovered transport type
290        let address = {
291            use crate::component::TransportType;
292
293            // Get the instance and use its actual transport type
294            let instances = self.client.instances();
295            let instance = instances
296                .iter()
297                .find(|i| i.instance_id == instance_id)
298                .ok_or_else(|| {
299                    anyhow::anyhow!("Instance {} not found in available instances", instance_id)
300                })?;
301
302            match &instance.transport {
303                TransportType::Http(http_endpoint) => {
304                    tracing::debug!(
305                        instance_id = instance_id,
306                        http_endpoint = %http_endpoint,
307                        "Using HTTP transport for instance"
308                    );
309                    http_endpoint.clone()
310                }
311                TransportType::Tcp(tcp_endpoint) => {
312                    tracing::debug!(
313                        instance_id = instance_id,
314                        tcp_endpoint = %tcp_endpoint,
315                        "Using TCP transport for instance"
316                    );
317                    tcp_endpoint.clone()
318                }
319                TransportType::Nats(subject) => {
320                    tracing::debug!(
321                        instance_id = instance_id,
322                        subject = %subject,
323                        "Using NATS transport for instance"
324                    );
325                    subject.clone()
326                }
327            }
328        };
329
330        let request = request.map(|req| AddressedRequest::new(req, address));
331
332        let stream: anyhow::Result<ManyOut<U>> = self.addressed.generate(request).await;
333        match stream {
334            Ok(stream) => {
335                let engine_ctx = stream.context();
336                let client = self.client.clone();
337                let stream = stream.map(move |res| {
338                    // TODO: Standardize error type to avoid using string matching DIS-364
339                    if let Some(err) = res.err()
340                        && format!("{:?}", err) == STREAM_ERR_MSG
341                    {
342                        tracing::debug!(
343                            "Reporting instance {instance_id} down due to stream error: {err}"
344                        );
345                        client.report_instance_down(instance_id);
346                    }
347                    res
348                });
349                Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
350            }
351            Err(err) => {
352                if let Some(req_err) = err.downcast_ref::<NatsRequestError>()
353                    && matches!(req_err.kind(), NatsNoResponders)
354                {
355                    tracing::debug!(
356                        "Reporting instance {instance_id} down due to request error: {req_err}"
357                    );
358                    self.client.report_instance_down(instance_id);
359                }
360                Err(err)
361            }
362        }
363    }
364}
365
366#[async_trait]
367impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
368where
369    T: Data + Serialize,
370    U: Data + for<'de> Deserialize<'de> + MaybeError,
371{
372    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
373        //InstanceSource::Static => self.r#static(request).await,
374        match self.router_mode {
375            RouterMode::Random => self.random(request).await,
376            RouterMode::RoundRobin => self.round_robin(request).await,
377            RouterMode::Direct(instance_id) => self.direct(request, instance_id).await,
378            RouterMode::KV => {
379                anyhow::bail!("KV routing should not call generate on PushRouter");
380            }
381        }
382    }
383}