connection-utils 0.8.0

Connection related utilities.
Documentation
use std::fmt;
use std::{pin::Pin, sync::Arc};

use cs_utils::random_number;
use anyhow::{Result, anyhow};
use tokio_util::codec::Framed;
use futures::stream::SplitSink;
use futures::{SinkExt, Future, future};
use serde::{Serialize, de::DeserializeOwned};
use tokio::{io::AsyncWrite, sync::{Mutex, mpsc::{Receiver, self, Sender}}, task::JoinHandle};

use crate::codecs::GenericCodec;

/// Create a job to send message to the stream. Used inside a newly
/// spawned task so the send does not block other channels.
async fn send_job<
    T: Serialize + DeserializeOwned + Send + fmt::Debug + fmt::Display + 'static,
    TAsyncDuplex: AsyncWrite + Send + ?Sized + 'static,
>(
    mut messages_channel: Receiver<T>,
    stream: Arc<Mutex<SplitSink<Framed<Pin<Box<TAsyncDuplex>>, GenericCodec<T>>, T>>>,
) -> Result<()> {
    while let Some(message) = messages_channel.recv().await {
        let message_str = format!("{}", message);
        println!("[send_job]> sending message: {}", message_str);

        {
            let mut lock = stream.lock().await;

            println!("[send_job]> got lock for {}", message_str);

            let result = lock
                .send(message).await;

            if let Err(error) = result {
                println!("[send_job]> error: {}", error);

                Err(error)?;
            }

            // lock.flush().await.unwrap();
        }

        println!("[send_job]> message {} sent", message_str);
    }

    println!("[send_job]> end");

    return Ok(());
}

/// Factory to create a `send_job` above.
fn create_send_job<
    T: Serialize + DeserializeOwned + Send + fmt::Debug + fmt::Display + 'static,
    TAsyncDuplex: AsyncWrite + Send + ?Sized + 'static,
>(
    stream: Arc<Mutex<SplitSink<Framed<Pin<Box<TAsyncDuplex>>, GenericCodec<T>>, T>>>,
) -> (Sender<T>, Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>) {
    let (sender, receiver) = mpsc::channel(100);

    let job = Box::pin(
        send_job(receiver, stream),
    );

    return (sender, job);
}

/// Create a `mpsc` channel that points to the framed stream.
pub fn channel_into_framed_stream<
    T: Serialize + DeserializeOwned + Send + fmt::Debug + fmt::Display + 'static,
    TAsyncDuplex: AsyncWrite + Send + ?Sized + 'static,
>(
    stream: SplitSink<Framed<Pin<Box<TAsyncDuplex>>, GenericCodec<T>>, T>,
) -> (Sender<T>, JoinHandle<Result<()>>) {
    let (sender, mut receiver) = mpsc::channel(100);
    let stream = Arc::new(Mutex::new(stream));

    let mut senders = vec![];
    let mut jobs = vec![];

    while jobs.len() < 5 {
        let (sender, job) = create_send_job(Arc::clone(&stream));

        senders.push(sender);
        jobs.push(job);
    }

    let forward_future = Box::pin(async move {
        while let Some(message) = receiver.recv().await {
            let sender_index = random_number(0..senders.len());
            let sender = &mut senders[sender_index];
            
            sender.send(message).await
                .map_err(|error| {
                    println!("[forward_future]> error: {}", error);

                    return anyhow!("{}", error);
                })?;
        }

        return Ok(());
    });

    jobs.push(forward_future);

    let join_handle = tokio::spawn(async move {
        let result = future::try_join_all(jobs).await;

        println!("[forward_future]> result: {:?}", result);

        result?;

        return Ok(());
    });
    
    return (sender, join_handle);
}

#[cfg(test)]
mod tests {
    use rstest::rstest;
    use tokio::io::duplex;
    use futures::StreamExt;
    use tokio::sync::mpsc::Sender;
    use cs_utils::{test::random_vec, random_number};

    use crate::{utils::create_framed_stream, test::TestStreamMessage};
    use super::channel_into_framed_stream;

    #[rstest]
    #[case(64)]
    #[case(128)]
    #[case(256)]
    #[case(512)]
    #[case(1_024)]
    #[case(2_048)]
    #[case(4_096)]
    #[case(8_192)]
    #[tokio::test]
    async fn transfers_messages(
        #[case] data_len: usize,
    ) {
        let (duplex1, duplex2) = duplex(16_384);

        let mut main_channel_source = create_framed_stream::<TestStreamMessage, _>(Box::new(duplex1));
        let duplex2 = create_framed_stream::<TestStreamMessage, _>(Box::new(duplex2));

        let (duplex2_sink, _duplex2_source) = duplex2.split();

        let (sender, _) = channel_into_framed_stream(duplex2_sink);

        let senders_count = 5;
        let mut senders = vec![];
        while senders.len() < senders_count {
            senders.push(Sender::clone(&sender));
        }

        let mut test_data = random_vec::<TestStreamMessage>(data_len as u32);
        let mut test_data_to_send = test_data.clone();
        
        tokio::join!(
            Box::pin(async move {
                while let Some(test_message) = test_data_to_send.pop() {
                    let sender_index = random_number(0..senders.len());
                    let sender = &mut senders[sender_index];

                    sender.send(test_message).await
                        .expect("Cannot send message.")
                }
            }),
            Box::pin(async move {
                let mut received_data = vec![];

                while let Some(maybe_message) = main_channel_source.next().await {
                    let message = maybe_message
                        .expect("Failed to receive message.");

                    received_data.push(message);

                    if received_data.len() == test_data.len() {
                        break;
                    }
                }

                received_data.sort();
                test_data.sort();

                assert_eq!(
                    received_data,
                    test_data,
                    "Must receive the correct data.",
                );
            }),
        );
    }
}