use std::fmt;
use std::{pin::Pin, sync::Arc};
use cs_utils::random_number;
use anyhow::{Result, anyhow};
use tokio_util::codec::Framed;
use futures::stream::SplitSink;
use futures::{SinkExt, Future, future};
use serde::{Serialize, de::DeserializeOwned};
use tokio::{io::AsyncWrite, sync::{Mutex, mpsc::{Receiver, self, Sender}}, task::JoinHandle};
use crate::codecs::GenericCodec;
async fn send_job<
T: Serialize + DeserializeOwned + Send + fmt::Debug + fmt::Display + 'static,
TAsyncDuplex: AsyncWrite + Send + ?Sized + 'static,
>(
mut messages_channel: Receiver<T>,
stream: Arc<Mutex<SplitSink<Framed<Pin<Box<TAsyncDuplex>>, GenericCodec<T>>, T>>>,
) -> Result<()> {
while let Some(message) = messages_channel.recv().await {
let message_str = format!("{}", message);
println!("[send_job]> sending message: {}", message_str);
{
let mut lock = stream.lock().await;
println!("[send_job]> got lock for {}", message_str);
let result = lock
.send(message).await;
if let Err(error) = result {
println!("[send_job]> error: {}", error);
Err(error)?;
}
}
println!("[send_job]> message {} sent", message_str);
}
println!("[send_job]> end");
return Ok(());
}
fn create_send_job<
T: Serialize + DeserializeOwned + Send + fmt::Debug + fmt::Display + 'static,
TAsyncDuplex: AsyncWrite + Send + ?Sized + 'static,
>(
stream: Arc<Mutex<SplitSink<Framed<Pin<Box<TAsyncDuplex>>, GenericCodec<T>>, T>>>,
) -> (Sender<T>, Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>) {
let (sender, receiver) = mpsc::channel(100);
let job = Box::pin(
send_job(receiver, stream),
);
return (sender, job);
}
pub fn channel_into_framed_stream<
T: Serialize + DeserializeOwned + Send + fmt::Debug + fmt::Display + 'static,
TAsyncDuplex: AsyncWrite + Send + ?Sized + 'static,
>(
stream: SplitSink<Framed<Pin<Box<TAsyncDuplex>>, GenericCodec<T>>, T>,
) -> (Sender<T>, JoinHandle<Result<()>>) {
let (sender, mut receiver) = mpsc::channel(100);
let stream = Arc::new(Mutex::new(stream));
let mut senders = vec![];
let mut jobs = vec![];
while jobs.len() < 5 {
let (sender, job) = create_send_job(Arc::clone(&stream));
senders.push(sender);
jobs.push(job);
}
let forward_future = Box::pin(async move {
while let Some(message) = receiver.recv().await {
let sender_index = random_number(0..senders.len());
let sender = &mut senders[sender_index];
sender.send(message).await
.map_err(|error| {
println!("[forward_future]> error: {}", error);
return anyhow!("{}", error);
})?;
}
return Ok(());
});
jobs.push(forward_future);
let join_handle = tokio::spawn(async move {
let result = future::try_join_all(jobs).await;
println!("[forward_future]> result: {:?}", result);
result?;
return Ok(());
});
return (sender, join_handle);
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use tokio::io::duplex;
use futures::StreamExt;
use tokio::sync::mpsc::Sender;
use cs_utils::{test::random_vec, random_number};
use crate::{utils::create_framed_stream, test::TestStreamMessage};
use super::channel_into_framed_stream;
#[rstest]
#[case(64)]
#[case(128)]
#[case(256)]
#[case(512)]
#[case(1_024)]
#[case(2_048)]
#[case(4_096)]
#[case(8_192)]
#[tokio::test]
async fn transfers_messages(
#[case] data_len: usize,
) {
let (duplex1, duplex2) = duplex(16_384);
let mut main_channel_source = create_framed_stream::<TestStreamMessage, _>(Box::new(duplex1));
let duplex2 = create_framed_stream::<TestStreamMessage, _>(Box::new(duplex2));
let (duplex2_sink, _duplex2_source) = duplex2.split();
let (sender, _) = channel_into_framed_stream(duplex2_sink);
let senders_count = 5;
let mut senders = vec![];
while senders.len() < senders_count {
senders.push(Sender::clone(&sender));
}
let mut test_data = random_vec::<TestStreamMessage>(data_len as u32);
let mut test_data_to_send = test_data.clone();
tokio::join!(
Box::pin(async move {
while let Some(test_message) = test_data_to_send.pop() {
let sender_index = random_number(0..senders.len());
let sender = &mut senders[sender_index];
sender.send(test_message).await
.expect("Cannot send message.")
}
}),
Box::pin(async move {
let mut received_data = vec![];
while let Some(maybe_message) = main_channel_source.next().await {
let message = maybe_message
.expect("Failed to receive message.");
received_data.push(message);
if received_data.len() == test_data.len() {
break;
}
}
received_data.sort();
test_data.sort();
assert_eq!(
received_data,
test_data,
"Must receive the correct data.",
);
}),
);
}
}