use std::{pin::Pin, task::{Context, Poll}, io, ops::RangeInclusive, fmt};
use futures::{Future, ready};
use cs_utils::{random_number, random_str, futures::wait_random, traits::Random};
use tokio::{io::{duplex, AsyncRead, AsyncWrite, ReadBuf, DuplexStream}, sync::watch};
use crate::Channel;
pub struct ChannelMock<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static = DuplexStream> {
id: u16,
label: String,
channel: Pin<Box<TAsyncDuplex>>,
options: ChannelMockOptions,
read_delay_future: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
write_delay_future: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
on_close: watch::Receiver<bool>,
on_close_sender: watch::Sender<bool>,
}
impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> ChannelMock<TAsyncDuplex> {
pub fn new(
channel: Box<TAsyncDuplex>,
options: ChannelMockOptions,
) -> Box<dyn Channel> {
let (on_close_sender, on_close) = watch::channel(false);
return Box::new(
ChannelMock {
id: options.id,
label: options.label.clone(),
channel: Pin::new(channel),
options,
read_delay_future: None,
write_delay_future: None,
on_close,
on_close_sender,
},
);
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ChannelMockOptions {
id: u16,
label: String,
latency_range: RangeInclusive<u64>,
buffer_size: u32,
}
impl ChannelMockOptions {
pub fn with_id(
self,
id: u16,
) -> ChannelMockOptions {
return ChannelMockOptions {
id,
..self
};
}
pub fn with_label(
self,
label: impl AsRef<str> + ToString,
) -> ChannelMockOptions {
return ChannelMockOptions {
label: label.to_string(),
..self
};
}
pub fn with_latency(
self,
latency_range: RangeInclusive<u64>,
) -> ChannelMockOptions {
return ChannelMockOptions {
latency_range,
..self
};
}
pub fn with_buffer_size(
self,
buffer_size: u32,
) -> ChannelMockOptions {
return ChannelMockOptions {
buffer_size,
..self
};
}
}
impl Random for ChannelMockOptions {
fn random() -> Self {
let min = random_number(0..5);
let max = random_number(5..=50);
return ChannelMockOptions::default()
.with_latency(min..=max);
}
}
impl Default for ChannelMockOptions {
fn default() -> ChannelMockOptions {
return ChannelMockOptions {
id: random_number(0..=u16::MAX),
label: format!("channel-mock-{}", random_str(8)),
latency_range: (0..=0),
buffer_size: 4_096,
};
}
}
pub fn channel_mock_pair(
options1: ChannelMockOptions,
options2: ChannelMockOptions,
) -> (Box<dyn Channel>, Box<dyn Channel>) {
let (channel1, channel2) = duplex(options1.buffer_size as usize);
return (
ChannelMock::new(Box::new(channel1), options1),
ChannelMock::new(Box::new(channel2), options2),
);
}
impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> Channel for ChannelMock<TAsyncDuplex> {
fn id(&self) -> u16 {
return self.id;
}
fn label(&self) -> &String {
return &self.label;
}
fn is_closed(&self) -> bool {
return *self.on_close.borrow();
}
fn on_close(&self) -> watch::Receiver<bool> {
return watch::Receiver::clone(&self.on_close);
}
fn buffer_size(&self) -> u32 {
return self.options.buffer_size;
}
}
impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> fmt::Debug for ChannelMock<TAsyncDuplex> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
return self.debug("ChannelMock", f);
}
}
impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncRead for ChannelMock<TAsyncDuplex> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Some(read_delay_future) = self.read_delay_future.as_mut() {
ready!(read_delay_future.as_mut().poll(cx));
self.read_delay_future.take();
}
let result = ready!(self.channel.as_mut().poll_read(cx, buf));
self.read_delay_future = Some(Box::pin(wait_random(self.options.latency_range.clone())));
return Poll::Ready(result);
}
}
impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncWrite for ChannelMock<TAsyncDuplex> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if let Some(write_delay_future) = self.write_delay_future.as_mut() {
ready!(write_delay_future.as_mut().poll(cx));
self.write_delay_future.take();
}
let result = ready!(self.channel.as_mut().poll_write(cx, buf));
self.write_delay_future = Some(Box::pin(wait_random(self.options.latency_range.clone())));
return Poll::Ready(result);
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
return self.channel.as_mut()
.poll_flush(cx);
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
if let Some(read_delay_future) = self.read_delay_future.as_mut() {
ready!(read_delay_future.as_mut().poll(cx));
self.read_delay_future.take();
}
let result = ready!(self.channel.as_mut().poll_shutdown(cx));
let _err = self.on_close_sender.send(true);
return Poll::Ready(result);
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use cs_utils::{traits::Random, random_number};
use crate::utils::create_framed_stream;
use crate::test::{TestStreamMessage, test_async_stream, test_framed_stream, TestOptions};
use super::channel_mock_pair;
#[rstest]
#[case(128)]
#[case(256)]
#[case(512)]
#[case(1_024)]
#[case(2_048)]
#[case(4_096)]
#[case(8_192)]
#[case(16_384)]
#[case(32_768)]
#[case(65_536)]
#[tokio::test]
async fn transfers_binary_data(
#[case] test_data_len: usize,
) {
let (channel1, channel2) = channel_mock_pair(Random::random(), Random::random());
test_async_stream(
channel1,
channel2,
TestOptions::random()
.with_data_len(test_data_len),
).await;
}
#[rstest]
#[case(random_number(6..=8))]
#[case(random_number(12..=16))]
#[case(random_number(25..=32))]
#[case(random_number(53..=64))]
#[case(random_number(100..=128))]
#[case(random_number(200..=256))]
#[tokio::test]
async fn transfers_stream_data(
#[case] items_count: usize,
) {
let (channel1, channel2) = channel_mock_pair(Random::random(), Random::random());
let channel1 = create_framed_stream::<TestStreamMessage, _>(channel1);
let channel2 = create_framed_stream::<TestStreamMessage, _>(channel2);
test_framed_stream(
channel1,
channel2,
TestOptions::random()
.with_data_len(items_count),
).await;
}
}