dynamo_runtime/pipeline/network/egress/
push_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG};
5use crate::{
6    component::{Client, Endpoint, InstanceSource},
7    engine::{AsyncEngine, Data},
8    pipeline::{
9        AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
10        error::{PipelineError, PipelineErrorExt},
11    },
12    protocols::maybe_error::MaybeError,
13    traits::DistributedRuntimeProvider,
14};
15use async_nats::client::{
16    RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
17};
18use async_trait::async_trait;
19use rand::Rng;
20use serde::{Deserialize, Serialize};
21use std::{
22    future::Future,
23    marker::PhantomData,
24    sync::{
25        Arc,
26        atomic::{AtomicU64, Ordering},
27    },
28};
29use tokio_stream::StreamExt;
30
31/// Trait for monitoring worker load and determining busy state.
32/// Implementations can define custom load metrics and busy thresholds.
33#[async_trait]
34pub trait WorkerLoadMonitor: Send + Sync {
35    /// Start background monitoring of worker load.
36    /// This should spawn background tasks that update the client's free instances.
37    async fn start_monitoring(&self) -> anyhow::Result<()>;
38}
39
40#[derive(Clone)]
41pub struct PushRouter<T, U>
42where
43    T: Data + Serialize,
44    U: Data + for<'de> Deserialize<'de>,
45{
46    // TODO: This shouldn't be pub, but lib/bindings/python/rust/lib.rs exposes it.
47    /// The Client is how we gather remote endpoint information from etcd.
48    pub client: Client,
49
50    /// How we choose which instance to send traffic to.
51    ///
52    /// Setting this to KV means we never intend to call `generate` on this PushRouter. We are
53    /// not using it as an AsyncEngine.
54    /// Instead we will decide whether to call random/round_robin/direct ourselves and call them directly.
55    /// dynamo-llm's KV Routing does this.
56    router_mode: RouterMode,
57
58    /// Number of round robin requests handled. Used to decide which server is next.
59    round_robin_counter: Arc<AtomicU64>,
60
61    /// The next step in the chain. PushRouter (this object) picks an instances,
62    /// addresses it, then passes it to AddressedPushRouter which does the network traffic.
63    addressed: Arc<AddressedPushRouter>,
64
65    /// Threshold for determining when a worker is busy (0.0 to 1.0)
66    /// If None, busy detection is disabled
67    busy_threshold: Option<f64>,
68
69    /// An internal Rust type. This says that PushRouter is generic over the T and U types,
70    /// which are the input and output types of it's `generate` function. It allows the
71    /// compiler to specialize us at compile time.
72    _phantom: PhantomData<(T, U)>,
73}
74
75#[derive(Default, Debug, Clone, Copy, PartialEq)]
76pub enum RouterMode {
77    #[default]
78    RoundRobin,
79    Random,
80    Direct(u64),
81    // Marker value, KV routing itself is in dynamo-llm
82    KV,
83}
84
85impl RouterMode {
86    pub fn is_kv_routing(&self) -> bool {
87        *self == RouterMode::KV
88    }
89}
90
91async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
92    let Some(nats_client) = endpoint.drt().nats_client() else {
93        anyhow::bail!("Missing NATS. Please ensure it is running and accessible.");
94    };
95    AddressedPushRouter::new(
96        nats_client.client().clone(),
97        endpoint.drt().tcp_server().await?,
98    )
99}
100
101impl<T, U> PushRouter<T, U>
102where
103    T: Data + Serialize,
104    U: Data + for<'de> Deserialize<'de> + MaybeError,
105{
106    /// Create a new PushRouter without busy threshold (no busy detection)
107    pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
108        Self::from_client_with_threshold(client, router_mode, None, None).await
109    }
110
111    /// Create a new PushRouter with optional busy threshold and worker load monitor
112    pub async fn from_client_with_threshold(
113        client: Client,
114        router_mode: RouterMode,
115        busy_threshold: Option<f64>,
116        worker_monitor: Option<Arc<dyn WorkerLoadMonitor>>,
117    ) -> anyhow::Result<Self> {
118        let addressed = addressed_router(&client.endpoint).await?;
119
120        // Start worker monitor if provided and in dynamic mode
121        if let Some(monitor) = worker_monitor.as_ref()
122            && matches!(client.instance_source.as_ref(), InstanceSource::Dynamic(_))
123        {
124            monitor.start_monitoring().await?;
125        }
126
127        let router = PushRouter {
128            client: client.clone(),
129            addressed,
130            router_mode,
131            round_robin_counter: Arc::new(AtomicU64::new(0)),
132            busy_threshold,
133            _phantom: PhantomData,
134        };
135
136        Ok(router)
137    }
138
139    /// Issue a request to the next available instance in a round-robin fashion
140    pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
141        let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
142
143        let instance_id = {
144            let instance_ids = self.client.instance_ids_avail();
145            let count = instance_ids.len();
146            if count == 0 {
147                return Err(anyhow::anyhow!(
148                    "no instances found for endpoint {:?}",
149                    self.client.endpoint.etcd_root()
150                ));
151            }
152            instance_ids[counter % count]
153        };
154        tracing::trace!("round robin router selected {instance_id}");
155
156        self.generate_with_fault_detection(instance_id, request)
157            .await
158    }
159
160    /// Issue a request to a random endpoint
161    pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
162        let instance_id = {
163            let instance_ids = self.client.instance_ids_avail();
164            let count = instance_ids.len();
165            if count == 0 {
166                return Err(anyhow::anyhow!(
167                    "no instances found for endpoint {:?}",
168                    self.client.endpoint.etcd_root()
169                ));
170            }
171            let counter = rand::rng().random::<u64>() as usize;
172            instance_ids[counter % count]
173        };
174        tracing::trace!("random router selected {instance_id}");
175
176        self.generate_with_fault_detection(instance_id, request)
177            .await
178    }
179
180    /// Issue a request to a specific endpoint
181    pub async fn direct(
182        &self,
183        request: SingleIn<T>,
184        instance_id: u64,
185    ) -> anyhow::Result<ManyOut<U>> {
186        let found = self.client.instance_ids_avail().contains(&instance_id);
187
188        if !found {
189            return Err(anyhow::anyhow!(
190                "instance_id={instance_id} not found for endpoint {:?}",
191                self.client.endpoint.etcd_root()
192            ));
193        }
194
195        self.generate_with_fault_detection(instance_id, request)
196            .await
197    }
198
199    pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
200        let subject = self.client.endpoint.subject();
201        tracing::debug!("static got subject: {subject}");
202        let request = request.map(|req| AddressedRequest::new(req, subject));
203        tracing::debug!("router generate");
204        self.addressed.generate(request).await
205    }
206
207    async fn generate_with_fault_detection(
208        &self,
209        instance_id: u64,
210        request: SingleIn<T>,
211    ) -> anyhow::Result<ManyOut<U>> {
212        // Check if all workers are busy (only if busy threshold is set)
213        if self.busy_threshold.is_some() {
214            let free_instances = self.client.instance_ids_free();
215            if free_instances.is_empty() {
216                // Check if we actually have any instances at all
217                let all_instances = self.client.instance_ids();
218                if !all_instances.is_empty() {
219                    return Err(PipelineError::ServiceOverloaded(
220                        "All workers are busy, please retry later".to_string(),
221                    )
222                    .into());
223                }
224            }
225        }
226
227        let subject = self.client.endpoint.subject_to(instance_id);
228        let request = request.map(|req| AddressedRequest::new(req, subject));
229
230        let stream: anyhow::Result<ManyOut<U>> = self.addressed.generate(request).await;
231        match stream {
232            Ok(stream) => {
233                let engine_ctx = stream.context();
234                let client = self.client.clone();
235                let stream = stream.map(move |res| {
236                    // TODO: Standardize error type to avoid using string matching DIS-364
237                    if let Some(err) = res.err()
238                        && format!("{:?}", err) == STREAM_ERR_MSG
239                    {
240                        tracing::debug!(
241                            "Reporting instance {instance_id} down due to stream error: {err}"
242                        );
243                        client.report_instance_down(instance_id);
244                    }
245                    res
246                });
247                Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
248            }
249            Err(err) => {
250                if let Some(req_err) = err.downcast_ref::<NatsRequestError>()
251                    && matches!(req_err.kind(), NatsNoResponders)
252                {
253                    tracing::debug!(
254                        "Reporting instance {instance_id} down due to request error: {req_err}"
255                    );
256                    self.client.report_instance_down(instance_id);
257                }
258                Err(err)
259            }
260        }
261    }
262}
263
264#[async_trait]
265impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
266where
267    T: Data + Serialize,
268    U: Data + for<'de> Deserialize<'de> + MaybeError,
269{
270    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
271        match self.client.instance_source.as_ref() {
272            InstanceSource::Static => self.r#static(request).await,
273            InstanceSource::Dynamic(_) => match self.router_mode {
274                RouterMode::Random => self.random(request).await,
275                RouterMode::RoundRobin => self.round_robin(request).await,
276                RouterMode::Direct(instance_id) => self.direct(request, instance_id).await,
277                RouterMode::KV => {
278                    anyhow::bail!("KV routing should not call generate on PushRouter");
279                }
280            },
281        }
282    }
283}