Skip to main content

dynamo_runtime/pipeline/
network.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Network layer for distributed communication
5//!
6//! Provides request distribution across multiple transport protocols:
7//! - HTTP/2 for standard deployments
8//! - TCP with length-prefixed protocol for high-performance scenarios
9//! - NATS for legacy/messaging-based deployments
10
11pub mod codec;
12pub mod egress;
13pub mod ingress;
14pub mod manager;
15pub mod tcp;
16
17use crate::SystemHealth;
18use std::sync::{Arc, OnceLock};
19
20use anyhow::Result;
21use async_trait::async_trait;
22use bytes::Bytes;
23use codec::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
24use derive_builder::Builder;
25use futures::StreamExt;
26// io::Cursor, TryStreamExt
27use super::{AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, ResponseStream};
28use serde::{Deserialize, Serialize};
29
30use super::{
31    AsyncTransportEngine, Context, Data, Error, ManyOut, PipelineError, PipelineIO, SegmentSource,
32    ServiceBackend, ServiceEngine, SingleIn, Source, context,
33};
34use ingress::push_handler::WorkHandlerMetrics;
35
36// Add Prometheus metrics types
37use crate::metrics::MetricsHierarchy;
38use prometheus::{CounterVec, Histogram, IntCounter, IntCounterVec, IntGauge};
39
40pub trait Codable: PipelineIO + Serialize + for<'de> Deserialize<'de> {}
41impl<T: PipelineIO + Serialize + for<'de> Deserialize<'de>> Codable for T {}
42
43/// `WorkQueueConsumer` is a generic interface for a work queue that can be used to send and receive
44#[async_trait]
45pub trait WorkQueueConsumer {
46    async fn dequeue(&self) -> Result<Bytes, String>;
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
50#[serde(rename_all = "snake_case")]
51pub enum StreamType {
52    Request,
53    Response,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
57#[serde(rename_all = "snake_case")]
58pub enum ControlMessage {
59    Stop,
60    Kill,
61    Sentinel,
62}
63
64/// This is the first message in a `ResponseStream`. This is not a message that gets process
65/// by the general pipeline, but is a control message that is awaited before the
66/// [`AsyncEngine::generate`] method is allowed to return.
67///
68/// If an error is present, the [`AsyncEngine::generate`] method will return the error instead
69/// of returning the `ResponseStream`.
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
71pub struct ResponseStreamPrologue {
72    error: Option<String>,
73}
74
75pub type StreamProvider<T> = tokio::sync::oneshot::Receiver<Result<T, String>>;
76
77/// The [`RegisteredStream`] object is acquired from a [`StreamProvider`] and is used to provide
78/// an awaitable receiver which will the `T` which is either a stream writer for a request stream
79/// or a stream reader for a response stream.
80///
81/// make this an raii object linked to some stream provider
82/// if the object has not been awaited an the type T unwrapped, the registered stream
83/// on the stream provider will be informed and can clean up a stream that will never
84/// be connected.
85#[derive(Debug)]
86pub struct RegisteredStream<T> {
87    pub connection_info: ConnectionInfo,
88    pub stream_provider: StreamProvider<T>,
89}
90
91impl<T> RegisteredStream<T> {
92    pub fn into_parts(self) -> (ConnectionInfo, StreamProvider<T>) {
93        (self.connection_info, self.stream_provider)
94    }
95}
96
97/// After registering a stream, the [`PendingConnections`] object is returned to the caller. This
98/// object can be used to await the connection to be established.
99pub struct PendingConnections {
100    pub send_stream: Option<RegisteredStream<StreamSender>>,
101    pub recv_stream: Option<RegisteredStream<StreamReceiver>>,
102}
103
104impl PendingConnections {
105    pub fn into_parts(
106        self,
107    ) -> (
108        Option<RegisteredStream<StreamSender>>,
109        Option<RegisteredStream<StreamReceiver>>,
110    ) {
111        (self.send_stream, self.recv_stream)
112    }
113}
114
115/// A [`ResponseService`] implements a services in which a context a specific subject with will
116/// be associated with a stream of responses.
117#[async_trait::async_trait]
118pub trait ResponseService {
119    async fn register(&self, options: StreamOptions) -> PendingConnections;
120}
121
122// #[derive(Debug, Clone, Serialize, Deserialize)]
123// struct Handshake {
124//     request_id: String,
125//     worker_id: Option<String>,
126//     error: Option<String>,
127// }
128
129// impl Handshake {
130//     pub fn validate(&self) -> Result<(), String> {
131//         if let Some(e) = &self.error {
132//             return Err(e.clone());
133//         }
134//         Ok(())
135//     }
136// }
137
138// this probably needs to be come a ResponseStreamSender
139// since the prologue in this scenario sender telling the receiver
140// that all is good and it's ready to send
141//
142// in the RequestStreamSender, the prologue would be coming from the
143// receiver, so the sender would have to await the prologue which if
144// was not an error, would indicate the RequestStreamReceiver is read
145// to receive data.
146pub struct StreamSender {
147    tx: tokio::sync::mpsc::Sender<TwoPartMessage>,
148    prologue: Option<ResponseStreamPrologue>,
149}
150
151impl StreamSender {
152    pub async fn send(&self, data: Bytes) -> Result<()> {
153        Ok(self.tx.send(TwoPartMessage::from_data(data)).await?)
154    }
155
156    pub async fn send_control(&self, control: ControlMessage) -> Result<()> {
157        let bytes = serde_json::to_vec(&control)?;
158        Ok(self
159            .tx
160            .send(TwoPartMessage::from_header(bytes.into()))
161            .await?)
162    }
163
164    #[allow(clippy::needless_update)]
165    pub async fn send_prologue(&mut self, error: Option<String>) -> Result<(), String> {
166        // leaving the original logic in place for now
167        // error overrides the dissolved prologue, but the only field on `ResponseStreamPrologue` is `error`
168        // so the second argument can never be used, and the value of error passed by the caller would always be used
169        if let Some(_prologue) = self.prologue.take() {
170            // let prologue = ResponseStreamPrologue { error, ..prologue };
171            let prologue = ResponseStreamPrologue { error };
172            let header_bytes: Bytes = match serde_json::to_vec(&prologue) {
173                Ok(b) => b.into(),
174                Err(err) => {
175                    tracing::error!(%err, "send_prologue: ResponseStreamPrologue did not serialize to a JSON array");
176                    return Err("Invalid prologue".to_string());
177                }
178            };
179            self.tx
180                .send(TwoPartMessage::from_header(header_bytes))
181                .await
182                .map_err(|e| e.to_string())?;
183        } else {
184            panic!("Prologue already sent; or not set; logic error");
185        }
186        Ok(())
187    }
188}
189
190pub struct StreamReceiver {
191    rx: tokio::sync::mpsc::Receiver<Bytes>,
192}
193
194/// Connection Info is encoded as JSON and then again serialized has part of the Transport
195/// Layer. The double serialization is not performance critical as it is only done once per
196/// connection. The primary reason storing the ConnecitonInfo has a JSON string is for type
197/// erasure. The Transport Layer will check the [`ConnectionInfo::transport`] type and then
198/// route it to the appropriate instance of the Transport, which will then deserialize the
199/// [`ConnectionInfo::info`] field to its internal connection info object.
200///
201/// Optionally, this object could become strongly typed for which all possible combinations
202/// of transport and connection info would need to be enumerated.
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct ConnectionInfo {
205    pub transport: String,
206    pub info: String,
207}
208
209/// When registering a new TransportStream on the server, the caller specifies if the
210/// stream is a sender, receiver or both.
211///
212/// Senders and Receivers are with share a Context, but result in separate tcp socket
213/// connections to the server. Internally, we may use bcast channels to coordinate the
214/// internal control messages between the sender and receiver socket connections.
215#[derive(Clone, Builder)]
216pub struct StreamOptions {
217    /// Context
218    pub context: Arc<dyn AsyncEngineContext>,
219
220    /// Register with the server that this connection will have a server-side Sender
221    /// that can be picked up by the Request/Forward pipeline
222    ///
223    /// TODO - note, this option is currently not implemented and will cause a panic
224    pub enable_request_stream: bool,
225
226    /// Register with the server that this connection will have a server-side Receiver
227    /// that can be picked up by the Response/Reverse pipeline
228    pub enable_response_stream: bool,
229
230    /// The number of messages to buffer before blocking
231    #[builder(default = "8")]
232    pub send_buffer_count: usize,
233
234    /// The number of messages to buffer before blocking
235    #[builder(default = "8")]
236    pub recv_buffer_count: usize,
237}
238
239impl StreamOptions {
240    pub fn builder() -> StreamOptionsBuilder {
241        StreamOptionsBuilder::default()
242    }
243}
244
245pub struct Egress<Req: PipelineIO, Resp: PipelineIO> {
246    transport_engine: Arc<dyn AsyncTransportEngine<Req, Resp>>,
247}
248
249#[async_trait]
250impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error>
251    for Egress<SingleIn<T>, ManyOut<U>>
252where
253    T: Data + Serialize,
254    U: for<'de> Deserialize<'de> + Data,
255{
256    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
257        self.transport_engine.generate(request).await
258    }
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
262#[serde(rename_all = "snake_case")]
263enum RequestType {
264    SingleIn,
265    ManyIn,
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
269#[serde(rename_all = "snake_case")]
270enum ResponseType {
271    SingleOut,
272    ManyOut,
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
276struct RequestControlMessage {
277    id: String,
278    request_type: RequestType,
279    response_type: ResponseType,
280    connection_info: ConnectionInfo,
281    /// Wall-clock send timestamp (nanos since UNIX epoch) for transport latency breakdown.
282    /// Uses `SystemTime` so accuracy depends on NTP sync between frontend and backend hosts.
283    /// Reliable for single-machine profiling; treat cross-host values as approximate.
284    #[serde(default, skip_serializing_if = "Option::is_none")]
285    frontend_send_ts_ns: Option<u64>,
286}
287
288pub struct Ingress<Req: PipelineIO, Resp: PipelineIO> {
289    segment: OnceLock<Arc<SegmentSource<Req, Resp>>>,
290    metrics: OnceLock<Arc<WorkHandlerMetrics>>,
291    /// Endpoint-specific notifier for health check timer resets
292    endpoint_health_check_notifier: OnceLock<Arc<tokio::sync::Notify>>,
293}
294
295impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> {
296    pub fn new() -> Arc<Self> {
297        Arc::new(Self {
298            segment: OnceLock::new(),
299            metrics: OnceLock::new(),
300            endpoint_health_check_notifier: OnceLock::new(),
301        })
302    }
303
304    pub fn attach(&self, segment: Arc<SegmentSource<Req, Resp>>) -> Result<()> {
305        self.segment
306            .set(segment)
307            .map_err(|_| anyhow::anyhow!("Segment already set"))
308    }
309
310    pub fn add_metrics(
311        &self,
312        endpoint: &crate::component::Endpoint,
313        metrics_labels: Option<&[(&str, &str)]>,
314    ) -> Result<()> {
315        let metrics = WorkHandlerMetrics::from_endpoint(endpoint, metrics_labels)
316            .map_err(|e| anyhow::anyhow!("Failed to create work handler metrics: {}", e))?;
317
318        // Register global transport breakdown metrics (idempotent)
319        crate::metrics::work_handler_perf::ensure_work_handler_perf_metrics_registered(
320            endpoint.get_metrics_registry(),
321        );
322
323        self.metrics
324            .set(Arc::new(metrics))
325            .map_err(|_| anyhow::anyhow!("Metrics already set"))
326    }
327
328    pub fn link(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
329        let ingress = Ingress::new();
330        ingress.attach(segment)?;
331        Ok(ingress)
332    }
333
334    pub fn for_pipeline(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
335        let ingress = Ingress::new();
336        ingress.attach(segment)?;
337        Ok(ingress)
338    }
339
340    pub fn for_engine(engine: ServiceEngine<Req, Resp>) -> Result<Arc<Self>> {
341        let frontend = SegmentSource::<Req, Resp>::new();
342        let backend = ServiceBackend::from_engine(engine);
343
344        // create the pipeline
345        let pipeline = frontend.link(backend)?.link(frontend)?;
346
347        let ingress = Ingress::new();
348        ingress.attach(pipeline)?;
349
350        Ok(ingress)
351    }
352
353    /// Helper method to access metrics if available
354    fn metrics(&self) -> Option<&Arc<WorkHandlerMetrics>> {
355        self.metrics.get()
356    }
357}
358
359#[async_trait]
360pub trait PushWorkHandler: Send + Sync {
361    async fn handle_payload(
362        &self,
363        payload: Bytes,
364        request_id: Option<String>,
365    ) -> Result<(), PipelineError>;
366
367    /// Add metrics to the handler
368    fn add_metrics(
369        &self,
370        endpoint: &crate::component::Endpoint,
371        metrics_labels: Option<&[(&str, &str)]>,
372    ) -> Result<()>;
373
374    /// Set the endpoint-specific notifier for health check timer resets
375    fn set_endpoint_health_check_notifier(
376        &self,
377        _notifier: Arc<tokio::sync::Notify>,
378    ) -> Result<()> {
379        // Default implementation for backwards compatibility
380        Ok(())
381    }
382}
383
384/*
385/// `NetworkStreamWrapper` is a simple wrapper used to detect proper stream termination
386/// in network communication between ingress and egress components.
387///
388/// **Purpose**: This wrapper solves the problem of detecting whether a stream ended
389/// gracefully or was cut off prematurely (e.g., due to network issues).
390///
391/// **Design Rationale**:
392/// - Cannot use `Annotated` directly because the generic type `U` varies:
393///   - Sometimes `U = Annotated<...>`
394///   - Sometimes `U = LLMEngineOutput<...>`
395/// - Using `Annotated` would require double-wrapping like `Annotated<Annotated<...>>`
396/// - A simple wrapper is cleaner and more straightforward
397///
398/// **Stream Flow**:
399/// ```
400/// At AsyncEngine:
401///   response 1 -> response 2 -> response 3 -> <end>
402///
403/// Between ingress/egress:
404///   response 1 <end=false> -> response 2 <end=false> -> response 3 <end=false> -> (null) <end=true>
405///
406/// At client:
407///   response 1 -> response 2 -> response 3 -> <end>
408/// ```
409///
410/// **Error Handling**:
411/// If the stream is cut off before proper termination, the egress is responsible for
412/// injecting an error response to communicate the incomplete stream to the client:
413/// ```
414/// At AsyncEngine:
415///   response 1 -> ... <without end flag>
416///
417/// At egress:
418///   response 1 <end=false> -> <stream ended without end flag -> convert to error>
419///
420/// At client:
421///   response 1 -> error response
422/// ```
423///
424/// The detection must be done at egress level because premature stream termination
425/// can be due to network issues that only the egress component can detect.
426*/
427/// TODO: Detect end-of-stream using Server-Sent Events (SSE). This will be removed.
428#[derive(Serialize, Deserialize, Debug)]
429pub struct NetworkStreamWrapper<U> {
430    #[serde(skip_serializing_if = "Option::is_none")]
431    pub data: Option<U>,
432    pub complete_final: bool,
433}