Skip to main content

dynamo_runtime/pipeline/
network.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Network layer for distributed communication
5//!
6//! Provides request distribution across multiple transport protocols:
7//! - HTTP/2 for standard deployments
8//! - TCP with length-prefixed protocol for high-performance scenarios
9//! - NATS for legacy/messaging-based deployments
10
11pub mod codec;
12pub mod egress;
13pub mod ingress;
14pub mod manager;
15pub mod tcp;
16
17use crate::SystemHealth;
18use std::sync::{Arc, OnceLock};
19
20use anyhow::Result;
21use async_trait::async_trait;
22use bytes::Bytes;
23use codec::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
24use derive_builder::Builder;
25use futures::StreamExt;
26// io::Cursor, TryStreamExt
27use super::{AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, ResponseStream};
28use serde::{Deserialize, Serialize};
29
30use super::{
31    AsyncTransportEngine, Context, Data, Error, ManyOut, PipelineError, PipelineIO, SegmentSource,
32    ServiceBackend, ServiceEngine, SingleIn, Source, context,
33};
34use ingress::push_handler::WorkHandlerMetrics;
35
36// Define stream error message constant
37pub const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
38
39// Add Prometheus metrics types
40use crate::metrics::MetricsHierarchy;
41use prometheus::{CounterVec, Histogram, IntCounter, IntCounterVec, IntGauge};
42
43pub trait Codable: PipelineIO + Serialize + for<'de> Deserialize<'de> {}
44impl<T: PipelineIO + Serialize + for<'de> Deserialize<'de>> Codable for T {}
45
46/// `WorkQueueConsumer` is a generic interface for a work queue that can be used to send and receive
47#[async_trait]
48pub trait WorkQueueConsumer {
49    async fn dequeue(&self) -> Result<Bytes, String>;
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
53#[serde(rename_all = "snake_case")]
54pub enum StreamType {
55    Request,
56    Response,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
60#[serde(rename_all = "snake_case")]
61pub enum ControlMessage {
62    Stop,
63    Kill,
64    Sentinel,
65}
66
67/// This is the first message in a `ResponseStream`. This is not a message that gets process
68/// by the general pipeline, but is a control message that is awaited before the
69/// [`AsyncEngine::generate`] method is allowed to return.
70///
71/// If an error is present, the [`AsyncEngine::generate`] method will return the error instead
72/// of returning the `ResponseStream`.
73#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
74pub struct ResponseStreamPrologue {
75    error: Option<String>,
76}
77
78pub type StreamProvider<T> = tokio::sync::oneshot::Receiver<Result<T, String>>;
79
80/// The [`RegisteredStream`] object is acquired from a [`StreamProvider`] and is used to provide
81/// an awaitable receiver which will the `T` which is either a stream writer for a request stream
82/// or a stream reader for a response stream.
83///
84/// make this an raii object linked to some stream provider
85/// if the object has not been awaited an the type T unwrapped, the registered stream
86/// on the stream provider will be informed and can clean up a stream that will never
87/// be connected.
88#[derive(Debug)]
89pub struct RegisteredStream<T> {
90    pub connection_info: ConnectionInfo,
91    pub stream_provider: StreamProvider<T>,
92}
93
94impl<T> RegisteredStream<T> {
95    pub fn into_parts(self) -> (ConnectionInfo, StreamProvider<T>) {
96        (self.connection_info, self.stream_provider)
97    }
98}
99
100/// After registering a stream, the [`PendingConnections`] object is returned to the caller. This
101/// object can be used to await the connection to be established.
102pub struct PendingConnections {
103    pub send_stream: Option<RegisteredStream<StreamSender>>,
104    pub recv_stream: Option<RegisteredStream<StreamReceiver>>,
105}
106
107impl PendingConnections {
108    pub fn into_parts(
109        self,
110    ) -> (
111        Option<RegisteredStream<StreamSender>>,
112        Option<RegisteredStream<StreamReceiver>>,
113    ) {
114        (self.send_stream, self.recv_stream)
115    }
116}
117
118/// A [`ResponseService`] implements a services in which a context a specific subject with will
119/// be associated with a stream of responses.
120#[async_trait::async_trait]
121pub trait ResponseService {
122    async fn register(&self, options: StreamOptions) -> PendingConnections;
123}
124
125// #[derive(Debug, Clone, Serialize, Deserialize)]
126// struct Handshake {
127//     request_id: String,
128//     worker_id: Option<String>,
129//     error: Option<String>,
130// }
131
132// impl Handshake {
133//     pub fn validate(&self) -> Result<(), String> {
134//         if let Some(e) = &self.error {
135//             return Err(e.clone());
136//         }
137//         Ok(())
138//     }
139// }
140
141// this probably needs to be come a ResponseStreamSender
142// since the prologue in this scenario sender telling the receiver
143// that all is good and it's ready to send
144//
145// in the RequestStreamSender, the prologue would be coming from the
146// receiver, so the sender would have to await the prologue which if
147// was not an error, would indicate the RequestStreamReceiver is read
148// to receive data.
149pub struct StreamSender {
150    tx: tokio::sync::mpsc::Sender<TwoPartMessage>,
151    prologue: Option<ResponseStreamPrologue>,
152}
153
154impl StreamSender {
155    pub async fn send(&self, data: Bytes) -> Result<()> {
156        Ok(self.tx.send(TwoPartMessage::from_data(data)).await?)
157    }
158
159    pub async fn send_control(&self, control: ControlMessage) -> Result<()> {
160        let bytes = serde_json::to_vec(&control)?;
161        Ok(self
162            .tx
163            .send(TwoPartMessage::from_header(bytes.into()))
164            .await?)
165    }
166
167    #[allow(clippy::needless_update)]
168    pub async fn send_prologue(&mut self, error: Option<String>) -> Result<(), String> {
169        if let Some(prologue) = self.prologue.take() {
170            let prologue = ResponseStreamPrologue { error, ..prologue };
171            let header_bytes: Bytes = match serde_json::to_vec(&prologue) {
172                Ok(b) => b.into(),
173                Err(err) => {
174                    tracing::error!(%err, "send_prologue: ResponseStreamPrologue did not serialize to a JSON array");
175                    return Err("Invalid prologue".to_string());
176                }
177            };
178            self.tx
179                .send(TwoPartMessage::from_header(header_bytes))
180                .await
181                .map_err(|e| e.to_string())?;
182        } else {
183            panic!("Prologue already sent; or not set; logic error");
184        }
185        Ok(())
186    }
187}
188
189pub struct StreamReceiver {
190    rx: tokio::sync::mpsc::Receiver<Bytes>,
191}
192
193/// Connection Info is encoded as JSON and then again serialized has part of the Transport
194/// Layer. The double serialization is not performance critical as it is only done once per
195/// connection. The primary reason storing the ConnecitonInfo has a JSON string is for type
196/// erasure. The Transport Layer will check the [`ConnectionInfo::transport`] type and then
197/// route it to the appropriate instance of the Transport, which will then deserialize the
198/// [`ConnectionInfo::info`] field to its internal connection info object.
199///
200/// Optionally, this object could become strongly typed for which all possible combinations
201/// of transport and connection info would need to be enumerated.
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct ConnectionInfo {
204    pub transport: String,
205    pub info: String,
206}
207
208/// When registering a new TransportStream on the server, the caller specifies if the
209/// stream is a sender, receiver or both.
210///
211/// Senders and Receivers are with share a Context, but result in separate tcp socket
212/// connections to the server. Internally, we may use bcast channels to coordinate the
213/// internal control messages between the sender and receiver socket connections.
214#[derive(Clone, Builder)]
215pub struct StreamOptions {
216    /// Context
217    pub context: Arc<dyn AsyncEngineContext>,
218
219    /// Register with the server that this connection will have a server-side Sender
220    /// that can be picked up by the Request/Forward pipeline
221    ///
222    /// TODO - note, this option is currently not implemented and will cause a panic
223    pub enable_request_stream: bool,
224
225    /// Register with the server that this connection will have a server-side Receiver
226    /// that can be picked up by the Response/Reverse pipeline
227    pub enable_response_stream: bool,
228
229    /// The number of messages to buffer before blocking
230    #[builder(default = "8")]
231    pub send_buffer_count: usize,
232
233    /// The number of messages to buffer before blocking
234    #[builder(default = "8")]
235    pub recv_buffer_count: usize,
236}
237
238impl StreamOptions {
239    pub fn builder() -> StreamOptionsBuilder {
240        StreamOptionsBuilder::default()
241    }
242}
243
244pub struct Egress<Req: PipelineIO, Resp: PipelineIO> {
245    transport_engine: Arc<dyn AsyncTransportEngine<Req, Resp>>,
246}
247
248#[async_trait]
249impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error>
250    for Egress<SingleIn<T>, ManyOut<U>>
251where
252    T: Data + Serialize,
253    U: for<'de> Deserialize<'de> + Data,
254{
255    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
256        self.transport_engine.generate(request).await
257    }
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
261#[serde(rename_all = "snake_case")]
262enum RequestType {
263    SingleIn,
264    ManyIn,
265}
266
267#[derive(Debug, Clone, Serialize, Deserialize)]
268#[serde(rename_all = "snake_case")]
269enum ResponseType {
270    SingleOut,
271    ManyOut,
272}
273
274#[derive(Debug, Clone, Serialize, Deserialize)]
275struct RequestControlMessage {
276    id: String,
277    request_type: RequestType,
278    response_type: ResponseType,
279    connection_info: ConnectionInfo,
280}
281
282pub struct Ingress<Req: PipelineIO, Resp: PipelineIO> {
283    segment: OnceLock<Arc<SegmentSource<Req, Resp>>>,
284    metrics: OnceLock<Arc<WorkHandlerMetrics>>,
285    /// Endpoint-specific notifier for health check timer resets
286    endpoint_health_check_notifier: OnceLock<Arc<tokio::sync::Notify>>,
287}
288
289impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> {
290    pub fn new() -> Arc<Self> {
291        Arc::new(Self {
292            segment: OnceLock::new(),
293            metrics: OnceLock::new(),
294            endpoint_health_check_notifier: OnceLock::new(),
295        })
296    }
297
298    pub fn attach(&self, segment: Arc<SegmentSource<Req, Resp>>) -> Result<()> {
299        self.segment
300            .set(segment)
301            .map_err(|_| anyhow::anyhow!("Segment already set"))
302    }
303
304    pub fn add_metrics(
305        &self,
306        endpoint: &crate::component::Endpoint,
307        metrics_labels: Option<&[(&str, &str)]>,
308    ) -> Result<()> {
309        let metrics = WorkHandlerMetrics::from_endpoint(endpoint, metrics_labels)
310            .map_err(|e| anyhow::anyhow!("Failed to create work handler metrics: {}", e))?;
311
312        self.metrics
313            .set(Arc::new(metrics))
314            .map_err(|_| anyhow::anyhow!("Metrics already set"))
315    }
316
317    pub fn link(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
318        let ingress = Ingress::new();
319        ingress.attach(segment)?;
320        Ok(ingress)
321    }
322
323    pub fn for_pipeline(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
324        let ingress = Ingress::new();
325        ingress.attach(segment)?;
326        Ok(ingress)
327    }
328
329    pub fn for_engine(engine: ServiceEngine<Req, Resp>) -> Result<Arc<Self>> {
330        let frontend = SegmentSource::<Req, Resp>::new();
331        let backend = ServiceBackend::from_engine(engine);
332
333        // create the pipeline
334        let pipeline = frontend.link(backend)?.link(frontend)?;
335
336        let ingress = Ingress::new();
337        ingress.attach(pipeline)?;
338
339        Ok(ingress)
340    }
341
342    /// Helper method to access metrics if available
343    fn metrics(&self) -> Option<&Arc<WorkHandlerMetrics>> {
344        self.metrics.get()
345    }
346}
347
348#[async_trait]
349pub trait PushWorkHandler: Send + Sync {
350    async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError>;
351
352    /// Add metrics to the handler
353    fn add_metrics(
354        &self,
355        endpoint: &crate::component::Endpoint,
356        metrics_labels: Option<&[(&str, &str)]>,
357    ) -> Result<()>;
358
359    /// Set the endpoint-specific notifier for health check timer resets
360    fn set_endpoint_health_check_notifier(
361        &self,
362        _notifier: Arc<tokio::sync::Notify>,
363    ) -> Result<()> {
364        // Default implementation for backwards compatibility
365        Ok(())
366    }
367}
368
369/*
370/// `NetworkStreamWrapper` is a simple wrapper used to detect proper stream termination
371/// in network communication between ingress and egress components.
372///
373/// **Purpose**: This wrapper solves the problem of detecting whether a stream ended
374/// gracefully or was cut off prematurely (e.g., due to network issues).
375///
376/// **Design Rationale**:
377/// - Cannot use `Annotated` directly because the generic type `U` varies:
378///   - Sometimes `U = Annotated<...>`
379///   - Sometimes `U = LLMEngineOutput<...>`
380/// - Using `Annotated` would require double-wrapping like `Annotated<Annotated<...>>`
381/// - A simple wrapper is cleaner and more straightforward
382///
383/// **Stream Flow**:
384/// ```
385/// At AsyncEngine:
386///   response 1 -> response 2 -> response 3 -> <end>
387///
388/// Between ingress/egress:
389///   response 1 <end=false> -> response 2 <end=false> -> response 3 <end=false> -> (null) <end=true>
390///
391/// At client:
392///   response 1 -> response 2 -> response 3 -> <end>
393/// ```
394///
395/// **Error Handling**:
396/// If the stream is cut off before proper termination, the egress is responsible for
397/// injecting an error response to communicate the incomplete stream to the client:
398/// ```
399/// At AsyncEngine:
400///   response 1 -> ... <without end flag>
401///
402/// At egress:
403///   response 1 <end=false> -> <stream ended without end flag -> convert to error>
404///
405/// At client:
406///   response 1 -> error response
407/// ```
408///
409/// The detection must be done at egress level because premature stream termination
410/// can be due to network issues that only the egress component can detect.
411*/
412/// TODO: Detect end-of-stream using Server-Sent Events (SSE). This will be removed.
413#[derive(Serialize, Deserialize, Debug)]
414pub struct NetworkStreamWrapper<U> {
415    #[serde(skip_serializing_if = "Option::is_none")]
416    pub data: Option<U>,
417    pub complete_final: bool,
418}