connection-utils 0.8.0

Connection related utilities.
Documentation
use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, AsyncReadExt, split, WriteHalf, ReadHalf};

use cs_utils::{random_str, futures::wait_random};

use crate::test::TestOptions;

pub async fn test_async_stream_sends_data<TAsyncDuplex: AsyncRead + AsyncWrite + Send + ?Sized + 'static>(
    mut channel: Pin<Box<TAsyncDuplex>>,
    options: TestOptions,
    test_data: String,
) -> Pin<Box<TAsyncDuplex>> {
    let mut i = 0;
    let data = test_data.as_bytes().to_vec();

    while i < test_data.len() {
        wait_random(options.latency_range()).await;
       
        let bytes_sent = channel
            .write(&data[i..]).await
            .expect("Cannot send a message.");

        assert!(
            bytes_sent > 0,
            "No bytes sent.",
        );
            
        i += bytes_sent as usize;
    }

    return channel;
}

pub async fn test_async_stream_receives_data<TAsyncDuplex: AsyncRead + AsyncWrite + Send + ?Sized + 'static>(
    mut channel: Pin<Box<TAsyncDuplex>>,
    options: TestOptions,
    test_data: String,
) -> Pin<Box<TAsyncDuplex>> {
    let mut received_data = String::new();

    let mut data = [0; 4096];
    loop {
        wait_random(options.latency_range()).await;

        let bytes_read = channel
            .read(&mut data).await
            .expect("Cannot receive message.");

        let message_str = std::str::from_utf8(&data[..bytes_read])
            .expect("Cannot parse UTF8 message.")
            .to_string();

        received_data = format!("{}{}", &received_data, message_str);

        assert!(
            test_data.starts_with(&received_data),
            "Data corruption, received data is not subset of the test data.",
        );

        if received_data.len() == test_data.len() {
            assert_eq!(
                received_data,
                test_data,
                "Sent and received data must match.",
            );

            break;
        }
    }

    return channel;
}

pub async fn test_async_stream_data_transfer<TAsyncDuplex: AsyncRead + AsyncWrite + Send + ?Sized + 'static>(
    channel1: Pin<Box<TAsyncDuplex>>,
    channel2: Pin<Box<TAsyncDuplex>>,
    options: TestOptions,
) -> (Pin<Box<TAsyncDuplex>>, Pin<Box<TAsyncDuplex>>) {
    let test_data = random_str(options.data_len());

    return tokio::try_join!(
        tokio::spawn(test_async_stream_sends_data(channel1, options.clone(), test_data.clone())),
        tokio::spawn(test_async_stream_receives_data(channel2, options.clone(), test_data.clone())),
    ).unwrap();
}

pub async fn test_async_half_sends_data<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
    mut channel: WriteHalf<TAsyncDuplex>,
    options: TestOptions,
    test_data: String,
) -> WriteHalf<TAsyncDuplex> {
    let mut i = 0;
    let data = test_data.as_bytes().to_vec();

    while i < test_data.len() {
        wait_random(options.latency_range()).await;

        let bytes_sent = channel
            .write(&data[i..]).await
            .expect("Cannot send a message.");

        assert!(
            bytes_sent > 0,
            "No bytes sent.",
        );
            
        i += bytes_sent as usize;
    }

    return channel;
}

pub async fn test_async_half_receives_data<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
    mut channel: ReadHalf<TAsyncDuplex>,
    options: TestOptions,
    test_data: String,
) -> ReadHalf<TAsyncDuplex> {
    let mut received_data = String::new();

    let mut data = [0; 4096];

    loop {
        wait_random(options.latency_range()).await;

        let bytes_read = channel
            .read(&mut data).await
            .expect("Cannot receive message.");

        let message_str = std::str::from_utf8(&data[..bytes_read])
            .expect("Cannot parse UTF8 message.")
            .to_string();

        received_data = format!("{}{}", &received_data, message_str);

        assert!(
            test_data.starts_with(&received_data),
            "Data corruption, received data is not subset of the test data.",
        );

        if received_data.len() == test_data.len() {
            assert_eq!(
                received_data,
                test_data,
                "Sent and received data must match.",
            );

            break;
        }
    }

    return channel;
}

pub async fn test_async_stream_duplex<TAsyncDuplex: AsyncRead + AsyncWrite + Send + ?Sized + 'static>(
    channel1: Pin<Box<TAsyncDuplex>>,
    channel2: Pin<Box<TAsyncDuplex>>,
    options: TestOptions,
) -> (Pin<Box<TAsyncDuplex>>, Pin<Box<TAsyncDuplex>>) {
    let test_data = random_str(options.data_len());

    let (channel1_source, channel1_sink) = split(channel1);
    let (channel2_source, channel2_sink) = split(channel2);

    let reversed_test_data: String = test_data
        .chars()
        .rev()
        .collect();

    let (
        channel2_sink,
        channel1_source,
        channel1_sink,
        channel2_source,
    ) = tokio::try_join!(
        tokio::spawn(test_async_half_sends_data(channel2_sink, options.clone(), test_data.clone())),
        tokio::spawn(test_async_half_receives_data(channel1_source, options.clone(), test_data.clone())),
        // for other direction, use reversed data string to send something different than `test_data`
        tokio::spawn(test_async_half_sends_data(channel1_sink, options.clone(), reversed_test_data.clone())),
        tokio::spawn(test_async_half_receives_data(channel2_source, options.clone(), reversed_test_data.clone())),
    ).unwrap();

    let channel1 = channel1_source.unsplit(channel1_sink);
    let channel2 = channel2_source.unsplit(channel2_sink);

    return (channel1, channel2);
}

/// Test an `AsyncRead + AsyncWrite` stream data transfer
/// 
/// ### Examples
/// 
/// ```
/// use tokio::io::duplex;
/// 
/// #[tokio::main]
/// async fn main() {
///     // either `test` or the `all` feature must be enabled
///     #[cfg(any(feature = "test", feature = "all"))]
///     {
///         use connection_utils::test::{test_async_stream, TestOptions};
///         use cs_utils::random_str_rg;
///
///         // create stream to test
///         let (channel1, channel2) = duplex(4096);
///         
///         // create test data
///         let options = TestOptions::default()
///             .with_data_len(4096);
///         
///         // test data transfer
///         test_async_stream(
///             Box::new(channel1),
///             Box::new(channel2),
///             options,
///         ).await;
///         
///         println!("👌 data transfer succeeded");
///     }
/// }
/// ```
pub async fn test_async_stream<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static>(
    channel1: Box<TAsyncDuplex>,
    channel2: Box<TAsyncDuplex>,
    options: TestOptions,
) -> (Pin<Box<TAsyncDuplex>>, Pin<Box<TAsyncDuplex>>) {
    return test_async_stream_pinned(
        Pin::new(channel1),
        Pin::new(channel2),
        options,
    ).await;
}

pub async fn test_async_stream_pinned<TAsyncDuplex: AsyncRead + AsyncWrite + Send + ?Sized + 'static>(
    channel1: Pin<Box<TAsyncDuplex>>,
    channel2: Pin<Box<TAsyncDuplex>>,
    options: TestOptions,
) -> (Pin<Box<TAsyncDuplex>>, Pin<Box<TAsyncDuplex>>) {
    // test `channel1` to `channel2` direction
    let (channel1, channel2) = test_async_stream_data_transfer(
        channel1,
        channel2,
        options.clone()
    ).await;

    // test `channel2` to `channel1` direction
    let (channel2, channel1) = test_async_stream_data_transfer(
        channel2,
        channel1,
        options.clone(),
    ).await;

    // test bidirectional data transfer
    return test_async_stream_duplex(
        channel1,
        channel2,
        options,
    ).await;
}