Skip to main content

dynamo_runtime/pipeline/network/
tcp.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! TCP Transport Module
5//!
6//! The TCP Transport module consists of two main components: Client and Server. The Client is
7//! the downstream node that is responsible for connecting back to the upstream node (Server).
8//!
9//! Both Client and Server are given a Stream object that they can specialize for their specific
10//! needs, i.e. if they are SingleIn/ManyIn or SingleOut/ManyOut.
11//!
12//! The Request object will carry the Transport Type and Connection details, i.e. how the receiver
13//! of a Request is able to communicate back to the source of the Request.
14//!
15//! There are two types of TcpStream:
16//! - CallHome stream - the address for the listening socket is forward via some mechanism which then
17//!   connects back to the source of the CallHome stream. To match the socket with an awaiting data
18//!   stream, the CallHomeHandshake is used.
19
20pub mod client;
21pub mod server;
22
23pub mod test_utils;
24
25use super::ControlMessage;
26use serde::{Deserialize, Serialize};
27
28#[allow(unused_imports)]
29use super::{
30    ConnectionInfo, PendingConnections, RegisteredStream, ResponseService, StreamOptions,
31    StreamReceiver, StreamSender, StreamType, codec::TwoPartCodec,
32};
33
34const TCP_TRANSPORT: &str = "tcp_server";
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct TcpStreamConnectionInfo {
38    pub address: String,
39    pub subject: String,
40    pub context: String,
41    pub stream_type: StreamType,
42}
43
44impl From<TcpStreamConnectionInfo> for ConnectionInfo {
45    fn from(info: TcpStreamConnectionInfo) -> Self {
46        // Need to consider the below. If failure should be fatal, keep the below with .expect()
47        // But if there is a default value, we can use:
48        // unwrap_or_else(|e| {
49        //     eprintln!("Failed to serialize TcpStreamConnectionInfo: {:?}", e);
50        //     "{}".to_string() // Provide a fallback empty JSON string or default value
51        ConnectionInfo {
52            transport: TCP_TRANSPORT.to_string(),
53            info: serde_json::to_string(&info)
54                .expect("Failed to serialize TcpStreamConnectionInfo"),
55        }
56    }
57}
58
59impl TryFrom<ConnectionInfo> for TcpStreamConnectionInfo {
60    type Error = anyhow::Error;
61
62    fn try_from(info: ConnectionInfo) -> Result<Self, Self::Error> {
63        if info.transport != TCP_TRANSPORT {
64            return Err(anyhow::anyhow!(
65                "Invalid transport; TcpClient requires the transport to be `tcp_server`; however {} was passed",
66                info.transport
67            ));
68        }
69
70        serde_json::from_str(&info.info)
71            .map_err(|e| anyhow::anyhow!("Failed parse ConnectionInfo: {:?}", e))
72    }
73}
74
75/// First message sent over a CallHome stream which will map the newly created socket to a specific
76/// response data stream which was registered with the same subject.
77///
78/// This is a transport specific message as part of forming/completing a CallHome TcpStream.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80struct CallHomeHandshake {
81    subject: String,
82    stream_type: StreamType,
83}
84
85#[cfg(test)]
86mod tests {
87    use crate::engine::AsyncEngineContextProvider;
88
89    use super::*;
90    use crate::pipeline::Context;
91
92    #[derive(Debug, Clone, Serialize, Deserialize)]
93    struct TestMessage {
94        foo: String,
95    }
96
97    #[tokio::test]
98    async fn test_tcp_stream_client_server() {
99        println!("Test Started");
100        let options = server::ServerOptions::builder().port(9124).build().unwrap();
101        println!("Test Started");
102        let server = server::TcpStreamServer::new(options).await.unwrap();
103        println!("Server created");
104
105        let context_rank0 = Context::new(());
106
107        let options = StreamOptions::builder()
108            .context(context_rank0.context())
109            .enable_request_stream(false)
110            .enable_response_stream(true)
111            .build()
112            .unwrap();
113
114        let pending_connection = server.register(options).await;
115
116        let connection_info = pending_connection
117            .recv_stream
118            .as_ref()
119            .unwrap()
120            .connection_info
121            .clone();
122
123        // set up the other rank
124        let context_rank1 = Context::with_id((), context_rank0.id().to_string());
125
126        // connect to the server socket
127        let mut send_stream = client::TcpClient::create_response_stream(
128            context_rank1.context(),
129            connection_info,
130            None,
131        )
132        .await
133        .unwrap();
134        println!("Client connected");
135
136        // the client can now setup it's end of the stream and if it errors, it can send a message
137        // to the server to stop the stream
138        //
139        // this step must be done before the next step on the server can complete, i.e.
140        // the server's stream is now blocked on receiving the prologue message
141        //
142        // let's improve this and use an enum like Ok/Err; currently, None means good-to-go, and
143        // Some(String) means an error happened on this downstream node and we need to alert the
144        // upstream node that an error occurred
145        send_stream.send_prologue(None).await.unwrap();
146
147        // [server] next - now pending connections should be connected
148        let recv_stream = pending_connection
149            .recv_stream
150            .unwrap()
151            .stream_provider
152            .await
153            .unwrap();
154
155        println!("Server paired");
156
157        let msg = TestMessage {
158            foo: "bar".to_string(),
159        };
160
161        let payload = serde_json::to_vec(&msg).unwrap();
162
163        send_stream.send(payload.into()).await.unwrap();
164
165        println!("Client sent message");
166
167        let data = recv_stream.unwrap().rx.recv().await.unwrap();
168
169        println!("Server received message");
170
171        let recv_msg = serde_json::from_slice::<TestMessage>(&data).unwrap();
172
173        assert_eq!(msg.foo, recv_msg.foo);
174        println!("message match");
175
176        drop(send_stream);
177
178        // let data = recv_stream.rx.recv().await;
179
180        // assert!(data.is_none());
181    }
182}