connection-utils 0.8.0

Connection related utilities.
Documentation
use std::{pin::Pin, task::{Context, Poll}, io, ops::RangeInclusive, fmt};

use futures::{Future, ready};
use cs_utils::{random_number, random_str, futures::wait_random, traits::Random};
use tokio::{io::{duplex, AsyncRead, AsyncWrite, ReadBuf, DuplexStream}, sync::watch};

use crate::Channel;

pub struct ChannelMock<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static = DuplexStream> {
    id: u16,
    label: String,
    channel: Pin<Box<TAsyncDuplex>>,
    options: ChannelMockOptions,
    read_delay_future: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
    write_delay_future: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
    on_close: watch::Receiver<bool>,
    on_close_sender: watch::Sender<bool>,
}

impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> ChannelMock<TAsyncDuplex> {
    pub fn new(
        channel: Box<TAsyncDuplex>,
        options: ChannelMockOptions,
    ) -> Box<dyn Channel> {
        let (on_close_sender, on_close) = watch::channel(false);

        return Box::new(
            ChannelMock {
                id: options.id,
                label: options.label.clone(),
                channel: Pin::new(channel),
                options,
                read_delay_future: None,
                write_delay_future: None,
                on_close,
                on_close_sender,
            },
        );
    }
}

#[derive(Debug, Clone, PartialEq)]
pub struct ChannelMockOptions {
    id: u16,
    label: String,
    latency_range: RangeInclusive<u64>,
    buffer_size: u32,
}

impl ChannelMockOptions {
    pub fn with_id(
        self,
        id: u16,
    ) -> ChannelMockOptions {
        return ChannelMockOptions {
            id,
            ..self
        };
    }

    pub fn with_label(
        self,
        label: impl AsRef<str> + ToString,
    ) -> ChannelMockOptions {
        return ChannelMockOptions {
            label: label.to_string(),
            ..self
        };
    }

    pub fn with_latency(
        self,
        latency_range: RangeInclusive<u64>,
    ) -> ChannelMockOptions {
        return ChannelMockOptions {
            latency_range,
            ..self
        };
    }

    pub fn with_buffer_size(
        self,
        buffer_size: u32,
    ) -> ChannelMockOptions {
        return ChannelMockOptions {
            buffer_size,
            ..self
        };
    }
}

impl Random for ChannelMockOptions {
    fn random() -> Self {
        let min = random_number(0..5);
        let max = random_number(5..=50);

        return ChannelMockOptions::default()
            .with_latency(min..=max);
    }
}

impl Default for ChannelMockOptions {
    fn default() -> ChannelMockOptions {
        return ChannelMockOptions {
            id: random_number(0..=u16::MAX),
            label: format!("channel-mock-{}", random_str(8)),
            latency_range: (0..=0),
            buffer_size: 4_096,
        };
    }
}

pub fn channel_mock_pair(
    options1: ChannelMockOptions,
    options2: ChannelMockOptions,
) -> (Box<dyn Channel>, Box<dyn Channel>) {
    let (channel1, channel2) = duplex(options1.buffer_size as usize);

    return (
        ChannelMock::new(Box::new(channel1), options1),
        ChannelMock::new(Box::new(channel2), options2),
    );
}

impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> Channel for ChannelMock<TAsyncDuplex> {
    fn id(&self) -> u16 {
        return self.id;
    }

    fn label(&self) ->  &String {
        return &self.label;
    }

    fn is_closed(&self) -> bool {
        return *self.on_close.borrow();
    }

    fn on_close(&self) -> watch::Receiver<bool> {
        return watch::Receiver::clone(&self.on_close);
    }

    fn buffer_size(&self) -> u32 {
        return self.options.buffer_size;
    }
}

impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> fmt::Debug for ChannelMock<TAsyncDuplex> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        return self.debug("ChannelMock", f);
    }
}

impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncRead for ChannelMock<TAsyncDuplex> {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        // if delay future present, wait until it completes
        if let Some(read_delay_future) = self.read_delay_future.as_mut() {
            ready!(read_delay_future.as_mut().poll(cx));

            self.read_delay_future.take();
        }

        // otherwise run the read future to completion
        let result = ready!(self.channel.as_mut().poll_read(cx, buf));

        self.read_delay_future = Some(Box::pin(wait_random(self.options.latency_range.clone())));

        return Poll::Ready(result);
    }
}

impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncWrite for ChannelMock<TAsyncDuplex> {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        // if delay future present, wait until it completes
        if let Some(write_delay_future) = self.write_delay_future.as_mut() {
            ready!(write_delay_future.as_mut().poll(cx));

            self.write_delay_future.take();
        }

        let result = ready!(self.channel.as_mut().poll_write(cx, buf));

        self.write_delay_future = Some(Box::pin(wait_random(self.options.latency_range.clone())));

        return Poll::Ready(result);
    }

    fn poll_flush(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<io::Result<()>> {
        return self.channel.as_mut()
            .poll_flush(cx);
    }

    fn poll_shutdown(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<io::Result<()>> {
        // if delay future present, wait until it completes
        if let Some(read_delay_future) = self.read_delay_future.as_mut() {
            ready!(read_delay_future.as_mut().poll(cx));

            self.read_delay_future.take();
        }

        let result = ready!(self.channel.as_mut().poll_shutdown(cx));

        let _err = self.on_close_sender.send(true);

        return Poll::Ready(result);
    }
}

#[cfg(test)]
mod tests {
    use rstest::rstest;
    
    use cs_utils::{traits::Random, random_number};
    
    use crate::utils::create_framed_stream;
    use crate::test::{TestStreamMessage, test_async_stream, test_framed_stream, TestOptions};

    use super::channel_mock_pair;

    #[rstest]
    #[case(128)]
    #[case(256)]
    #[case(512)]
    #[case(1_024)]
    #[case(2_048)]
    #[case(4_096)]
    #[case(8_192)]
    #[case(16_384)]
    #[case(32_768)]
    #[case(65_536)]
    #[tokio::test]
    async fn transfers_binary_data(
        #[case] test_data_len: usize,
    ) {
        let (channel1, channel2) = channel_mock_pair(Random::random(), Random::random());

        test_async_stream(
            channel1,
            channel2,
            TestOptions::random()
                .with_data_len(test_data_len),
        ).await;
    }

    #[rstest]
    #[case(random_number(6..=8))]
    #[case(random_number(12..=16))]
    #[case(random_number(25..=32))]
    #[case(random_number(53..=64))]
    #[case(random_number(100..=128))]
    #[case(random_number(200..=256))]
    #[tokio::test]
    async fn transfers_stream_data(
        #[case] items_count: usize,
    ) {
        let (channel1, channel2) = channel_mock_pair(Random::random(), Random::random());

        let channel1 = create_framed_stream::<TestStreamMessage, _>(channel1);
        let channel2 = create_framed_stream::<TestStreamMessage, _>(channel2);

        test_framed_stream(
            channel1,
            channel2,
            TestOptions::random()
                .with_data_len(items_count),
        ).await;
    }
}