use std::time::Duration;
use bytes::Bytes;
use msg_socket::{DEFAULT_QUEUE_SIZE, RepSocket, ReqOptions, ReqSocket};
use msg_transport::{
tcp::Tcp,
tcp_tls::{self, TcpTls},
};
use openssl::ssl::{SslAcceptor, SslMethod};
use tokio_stream::StreamExt;
mod helpers {
use std::{path::PathBuf, str::FromStr as _};
use openssl::ssl::{
SslAcceptor, SslAcceptorBuilder, SslConnector, SslConnectorBuilder, SslFiletype, SslMethod,
};
pub fn default_acceptor_builder() -> SslAcceptorBuilder {
let certificate_path =
PathBuf::from_str("../testdata/certificates/server-cert.pem").unwrap();
let private_key_path =
PathBuf::from_str("../testdata/certificates/server-key.pem").unwrap();
let ca_certificate_path =
PathBuf::from_str("../testdata/certificates/ca-cert.pem").unwrap();
assert!(certificate_path.exists(), "Certificate file does not exist");
assert!(private_key_path.exists(), "Private key file does not exist");
assert!(ca_certificate_path.exists(), "CA Certificate file does not exist");
let mut acceptor_builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
acceptor_builder.set_certificate_file(certificate_path, SslFiletype::PEM).unwrap();
acceptor_builder.set_private_key_file(private_key_path, SslFiletype::PEM).unwrap();
acceptor_builder.set_ca_file(ca_certificate_path).unwrap();
acceptor_builder
}
pub fn default_connector_builder() -> SslConnectorBuilder {
let certificate_path =
PathBuf::from_str("../testdata/certificates/client-cert.pem").unwrap();
let private_key_path =
PathBuf::from_str("../testdata/certificates/client-key.pem").unwrap();
let ca_certificate_path =
PathBuf::from_str("../testdata/certificates/ca-cert.pem").unwrap();
assert!(certificate_path.exists(), "Certificate file does not exist");
assert!(private_key_path.exists(), "Private key file does not exist");
assert!(ca_certificate_path.exists(), "CA Certificate file does not exist");
let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
connector_builder.set_certificate_file(certificate_path, SslFiletype::PEM).unwrap();
connector_builder.set_private_key_file(private_key_path, SslFiletype::PEM).unwrap();
connector_builder.set_ca_file(ca_certificate_path).unwrap();
connector_builder
}
}
#[tokio::test]
async fn reqrep_works() {
let _ = tracing_subscriber::fmt::try_init();
let mut rep = RepSocket::new(Tcp::default());
let mut req = ReqSocket::new(Tcp::default());
rep.bind("0.0.0.0:0").await.unwrap();
req.connect(rep.local_addr().unwrap()).await.unwrap();
tokio::spawn(async move {
while let Some(request) = rep.next().await {
let msg = request.msg().clone();
request.respond(msg).unwrap();
}
});
let hello = Bytes::from_static(b"hello");
let response = req.request(hello.clone()).await.unwrap();
assert_eq!(hello, response, "expected {hello:?}, got {response:?}");
}
#[tokio::test]
async fn reqrep_tls_works() {
let _ = tracing_subscriber::fmt::try_init();
let server_config =
tcp_tls::config::Server::new(helpers::default_acceptor_builder().build().into());
let tcp_tls_server = TcpTls::new_server(server_config);
let mut rep = RepSocket::new(tcp_tls_server);
rep.bind("0.0.0.0:0").await.unwrap();
let domain = "localhost".to_string();
let ssl_connector = helpers::default_connector_builder().build();
let tcp_tls_client =
TcpTls::new_client(tcp_tls::config::Client::new(domain).with_ssl_connector(ssl_connector));
let mut req = ReqSocket::new(tcp_tls_client);
req.connect(rep.local_addr().unwrap()).await.unwrap();
tokio::spawn(async move {
while let Some(request) = rep.next().await {
let msg = request.msg().clone();
request.respond(msg).unwrap();
}
});
let hello = Bytes::from_static(b"hello");
let response = req.request(hello.clone()).await.unwrap();
assert_eq!(hello, response, "expected {hello:?}, got {response:?}");
}
#[tokio::test]
async fn reqrep_tls_control_works() {
let _ = tracing_subscriber::fmt::try_init();
let server_config =
tcp_tls::config::Server::new(helpers::default_acceptor_builder().build().into());
let tcp_tls_server = TcpTls::new_server(server_config);
let mut rep = RepSocket::new(tcp_tls_server);
rep.bind("0.0.0.0:0").await.unwrap();
let domain = "localhost".to_string();
let ssl_connector = helpers::default_connector_builder().build();
let tcp_tls_client =
TcpTls::new_client(tcp_tls::config::Client::new(domain).with_ssl_connector(ssl_connector));
let mut req = ReqSocket::new(tcp_tls_client);
req.connect(rep.local_addr().unwrap()).await.unwrap();
let handle = tokio::spawn(async move {
if let Some(request) = rep.next().await {
let msg = request.msg().clone();
request.respond(msg).unwrap();
}
rep
});
let hello = Bytes::from_static(b"hello");
let response = req.request(hello.clone()).await.unwrap();
assert_eq!(hello, response, "expected {hello:?}, got {response:?}");
drop(req);
let mut rep = handle.await.unwrap();
let domain = "localhost".to_string();
let ssl_connector = helpers::default_connector_builder().build();
let tcp_tls_client =
TcpTls::new_client(tcp_tls::config::Client::new(domain).with_ssl_connector(ssl_connector));
let mut req = ReqSocket::new(tcp_tls_client);
let acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap().build();
rep.control(tcp_tls::Control::SwapAcceptor(acceptor.into())).await.unwrap();
req.connect_sync(rep.local_addr().copied().unwrap());
tokio::spawn(async move {
if let Some(request) = rep.next().await {
let msg = request.msg().clone();
request.respond(msg).unwrap();
}
});
let hello = Bytes::from_static(b"hello");
tokio::time::timeout(Duration::from_secs(1), req.request(hello.clone())).await.unwrap_err();
}
#[tokio::test]
async fn reqrep_mutual_tls_works() {
let _ = tracing_subscriber::fmt::try_init();
let mut acceptor_builder = helpers::default_acceptor_builder();
acceptor_builder.set_verify(
openssl::ssl::SslVerifyMode::PEER | openssl::ssl::SslVerifyMode::FAIL_IF_NO_PEER_CERT,
);
let server_config = tcp_tls::config::Server::new(acceptor_builder.build().into());
let tcp_tls_server = TcpTls::new_server(server_config);
let mut rep = RepSocket::new(tcp_tls_server);
rep.bind("0.0.0.0:0").await.unwrap();
let domain = "localhost".to_string();
let ssl_connector = helpers::default_connector_builder().build();
let tcp_tls_client =
TcpTls::new_client(tcp_tls::config::Client::new(domain).with_ssl_connector(ssl_connector));
let mut req = ReqSocket::new(tcp_tls_client);
req.connect(rep.local_addr().unwrap()).await.unwrap();
tokio::spawn(async move {
while let Some(request) = rep.next().await {
let msg = request.msg().clone();
request.respond(msg).unwrap();
}
});
let hello = Bytes::from_static(b"hello");
let response = req.request(hello.clone()).await.unwrap();
assert_eq!(hello, response, "expected {hello:?}, got {response:?}");
}
#[tokio::test]
async fn reqrep_late_bind_works() {
let _ = tracing_subscriber::fmt::try_init();
let mut rep = RepSocket::new(Tcp::default());
let mut req = ReqSocket::new(Tcp::default());
let local_addr = "localhost:64521";
req.connect(local_addr).await.unwrap();
let hello = Bytes::from_static(b"hello");
let reply = tokio::spawn(async move { req.request(hello.clone()).await.unwrap() });
tokio::time::sleep(Duration::from_millis(1000)).await;
rep.bind(local_addr).await.unwrap();
let msg = rep.next().await.unwrap();
let payload = msg.msg().clone();
msg.respond(payload).unwrap();
let response = reply.await.unwrap();
let hello = Bytes::from_static(b"hello");
assert_eq!(hello, response, "expected {hello:?}, got {response:?}");
}
#[tokio::test]
async fn reqrep_hwm_reached() {
let _ = tracing_subscriber::fmt::try_init();
const HWM: usize = 2;
let mut rep = RepSocket::new(Tcp::default());
let options =
ReqOptions::default().with_max_pending_requests(HWM).with_timeout(Duration::from_secs(30));
let mut req = ReqSocket::with_options(Tcp::default(), options);
rep.bind("0.0.0.0:0").await.unwrap();
req.connect(rep.local_addr().unwrap()).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
tokio::spawn(async move {
let mut requests = Vec::new();
loop {
tokio::select! {
Some(request) = rep.next() => {
requests.push(request);
}
}
}
});
let req = std::sync::Arc::new(req);
const TOTAL_CAPACITY: usize = HWM + DEFAULT_QUEUE_SIZE;
let mut success_receivers = Vec::new();
let mut sent_count = 0;
let (loop_tx, mut loop_rx) = tokio::sync::mpsc::channel(1);
loop {
let (tx, rx) = tokio::sync::oneshot::channel();
let req_clone = std::sync::Arc::clone(&req);
let loop_tx = loop_tx.clone();
let i = sent_count;
tokio::spawn(async move {
let result = req_clone.request(Bytes::from(format!("request{}", i))).await;
if result.is_err() {
_ = loop_tx.send(()).await;
}
let _ = tx.send(result);
});
success_receivers.push(rx);
sent_count += 1;
tokio::time::sleep(Duration::from_millis(1)).await;
if loop_rx.try_recv().is_ok() {
break;
}
}
let expected_limit = TOTAL_CAPACITY + 1;
assert_eq!(
sent_count, expected_limit,
"Expected to send {} requests before HWM, but sent {}",
expected_limit, sent_count
);
}