Skip to main content

dynamo_runtime/pipeline/network/egress/
push_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{AsyncEngineContextProvider, ResponseStream};
5use crate::error::{BackendError, ErrorType, match_error_chain};
6
7/// Check if an error chain indicates the worker should be reported as down.
8fn is_inhibited(err: &(dyn std::error::Error + 'static)) -> bool {
9    const INHIBITED: &[ErrorType] = &[
10        ErrorType::CannotConnect,
11        ErrorType::Disconnected,
12        ErrorType::ConnectionTimeout,
13        ErrorType::Backend(BackendError::EngineShutdown),
14    ];
15    match_error_chain(err, INHIBITED, &[])
16}
17use crate::{
18    component::{Client, Endpoint},
19    engine::{AsyncEngine, Data},
20    pipeline::{
21        AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
22        error::{PipelineError, PipelineErrorExt},
23    },
24    protocols::maybe_error::MaybeError,
25    traits::DistributedRuntimeProvider,
26};
27use async_trait::async_trait;
28use rand::Rng;
29use serde::{Deserialize, Serialize};
30use std::{
31    future::Future,
32    marker::PhantomData,
33    sync::{
34        Arc,
35        atomic::{AtomicU64, Ordering},
36    },
37};
38use tokio_stream::StreamExt;
39use tracing::Instrument;
40
41/// Trait for monitoring worker load and determining busy state.
42/// Implementations can define custom load metrics and busy thresholds.
43#[async_trait]
44pub trait WorkerLoadMonitor: Send + Sync {
45    /// Start background monitoring of worker load.
46    /// This should spawn background tasks that update the client's free instances.
47    async fn start_monitoring(&self) -> anyhow::Result<()>;
48}
49
50#[derive(Clone)]
51pub struct PushRouter<T, U>
52where
53    T: Data + Serialize,
54    U: Data + for<'de> Deserialize<'de>,
55{
56    // TODO: This shouldn't be pub, but lib/bindings/python/rust/lib.rs exposes it.
57    /// The Client is how we gather remote endpoint information from etcd.
58    pub client: Client,
59
60    /// How we choose which instance to send traffic to.
61    ///
62    /// Setting this to KV means we never intend to call `generate` on this PushRouter. We are
63    /// not using it as an AsyncEngine.
64    /// Instead we will decide whether to call random/round_robin/direct ourselves and call them directly.
65    /// dynamo-llm's KV Routing does this.
66    router_mode: RouterMode,
67
68    /// Number of round robin requests handled. Used to decide which server is next.
69    round_robin_counter: Arc<AtomicU64>,
70
71    /// The next step in the chain. PushRouter (this object) picks an instances,
72    /// addresses it, then passes it to AddressedPushRouter which does the network traffic.
73    addressed: Arc<AddressedPushRouter>,
74
75    /// Threshold for determining when a worker is busy (0.0 to 1.0)
76    /// If None, busy detection is disabled
77    busy_threshold: Option<f64>,
78
79    /// When false, `generate_with_fault_detection` skips fault detection logic:
80    /// it won't call `report_instance_down` on errors, and it uses the raw discovery
81    /// instance list instead of the filtered avail list. Use for recovery/query paths
82    /// where transient failures are expected.
83    fault_detection_enabled: bool,
84
85    /// An internal Rust type. This says that PushRouter is generic over the T and U types,
86    /// which are the input and output types of it's `generate` function. It allows the
87    /// compiler to specialize us at compile time.
88    _phantom: PhantomData<(T, U)>,
89}
90
91#[derive(Default, Debug, Clone, Copy, PartialEq)]
92pub enum RouterMode {
93    #[default]
94    RoundRobin,
95    Random,
96    KV,
97    Direct,
98}
99
100impl RouterMode {
101    pub fn is_kv_routing(&self) -> bool {
102        *self == RouterMode::KV
103    }
104
105    pub fn is_direct_routing(&self) -> bool {
106        *self == RouterMode::Direct
107    }
108}
109
110async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
111    // Get network manager and create client (no mode checks!)
112    let manager = endpoint.drt().network_manager();
113    let req_client = manager.create_client()?;
114    let resp_transport = endpoint.drt().tcp_server().await?;
115
116    tracing::debug!(
117        transport = req_client.transport_name(),
118        "Creating AddressedPushRouter with request plane client"
119    );
120
121    AddressedPushRouter::new(req_client, resp_transport)
122}
123
124impl<T, U> PushRouter<T, U>
125where
126    T: Data + Serialize,
127    U: Data + for<'de> Deserialize<'de> + MaybeError,
128{
129    /// Create a new PushRouter without busy threshold (no busy detection)
130    pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
131        Self::from_client_with_threshold(client, router_mode, None, None).await
132    }
133
134    /// Create a new PushRouter with fault detection disabled.
135    ///
136    /// Unlike `from_client`, this router will not call `report_instance_down` on
137    /// transient errors, and `direct()` uses the raw discovery instance list instead
138    /// of the filtered avail list. Use for recovery/query paths.
139    pub async fn from_client_no_fault_detection(
140        client: Client,
141        router_mode: RouterMode,
142    ) -> anyhow::Result<Self> {
143        let addressed = addressed_router(&client.endpoint).await?;
144
145        Ok(PushRouter {
146            client: client.clone(),
147            addressed,
148            router_mode,
149            round_robin_counter: Arc::new(AtomicU64::new(0)),
150            busy_threshold: None,
151            fault_detection_enabled: false,
152            _phantom: PhantomData,
153        })
154    }
155
156    /// Create a new PushRouter with optional busy threshold and worker load monitor
157    pub async fn from_client_with_threshold(
158        client: Client,
159        router_mode: RouterMode,
160        busy_threshold: Option<f64>,
161        worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
162    ) -> anyhow::Result<Self> {
163        let addressed = addressed_router(&client.endpoint).await?;
164
165        // Start worker monitor if provided and in dynamic mode
166        if let Some(monitor) = worker_monitor.as_ref() {
167            monitor.start_monitoring().await?;
168        }
169
170        let router = PushRouter {
171            client: client.clone(),
172            addressed,
173            router_mode,
174            round_robin_counter: Arc::new(AtomicU64::new(0)),
175            busy_threshold,
176            fault_detection_enabled: true,
177            _phantom: PhantomData,
178        };
179
180        Ok(router)
181    }
182
183    /// Issue a request to the next available instance in a round-robin fashion
184    pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
185        let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
186
187        let instance_id = {
188            let instance_ids = self.client.instance_ids_avail();
189            let count = instance_ids.len();
190            if count == 0 {
191                return Err(anyhow::anyhow!(
192                    "no instances found for endpoint {}",
193                    self.client.endpoint.id()
194                ));
195            }
196            instance_ids[counter % count]
197        };
198        tracing::trace!("round robin router selected {instance_id}");
199
200        self.generate_with_fault_detection(instance_id, request)
201            .await
202    }
203
204    /// Issue a request to a random endpoint
205    pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
206        let instance_id = {
207            let instance_ids = self.client.instance_ids_avail();
208            let count = instance_ids.len();
209            if count == 0 {
210                return Err(anyhow::anyhow!(
211                    "no instances found for endpoint {}",
212                    self.client.endpoint.id()
213                ));
214            }
215            let counter = rand::rng().random::<u64>() as usize;
216            instance_ids[counter % count]
217        };
218        tracing::trace!("random router selected {instance_id}");
219
220        self.generate_with_fault_detection(instance_id, request)
221            .await
222    }
223
224    /// Issue a request to a specific endpoint
225    pub async fn direct(
226        &self,
227        request: SingleIn<T>,
228        instance_id: u64,
229    ) -> anyhow::Result<ManyOut<U>> {
230        // When fault detection is disabled, check the raw discovery list
231        // (not filtered by report_instance_down) so transient failures
232        // don't poison the instance for subsequent retries.
233        let found = if self.fault_detection_enabled {
234            self.client.instance_ids_avail().contains(&instance_id)
235        } else {
236            self.client.instance_ids().contains(&instance_id)
237        };
238
239        if !found {
240            return Err(anyhow::anyhow!(
241                "instance_id={instance_id} not found for endpoint {}",
242                self.client.endpoint.id()
243            ));
244        }
245
246        self.generate_with_fault_detection(instance_id, request)
247            .await
248    }
249
250    /// Select the next worker according to the routing mode.
251    /// Increments round-robin counter if applicable.
252    /// Panics if called on Direct or KV mode - those have their own selection mechanisms.
253    pub fn select_next_worker(&self) -> Option<u64> {
254        let instance_ids = self.client.instance_ids_avail();
255        let count = instance_ids.len();
256        if count == 0 {
257            return None;
258        }
259
260        match self.router_mode {
261            RouterMode::RoundRobin => {
262                let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
263                Some(instance_ids[counter % count])
264            }
265            RouterMode::Random => {
266                let counter = rand::rng().random::<u64>() as usize;
267                Some(instance_ids[counter % count])
268            }
269            _ => {
270                panic!(
271                    "select_next_worker should not be called for {:?} routing mode",
272                    self.router_mode
273                )
274            }
275        }
276    }
277
278    /// Peek the next worker according to the routing mode without incrementing the counter.
279    /// Useful for checking if a worker is suitable before committing to it.
280    pub fn peek_next_worker(&self) -> Option<u64> {
281        let instance_ids = self.client.instance_ids_avail();
282        let count = instance_ids.len();
283        if count == 0 {
284            return None;
285        }
286
287        match self.router_mode {
288            RouterMode::RoundRobin => {
289                // Just peek at the current counter value without incrementing
290                let counter = self.round_robin_counter.load(Ordering::Relaxed) as usize;
291                Some(instance_ids[counter % count])
292            }
293            RouterMode::Random => {
294                // For random, peeking implies a fresh random selection since it's stateless.
295                // Note: The caller must realize that select_next_worker() will pick a DIFFERENT random worker.
296                let counter = rand::rng().random::<u64>() as usize;
297                Some(instance_ids[counter % count])
298            }
299            _ => {
300                panic!(
301                    "peek_next_worker should not be called for {:?} routing mode",
302                    self.router_mode
303                )
304            }
305        }
306    }
307
308    /*
309    pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
310        let subject = self.client.endpoint.subject();
311        tracing::debug!("static got subject: {subject}");
312        let request = request.map(|req| AddressedRequest::new(req, subject));
313        tracing::debug!("router generate");
314        self.addressed.generate(request).await
315    }
316    */
317
318    async fn generate_with_fault_detection(
319        &self,
320        instance_id: u64,
321        request: SingleIn<T>,
322    ) -> anyhow::Result<ManyOut<U>> {
323        let request_id = request.id().to_string();
324        let route_span = if matches!(self.router_mode, RouterMode::KV) {
325            tracing::Span::none()
326        } else {
327            tracing::info_span!(
328                "router.route_request",
329                request_id = %request_id,
330                worker_id = instance_id,
331                router_mode = ?self.router_mode,
332            )
333        };
334
335        // Check if all workers are busy (only if busy threshold is set and fault detection enabled)
336        if self.fault_detection_enabled && self.busy_threshold.is_some() {
337            let free_instances = self.client.instance_ids_free();
338            if free_instances.is_empty() {
339                // Check if we actually have any instances at all
340                let all_instances = self.client.instance_ids();
341                if !all_instances.is_empty() {
342                    tracing::warn!(
343                        instance_id,
344                        total_workers = all_instances.len(),
345                        "Rejecting request: all workers are busy"
346                    );
347                    return Err(PipelineError::ServiceOverloaded(
348                        "All workers are busy, please retry later".to_string(),
349                    )
350                    .into());
351                }
352            }
353        }
354
355        // Get the address based on discovered transport type
356        let address = {
357            use crate::component::TransportType;
358
359            // Get the instance and use its actual transport type
360            let instances = self.client.instances();
361            let instance = instances
362                .iter()
363                .find(|i| i.instance_id == instance_id)
364                .ok_or_else(|| {
365                    anyhow::anyhow!("Instance {} not found in available instances", instance_id)
366                })?;
367
368            match &instance.transport {
369                TransportType::Http(http_endpoint) => {
370                    tracing::debug!(
371                        instance_id = instance_id,
372                        http_endpoint = %http_endpoint,
373                        "Using HTTP transport for instance"
374                    );
375                    http_endpoint.clone()
376                }
377                TransportType::Tcp(tcp_endpoint) => {
378                    tracing::debug!(
379                        instance_id = instance_id,
380                        tcp_endpoint = %tcp_endpoint,
381                        "Using TCP transport for instance"
382                    );
383                    tcp_endpoint.clone()
384                }
385                TransportType::Nats(subject) => {
386                    tracing::debug!(
387                        instance_id = instance_id,
388                        subject = %subject,
389                        "Using NATS transport for instance"
390                    );
391                    subject.clone()
392                }
393            }
394        };
395
396        let request = request.map(|req| AddressedRequest::new(req, address));
397
398        let stream: anyhow::Result<ManyOut<U>> = self
399            .addressed
400            .generate(request)
401            .instrument(route_span)
402            .await;
403        match stream {
404            Ok(stream) => {
405                if !self.fault_detection_enabled {
406                    return Ok(stream);
407                }
408                let engine_ctx = stream.context();
409                let client = self.client.clone();
410                let stream = stream.map(move |res| {
411                    // Check if the error is migratable (indicates worker/connection failure)
412                    if let Some(err) = res.err()
413                        && is_inhibited(&err)
414                    {
415                        tracing::debug!(
416                            "Reporting instance {instance_id} down due to migratable error: {err}"
417                        );
418                        client.report_instance_down(instance_id);
419                    }
420                    res
421                });
422                Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
423            }
424            Err(err) => {
425                if self.fault_detection_enabled && is_inhibited(err.as_ref()) {
426                    tracing::debug!("Reporting instance {instance_id} down due to error: {err}");
427                    self.client.report_instance_down(instance_id);
428                }
429                Err(err)
430            }
431        }
432    }
433}
434
435#[async_trait]
436impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
437where
438    T: Data + Serialize,
439    U: Data + for<'de> Deserialize<'de> + MaybeError,
440{
441    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
442        match self.router_mode {
443            RouterMode::Random => self.random(request).await,
444            RouterMode::RoundRobin => self.round_robin(request).await,
445            RouterMode::KV => {
446                anyhow::bail!("KV routing should not call generate on PushRouter");
447            }
448            RouterMode::Direct => {
449                anyhow::bail!(
450                    "Direct routing should not call generate on PushRouter directly; use DirectRoutingRouter wrapper"
451                );
452            }
453        }
454    }
455}