dynamo_runtime/pipeline/network/
tcp.rs1pub mod client;
21pub mod server;
22
23pub mod test_utils;
24
25use super::ControlMessage;
26use serde::{Deserialize, Serialize};
27
28#[allow(unused_imports)]
29use super::{
30 ConnectionInfo, PendingConnections, RegisteredStream, ResponseService, StreamOptions,
31 StreamReceiver, StreamSender, StreamType, codec::TwoPartCodec,
32};
33
34const TCP_TRANSPORT: &str = "tcp_server";
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct TcpStreamConnectionInfo {
38 pub address: String,
39 pub subject: String,
40 pub context: String,
41 pub stream_type: StreamType,
42}
43
44impl From<TcpStreamConnectionInfo> for ConnectionInfo {
45 fn from(info: TcpStreamConnectionInfo) -> Self {
46 ConnectionInfo {
52 transport: TCP_TRANSPORT.to_string(),
53 info: serde_json::to_string(&info)
54 .expect("Failed to serialize TcpStreamConnectionInfo"),
55 }
56 }
57}
58
59impl TryFrom<ConnectionInfo> for TcpStreamConnectionInfo {
60 type Error = anyhow::Error;
61
62 fn try_from(info: ConnectionInfo) -> Result<Self, Self::Error> {
63 if info.transport != TCP_TRANSPORT {
64 return Err(anyhow::anyhow!(
65 "Invalid transport; TcpClient requires the transport to be `tcp_server`; however {} was passed",
66 info.transport
67 ));
68 }
69
70 serde_json::from_str(&info.info)
71 .map_err(|e| anyhow::anyhow!("Failed parse ConnectionInfo: {:?}", e))
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
80struct CallHomeHandshake {
81 subject: String,
82 stream_type: StreamType,
83}
84
85#[cfg(test)]
86mod tests {
87 use crate::engine::AsyncEngineContextProvider;
88
89 use super::*;
90 use crate::pipeline::Context;
91
92 #[derive(Debug, Clone, Serialize, Deserialize)]
93 struct TestMessage {
94 foo: String,
95 }
96
97 #[tokio::test]
98 async fn test_tcp_stream_client_server() {
99 println!("Test Started");
100 let options = server::ServerOptions::builder().port(9124).build().unwrap();
101 println!("Test Started");
102 let server = server::TcpStreamServer::new(options).await.unwrap();
103 println!("Server created");
104
105 let context_rank0 = Context::new(());
106
107 let options = StreamOptions::builder()
108 .context(context_rank0.context())
109 .enable_request_stream(false)
110 .enable_response_stream(true)
111 .build()
112 .unwrap();
113
114 let pending_connection = server.register(options).await;
115
116 let connection_info = pending_connection
117 .recv_stream
118 .as_ref()
119 .unwrap()
120 .connection_info
121 .clone();
122
123 let context_rank1 = Context::with_id((), context_rank0.id().to_string());
125
126 let mut send_stream =
128 client::TcpClient::create_response_stream(context_rank1.context(), connection_info)
129 .await
130 .unwrap();
131 println!("Client connected");
132
133 send_stream.send_prologue(None).await.unwrap();
143
144 let recv_stream = pending_connection
146 .recv_stream
147 .unwrap()
148 .stream_provider
149 .await
150 .unwrap();
151
152 println!("Server paired");
153
154 let msg = TestMessage {
155 foo: "bar".to_string(),
156 };
157
158 let payload = serde_json::to_vec(&msg).unwrap();
159
160 send_stream.send(payload.into()).await.unwrap();
161
162 println!("Client sent message");
163
164 let data = recv_stream.unwrap().rx.recv().await.unwrap();
165
166 println!("Server received message");
167
168 let recv_msg = serde_json::from_slice::<TestMessage>(&data).unwrap();
169
170 assert_eq!(msg.foo, recv_msg.foo);
171 println!("message match");
172
173 drop(send_stream);
174
175 }
179}