1use futures_core::ready;
2use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite};
3use std::io::{IoSliceMut, Result};
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7pub(crate) trait AsyncBufReadExt: AsyncBufRead {
9 fn read_buf<CallbackReturn, Callback: FnOnce(&'_ [u8]) -> (usize, CallbackReturn) + Unpin>(
23 self: Pin<&mut Self>,
24 callback: Callback,
25 ) -> ReadBufFuture<'_, Self, CallbackReturn, Callback> {
26 ReadBufFuture {
27 source: self,
28 callback: Some(callback),
29 }
30 }
31}
32
33impl<R: AsyncBufRead + ?Sized> AsyncBufReadExt for R {}
34
35pub(crate) trait AsyncReadExt: AsyncRead {
37 #[cfg(test)]
39 fn read<'buffer>(
40 self: Pin<&mut Self>,
41 buffer: &'buffer mut [u8],
42 ) -> ReadFuture<'_, 'buffer, Self> {
43 ReadFuture {
44 source: self,
45 buffer,
46 }
47 }
48
49 #[cfg(test)]
51 fn read_vectored<'buffers>(
52 self: Pin<&mut Self>,
53 buffers: &'buffers mut [IoSliceMut<'buffers>],
54 ) -> ReadVectoredFuture<'_, 'buffers, Self> {
55 ReadVectoredFuture {
56 source: self,
57 buffers,
58 }
59 }
60
61 #[cfg(test)]
63 fn read_vectored_bounded<'bufs>(
64 self: Pin<&mut Self>,
65 bufs: &'bufs mut [IoSliceMut<'bufs>],
66 limit: u64,
67 ) -> ReadVectoredBoundedFuture<'_, 'bufs, Self> {
68 ReadVectoredBoundedFuture {
69 source: self,
70 bufs,
71 limit,
72 }
73 }
74
75 fn poll_read_vectored_bounded(
78 self: Pin<&mut Self>,
79 cx: &mut Context<'_>,
80 bufs: &mut [IoSliceMut<'_>],
81 limit: u64,
82 ) -> Poll<Result<usize>> {
83 if limit == 0 {
84 Ok(0).into()
85 } else {
86 let limit = std::cmp::min(limit, usize::MAX as u64) as usize;
87 let first_buffer = &mut bufs[0];
88 if first_buffer.len() >= limit {
89 self.poll_read(cx, &mut first_buffer[..limit])
92 } else {
93 let buf_count: usize = bufs
98 .iter()
99 .scan(0_usize, |size_so_far, elt| {
100 *size_so_far += elt.len();
101 Some(*size_so_far > limit)
102 })
103 .enumerate()
104 .find(|elt| elt.1)
105 .unwrap_or((bufs.len(), false))
106 .0;
107 self.poll_read_vectored(cx, &mut bufs[..buf_count])
108 }
109 }
110 }
111}
112
113impl<R: AsyncRead + ?Sized> AsyncReadExt for R {}
114
115pub(crate) trait AsyncWriteExt: AsyncWrite {
117 fn write_all<'a>(self: Pin<&'a mut Self>, data: &'a [u8]) -> WriteAllFuture<'a, Self> {
122 WriteAllFuture { sink: self, data }
123 }
124}
125
126impl<W: AsyncWrite + ?Sized> AsyncWriteExt for W {}
127
128#[derive(Debug)]
131pub(crate) struct ReadBufFuture<
132 'source,
133 Source: AsyncBufRead + ?Sized,
134 CallbackReturn,
135 Callback: FnOnce(&[u8]) -> (usize, CallbackReturn) + Unpin,
136> {
137 source: Pin<&'source mut Source>,
138 callback: Option<Callback>,
139}
140
141impl<
142 Source: AsyncBufRead + ?Sized,
143 CallbackReturn,
144 Callback: FnOnce(&[u8]) -> (usize, CallbackReturn) + Unpin,
145> Future for ReadBufFuture<'_, Source, CallbackReturn, Callback>
146{
147 type Output = Result<CallbackReturn>;
148
149 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150 let this = self.get_mut();
151 let data = ready!(this.source.as_mut().poll_fill_buf(cx))?;
152 let (consumed, ret) = (this.callback.take().unwrap())(data);
153 this.source.as_mut().consume(consumed);
154 Ok(ret).into()
155 }
156}
157
158#[cfg(test)]
160#[derive(Debug)]
161pub(crate) struct ReadFuture<'source, 'buffer, Source: AsyncRead + ?Sized> {
162 source: Pin<&'source mut Source>,
163 buffer: &'buffer mut [u8],
164}
165
166#[cfg(test)]
167impl<Source: AsyncRead + ?Sized> Future for ReadFuture<'_, '_, Source> {
168 type Output = Result<usize>;
169
170 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
171 let this = self.get_mut();
172 this.source.as_mut().poll_read(cx, this.buffer)
173 }
174}
175
176#[cfg(test)]
178#[derive(Debug)]
179pub(crate) struct ReadVectoredFuture<'source, 'buffers, Source: AsyncRead + ?Sized> {
180 source: Pin<&'source mut Source>,
181 buffers: &'buffers mut [IoSliceMut<'buffers>],
182}
183
184#[cfg(test)]
185impl<Source: AsyncRead + ?Sized> Future for ReadVectoredFuture<'_, '_, Source> {
186 type Output = Result<usize>;
187
188 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189 let this = self.get_mut();
190 this.source.as_mut().poll_read_vectored(cx, this.buffers)
191 }
192}
193
194#[cfg(test)]
197#[derive(Debug)]
198pub(crate) struct ReadVectoredBoundedFuture<'source, 'bufs, Source: AsyncRead + ?Sized> {
199 source: Pin<&'source mut Source>,
200 bufs: &'bufs mut [IoSliceMut<'bufs>],
201 limit: u64,
202}
203
204#[cfg(test)]
205impl<Source: AsyncRead + ?Sized> Future for ReadVectoredBoundedFuture<'_, '_, Source> {
206 type Output = Result<usize>;
207
208 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
209 let this = self.get_mut();
210 this.source
211 .as_mut()
212 .poll_read_vectored_bounded(cx, this.bufs, this.limit)
213 }
214}
215
216#[derive(Debug)]
218pub(crate) struct WriteAllFuture<'a, T: AsyncWrite + ?Sized> {
219 sink: Pin<&'a mut T>,
220 data: &'a [u8],
221}
222
223impl<T: AsyncWrite + ?Sized> Future for WriteAllFuture<'_, T> {
224 type Output = Result<()>;
225
226 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
227 while !self.data.is_empty() {
228 let data = self.data;
229 let bytes_written = ready!(self.sink.as_mut().poll_write(cx, data))?;
230 self.data = &self.data[bytes_written..];
231 }
232 Ok(()).into()
233 }
234}
235
236#[cfg(test)]
238pub(crate) async fn read_all<Source: AsyncRead + ?Sized>(
239 mut src: Pin<&mut Source>,
240 mut buffer: &mut [u8],
241) -> Result<()> {
242 while !buffer.is_empty() {
243 let bytes_read = src.as_mut().read(buffer).await?;
244 if bytes_read == 0 {
245 return Err(std::io::ErrorKind::UnexpectedEof.into());
246 } else {
247 buffer = &mut buffer[bytes_read..];
248 }
249 }
250 Ok(())
251}
252
253#[cfg(test)]
254mod test {
255 use super::*;
256 use futures_executor::block_on;
257
258 #[test]
260 fn read() {
261 block_on(async {
262 let mut src: &[u8] = &b"abcdefgh"[..];
263 let mut buffer = [0_u8; 4];
264 let bytes_read = Pin::new(&mut src).read(&mut buffer[..]).await.unwrap();
265 assert_eq!(bytes_read, 4);
266 assert_eq!(&buffer, b"abcd");
267 });
268 }
269
270 #[test]
272 fn read_vectored() {
273 block_on(async {
274 let mut src: &[u8] = &b"abcdefgh"[..];
275 let mut buf1 = [0_u8; 4];
276 let mut buf2 = [0_u8; 2];
277 let mut slices = [IoSliceMut::new(&mut buf1), IoSliceMut::new(&mut buf2)];
278 let bytes_read = Pin::new(&mut src).read_vectored(&mut slices).await.unwrap();
279 assert_eq!(bytes_read, 6);
280 assert_eq!(&buf1, b"abcd");
281 assert_eq!(&buf2, b"ef");
282 });
283 }
284
285 #[test]
288 fn poll_read_vectored_bounded_one_partial() {
289 block_on(async {
290 let mut src: &[u8] = &b"abcdefgh"[..];
291 let mut buf1 = [0_u8; 4];
292 let mut buf2 = [0_u8; 4];
293 let mut slices = [
294 IoSliceMut::new(&mut buf1[..]),
295 IoSliceMut::new(&mut buf2[..]),
296 ];
297 let bytes_read = Pin::new(&mut src)
298 .read_vectored_bounded(&mut slices, 3)
299 .await
300 .unwrap();
301 assert_eq!(bytes_read, 3);
302 assert_eq!(&buf1, b"abc\0");
303 assert_eq!(&buf2, b"\0\0\0\0");
304 });
305 }
306
307 #[test]
310 fn poll_read_vectored_bounded_one_full() {
311 block_on(async {
312 let mut src: &[u8] = &b"abcdefgh"[..];
313 let mut buf1 = [0_u8; 4];
314 let mut buf2 = [0_u8; 4];
315 let mut slices = [
316 IoSliceMut::new(&mut buf1[..]),
317 IoSliceMut::new(&mut buf2[..]),
318 ];
319 let bytes_read = Pin::new(&mut src)
320 .read_vectored_bounded(&mut slices, 5)
321 .await
322 .unwrap();
323 assert_eq!(bytes_read, 4);
324 assert_eq!(&buf1, b"abcd");
325 assert_eq!(&buf2, b"\0\0\0\0");
326 });
327 }
328
329 #[test]
332 fn poll_read_vectored_bounded_two_full() {
333 block_on(async {
334 let mut src: &[u8] = &b"abcdefghij"[..];
335 let mut buf1 = [0_u8; 4];
336 let mut buf2 = [0_u8; 4];
337 let mut slices = [
338 IoSliceMut::new(&mut buf1[..]),
339 IoSliceMut::new(&mut buf2[..]),
340 ];
341 let bytes_read = Pin::new(&mut src)
342 .read_vectored_bounded(&mut slices, 10)
343 .await
344 .unwrap();
345 assert_eq!(bytes_read, 8);
346 assert_eq!(&buf1, b"abcd");
347 assert_eq!(&buf2, b"efgh");
348 });
349 }
350
351 #[test]
353 fn write_all_fast() {
354 struct Test {
355 v: Vec<u8>,
356 }
357 impl AsyncWrite for Test {
358 fn poll_write(
359 mut self: Pin<&mut Self>,
360 _cx: &mut Context<'_>,
361 buf: &[u8],
362 ) -> Poll<Result<usize>> {
363 self.v.extend_from_slice(buf);
364 Ok(buf.len()).into()
365 }
366
367 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
368 panic!("Should not be called");
369 }
370
371 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
372 panic!("Should not be called");
373 }
374 }
375 let mut t = Test { v: vec![] };
376 block_on(async { Pin::new(&mut t).write_all(b"abcdefgh").await }).unwrap();
377 assert_eq!(t.v.as_slice(), b"abcdefgh");
378 }
379
380 #[test]
382 fn write_all_slow() {
383 struct Test {
384 v: Vec<u8>,
385 }
386 impl AsyncWrite for Test {
387 fn poll_write(
388 mut self: Pin<&mut Self>,
389 _cx: &mut Context<'_>,
390 buf: &[u8],
391 ) -> Poll<Result<usize>> {
392 match buf.first() {
393 None => Ok(0).into(),
394 Some(&b) => {
395 self.v.push(b);
396 Ok(1).into()
397 }
398 }
399 }
400
401 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
402 panic!("Should not be called");
403 }
404
405 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
406 panic!("Should not be called");
407 }
408 }
409 let mut t = Test { v: vec![] };
410 block_on(async { Pin::new(&mut t).write_all(b"abcdefgh").await }).unwrap();
411 assert_eq!(t.v.as_slice(), b"abcdefgh");
412 }
413
414 #[test]
416 fn write_all_error() {
417 struct Test {
418 already_called: bool,
419 }
420 impl AsyncWrite for Test {
421 fn poll_write(
422 mut self: Pin<&mut Self>,
423 _cx: &mut Context<'_>,
424 _data: &[u8],
425 ) -> Poll<Result<usize>> {
426 assert!(!self.already_called);
427 self.already_called = true;
428 Err(std::io::Error::other("Test error message")).into()
429 }
430
431 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
432 panic!("Should not be called");
433 }
434
435 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
436 panic!("Should not be called");
437 }
438 }
439 let mut t = Test {
440 already_called: false,
441 };
442 let e = block_on(async { Pin::new(&mut t).write_all(b"abcdefgh").await }).unwrap_err();
443 assert_eq!(e.kind(), std::io::ErrorKind::Other);
444 assert_eq!(format!("{e}"), "Test error message");
445 }
446}