pub mod client;
pub mod server;
pub mod test_utils;
use super::ControlMessage;
use serde::{Deserialize, Serialize};
#[allow(unused_imports)]
use super::{
ConnectionInfo, PendingConnections, RegisteredStream, ResponseService, StreamOptions,
StreamReceiver, StreamSender, StreamType, codec::TwoPartCodec,
};
const TCP_TRANSPORT: &str = "tcp_server";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TcpStreamConnectionInfo {
pub address: String,
pub subject: String,
pub context: String,
pub stream_type: StreamType,
}
impl From<TcpStreamConnectionInfo> for ConnectionInfo {
fn from(info: TcpStreamConnectionInfo) -> Self {
ConnectionInfo {
transport: TCP_TRANSPORT.to_string(),
info: serde_json::to_string(&info)
.expect("Failed to serialize TcpStreamConnectionInfo"),
}
}
}
impl TryFrom<ConnectionInfo> for TcpStreamConnectionInfo {
type Error = anyhow::Error;
fn try_from(info: ConnectionInfo) -> Result<Self, Self::Error> {
if info.transport != TCP_TRANSPORT {
return Err(anyhow::anyhow!(
"Invalid transport; TcpClient requires the transport to be `tcp_server`; however {} was passed",
info.transport
));
}
serde_json::from_str(&info.info)
.map_err(|e| anyhow::anyhow!("Failed parse ConnectionInfo: {:?}", e))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CallHomeHandshake {
subject: String,
stream_type: StreamType,
}
#[cfg(test)]
mod tests {
use crate::engine::AsyncEngineContextProvider;
use super::*;
use crate::pipeline::Context;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TestMessage {
foo: String,
}
#[tokio::test]
async fn test_tcp_stream_client_server() {
println!("Test Started");
let options = server::ServerOptions::builder().port(9124).build().unwrap();
println!("Test Started");
let server = server::TcpStreamServer::new(options).await.unwrap();
println!("Server created");
let context_rank0 = Context::new(());
let options = StreamOptions::builder()
.context(context_rank0.context())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending_connection = server.register(options).await;
let connection_info = pending_connection
.recv_stream
.as_ref()
.unwrap()
.connection_info
.clone();
let context_rank1 = Context::with_id((), context_rank0.id().to_string());
let mut send_stream =
client::TcpClient::create_response_stream(context_rank1.context(), connection_info)
.await
.unwrap();
println!("Client connected");
send_stream.send_prologue(None).await.unwrap();
let recv_stream = pending_connection
.recv_stream
.unwrap()
.stream_provider
.await
.unwrap();
println!("Server paired");
let msg = TestMessage {
foo: "bar".to_string(),
};
let payload = serde_json::to_vec(&msg).unwrap();
send_stream.send(payload.into()).await.unwrap();
println!("Client sent message");
let data = recv_stream.unwrap().rx.recv().await.unwrap();
println!("Server received message");
let recv_msg = serde_json::from_slice::<TestMessage>(&data).unwrap();
assert_eq!(msg.foo, recv_msg.foo);
println!("message match");
drop(send_stream);
}
}