dynamo_runtime/pipeline/network/tcp.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! TCP Transport Module
17//!
18//! The TCP Transport module consists of two main components: Client and Server. The Client is
19//! the downstream node that is responsible for connecting back to the upstream node (Server).
20//!
21//! Both Client and Server are given a Stream object that they can specialize for their specific
22//! needs, i.e. if they are SingleIn/ManyIn or SingleOut/ManyOut.
23//!
24//! The Request object will carry the Transport Type and Connection details, i.e. how the receiver
25//! of a Request is able to communicate back to the source of the Request.
26//!
27//! There are two types of TcpStream:
28//! - CallHome stream - the address for the listening socket is forward via some mechanism which then
29//! connects back to the source of the CallHome stream. To match the socket with an awaiting data
30//! stream, the CallHomeHandshake is used.
31
32pub mod client;
33pub mod server;
34
35use super::ControlMessage;
36use serde::{Deserialize, Serialize};
37
38#[allow(unused_imports)]
39use super::{
40 ConnectionInfo, PendingConnections, RegisteredStream, ResponseService, StreamOptions,
41 StreamReceiver, StreamSender, StreamType, codec::TwoPartCodec,
42};
43
44const TCP_TRANSPORT: &str = "tcp_server";
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct TcpStreamConnectionInfo {
48 pub address: String,
49 pub subject: String,
50 pub context: String,
51 pub stream_type: StreamType,
52}
53
54impl From<TcpStreamConnectionInfo> for ConnectionInfo {
55 fn from(info: TcpStreamConnectionInfo) -> Self {
56 // Need to consider the below. If failure should be fatal, keep the below with .expect()
57 // But if there is a default value, we can use:
58 // unwrap_or_else(|e| {
59 // eprintln!("Failed to serialize TcpStreamConnectionInfo: {:?}", e);
60 // "{}".to_string() // Provide a fallback empty JSON string or default value
61 ConnectionInfo {
62 transport: TCP_TRANSPORT.to_string(),
63 info: serde_json::to_string(&info)
64 .expect("Failed to serialize TcpStreamConnectionInfo"),
65 }
66 }
67}
68
69impl TryFrom<ConnectionInfo> for TcpStreamConnectionInfo {
70 type Error = anyhow::Error;
71
72 fn try_from(info: ConnectionInfo) -> Result<Self, Self::Error> {
73 if info.transport != TCP_TRANSPORT {
74 return Err(anyhow::anyhow!(
75 "Invalid transport; TcpClient requires the transport to be `tcp_server`; however {} was passed",
76 info.transport
77 ));
78 }
79
80 serde_json::from_str(&info.info)
81 .map_err(|e| anyhow::anyhow!("Failed parse ConnectionInfo: {:?}", e))
82 }
83}
84
85/// First message sent over a CallHome stream which will map the newly created socket to a specific
86/// response data stream which was registered with the same subject.
87///
88/// This is a transport specific message as part of forming/completing a CallHome TcpStream.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90struct CallHomeHandshake {
91 subject: String,
92 stream_type: StreamType,
93}
94
95#[cfg(test)]
96mod tests {
97 use crate::engine::AsyncEngineContextProvider;
98
99 use super::*;
100 use crate::pipeline::Context;
101
102 #[derive(Debug, Clone, Serialize, Deserialize)]
103 struct TestMessage {
104 foo: String,
105 }
106
107 #[tokio::test]
108 async fn test_tcp_stream_client_server() {
109 println!("Test Started");
110 let options = server::ServerOptions::builder().port(9124).build().unwrap();
111 println!("Test Started");
112 let server = server::TcpStreamServer::new(options).await.unwrap();
113 println!("Server created");
114
115 let context_rank0 = Context::new(());
116
117 let options = StreamOptions::builder()
118 .context(context_rank0.context())
119 .enable_request_stream(false)
120 .enable_response_stream(true)
121 .build()
122 .unwrap();
123
124 let pending_connection = server.register(options).await;
125
126 let connection_info = pending_connection
127 .recv_stream
128 .as_ref()
129 .unwrap()
130 .connection_info
131 .clone();
132
133 // set up the other rank
134 let context_rank1 = Context::with_id((), context_rank0.id().to_string());
135
136 // connect to the server socket
137 let mut send_stream =
138 client::TcpClient::create_response_stream(context_rank1.context(), connection_info)
139 .await
140 .unwrap();
141 println!("Client connected");
142
143 // the client can now setup it's end of the stream and if it errors, it can send a message
144 // to the server to stop the stream
145 //
146 // this step must be done before the next step on the server can complete, i.e.
147 // the server's stream is now blocked on receiving the prologue message
148 //
149 // let's improve this and use an enum like Ok/Err; currently, None means good-to-go, and
150 // Some(String) means an error happened on this downstream node and we need to alert the
151 // upstream node that an error occurred
152 send_stream.send_prologue(None).await.unwrap();
153
154 // [server] next - now pending connections should be connected
155 let recv_stream = pending_connection
156 .recv_stream
157 .unwrap()
158 .stream_provider
159 .await
160 .unwrap();
161
162 println!("Server paired");
163
164 let msg = TestMessage {
165 foo: "bar".to_string(),
166 };
167
168 let payload = serde_json::to_vec(&msg).unwrap();
169
170 send_stream.send(payload.into()).await.unwrap();
171
172 println!("Client sent message");
173
174 let data = recv_stream.unwrap().rx.recv().await.unwrap();
175
176 println!("Server received message");
177
178 let recv_msg = serde_json::from_slice::<TestMessage>(&data).unwrap();
179
180 assert_eq!(msg.foo, recv_msg.foo);
181 println!("message match");
182
183 drop(send_stream);
184
185 // let data = recv_stream.rx.recv().await;
186
187 // assert!(data.is_none());
188 }
189}