dynamo_runtime/pipeline/
network.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! TODO - we need to reconcile what is in this crate with distributed::transports
17
18pub mod codec;
19pub mod egress;
20pub mod ingress;
21pub mod tcp;
22
23use std::sync::{Arc, OnceLock};
24
25use anyhow::Result;
26use async_trait::async_trait;
27use bytes::Bytes;
28use codec::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
29use derive_builder::Builder;
30use futures::StreamExt;
31// io::Cursor, TryStreamExt
32use super::{AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, ResponseStream};
33use serde::{Deserialize, Serialize};
34
35use super::{
36    context, AsyncTransportEngine, Context, Data, Error, ManyOut, PipelineError, PipelineIO,
37    SegmentSource, ServiceBackend, ServiceEngine, SingleIn, Source,
38};
39
40pub trait Codable: PipelineIO + Serialize + for<'de> Deserialize<'de> {}
41impl<T: PipelineIO + Serialize + for<'de> Deserialize<'de>> Codable for T {}
42
43/// `WorkQueueConsumer` is a generic interface for a work queue that can be used to send and receive
44#[async_trait]
45pub trait WorkQueueConsumer {
46    async fn dequeue(&self) -> Result<Bytes, String>;
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
50#[serde(rename_all = "snake_case")]
51pub enum StreamType {
52    Request,
53    Response,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
57#[serde(rename_all = "snake_case")]
58pub enum ControlMessage {
59    Stop,
60    Kill,
61    Sentinel,
62}
63
64/// This is the first message in a `ResponseStream`. This is not a message that gets process
65/// by the general pipeline, but is a control message that is awaited before the
66/// [`AsyncEngine::generate`] method is allowed to return.
67///
68/// If an error is present, the [`AsyncEngine::generate`] method will return the error instead
69/// of returning the `ResponseStream`.
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
71pub struct ResponseStreamPrologue {
72    error: Option<String>,
73}
74
75pub type StreamProvider<T> = tokio::sync::oneshot::Receiver<Result<T, String>>;
76
77/// The [`RegisteredStream`] object is acquired from a [`StreamProvider`] and is used to provide
78/// an awaitable receiver which will the `T` which is either a stream writer for a request stream
79/// or a stream reader for a response stream.
80///
81/// make this an raii object linked to some stream provider
82/// if the object has not been awaited an the type T unwrapped, the registered stream
83/// on the stream provider will be informed and can clean up a stream that will never
84/// be connected.
85#[derive(Debug)]
86pub struct RegisteredStream<T> {
87    pub connection_info: ConnectionInfo,
88    pub stream_provider: StreamProvider<T>,
89}
90
91impl<T> RegisteredStream<T> {
92    pub fn into_parts(self) -> (ConnectionInfo, StreamProvider<T>) {
93        (self.connection_info, self.stream_provider)
94    }
95}
96
97/// After registering a stream, the [`PendingConnections`] object is returned to the caller. This
98/// object can be used to await the connection to be established.
99pub struct PendingConnections {
100    pub send_stream: Option<RegisteredStream<StreamSender>>,
101    pub recv_stream: Option<RegisteredStream<StreamReceiver>>,
102}
103
104impl PendingConnections {
105    pub fn into_parts(
106        self,
107    ) -> (
108        Option<RegisteredStream<StreamSender>>,
109        Option<RegisteredStream<StreamReceiver>>,
110    ) {
111        (self.send_stream, self.recv_stream)
112    }
113}
114
115/// A [`ResponseService`] implements a services in which a context a specific subject with will
116/// be associated with a stream of responses.
117#[async_trait::async_trait]
118pub trait ResponseService {
119    async fn register(&self, options: StreamOptions) -> PendingConnections;
120}
121
122// #[derive(Debug, Clone, Serialize, Deserialize)]
123// struct Handshake {
124//     request_id: String,
125//     worker_id: Option<String>,
126//     error: Option<String>,
127// }
128
129// impl Handshake {
130//     pub fn validate(&self) -> Result<(), String> {
131//         if let Some(e) = &self.error {
132//             return Err(e.clone());
133//         }
134//         Ok(())
135//     }
136// }
137
138// this probably needs to be come a ResponseStreamSender
139// since the prologue in this scenario sender telling the receiver
140// that all is good and it's ready to send
141//
142// in the RequestStreamSender, the prologue would be coming from the
143// receiver, so the sender would have to await the prologue which if
144// was not an error, would indicate the RequestStreamReceiver is read
145// to receive data.
146pub struct StreamSender {
147    tx: tokio::sync::mpsc::Sender<TwoPartMessage>,
148    prologue: Option<ResponseStreamPrologue>,
149}
150
151impl StreamSender {
152    pub async fn send(&self, data: Bytes) -> Result<()> {
153        Ok(self.tx.send(TwoPartMessage::from_data(data)).await?)
154    }
155
156    pub async fn send_control(&self, control: ControlMessage) -> Result<()> {
157        let bytes = serde_json::to_vec(&control)?;
158        Ok(self
159            .tx
160            .send(TwoPartMessage::from_header(bytes.into()))
161            .await?)
162    }
163
164    #[allow(clippy::needless_update)]
165    pub async fn send_prologue(&mut self, error: Option<String>) -> Result<(), String> {
166        if let Some(prologue) = self.prologue.take() {
167            let prologue = ResponseStreamPrologue { error, ..prologue };
168            let header_bytes: Bytes = match serde_json::to_vec(&prologue) {
169                Ok(b) => b.into(),
170                Err(err) => {
171                    tracing::error!(%err, "send_prologue: ResponseStreamPrologue did not serialize to a JSON array");
172                    return Err("Invalid prologue".to_string());
173                }
174            };
175            self.tx
176                .send(TwoPartMessage::from_header(header_bytes))
177                .await
178                .map_err(|e| e.to_string())?;
179        } else {
180            panic!("Prologue already sent; or not set; logic error");
181        }
182        Ok(())
183    }
184}
185
186pub struct StreamReceiver {
187    rx: tokio::sync::mpsc::Receiver<Bytes>,
188}
189
190/// Connection Info is encoded as JSON and then again serialized has part of the Transport
191/// Layer. The double serialization is not performance critical as it is only done once per
192/// connection. The primary reason storing the ConnecitonInfo has a JSON string is for type
193/// erasure. The Transport Layer will check the [`ConnectionInfo::transport`] type and then
194/// route it to the appropriate instance of the Transport, which will then deserialize the
195/// [`ConnectionInfo::info`] field to its internal connection info object.
196///
197/// Optionally, this object could become strongly typed for which all possible combinations
198/// of transport and connection info would need to be enumerated.
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct ConnectionInfo {
201    pub transport: String,
202    pub info: String,
203}
204
205/// When registering a new TransportStream on the server, the caller specifies if the
206/// stream is a sender, receiver or both.
207///
208/// Senders and Receivers are with share a Context, but result in separate tcp socket
209/// connections to the server. Internally, we may use bcast channels to coordinate the
210/// internal control messages between the sender and receiver socket connections.
211#[derive(Clone, Builder)]
212pub struct StreamOptions {
213    /// Context
214    pub context: Arc<dyn AsyncEngineContext>,
215
216    /// Register with the server that this connection will have a server-side Sender
217    /// that can be picked up by the Request/Forward pipeline
218    ///
219    /// TODO - note, this option is currently not implemented and will cause a panic
220    pub enable_request_stream: bool,
221
222    /// Register with the server that this connection will have a server-side Receiver
223    /// that can be picked up by the Response/Reverse pipeline
224    pub enable_response_stream: bool,
225
226    /// The number of messages to buffer before blocking
227    #[builder(default = "8")]
228    pub send_buffer_count: usize,
229
230    /// The number of messages to buffer before blocking
231    #[builder(default = "8")]
232    pub recv_buffer_count: usize,
233}
234
235impl StreamOptions {
236    pub fn builder() -> StreamOptionsBuilder {
237        StreamOptionsBuilder::default()
238    }
239}
240
241pub struct Egress<Req: PipelineIO, Resp: PipelineIO> {
242    transport_engine: Arc<dyn AsyncTransportEngine<Req, Resp>>,
243}
244
245#[async_trait]
246impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error>
247    for Egress<SingleIn<T>, ManyOut<U>>
248where
249    T: Data + Serialize,
250    U: for<'de> Deserialize<'de> + Data,
251{
252    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
253        self.transport_engine.generate(request).await
254    }
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
258#[serde(rename_all = "snake_case")]
259enum RequestType {
260    SingleIn,
261    ManyIn,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
265#[serde(rename_all = "snake_case")]
266enum ResponseType {
267    SingleOut,
268    ManyOut,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
272struct RequestControlMessage {
273    id: String,
274    request_type: RequestType,
275    response_type: ResponseType,
276    connection_info: ConnectionInfo,
277}
278
279pub struct Ingress<Req: PipelineIO, Resp: PipelineIO> {
280    segment: OnceLock<Arc<SegmentSource<Req, Resp>>>,
281}
282
283impl<Req: PipelineIO, Resp: PipelineIO> Ingress<Req, Resp> {
284    pub fn new() -> Arc<Self> {
285        Arc::new(Self {
286            segment: OnceLock::new(),
287        })
288    }
289
290    pub fn attach(&self, segment: Arc<SegmentSource<Req, Resp>>) -> Result<()> {
291        self.segment
292            .set(segment)
293            .map_err(|_| anyhow::anyhow!("Segment already set"))
294    }
295
296    pub fn link(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
297        let ingress = Ingress::new();
298        ingress.attach(segment)?;
299        Ok(ingress)
300    }
301
302    pub fn for_pipeline(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
303        let ingress = Ingress::new();
304        ingress.attach(segment)?;
305        Ok(ingress)
306    }
307
308    pub fn for_engine(engine: ServiceEngine<Req, Resp>) -> Result<Arc<Self>> {
309        let frontend = SegmentSource::<Req, Resp>::new();
310        let backend = ServiceBackend::from_engine(engine);
311
312        // create the pipeline
313        let pipeline = frontend.link(backend)?.link(frontend)?;
314
315        let ingress = Ingress::new();
316        ingress.attach(pipeline)?;
317
318        Ok(ingress)
319    }
320}
321
322#[async_trait]
323pub trait PushWorkHandler: Send + Sync {
324    async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError>;
325}