Skip to main content

dynamo_runtime/pipeline/network/egress/
addressed_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5use std::time::Instant;
6
7use super::unified_client::RequestPlaneClient;
8use super::*;
9use crate::dynamo_nvtx_range;
10use crate::engine::{AsyncEngine, AsyncEngineContextProvider, Data};
11use crate::error::{DynamoError, ErrorType};
12use crate::logging::inject_trace_headers_into_map;
13use crate::metrics::frontend_perf::STAGE_DURATION_SECONDS;
14use crate::metrics::request_plane::{
15    REQUEST_PLANE_INFLIGHT, REQUEST_PLANE_QUEUE_SECONDS, REQUEST_PLANE_ROUNDTRIP_TTFT_SECONDS,
16    REQUEST_PLANE_SEND_SECONDS,
17};
18use crate::pipeline::network::ConnectionInfo;
19use crate::pipeline::network::NetworkStreamWrapper;
20use crate::pipeline::network::PendingConnections;
21use crate::pipeline::network::StreamOptions;
22use crate::pipeline::network::TwoPartCodec;
23use crate::pipeline::network::codec::TwoPartMessage;
24use crate::pipeline::network::tcp;
25use crate::pipeline::{ManyOut, PipelineError, ResponseStream, SingleIn};
26use crate::protocols::maybe_error::MaybeError;
27
28use anyhow::{Error, Result};
29use futures::stream::Stream;
30use serde::Deserialize;
31use serde::Serialize;
32use std::pin::Pin;
33use std::task::{Context, Poll};
34use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
35use tracing::Instrument;
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(rename_all = "snake_case")]
39enum RequestType {
40    SingleIn,
41    ManyIn,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45#[serde(rename_all = "snake_case")]
46enum ResponseType {
47    SingleOut,
48    ManyOut,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52struct RequestControlMessage {
53    id: String,
54    request_type: RequestType,
55    response_type: ResponseType,
56    connection_info: ConnectionInfo,
57    /// Wall-clock send timestamp (nanos since UNIX epoch) for transport latency breakdown.
58    /// Uses `SystemTime` so accuracy depends on NTP sync between frontend and backend hosts.
59    /// Reliable for single-machine profiling; treat cross-host values as approximate.
60    #[serde(default, skip_serializing_if = "Option::is_none")]
61    frontend_send_ts_ns: Option<u64>,
62}
63
64/// RAII guard that decrements REQUEST_PLANE_INFLIGHT on drop unless disarmed.
65/// Protects against gauge leaks when `?` operators cause early returns between
66/// the increment and `InflightDecStream` construction.
67struct InflightGuard {
68    armed: bool,
69}
70
71impl InflightGuard {
72    fn new() -> Self {
73        Self { armed: true }
74    }
75
76    /// Consume the guard without decrementing. Call this when `InflightDecStream`
77    /// takes over responsibility for the decrement.
78    fn disarm(mut self) {
79        self.armed = false;
80    }
81}
82
83impl Drop for InflightGuard {
84    fn drop(&mut self) {
85        if self.armed {
86            REQUEST_PLANE_INFLIGHT.dec();
87        }
88    }
89}
90
91/// Wrapper that decrements request-plane inflight gauge when the stream is dropped.
92struct InflightDecStream<S> {
93    inner: S,
94}
95
96impl<S, T> Stream for InflightDecStream<S>
97where
98    S: Stream<Item = T> + Unpin,
99{
100    type Item = T;
101
102    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
103        Pin::new(&mut self.inner).poll_next(cx)
104    }
105}
106
107impl<S> Drop for InflightDecStream<S> {
108    fn drop(&mut self) {
109        REQUEST_PLANE_INFLIGHT.dec();
110    }
111}
112
113pub struct AddressedRequest<T> {
114    request: T,
115    address: String,
116}
117
118impl<T> AddressedRequest<T> {
119    pub fn new(request: T, address: String) -> Self {
120        Self { request, address }
121    }
122
123    pub(crate) fn into_parts(self) -> (T, String) {
124        (self.request, self.address)
125    }
126}
127
128pub struct AddressedPushRouter {
129    // Request transport (unified trait object - works with all transports)
130    req_client: Arc<dyn RequestPlaneClient>,
131
132    // Response transport (TCP streaming - unchanged)
133    resp_transport: Arc<tcp::server::TcpStreamServer>,
134}
135
136impl AddressedPushRouter {
137    /// Create a new router with a request plane client
138    ///
139    /// This is the unified constructor that works with any transport type.
140    /// The client is provided as a trait object, hiding the specific implementation.
141    pub fn new(
142        req_client: Arc<dyn RequestPlaneClient>,
143        resp_transport: Arc<tcp::server::TcpStreamServer>,
144    ) -> Result<Arc<Self>> {
145        Ok(Arc::new(Self {
146            req_client,
147            resp_transport,
148        }))
149    }
150}
151
152#[async_trait::async_trait]
153impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
154where
155    T: Data + Serialize,
156    U: Data + for<'de> Deserialize<'de> + MaybeError,
157{
158    async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
159        let queue_start = Instant::now();
160        REQUEST_PLANE_INFLIGHT.inc();
161        let inflight_guard = InflightGuard::new();
162
163        let request_id = request.context().id().to_string();
164        let (addressed_request, context) = request.transfer(());
165        let (request, address) = addressed_request.into_parts();
166        let engine_ctx = context.context();
167        let engine_ctx_ = engine_ctx.clone();
168
169        // registration options for the data plane in a singe in / many out configuration
170        let options = StreamOptions::builder()
171            .context(engine_ctx.clone())
172            .enable_request_stream(false)
173            .enable_response_stream(true)
174            .build()
175            .unwrap();
176
177        // register our needs with the data plane
178        // todo - generalize this with a generic data plane object which hides the specific transports
179        let pending_connections: PendingConnections = self.resp_transport.register(options).await;
180
181        // validate and unwrap the RegisteredStream object
182        let pending_response_stream = match pending_connections.into_parts() {
183            (None, Some(recv_stream)) => recv_stream,
184            _ => {
185                panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
186            }
187        };
188
189        // separate out the connection info and the stream provider from the registered stream
190        let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
191
192        // package up the connection info as part of the "header" component of the two part message
193        // used to issue the request on the
194        // todo -- this object should be automatically created by the register call, and achieved by to the two into_parts()
195        // calls. all the information here is provided by the [`StreamOptions`] object and/or the dataplane object
196        let control_message = RequestControlMessage {
197            id: engine_ctx.id().to_string(),
198            request_type: RequestType::SingleIn,
199            response_type: ResponseType::ManyOut,
200            connection_info,
201            frontend_send_ts_ns: None,
202        };
203
204        // next build the two part message where we package the connection info and the request into
205        // a single Vec<u8> that can be sent over the wire.
206        // --- package this up in the WorkQueuePublisher ---
207        let ctrl = serde_json::to_vec(&control_message)?;
208        let data = serde_json::to_vec(&request)?;
209
210        tracing::trace!(
211            request_id,
212            "packaging two-part message; ctrl: {} bytes, data: {} bytes",
213            ctrl.len(),
214            data.len()
215        );
216
217        let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
218
219        // the request plane / work queue should provide a two part message codec that can be used
220        // or it should take a two part message directly
221        // todo - update this
222        let codec = TwoPartCodec::default();
223        let buffer = {
224            let _nvtx = dynamo_nvtx_range!("codec.encode");
225            codec.encode_message(msg)?
226        };
227
228        REQUEST_PLANE_QUEUE_SECONDS.observe(queue_start.elapsed().as_secs_f64());
229        let tx_start = Instant::now();
230
231        // TRANSPORT ABSTRACT REQUIRED - END HERE
232
233        // Send request using unified client interface
234        tracing::trace!(
235            request_id,
236            transport = self.req_client.transport_name(),
237            address = %address,
238            "Sending request via request plane client"
239        );
240
241        // Prepare trace headers using shared helper
242        let mut headers = std::collections::HashMap::new();
243        inject_trace_headers_into_map(&mut headers);
244        headers.insert("request-id".to_string(), request_id.clone());
245
246        // Stamp send time right before the transport write so the network
247        // transit metric excludes serialization/encoding overhead.
248        let send_ts_ns = std::time::SystemTime::now()
249            .duration_since(std::time::UNIX_EPOCH)
250            .unwrap_or_default()
251            .as_nanos() as u64;
252        headers.insert("x-frontend-send-ts-ns".to_string(), send_ts_ns.to_string());
253
254        // Phase A: Frontend → Backend (network + queue + ack)
255        let _nvtx_send = dynamo_nvtx_range!("transport.tcp.send");
256        let _response = self
257            .req_client
258            .send_request(address, buffer, headers)
259            .await?;
260        drop(_nvtx_send);
261        REQUEST_PLANE_SEND_SECONDS.observe(tx_start.elapsed().as_secs_f64());
262
263        let _nvtx_wait = dynamo_nvtx_range!("transport.tcp.wait_backend");
264        tracing::trace!(request_id, "awaiting transport handshake");
265        let response_stream = response_stream_provider
266            .await
267            .map_err(|_| PipelineError::DetachedStreamReceiver)?
268            .map_err(PipelineError::ConnectionFailed)?;
269        drop(_nvtx_wait);
270
271        // TODO: Detect end-of-stream using Server-Sent Events (SSE)
272        let mut is_complete_final = false;
273        let mut first_response = true;
274        let stream = tokio_stream::StreamNotifyClose::new(
275            tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
276        )
277        .filter_map(move |res| {
278            if let Some(res_bytes) = res {
279                if first_response {
280                    first_response = false;
281                    let roundtrip_ttft = tx_start.elapsed().as_secs_f64();
282                    REQUEST_PLANE_ROUNDTRIP_TTFT_SECONDS.observe(roundtrip_ttft);
283                    STAGE_DURATION_SECONDS
284                        .with_label_values(&["transport_roundtrip"])
285                        .observe(queue_start.elapsed().as_secs_f64());
286                }
287                if is_complete_final {
288                    let err = DynamoError::msg(
289                        "Response received after generation ended - this should never happen",
290                    );
291                    return Some(U::from_err(err));
292                }
293                match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
294                    Ok(item) => {
295                        is_complete_final = item.complete_final;
296                        if let Some(data) = item.data {
297                            Some(data)
298                        } else if is_complete_final {
299                            None
300                        } else {
301                            let err = DynamoError::msg(
302                                "Empty response received - this should never happen",
303                            );
304                            Some(U::from_err(err))
305                        }
306                    }
307                    Err(err) => {
308                        // legacy log print
309                        let json_str = String::from_utf8_lossy(&res_bytes);
310                        tracing::warn!(%err, %json_str, "Failed deserializing JSON to response");
311
312                        Some(U::from_err(DynamoError::msg(err.to_string())))
313                    }
314                }
315            } else if is_complete_final {
316                // end of stream
317                None
318            } else if engine_ctx_.is_stopped() {
319                // Gracefully end the stream if 'stop_generating()' was called. Do NOT check for
320                // 'is_killed()' here because it implies the stream ended abnormally which should be
321                // handled by the error branch below.
322                tracing::debug!("Request cancelled and then trying to read a response");
323                None
324            } else {
325                // stream ended unexpectedly
326                let err = DynamoError::builder()
327                    .error_type(ErrorType::Disconnected)
328                    .message("Stream ended before generation completed")
329                    .build();
330                tracing::debug!("{err}");
331                Some(U::from_err(err))
332            }
333        });
334
335        inflight_guard.disarm();
336        let stream = InflightDecStream { inner: stream };
337        Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
338    }
339}