dynamo_runtime/pipeline/network/
tcp.rs1pub mod client;
21pub mod server;
22
23use super::ControlMessage;
24use serde::{Deserialize, Serialize};
25
26#[allow(unused_imports)]
27use super::{
28 ConnectionInfo, PendingConnections, RegisteredStream, ResponseService, StreamOptions,
29 StreamReceiver, StreamSender, StreamType, codec::TwoPartCodec,
30};
31
32const TCP_TRANSPORT: &str = "tcp_server";
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct TcpStreamConnectionInfo {
36 pub address: String,
37 pub subject: String,
38 pub context: String,
39 pub stream_type: StreamType,
40}
41
42impl From<TcpStreamConnectionInfo> for ConnectionInfo {
43 fn from(info: TcpStreamConnectionInfo) -> Self {
44 ConnectionInfo {
50 transport: TCP_TRANSPORT.to_string(),
51 info: serde_json::to_string(&info)
52 .expect("Failed to serialize TcpStreamConnectionInfo"),
53 }
54 }
55}
56
57impl TryFrom<ConnectionInfo> for TcpStreamConnectionInfo {
58 type Error = anyhow::Error;
59
60 fn try_from(info: ConnectionInfo) -> Result<Self, Self::Error> {
61 if info.transport != TCP_TRANSPORT {
62 return Err(anyhow::anyhow!(
63 "Invalid transport; TcpClient requires the transport to be `tcp_server`; however {} was passed",
64 info.transport
65 ));
66 }
67
68 serde_json::from_str(&info.info)
69 .map_err(|e| anyhow::anyhow!("Failed parse ConnectionInfo: {:?}", e))
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
78struct CallHomeHandshake {
79 subject: String,
80 stream_type: StreamType,
81}
82
83#[cfg(test)]
84mod tests {
85 use crate::engine::AsyncEngineContextProvider;
86
87 use super::*;
88 use crate::pipeline::Context;
89
90 #[derive(Debug, Clone, Serialize, Deserialize)]
91 struct TestMessage {
92 foo: String,
93 }
94
95 #[tokio::test]
96 async fn test_tcp_stream_client_server() {
97 println!("Test Started");
98 let options = server::ServerOptions::builder().port(9124).build().unwrap();
99 println!("Test Started");
100 let server = server::TcpStreamServer::new(options).await.unwrap();
101 println!("Server created");
102
103 let context_rank0 = Context::new(());
104
105 let options = StreamOptions::builder()
106 .context(context_rank0.context())
107 .enable_request_stream(false)
108 .enable_response_stream(true)
109 .build()
110 .unwrap();
111
112 let pending_connection = server.register(options).await;
113
114 let connection_info = pending_connection
115 .recv_stream
116 .as_ref()
117 .unwrap()
118 .connection_info
119 .clone();
120
121 let context_rank1 = Context::with_id((), context_rank0.id().to_string());
123
124 let mut send_stream =
126 client::TcpClient::create_response_stream(context_rank1.context(), connection_info)
127 .await
128 .unwrap();
129 println!("Client connected");
130
131 send_stream.send_prologue(None).await.unwrap();
141
142 let recv_stream = pending_connection
144 .recv_stream
145 .unwrap()
146 .stream_provider
147 .await
148 .unwrap();
149
150 println!("Server paired");
151
152 let msg = TestMessage {
153 foo: "bar".to_string(),
154 };
155
156 let payload = serde_json::to_vec(&msg).unwrap();
157
158 send_stream.send(payload.into()).await.unwrap();
159
160 println!("Client sent message");
161
162 let data = recv_stream.unwrap().rx.recv().await.unwrap();
163
164 println!("Server received message");
165
166 let recv_msg = serde_json::from_slice::<TestMessage>(&data).unwrap();
167
168 assert_eq!(msg.foo, recv_msg.foo);
169 println!("message match");
170
171 drop(send_stream);
172
173 }
177}