use std::{pin::Pin, task::{Context, Poll, Waker}, io, fmt};
use futures::ready;
use tokio::{io::{AsyncRead, AsyncWrite, ReadBuf}, sync::watch};
use crate::Channel;
pub struct TransportChannel<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> {
id: u16,
label: String,
channel: Pin<Box<TAsyncDuplex>>,
is_closed: bool,
is_read_closed: bool,
is_shutdown_requested: bool,
read_waker: Option<Waker>,
self_closed: watch::Receiver<bool>,
remote_closed: watch::Receiver<bool>,
local_closed: watch::Sender<bool>,
buffer_size: u32,
}
impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> TransportChannel<TAsyncDuplex> {
pub fn new_pair(
id: u16,
label: impl AsRef<str> + ToString,
channels: (Box<TAsyncDuplex>, Box<TAsyncDuplex>),
buffer_size: u32,
) -> (Box<dyn Channel>, Box<dyn Channel>) {
let (channel1, channel2) = channels;
let (local_closed1, remote_closed1) = watch::channel(false);
let (local_closed2, remote_closed2) = watch::channel(false);
let label = label.to_string();
let label1 = format!("{label}-1");
let label2 = format!("{label}-2");
let self_closed1 = remote_closed1.clone();
let self_closed2 = remote_closed2.clone();
let channel1 = Box::new(
TransportChannel {
id,
label: label1,
channel: Pin::new(channel1),
is_closed: false,
is_read_closed: false,
is_shutdown_requested: false,
read_waker: None,
self_closed: self_closed1,
remote_closed: remote_closed2,
local_closed: local_closed1,
buffer_size,
},
);
let channel2 = Box::new(
TransportChannel {
id,
label: label2,
channel: Pin::new(channel2),
is_closed: false,
is_read_closed: false,
is_shutdown_requested: false,
read_waker: None,
self_closed: self_closed2,
remote_closed: remote_closed1,
local_closed: local_closed2,
buffer_size
},
);
return (channel1, channel2)
}
fn is_remote_closed(&self) -> bool {
return *self.remote_closed.borrow();
}
}
impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> Channel for TransportChannel<TAsyncDuplex> {
fn id(&self) -> u16 {
return self.id;
}
fn label(&self) -> &String {
return &self.label;
}
fn is_closed(&self) -> bool {
return self.is_closed;
}
fn on_close(&self) -> watch::Receiver<bool> {
return self.self_closed.clone();
}
fn buffer_size(&self) -> u32 {
return self.buffer_size;
}
}
impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> AsyncRead for TransportChannel<TAsyncDuplex> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if self.is_shutdown_requested && !self.is_closed {
let result = ready!(self.as_mut().poll_shutdown(cx));
self.is_read_closed = true;
return Poll::Ready(result);
}
if self.is_closed && self.is_read_closed {
return Poll::Ready(Ok(()));
}
let filled_before = buf.filled().len();
let result = self.channel.as_mut().poll_read(cx, buf);
let bytes_read = buf.filled().len() - filled_before;
if self.is_closed && !self.is_read_closed {
self.is_read_closed = true;
return Poll::Ready(Ok(()));
}
if result.is_pending() {
self.read_waker.replace(cx.waker().clone());
} else {
self.read_waker.take();
if self.is_remote_closed() {
self.is_shutdown_requested = true;
if bytes_read == 0 {
return self.poll_shutdown(cx);
}
}
}
return result;
}
}
impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> AsyncWrite for TransportChannel<TAsyncDuplex> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if self.is_remote_closed() {
return Poll::Ready(Ok(0));
}
let result = self.channel.as_mut()
.poll_write(cx, buf);
return result;
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
return self.channel.as_mut()
.poll_flush(cx);
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.is_closed {
return Poll::Ready(Ok(()));
}
let result = ready!(self.channel.as_mut().poll_shutdown(cx));
self.is_closed = true;
let _res = self.local_closed.send(true);
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
return Poll::Ready(result);
}
}
impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> fmt::Debug for TransportChannel<TAsyncDuplex> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
return self.debug("TransportChannel", f);
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use futures::{SinkExt, StreamExt};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use cs_utils::{traits::Random, futures::wait_random, test::random_vec, random_number, random_str_rg};
use super::TransportChannel;
use crate::create_framed_stream;
use crate::mocks::{channel_mock_pair, ChannelMockOptions};
use crate::test::{test_framed_stream, test_async_stream, TestOptions, TestStreamMessage};
#[rstest]
#[case(128)]
#[case(256)]
#[case(512)]
#[case(1_024)]
#[case(2_048)]
#[case(4_096)]
#[case(8_192)]
#[case(16_384)]
#[case(32_768)]
#[tokio::test]
async fn transfers_binary_data(
#[case] test_data_size: usize,
) {
let (channel1, channel2) = channel_mock_pair(
ChannelMockOptions::random(),
ChannelMockOptions::random(),
);
let (channel1, channel2) = TransportChannel::new_pair(
1,
"in-memory-channel-1",
(Box::new(channel1), Box::new(channel2)),
4_096,
);
test_async_stream(
channel1,
channel2,
TestOptions::random()
.with_data_len(test_data_size),
).await;
}
#[rstest]
#[case(random_number(6..=8))]
#[case(random_number(12..=16))]
#[case(random_number(25..=32))]
#[case(random_number(53..=64))]
#[case(random_number(100..=128))]
#[case(random_number(200..=256))]
#[tokio::test]
async fn transfers_stream_data(
#[case] items_count: usize,
) {
let (channel1, channel2) = channel_mock_pair(
ChannelMockOptions::random(),
ChannelMockOptions::random(),
);
let (channel1, channel2) = TransportChannel::new_pair(
1,
"in-memory-channel-1",
(Box::new(channel1), Box::new(channel2)),
4_096,
);
let channel1 = create_framed_stream::<TestStreamMessage, _>(channel1);
let channel2 = create_framed_stream::<TestStreamMessage, _>(channel2);
test_framed_stream(
channel1,
channel2,
TestOptions::random()
.with_data_len(items_count),
).await;
}
#[rstest]
#[case(128)]
#[case(256)]
#[case(512)]
#[case(1_024)]
#[case(2_048)]
#[tokio::test]
async fn reads_to_end_if_self_shutdown(
#[case] test_data_size: usize,
) {
let (channel1, channel2) = channel_mock_pair(
ChannelMockOptions::random(),
ChannelMockOptions::random(),
);
let (channel1, channel2) = TransportChannel::new_pair(
1,
"in-memory-channel-1",
(Box::new(channel1), Box::new(channel2)),
4_096,
);
let (channel1, mut channel2) = test_async_stream(
channel1,
channel2,
TestOptions::random()
.with_data_len(test_data_size),
).await;
wait_random(25..=50).await;
let test_data = random_str_rg(8..=32);
channel2.write(test_data.as_bytes()).await.unwrap();
let (mut source, mut sink) = tokio::io::split(channel1);
tokio::join!(
Box::pin(async move {
wait_random(0..=5).await;
let mut buf = vec![];
let bytes_read = source.read_to_end(&mut buf).await
.expect("Cannot read to end.");
assert_eq!(
bytes_read,
test_data.len(),
"Closed channel must read {} bytes.",
test_data.len(),
);
}),
Box::pin(async move {
wait_random(0..=5).await;
sink.shutdown().await.unwrap();
}),
);
assert!(!channel2.is_closed(), "Channel2 must not be closed.");
}
#[rstest]
#[case(128)]
#[case(256)]
#[case(512)]
#[case(1_024)]
#[case(2_048)]
#[tokio::test]
async fn reads_if_self_shutdown(
#[case] test_data_size: usize,
) {
let (channel1, channel2) = channel_mock_pair(
ChannelMockOptions::random(),
ChannelMockOptions::random(),
);
let (channel1, channel2) = TransportChannel::new_pair(
1,
"in-memory-channel-1",
(Box::new(channel1), Box::new(channel2)),
4_096,
);
let (channel1, mut channel2) = test_async_stream(
channel1,
channel2,
TestOptions::random()
.with_data_len(test_data_size),
).await;
wait_random(25..=50).await;
let test_data = random_str_rg(8..=32);
channel2.write(test_data.as_bytes()).await.unwrap();
let (mut source, mut sink) = tokio::io::split(channel1);
tokio::join!(
Box::pin(async move {
wait_random(0..=5).await;
let mut buf = [0; 1024];
let bytes_read = source.read(&mut buf).await
.expect("Cannot read to end.");
assert_eq!(
bytes_read,
test_data.len(),
"Closed channel must read {} bytes.",
test_data.len(),
);
}),
Box::pin(async move {
wait_random(0..=5).await;
sink.shutdown().await.unwrap();
}),
);
assert!(!channel2.is_closed(), "Channel2 must not be closed.");
}
#[rstest]
#[case(random_number(6..=8))]
#[case(random_number(12..=16))]
#[case(random_number(25..=32))]
#[case(random_number(53..=64))]
#[case(random_number(100..=128))]
#[case(random_number(200..=256))]
#[tokio::test]
async fn closes_stream_if_self_is_closed(
#[case] items_count: u32,
) {
let (channel1, channel2) = channel_mock_pair(
ChannelMockOptions::random(),
ChannelMockOptions::random(),
);
let (channel1, channel2) = TransportChannel::new_pair(
1,
"in-memory-channel-1",
(Box::new(channel1), Box::new(channel2)),
4_096,
);
let channel1 = create_framed_stream::<TestStreamMessage, _>(channel1);
let channel2 = create_framed_stream::<TestStreamMessage, _>(channel2);
let (channel1, mut channel2) = test_framed_stream(
channel1,
channel2,
TestOptions::random()
.with_data_len(10),
).await;
let (mut sink, mut source) = channel1.split();
let test_messages = random_vec::<TestStreamMessage>(items_count);
let messages_to_send = test_messages.clone();
let mut received_messages = vec![];
tokio::join!(
Box::pin(async move {
while let Some(message) = source.next().await {
received_messages.push(message);
}
}),
Box::pin(async move {
for message in messages_to_send {
channel2.send(message).await.unwrap();
}
sink.close().await.unwrap();
}),
);
}
#[rstest]
#[case(random_number(6..=8))]
#[case(random_number(12..=16))]
#[case(random_number(25..=32))]
#[case(random_number(53..=64))]
#[case(random_number(100..=128))]
#[case(random_number(200..=256))]
#[tokio::test]
async fn closes_stream_if_remote_counterpart_is_closed(
#[case] items_count: u32,
) {
let (channel1, channel2) = channel_mock_pair(
ChannelMockOptions::random(),
ChannelMockOptions::random(),
);
let (channel1, channel2) = TransportChannel::new_pair(
1,
"in-memory-channel-1",
(Box::new(channel1), Box::new(channel2)),
4_096,
);
let channel1 = create_framed_stream::<TestStreamMessage, _>(channel1);
let channel2 = create_framed_stream::<TestStreamMessage, _>(channel2);
let (mut channel1, mut channel2) = test_framed_stream(
channel1,
channel2,
TestOptions::random()
.with_data_len(10),
).await;
let test_messages = random_vec::<TestStreamMessage>(items_count);
let messages_to_send = test_messages.clone();
let mut received_messages = vec![];
tokio::join!(
Box::pin(async move {
while let Some(message) = channel1.next().await {
received_messages.push(message);
}
assert!(channel1.get_ref().is_closed(), "Channel must be closed.");
}),
Box::pin(async move {
for message in messages_to_send {
channel2.send(message).await.unwrap();
}
channel2.close().await.unwrap();
}),
);
}
#[rstest]
#[case(128)]
#[case(256)]
#[case(512)]
#[case(1_024)]
#[case(2_048)]
#[tokio::test]
async fn reads_to_end_if_remote_counterpart_is_closed(
#[case] test_data_size: usize,
) {
let (channel1, channel2) = channel_mock_pair(
ChannelMockOptions::random(),
ChannelMockOptions::random(),
);
let (channel1, channel2) = TransportChannel::new_pair(
1,
"in-memory-channel-1",
(Box::new(channel1), Box::new(channel2)),
4_096,
);
let (mut channel1, mut channel2) = test_async_stream(
channel1,
channel2,
TestOptions::random()
.with_data_len(test_data_size),
).await;
let test_data = random_str_rg(8..=32);
channel2.write(test_data.as_bytes()).await.unwrap();
tokio::join!(
Box::pin(async move {
wait_random(0..=5).await;
let mut buf = vec![];
let bytes_read = channel1.read_to_end(&mut buf).await
.expect("Cannot read to end.");
assert_eq!(
bytes_read,
test_data.len(),
"Closed channel must read {} bytes.",
test_data.len(),
);
assert!(
channel1.is_closed(),
"Channel must be closed after remote counterpart is closed.",
);
}),
Box::pin(async move {
wait_random(0..=5).await;
channel2.shutdown().await.unwrap();
}),
);
}
#[rstest]
#[case(128)]
#[case(256)]
#[case(512)]
#[case(1_024)]
#[case(2_048)]
#[tokio::test]
async fn reads_if_remote_counterpart_is_closed(
#[case] test_data_size: usize,
) {
let (channel1, channel2) = channel_mock_pair(
ChannelMockOptions::random(),
ChannelMockOptions::random(),
);
let (channel1, channel2) = TransportChannel::new_pair(
1,
"in-memory-channel-1",
(Box::new(channel1), Box::new(channel2)),
4_096,
);
let (mut channel1, mut channel2) = test_async_stream(
channel1,
channel2,
TestOptions::random()
.with_data_len(test_data_size),
).await;
let test_data = random_str_rg(8..=32);
channel2.write(test_data.as_bytes()).await.unwrap();
channel2.shutdown().await.unwrap();
assert!(
channel2.is_closed(),
"Channel2 must be closed.",
);
wait_random(3..=5).await;
let mut buf = [0; 1024];
let bytes_read = channel1.read(&mut buf).await
.expect("Cannot read to end.");
assert_eq!(
bytes_read,
test_data.len(),
"Closed channel must read {} bytes.",
test_data.len(),
);
let bytes_read = channel1.read(&mut buf).await
.expect("Cannot read to end.");
assert_eq!(
bytes_read,
0,
"Closed channel must read 0 bytes.",
);
assert!(
channel1.is_closed(),
"Channel must be closed after remote counterpart is closed.",
);
}
#[rstest]
#[case(128)]
#[case(256)]
#[case(512)]
#[case(1_024)]
#[case(2_048)]
#[tokio::test]
async fn fails_to_write_if_remote_counterpart_is_closed(
#[case] test_data_size: usize,
) {
let (channel1, channel2) = channel_mock_pair(
ChannelMockOptions::random(),
ChannelMockOptions::random(),
);
let (channel1, channel2) = TransportChannel::new_pair(
1,
"in-memory-channel-1",
(Box::new(channel1), Box::new(channel2)),
4_096,
);
let (mut channel1, mut channel2) = test_async_stream(
channel1,
channel2,
TestOptions::random()
.with_data_len(test_data_size),
).await;
channel2.shutdown().await.unwrap();
assert!(
channel2.write(b"anything").await.is_err(),
"Must fail to write to closed channel.",
);
assert!(
channel2.is_closed(),
"Channel2 must be closed.",
);
wait_random(3..=5).await;
let test_data = random_str_rg(24..=32);
let bytes_written = channel1.write(test_data.as_bytes()).await
.expect("Cannot write to channel.");
assert_eq!(
bytes_written,
0,
"Must write 0 bytes if remote channel is closed.",
);
}
#[rstest]
#[case(128)]
#[case(256)]
#[case(512)]
#[case(1_024)]
#[case(2_048)]
#[tokio::test]
async fn fails_to_write_if_self_is_closed(
#[case] test_data_size: usize,
) {
let (channel1, channel2) = channel_mock_pair(
ChannelMockOptions::random(),
ChannelMockOptions::random(),
);
let (channel1, channel2) = TransportChannel::new_pair(
1,
"in-memory-channel-1",
(Box::new(channel1), Box::new(channel2)),
4_096,
);
let (channel1, mut channel2) = test_async_stream(
channel1,
channel2,
TestOptions::random()
.with_data_len(test_data_size),
).await;
let (mut source, mut sink) = tokio::io::split(channel1);
let test_data = random_str_rg(24..=32);
channel2.write(test_data.as_bytes()).await
.expect("Cannot write data.");
sink.shutdown().await.unwrap();
assert!(
sink.write(b"something").await.is_err(),
"Must fail to write to closed channel.",
);
let mut buf = vec![];
let bytes_received = source.read_to_end(&mut buf).await
.expect("Cannot read data.");
assert_eq!(
bytes_received,
test_data.len(),
"Must be able to read to end if channel is closed.",
);
let channel1 = source.unsplit(sink);
assert!(
channel1.is_closed(),
"Channel must be closed.",
);
}
}