Skip to main content

dynamo_runtime/pipeline/network/egress/
addressed_router.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::BTreeMap;
5use std::sync::Arc;
6use std::time::Instant;
7
8use super::unified_client::RequestPlaneClient;
9use super::*;
10use crate::component::Instance;
11use crate::discovery::EndpointInstanceId;
12use crate::dynamo_nvtx_range;
13use crate::engine::{AsyncEngine, AsyncEngineContextProvider, Data};
14use crate::error::{DynamoError, ErrorType};
15use crate::logging::inject_trace_headers_into_map;
16use crate::metrics::frontend_perf::STAGE_DURATION_SECONDS;
17use crate::metrics::request_plane::{
18    REQUEST_PLANE_INFLIGHT, REQUEST_PLANE_QUEUE_SECONDS, REQUEST_PLANE_ROUNDTRIP_TTFT_SECONDS,
19    REQUEST_PLANE_SEND_SECONDS,
20};
21use crate::pipeline::network::ConnectionInfo;
22use crate::pipeline::network::NetworkStreamWrapper;
23use crate::pipeline::network::PendingConnections;
24use crate::pipeline::network::StreamOptions;
25use crate::pipeline::network::TwoPartCodec;
26use crate::pipeline::network::codec::TwoPartMessage;
27use crate::pipeline::network::tcp;
28use crate::pipeline::{ManyOut, PipelineError, ResponseStream, SingleIn};
29use crate::protocols::maybe_error::MaybeError;
30use crate::traits::DistributedRuntimeProvider;
31
32use anyhow::{Error, Result};
33use futures::stream::Stream;
34use serde::Deserialize;
35use serde::Serialize;
36use std::pin::Pin;
37use std::task::{Context, Poll};
38use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
39use tracing::Instrument;
40
41const CONTROL_MESSAGE_MAX_BYTES: usize = 128 * 1024;
42
43fn serialize_control_message(control_message: &RequestControlMessage) -> Result<Vec<u8>, Error> {
44    let ctrl = serde_json::to_vec(control_message)?;
45    if ctrl.len() > CONTROL_MESSAGE_MAX_BYTES {
46        return Err(PipelineError::Generic(format!(
47            "request control message too large: {} bytes exceeds limit {}",
48            ctrl.len(),
49            CONTROL_MESSAGE_MAX_BYTES
50        ))
51        .into());
52    }
53    Ok(ctrl)
54}
55
56/// RAII guard that decrements REQUEST_PLANE_INFLIGHT on drop unless disarmed.
57/// Protects against gauge leaks when `?` operators cause early returns between
58/// the increment and `InflightDecStream` construction.
59struct InflightGuard {
60    armed: bool,
61}
62
63impl InflightGuard {
64    fn new() -> Self {
65        Self { armed: true }
66    }
67
68    /// Consume the guard without decrementing. Call this when `InflightDecStream`
69    /// takes over responsibility for the decrement.
70    fn disarm(mut self) {
71        self.armed = false;
72    }
73}
74
75impl Drop for InflightGuard {
76    fn drop(&mut self) {
77        if self.armed {
78            REQUEST_PLANE_INFLIGHT.dec();
79        }
80    }
81}
82
83/// Wrapper that decrements request-plane inflight gauge when the stream is dropped.
84struct InflightDecStream<S> {
85    inner: S,
86}
87
88impl<S, T> Stream for InflightDecStream<S>
89where
90    S: Stream<Item = T> + Unpin,
91{
92    type Item = T;
93
94    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95        Pin::new(&mut self.inner).poll_next(cx)
96    }
97}
98
99impl<S> Drop for InflightDecStream<S> {
100    fn drop(&mut self) {
101        REQUEST_PLANE_INFLIGHT.dec();
102    }
103}
104
105pub struct AddressedRequest<T> {
106    request: T,
107    address: String,
108    /// Carries endpoint name + instance_id so cancellation is scoped to the
109    /// exact (endpoint, instance) pair, not all endpoints on the same runtime.
110    instance: Option<Instance>,
111}
112
113impl<T> AddressedRequest<T> {
114    pub fn new(request: T, address: String) -> Self {
115        Self {
116            request,
117            address,
118            instance: None,
119        }
120    }
121
122    pub fn with_instance(request: T, address: String, instance: Instance) -> Self {
123        Self {
124            request,
125            address,
126            instance: Some(instance),
127        }
128    }
129
130    pub fn for_instance(request: T, instance: Instance) -> Self {
131        let address = instance.transport.address().to_string();
132        Self::with_instance(request, address, instance)
133    }
134
135    pub(crate) fn into_parts(self) -> (T, String, Option<Instance>) {
136        (self.request, self.address, self.instance)
137    }
138}
139
140pub struct AddressedPushRouter {
141    // Request transport (unified trait object - works with all transports)
142    req_client: Arc<dyn RequestPlaneClient>,
143
144    // Response transport (TCP streaming - unchanged)
145    resp_transport: Arc<tcp::server::TcpStreamServer>,
146}
147
148impl AddressedPushRouter {
149    /// Create a new router with a request plane client
150    ///
151    /// This is the unified constructor that works with any transport type.
152    /// The client is provided as a trait object, hiding the specific implementation.
153    pub fn new(
154        req_client: Arc<dyn RequestPlaneClient>,
155        resp_transport: Arc<tcp::server::TcpStreamServer>,
156    ) -> Result<Arc<Self>> {
157        Ok(Arc::new(Self {
158            req_client,
159            resp_transport,
160        }))
161    }
162
163    pub async fn from_runtime_provider(
164        provider: &impl DistributedRuntimeProvider,
165    ) -> Result<Arc<Self>> {
166        let manager = provider.drt().network_manager();
167        let req_client = manager.create_client()?;
168        let resp_transport = provider.drt().tcp_server().await?;
169
170        tracing::debug!(
171            transport = req_client.transport_name(),
172            "Creating AddressedPushRouter with request plane client"
173        );
174
175        Self::new(req_client, resp_transport)
176    }
177
178    /// Cancel all pending response-stream registrations for an instance.
179    pub async fn cancel_instance_streams(&self, instance_id: &EndpointInstanceId) -> usize {
180        self.resp_transport
181            .cancel_instance_streams(instance_id)
182            .await
183    }
184
185    /// Clear the tombstone after an instance reappears in discovery.
186    pub async fn clear_instance_tombstone(&self, instance_id: &EndpointInstanceId) {
187        self.resp_transport
188            .clear_instance_tombstone(instance_id)
189            .await
190    }
191}
192
193#[async_trait::async_trait]
194impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
195where
196    T: Data + Serialize,
197    U: Data + for<'de> Deserialize<'de> + MaybeError,
198{
199    async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
200        let queue_start = Instant::now();
201        REQUEST_PLANE_INFLIGHT.inc();
202        let inflight_guard = InflightGuard::new();
203
204        let request_id = request.context().id().to_string();
205        let (addressed_request, context) = request.transfer(());
206        let (request, address, instance_info) = addressed_request.into_parts();
207        let engine_ctx = context.context();
208        let engine_ctx_ = engine_ctx.clone();
209
210        // registration options for the data plane in a singe in / many out configuration
211        let options = StreamOptions::builder()
212            .context(engine_ctx.clone())
213            .enable_request_stream(false)
214            .enable_response_stream(true)
215            .build()
216            .unwrap();
217
218        // register our needs with the data plane
219        // todo - generalize this with a generic data plane object which hides the specific transports
220        let pending_connections: PendingConnections = self.resp_transport.register(options).await;
221
222        // validate and unwrap the RegisteredStream object
223        let pending_response_stream = match pending_connections.into_parts() {
224            (None, Some(recv_stream)) => recv_stream,
225            _ => {
226                panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
227            }
228        };
229
230        // separate out the connection info and the stream provider from the registered stream
231        let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
232
233        // Snapshot subject before connection_info is moved; used for cleanup.
234        let recv_subject: Option<String> =
235            serde_json::from_str::<tcp::TcpStreamConnectionInfo>(&connection_info.info)
236                .ok()
237                .map(|ci| ci.subject);
238
239        // If the instance is already tombstoned, fail fast with a migratable
240        // error instead of writing to the request plane.
241        if let (Some(subject), Some(inst)) = (&recv_subject, &instance_info) {
242            let endpoint_instance_id = inst.endpoint_instance_id();
243            if !self
244                .resp_transport
245                .associate_instance(subject, &endpoint_instance_id)
246                .await
247            {
248                return Err(anyhow::anyhow!(
249                    DynamoError::builder()
250                        .error_type(ErrorType::Disconnected)
251                        .message(
252                            "Worker removed before request could be sent (tombstoned instance)"
253                        )
254                        .build()
255                ));
256            }
257        }
258
259        // package up the connection info as part of the "header" component of the two part message
260        // used to issue the request on the
261        // todo -- this object should be automatically created by the register call, and achieved by to the two into_parts()
262        // calls. all the information here is provided by the [`StreamOptions`] object and/or the dataplane object
263        let control_message = RequestControlMessage {
264            id: engine_ctx.id().to_string(),
265            request_type: RequestType::SingleIn,
266            response_type: ResponseType::ManyOut,
267            connection_info,
268            metadata: context.metadata().clone(),
269            frontend_send_ts_ns: None,
270        };
271
272        // next build the two part message where we package the connection info and the request into
273        // a single Vec<u8> that can be sent over the wire.
274        // --- package this up in the WorkQueuePublisher ---
275        let ctrl = match serialize_control_message(&control_message) {
276            Ok(v) => v,
277            Err(e) => {
278                if let Some(subject) = &recv_subject {
279                    self.resp_transport.cancel_recv_stream(subject).await;
280                }
281                return Err(e);
282            }
283        };
284        let data = match serde_json::to_vec(&request) {
285            Ok(v) => v,
286            Err(e) => {
287                if let Some(subject) = &recv_subject {
288                    self.resp_transport.cancel_recv_stream(subject).await;
289                }
290                return Err(e.into());
291            }
292        };
293
294        tracing::trace!(
295            request_id,
296            "packaging two-part message; ctrl: {} bytes, data: {} bytes",
297            ctrl.len(),
298            data.len()
299        );
300
301        let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
302
303        // the request plane / work queue should provide a two part message codec that can be used
304        // or it should take a two part message directly
305        // todo - update this
306        let codec = TwoPartCodec::default();
307        let buffer = match codec.encode_message(msg) {
308            Ok(v) => v,
309            Err(e) => {
310                if let Some(subject) = &recv_subject {
311                    self.resp_transport.cancel_recv_stream(subject).await;
312                }
313                return Err(e.into());
314            }
315        };
316
317        REQUEST_PLANE_QUEUE_SECONDS.observe(queue_start.elapsed().as_secs_f64());
318        let tx_start = Instant::now();
319
320        // TRANSPORT ABSTRACT REQUIRED - END HERE
321
322        // Send request using unified client interface
323        tracing::trace!(
324            request_id,
325            transport = self.req_client.transport_name(),
326            address = %address,
327            "Sending request via request plane client"
328        );
329
330        // Prepare trace headers using shared helper
331        let mut headers = std::collections::HashMap::new();
332        inject_trace_headers_into_map(&mut headers);
333        headers.insert("request-id".to_string(), request_id.clone());
334
335        // Stamp send time right before the transport write so the network
336        // transit metric excludes serialization/encoding overhead.
337        let send_ts_ns = std::time::SystemTime::now()
338            .duration_since(std::time::UNIX_EPOCH)
339            .unwrap_or_default()
340            .as_nanos() as u64;
341        headers.insert("x-frontend-send-ts-ns".to_string(), send_ts_ns.to_string());
342
343        // Phase A: Frontend → Backend (network + queue + ack)
344        let _nvtx_send = dynamo_nvtx_range!("transport.tcp.send");
345        let send_result = self.req_client.send_request(address, buffer, headers).await;
346        drop(_nvtx_send);
347
348        if let Err(e) = send_result {
349            if let Some(subject) = &recv_subject {
350                self.resp_transport.cancel_recv_stream(subject).await;
351            }
352            return Err(e);
353        }
354        REQUEST_PLANE_SEND_SECONDS.observe(tx_start.elapsed().as_secs_f64());
355
356        let _nvtx_wait = dynamo_nvtx_range!("transport.tcp.wait_backend");
357        tracing::trace!(request_id, "awaiting transport handshake");
358
359        // RecvError → migratable Disconnected (watcher cancelled the subject
360        // or the worker died before establishing the response stream).
361        let response_stream = match response_stream_provider.await {
362            Ok(Ok(stream)) => stream,
363            Ok(Err(e)) => {
364                // generate() failed before any response bytes; migrate via
365                // CannotConnect since the dominant cause is a worker-local
366                // setup/version issue. The wire prologue carries only an
367                // opaque string today, so app-level rejections also retry
368                // -- safe because no side effects are visible yet. Follow-up:
369                // structured prologue error type for finer routing.
370                if let Some(subject) = &recv_subject {
371                    self.resp_transport.cancel_recv_stream(subject).await;
372                }
373                return Err(anyhow::anyhow!(
374                    DynamoError::builder()
375                        .error_type(ErrorType::CannotConnect)
376                        .message(format!(
377                            "Worker generate() failed before response stream: {e}"
378                        ))
379                        .build()
380                ));
381            }
382            Err(_recv_err) => {
383                // oneshot dropped: either the discovery watcher cancelled
384                // this subject or the worker died mid-handshake.
385                if let Some(subject) = &recv_subject {
386                    self.resp_transport.cancel_recv_stream(subject).await;
387                }
388                return Err(anyhow::anyhow!(
389                    DynamoError::builder()
390                        .error_type(ErrorType::Disconnected)
391                        .message("Worker disconnected before response stream was established")
392                        .build()
393                ));
394            }
395        };
396        drop(_nvtx_wait);
397
398        // TODO: Detect end-of-stream using Server-Sent Events (SSE)
399        let mut is_complete_final = false;
400        let mut first_response = true;
401        let stream = tokio_stream::StreamNotifyClose::new(
402            tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
403        )
404        .filter_map(move |res| {
405            if let Some(res_bytes) = res {
406                if first_response {
407                    first_response = false;
408                    let roundtrip_ttft = tx_start.elapsed().as_secs_f64();
409                    REQUEST_PLANE_ROUNDTRIP_TTFT_SECONDS.observe(roundtrip_ttft);
410                    STAGE_DURATION_SECONDS
411                        .with_label_values(&["transport_roundtrip"])
412                        .observe(queue_start.elapsed().as_secs_f64());
413                }
414                if is_complete_final {
415                    let err = DynamoError::msg(
416                        "Response received after generation ended - this should never happen",
417                    );
418                    return Some(U::from_err(err));
419                }
420                match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
421                    Ok(item) => {
422                        is_complete_final = item.complete_final;
423                        if let Some(data) = item.data {
424                            Some(data)
425                        } else if is_complete_final {
426                            None
427                        } else {
428                            let err = DynamoError::msg(
429                                "Empty response received - this should never happen",
430                            );
431                            Some(U::from_err(err))
432                        }
433                    }
434                    Err(err) => {
435                        // legacy log print
436                        let json_str = String::from_utf8_lossy(&res_bytes);
437                        tracing::warn!(%err, %json_str, "Failed deserializing JSON to response");
438
439                        Some(U::from_err(DynamoError::msg(err.to_string())))
440                    }
441                }
442            } else if is_complete_final {
443                // end of stream
444                None
445            } else if engine_ctx_.is_stopped() {
446                // Gracefully end the stream if 'stop_generating()' was called. Do NOT check for
447                // 'is_killed()' here because it implies the stream ended abnormally which should be
448                // handled by the error branch below.
449                tracing::debug!("Request cancelled and then trying to read a response");
450                None
451            } else {
452                // stream ended unexpectedly
453                let err = DynamoError::builder()
454                    .error_type(ErrorType::Disconnected)
455                    .message("Stream ended before generation completed")
456                    .build();
457                tracing::debug!("{err}");
458                Some(U::from_err(err))
459            }
460        });
461
462        inflight_guard.disarm();
463        let stream = InflightDecStream { inner: stream };
464        Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::{
471        CONTROL_MESSAGE_MAX_BYTES, ConnectionInfo, RequestControlMessage, RequestType,
472        ResponseType, serialize_control_message,
473    };
474    use std::collections::BTreeMap;
475
476    fn base_control_message(metadata: BTreeMap<String, String>) -> RequestControlMessage {
477        RequestControlMessage {
478            id: "request-123".to_string(),
479            request_type: RequestType::SingleIn,
480            response_type: ResponseType::ManyOut,
481            connection_info: ConnectionInfo {
482                transport: "tcp".to_string(),
483                info: "{}".to_string(),
484            },
485            metadata,
486            frontend_send_ts_ns: None,
487        }
488    }
489
490    #[test]
491    fn serialize_control_message_succeeds_under_limit() {
492        let mut metadata = BTreeMap::new();
493        metadata.insert("x-tiny-blob".to_string(), "alpha".to_string());
494
495        let ctrl = serialize_control_message(&base_control_message(metadata))
496            .expect("control message should serialize under the limit");
497        assert!(ctrl.len() <= CONTROL_MESSAGE_MAX_BYTES);
498    }
499
500    #[test]
501    fn serialize_control_message_errors_over_limit() {
502        let mut metadata = BTreeMap::new();
503        metadata.insert(
504            "x-large-blob".to_string(),
505            "x".repeat(CONTROL_MESSAGE_MAX_BYTES),
506        );
507
508        let err = serialize_control_message(&base_control_message(metadata))
509            .expect_err("oversized control message should fail")
510            .to_string();
511        assert!(err.contains("request control message too large"));
512        assert!(err.contains(&CONTROL_MESSAGE_MAX_BYTES.to_string()));
513    }
514}