1use futures_core::ready;
2use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite};
3use std::future::Future;
4use std::io::{IoSliceMut, Result};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8pub trait AsyncBufReadExt: AsyncBufRead {
10 fn read_buf<CallbackReturn, Callback: FnOnce(&'_ [u8]) -> (usize, CallbackReturn) + Unpin>(
24 self: Pin<&mut Self>,
25 callback: Callback,
26 ) -> ReadBufFuture<'_, Self, CallbackReturn, Callback> {
27 ReadBufFuture {
28 source: self,
29 callback: Some(callback),
30 }
31 }
32}
33
34impl<R: AsyncBufRead + ?Sized> AsyncBufReadExt for R {}
35
36pub trait AsyncReadExt: AsyncRead {
38 fn read<'buffer>(
40 self: Pin<&mut Self>,
41 buffer: &'buffer mut [u8],
42 ) -> ReadFuture<'_, 'buffer, Self> {
43 ReadFuture {
44 source: self,
45 buffer,
46 }
47 }
48
49 fn read_vectored<'buffers>(
51 self: Pin<&mut Self>,
52 buffers: &'buffers mut [IoSliceMut<'buffers>],
53 ) -> ReadVectoredFuture<'_, 'buffers, Self> {
54 ReadVectoredFuture {
55 source: self,
56 buffers,
57 }
58 }
59
60 fn read_vectored_bounded<'bufs>(
62 self: Pin<&mut Self>,
63 bufs: &'bufs mut [IoSliceMut<'bufs>],
64 limit: u64,
65 ) -> ReadVectoredBoundedFuture<'_, 'bufs, Self> {
66 ReadVectoredBoundedFuture {
67 source: self,
68 bufs,
69 limit,
70 }
71 }
72
73 fn poll_read_vectored_bounded(
76 self: Pin<&mut Self>,
77 cx: &mut Context<'_>,
78 bufs: &mut [IoSliceMut<'_>],
79 limit: u64,
80 ) -> Poll<Result<usize>> {
81 if limit == 0 {
82 Ok(0).into()
83 } else {
84 let limit = std::cmp::min(limit, usize::MAX as u64) as usize;
85 let first_buffer = &mut bufs[0];
86 if first_buffer.len() >= limit {
87 self.poll_read(cx, &mut first_buffer[..limit])
90 } else {
91 let buf_count: usize = bufs
96 .iter()
97 .scan(0_usize, |size_so_far, elt| {
98 *size_so_far += elt.len();
99 Some(*size_so_far > limit)
100 })
101 .enumerate()
102 .find(|elt| elt.1)
103 .unwrap_or((bufs.len(), false))
104 .0;
105 self.poll_read_vectored(cx, &mut bufs[..buf_count])
106 }
107 }
108 }
109}
110
111impl<R: AsyncRead + ?Sized> AsyncReadExt for R {}
112
113pub trait AsyncWriteExt: AsyncWrite {
115 fn write_all<'a>(self: Pin<&'a mut Self>, data: &'a [u8]) -> WriteAllFuture<'a, Self> {
120 WriteAllFuture { sink: self, data }
121 }
122}
123
124impl<W: AsyncWrite + ?Sized> AsyncWriteExt for W {}
125
126#[derive(Debug)]
129pub struct ReadBufFuture<
130 'source,
131 Source: AsyncBufRead + ?Sized,
132 CallbackReturn,
133 Callback: FnOnce(&[u8]) -> (usize, CallbackReturn) + Unpin,
134> {
135 source: Pin<&'source mut Source>,
136 callback: Option<Callback>,
137}
138
139impl<
140 Source: AsyncBufRead + ?Sized,
141 CallbackReturn,
142 Callback: FnOnce(&[u8]) -> (usize, CallbackReturn) + Unpin,
143 > Future for ReadBufFuture<'_, Source, CallbackReturn, Callback>
144{
145 type Output = Result<CallbackReturn>;
146
147 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
148 let this = self.get_mut();
149 let data = ready!(this.source.as_mut().poll_fill_buf(cx))?;
150 let (consumed, ret) = (this.callback.take().unwrap())(data);
151 this.source.as_mut().consume(consumed);
152 Ok(ret).into()
153 }
154}
155
156#[derive(Debug)]
158pub struct ReadFuture<'source, 'buffer, Source: AsyncRead + ?Sized> {
159 source: Pin<&'source mut Source>,
160 buffer: &'buffer mut [u8],
161}
162
163impl<Source: AsyncRead + ?Sized> Future for ReadFuture<'_, '_, Source> {
164 type Output = Result<usize>;
165
166 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
167 let this = self.get_mut();
168 this.source.as_mut().poll_read(cx, this.buffer)
169 }
170}
171
172#[derive(Debug)]
174pub struct ReadVectoredFuture<'source, 'buffers, Source: AsyncRead + ?Sized> {
175 source: Pin<&'source mut Source>,
176 buffers: &'buffers mut [IoSliceMut<'buffers>],
177}
178
179impl<Source: AsyncRead + ?Sized> Future for ReadVectoredFuture<'_, '_, Source> {
180 type Output = Result<usize>;
181
182 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
183 let this = self.get_mut();
184 this.source.as_mut().poll_read_vectored(cx, this.buffers)
185 }
186}
187
188#[derive(Debug)]
191pub struct ReadVectoredBoundedFuture<'source, 'bufs, Source: AsyncRead + ?Sized> {
192 source: Pin<&'source mut Source>,
193 bufs: &'bufs mut [IoSliceMut<'bufs>],
194 limit: u64,
195}
196
197impl<Source: AsyncRead + ?Sized> Future for ReadVectoredBoundedFuture<'_, '_, Source> {
198 type Output = Result<usize>;
199
200 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
201 let this = self.get_mut();
202 this.source
203 .as_mut()
204 .poll_read_vectored_bounded(cx, this.bufs, this.limit)
205 }
206}
207
208#[derive(Debug)]
210pub struct WriteAllFuture<'a, T: AsyncWrite + ?Sized> {
211 sink: Pin<&'a mut T>,
212 data: &'a [u8],
213}
214
215impl<T: AsyncWrite + ?Sized> Future for WriteAllFuture<'_, T> {
216 type Output = Result<()>;
217
218 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
219 while !self.data.is_empty() {
220 let data = self.data;
221 let bytes_written = ready!(self.sink.as_mut().poll_write(cx, data))?;
222 self.data = &self.data[bytes_written..];
223 }
224 Ok(()).into()
225 }
226}
227
228#[cfg(test)]
230pub async fn read_all<'buffer, Source: AsyncRead + ?Sized>(
231 mut src: Pin<&mut Source>,
232 mut buffer: &'buffer mut [u8],
233) -> Result<()> {
234 while !buffer.is_empty() {
235 let bytes_read = src.as_mut().read(buffer).await?;
236 if bytes_read == 0 {
237 return Err(std::io::ErrorKind::UnexpectedEof.into());
238 } else {
239 buffer = &mut buffer[bytes_read..];
240 }
241 }
242 Ok(())
243}
244
245#[cfg(test)]
246mod test {
247 use super::*;
248 use futures_executor::block_on;
249 use futures_io::AsyncWrite;
250 use std::pin::Pin;
251 use std::task::{Context, Poll};
252
253 #[test]
255 fn test_read() {
256 block_on(async {
257 let mut src: &[u8] = &b"abcdefgh"[..];
258 let mut buffer = [0u8; 4];
259 let bytes_read = Pin::new(&mut src).read(&mut buffer[..]).await.unwrap();
260 assert_eq!(bytes_read, 4);
261 assert_eq!(&buffer, b"abcd");
262 });
263 }
264
265 #[test]
267 fn test_read_vectored() {
268 block_on(async {
269 let mut src: &[u8] = &b"abcdefgh"[..];
270 let mut buf1 = [0u8; 4];
271 let mut buf2 = [0u8; 2];
272 let mut slices = [IoSliceMut::new(&mut buf1), IoSliceMut::new(&mut buf2)];
273 let bytes_read = Pin::new(&mut src).read_vectored(&mut slices).await.unwrap();
274 assert_eq!(bytes_read, 6);
275 assert_eq!(&buf1, b"abcd");
276 assert_eq!(&buf2, b"ef");
277 });
278 }
279
280 #[test]
283 fn test_poll_read_vectored_bounded_one_partial() {
284 block_on(async {
285 let mut src: &[u8] = &b"abcdefgh"[..];
286 let mut buf1 = [0_u8; 4];
287 let mut buf2 = [0_u8; 4];
288 let mut slices = [
289 IoSliceMut::new(&mut buf1[..]),
290 IoSliceMut::new(&mut buf2[..]),
291 ];
292 let bytes_read = Pin::new(&mut src)
293 .read_vectored_bounded(&mut slices, 3)
294 .await
295 .unwrap();
296 assert_eq!(bytes_read, 3);
297 assert_eq!(&buf1, b"abc\0");
298 assert_eq!(&buf2, b"\0\0\0\0");
299 });
300 }
301
302 #[test]
305 fn test_poll_read_vectored_bounded_one_full() {
306 block_on(async {
307 let mut src: &[u8] = &b"abcdefgh"[..];
308 let mut buf1 = [0_u8; 4];
309 let mut buf2 = [0_u8; 4];
310 let mut slices = [
311 IoSliceMut::new(&mut buf1[..]),
312 IoSliceMut::new(&mut buf2[..]),
313 ];
314 let bytes_read = Pin::new(&mut src)
315 .read_vectored_bounded(&mut slices, 5)
316 .await
317 .unwrap();
318 assert_eq!(bytes_read, 4);
319 assert_eq!(&buf1, b"abcd");
320 assert_eq!(&buf2, b"\0\0\0\0");
321 });
322 }
323
324 #[test]
327 fn test_poll_read_vectored_bounded_two_full() {
328 block_on(async {
329 let mut src: &[u8] = &b"abcdefghij"[..];
330 let mut buf1 = [0_u8; 4];
331 let mut buf2 = [0_u8; 4];
332 let mut slices = [
333 IoSliceMut::new(&mut buf1[..]),
334 IoSliceMut::new(&mut buf2[..]),
335 ];
336 let bytes_read = Pin::new(&mut src)
337 .read_vectored_bounded(&mut slices, 10)
338 .await
339 .unwrap();
340 assert_eq!(bytes_read, 8);
341 assert_eq!(&buf1, b"abcd");
342 assert_eq!(&buf2, b"efgh");
343 });
344 }
345
346 #[test]
348 fn test_write_all_fast() {
349 struct Test {
350 v: Vec<u8>,
351 }
352 impl AsyncWrite for Test {
353 fn poll_write(
354 mut self: Pin<&mut Self>,
355 _cx: &mut Context<'_>,
356 data: &[u8],
357 ) -> Poll<Result<usize>> {
358 self.v.extend_from_slice(data);
359 Ok(data.len()).into()
360 }
361
362 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
363 panic!("Should not be called");
364 }
365
366 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
367 panic!("Should not be called");
368 }
369 }
370 let mut t = Test { v: vec![] };
371 block_on(async { Pin::new(&mut t).write_all(b"abcdefgh").await }).unwrap();
372 assert_eq!(t.v.as_slice(), b"abcdefgh");
373 }
374
375 #[test]
377 fn test_write_all_slow() {
378 struct Test {
379 v: Vec<u8>,
380 }
381 impl AsyncWrite for Test {
382 fn poll_write(
383 mut self: Pin<&mut Self>,
384 _cx: &mut Context<'_>,
385 data: &[u8],
386 ) -> Poll<Result<usize>> {
387 match data.first() {
388 None => Ok(0).into(),
389 Some(&b) => {
390 self.v.push(b);
391 Ok(1).into()
392 }
393 }
394 }
395
396 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
397 panic!("Should not be called");
398 }
399
400 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
401 panic!("Should not be called");
402 }
403 }
404 let mut t = Test { v: vec![] };
405 block_on(async { Pin::new(&mut t).write_all(b"abcdefgh").await }).unwrap();
406 assert_eq!(t.v.as_slice(), b"abcdefgh");
407 }
408
409 #[test]
411 fn test_write_all_error() {
412 struct Test {
413 already_called: bool,
414 }
415 impl AsyncWrite for Test {
416 fn poll_write(
417 mut self: Pin<&mut Self>,
418 _cx: &mut Context<'_>,
419 _data: &[u8],
420 ) -> Poll<Result<usize>> {
421 assert!(!self.already_called);
422 self.already_called = true;
423 Err(std::io::Error::new(
424 std::io::ErrorKind::Other,
425 "Test error message",
426 ))
427 .into()
428 }
429
430 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
431 panic!("Should not be called");
432 }
433
434 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
435 panic!("Should not be called");
436 }
437 }
438 let mut t = Test {
439 already_called: false,
440 };
441 let e = block_on(async { Pin::new(&mut t).write_all(b"abcdefgh").await }).unwrap_err();
442 assert_eq!(e.kind(), std::io::ErrorKind::Other);
443 assert_eq!(format!("{}", e), "Test error message");
444 }
445}