dynamo_runtime/pipeline/
network.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! TODO - we need to reconcile what is in this crate with distributed::transports
5
6pub mod codec;
7pub mod egress;
8pub mod ingress;
9pub mod tcp;
10
11use crate::SystemHealth;
12use std::sync::{Arc, OnceLock};
13
14use anyhow::Result;
15use async_trait::async_trait;
16use bytes::Bytes;
17use codec::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
18use derive_builder::Builder;
19use futures::StreamExt;
20// io::Cursor, TryStreamExt
21use super::{AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, ResponseStream};
22use serde::{Deserialize, Serialize};
23
24use super::{
25    AsyncTransportEngine, Context, Data, Error, ManyOut, PipelineError, PipelineIO, SegmentSource,
26    ServiceBackend, ServiceEngine, SingleIn, Source, context,
27};
28use ingress::push_handler::WorkHandlerMetrics;
29
30// Define stream error message constant
31pub const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
32
33// Add Prometheus metrics types
34use crate::metrics::MetricsRegistry;
35use prometheus::{CounterVec, Histogram, IntCounter, IntCounterVec, IntGauge};
36
37pub trait Codable: PipelineIO + Serialize + for<'de> Deserialize<'de> {}
38impl<T: PipelineIO + Serialize + for<'de> Deserialize<'de>> Codable for T {}
39
40/// `WorkQueueConsumer` is a generic interface for a work queue that can be used to send and receive
41#[async_trait]
42pub trait WorkQueueConsumer {
43    async fn dequeue(&self) -> Result<Bytes, String>;
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
47#[serde(rename_all = "snake_case")]
48pub enum StreamType {
49    Request,
50    Response,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
54#[serde(rename_all = "snake_case")]
55pub enum ControlMessage {
56    Stop,
57    Kill,
58    Sentinel,
59}
60
61/// This is the first message in a `ResponseStream`. This is not a message that gets process
62/// by the general pipeline, but is a control message that is awaited before the
63/// [`AsyncEngine::generate`] method is allowed to return.
64///
65/// If an error is present, the [`AsyncEngine::generate`] method will return the error instead
66/// of returning the `ResponseStream`.
67#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
68pub struct ResponseStreamPrologue {
69    error: Option<String>,
70}
71
72pub type StreamProvider<T> = tokio::sync::oneshot::Receiver<Result<T, String>>;
73
74/// The [`RegisteredStream`] object is acquired from a [`StreamProvider`] and is used to provide
75/// an awaitable receiver which will the `T` which is either a stream writer for a request stream
76/// or a stream reader for a response stream.
77///
78/// make this an raii object linked to some stream provider
79/// if the object has not been awaited an the type T unwrapped, the registered stream
80/// on the stream provider will be informed and can clean up a stream that will never
81/// be connected.
82#[derive(Debug)]
83pub struct RegisteredStream<T> {
84    pub connection_info: ConnectionInfo,
85    pub stream_provider: StreamProvider<T>,
86}
87
88impl<T> RegisteredStream<T> {
89    pub fn into_parts(self) -> (ConnectionInfo, StreamProvider<T>) {
90        (self.connection_info, self.stream_provider)
91    }
92}
93
94/// After registering a stream, the [`PendingConnections`] object is returned to the caller. This
95/// object can be used to await the connection to be established.
96pub struct PendingConnections {
97    pub send_stream: Option<RegisteredStream<StreamSender>>,
98    pub recv_stream: Option<RegisteredStream<StreamReceiver>>,
99}
100
101impl PendingConnections {
102    pub fn into_parts(
103        self,
104    ) -> (
105        Option<RegisteredStream<StreamSender>>,
106        Option<RegisteredStream<StreamReceiver>>,
107    ) {
108        (self.send_stream, self.recv_stream)
109    }
110}
111
112/// A [`ResponseService`] implements a services in which a context a specific subject with will
113/// be associated with a stream of responses.
114#[async_trait::async_trait]
115pub trait ResponseService {
116    async fn register(&self, options: StreamOptions) -> PendingConnections;
117}
118
119// #[derive(Debug, Clone, Serialize, Deserialize)]
120// struct Handshake {
121//     request_id: String,
122//     worker_id: Option<String>,
123//     error: Option<String>,
124// }
125
126// impl Handshake {
127//     pub fn validate(&self) -> Result<(), String> {
128//         if let Some(e) = &self.error {
129//             return Err(e.clone());
130//         }
131//         Ok(())
132//     }
133// }
134
135// this probably needs to be come a ResponseStreamSender
136// since the prologue in this scenario sender telling the receiver
137// that all is good and it's ready to send
138//
139// in the RequestStreamSender, the prologue would be coming from the
140// receiver, so the sender would have to await the prologue which if
141// was not an error, would indicate the RequestStreamReceiver is read
142// to receive data.
143pub struct StreamSender {
144    tx: tokio::sync::mpsc::Sender<TwoPartMessage>,
145    prologue: Option<ResponseStreamPrologue>,
146}
147
148impl StreamSender {
149    pub async fn send(&self, data: Bytes) -> Result<()> {
150        Ok(self.tx.send(TwoPartMessage::from_data(data)).await?)
151    }
152
153    pub async fn send_control(&self, control: ControlMessage) -> Result<()> {
154        let bytes = serde_json::to_vec(&control)?;
155        Ok(self
156            .tx
157            .send(TwoPartMessage::from_header(bytes.into()))
158            .await?)
159    }
160
161    #[allow(clippy::needless_update)]
162    pub async fn send_prologue(&mut self, error: Option<String>) -> Result<(), String> {
163        if let Some(prologue) = self.prologue.take() {
164            let prologue = ResponseStreamPrologue { error, ..prologue };
165            let header_bytes: Bytes = match serde_json::to_vec(&prologue) {
166                Ok(b) => b.into(),
167                Err(err) => {
168                    tracing::error!(%err, "send_prologue: ResponseStreamPrologue did not serialize to a JSON array");
169                    return Err("Invalid prologue".to_string());
170                }
171            };
172            self.tx
173                .send(TwoPartMessage::from_header(header_bytes))
174                .await
175                .map_err(|e| e.to_string())?;
176        } else {
177            panic!("Prologue already sent; or not set; logic error");
178        }
179        Ok(())
180    }
181}
182
183pub struct StreamReceiver {
184    rx: tokio::sync::mpsc::Receiver<Bytes>,
185}
186
187/// Connection Info is encoded as JSON and then again serialized has part of the Transport
188/// Layer. The double serialization is not performance critical as it is only done once per
189/// connection. The primary reason storing the ConnecitonInfo has a JSON string is for type
190/// erasure. The Transport Layer will check the [`ConnectionInfo::transport`] type and then
191/// route it to the appropriate instance of the Transport, which will then deserialize the
192/// [`ConnectionInfo::info`] field to its internal connection info object.
193///
194/// Optionally, this object could become strongly typed for which all possible combinations
195/// of transport and connection info would need to be enumerated.
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ConnectionInfo {
198    pub transport: String,
199    pub info: String,
200}
201
202/// When registering a new TransportStream on the server, the caller specifies if the
203/// stream is a sender, receiver or both.
204///
205/// Senders and Receivers are with share a Context, but result in separate tcp socket
206/// connections to the server. Internally, we may use bcast channels to coordinate the
207/// internal control messages between the sender and receiver socket connections.
208#[derive(Clone, Builder)]
209pub struct StreamOptions {
210    /// Context
211    pub context: Arc<dyn AsyncEngineContext>,
212
213    /// Register with the server that this connection will have a server-side Sender
214    /// that can be picked up by the Request/Forward pipeline
215    ///
216    /// TODO - note, this option is currently not implemented and will cause a panic
217    pub enable_request_stream: bool,
218
219    /// Register with the server that this connection will have a server-side Receiver
220    /// that can be picked up by the Response/Reverse pipeline
221    pub enable_response_stream: bool,
222
223    /// The number of messages to buffer before blocking
224    #[builder(default = "8")]
225    pub send_buffer_count: usize,
226
227    /// The number of messages to buffer before blocking
228    #[builder(default = "8")]
229    pub recv_buffer_count: usize,
230}
231
232impl StreamOptions {
233    pub fn builder() -> StreamOptionsBuilder {
234        StreamOptionsBuilder::default()
235    }
236}
237
238pub struct Egress<Req: PipelineIO, Resp: PipelineIO> {
239    transport_engine: Arc<dyn AsyncTransportEngine<Req, Resp>>,
240}
241
242#[async_trait]
243impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error>
244    for Egress<SingleIn<T>, ManyOut<U>>
245where
246    T: Data + Serialize,
247    U: for<'de> Deserialize<'de> + Data,
248{
249    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
250        self.transport_engine.generate(request).await
251    }
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
255#[serde(rename_all = "snake_case")]
256enum RequestType {
257    SingleIn,
258    ManyIn,
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
262#[serde(rename_all = "snake_case")]
263enum ResponseType {
264    SingleOut,
265    ManyOut,
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
269struct RequestControlMessage {
270    id: String,
271    request_type: RequestType,
272    response_type: ResponseType,
273    connection_info: ConnectionInfo,
274}
275
276pub struct Ingress<Req: PipelineIO, Resp: PipelineIO> {
277    segment: OnceLock<Arc<SegmentSource<Req, Resp>>>,
278    metrics: OnceLock<Arc<WorkHandlerMetrics>>,
279    /// Endpoint-specific notifier for health check timer resets
280    endpoint_health_check_notifier: OnceLock<Arc<tokio::sync::Notify>>,
281}
282
283impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> {
284    pub fn new() -> Arc<Self> {
285        Arc::new(Self {
286            segment: OnceLock::new(),
287            metrics: OnceLock::new(),
288            endpoint_health_check_notifier: OnceLock::new(),
289        })
290    }
291
292    pub fn attach(&self, segment: Arc<SegmentSource<Req, Resp>>) -> Result<()> {
293        self.segment
294            .set(segment)
295            .map_err(|_| anyhow::anyhow!("Segment already set"))
296    }
297
298    pub fn add_metrics(
299        &self,
300        endpoint: &crate::component::Endpoint,
301        metrics_labels: Option<&[(&str, &str)]>,
302    ) -> Result<()> {
303        let metrics = WorkHandlerMetrics::from_endpoint(endpoint, metrics_labels)
304            .map_err(|e| anyhow::anyhow!("Failed to create work handler metrics: {}", e))?;
305
306        self.metrics
307            .set(Arc::new(metrics))
308            .map_err(|_| anyhow::anyhow!("Metrics already set"))
309    }
310
311    pub fn link(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
312        let ingress = Ingress::new();
313        ingress.attach(segment)?;
314        Ok(ingress)
315    }
316
317    pub fn for_pipeline(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
318        let ingress = Ingress::new();
319        ingress.attach(segment)?;
320        Ok(ingress)
321    }
322
323    pub fn for_engine(engine: ServiceEngine<Req, Resp>) -> Result<Arc<Self>> {
324        let frontend = SegmentSource::<Req, Resp>::new();
325        let backend = ServiceBackend::from_engine(engine);
326
327        // create the pipeline
328        let pipeline = frontend.link(backend)?.link(frontend)?;
329
330        let ingress = Ingress::new();
331        ingress.attach(pipeline)?;
332
333        Ok(ingress)
334    }
335
336    /// Helper method to access metrics if available
337    fn metrics(&self) -> Option<&Arc<WorkHandlerMetrics>> {
338        self.metrics.get()
339    }
340}
341
342#[async_trait]
343pub trait PushWorkHandler: Send + Sync {
344    async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError>;
345
346    /// Add metrics to the handler
347    fn add_metrics(
348        &self,
349        endpoint: &crate::component::Endpoint,
350        metrics_labels: Option<&[(&str, &str)]>,
351    ) -> Result<()>;
352
353    /// Set the endpoint-specific notifier for health check timer resets
354    fn set_endpoint_health_check_notifier(
355        &self,
356        _notifier: Arc<tokio::sync::Notify>,
357    ) -> Result<()> {
358        // Default implementation for backwards compatibility
359        Ok(())
360    }
361}
362
363/*
364/// `NetworkStreamWrapper` is a simple wrapper used to detect proper stream termination
365/// in network communication between ingress and egress components.
366///
367/// **Purpose**: This wrapper solves the problem of detecting whether a stream ended
368/// gracefully or was cut off prematurely (e.g., due to network issues).
369///
370/// **Design Rationale**:
371/// - Cannot use `Annotated` directly because the generic type `U` varies:
372///   - Sometimes `U = Annotated<...>`
373///   - Sometimes `U = LLMEngineOutput<...>`
374/// - Using `Annotated` would require double-wrapping like `Annotated<Annotated<...>>`
375/// - A simple wrapper is cleaner and more straightforward
376///
377/// **Stream Flow**:
378/// ```
379/// At AsyncEngine:
380///   response 1 -> response 2 -> response 3 -> <end>
381///
382/// Between ingress/egress:
383///   response 1 <end=false> -> response 2 <end=false> -> response 3 <end=false> -> (null) <end=true>
384///
385/// At client:
386///   response 1 -> response 2 -> response 3 -> <end>
387/// ```
388///
389/// **Error Handling**:
390/// If the stream is cut off before proper termination, the egress is responsible for
391/// injecting an error response to communicate the incomplete stream to the client:
392/// ```
393/// At AsyncEngine:
394///   response 1 -> ... <without end flag>
395///
396/// At egress:
397///   response 1 <end=false> -> <stream ended without end flag -> convert to error>
398///
399/// At client:
400///   response 1 -> error response
401/// ```
402///
403/// The detection must be done at egress level because premature stream termination
404/// can be due to network issues that only the egress component can detect.
405*/
406/// TODO: Detect end-of-stream using Server-Sent Events (SSE). This will be removed.
407#[derive(Serialize, Deserialize, Debug)]
408pub struct NetworkStreamWrapper<U> {
409    #[serde(skip_serializing_if = "Option::is_none")]
410    pub data: Option<U>,
411    pub complete_final: bool,
412}