Skip to main content

futures_copy/
copy_buf.rs

1//! Functionality to copy from an `AsyncBufRead` to an `AsyncWrite`.
2
3use std::{
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll, ready},
7};
8
9use crate::{
10    arc_io_result::{ArcIoResult, ArcIoResultExt},
11    fuse_buf_reader::FuseBufReader,
12};
13use futures::{AsyncBufRead, AsyncWrite};
14use pin_project::pin_project;
15
16/// Return a future to copy all bytes interactively from `reader` to `writer`.
17///
18/// Unlike [`futures::io::copy`], this future makes sure that
19/// if `reader` pauses (returns `Pending`),
20/// all as-yet-received bytes are still flushed to `writer`.
21///
22/// The future continues copying data until either an error occurs
23/// (in which case it yields an error),
24/// or the reader returns an EOF
25/// (in which case it flushes any pending data,
26/// and returns the number of bytes copied).
27///
28/// # Limitations
29///
30/// See the crate-level documentation for
31/// [discussion of this function's limitations](crate#Limitations).
32pub fn copy_buf<R, W>(reader: R, writer: W) -> CopyBuf<R, W>
33where
34    R: AsyncBufRead,
35    W: AsyncWrite,
36{
37    CopyBuf {
38        reader: FuseBufReader::new(reader),
39        writer,
40        copied: 0,
41    }
42}
43
44/// A future returned by [`copy_buf`].
45#[derive(Debug)]
46#[pin_project]
47#[must_use = "futures do nothing unless you `.await` or poll them"]
48pub struct CopyBuf<R, W> {
49    /// The reader that we're taking data from.
50    ///
51    /// This is `FuseBufReader` to make our logic simpler.
52    #[pin]
53    reader: FuseBufReader<R>,
54
55    /// The writer that we're pushing
56    #[pin]
57    writer: W,
58
59    /// The number of bytes written to the writer so far.
60    copied: u64,
61}
62
63impl<R, W> CopyBuf<R, W> {
64    /// Consume this CopyBuf future, and return the underlying reader and writer.
65    pub fn into_inner(self) -> (R, W) {
66        (self.reader.into_inner(), self.writer)
67    }
68}
69
70impl<R, W> Future for CopyBuf<R, W>
71where
72    R: AsyncBufRead,
73    W: AsyncWrite,
74{
75    type Output = std::io::Result<u64>;
76
77    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
78        let this = self.project();
79        let () = ready!(poll_copy_r_to_w(
80            cx,
81            this.reader,
82            this.writer,
83            this.copied,
84            false
85        ))
86        .io_result()?;
87        Poll::Ready(Ok(*this.copied))
88    }
89}
90
91/// Core implementation function:
92/// Try to make progress copying bytes from `reader` to `writer`,
93/// and add the number of bytes written to `*total_copied`.
94///
95/// Returns `Ready` when an error has occurred,
96/// or when the reader has reached EOF and the writer has been flushed.
97/// Otherwise, returns `Pending`, and registers itself with `cx`.
98///
99/// (This is a separate function so we can use it to implement CopyBuf and CopyBufBidirectional.)
100pub(crate) fn poll_copy_r_to_w<R, W>(
101    cx: &mut Context<'_>,
102    mut reader: Pin<&mut FuseBufReader<R>>,
103    mut writer: Pin<&mut W>,
104    total_copied: &mut u64,
105    flush_on_err: bool,
106) -> Poll<ArcIoResult<()>>
107where
108    R: AsyncBufRead,
109    W: AsyncWrite,
110{
111    // TODO: Instead of using poll_fill_buf() unconditionally,
112    // it might be a neat idea to use the buffer by reference and just keep writing
113    // if the buffer is already "full enough".  The futures::io AsyncBufRead API
114    // doesn't really make that possible, though.  If specialization is ever stabilized,
115    // we could have a special implementation for BufReader, I guess.
116
117    // TODO: We assume that 'flush' is pretty fast when it has nothing to do.
118    // If that's wrong, we may need to remember whether we've written data but not flushed it.
119
120    loop {
121        match reader.as_mut().poll_fill_buf(cx) {
122            Poll::Pending => {
123                // If there's nothing to read now, we may need to make sure that the writer
124                // is flushed.
125                let () = ready!(writer.as_mut().poll_flush(cx))?;
126                return Poll::Pending;
127            }
128            Poll::Ready(Err(e)) => {
129                //  On error, flush, and propagate the error.
130                if flush_on_err {
131                    let _ignore_flush_error = ready!(writer.as_mut().poll_flush(cx));
132                }
133                return Poll::Ready(Err(e));
134            }
135            Poll::Ready(Ok(&[])) => {
136                // On EOF, we have already written all the data; make sure we flush it,
137                // and then return the amount that we copied.
138                let () = ready!(writer.as_mut().poll_flush(cx))?;
139                return Poll::Ready(Ok(()));
140            }
141            Poll::Ready(Ok(data)) => {
142                // If there is pending data, we copy as much as we can.
143                // We return "pending" if we can't write any.
144                let n_written: usize = ready!(writer.as_mut().poll_write(cx, data))?;
145                // Remove the data from the reader.
146                reader.as_mut().consume(n_written);
147                *total_copied += n_written as u64;
148            }
149        }
150    }
151}
152
153#[cfg(test)]
154mod test {
155    // @@ begin test lint list maintained by maint/add_warning @@
156    #![allow(clippy::bool_assert_comparison)]
157    #![allow(clippy::clone_on_copy)]
158    #![allow(clippy::dbg_macro)]
159    #![allow(clippy::mixed_attributes_style)]
160    #![allow(clippy::print_stderr)]
161    #![allow(clippy::print_stdout)]
162    #![allow(clippy::single_char_pattern)]
163    #![allow(clippy::unwrap_used)]
164    #![allow(clippy::unchecked_time_subtraction)]
165    #![allow(clippy::useless_vec)]
166    #![allow(clippy::needless_pass_by_value)]
167    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
168
169    use super::*;
170    use crate::test::{ErrorRW, PausedRead};
171
172    use futures::{
173        AsyncReadExt as _,
174        future::poll_fn,
175        io::{BufReader, Cursor},
176    };
177    use std::io;
178    use tor_rtcompat::SpawnExt as _;
179    use tor_rtmock::{MockRuntime, io::stream_pair};
180
181    async fn test_copy_cursor(data: &[u8]) {
182        let mut out: Vec<u8> = Vec::new();
183        let r = Cursor::new(data);
184        let mut w = Cursor::new(&mut out);
185
186        let n_copied = copy_buf(&mut BufReader::new(r), &mut w).await.unwrap();
187        assert_eq!(n_copied, data.len() as u64);
188        assert_eq!(&out[..], data);
189    }
190
191    async fn test_copy_stream(rt: &MockRuntime, data: &[u8]) {
192        let out: Vec<u8> = Vec::new();
193        let r1 = Cursor::new(data.to_vec());
194        let (w1, r2) = stream_pair();
195        let mut w2 = Cursor::new(out);
196        let r1 = BufReader::new(r1);
197        let r2 = BufReader::new(r2);
198        let task1 = rt.spawn_with_handle(copy_buf(r1, w1)).unwrap();
199        let task2 = rt
200            .spawn_with_handle(async move {
201                let copy_result = copy_buf(r2, &mut w2).await;
202                (copy_result, w2)
203            })
204            .unwrap();
205
206        let copy_result_1 = task1.await;
207        let (copy_result_2, output) = task2.await;
208
209        assert_eq!(copy_result_1.unwrap(), data.len() as u64);
210        assert_eq!(copy_result_2.unwrap(), data.len() as u64);
211        assert_eq!(&output.into_inner()[..], data);
212    }
213
214    async fn test_copy_stream_paused(rt: &MockRuntime, data: &[u8]) {
215        let n = data.len();
216        let r1 = BufReader::new(Cursor::new(data.to_vec()).chain(PausedRead));
217        let (w1, mut r2) = stream_pair();
218        let mut task1 = rt.spawn_with_handle(copy_buf(r1, w1)).unwrap();
219        let mut buf = vec![0_u8; n];
220        r2.read_exact(&mut buf[..]).await.unwrap();
221        assert_eq!(&buf[..], data);
222
223        // Should not be able to ever end.
224        let task1_status = poll_fn(|cx| Poll::Ready(Pin::new(&mut task1).poll(cx))).await;
225        assert!(task1_status.is_pending());
226    }
227
228    async fn test_copy_stream_error(rt: &MockRuntime, data: &[u8]) {
229        let out: Vec<u8> = Vec::new();
230        let r1 = Cursor::new(data.to_vec()).chain(ErrorRW(io::ErrorKind::ResourceBusy));
231        let (w1, r2) = stream_pair();
232        let mut w2 = Cursor::new(out);
233        let r1 = BufReader::new(r1);
234        let r2 = BufReader::new(r2);
235        let task1 = rt.spawn_with_handle(copy_buf(r1, w1)).unwrap();
236        let task2 = rt
237            .spawn_with_handle(async move {
238                let copy_result = copy_buf(r2, &mut w2).await;
239                (copy_result, w2)
240            })
241            .unwrap();
242
243        let copy_result_1 = task1.await;
244        let (copy_result_2, output) = task2.await;
245
246        assert_eq!(
247            copy_result_1.unwrap_err().kind(),
248            io::ErrorKind::ResourceBusy
249        );
250        assert_eq!(copy_result_2.unwrap(), data.len() as u64);
251        assert_eq!(&output.into_inner()[..], data);
252    }
253
254    fn test_copy(data: &[u8]) {
255        MockRuntime::test_with_various(async |rt| {
256            test_copy_cursor(data).await;
257            test_copy_stream(&rt, data).await;
258            test_copy_stream_paused(&rt, data).await;
259            test_copy_stream_error(&rt, data).await;
260        });
261    }
262
263    #[test]
264    fn copy_nothing() {
265        test_copy(&[]);
266    }
267
268    #[test]
269    fn copy_small() {
270        test_copy(b"hEllo world");
271    }
272
273    #[test]
274    fn copy_huge() {
275        let huge: Vec<u8> = (0..=77).cycle().take(1_500_000).collect();
276        test_copy(&huge[..]);
277    }
278}