use std::pin::Pin;
use connection_utils::Disconnected;
use cs_trace::{Tracer, create_trace};
use tokio::io::{AsyncRead, AsyncWrite};
mod rpc;
pub mod disconnected;
pub mod connected;
pub struct MultiplexedConnection<TAsyncDuplex: AsyncRead + AsyncWrite + Send + 'static> {
trace: Box<dyn Tracer>,
stream: Option<Pin<Box<TAsyncDuplex>>>,
}
impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + 'static> MultiplexedConnection<TAsyncDuplex> {
pub fn new(
stream: TAsyncDuplex,
) -> Box<dyn Disconnected> {
let trace = create_trace!("rpc-connection");
return Box::new(
MultiplexedConnection {
trace,
stream: Some(Box::pin(stream)),
}
);
}
}
#[cfg(test)]
mod tests {
pub use cs_trace::{create_trace_listener, TraceListenerOptions, Trace, SubscriberInitExt, create_trace, child};
use cs_utils::futures::wait;
use connection_utils::test::test_async_stream;
use cs_utils::random_str;
use rstest::rstest;
use tokio::io::duplex;
use tokio::try_join;
use cs_utils::random_str_rg;
use super::MultiplexedConnection;
#[rstest]
#[case::size_8_32(8, 32)]
#[case::size_128_512(128, 512)]
#[case::size_2048_4096(2048, 4096)]
#[case::size_4096_8192(4096, 8192)]
#[case::size_8192_16384(8192, 16384)]
#[tokio::test]
async fn sends_data_from_local_channel(
#[case] str_min_size: usize,
#[case] str_max_size: usize,
) {
use cs_utils::random_str_rg;
let (duplex1, duplex2) = duplex(4096);
let channel_label = format!("channel-label-{}", random_str(4));
let channel_label1 = channel_label.clone();
let channel_label2 = channel_label.clone();
let (channel1, channel2) = try_join!(
tokio::spawn(async move {
let mut connection1 = MultiplexedConnection::new(duplex1)
.connect().await
.expect("Error while listening.");
wait(50).await;
let channel = connection1
.channel(channel_label1.clone()).await
.unwrap();
assert_eq!(
channel.label(),
&channel_label1,
"Channel labels must match.",
);
channel
}),
tokio::spawn(async move {
let mut connection2 = MultiplexedConnection::new(duplex2)
.listen().await
.expect("Error while listening.");
let mut on_remote_channel = connection2.on_remote_channel().unwrap();
let channel = {
loop {
if let Ok(channel) = on_remote_channel.try_recv() {
break channel;
}
wait(50).await;
}
};
assert_eq!(
channel.label(),
&channel_label2,
"Channel labels must match.",
);
channel
}),
).unwrap();
let test_data = vec![
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
].join("");
test_async_stream(
channel1,
channel2,
test_data,
).await;
}
#[rstest]
#[case::size_8_32(8, 32)]
#[case::size_128_512(128, 512)]
#[case::size_2048_4096(2048, 4096)]
#[case::size_4096_8192(4096, 8192)]
#[case::size_8192_16384(8192, 16384)]
#[tokio::test]
async fn sends_data_from_remote_channel(
#[case] str_min_size: usize,
#[case] str_max_size: usize,
) {
let (duplex1, duplex2) = duplex(4096);
let channel_label = format!("channel-label-{}", random_str(4));
let channel_label1 = channel_label.clone();
let channel_label2 = channel_label.clone();
let (channel1, channel2) = try_join!(
tokio::spawn(async move {
let mut connection1 = MultiplexedConnection::new(duplex1)
.connect().await
.expect("Error while listening");
wait(50).await;
let channel = connection1
.channel(channel_label1.clone()).await
.unwrap();
assert_eq!(
channel.label(),
&channel_label1,
"Channel labels must match.",
);
channel
}),
tokio::spawn(async move {
let mut connection2 = MultiplexedConnection::new(duplex2)
.listen().await
.expect("Error while listening.");
let mut on_remote_channel = connection2.on_remote_channel().unwrap();
let channel = {
loop {
if let Ok(channel) = on_remote_channel.try_recv() {
break channel;
}
wait(50).await;
}
};
assert_eq!(
channel.label(),
&channel_label2,
"Channel labels must match.",
);
channel
}),
).unwrap();
let test_data = vec![
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
].join("");
test_async_stream(
channel1,
channel2,
test_data,
).await;
}
}