Skip to main content

dynamo_runtime/pipeline/network/egress/
addressed_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5use std::time::Instant;
6
7use super::unified_client::RequestPlaneClient;
8use super::*;
9use crate::component::Instance;
10use crate::discovery::EndpointInstanceId;
11use crate::dynamo_nvtx_range;
12use crate::engine::{AsyncEngine, AsyncEngineContextProvider, Data};
13use crate::error::{DynamoError, ErrorType};
14use crate::logging::inject_trace_headers_into_map;
15use crate::metrics::frontend_perf::STAGE_DURATION_SECONDS;
16use crate::metrics::request_plane::{
17    REQUEST_PLANE_INFLIGHT, REQUEST_PLANE_QUEUE_SECONDS, REQUEST_PLANE_ROUNDTRIP_TTFT_SECONDS,
18    REQUEST_PLANE_SEND_SECONDS,
19};
20use crate::pipeline::network::ConnectionInfo;
21use crate::pipeline::network::NetworkStreamWrapper;
22use crate::pipeline::network::PendingConnections;
23use crate::pipeline::network::StreamOptions;
24use crate::pipeline::network::TwoPartCodec;
25use crate::pipeline::network::codec::TwoPartMessage;
26use crate::pipeline::network::tcp;
27use crate::pipeline::{ManyOut, PipelineError, ResponseStream, SingleIn};
28use crate::protocols::maybe_error::MaybeError;
29
30use anyhow::{Error, Result};
31use futures::stream::Stream;
32use serde::Deserialize;
33use serde::Serialize;
34use std::pin::Pin;
35use std::task::{Context, Poll};
36use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
37use tracing::Instrument;
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40#[serde(rename_all = "snake_case")]
41enum RequestType {
42    SingleIn,
43    ManyIn,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48enum ResponseType {
49    SingleOut,
50    ManyOut,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54struct RequestControlMessage {
55    id: String,
56    request_type: RequestType,
57    response_type: ResponseType,
58    connection_info: ConnectionInfo,
59    /// Wall-clock send timestamp (nanos since UNIX epoch) for transport latency breakdown.
60    /// Uses `SystemTime` so accuracy depends on NTP sync between frontend and backend hosts.
61    /// Reliable for single-machine profiling; treat cross-host values as approximate.
62    #[serde(default, skip_serializing_if = "Option::is_none")]
63    frontend_send_ts_ns: Option<u64>,
64}
65
66/// RAII guard that decrements REQUEST_PLANE_INFLIGHT on drop unless disarmed.
67/// Protects against gauge leaks when `?` operators cause early returns between
68/// the increment and `InflightDecStream` construction.
69struct InflightGuard {
70    armed: bool,
71}
72
73impl InflightGuard {
74    fn new() -> Self {
75        Self { armed: true }
76    }
77
78    /// Consume the guard without decrementing. Call this when `InflightDecStream`
79    /// takes over responsibility for the decrement.
80    fn disarm(mut self) {
81        self.armed = false;
82    }
83}
84
85impl Drop for InflightGuard {
86    fn drop(&mut self) {
87        if self.armed {
88            REQUEST_PLANE_INFLIGHT.dec();
89        }
90    }
91}
92
93/// Wrapper that decrements request-plane inflight gauge when the stream is dropped.
94struct InflightDecStream<S> {
95    inner: S,
96}
97
98impl<S, T> Stream for InflightDecStream<S>
99where
100    S: Stream<Item = T> + Unpin,
101{
102    type Item = T;
103
104    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
105        Pin::new(&mut self.inner).poll_next(cx)
106    }
107}
108
109impl<S> Drop for InflightDecStream<S> {
110    fn drop(&mut self) {
111        REQUEST_PLANE_INFLIGHT.dec();
112    }
113}
114
115pub struct AddressedRequest<T> {
116    request: T,
117    address: String,
118    /// Carries endpoint name + instance_id so cancellation is scoped to the
119    /// exact (endpoint, instance) pair, not all endpoints on the same runtime.
120    instance: Option<Instance>,
121}
122
123impl<T> AddressedRequest<T> {
124    pub fn new(request: T, address: String) -> Self {
125        Self {
126            request,
127            address,
128            instance: None,
129        }
130    }
131
132    pub fn with_instance(request: T, address: String, instance: Instance) -> Self {
133        Self {
134            request,
135            address,
136            instance: Some(instance),
137        }
138    }
139
140    pub(crate) fn into_parts(self) -> (T, String, Option<Instance>) {
141        (self.request, self.address, self.instance)
142    }
143}
144
145pub struct AddressedPushRouter {
146    // Request transport (unified trait object - works with all transports)
147    req_client: Arc<dyn RequestPlaneClient>,
148
149    // Response transport (TCP streaming - unchanged)
150    resp_transport: Arc<tcp::server::TcpStreamServer>,
151}
152
153impl AddressedPushRouter {
154    /// Create a new router with a request plane client
155    ///
156    /// This is the unified constructor that works with any transport type.
157    /// The client is provided as a trait object, hiding the specific implementation.
158    pub fn new(
159        req_client: Arc<dyn RequestPlaneClient>,
160        resp_transport: Arc<tcp::server::TcpStreamServer>,
161    ) -> Result<Arc<Self>> {
162        Ok(Arc::new(Self {
163            req_client,
164            resp_transport,
165        }))
166    }
167
168    /// Cancel all pending response-stream registrations for an instance.
169    pub async fn cancel_instance_streams(&self, instance_id: &EndpointInstanceId) -> usize {
170        self.resp_transport
171            .cancel_instance_streams(instance_id)
172            .await
173    }
174
175    /// Clear the tombstone after an instance reappears in discovery.
176    pub async fn clear_instance_tombstone(&self, instance_id: &EndpointInstanceId) {
177        self.resp_transport
178            .clear_instance_tombstone(instance_id)
179            .await
180    }
181}
182
183#[async_trait::async_trait]
184impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
185where
186    T: Data + Serialize,
187    U: Data + for<'de> Deserialize<'de> + MaybeError,
188{
189    async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
190        let queue_start = Instant::now();
191        REQUEST_PLANE_INFLIGHT.inc();
192        let inflight_guard = InflightGuard::new();
193
194        let request_id = request.context().id().to_string();
195        let (addressed_request, context) = request.transfer(());
196        let (request, address, instance_info) = addressed_request.into_parts();
197        let engine_ctx = context.context();
198        let engine_ctx_ = engine_ctx.clone();
199
200        // registration options for the data plane in a singe in / many out configuration
201        let options = StreamOptions::builder()
202            .context(engine_ctx.clone())
203            .enable_request_stream(false)
204            .enable_response_stream(true)
205            .build()
206            .unwrap();
207
208        // register our needs with the data plane
209        // todo - generalize this with a generic data plane object which hides the specific transports
210        let pending_connections: PendingConnections = self.resp_transport.register(options).await;
211
212        // validate and unwrap the RegisteredStream object
213        let pending_response_stream = match pending_connections.into_parts() {
214            (None, Some(recv_stream)) => recv_stream,
215            _ => {
216                panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
217            }
218        };
219
220        // separate out the connection info and the stream provider from the registered stream
221        let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
222
223        // Snapshot subject before connection_info is moved; used for cleanup.
224        let recv_subject: Option<String> =
225            serde_json::from_str::<tcp::TcpStreamConnectionInfo>(&connection_info.info)
226                .ok()
227                .map(|ci| ci.subject);
228
229        // If the instance is already tombstoned, fail fast with a migratable
230        // error instead of writing to the request plane.
231        if let (Some(subject), Some(inst)) = (&recv_subject, &instance_info) {
232            let endpoint_instance_id = inst.endpoint_instance_id();
233            if !self
234                .resp_transport
235                .associate_instance(subject, &endpoint_instance_id)
236                .await
237            {
238                return Err(anyhow::anyhow!(
239                    DynamoError::builder()
240                        .error_type(ErrorType::Disconnected)
241                        .message(
242                            "Worker removed before request could be sent (tombstoned instance)"
243                        )
244                        .build()
245                ));
246            }
247        }
248
249        // package up the connection info as part of the "header" component of the two part message
250        // used to issue the request on the
251        // todo -- this object should be automatically created by the register call, and achieved by to the two into_parts()
252        // calls. all the information here is provided by the [`StreamOptions`] object and/or the dataplane object
253        let control_message = RequestControlMessage {
254            id: engine_ctx.id().to_string(),
255            request_type: RequestType::SingleIn,
256            response_type: ResponseType::ManyOut,
257            connection_info,
258            frontend_send_ts_ns: None,
259        };
260
261        // next build the two part message where we package the connection info and the request into
262        // a single Vec<u8> that can be sent over the wire.
263        // --- package this up in the WorkQueuePublisher ---
264        let ctrl = match serde_json::to_vec(&control_message) {
265            Ok(v) => v,
266            Err(e) => {
267                if let Some(subject) = &recv_subject {
268                    self.resp_transport.cancel_recv_stream(subject).await;
269                }
270                return Err(e.into());
271            }
272        };
273        let data = match serde_json::to_vec(&request) {
274            Ok(v) => v,
275            Err(e) => {
276                if let Some(subject) = &recv_subject {
277                    self.resp_transport.cancel_recv_stream(subject).await;
278                }
279                return Err(e.into());
280            }
281        };
282
283        tracing::trace!(
284            request_id,
285            "packaging two-part message; ctrl: {} bytes, data: {} bytes",
286            ctrl.len(),
287            data.len()
288        );
289
290        let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
291
292        // the request plane / work queue should provide a two part message codec that can be used
293        // or it should take a two part message directly
294        // todo - update this
295        let codec = TwoPartCodec::default();
296        let buffer = match codec.encode_message(msg) {
297            Ok(v) => v,
298            Err(e) => {
299                if let Some(subject) = &recv_subject {
300                    self.resp_transport.cancel_recv_stream(subject).await;
301                }
302                return Err(e.into());
303            }
304        };
305
306        REQUEST_PLANE_QUEUE_SECONDS.observe(queue_start.elapsed().as_secs_f64());
307        let tx_start = Instant::now();
308
309        // TRANSPORT ABSTRACT REQUIRED - END HERE
310
311        // Send request using unified client interface
312        tracing::trace!(
313            request_id,
314            transport = self.req_client.transport_name(),
315            address = %address,
316            "Sending request via request plane client"
317        );
318
319        // Prepare trace headers using shared helper
320        let mut headers = std::collections::HashMap::new();
321        inject_trace_headers_into_map(&mut headers);
322        headers.insert("request-id".to_string(), request_id.clone());
323
324        // Stamp send time right before the transport write so the network
325        // transit metric excludes serialization/encoding overhead.
326        let send_ts_ns = std::time::SystemTime::now()
327            .duration_since(std::time::UNIX_EPOCH)
328            .unwrap_or_default()
329            .as_nanos() as u64;
330        headers.insert("x-frontend-send-ts-ns".to_string(), send_ts_ns.to_string());
331
332        // Phase A: Frontend → Backend (network + queue + ack)
333        let _nvtx_send = dynamo_nvtx_range!("transport.tcp.send");
334        let send_result = self.req_client.send_request(address, buffer, headers).await;
335        drop(_nvtx_send);
336
337        if let Err(e) = send_result {
338            if let Some(subject) = &recv_subject {
339                self.resp_transport.cancel_recv_stream(subject).await;
340            }
341            return Err(e);
342        }
343        REQUEST_PLANE_SEND_SECONDS.observe(tx_start.elapsed().as_secs_f64());
344
345        let _nvtx_wait = dynamo_nvtx_range!("transport.tcp.wait_backend");
346        tracing::trace!(request_id, "awaiting transport handshake");
347
348        // RecvError → migratable Disconnected (watcher cancelled the subject
349        // or the worker died before establishing the response stream).
350        let response_stream = match response_stream_provider.await {
351            Ok(Ok(stream)) => stream,
352            Ok(Err(e)) => {
353                // generate() failed before any response bytes; migrate via
354                // CannotConnect since the dominant cause is a worker-local
355                // setup/version issue. The wire prologue carries only an
356                // opaque string today, so app-level rejections also retry
357                // -- safe because no side effects are visible yet. Follow-up:
358                // structured prologue error type for finer routing.
359                if let Some(subject) = &recv_subject {
360                    self.resp_transport.cancel_recv_stream(subject).await;
361                }
362                return Err(anyhow::anyhow!(
363                    DynamoError::builder()
364                        .error_type(ErrorType::CannotConnect)
365                        .message(format!(
366                            "Worker generate() failed before response stream: {e}"
367                        ))
368                        .build()
369                ));
370            }
371            Err(_recv_err) => {
372                // oneshot dropped: either the discovery watcher cancelled
373                // this subject or the worker died mid-handshake.
374                if let Some(subject) = &recv_subject {
375                    self.resp_transport.cancel_recv_stream(subject).await;
376                }
377                return Err(anyhow::anyhow!(
378                    DynamoError::builder()
379                        .error_type(ErrorType::Disconnected)
380                        .message("Worker disconnected before response stream was established")
381                        .build()
382                ));
383            }
384        };
385        drop(_nvtx_wait);
386
387        // TODO: Detect end-of-stream using Server-Sent Events (SSE)
388        let mut is_complete_final = false;
389        let mut first_response = true;
390        let stream = tokio_stream::StreamNotifyClose::new(
391            tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
392        )
393        .filter_map(move |res| {
394            if let Some(res_bytes) = res {
395                if first_response {
396                    first_response = false;
397                    let roundtrip_ttft = tx_start.elapsed().as_secs_f64();
398                    REQUEST_PLANE_ROUNDTRIP_TTFT_SECONDS.observe(roundtrip_ttft);
399                    STAGE_DURATION_SECONDS
400                        .with_label_values(&["transport_roundtrip"])
401                        .observe(queue_start.elapsed().as_secs_f64());
402                }
403                if is_complete_final {
404                    let err = DynamoError::msg(
405                        "Response received after generation ended - this should never happen",
406                    );
407                    return Some(U::from_err(err));
408                }
409                match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
410                    Ok(item) => {
411                        is_complete_final = item.complete_final;
412                        if let Some(data) = item.data {
413                            Some(data)
414                        } else if is_complete_final {
415                            None
416                        } else {
417                            let err = DynamoError::msg(
418                                "Empty response received - this should never happen",
419                            );
420                            Some(U::from_err(err))
421                        }
422                    }
423                    Err(err) => {
424                        // legacy log print
425                        let json_str = String::from_utf8_lossy(&res_bytes);
426                        tracing::warn!(%err, %json_str, "Failed deserializing JSON to response");
427
428                        Some(U::from_err(DynamoError::msg(err.to_string())))
429                    }
430                }
431            } else if is_complete_final {
432                // end of stream
433                None
434            } else if engine_ctx_.is_stopped() {
435                // Gracefully end the stream if 'stop_generating()' was called. Do NOT check for
436                // 'is_killed()' here because it implies the stream ended abnormally which should be
437                // handled by the error branch below.
438                tracing::debug!("Request cancelled and then trying to read a response");
439                None
440            } else {
441                // stream ended unexpectedly
442                let err = DynamoError::builder()
443                    .error_type(ErrorType::Disconnected)
444                    .message("Stream ended before generation completed")
445                    .build();
446                tracing::debug!("{err}");
447                Some(U::from_err(err))
448            }
449        });
450
451        inflight_guard.disarm();
452        let stream = InflightDecStream { inner: stream };
453        Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
454    }
455}