1#![forbid(unsafe_code)]
2
3use std::task::Context;
4use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
5use tokio::macros::support::{Pin, Poll};
6
7pub struct AsyncReadWriteTake<'a, RW: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin> {
19 read_writer: &'a mut RW,
20 remaining_bytes: u64,
21}
22
23impl<'a, RW: AsyncRead + AsyncWrite + Send + Unpin> AsyncReadWriteTake<'a, RW> {
24 pub fn new(read_writer: &'a mut RW, len: u64) -> AsyncReadWriteTake<'a, RW> {
26 Self {
27 read_writer,
28 remaining_bytes: len,
29 }
30 }
31}
32
33impl<'a, RW: AsyncRead + AsyncWrite + Send + Unpin> AsyncRead for AsyncReadWriteTake<'a, RW> {
34 fn poll_read(
35 self: Pin<&mut Self>,
36 cx: &mut Context<'_>,
37 buf: &mut ReadBuf<'_>,
38 ) -> Poll<Result<(), std::io::Error>> {
39 let mut_self = self.get_mut();
40 if mut_self.remaining_bytes == 0 {
41 return Poll::Ready(Ok(()));
42 }
43 let num_to_read = mut_self.remaining_bytes.min(buf.remaining() as u64) as usize;
44 let dest = &mut buf.initialize_unfilled()[0..num_to_read];
45 let mut buf2 = ReadBuf::new(dest);
46 match Pin::new(&mut mut_self.read_writer).poll_read(cx, &mut buf2) {
47 Poll::Ready(Ok(())) => {
48 let num_read = buf2.filled().len();
49 buf.advance(num_read);
50 mut_self.remaining_bytes -= num_read as u64;
51 Poll::Ready(Ok(()))
52 }
53 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
54 Poll::Pending => Poll::Pending,
55 }
56 }
57}
58
59impl<'a, RW: AsyncRead + AsyncWrite + Send + Unpin> AsyncWrite for AsyncReadWriteTake<'a, RW> {
60 fn poll_write(
61 self: Pin<&mut Self>,
62 cx: &mut Context<'_>,
63 buf: &[u8],
64 ) -> Poll<Result<usize, std::io::Error>> {
65 let mut_self = self.get_mut();
66 Pin::new(&mut mut_self.read_writer).poll_write(cx, buf)
67 }
68
69 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
70 let mut_self = self.get_mut();
71 Pin::new(&mut mut_self.read_writer).poll_flush(cx)
72 }
73
74 fn poll_shutdown(
75 self: Pin<&mut Self>,
76 cx: &mut Context<'_>,
77 ) -> Poll<Result<(), std::io::Error>> {
78 let mut_self = self.get_mut();
79 Pin::new(&mut mut_self.read_writer).poll_shutdown(cx)
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use super::super::*;
86 use fixed_buffer::escape_ascii;
87
88 #[tokio::test]
89 async fn read_error() {
90 let mut read_writer = FakeAsyncReadWriter::new(vec![Err(err1()), Ok(2), Ok(0)]);
91 let mut take = AsyncReadWriteTake::new(&mut read_writer, 3);
92 let mut buf = [b'.'; 4];
93 assert_eq!(
94 "err1",
95 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
96 .await
97 .unwrap_err()
98 .to_string()
99 );
100 assert_eq!(
101 2,
102 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
103 .await
104 .unwrap()
105 );
106 assert_eq!("ab..", escape_ascii(&buf));
107 assert_eq!(
108 0,
109 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
110 .await
111 .unwrap()
112 );
113 assert_eq!("ab..", escape_ascii(&buf));
114 }
115
116 #[tokio::test]
117 async fn empty() {
118 let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(0)]);
119 let mut take = AsyncReadWriteTake::new(&mut read_writer, 3);
120 let mut buf = [b'.'; 4];
121 assert_eq!(
122 0,
123 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
124 .await
125 .unwrap()
126 );
127 assert_eq!("....", escape_ascii(&buf));
128 }
129
130 #[tokio::test]
131 async fn doesnt_read_when_zero() {
132 let mut read_writer = FakeAsyncReadWriter::empty();
133 let mut take = AsyncReadWriteTake::new(&mut read_writer, 0);
134 let mut buf = [b'.'; 4];
135 assert_eq!(
136 0,
137 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
138 .await
139 .unwrap()
140 );
141 assert_eq!("....", escape_ascii(&buf));
142 }
143
144 #[tokio::test]
145 async fn fewer_than_len() {
146 let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(2), Ok(0)]);
147 let mut take = AsyncReadWriteTake::new(&mut read_writer, 3);
148 let mut buf = [b'.'; 4];
149 assert_eq!(
150 2,
151 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
152 .await
153 .unwrap()
154 );
155 assert_eq!("ab..", escape_ascii(&buf));
156 assert_eq!(
157 0,
158 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
159 .await
160 .unwrap()
161 );
162 assert_eq!("ab..", escape_ascii(&buf));
163 }
164
165 #[tokio::test]
166 async fn fewer_than_len_in_multiple_reads() {
167 let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(2), Ok(2), Ok(0)]);
168 let mut take = AsyncReadWriteTake::new(&mut read_writer, 5);
169 let mut buf = [b'.'; 4];
170 assert_eq!(
171 2,
172 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
173 .await
174 .unwrap()
175 );
176 assert_eq!("ab..", escape_ascii(&buf));
177 assert_eq!(
178 2,
179 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
180 .await
181 .unwrap()
182 );
183 assert_eq!("cd..", escape_ascii(&buf));
184 assert_eq!(
185 0,
186 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
187 .await
188 .unwrap()
189 );
190 assert_eq!("cd..", escape_ascii(&buf));
191 }
192
193 #[tokio::test]
194 async fn exactly_len() {
195 let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(3), Ok(0)]);
196 let mut take = AsyncReadWriteTake::new(&mut read_writer, 3);
197 let mut buf = [b'.'; 4];
198 assert_eq!(
199 3,
200 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
201 .await
202 .unwrap()
203 );
204 assert_eq!("abc.", escape_ascii(&buf));
205 assert_eq!(
206 0,
207 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
208 .await
209 .unwrap()
210 );
211 assert_eq!("abc.", escape_ascii(&buf));
212 }
213
214 #[tokio::test]
215 async fn exactly_len_in_multiple_reads() {
216 let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(2), Ok(1), Ok(0)]);
217 let mut take = AsyncReadWriteTake::new(&mut read_writer, 3);
218 let mut buf = [b'.'; 4];
219 assert_eq!(
220 2,
221 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
222 .await
223 .unwrap()
224 );
225 assert_eq!("ab..", escape_ascii(&buf));
226 assert_eq!(
227 1,
228 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
229 .await
230 .unwrap()
231 );
232 assert_eq!("cb..", escape_ascii(&buf));
233 assert_eq!(
234 0,
235 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
236 .await
237 .unwrap()
238 );
239 assert_eq!("cb..", escape_ascii(&buf));
240 }
241
242 #[tokio::test]
243 async fn doesnt_call_read_after_len_reached() {
244 let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(3)]);
245 let mut take = AsyncReadWriteTake::new(&mut read_writer, 3);
246 let mut buf = [b'.'; 4];
247 assert_eq!(
248 3,
249 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
250 .await
251 .unwrap()
252 );
253 assert_eq!("abc.", escape_ascii(&buf));
254 assert_eq!(
255 0,
256 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
257 .await
258 .unwrap()
259 );
260 assert_eq!("abc.", escape_ascii(&buf));
261 }
262
263 #[tokio::test]
264 async fn doesnt_call_read_after_len_reached_in_multiple_reads() {
265 let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(2), Ok(1)]);
266 let mut take = AsyncReadWriteTake::new(&mut read_writer, 3);
267 let mut buf = [b'.'; 4];
268 assert_eq!(
269 2,
270 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
271 .await
272 .unwrap()
273 );
274 assert_eq!("ab..", escape_ascii(&buf));
275 assert_eq!(
276 1,
277 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
278 .await
279 .unwrap()
280 );
281 assert_eq!("cb..", escape_ascii(&buf));
282 assert_eq!(
283 0,
284 tokio::io::AsyncReadExt::read(&mut take, &mut buf)
285 .await
286 .unwrap()
287 );
288 assert_eq!("cb..", escape_ascii(&buf));
289 }
290
291 #[tokio::test]
292 async fn passes_writes_through() {
293 let mut read_writer = FakeAsyncReadWriter::new(vec![Ok(3)]);
294 let mut take = AsyncReadWriteTake::new(&mut read_writer, 2);
295 assert_eq!(
296 3,
297 tokio::io::AsyncWriteExt::write(&mut take, b"abc")
298 .await
299 .unwrap()
300 );
301 assert!(read_writer.is_empty());
302 }
303}