futures_util/io/
write_all_vectored.rs

1use futures_core::future::Future;
2use futures_core::ready;
3use futures_core::task::{Context, Poll};
4use futures_io::AsyncWrite;
5use futures_io::IoSlice;
6use std::io;
7use std::pin::Pin;
8
9/// Future for the
10/// [`write_all_vectored`](super::AsyncWriteExt::write_all_vectored) method.
11#[derive(Debug)]
12#[must_use = "futures do nothing unless you `.await` or poll them"]
13pub struct WriteAllVectored<'a, W: ?Sized + Unpin> {
14    writer: &'a mut W,
15    bufs: &'a mut [IoSlice<'a>],
16}
17
18impl<W: ?Sized + Unpin> Unpin for WriteAllVectored<'_, W> {}
19
20impl<'a, W: AsyncWrite + ?Sized + Unpin> WriteAllVectored<'a, W> {
21    pub(super) fn new(writer: &'a mut W, mut bufs: &'a mut [IoSlice<'a>]) -> Self {
22        IoSlice::advance_slices(&mut bufs, 0);
23        Self { writer, bufs }
24    }
25}
26
27impl<W: AsyncWrite + ?Sized + Unpin> Future for WriteAllVectored<'_, W> {
28    type Output = io::Result<()>;
29
30    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
31        let this = &mut *self;
32        while !this.bufs.is_empty() {
33            let n = ready!(Pin::new(&mut this.writer).poll_write_vectored(cx, this.bufs))?;
34            if n == 0 {
35                return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
36            } else {
37                IoSlice::advance_slices(&mut this.bufs, n);
38            }
39        }
40
41        Poll::Ready(Ok(()))
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use std::cmp::min;
48    use std::future::Future;
49    use std::io;
50    use std::pin::Pin;
51    use std::task::{Context, Poll};
52    use std::vec;
53    use std::vec::Vec;
54
55    use crate::io::{AsyncWrite, AsyncWriteExt, IoSlice};
56    use crate::task::noop_waker;
57
58    /// Create a new writer that reads from at most `n_bufs` and reads
59    /// `per_call` bytes (in total) per call to write.
60    fn test_writer(n_bufs: usize, per_call: usize) -> TestWriter {
61        TestWriter { n_bufs, per_call, written: Vec::new() }
62    }
63
64    // TODO: maybe move this the future-test crate?
65    struct TestWriter {
66        n_bufs: usize,
67        per_call: usize,
68        written: Vec<u8>,
69    }
70
71    impl AsyncWrite for TestWriter {
72        fn poll_write(
73            self: Pin<&mut Self>,
74            cx: &mut Context<'_>,
75            buf: &[u8],
76        ) -> Poll<io::Result<usize>> {
77            self.poll_write_vectored(cx, &[IoSlice::new(buf)])
78        }
79
80        fn poll_write_vectored(
81            mut self: Pin<&mut Self>,
82            _cx: &mut Context<'_>,
83            bufs: &[IoSlice<'_>],
84        ) -> Poll<io::Result<usize>> {
85            let mut left = self.per_call;
86            let mut written = 0;
87            for buf in bufs.iter().take(self.n_bufs) {
88                let n = min(left, buf.len());
89                self.written.extend_from_slice(&buf[0..n]);
90                left -= n;
91                written += n;
92            }
93            Poll::Ready(Ok(written))
94        }
95
96        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
97            Poll::Ready(Ok(()))
98        }
99
100        fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101            Poll::Ready(Ok(()))
102        }
103    }
104
105    // TODO: maybe move this the future-test crate?
106    macro_rules! assert_poll_ok {
107        ($e:expr, $expected:expr) => {
108            let expected = $expected;
109            match $e {
110                Poll::Ready(Ok(ok)) if ok == expected => {}
111                got => {
112                    panic!("unexpected result, got: {:?}, wanted: Ready(Ok({:?}))", got, expected)
113                }
114            }
115        };
116    }
117
118    #[test]
119    fn test_writer_read_from_one_buf() {
120        let waker = noop_waker();
121        let mut cx = Context::from_waker(&waker);
122
123        let mut dst = test_writer(1, 2);
124        let mut dst = Pin::new(&mut dst);
125
126        assert_poll_ok!(dst.as_mut().poll_write(&mut cx, &[]), 0);
127        assert_poll_ok!(dst.as_mut().poll_write_vectored(&mut cx, &[]), 0);
128
129        // Read at most 2 bytes.
130        assert_poll_ok!(dst.as_mut().poll_write(&mut cx, &[1, 1, 1]), 2);
131        let bufs = &[IoSlice::new(&[2, 2, 2])];
132        assert_poll_ok!(dst.as_mut().poll_write_vectored(&mut cx, bufs), 2);
133
134        // Only read from first buf.
135        let bufs = &[IoSlice::new(&[3]), IoSlice::new(&[4, 4])];
136        assert_poll_ok!(dst.as_mut().poll_write_vectored(&mut cx, bufs), 1);
137
138        assert_eq!(dst.written, &[1, 1, 2, 2, 3]);
139    }
140
141    #[test]
142    fn test_writer_read_from_multiple_bufs() {
143        let waker = noop_waker();
144        let mut cx = Context::from_waker(&waker);
145
146        let mut dst = test_writer(3, 3);
147        let mut dst = Pin::new(&mut dst);
148
149        // Read at most 3 bytes from two buffers.
150        let bufs = &[IoSlice::new(&[1]), IoSlice::new(&[2, 2, 2])];
151        assert_poll_ok!(dst.as_mut().poll_write_vectored(&mut cx, bufs), 3);
152
153        // Read at most 3 bytes from three buffers.
154        let bufs = &[IoSlice::new(&[3]), IoSlice::new(&[4]), IoSlice::new(&[5, 5])];
155        assert_poll_ok!(dst.as_mut().poll_write_vectored(&mut cx, bufs), 3);
156
157        assert_eq!(dst.written, &[1, 2, 2, 3, 4, 5]);
158    }
159
160    #[test]
161    fn test_write_all_vectored() {
162        let waker = noop_waker();
163        let mut cx = Context::from_waker(&waker);
164
165        #[rustfmt::skip] // Becomes unreadable otherwise.
166        let tests: Vec<(_, &'static [u8])> = vec![
167            (vec![], &[]),
168            (vec![IoSlice::new(&[]), IoSlice::new(&[])], &[]),
169            (vec![IoSlice::new(&[1])], &[1]),
170            (vec![IoSlice::new(&[1, 2])], &[1, 2]),
171            (vec![IoSlice::new(&[1, 2, 3])], &[1, 2, 3]),
172            (vec![IoSlice::new(&[1, 2, 3, 4])], &[1, 2, 3, 4]),
173            (vec![IoSlice::new(&[1, 2, 3, 4, 5])], &[1, 2, 3, 4, 5]),
174            (vec![IoSlice::new(&[1]), IoSlice::new(&[2])], &[1, 2]),
175            (vec![IoSlice::new(&[1, 1]), IoSlice::new(&[2, 2])], &[1, 1, 2, 2]),
176            (vec![IoSlice::new(&[1, 1, 1]), IoSlice::new(&[2, 2, 2])], &[1, 1, 1, 2, 2, 2]),
177            (vec![IoSlice::new(&[1, 1, 1, 1]), IoSlice::new(&[2, 2, 2, 2])], &[1, 1, 1, 1, 2, 2, 2, 2]),
178            (vec![IoSlice::new(&[1]), IoSlice::new(&[2]), IoSlice::new(&[3])], &[1, 2, 3]),
179            (vec![IoSlice::new(&[1, 1]), IoSlice::new(&[2, 2]), IoSlice::new(&[3, 3])], &[1, 1, 2, 2, 3, 3]),
180            (vec![IoSlice::new(&[1, 1, 1]), IoSlice::new(&[2, 2, 2]), IoSlice::new(&[3, 3, 3])], &[1, 1, 1, 2, 2, 2, 3, 3, 3]),
181        ];
182
183        for (mut input, wanted) in tests {
184            let mut dst = test_writer(2, 2);
185            {
186                let mut future = dst.write_all_vectored(&mut *input);
187                match Pin::new(&mut future).poll(&mut cx) {
188                    Poll::Ready(Ok(())) => {}
189                    other => panic!("unexpected result polling future: {:?}", other),
190                }
191            }
192            assert_eq!(&*dst.written, &*wanted);
193        }
194    }
195}