connection_utils/utils/
duplex_stream.rs

1use std::io;
2use std::fmt;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use tokio::io::{AsyncRead, AsyncWrite};
7
8/// Join AsyncRead and AsyncWrite into a single `AsyncRead + AsyncWrite` object.
9pub struct DuplexStream {
10    reader: Box<dyn AsyncRead + Unpin + Send + 'static>,
11    writer: Box<dyn AsyncWrite + Unpin + Send + 'static>,
12}
13
14impl DuplexStream {
15    pub fn new(
16        reader: impl AsyncRead + Unpin + Send + 'static,
17        writer: impl AsyncWrite + Unpin + Send + 'static,
18    ) -> Box<impl AsyncRead + AsyncWrite + Unpin + Send + 'static> {
19        return Box::new(
20            DuplexStream {
21                reader: Box::new(reader),
22                writer: Box::new(writer),
23            },
24        );
25    }
26}
27
28impl AsyncRead for DuplexStream {
29    fn poll_read(
30        mut self: Pin<&mut DuplexStream>,
31        cx: &mut Context<'_>,
32        buf: &mut tokio::io::ReadBuf<'_>,
33    ) -> Poll<io::Result<()>> {
34        return AsyncRead::poll_read(Pin::new(&mut self.reader), cx, buf);
35    }
36}
37
38impl AsyncWrite for DuplexStream {
39    fn poll_write(
40        mut self: Pin<&mut DuplexStream>,
41        cx: &mut Context<'_>,
42        buf: &[u8],
43    ) -> Poll<Result<usize, io::Error>> {
44        return AsyncWrite::poll_write(Pin::new(&mut self.writer), cx, buf);
45    }
46
47    fn poll_flush(
48        mut self: Pin<&mut DuplexStream>,
49        cx: &mut Context<'_>,
50    ) -> Poll<Result<(), io::Error>> {
51        return AsyncWrite::poll_flush(Pin::new(&mut self.writer), cx);
52    }
53
54    fn poll_shutdown(
55        mut self: Pin<&mut DuplexStream>,
56        cx: &mut Context<'_>,
57    ) -> Poll<Result<(), io::Error>> {
58        return AsyncWrite::poll_shutdown(Pin::new(&mut self.writer), cx);
59    }
60}
61
62impl fmt::Debug for DuplexStream {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        return f.debug_struct("DuplexStream")
65            .finish();
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use rstest::rstest;
72    use tokio::io::split;
73    use cs_utils::traits::Random;
74
75    use super::*;
76    use crate::{
77        test::{TestOptions, test_async_stream},
78        mocks::{ChannelMockOptions, channel_mock_pair},
79    };
80
81    mod data_transfer {
82        use super::*;
83
84        #[rstest]
85        #[case(128)]
86        #[case(256)]
87        #[case(512)]
88        #[case(1_024)]
89        #[case(2_048)]
90        #[case(4_096)]
91        #[case(8_192)]
92        #[case(16_384)]
93        #[case(32_768)]
94        #[tokio::test]
95        async fn transfers_binary_data(
96            #[case] test_data_size: usize,
97        ) {
98            let (channel1, channel2) = channel_mock_pair(
99                ChannelMockOptions::random(),
100                ChannelMockOptions::random(),
101            );
102
103            let (channel11, channel12) = split(channel1);
104            let (channel21, channel22) = split(channel2);
105
106            let channel1 = DuplexStream::new(channel11, channel12);
107            let channel2 = DuplexStream::new(channel21, channel22);
108
109            test_async_stream(
110                channel1,
111                channel2,
112                TestOptions::random()
113                    .with_data_len(test_data_size),
114            ).await;
115        }
116    }
117
118}