dynamo_runtime/pipeline/network/egress/
push_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG};
5use crate::utils::worker_monitor::WorkerMonitor;
6use crate::{
7    component::{Client, Endpoint, InstanceSource},
8    engine::{AsyncEngine, Data},
9    pipeline::{
10        AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
11        error::{PipelineError, PipelineErrorExt},
12    },
13    protocols::maybe_error::MaybeError,
14    traits::DistributedRuntimeProvider,
15};
16use async_nats::client::{
17    RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
18};
19use async_trait::async_trait;
20use rand::Rng;
21use serde::{Deserialize, Serialize};
22use std::{
23    future::Future,
24    marker::PhantomData,
25    sync::{
26        Arc,
27        atomic::{AtomicU64, Ordering},
28    },
29};
30use tokio_stream::StreamExt;
31
32#[derive(Clone)]
33pub struct PushRouter<T, U>
34where
35    T: Data + Serialize,
36    U: Data + for<'de> Deserialize<'de>,
37{
38    // TODO: This shouldn't be pub, but lib/bindings/python/rust/lib.rs exposes it.
39    /// The Client is how we gather remote endpoint information from etcd.
40    pub client: Client,
41
42    /// How we choose which instance to send traffic to.
43    ///
44    /// Setting this to KV means we never intend to call `generate` on this PushRouter. We are
45    /// not using it as an AsyncEngine.
46    /// Instead we will decide whether to call random/round_robin/direct ourselves and call them directly.
47    /// dynamo-llm's KV Routing does this.
48    router_mode: RouterMode,
49
50    /// Number of round robin requests handled. Used to decide which server is next.
51    round_robin_counter: Arc<AtomicU64>,
52
53    /// The next step in the chain. PushRouter (this object) picks an instances,
54    /// addresses it, then passes it to AddressedPushRouter which does the network traffic.
55    addressed: Arc<AddressedPushRouter>,
56
57    /// Worker monitor for tracking KV cache usage
58    worker_monitor: Option<Arc<WorkerMonitor>>,
59
60    /// Threshold for determining when a worker is busy (0.0 to 1.0)
61    /// If None, busy detection is disabled
62    busy_threshold: Option<f64>,
63
64    /// An internal Rust type. This says that PushRouter is generic over the T and U types,
65    /// which are the input and output types of it's `generate` function. It allows the
66    /// compiler to specialize us at compile time.
67    _phantom: PhantomData<(T, U)>,
68}
69
70#[derive(Default, Debug, Clone, Copy, PartialEq)]
71pub enum RouterMode {
72    #[default]
73    RoundRobin,
74    Random,
75    Direct(i64),
76    // Marker value, KV routing itself is in dynamo-llm
77    KV,
78}
79
80impl RouterMode {
81    pub fn is_kv_routing(&self) -> bool {
82        *self == RouterMode::KV
83    }
84}
85
86async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
87    AddressedPushRouter::new(
88        endpoint.drt().nats_client.client().clone(),
89        endpoint.drt().tcp_server().await?,
90    )
91}
92
93impl<T, U> PushRouter<T, U>
94where
95    T: Data + Serialize,
96    U: Data + for<'de> Deserialize<'de> + MaybeError,
97{
98    /// Create a new PushRouter without busy threshold (no busy detection)
99    pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
100        Self::from_client_with_threshold(client, router_mode, None).await
101    }
102
103    /// Create a new PushRouter with optional busy threshold
104    pub async fn from_client_with_threshold(
105        client: Client,
106        router_mode: RouterMode,
107        busy_threshold: Option<f64>,
108    ) -> anyhow::Result<Self> {
109        let addressed = addressed_router(&client.endpoint).await?;
110
111        // Create worker monitor only if we have a threshold and are in dynamic mode
112        let worker_monitor = match (busy_threshold, client.instance_source.as_ref()) {
113            (Some(threshold), InstanceSource::Dynamic(_)) => {
114                let monitor = Arc::new(WorkerMonitor::new_with_threshold(
115                    Arc::new(client.clone()),
116                    threshold,
117                ));
118                monitor.start_monitoring().await?;
119                Some(monitor)
120            }
121            _ => None,
122        };
123
124        let router = PushRouter {
125            client: client.clone(),
126            addressed,
127            router_mode,
128            round_robin_counter: Arc::new(AtomicU64::new(0)),
129            worker_monitor,
130            busy_threshold,
131            _phantom: PhantomData,
132        };
133
134        Ok(router)
135    }
136
137    /// Issue a request to the next available instance in a round-robin fashion
138    pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
139        let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
140
141        let instance_id = {
142            let instance_ids = self.client.instance_ids_avail();
143            let count = instance_ids.len();
144            if count == 0 {
145                return Err(anyhow::anyhow!(
146                    "no instances found for endpoint {:?}",
147                    self.client.endpoint.etcd_root()
148                ));
149            }
150            instance_ids[counter % count]
151        };
152        tracing::trace!("round robin router selected {instance_id}");
153
154        self.generate_with_fault_detection(instance_id, request)
155            .await
156    }
157
158    /// Issue a request to a random endpoint
159    pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
160        let instance_id = {
161            let instance_ids = self.client.instance_ids_avail();
162            let count = instance_ids.len();
163            if count == 0 {
164                return Err(anyhow::anyhow!(
165                    "no instances found for endpoint {:?}",
166                    self.client.endpoint.etcd_root()
167                ));
168            }
169            let counter = rand::rng().random::<u64>() as usize;
170            instance_ids[counter % count]
171        };
172        tracing::trace!("random router selected {instance_id}");
173
174        self.generate_with_fault_detection(instance_id, request)
175            .await
176    }
177
178    /// Issue a request to a specific endpoint
179    pub async fn direct(
180        &self,
181        request: SingleIn<T>,
182        instance_id: i64,
183    ) -> anyhow::Result<ManyOut<U>> {
184        let found = self.client.instance_ids_avail().contains(&instance_id);
185
186        if !found {
187            return Err(anyhow::anyhow!(
188                "instance_id={instance_id} not found for endpoint {:?}",
189                self.client.endpoint.etcd_root()
190            ));
191        }
192
193        self.generate_with_fault_detection(instance_id, request)
194            .await
195    }
196
197    pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
198        let subject = self.client.endpoint.subject();
199        tracing::debug!("static got subject: {subject}");
200        let request = request.map(|req| AddressedRequest::new(req, subject));
201        tracing::debug!("router generate");
202        self.addressed.generate(request).await
203    }
204
205    async fn generate_with_fault_detection(
206        &self,
207        instance_id: i64,
208        request: SingleIn<T>,
209    ) -> anyhow::Result<ManyOut<U>> {
210        // Check if all workers are busy (only if busy threshold is set)
211        if self.busy_threshold.is_some() {
212            let free_instances = self.client.instance_ids_free();
213            if free_instances.is_empty() {
214                // Check if we actually have any instances at all
215                let all_instances = self.client.instance_ids();
216                if !all_instances.is_empty() {
217                    return Err(PipelineError::ServiceOverloaded(
218                        "All workers are busy, please retry later".to_string(),
219                    )
220                    .into());
221                }
222            }
223        }
224
225        let subject = self.client.endpoint.subject_to(instance_id);
226        let request = request.map(|req| AddressedRequest::new(req, subject));
227
228        let stream: anyhow::Result<ManyOut<U>> = self.addressed.generate(request).await;
229        match stream {
230            Ok(stream) => {
231                let engine_ctx = stream.context();
232                let client = self.client.clone();
233                let stream = stream.map(move |res| {
234                    // TODO: Standardize error type to avoid using string matching DIS-364
235                    if let Some(err) = res.err()
236                        && format!("{:?}", err) == STREAM_ERR_MSG
237                    {
238                        tracing::debug!(
239                            "Reporting instance {instance_id} down due to stream error: {err}"
240                        );
241                        client.report_instance_down(instance_id);
242                    }
243                    res
244                });
245                Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
246            }
247            Err(err) => {
248                if let Some(req_err) = err.downcast_ref::<NatsRequestError>()
249                    && matches!(req_err.kind(), NatsNoResponders)
250                {
251                    tracing::debug!(
252                        "Reporting instance {instance_id} down due to request error: {req_err}"
253                    );
254                    self.client.report_instance_down(instance_id);
255                }
256                Err(err)
257            }
258        }
259    }
260}
261
262#[async_trait]
263impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
264where
265    T: Data + Serialize,
266    U: Data + for<'de> Deserialize<'de> + MaybeError,
267{
268    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
269        match self.client.instance_source.as_ref() {
270            InstanceSource::Static => self.r#static(request).await,
271            InstanceSource::Dynamic(_) => match self.router_mode {
272                RouterMode::Random => self.random(request).await,
273                RouterMode::RoundRobin => self.round_robin(request).await,
274                RouterMode::Direct(instance_id) => self.direct(request, instance_id).await,
275                RouterMode::KV => {
276                    anyhow::bail!("KV routing should not call generate on PushRouter");
277                }
278            },
279        }
280    }
281}