Skip to main content

dynamo_runtime/pipeline/
network.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Network layer for distributed communication
5//!
6//! Provides request distribution across multiple transport protocols:
7//! - HTTP/2 for standard deployments
8//! - TCP with length-prefixed protocol for high-performance scenarios
9//! - NATS for legacy/messaging-based deployments
10
11pub mod codec;
12pub mod egress;
13pub mod ingress;
14pub mod manager;
15pub mod tcp;
16
17use crate::SystemHealth;
18use std::sync::{Arc, OnceLock};
19
20use anyhow::Result;
21use async_trait::async_trait;
22use bytes::Bytes;
23use codec::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
24use derive_builder::Builder;
25use futures::StreamExt;
26// io::Cursor, TryStreamExt
27use super::{AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, ResponseStream};
28use serde::{Deserialize, Serialize};
29
30use super::{
31    AsyncTransportEngine, Context, Data, Error, ManyOut, PipelineError, PipelineIO, SegmentSource,
32    ServiceBackend, ServiceEngine, SingleIn, Source, context,
33};
34use crate::metrics::MetricsHierarchy;
35use ingress::push_handler::WorkHandlerMetrics;
36use prometheus::{CounterVec, Histogram, IntCounter, IntCounterVec, IntGauge};
37
38/// Shared default maximum TCP message size across request-plane components.
39pub(crate) const DEFAULT_TCP_MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024;
40
41static TCP_MAX_MESSAGE_SIZE: OnceLock<usize> = OnceLock::new();
42
43/// Read the configured TCP max message size once and share it across client,
44/// server, and zero-copy decoder code paths.
45pub(crate) fn get_tcp_max_message_size() -> usize {
46    *TCP_MAX_MESSAGE_SIZE.get_or_init(|| {
47        std::env::var("DYN_TCP_MAX_MESSAGE_SIZE")
48            .ok()
49            .and_then(|s| s.parse::<usize>().ok())
50            .unwrap_or(DEFAULT_TCP_MAX_MESSAGE_SIZE)
51    })
52}
53
54pub trait Codable: PipelineIO + Serialize + for<'de> Deserialize<'de> {}
55impl<T: PipelineIO + Serialize + for<'de> Deserialize<'de>> Codable for T {}
56
57/// `WorkQueueConsumer` is a generic interface for a work queue that can be used to send and receive
58#[async_trait]
59pub trait WorkQueueConsumer {
60    async fn dequeue(&self) -> Result<Bytes, String>;
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
64#[serde(rename_all = "snake_case")]
65pub enum StreamType {
66    Request,
67    Response,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
71#[serde(rename_all = "snake_case")]
72pub enum ControlMessage {
73    Stop,
74    Kill,
75    Sentinel,
76}
77
78/// This is the first message in a `ResponseStream`. This is not a message that gets process
79/// by the general pipeline, but is a control message that is awaited before the
80/// [`AsyncEngine::generate`] method is allowed to return.
81///
82/// If an error is present, the [`AsyncEngine::generate`] method will return the error instead
83/// of returning the `ResponseStream`.
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
85pub struct ResponseStreamPrologue {
86    error: Option<String>,
87}
88
89pub type StreamProvider<T> = tokio::sync::oneshot::Receiver<Result<T, String>>;
90
91/// Owning `Drop` here (rather than on `RegisteredStream`) lets `into_parts()`
92/// move the public fields out by plain destructure.
93struct Cleanup(Option<Box<dyn FnOnce() + Send + 'static>>);
94
95impl Drop for Cleanup {
96    fn drop(&mut self) {
97        if let Some(f) = self.0.take() {
98            f();
99        }
100    }
101}
102
103/// Awaitable handle for a stream sender or receiver. Drop without calling
104/// [`into_parts()`] runs the optional cleanup closure, removing the
105/// registration from the stream server's maps.
106pub struct RegisteredStream<T> {
107    pub connection_info: ConnectionInfo,
108    pub stream_provider: StreamProvider<T>,
109    cleanup: Cleanup,
110}
111
112impl<T> std::fmt::Debug for RegisteredStream<T> {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("RegisteredStream")
115            .field("connection_info", &self.connection_info)
116            .finish_non_exhaustive()
117    }
118}
119
120impl<T> RegisteredStream<T> {
121    pub(crate) fn new(connection_info: ConnectionInfo, stream_provider: StreamProvider<T>) -> Self {
122        Self {
123            connection_info,
124            stream_provider,
125            cleanup: Cleanup(None),
126        }
127    }
128
129    pub(crate) fn with_cleanup<F>(mut self, cleanup: F) -> Self
130    where
131        F: FnOnce() + Send + 'static,
132    {
133        self.cleanup.0 = Some(Box::new(cleanup));
134        self
135    }
136
137    /// Consume the registration, disarming the RAII cleanup. Caller takes
138    /// responsibility for cleanup if the stream provider is never awaited.
139    pub fn into_parts(self) -> (ConnectionInfo, StreamProvider<T>) {
140        let Self {
141            connection_info,
142            stream_provider,
143            mut cleanup,
144        } = self;
145        cleanup.0.take();
146        (connection_info, stream_provider)
147    }
148}
149
150/// After registering a stream, the [`PendingConnections`] object is returned to the caller. This
151/// object can be used to await the connection to be established.
152pub struct PendingConnections {
153    pub send_stream: Option<RegisteredStream<StreamSender>>,
154    pub recv_stream: Option<RegisteredStream<StreamReceiver>>,
155}
156
157impl PendingConnections {
158    pub fn into_parts(
159        self,
160    ) -> (
161        Option<RegisteredStream<StreamSender>>,
162        Option<RegisteredStream<StreamReceiver>>,
163    ) {
164        (self.send_stream, self.recv_stream)
165    }
166}
167
168/// A [`ResponseService`] implements a services in which a context a specific subject with will
169/// be associated with a stream of responses.
170#[async_trait::async_trait]
171pub trait ResponseService {
172    async fn register(&self, options: StreamOptions) -> PendingConnections;
173}
174
175#[cfg(test)]
176mod registered_stream_tests {
177    use super::*;
178    use std::sync::atomic::{AtomicBool, Ordering};
179
180    fn dummy_conn_info() -> ConnectionInfo {
181        ConnectionInfo {
182            transport: "test".to_string(),
183            info: "{}".to_string(),
184        }
185    }
186
187    /// Drop without `into_parts()` must run the cleanup closure.
188    #[test]
189    fn drop_runs_cleanup() {
190        let flag = Arc::new(AtomicBool::new(false));
191        let flag_clone = flag.clone();
192
193        let (_tx, rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
194        let stream = RegisteredStream::new(dummy_conn_info(), rx).with_cleanup(move || {
195            flag_clone.store(true, Ordering::SeqCst);
196        });
197
198        drop(stream);
199        assert!(
200            flag.load(Ordering::SeqCst),
201            "cleanup must fire when RegisteredStream is dropped"
202        );
203    }
204
205    /// `into_parts()` must disarm the cleanup. After the call, dropping the
206    /// returned halves must NOT trigger the closure -- the caller has taken
207    /// ownership of cleanup responsibility.
208    #[test]
209    fn into_parts_disarms_cleanup() {
210        let flag = Arc::new(AtomicBool::new(false));
211        let flag_clone = flag.clone();
212
213        let (_tx, rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
214        let stream = RegisteredStream::new(dummy_conn_info(), rx).with_cleanup(move || {
215            flag_clone.store(true, Ordering::SeqCst);
216        });
217
218        let (conn, provider) = stream.into_parts();
219        drop(conn);
220        drop(provider);
221
222        assert!(
223            !flag.load(Ordering::SeqCst),
224            "into_parts() must disarm the cleanup closure"
225        );
226    }
227
228    /// `RegisteredStream` with no cleanup configured must drop cleanly.
229    #[test]
230    fn drop_without_cleanup_is_a_noop() {
231        let (_tx, rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
232        let stream: RegisteredStream<()> = RegisteredStream::new(dummy_conn_info(), rx);
233        drop(stream); // must not panic; nothing observable to assert beyond that
234    }
235}
236
237// #[derive(Debug, Clone, Serialize, Deserialize)]
238// struct Handshake {
239//     request_id: String,
240//     worker_id: Option<String>,
241//     error: Option<String>,
242// }
243
244// impl Handshake {
245//     pub fn validate(&self) -> Result<(), String> {
246//         if let Some(e) = &self.error {
247//             return Err(e.clone());
248//         }
249//         Ok(())
250//     }
251// }
252
253// this probably needs to be come a ResponseStreamSender
254// since the prologue in this scenario sender telling the receiver
255// that all is good and it's ready to send
256//
257// in the RequestStreamSender, the prologue would be coming from the
258// receiver, so the sender would have to await the prologue which if
259// was not an error, would indicate the RequestStreamReceiver is read
260// to receive data.
261pub struct StreamSender {
262    tx: tokio::sync::mpsc::Sender<TwoPartMessage>,
263    prologue: Option<ResponseStreamPrologue>,
264}
265
266impl StreamSender {
267    pub async fn send(&self, data: Bytes) -> Result<()> {
268        Ok(self.tx.send(TwoPartMessage::from_data(data)).await?)
269    }
270
271    pub async fn send_control(&self, control: ControlMessage) -> Result<()> {
272        let bytes = serde_json::to_vec(&control)?;
273        Ok(self
274            .tx
275            .send(TwoPartMessage::from_header(bytes.into()))
276            .await?)
277    }
278
279    #[allow(clippy::needless_update)]
280    pub async fn send_prologue(&mut self, error: Option<String>) -> Result<(), String> {
281        // leaving the original logic in place for now
282        // error overrides the dissolved prologue, but the only field on `ResponseStreamPrologue` is `error`
283        // so the second argument can never be used, and the value of error passed by the caller would always be used
284        if let Some(_prologue) = self.prologue.take() {
285            // let prologue = ResponseStreamPrologue { error, ..prologue };
286            let prologue = ResponseStreamPrologue { error };
287            let header_bytes: Bytes = match serde_json::to_vec(&prologue) {
288                Ok(b) => b.into(),
289                Err(err) => {
290                    tracing::error!(%err, "send_prologue: ResponseStreamPrologue did not serialize to a JSON array");
291                    return Err("Invalid prologue".to_string());
292                }
293            };
294            self.tx
295                .send(TwoPartMessage::from_header(header_bytes))
296                .await
297                .map_err(|e| e.to_string())?;
298        } else {
299            panic!("Prologue already sent; or not set; logic error");
300        }
301        Ok(())
302    }
303}
304
305pub struct StreamReceiver {
306    rx: tokio::sync::mpsc::Receiver<Bytes>,
307}
308
309/// Connection Info is encoded as JSON and then again serialized has part of the Transport
310/// Layer. The double serialization is not performance critical as it is only done once per
311/// connection. The primary reason storing the ConnecitonInfo has a JSON string is for type
312/// erasure. The Transport Layer will check the [`ConnectionInfo::transport`] type and then
313/// route it to the appropriate instance of the Transport, which will then deserialize the
314/// [`ConnectionInfo::info`] field to its internal connection info object.
315///
316/// Optionally, this object could become strongly typed for which all possible combinations
317/// of transport and connection info would need to be enumerated.
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct ConnectionInfo {
320    pub transport: String,
321    pub info: String,
322}
323
324/// When registering a new TransportStream on the server, the caller specifies if the
325/// stream is a sender, receiver or both.
326///
327/// Senders and Receivers are with share a Context, but result in separate tcp socket
328/// connections to the server. Internally, we may use bcast channels to coordinate the
329/// internal control messages between the sender and receiver socket connections.
330#[derive(Clone, Builder)]
331pub struct StreamOptions {
332    /// Context
333    pub context: Arc<dyn AsyncEngineContext>,
334
335    /// Register with the server that this connection will have a server-side Sender
336    /// that can be picked up by the Request/Forward pipeline
337    ///
338    /// TODO - note, this option is currently not implemented and will cause a panic
339    pub enable_request_stream: bool,
340
341    /// Register with the server that this connection will have a server-side Receiver
342    /// that can be picked up by the Response/Reverse pipeline
343    pub enable_response_stream: bool,
344
345    /// The number of messages to buffer before blocking
346    #[builder(default = "8")]
347    pub send_buffer_count: usize,
348
349    /// The number of messages to buffer before blocking
350    #[builder(default = "8")]
351    pub recv_buffer_count: usize,
352}
353
354impl StreamOptions {
355    pub fn builder() -> StreamOptionsBuilder {
356        StreamOptionsBuilder::default()
357    }
358}
359
360pub struct Egress<Req: PipelineIO, Resp: PipelineIO> {
361    transport_engine: Arc<dyn AsyncTransportEngine<Req, Resp>>,
362}
363
364#[async_trait]
365impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error>
366    for Egress<SingleIn<T>, ManyOut<U>>
367where
368    T: Data + Serialize,
369    U: for<'de> Deserialize<'de> + Data,
370{
371    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
372        self.transport_engine.generate(request).await
373    }
374}
375
376#[derive(Debug, Clone, Serialize, Deserialize)]
377#[serde(rename_all = "snake_case")]
378enum RequestType {
379    SingleIn,
380    ManyIn,
381}
382
383#[derive(Debug, Clone, Serialize, Deserialize)]
384#[serde(rename_all = "snake_case")]
385enum ResponseType {
386    SingleOut,
387    ManyOut,
388}
389
390#[derive(Debug, Clone, Serialize, Deserialize)]
391struct RequestControlMessage {
392    id: String,
393    request_type: RequestType,
394    response_type: ResponseType,
395    connection_info: ConnectionInfo,
396    /// Wall-clock send timestamp (nanos since UNIX epoch) for transport latency breakdown.
397    /// Uses `SystemTime` so accuracy depends on NTP sync between frontend and backend hosts.
398    /// Reliable for single-machine profiling; treat cross-host values as approximate.
399    #[serde(default, skip_serializing_if = "Option::is_none")]
400    frontend_send_ts_ns: Option<u64>,
401}
402
403pub struct Ingress<Req: PipelineIO, Resp: PipelineIO> {
404    segment: OnceLock<Arc<SegmentSource<Req, Resp>>>,
405    metrics: OnceLock<Arc<WorkHandlerMetrics>>,
406    /// Endpoint-specific notifier for health check timer resets
407    endpoint_health_check_notifier: OnceLock<Arc<tokio::sync::Notify>>,
408}
409
410impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> {
411    pub fn new() -> Arc<Self> {
412        Arc::new(Self {
413            segment: OnceLock::new(),
414            metrics: OnceLock::new(),
415            endpoint_health_check_notifier: OnceLock::new(),
416        })
417    }
418
419    pub fn attach(&self, segment: Arc<SegmentSource<Req, Resp>>) -> Result<()> {
420        self.segment
421            .set(segment)
422            .map_err(|_| anyhow::anyhow!("Segment already set"))
423    }
424
425    pub fn add_metrics(
426        &self,
427        endpoint: &crate::component::Endpoint,
428        metrics_labels: Option<&[(&str, &str)]>,
429    ) -> Result<()> {
430        let metrics = WorkHandlerMetrics::from_endpoint(endpoint, metrics_labels)
431            .map_err(|e| anyhow::anyhow!("Failed to create work handler metrics: {}", e))?;
432
433        // Register global transport breakdown metrics (idempotent)
434        crate::metrics::work_handler_perf::ensure_work_handler_perf_metrics_registered(
435            endpoint.get_metrics_registry(),
436        );
437
438        // Register worker-pool saturation metrics (idempotent). These are
439        // process-global and shared across all endpoints attached to the
440        // same shared TCP server.
441        crate::metrics::work_handler_pool::ensure_work_handler_pool_metrics_registered(
442            endpoint.get_metrics_registry(),
443        );
444
445        self.metrics
446            .set(Arc::new(metrics))
447            .map_err(|_| anyhow::anyhow!("Metrics already set"))
448    }
449
450    pub fn link(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
451        let ingress = Ingress::new();
452        ingress.attach(segment)?;
453        Ok(ingress)
454    }
455
456    pub fn for_pipeline(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
457        let ingress = Ingress::new();
458        ingress.attach(segment)?;
459        Ok(ingress)
460    }
461
462    pub fn for_engine(engine: ServiceEngine<Req, Resp>) -> Result<Arc<Self>> {
463        let frontend = SegmentSource::<Req, Resp>::new();
464        let backend = ServiceBackend::from_engine(engine);
465
466        // create the pipeline
467        let pipeline = frontend.link(backend)?.link(frontend)?;
468
469        let ingress = Ingress::new();
470        ingress.attach(pipeline)?;
471
472        Ok(ingress)
473    }
474
475    /// Helper method to access metrics if available
476    fn metrics(&self) -> Option<&Arc<WorkHandlerMetrics>> {
477        self.metrics.get()
478    }
479}
480
481#[async_trait]
482pub trait PushWorkHandler: Send + Sync {
483    async fn handle_payload(
484        &self,
485        payload: Bytes,
486        request_id: Option<String>,
487    ) -> Result<(), PipelineError>;
488
489    /// Add metrics to the handler
490    fn add_metrics(
491        &self,
492        endpoint: &crate::component::Endpoint,
493        metrics_labels: Option<&[(&str, &str)]>,
494    ) -> Result<()>;
495
496    /// Set the endpoint-specific notifier for health check timer resets
497    fn set_endpoint_health_check_notifier(
498        &self,
499        _notifier: Arc<tokio::sync::Notify>,
500    ) -> Result<()> {
501        // Default implementation for backwards compatibility
502        Ok(())
503    }
504}
505
506/*
507/// `NetworkStreamWrapper` is a simple wrapper used to detect proper stream termination
508/// in network communication between ingress and egress components.
509///
510/// **Purpose**: This wrapper solves the problem of detecting whether a stream ended
511/// gracefully or was cut off prematurely (e.g., due to network issues).
512///
513/// **Design Rationale**:
514/// - Cannot use `Annotated` directly because the generic type `U` varies:
515///   - Sometimes `U = Annotated<...>`
516///   - Sometimes `U = LLMEngineOutput<...>`
517/// - Using `Annotated` would require double-wrapping like `Annotated<Annotated<...>>`
518/// - A simple wrapper is cleaner and more straightforward
519///
520/// **Stream Flow**:
521/// ```
522/// At AsyncEngine:
523///   response 1 -> response 2 -> response 3 -> <end>
524///
525/// Between ingress/egress:
526///   response 1 <end=false> -> response 2 <end=false> -> response 3 <end=false> -> (null) <end=true>
527///
528/// At client:
529///   response 1 -> response 2 -> response 3 -> <end>
530/// ```
531///
532/// **Error Handling**:
533/// If the stream is cut off before proper termination, the egress is responsible for
534/// injecting an error response to communicate the incomplete stream to the client:
535/// ```
536/// At AsyncEngine:
537///   response 1 -> ... <without end flag>
538///
539/// At egress:
540///   response 1 <end=false> -> <stream ended without end flag -> convert to error>
541///
542/// At client:
543///   response 1 -> error response
544/// ```
545///
546/// The detection must be done at egress level because premature stream termination
547/// can be due to network issues that only the egress component can detect.
548*/
549/// TODO: Detect end-of-stream using Server-Sent Events (SSE). This will be removed.
550#[derive(Serialize, Deserialize, Debug)]
551pub struct NetworkStreamWrapper<U> {
552    #[serde(skip_serializing_if = "Option::is_none")]
553    pub data: Option<U>,
554    pub complete_final: bool,
555}