fixed_buffer_tokio/
async_read_write_take.rs

1#![forbid(unsafe_code)]
2
3use std::task::Context;
4use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
5use tokio::macros::support::{Pin, Poll};
6
7/// Wraps a struct that implements
8/// [`AsyncRead`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncRead.html)+[`AsyncWrite`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncWrite.html).
9/// Passes through reads and writes to the struct.
10/// Limits the number of bytes that can be read.
11///
12/// This is like [`tokio::io::Take`](https://docs.rs/tokio/latest/tokio/io/struct.Take.html)
13/// that passes through writes.
14/// This makes it usable with AsyncRead+AsyncWrite objects like
15/// [`tokio::net::TcpStream`](https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html)
16/// and
17/// [`tokio_rustls::server::TlsStream`](https://docs.rs/tokio-rustls/latest/tokio_rustls/server/struct.TlsStream.html).
18pub struct AsyncReadWriteTake<'a, RW: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin> {
19    read_writer: &'a mut RW,
20    remaining_bytes: u64,
21}
22
23impl<'a, RW: AsyncRead + AsyncWrite + Send + Unpin> AsyncReadWriteTake<'a, RW> {
24    /// See [`AsyncReadWriteTake`](struct.AsyncReadWriteTake.html).
25    pub fn new(read_writer: &'a mut RW, len: u64) -> AsyncReadWriteTake<'a, RW> {
26        Self {
27            read_writer,
28            remaining_bytes: len,
29        }
30    }
31}
32
33impl<'a, RW: AsyncRead + AsyncWrite + Send + Unpin> AsyncRead for AsyncReadWriteTake<'a, RW> {
34    fn poll_read(
35        self: Pin<&mut Self>,
36        cx: &mut Context<'_>,
37        buf: &mut ReadBuf<'_>,
38    ) -> Poll<Result<(), std::io::Error>> {
39        let mut_self = self.get_mut();
40        if mut_self.remaining_bytes == 0 {
41            return Poll::Ready(Ok(()));
42        }
43        let num_to_read = mut_self.remaining_bytes.min(buf.remaining() as u64) as usize;
44        let dest = &mut buf.initialize_unfilled()[0..num_to_read];
45        let mut buf2 = ReadBuf::new(dest);
46        match Pin::new(&mut mut_self.read_writer).poll_read(cx, &mut buf2) {
47            Poll::Ready(Ok(())) => {
48                let num_read = buf2.filled().len();
49                buf.advance(num_read);
50                mut_self.remaining_bytes -= num_read as u64;
51                Poll::Ready(Ok(()))
52            }
53            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
54            Poll::Pending => Poll::Pending,
55        }
56    }
57}
58
59impl<'a, RW: AsyncRead + AsyncWrite + Send + Unpin> AsyncWrite for AsyncReadWriteTake<'a, RW> {
60    fn poll_write(
61        self: Pin<&mut Self>,
62        cx: &mut Context<'_>,
63        buf: &[u8],
64    ) -> Poll<Result<usize, std::io::Error>> {
65        let mut_self = self.get_mut();
66        Pin::new(&mut mut_self.read_writer).poll_write(cx, buf)
67    }
68
69    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
70        let mut_self = self.get_mut();
71        Pin::new(&mut mut_self.read_writer).poll_flush(cx)
72    }
73
74    fn poll_shutdown(
75        self: Pin<&mut Self>,
76        cx: &mut Context<'_>,
77    ) -> Poll<Result<(), std::io::Error>> {
78        let mut_self = self.get_mut();
79        Pin::new(&mut mut_self.read_writer).poll_shutdown(cx)
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::super::*;
86    use fixed_buffer::escape_ascii;
87
88    #[tokio::test]
89    async fn read_error() {
90        let mut read_writer = FakeAsyncReadWriter::new(vec![Err(err1()), Ok(2), Ok(0)]);
91        let mut take = AsyncReadWriteTake::new(&mut read_writer, 3);
92        let mut buf = [b'.'; 4];
93        assert_eq!(
94            "err1",
95            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
96                .await
97                .unwrap_err()
98                .to_string()
99        );
100        assert_eq!(
101            2,
102            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
103                .await
104                .unwrap()
105        );
106        assert_eq!("ab..", escape_ascii(&buf));
107        assert_eq!(
108            0,
109            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
110                .await
111                .unwrap()
112        );
113        assert_eq!("ab..", escape_ascii(&buf));
114    }
115
116    #[tokio::test]
117    async fn empty() {
118        let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(0)]);
119        let mut take = AsyncReadWriteTake::new(&mut read_writer, 3);
120        let mut buf = [b'.'; 4];
121        assert_eq!(
122            0,
123            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
124                .await
125                .unwrap()
126        );
127        assert_eq!("....", escape_ascii(&buf));
128    }
129
130    #[tokio::test]
131    async fn doesnt_read_when_zero() {
132        let mut read_writer = FakeAsyncReadWriter::empty();
133        let mut take = AsyncReadWriteTake::new(&mut read_writer, 0);
134        let mut buf = [b'.'; 4];
135        assert_eq!(
136            0,
137            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
138                .await
139                .unwrap()
140        );
141        assert_eq!("....", escape_ascii(&buf));
142    }
143
144    #[tokio::test]
145    async fn fewer_than_len() {
146        let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(2), Ok(0)]);
147        let mut take = AsyncReadWriteTake::new(&mut read_writer, 3);
148        let mut buf = [b'.'; 4];
149        assert_eq!(
150            2,
151            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
152                .await
153                .unwrap()
154        );
155        assert_eq!("ab..", escape_ascii(&buf));
156        assert_eq!(
157            0,
158            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
159                .await
160                .unwrap()
161        );
162        assert_eq!("ab..", escape_ascii(&buf));
163    }
164
165    #[tokio::test]
166    async fn fewer_than_len_in_multiple_reads() {
167        let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(2), Ok(2), Ok(0)]);
168        let mut take = AsyncReadWriteTake::new(&mut read_writer, 5);
169        let mut buf = [b'.'; 4];
170        assert_eq!(
171            2,
172            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
173                .await
174                .unwrap()
175        );
176        assert_eq!("ab..", escape_ascii(&buf));
177        assert_eq!(
178            2,
179            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
180                .await
181                .unwrap()
182        );
183        assert_eq!("cd..", escape_ascii(&buf));
184        assert_eq!(
185            0,
186            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
187                .await
188                .unwrap()
189        );
190        assert_eq!("cd..", escape_ascii(&buf));
191    }
192
193    #[tokio::test]
194    async fn exactly_len() {
195        let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(3), Ok(0)]);
196        let mut take = AsyncReadWriteTake::new(&mut read_writer, 3);
197        let mut buf = [b'.'; 4];
198        assert_eq!(
199            3,
200            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
201                .await
202                .unwrap()
203        );
204        assert_eq!("abc.", escape_ascii(&buf));
205        assert_eq!(
206            0,
207            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
208                .await
209                .unwrap()
210        );
211        assert_eq!("abc.", escape_ascii(&buf));
212    }
213
214    #[tokio::test]
215    async fn exactly_len_in_multiple_reads() {
216        let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(2), Ok(1), Ok(0)]);
217        let mut take = AsyncReadWriteTake::new(&mut read_writer, 3);
218        let mut buf = [b'.'; 4];
219        assert_eq!(
220            2,
221            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
222                .await
223                .unwrap()
224        );
225        assert_eq!("ab..", escape_ascii(&buf));
226        assert_eq!(
227            1,
228            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
229                .await
230                .unwrap()
231        );
232        assert_eq!("cb..", escape_ascii(&buf));
233        assert_eq!(
234            0,
235            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
236                .await
237                .unwrap()
238        );
239        assert_eq!("cb..", escape_ascii(&buf));
240    }
241
242    #[tokio::test]
243    async fn doesnt_call_read_after_len_reached() {
244        let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(3)]);
245        let mut take = AsyncReadWriteTake::new(&mut read_writer, 3);
246        let mut buf = [b'.'; 4];
247        assert_eq!(
248            3,
249            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
250                .await
251                .unwrap()
252        );
253        assert_eq!("abc.", escape_ascii(&buf));
254        assert_eq!(
255            0,
256            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
257                .await
258                .unwrap()
259        );
260        assert_eq!("abc.", escape_ascii(&buf));
261    }
262
263    #[tokio::test]
264    async fn doesnt_call_read_after_len_reached_in_multiple_reads() {
265        let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(2), Ok(1)]);
266        let mut take = AsyncReadWriteTake::new(&mut read_writer, 3);
267        let mut buf = [b'.'; 4];
268        assert_eq!(
269            2,
270            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
271                .await
272                .unwrap()
273        );
274        assert_eq!("ab..", escape_ascii(&buf));
275        assert_eq!(
276            1,
277            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
278                .await
279                .unwrap()
280        );
281        assert_eq!("cb..", escape_ascii(&buf));
282        assert_eq!(
283            0,
284            tokio::io::AsyncReadExt::read(&mut take, &mut buf)
285                .await
286                .unwrap()
287        );
288        assert_eq!("cb..", escape_ascii(&buf));
289    }
290
291    #[tokio::test]
292    async fn passes_writes_through() {
293        let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(3)]);
294        let mut take = AsyncReadWriteTake::new(&mut read_writer, 2);
295        assert_eq!(
296            3,
297            tokio::io::AsyncWriteExt::write(&mut take, b"abc")
298                .await
299                .unwrap()
300        );
301        assert!(read_writer.is_empty());
302    }
303}