1use std::{
4 future::Future,
5 pin::Pin,
6 task::{Context, Poll, ready},
7};
8
9use crate::{
10 arc_io_result::{ArcIoResult, ArcIoResultExt},
11 fuse_buf_reader::FuseBufReader,
12};
13use futures::{AsyncBufRead, AsyncWrite};
14use pin_project::pin_project;
15
16pub fn copy_buf<R, W>(reader: R, writer: W) -> CopyBuf<R, W>
33where
34 R: AsyncBufRead,
35 W: AsyncWrite,
36{
37 CopyBuf {
38 reader: FuseBufReader::new(reader),
39 writer,
40 copied: 0,
41 }
42}
43
44#[derive(Debug)]
46#[pin_project]
47#[must_use = "futures do nothing unless you `.await` or poll them"]
48pub struct CopyBuf<R, W> {
49 #[pin]
53 reader: FuseBufReader<R>,
54
55 #[pin]
57 writer: W,
58
59 copied: u64,
61}
62
63impl<R, W> CopyBuf<R, W> {
64 pub fn into_inner(self) -> (R, W) {
66 (self.reader.into_inner(), self.writer)
67 }
68}
69
70impl<R, W> Future for CopyBuf<R, W>
71where
72 R: AsyncBufRead,
73 W: AsyncWrite,
74{
75 type Output = std::io::Result<u64>;
76
77 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
78 let this = self.project();
79 let () = ready!(poll_copy_r_to_w(
80 cx,
81 this.reader,
82 this.writer,
83 this.copied,
84 false
85 ))
86 .io_result()?;
87 Poll::Ready(Ok(*this.copied))
88 }
89}
90
91pub(crate) fn poll_copy_r_to_w<R, W>(
101 cx: &mut Context<'_>,
102 mut reader: Pin<&mut FuseBufReader<R>>,
103 mut writer: Pin<&mut W>,
104 total_copied: &mut u64,
105 flush_on_err: bool,
106) -> Poll<ArcIoResult<()>>
107where
108 R: AsyncBufRead,
109 W: AsyncWrite,
110{
111 loop {
121 match reader.as_mut().poll_fill_buf(cx) {
122 Poll::Pending => {
123 let () = ready!(writer.as_mut().poll_flush(cx))?;
126 return Poll::Pending;
127 }
128 Poll::Ready(Err(e)) => {
129 if flush_on_err {
131 let _ignore_flush_error = ready!(writer.as_mut().poll_flush(cx));
132 }
133 return Poll::Ready(Err(e));
134 }
135 Poll::Ready(Ok(&[])) => {
136 let () = ready!(writer.as_mut().poll_flush(cx))?;
139 return Poll::Ready(Ok(()));
140 }
141 Poll::Ready(Ok(data)) => {
142 let n_written: usize = ready!(writer.as_mut().poll_write(cx, data))?;
145 reader.as_mut().consume(n_written);
147 *total_copied += n_written as u64;
148 }
149 }
150 }
151}
152
153#[cfg(test)]
154mod test {
155 #![allow(clippy::bool_assert_comparison)]
157 #![allow(clippy::clone_on_copy)]
158 #![allow(clippy::dbg_macro)]
159 #![allow(clippy::mixed_attributes_style)]
160 #![allow(clippy::print_stderr)]
161 #![allow(clippy::print_stdout)]
162 #![allow(clippy::single_char_pattern)]
163 #![allow(clippy::unwrap_used)]
164 #![allow(clippy::unchecked_time_subtraction)]
165 #![allow(clippy::useless_vec)]
166 #![allow(clippy::needless_pass_by_value)]
167 use super::*;
170 use crate::test::{ErrorRW, PausedRead};
171
172 use futures::{
173 AsyncReadExt as _,
174 future::poll_fn,
175 io::{BufReader, Cursor},
176 };
177 use std::io;
178 use tor_rtcompat::SpawnExt as _;
179 use tor_rtmock::{MockRuntime, io::stream_pair};
180
181 async fn test_copy_cursor(data: &[u8]) {
182 let mut out: Vec<u8> = Vec::new();
183 let r = Cursor::new(data);
184 let mut w = Cursor::new(&mut out);
185
186 let n_copied = copy_buf(&mut BufReader::new(r), &mut w).await.unwrap();
187 assert_eq!(n_copied, data.len() as u64);
188 assert_eq!(&out[..], data);
189 }
190
191 async fn test_copy_stream(rt: &MockRuntime, data: &[u8]) {
192 let out: Vec<u8> = Vec::new();
193 let r1 = Cursor::new(data.to_vec());
194 let (w1, r2) = stream_pair();
195 let mut w2 = Cursor::new(out);
196 let r1 = BufReader::new(r1);
197 let r2 = BufReader::new(r2);
198 let task1 = rt.spawn_with_handle(copy_buf(r1, w1)).unwrap();
199 let task2 = rt
200 .spawn_with_handle(async move {
201 let copy_result = copy_buf(r2, &mut w2).await;
202 (copy_result, w2)
203 })
204 .unwrap();
205
206 let copy_result_1 = task1.await;
207 let (copy_result_2, output) = task2.await;
208
209 assert_eq!(copy_result_1.unwrap(), data.len() as u64);
210 assert_eq!(copy_result_2.unwrap(), data.len() as u64);
211 assert_eq!(&output.into_inner()[..], data);
212 }
213
214 async fn test_copy_stream_paused(rt: &MockRuntime, data: &[u8]) {
215 let n = data.len();
216 let r1 = BufReader::new(Cursor::new(data.to_vec()).chain(PausedRead));
217 let (w1, mut r2) = stream_pair();
218 let mut task1 = rt.spawn_with_handle(copy_buf(r1, w1)).unwrap();
219 let mut buf = vec![0_u8; n];
220 r2.read_exact(&mut buf[..]).await.unwrap();
221 assert_eq!(&buf[..], data);
222
223 let task1_status = poll_fn(|cx| Poll::Ready(Pin::new(&mut task1).poll(cx))).await;
225 assert!(task1_status.is_pending());
226 }
227
228 async fn test_copy_stream_error(rt: &MockRuntime, data: &[u8]) {
229 let out: Vec<u8> = Vec::new();
230 let r1 = Cursor::new(data.to_vec()).chain(ErrorRW(io::ErrorKind::ResourceBusy));
231 let (w1, r2) = stream_pair();
232 let mut w2 = Cursor::new(out);
233 let r1 = BufReader::new(r1);
234 let r2 = BufReader::new(r2);
235 let task1 = rt.spawn_with_handle(copy_buf(r1, w1)).unwrap();
236 let task2 = rt
237 .spawn_with_handle(async move {
238 let copy_result = copy_buf(r2, &mut w2).await;
239 (copy_result, w2)
240 })
241 .unwrap();
242
243 let copy_result_1 = task1.await;
244 let (copy_result_2, output) = task2.await;
245
246 assert_eq!(
247 copy_result_1.unwrap_err().kind(),
248 io::ErrorKind::ResourceBusy
249 );
250 assert_eq!(copy_result_2.unwrap(), data.len() as u64);
251 assert_eq!(&output.into_inner()[..], data);
252 }
253
254 fn test_copy(data: &[u8]) {
255 MockRuntime::test_with_various(async |rt| {
256 test_copy_cursor(data).await;
257 test_copy_stream(&rt, data).await;
258 test_copy_stream_paused(&rt, data).await;
259 test_copy_stream_error(&rt, data).await;
260 });
261 }
262
263 #[test]
264 fn copy_nothing() {
265 test_copy(&[]);
266 }
267
268 #[test]
269 fn copy_small() {
270 test_copy(b"hEllo world");
271 }
272
273 #[test]
274 fn copy_huge() {
275 let huge: Vec<u8> = (0..=77).cycle().take(1_500_000).collect();
276 test_copy(&huge[..]);
277 }
278}