Skip to main content

dynamo_runtime/pipeline/
network.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Network layer for distributed communication
5//!
6//! Provides request distribution across multiple transport protocols:
7//! - HTTP/2 for standard deployments
8//! - TCP with length-prefixed protocol for high-performance scenarios
9//! - NATS for legacy/messaging-based deployments
10
11pub mod codec;
12pub mod egress;
13pub mod ingress;
14pub mod manager;
15pub mod tcp;
16
17use crate::SystemHealth;
18use std::sync::{Arc, OnceLock};
19
20use anyhow::Result;
21use async_trait::async_trait;
22use bytes::Bytes;
23use codec::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
24use derive_builder::Builder;
25use futures::StreamExt;
26// io::Cursor, TryStreamExt
27use super::{AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, ResponseStream};
28use serde::{Deserialize, Serialize};
29
30use super::{
31    AsyncTransportEngine, Context, Data, Error, ManyOut, PipelineError, PipelineIO, SegmentSource,
32    ServiceBackend, ServiceEngine, SingleIn, Source, context,
33};
34use crate::metrics::MetricsHierarchy;
35use ingress::push_handler::WorkHandlerMetrics;
36use prometheus::{CounterVec, Histogram, IntCounter, IntCounterVec, IntGauge};
37
38/// Shared default maximum TCP message size across request-plane components.
39pub(crate) const DEFAULT_TCP_MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024;
40
41static TCP_MAX_MESSAGE_SIZE: OnceLock<usize> = OnceLock::new();
42
43/// Read the configured TCP max message size once and share it across client,
44/// server, and zero-copy decoder code paths.
45pub(crate) fn get_tcp_max_message_size() -> usize {
46    *TCP_MAX_MESSAGE_SIZE.get_or_init(|| {
47        std::env::var("DYN_TCP_MAX_MESSAGE_SIZE")
48            .ok()
49            .and_then(|s| s.parse::<usize>().ok())
50            .unwrap_or(DEFAULT_TCP_MAX_MESSAGE_SIZE)
51    })
52}
53
54pub trait Codable: PipelineIO + Serialize + for<'de> Deserialize<'de> {}
55impl<T: PipelineIO + Serialize + for<'de> Deserialize<'de>> Codable for T {}
56
57/// `WorkQueueConsumer` is a generic interface for a work queue that can be used to send and receive
58#[async_trait]
59pub trait WorkQueueConsumer {
60    async fn dequeue(&self) -> Result<Bytes, String>;
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
64#[serde(rename_all = "snake_case")]
65pub enum StreamType {
66    Request,
67    Response,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
71#[serde(rename_all = "snake_case")]
72pub(crate) enum RequestType {
73    SingleIn,
74    ManyIn,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78#[serde(rename_all = "snake_case")]
79pub(crate) enum ResponseType {
80    SingleOut,
81    ManyOut,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub(crate) struct RequestControlMessage {
86    pub(crate) id: String,
87    pub(crate) request_type: RequestType,
88    pub(crate) response_type: ResponseType,
89    pub(crate) connection_info: ConnectionInfo,
90    #[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")]
91    pub(crate) metadata: std::collections::BTreeMap<String, String>,
92    /// Wall-clock send timestamp (nanos since UNIX epoch) for transport latency breakdown.
93    /// Uses `SystemTime` so accuracy depends on NTP sync between frontend and backend hosts.
94    /// Reliable for single-machine profiling; treat cross-host values as approximate.
95    #[serde(default, skip_serializing_if = "Option::is_none")]
96    pub(crate) frontend_send_ts_ns: Option<u64>,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
100#[serde(rename_all = "snake_case")]
101pub enum ControlMessage {
102    Stop,
103    Kill,
104    Sentinel,
105}
106
107/// This is the first message in a `ResponseStream`. This is not a message that gets process
108/// by the general pipeline, but is a control message that is awaited before the
109/// [`AsyncEngine::generate`] method is allowed to return.
110///
111/// If an error is present, the [`AsyncEngine::generate`] method will return the error instead
112/// of returning the `ResponseStream`.
113#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
114pub struct ResponseStreamPrologue {
115    error: Option<String>,
116}
117
118pub type StreamProvider<T> = tokio::sync::oneshot::Receiver<Result<T, String>>;
119
120/// Owning `Drop` here (rather than on `RegisteredStream`) lets `into_parts()`
121/// move the public fields out by plain destructure.
122struct Cleanup(Option<Box<dyn FnOnce() + Send + 'static>>);
123
124impl Drop for Cleanup {
125    fn drop(&mut self) {
126        if let Some(f) = self.0.take() {
127            f();
128        }
129    }
130}
131
132/// Awaitable handle for a stream sender or receiver. Drop without calling
133/// [`into_parts()`] runs the optional cleanup closure, removing the
134/// registration from the stream server's maps.
135pub struct RegisteredStream<T> {
136    pub connection_info: ConnectionInfo,
137    pub stream_provider: StreamProvider<T>,
138    cleanup: Cleanup,
139}
140
141impl<T> std::fmt::Debug for RegisteredStream<T> {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        f.debug_struct("RegisteredStream")
144            .field("connection_info", &self.connection_info)
145            .finish_non_exhaustive()
146    }
147}
148
149impl<T> RegisteredStream<T> {
150    pub(crate) fn new(connection_info: ConnectionInfo, stream_provider: StreamProvider<T>) -> Self {
151        Self {
152            connection_info,
153            stream_provider,
154            cleanup: Cleanup(None),
155        }
156    }
157
158    pub(crate) fn with_cleanup<F>(mut self, cleanup: F) -> Self
159    where
160        F: FnOnce() + Send + 'static,
161    {
162        self.cleanup.0 = Some(Box::new(cleanup));
163        self
164    }
165
166    /// Consume the registration, disarming the RAII cleanup. Caller takes
167    /// responsibility for cleanup if the stream provider is never awaited.
168    pub fn into_parts(self) -> (ConnectionInfo, StreamProvider<T>) {
169        let Self {
170            connection_info,
171            stream_provider,
172            mut cleanup,
173        } = self;
174        cleanup.0.take();
175        (connection_info, stream_provider)
176    }
177}
178
179/// After registering a stream, the [`PendingConnections`] object is returned to the caller. This
180/// object can be used to await the connection to be established.
181pub struct PendingConnections {
182    pub send_stream: Option<RegisteredStream<StreamSender>>,
183    pub recv_stream: Option<RegisteredStream<StreamReceiver>>,
184}
185
186impl PendingConnections {
187    pub fn into_parts(
188        self,
189    ) -> (
190        Option<RegisteredStream<StreamSender>>,
191        Option<RegisteredStream<StreamReceiver>>,
192    ) {
193        (self.send_stream, self.recv_stream)
194    }
195}
196
197/// A [`ResponseService`] implements a services in which a context a specific subject with will
198/// be associated with a stream of responses.
199#[async_trait::async_trait]
200pub trait ResponseService {
201    async fn register(&self, options: StreamOptions) -> PendingConnections;
202}
203
204#[cfg(test)]
205mod registered_stream_tests {
206    use super::*;
207    use std::sync::atomic::{AtomicBool, Ordering};
208
209    fn dummy_conn_info() -> ConnectionInfo {
210        ConnectionInfo {
211            transport: "test".to_string(),
212            info: "{}".to_string(),
213        }
214    }
215
216    /// Drop without `into_parts()` must run the cleanup closure.
217    #[test]
218    fn drop_runs_cleanup() {
219        let flag = Arc::new(AtomicBool::new(false));
220        let flag_clone = flag.clone();
221
222        let (_tx, rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
223        let stream = RegisteredStream::new(dummy_conn_info(), rx).with_cleanup(move || {
224            flag_clone.store(true, Ordering::SeqCst);
225        });
226
227        drop(stream);
228        assert!(
229            flag.load(Ordering::SeqCst),
230            "cleanup must fire when RegisteredStream is dropped"
231        );
232    }
233
234    /// `into_parts()` must disarm the cleanup. After the call, dropping the
235    /// returned halves must NOT trigger the closure -- the caller has taken
236    /// ownership of cleanup responsibility.
237    #[test]
238    fn into_parts_disarms_cleanup() {
239        let flag = Arc::new(AtomicBool::new(false));
240        let flag_clone = flag.clone();
241
242        let (_tx, rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
243        let stream = RegisteredStream::new(dummy_conn_info(), rx).with_cleanup(move || {
244            flag_clone.store(true, Ordering::SeqCst);
245        });
246
247        let (conn, provider) = stream.into_parts();
248        drop(conn);
249        drop(provider);
250
251        assert!(
252            !flag.load(Ordering::SeqCst),
253            "into_parts() must disarm the cleanup closure"
254        );
255    }
256
257    /// `RegisteredStream` with no cleanup configured must drop cleanly.
258    #[test]
259    fn drop_without_cleanup_is_a_noop() {
260        let (_tx, rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
261        let stream: RegisteredStream<()> = RegisteredStream::new(dummy_conn_info(), rx);
262        drop(stream); // must not panic; nothing observable to assert beyond that
263    }
264}
265
266// #[derive(Debug, Clone, Serialize, Deserialize)]
267// struct Handshake {
268//     request_id: String,
269//     worker_id: Option<String>,
270//     error: Option<String>,
271// }
272
273// impl Handshake {
274//     pub fn validate(&self) -> Result<(), String> {
275//         if let Some(e) = &self.error {
276//             return Err(e.clone());
277//         }
278//         Ok(())
279//     }
280// }
281
282// this probably needs to be come a ResponseStreamSender
283// since the prologue in this scenario sender telling the receiver
284// that all is good and it's ready to send
285//
286// in the RequestStreamSender, the prologue would be coming from the
287// receiver, so the sender would have to await the prologue which if
288// was not an error, would indicate the RequestStreamReceiver is read
289// to receive data.
290pub struct StreamSender {
291    tx: tokio::sync::mpsc::Sender<TwoPartMessage>,
292    prologue: Option<ResponseStreamPrologue>,
293}
294
295impl StreamSender {
296    pub async fn send(&self, data: Bytes) -> Result<()> {
297        Ok(self.tx.send(TwoPartMessage::from_data(data)).await?)
298    }
299
300    pub async fn send_control(&self, control: ControlMessage) -> Result<()> {
301        let bytes = serde_json::to_vec(&control)?;
302        Ok(self
303            .tx
304            .send(TwoPartMessage::from_header(bytes.into()))
305            .await?)
306    }
307
308    #[allow(clippy::needless_update)]
309    pub async fn send_prologue(&mut self, error: Option<String>) -> Result<(), String> {
310        // leaving the original logic in place for now
311        // error overrides the dissolved prologue, but the only field on `ResponseStreamPrologue` is `error`
312        // so the second argument can never be used, and the value of error passed by the caller would always be used
313        if let Some(_prologue) = self.prologue.take() {
314            // let prologue = ResponseStreamPrologue { error, ..prologue };
315            let prologue = ResponseStreamPrologue { error };
316            let header_bytes: Bytes = match serde_json::to_vec(&prologue) {
317                Ok(b) => b.into(),
318                Err(err) => {
319                    tracing::error!(%err, "send_prologue: ResponseStreamPrologue did not serialize to a JSON array");
320                    return Err("Invalid prologue".to_string());
321                }
322            };
323            self.tx
324                .send(TwoPartMessage::from_header(header_bytes))
325                .await
326                .map_err(|e| e.to_string())?;
327        } else {
328            panic!("Prologue already sent; or not set; logic error");
329        }
330        Ok(())
331    }
332}
333
334pub struct StreamReceiver {
335    rx: tokio::sync::mpsc::Receiver<Bytes>,
336}
337
338/// Connection Info is encoded as JSON and then again serialized has part of the Transport
339/// Layer. The double serialization is not performance critical as it is only done once per
340/// connection. The primary reason storing the ConnecitonInfo has a JSON string is for type
341/// erasure. The Transport Layer will check the [`ConnectionInfo::transport`] type and then
342/// route it to the appropriate instance of the Transport, which will then deserialize the
343/// [`ConnectionInfo::info`] field to its internal connection info object.
344///
345/// Optionally, this object could become strongly typed for which all possible combinations
346/// of transport and connection info would need to be enumerated.
347#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct ConnectionInfo {
349    pub transport: String,
350    pub info: String,
351}
352
353/// When registering a new TransportStream on the server, the caller specifies if the
354/// stream is a sender, receiver or both.
355///
356/// Senders and Receivers are with share a Context, but result in separate tcp socket
357/// connections to the server. Internally, we may use bcast channels to coordinate the
358/// internal control messages between the sender and receiver socket connections.
359#[derive(Clone, Builder)]
360pub struct StreamOptions {
361    /// Context
362    pub context: Arc<dyn AsyncEngineContext>,
363
364    /// Register with the server that this connection will have a server-side Sender
365    /// that can be picked up by the Request/Forward pipeline
366    ///
367    /// TODO - note, this option is currently not implemented and will cause a panic
368    pub enable_request_stream: bool,
369
370    /// Register with the server that this connection will have a server-side Receiver
371    /// that can be picked up by the Response/Reverse pipeline
372    pub enable_response_stream: bool,
373
374    /// The number of messages to buffer before blocking
375    #[builder(default = "8")]
376    pub send_buffer_count: usize,
377
378    /// The number of messages to buffer before blocking
379    #[builder(default = "8")]
380    pub recv_buffer_count: usize,
381}
382
383impl StreamOptions {
384    pub fn builder() -> StreamOptionsBuilder {
385        StreamOptionsBuilder::default()
386    }
387}
388
389pub struct Egress<Req: PipelineIO, Resp: PipelineIO> {
390    transport_engine: Arc<dyn AsyncTransportEngine<Req, Resp>>,
391}
392
393#[cfg(test)]
394mod tests {
395    use super::{RequestControlMessage, RequestType, ResponseType};
396
397    #[test]
398    fn request_control_message_defaults_missing_metadata() {
399        let json = r#"{
400            "id": "request-123",
401            "request_type": "single_in",
402            "response_type": "many_out",
403            "connection_info": {
404                "transport": "tcp",
405                "info": "{}"
406            }
407        }"#;
408
409        let message: RequestControlMessage =
410            serde_json::from_str(json).expect("control message should deserialize");
411
412        assert_eq!(message.id, "request-123");
413        assert!(matches!(message.request_type, RequestType::SingleIn));
414        assert!(matches!(message.response_type, ResponseType::ManyOut));
415        assert_eq!(message.connection_info.transport, "tcp");
416        assert_eq!(message.connection_info.info, "{}");
417        assert!(message.metadata.is_empty());
418        assert!(message.frontend_send_ts_ns.is_none());
419    }
420}
421
422#[async_trait]
423impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error>
424    for Egress<SingleIn<T>, ManyOut<U>>
425where
426    T: Data + Serialize,
427    U: for<'de> Deserialize<'de> + Data,
428{
429    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
430        self.transport_engine.generate(request).await
431    }
432}
433
434pub struct Ingress<Req: PipelineIO, Resp: PipelineIO> {
435    segment: OnceLock<Arc<SegmentSource<Req, Resp>>>,
436    metrics: OnceLock<Arc<WorkHandlerMetrics>>,
437    /// Endpoint-specific notifier for health check timer resets
438    endpoint_health_check_notifier: OnceLock<Arc<tokio::sync::Notify>>,
439}
440
441impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> {
442    pub fn new() -> Arc<Self> {
443        Arc::new(Self {
444            segment: OnceLock::new(),
445            metrics: OnceLock::new(),
446            endpoint_health_check_notifier: OnceLock::new(),
447        })
448    }
449
450    pub fn attach(&self, segment: Arc<SegmentSource<Req, Resp>>) -> Result<()> {
451        self.segment
452            .set(segment)
453            .map_err(|_| anyhow::anyhow!("Segment already set"))
454    }
455
456    pub fn add_metrics(
457        &self,
458        endpoint: &crate::component::Endpoint,
459        metrics_labels: Option<&[(&str, &str)]>,
460    ) -> Result<()> {
461        let metrics = WorkHandlerMetrics::from_endpoint(endpoint, metrics_labels)
462            .map_err(|e| anyhow::anyhow!("Failed to create work handler metrics: {}", e))?;
463
464        // Register global transport breakdown metrics (idempotent)
465        crate::metrics::work_handler_perf::ensure_work_handler_perf_metrics_registered(
466            endpoint.get_metrics_registry(),
467        );
468
469        // Register worker-pool saturation metrics (idempotent). These are
470        // process-global and shared across all endpoints attached to the
471        // same shared TCP server.
472        crate::metrics::work_handler_pool::ensure_work_handler_pool_metrics_registered(
473            endpoint.get_metrics_registry(),
474        );
475
476        self.metrics
477            .set(Arc::new(metrics))
478            .map_err(|_| anyhow::anyhow!("Metrics already set"))
479    }
480
481    pub fn link(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
482        let ingress = Ingress::new();
483        ingress.attach(segment)?;
484        Ok(ingress)
485    }
486
487    pub fn for_pipeline(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
488        let ingress = Ingress::new();
489        ingress.attach(segment)?;
490        Ok(ingress)
491    }
492
493    pub fn for_engine(engine: ServiceEngine<Req, Resp>) -> Result<Arc<Self>> {
494        let frontend = SegmentSource::<Req, Resp>::new();
495        let backend = ServiceBackend::from_engine(engine);
496
497        // create the pipeline
498        let pipeline = frontend.link(backend)?.link(frontend)?;
499
500        let ingress = Ingress::new();
501        ingress.attach(pipeline)?;
502
503        Ok(ingress)
504    }
505
506    /// Helper method to access metrics if available
507    fn metrics(&self) -> Option<&Arc<WorkHandlerMetrics>> {
508        self.metrics.get()
509    }
510}
511
512#[async_trait]
513pub trait PushWorkHandler: Send + Sync {
514    async fn handle_payload(
515        &self,
516        payload: Bytes,
517        request_id: Option<String>,
518    ) -> Result<(), PipelineError>;
519
520    /// Add metrics to the handler
521    fn add_metrics(
522        &self,
523        endpoint: &crate::component::Endpoint,
524        metrics_labels: Option<&[(&str, &str)]>,
525    ) -> Result<()>;
526
527    /// Set the endpoint-specific notifier for health check timer resets
528    fn set_endpoint_health_check_notifier(
529        &self,
530        _notifier: Arc<tokio::sync::Notify>,
531    ) -> Result<()> {
532        // Default implementation for backwards compatibility
533        Ok(())
534    }
535}
536
537/*
538/// `NetworkStreamWrapper` is a simple wrapper used to detect proper stream termination
539/// in network communication between ingress and egress components.
540///
541/// **Purpose**: This wrapper solves the problem of detecting whether a stream ended
542/// gracefully or was cut off prematurely (e.g., due to network issues).
543///
544/// **Design Rationale**:
545/// - Cannot use `Annotated` directly because the generic type `U` varies:
546///   - Sometimes `U = Annotated<...>`
547///   - Sometimes `U = LLMEngineOutput<...>`
548/// - Using `Annotated` would require double-wrapping like `Annotated<Annotated<...>>`
549/// - A simple wrapper is cleaner and more straightforward
550///
551/// **Stream Flow**:
552/// ```
553/// At AsyncEngine:
554///   response 1 -> response 2 -> response 3 -> <end>
555///
556/// Between ingress/egress:
557///   response 1 <end=false> -> response 2 <end=false> -> response 3 <end=false> -> (null) <end=true>
558///
559/// At client:
560///   response 1 -> response 2 -> response 3 -> <end>
561/// ```
562///
563/// **Error Handling**:
564/// If the stream is cut off before proper termination, the egress is responsible for
565/// injecting an error response to communicate the incomplete stream to the client:
566/// ```
567/// At AsyncEngine:
568///   response 1 -> ... <without end flag>
569///
570/// At egress:
571///   response 1 <end=false> -> <stream ended without end flag -> convert to error>
572///
573/// At client:
574///   response 1 -> error response
575/// ```
576///
577/// The detection must be done at egress level because premature stream termination
578/// can be due to network issues that only the egress component can detect.
579*/
580/// TODO: Detect end-of-stream using Server-Sent Events (SSE). This will be removed.
581#[derive(Serialize, Deserialize, Debug)]
582pub struct NetworkStreamWrapper<U> {
583    #[serde(skip_serializing_if = "Option::is_none")]
584    pub data: Option<U>,
585    pub complete_final: bool,
586}