fixed_buffer_tokio/
async_read_write_chain.rs

1#![forbid(unsafe_code)]
2
3use std::task::Context;
4use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
5use tokio::macros::support::{Pin, Poll};
6
7/// A wrapper for a pair of structs.
8/// The first implements [`AsyncRead`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncRead.html).
9/// The second implements
10/// [`AsyncRead`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncRead.html)+[`AsyncWrite`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncWrite.html).
11///
12/// Passes reads through to the AsyncRead.
13/// Once the AsyncRead returns EOF, passes reads to AsyncRead+AsyncWrite.
14///
15/// Passes all writes through to the AsyncRead+AsyncWrite.
16///
17/// This is like [`tokio::io::AsyncReadExt::chain`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.chain)
18/// that passes through writes.
19/// This makes it usable with AsyncRead+AsyncWrite objects like
20/// [`tokio::net::TcpStream`](https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html)
21/// and
22/// [`tokio_rustls::server::TlsStream`](https://docs.rs/tokio-rustls/latest/tokio_rustls/server/struct.TlsStream.html).
23pub struct AsyncReadWriteChain<
24    'a,
25    R: tokio::io::AsyncRead + Send + Unpin,
26    RW: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin,
27> {
28    reader: Option<&'a mut R>,
29    read_writer: &'a mut RW,
30}
31
32impl<'a, R: AsyncRead + Send + Unpin, RW: AsyncRead + AsyncWrite + Send + Unpin>
33    AsyncReadWriteChain<'a, R, RW>
34{
35    /// See [`AsyncReadWriteChain`](struct.AsyncReadWriteChain.html).
36    pub fn new(reader: &'a mut R, read_writer: &'a mut RW) -> AsyncReadWriteChain<'a, R, RW> {
37        Self {
38            reader: Some(reader),
39            read_writer,
40        }
41    }
42}
43
44impl<'a, R: AsyncRead + Send + Unpin, RW: AsyncRead + AsyncWrite + Send + Unpin> AsyncRead
45    for AsyncReadWriteChain<'a, R, RW>
46{
47    fn poll_read(
48        self: Pin<&mut Self>,
49        cx: &mut Context<'_>,
50        buf: &mut ReadBuf<'_>,
51    ) -> Poll<Result<(), std::io::Error>> {
52        let mut_self = self.get_mut();
53        if let Some(ref mut reader) = mut_self.reader {
54            let before_len = buf.filled().len();
55            match Pin::new(&mut *reader).poll_read(cx, buf) {
56                Poll::Pending => return Poll::Pending,
57                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
58                Poll::Ready(Ok(())) => {
59                    let num_read = buf.filled().len() - before_len;
60                    if num_read > 0 {
61                        return Poll::Ready(Ok(()));
62                    } else {
63                        // EOF
64                        mut_self.reader = None;
65                        // Fall through.
66                    }
67                }
68            }
69        }
70        Pin::new(&mut mut_self.read_writer).poll_read(cx, buf)
71    }
72}
73
74impl<'a, R: AsyncRead + Send + Unpin, RW: AsyncRead + AsyncWrite + Send + Unpin> AsyncWrite
75    for AsyncReadWriteChain<'a, R, RW>
76{
77    fn poll_write(
78        self: Pin<&mut Self>,
79        cx: &mut Context<'_>,
80        buf: &[u8],
81    ) -> Poll<Result<usize, std::io::Error>> {
82        let mut_self = self.get_mut();
83        Pin::new(&mut mut_self.read_writer).poll_write(cx, buf)
84    }
85
86    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
87        let mut_self = self.get_mut();
88        Pin::new(&mut mut_self.read_writer).poll_flush(cx)
89    }
90
91    fn poll_shutdown(
92        self: Pin<&mut Self>,
93        cx: &mut Context<'_>,
94    ) -> Poll<Result<(), std::io::Error>> {
95        let mut_self = self.get_mut();
96        Pin::new(&mut mut_self.read_writer).poll_shutdown(cx)
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::super::*;
103    use fixed_buffer::escape_ascii;
104
105    #[tokio::test]
106    async fn both_empty() {
107        let mut reader = std::io::Cursor::new(b"");
108        let mut read_writer: AsyncFixedBuf<8> = AsyncFixedBuf::new();
109        let mut chain = AsyncReadWriteChain::new(&mut reader, &mut read_writer);
110        let mut buf = [b'.'; 8];
111        assert_eq!(
112            0,
113            tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
114                .await
115                .unwrap()
116        );
117        assert_eq!("........", escape_ascii(&buf));
118    }
119
120    #[tokio::test]
121    async fn doesnt_read_second_when_first_has_data() {
122        let mut reader = std::io::Cursor::new(b"abc");
123        let mut read_writer = FakeAsyncReadWriter::empty();
124        let mut chain = AsyncReadWriteChain::new(&mut reader, &mut read_writer);
125        let mut buf = [b'.'; 4];
126        assert_eq!(
127            3,
128            tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
129                .await
130                .unwrap()
131        );
132        assert_eq!("abc.", escape_ascii(&buf));
133    }
134
135    #[tokio::test]
136    async fn doesnt_read_second_when_first_returns_error() {
137        let mut reader = FakeAsyncReadWriter::new(vec![Err(err1()), Err(err1())]);
138        let mut read_writer = FakeAsyncReadWriter::empty();
139        let mut chain = AsyncReadWriteChain::new(&mut reader, &mut read_writer);
140        let mut buf = [b'.'; 4];
141        let err = tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
142            .await
143            .unwrap_err();
144        assert_eq!(std::io::ErrorKind::Other, err.kind());
145        assert_eq!("err1", err.to_string());
146        assert_eq!("....", escape_ascii(&buf));
147        tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
148            .await
149            .unwrap_err();
150    }
151
152    #[tokio::test]
153    async fn reads_second_when_first_empty() {
154        let mut reader = std::io::Cursor::new(b"");
155        let mut read_writer: AsyncFixedBuf<4> = AsyncFixedBuf::new();
156        read_writer.write_str("abc").unwrap();
157        let mut chain = AsyncReadWriteChain::new(&mut reader, &mut read_writer);
158        let mut buf = [b'.'; 4];
159        assert_eq!(
160            3,
161            tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
162                .await
163                .unwrap()
164        );
165        assert_eq!("abc.", escape_ascii(&buf));
166    }
167
168    #[tokio::test]
169    async fn reads_first_then_second() {
170        let mut reader = std::io::Cursor::new(b"ab");
171        let mut read_writer: AsyncFixedBuf<4> = AsyncFixedBuf::new();
172        read_writer.write_str("cd").unwrap();
173        let mut chain = AsyncReadWriteChain::new(&mut reader, &mut read_writer);
174        let mut buf = [b'.'; 4];
175        assert_eq!(
176            2,
177            tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
178                .await
179                .unwrap()
180        );
181        assert_eq!("ab..", escape_ascii(&buf));
182        assert_eq!(
183            2,
184            tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
185                .await
186                .unwrap()
187        );
188        assert_eq!("cd..", escape_ascii(&buf));
189    }
190
191    #[tokio::test]
192    async fn returns_error_from_second() {
193        let mut reader = std::io::Cursor::new(b"");
194        let mut read_writer = FakeAsyncReadWriter::new(vec![Err(err1()), Err(err1())]);
195        let mut chain = AsyncReadWriteChain::new(&mut reader, &mut read_writer);
196        let mut buf = [b'.'; 4];
197        let err = tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
198            .await
199            .unwrap_err();
200        assert_eq!(std::io::ErrorKind::Other, err.kind());
201        assert_eq!("err1", err.to_string());
202        assert_eq!("....", escape_ascii(&buf));
203        tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
204            .await
205            .unwrap_err();
206    }
207
208    #[tokio::test]
209    async fn passes_writes_through() {
210        let mut reader = std::io::Cursor::new(b"");
211        let mut read_writer: AsyncFixedBuf<4> = AsyncFixedBuf::new();
212        let mut chain = AsyncReadWriteChain::new(&mut reader, &mut read_writer);
213        assert_eq!(
214            3,
215            tokio::io::AsyncWriteExt::write(&mut chain, b"abc")
216                .await
217                .unwrap()
218        );
219        assert_eq!("abc", read_writer.escape_ascii());
220    }
221}