use std::{sync::Arc, collections::HashMap, pin::Pin, fmt};
use cs_utils::random_number;
use cs_trace::{Tracer, child};
use tokio_util::codec::Framed;
use anyhow::{Result, anyhow, bail};
use serde::{Serialize, Deserialize};
use futures::{StreamExt, stream::{SplitStream, SplitSink}, SinkExt};
use tokio::{sync::{mpsc::{Sender, self}, Mutex, Notify, watch, RwLock}, io::{split, duplex, WriteHalf, AsyncReadExt, ReadHalf, AsyncWriteExt}};
use crate::{Channel, create_framed_stream, TransportChannel, codecs::GenericCodec};
type TChannels = Arc<Mutex<HashMap<u16, Arc<Mutex<(WriteHalf<Box<dyn Channel>>, watch::Receiver<bool>)>>>>>;
type TChannelId = u16;
#[derive(Serialize, Deserialize, Debug)]
pub enum ControlMessage {
Data(TChannelId, Vec<u8>),
OpenChannel(TChannelId, String, u32, bool),
Close(TChannelId),
Error(TChannelId, String),
}
impl fmt::Display for ControlMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ControlMessage::Data(id, data) => {
return f.debug_tuple("ControlMessage::Data")
.field(id)
.field(&data.len())
.finish();
},
ControlMessage::OpenChannel(id, label, buffer_size, is_response) => {
return f.debug_tuple("ControlMessage::OpenChannel")
.field(id)
.field(label)
.field(buffer_size)
.field(is_response)
.finish();
},
ControlMessage::Close(id) => {
return f.debug_tuple("ControlMessage::Close")
.field(id)
.finish();
},
ControlMessage::Error(id, message) => {
return f.debug_tuple("ControlMessage::Error")
.field(id)
.field(message)
.finish();
},
};
}
}
async fn send_error(
trace: &Box<dyn Tracer>,
id: TChannelId,
message: String,
message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
) {
trace.error(
&format!("Channel {id} read error: {:?}", message),
);
let result = message_sender.lock().await
.send(ControlMessage::Error(id, message)).await;
if let Err(error) = result {
trace.error(
&format!("Failed to send channel error to the remote side: {:?}", error),
);
}
}
async fn forward_channel_data(
trace: Box<dyn Tracer>,
id: u16,
mut reader: ReadHalf<Box<dyn Channel>>,
message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
on_close: watch::Receiver<bool>,
channels: TChannels,
buffer_size: u32,
) {
let mut buf = vec![];
buf.resize(buffer_size as usize, 0);
loop {
let is_closed = *on_close.borrow();
let bytes_read = match reader.read(buf.as_mut_slice()).await {
Ok(number) => number,
Err(error) => {
send_error(&trace, id, format!("{error}"), message_sender).await;
return;
},
};
let data = (&buf[..bytes_read]).to_vec();
let result = {
message_sender
.lock().await
.send(ControlMessage::Data(id, data)).await
};
if let Err(error) = result {
send_error(&trace, id, format!("{}", error), message_sender).await;
return;
};
if bytes_read == 0 {
trace.warn(
&format!("got EOF, sending channel close message"),
);
let close_message_result = {
message_sender.lock().await
.send(ControlMessage::Close(id)).await
};
if let Err(error) = close_channel(id, channels).await {
trace.error(
&format!("failed to close local channel: {:?}", error),
);
};
trace.info(
&format!("channel is closed by EOF"),
);
if let Err(error) = close_message_result {
send_error(&trace, id, format!("{}", error), message_sender).await;
return;
};
return;
}
if is_closed {
trace.info("channel is closed by notification");
return;
}
}
}
async fn send_channel_data(
trace: Box<dyn Tracer>,
id: u16,
data: Vec<u8>,
channels: TChannels,
message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
) {
let channel = {
let lock = channels.lock().await;
let channel = match lock.get(&id) {
Some(writer) => writer,
None => {
send_error(&trace, id, format!("No channel with ID {:?} found.", id), message_sender).await;
return;
},
};
Arc::clone(channel)
};
let (writer, on_close) = &mut *channel.lock().await;
let is_closed = *on_close.borrow();
if data.len() == 0 && is_closed {
trace.warn(
&format!("channel {id} already closed, skip writing"),
);
return;
}
if let Err(error) = writer.write_all(&data[..]).await {
send_error(&trace, id, format!("{}", error), Arc::clone(&message_sender)).await;
}
}
async fn close_channel(
id: u16,
channels: TChannels,
) -> Result<()> {
let mut lock = channels.lock().await;
let channel = {
let channel = match lock.remove(&id) {
Some(writer) => writer,
None => bail!("No channel found with ID {}.", id),
};
channel
};
let (writer, on_close) = &mut *channel.lock().await;
if *on_close.borrow() {
return Ok(());
}
writer.shutdown().await?;
return Ok(());
}
async fn add_local_channel(
trace: Box<dyn Tracer>,
id: u16,
label: String,
buffer_size: u32,
channels: TChannels,
message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
) -> Result<Box<dyn Channel>> {
let (duplex1, duplex2) = duplex(buffer_size as usize);
let (channel1, channel2) = TransportChannel::new_pair(
id,
label.clone(),
(Box::new(duplex1), Box::new(duplex2)),
buffer_size,
);
let on_close1 = channel1.on_close();
let on_close2 = channel1.on_close();
let (reader, writer) = split(channel1);
let trace2 = &trace;
let trace2 = child!(trace2, "forward-channel-data");
tokio::spawn(forward_channel_data(
trace2,
id,
reader,
Arc::clone(&message_sender),
on_close1,
Arc::clone(&channels),
buffer_size,
));
channels
.lock().await
.insert(id, Arc::new(Mutex::new((writer, on_close2))));
trace.info(
&format!("local channel opened: {}, {}", id, label),
);
return Ok(channel2);
}
async fn open_channel(
trace: Box<dyn Tracer>,
id: u16,
label: String,
buffer_size: u32,
is_response: bool,
channels: TChannels,
open_channel_requests: Arc<RwLock<HashMap<u16, Arc<Notify>>>>,
message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
on_remote_channel: mpsc::Sender<Box<dyn Channel>>,
) -> Result<()> {
if is_response {
let read_lock = open_channel_requests.read().await;
let notify = match read_lock.get(&id) {
None => bail!("No open channel notifier found."),
Some(notify) => notify,
};
notify.notify_waiters();
return Ok(());
}
let trace1 = &trace;
let trace1 = child!(trace1, "add-local-channel");
trace.trace("sending open channel response");
let channel = add_local_channel(
trace1,
id,
label.clone(),
buffer_size,
channels,
Arc::clone(&message_sender),
).await?;
{
message_sender
.lock().await
.send(ControlMessage::OpenChannel(id, label.clone(), buffer_size, true)).await?;
}
trace.trace("sent");
on_remote_channel
.send(channel).await
.map_err(|error| {
return anyhow!("{}", error);
})?;
return Ok(());
}
async fn handle_control_messages(
trace: Box<dyn Tracer>,
mut stream_source: SplitStream<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>>,
channels: TChannels,
open_channel_requests: Arc<RwLock<HashMap<u16, Arc<Notify>>>>,
message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
on_remote_channel: mpsc::Sender<Box<dyn Channel>>,
) -> Result<()> {
while let Some(maybe_message) = stream_source.next().await {
let message = maybe_message?;
trace.warn(
&format!("got control message: {}", message),
);
match message {
ControlMessage::Data(id, data) => {
let trace = &trace;
let trace = child!(trace, "send-channel-data");
tokio::spawn(send_channel_data(
trace,
id,
data,
Arc::clone(&channels),
Arc::clone(&message_sender),
));
},
ControlMessage::OpenChannel(id, label, buffer_size, is_response) => {
let trace = &trace;
let trace = child!(trace, "open-channel");
open_channel(
trace,
id,
label,
buffer_size,
is_response,
Arc::clone(&channels),
Arc::clone(&open_channel_requests),
Arc::clone(&message_sender),
Sender::clone(&on_remote_channel),
).await?;
},
ControlMessage::Close(id) => {
tokio::spawn(close_channel(id, Arc::clone(&channels)));
},
ControlMessage::Error(id, message) => {
trace.error(
&format!("remote channel {id} error: {:?}", message),
);
tokio::spawn(close_channel(id, Arc::clone(&channels)));
},
};
}
return Ok(());
}
pub struct TransportConnection {
trace: Box<dyn Tracer>,
message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
open_channel_requests: Arc<RwLock<HashMap<u16, Arc<Notify>>>>,
channels: TChannels,
on_remote_channel: Option<mpsc::Receiver<Box<dyn Channel>>>,
}
impl TransportConnection {
pub fn new(
trace: &Box<dyn Tracer>,
channel: Box<dyn Channel>,
) -> Box<TransportConnection> {
let trace = child!(trace, "transport-channel");
let stream = create_framed_stream(channel);
let (channel_sink, channel_source) = stream.split();
let message_sender = Arc::new(Mutex::new(channel_sink));
let open_channel_requests = Arc::new(RwLock::new(HashMap::new()));
let channels = Arc::new(Mutex::new(HashMap::new()));
let (on_remote_channel_sender, on_remote_channel) = mpsc::channel(25);
let trace2 = &trace;
let trace2 = child!(trace2, "control-messages-handler");
tokio::spawn(handle_control_messages(
trace2,
channel_source,
Arc::clone(&channels),
Arc::clone(&open_channel_requests),
Arc::clone(&message_sender),
on_remote_channel_sender,
));
return Box::new(TransportConnection {
trace,
message_sender,
open_channel_requests,
channels,
on_remote_channel: Some(on_remote_channel),
});
}
pub fn on_remote_channel(&mut self) -> Result<mpsc::Receiver<Box<dyn Channel>>> {
match self.on_remote_channel.take() {
Some(on_remote_channel) => return Ok(on_remote_channel),
None => bail!("No on_remote_channel found."),
};
}
pub fn off_remote_channel(
&mut self,
on_channel: mpsc::Receiver<Box<dyn Channel>>,
) -> Result<()> {
if let Some(_) = self.on_remote_channel {
bail!("on_remote_channel already set.");
}
self.on_remote_channel.replace(on_channel);
return Ok(());
}
pub async fn channel(
&mut self,
label: impl AsRef<str> + ToString,
buffer_size: u32,
) -> Result<Box<dyn Channel>> {
let id = random_number(0..=u16::MAX);
let label = label.to_string();
self.trace.trace(
&format!("creating channel, ID: {}, label: {}", id, label),
);
let notify = Arc::new(Notify::new());
{
self.open_channel_requests
.write().await
.insert(id, Arc::clone(¬ify));
}
self.trace.trace(
&format!("sending open channel request"),
);
{
self.message_sender
.lock().await
.send(ControlMessage::OpenChannel(id, label.clone(), buffer_size, false)).await?;
}
self.trace.trace(
&format!("open channel request sent"),
);
notify.notified().await;
self.trace.trace(
&format!("got open channel response"),
);
let trace2 = &self.trace;
let trace2 = child!(trace2, "add-local-channel");
let channel = add_local_channel(
trace2,
id,
label,
buffer_size,
Arc::clone(&self.channels),
Arc::clone(&self.message_sender),
).await?;
self.trace.trace(
&format!("channel created: {}, {}", channel.id(), channel.label()),
);
return Ok(channel);
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::Arc};
use rstest::rstest;
use futures::StreamExt;
use cs_trace::create_trace;
use cs_utils::{random_str, random_str_rg, random_number, traits::Random, futures::wait_random};
use tokio::{sync::{Mutex, mpsc, watch, RwLock, Notify}, io::{split, AsyncWriteExt, AsyncReadExt}};
use crate::test::TestOptions;
use crate::create_framed_stream;
use crate::{TransportChannel, TransportConnection};
use crate::{connections::transport_connection::{send_channel_data, forward_channel_data, close_channel, add_local_channel, open_channel, ControlMessage}, mocks::{ChannelMockOptions, channel_mock_pair}};
mod send_channel_data {
use super::*;
#[rstest]
#[case(128)]
#[case(256)]
#[case(512)]
#[case(1_024)]
#[case(2_048)]
#[case(4_096)]
#[tokio::test]
async fn sends_data_to_channel(
#[case] data_len: usize,
) {
let trace = create_trace!("test");
let buffer_size = 4_096;
let id = random_number(0..=u16::MAX);
let data = random_str(data_len).as_bytes().to_vec();
let options1 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let options2 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let (channel1, mut channel2) = channel_mock_pair(options1, options2);
let on_close = channel1.on_close();
let (_reader, writer) = split(channel1);
let mut channels = HashMap::new();
channels.insert(id, Arc::new(Mutex::new((writer, on_close))));
let channels = Arc::new(Mutex::new(channels));
let data_to_send = data.clone();
let data_to_receive = data.clone();
let options1 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let options2 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let (stream1, _stream2) = TransportChannel::new_pair(
id,
"control-stream",
channel_mock_pair(options1, options2),
buffer_size,
);
let (stream1_tx, _stream1_rx) = create_framed_stream(stream1).split();
let control_channel_sender = Arc::new(Mutex::new(stream1_tx));
tokio::join!(
Box::pin(async move {
wait_random(1..=5).await;
send_channel_data(
trace,
id,
data_to_send,
channels,
control_channel_sender,
).await;
}),
Box::pin(async move {
wait_random(1..=5).await;
let mut buf = [0; 4_096];
let bytes_read = channel2
.read(&mut buf).await
.unwrap();
let received_data = &buf[..bytes_read];
assert_eq!(
received_data,
&data_to_receive[..],
"Must receive correct data.",
);
}),
);
}
#[rstest]
#[tokio::test]
async fn does_not_send_if_already_closed() {
let trace = create_trace!("test");
let buffer_size = 4_096;
let id = random_number(0..=u16::MAX);
let options1 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let options2 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let (mut channel1, mut channel2) = channel_mock_pair(options1, options2);
let on_close = channel1.on_close();
channel1
.shutdown().await
.unwrap();
let (_reader, writer) = split(channel1);
let mut channels = HashMap::new();
channels.insert(id, Arc::new(Mutex::new((writer, on_close))));
let channels = Arc::new(Mutex::new(channels));
let options1 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let options2 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let (stream1, _stream2) = TransportChannel::new_pair(
id,
"control-stream",
channel_mock_pair(options1, options2),
buffer_size,
);
let (stream1_tx, _stream1_rx) = create_framed_stream(stream1).split();
let control_channel_sender = Arc::new(Mutex::new(stream1_tx));
tokio::join!(
Box::pin(async move {
wait_random(1..=5).await;
send_channel_data(
trace,
id,
vec![],
channels,
control_channel_sender,
).await;
}),
Box::pin(async move {
wait_random(1..=5).await;
let mut buf = [0; 4_096];
let bytes_read = channel2
.read(&mut buf).await
.unwrap();
assert_eq!(
bytes_read,
0,
"Must 0 bytes.",
);
}),
);
}
#[rstest]
#[case(128)]
#[case(256)]
#[case(512)]
#[case(1_024)]
#[case(2_048)]
#[case(4_096)]
#[tokio::test]
async fn fails_if_no_channel_found(
#[case] data_len: usize,
) {
let trace = create_trace!("test");
let buffer_size = 4_096;
let id = random_number(0..=u16::MAX);
let data = random_str(data_len).as_bytes().to_vec();
let options1 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let options2 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let (channel1, _channel2) = channel_mock_pair(options1, options2);
let on_close = channel1.on_close();
let (_reader, writer) = split(channel1);
let mut channels = HashMap::new();
let another_id = {
let mut another_id = random_number(0..=u16::MAX);
while another_id == id {
another_id = random_number(0..=u16::MAX);
}
another_id
};
channels.insert(another_id, Arc::new(Mutex::new((writer, on_close))));
let channels = Arc::new(Mutex::new(channels));
let options1 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let options2 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let (stream1, stream2) = TransportChannel::new_pair(
id,
"control-stream",
channel_mock_pair(options1, options2),
buffer_size,
);
let (stream1_tx, _stream1_rx) = create_framed_stream(stream1).split();
let (_stream2_tx, mut stream2_rx) = create_framed_stream(stream2).split();
tokio::try_join!(
tokio::spawn(send_channel_data(
trace,
id,
data,
channels,
Arc::new(Mutex::new(stream1_tx),
))),
tokio::spawn(async move {
let message = stream2_rx.next().await.unwrap().unwrap();
match message {
ControlMessage::Error(received_id, error_message) => {
assert_eq!(
received_id,
id,
"Must receive error with correct id.",
);
assert!(
error_message.len() > 3,
"Received error message must be not empty.",
);
},
unexpected @ _ => panic!("Unexpected message: {:?}.", unexpected),
};
}),
).unwrap();
}
}
mod handle_channel_reads {
use crate::TransportChannel;
use super::*;
#[rstest]
#[case(512)]
#[case(1_024)]
#[case(2_048)]
#[case(4_096)]
#[case(8_192)]
#[case(16_384)]
#[tokio::test]
async fn reads_from_a_local_channel(
#[case] data_len: usize,
) {
let trace = cs_trace::create_trace!("test");
let buffer_size: u32 = 4_096;
let id = random_number(0..=u16::MAX);
let data = random_str(data_len)
.as_bytes().to_vec();
let options1 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let options2 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let (channel1, mut channel2) = TransportChannel::new_pair(
id,
"transport-channel",
channel_mock_pair(options1, options2),
buffer_size,
);
let on_close = channel1.on_close();
let (reader, _writer) = split(channel1);
let options1 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let options2 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let (stream1, stream2) = TransportChannel::new_pair(
id,
"control-stream",
channel_mock_pair(options1, options2),
buffer_size,
);
let stream1 = create_framed_stream(stream1);
let stream2 = create_framed_stream(stream2);
let (stream1_tx, _stream1_rx) = stream1.split();
let (_stream2_tx, mut control_channel_receiver) = stream2.split();
let control_channel_sender = Arc::new(Mutex::new(stream1_tx));
let data_to_send = data.clone();
let data_to_receive = data.clone();
let channels = Arc::new(Mutex::new(HashMap::new()));
let channels1 = Arc::clone(&channels);
let channels2 = Arc::clone(&channels);
tokio::join!(
Box::pin(async move {
wait_random(1..=5).await;
forward_channel_data(
trace,
id,
reader,
control_channel_sender,
on_close,
channels1,
buffer_size,
).await;
}),
Box::pin(async move {
wait_random(1..=5).await;
let mut total_written = 0;
while total_written < data_to_send.len() {
let written = channel2
.write(&data_to_send[total_written..]).await
.unwrap();
total_written += written;
}
assert!(
!channels2.lock().await.contains_key(&id),
"Channel must be deleted.",
);
}),
Box::pin(async move {
wait_random(1..=5).await;
let mut received_data = vec![];
while let Some(maybe_message) = control_channel_receiver.next().await {
let message = maybe_message.unwrap();
let (received_id, data) = match message {
ControlMessage::Data(id, data) => (id, data),
ControlMessage::Close(received_id) => {
assert_eq!(
received_id,
id,
"Message must have correct channel ID.",
);
break;
},
other @ _ => panic!("Unexpected message: {:?}", other),
};
assert_eq!(
received_id,
id,
"Message must have correct channel ID.",
);
received_data.extend_from_slice(&data[..]);
}
assert_eq!(
received_data,
data_to_receive,
"Must receive correct data.",
);
}),
);
}
}
mod close_channel {
use crate::TransportChannel;
use super::*;
#[rstest]
#[case(())]
#[case(())]
#[case(())]
#[case(())]
#[case(())]
#[case(())]
#[tokio::test]
async fn shutsdown_a_channel_and_removes_reference(
#[case] _case_num: (),
) {
let buffer_size = 4_096;
let id = random_number(0..=u16::MAX);
let options1 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let options2 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let (channel1, _channel2) = TransportChannel::new_pair(
id,
"transport-channel",
channel_mock_pair(options1, options2),
buffer_size,
);
let on_close = channel1.on_close();
let (_reader, writer) = split(channel1);
let mut channels = HashMap::new();
channels.insert(id, Arc::new(Mutex::new((writer, watch::Receiver::clone(&on_close)))));
let channels = Arc::new(Mutex::new(channels));
wait_random(1..=5).await;
close_channel(
id,
Arc::clone(&channels),
).await.unwrap();
{
assert!(
!(channels.lock().await.contains_key(&id)),
"Must remove channel reference from the map.",
);
}
assert!(
*on_close.borrow(),
"Must close the channel.",
);
}
#[rstest]
#[case(())]
#[case(())]
#[case(())]
#[case(())]
#[case(())]
#[case(())]
#[tokio::test]
async fn does_not_fails_if_channel_allready_closed(
#[case] _case_num: (),
) {
let buffer_size = 4_096;
let id = random_number(0..=u16::MAX);
let options1 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let options2 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let (channel1, _channel2) = TransportChannel::new_pair(
id,
"transport-channel",
channel_mock_pair(options1, options2),
buffer_size,
);
let on_close = channel1.on_close();
let (_reader, mut writer) = split(channel1);
let mut channels = HashMap::new();
writer.shutdown().await
.unwrap();
assert!(
*on_close.borrow(),
"Must close the channel.",
);
channels.insert(id, Arc::new(Mutex::new((writer, watch::Receiver::clone(&on_close)))));
assert!(
channels.contains_key(&id),
"Must contain channel before test",
);
let channels = Arc::new(Mutex::new(channels));
wait_random(1..=5).await;
close_channel(
id,
Arc::clone(&channels),
).await.unwrap();
{
assert!(
!(channels.lock().await.contains_key(&id)),
"Must remove channel reference from the map.",
);
}
assert!(
*on_close.borrow(),
"Must close the channel.",
);
}
#[rstest]
#[case(())]
#[case(())]
#[case(())]
#[case(())]
#[case(())]
#[case(())]
#[tokio::test]
#[should_panic]
async fn fails_if_no_channel_found(
#[case] _case_num: (),
) {
let id = random_number(0..=u16::MAX);
let channels = HashMap::new();
let channels = Arc::new(Mutex::new(channels));
wait_random(1..=5).await;
close_channel(
id,
Arc::clone(&channels),
).await.unwrap();
}
}
mod add_local_channel {
use cs_trace::create_trace;
use crate::create_framed_stream;
use super::*;
#[tokio::test]
async fn adds_channel_to_channels_map() {
let trace = create_trace!("test");
let buffer_size: u32 = 4_096;
let id = random_number(0..=u16::MAX);
let label = random_str_rg(8..=16);
let channels = HashMap::new();
let channels = Arc::new(Mutex::new(channels));
{
assert!(
!(channels.lock().await).contains_key(&id),
"Must not contain channel before the test.",
);
}
let options1 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let options2 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let (stream1, _stream2) = TransportChannel::new_pair(
id,
"control-stream",
channel_mock_pair(options1, options2),
buffer_size,
);
let stream1 = create_framed_stream(stream1);
let (stream1_tx, _stream1_rx) = stream1.split();
let control_sender = Arc::new(Mutex::new(stream1_tx));
add_local_channel(
trace,
id,
label,
buffer_size,
Arc::clone(&channels),
control_sender,
).await.unwrap();
{
assert!(
(channels.lock().await).contains_key(&id),
"Must add channel to the map.",
);
}
}
}
mod open_channel {
use cs_trace::create_trace;
use crate::create_framed_stream;
use super::*;
#[tokio::test]
async fn notifies_pending_channel_open_requests() {
let trace = create_trace!("test");
let buffer_size: u32 = 4_096;
let id = random_number(0..=u16::MAX);
let label = random_str_rg(8..=16);
let is_response = true;
let channels = Arc::new(Mutex::new(HashMap::new()));
let open_channel_requests = Arc::new(RwLock::new(HashMap::new()));
let (on_remote_channel, _on_remote_channel_receiver) = mpsc::channel(buffer_size as usize);
let channel_open_notification = Arc::new(Notify::new());
{
open_channel_requests.write().await
.insert(id, Arc::clone(&channel_open_notification));
}
let channels1 = Arc::clone(&channels);
let options1 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let options2 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let (stream1, _stream2) = TransportChannel::new_pair(
id,
"control-stream",
channel_mock_pair(options1, options2),
buffer_size,
);
let stream1 = create_framed_stream(stream1);
let (stream1_tx, _stream1_rx) = stream1.split();
let control_sender = Arc::new(Mutex::new(stream1_tx));
tokio::join!(
Box::pin(async move {
open_channel(
trace,
id,
label,
buffer_size,
is_response,
channels1,
Arc::clone(&open_channel_requests),
control_sender,
on_remote_channel,
).await.unwrap();
}),
Box::pin(channel_open_notification.notified()),
);
assert!(
!(channels.lock().await.contains_key(&id)),
"Must not add channel into the map.",
);
}
#[tokio::test]
async fn fails_if_no_channel_notification_found() {
let trace = create_trace!("test");
let buffer_size: u32 = 4_096;
let id = random_number(0..=u16::MAX);
let label = random_str_rg(8..=16);
let is_response = true;
let channels = Arc::new(Mutex::new(HashMap::new()));
let open_channel_requests = Arc::new(RwLock::new(HashMap::new()));
let (on_remote_channel, _on_remote_channel_receiver) = mpsc::channel(buffer_size as usize);
let options1 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let options2 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let (stream1, _stream2) = TransportChannel::new_pair(
id,
"control-stream",
channel_mock_pair(options1, options2),
buffer_size,
);
let stream1 = create_framed_stream(stream1);
let (stream1_tx, _stream1_rx) = stream1.split();
let control_sender = Arc::new(Mutex::new(stream1_tx));
let result = open_channel(
trace,
id,
label,
buffer_size,
is_response,
Arc::clone(&channels),
Arc::clone(&open_channel_requests),
control_sender,
on_remote_channel,
).await;
assert!(
result.is_err(),
"Must fail if no channel notification present.",
);
assert!(
!(channels.lock().await.contains_key(&id)),
"Must not add channel into the map.",
);
}
#[tokio::test]
async fn responds_to_channel_open_request() {
let trace = create_trace!("test");
let buffer_size: u32 = 4_096;
let id = random_number(0..=u16::MAX);
let label = random_str_rg(8..=16);
let is_response = false;
let channels = Arc::new(Mutex::new(HashMap::new()));
let open_channel_requests = Arc::new(RwLock::new(HashMap::new()));
let (on_remote_channel, mut on_remote_channel_receiver) = mpsc::channel(buffer_size as usize);
let options1 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let options2 = ChannelMockOptions::random()
.with_buffer_size(buffer_size);
let (stream1, stream2) = TransportChannel::new_pair(
id,
"control-stream",
channel_mock_pair(options1, options2),
buffer_size,
);
let stream1 = create_framed_stream(stream1);
let stream2 = create_framed_stream(stream2);
let (stream1_tx, _stream1_rx) = stream1.split();
let (_stream2_tx, mut control_receiver) = stream2.split();
let control_sender = Arc::new(Mutex::new(stream1_tx));
let channels1 = Arc::clone(&channels);
let channels2 = Arc::clone(&channels);
let label1 = label.clone();
let label2 = label.clone();
tokio::join!(
Box::pin(async move {
wait_random(1..=5).await;
open_channel(
trace,
id,
label1,
buffer_size,
is_response,
channels1,
Arc::clone(&open_channel_requests),
control_sender,
on_remote_channel,
).await.unwrap();
}),
Box::pin(async move {
let message = control_receiver.next().await.expect("Stream closed.").unwrap();
match message {
ControlMessage::OpenChannel(recv_id, recv_label, recv_buffer_size, recv_is_response) => {
assert_eq!(
recv_id,
id,
"Must receive correct channel ID.",
);
assert_eq!(
recv_label,
label2,
"Must receive correct channel label.",
);
assert_eq!(
recv_buffer_size,
buffer_size,
"Must receive correct channel buffer_size.",
);
assert!(
recv_is_response,
"Must send a response.",
);
},
unexpected @ _ => panic!("Got unexpected control message: {:?}", unexpected),
};
}),
);
let _channel = on_remote_channel_receiver
.recv().await
.expect("Must send `on_remote_channel` notification.");
assert!(
(channels2.lock().await.contains_key(&id)),
"Must add channel into the map.",
);
}
}
mod data_transfer {
use futures::future;
use cs_trace::{create_trace, child};
use super::*;
use crate::{test::test_stream, Channel};
async fn open_channel(
mut local_connection: Box<TransportConnection>,
mut remote_connection: Box<TransportConnection>,
buffer_size: u32,
) -> [(Box<TransportConnection>, Box<dyn Channel>); 2] {
let (local, remote) = tokio::join!(
Box::pin(async move {
let local_channel = local_connection.channel("local-channel1", buffer_size).await
.expect("Cannot create a channel.");
return (local_connection, local_channel);
}),
Box::pin(async move {
let mut on_remote_channel = remote_connection
.on_remote_channel().unwrap();
let remote_channel = on_remote_channel
.recv().await
.expect("Cannot receive a remote channel.");
remote_connection.off_remote_channel(on_remote_channel)
.expect("Cannot set remote channel listener.");
return (remote_connection, remote_channel);
}),
);
return [local, remote];
}
#[rstest]
#[case(512)]
#[case(1_024)]
#[case(2_048)]
#[case(4_096)]
#[case(8_192)]
#[case(16_384)]
#[tokio::test]
async fn transfers_data(
#[case] data_len: usize,
) {
let trace = create_trace!("test");
let buffer_size: u32 = 2_048;
let (channel1, channel2) = TransportChannel::new_pair(
random_number(0..=u16::MAX),
"transport-channels",
channel_mock_pair(ChannelMockOptions::random(), ChannelMockOptions::random()),
buffer_size,
);
let trace1 = &trace;
let trace1 = child!(trace1, "local");
let trace2 = &trace;
let trace2 = child!(trace2, "remote");
let local_connection = TransportConnection::new(&trace1, channel1);
let remote_connection = TransportConnection::new(&trace2, channel2);
let [
(_local_connection, local_channel),
(_remote_connection, remote_channel),
] = open_channel(
local_connection,
remote_connection,
buffer_size,
).await;
test_stream(
local_channel,
remote_channel,
TestOptions::random()
.with_data_len(data_len),
).await;
}
#[rstest]
#[case(512)]
#[case(1_024)]
#[case(2_048)]
#[case(4_096)]
#[case(8_192)]
#[case(16_384)]
#[tokio::test]
async fn transfers_data_in_parallel(
#[case] data_len: usize,
) {
let trace = create_trace!("test");
let buffer_size: u32 = 2_048;
let (channel1, channel2) = TransportChannel::new_pair(
random_number(0..=u16::MAX),
"transport-channels",
channel_mock_pair(ChannelMockOptions::random(), ChannelMockOptions::random()),
buffer_size,
);
let trace1 = &trace;
let trace1 = child!(trace1, "local");
let trace2 = &trace;
let trace2 = child!(trace2, "remote");
let mut local_connection = TransportConnection::new(&trace1, channel1);
let mut remote_connection = TransportConnection::new(&trace2, channel2);
let mut tasks = vec![];
for _ in 0..random_number(5..=10) {
let [
(local_connection1, local_channel),
(remote_connection1, remote_channel),
] = open_channel(
local_connection,
remote_connection,
buffer_size,
).await;
local_connection = local_connection1;
remote_connection = remote_connection1;
tasks.push(
tokio::spawn(test_stream(
local_channel,
remote_channel,
TestOptions::random()
.with_data_len(data_len),
)),
);
wait_random(0..=50).await;
}
future::try_join_all(tasks).await
.unwrap();
}
}
}