Skip to main content

ax_io/utils/
copy.rs

1#[cfg(feature = "alloc")]
2use alloc::{collections::vec_deque::VecDeque, vec::Vec};
3use core::{io::BorrowedBuf, mem::MaybeUninit};
4
5use crate::{BufReader, BufWriter, DEFAULT_BUF_SIZE, Error, Read, Result, Write};
6
7/// Copies the entire contents of a reader into a writer.
8///
9/// This function will continuously read data from `reader` and then
10/// write it into `writer` in a streaming fashion until `reader`
11/// returns EOF.
12///
13/// On success, the total number of bytes that were copied from
14/// `reader` to `writer` is returned.
15///
16/// See [`std::io::copy`] for more details.
17pub fn copy<R, W>(reader: &mut R, writer: &mut W) -> Result<u64>
18where
19    R: Read + ?Sized,
20    W: Write + ?Sized,
21{
22    let read_buf = BufferedReaderSpec::buffer_size(reader);
23    let write_buf = BufferedWriterSpec::buffer_size(writer);
24
25    if read_buf >= DEFAULT_BUF_SIZE && read_buf >= write_buf {
26        return BufferedReaderSpec::copy_to(reader, writer);
27    }
28
29    BufferedWriterSpec::copy_from(writer, reader)
30}
31
32/// Fallback [`copy`] implementation using a stack-allocated buffer.
33pub fn stack_buffer_copy<R, W>(reader: &mut R, writer: &mut W) -> Result<u64>
34where
35    R: Read + ?Sized,
36    W: Write + ?Sized,
37{
38    let buf: &mut [_] = &mut [MaybeUninit::uninit(); DEFAULT_BUF_SIZE];
39    let mut buf: BorrowedBuf<'_> = buf.into();
40
41    let mut len = 0;
42
43    loop {
44        match reader.read_buf(buf.unfilled()) {
45            Ok(()) => {}
46            Err(e) if e.canonicalize() == Error::Interrupted => continue,
47            Err(e) => return Err(e),
48        };
49
50        if buf.filled().is_empty() {
51            break;
52        }
53
54        len += buf.filled().len() as u64;
55        writer.write_all(buf.filled())?;
56        buf.clear();
57    }
58
59    Ok(len)
60}
61
62/// Specialization of the read-write loop that reuses the internal
63/// buffer of a BufReader. If there's no buffer then the writer side
64/// should be used instead.
65trait BufferedReaderSpec {
66    fn buffer_size(&self) -> usize;
67
68    fn copy_to(&mut self, to: &mut (impl Write + ?Sized)) -> Result<u64>;
69}
70
71impl<T> BufferedReaderSpec for T
72where
73    Self: Read,
74    T: ?Sized,
75{
76    #[inline]
77    default fn buffer_size(&self) -> usize {
78        0
79    }
80
81    default fn copy_to(&mut self, _to: &mut (impl Write + ?Sized)) -> Result<u64> {
82        unreachable!("only called from specializations")
83    }
84}
85
86impl BufferedReaderSpec for &[u8] {
87    fn buffer_size(&self) -> usize {
88        // prefer this specialization since the source "buffer" is all we'll ever need,
89        // even if it's small
90        usize::MAX
91    }
92
93    fn copy_to(&mut self, to: &mut (impl Write + ?Sized)) -> Result<u64> {
94        let len = self.len();
95        to.write_all(self)?;
96        *self = &self[len..];
97        Ok(len as u64)
98    }
99}
100
101#[cfg(feature = "alloc")]
102impl BufferedReaderSpec for VecDeque<u8> {
103    fn buffer_size(&self) -> usize {
104        // prefer this specialization since the source "buffer" is all we'll ever need,
105        // even if it's small
106        usize::MAX
107    }
108
109    fn copy_to(&mut self, to: &mut (impl Write + ?Sized)) -> Result<u64> {
110        let len = self.len();
111        let (front, back) = self.as_slices();
112        to.write_all(front)?;
113        to.write_all(back)?;
114        self.clear();
115        Ok(len as u64)
116    }
117}
118
119impl<I> BufferedReaderSpec for BufReader<I>
120where
121    Self: Read,
122    I: ?Sized,
123{
124    fn buffer_size(&self) -> usize {
125        self.capacity()
126    }
127
128    fn copy_to(&mut self, to: &mut (impl Write + ?Sized)) -> Result<u64> {
129        let mut len = 0;
130
131        loop {
132            // Hack: this relies on `impl Read for BufReader` always calling fill_buf
133            // if the buffer is empty, even for empty slices.
134            // It can't be called directly here since specialization prevents us
135            // from adding I: Read
136            match self.read(&mut []) {
137                Ok(_) => {}
138                Err(e) if e.canonicalize() == Error::Interrupted => continue,
139                Err(e) => return Err(e),
140            }
141            let buf = self.buffer();
142            if self.buffer().is_empty() {
143                return Ok(len);
144            }
145
146            // In case the writer side is a BufWriter then its write_all
147            // implements an optimization that passes through large
148            // buffers to the underlying writer. That code path is #[cold]
149            // but we're still avoiding redundant memcopies when doing
150            // a copy between buffered inputs and outputs.
151            to.write_all(buf)?;
152            len += buf.len() as u64;
153            self.discard_buffer();
154        }
155    }
156}
157
158/// Specialization of the read-write loop that either uses a stack buffer
159/// or reuses the internal buffer of a BufWriter
160trait BufferedWriterSpec: Write {
161    fn buffer_size(&self) -> usize;
162
163    fn copy_from<R: Read + ?Sized>(&mut self, reader: &mut R) -> Result<u64>;
164}
165
166impl<W: Write + ?Sized> BufferedWriterSpec for W {
167    #[inline]
168    default fn buffer_size(&self) -> usize {
169        0
170    }
171
172    default fn copy_from<R: Read + ?Sized>(&mut self, reader: &mut R) -> Result<u64> {
173        stack_buffer_copy(reader, self)
174    }
175}
176
177#[cfg(feature = "alloc")]
178impl BufferedWriterSpec for Vec<u8> {
179    fn buffer_size(&self) -> usize {
180        core::cmp::max(DEFAULT_BUF_SIZE, self.capacity() - self.len())
181    }
182
183    fn copy_from<R: Read + ?Sized>(&mut self, reader: &mut R) -> Result<u64> {
184        reader
185            .read_to_end(self)
186            .map(|bytes| u64::try_from(bytes).expect("usize overflowed u64"))
187    }
188}
189
190impl<I: Write + ?Sized> BufferedWriterSpec for BufWriter<I> {
191    fn buffer_size(&self) -> usize {
192        self.capacity()
193    }
194
195    fn copy_from<R: Read + ?Sized>(&mut self, reader: &mut R) -> Result<u64> {
196        if self.capacity() < DEFAULT_BUF_SIZE {
197            return stack_buffer_copy(reader, self);
198        }
199
200        let mut len = 0;
201        #[cfg(borrowedbuf_init)]
202        let mut init = 0;
203
204        loop {
205            let buf = self.buffer_mut();
206            let mut read_buf: BorrowedBuf<'_> = buf.spare_capacity_mut().into();
207
208            #[cfg(borrowedbuf_init)]
209            unsafe {
210                // SAFETY: init is either 0 or the init_len from the previous iteration.
211                read_buf.set_init(init);
212            }
213
214            if read_buf.capacity() >= DEFAULT_BUF_SIZE {
215                let mut cursor = read_buf.unfilled();
216                match reader.read_buf(cursor.reborrow()) {
217                    Ok(()) => {
218                        let bytes_read = cursor.written();
219
220                        if bytes_read == 0 {
221                            return Ok(len);
222                        }
223
224                        #[cfg(borrowedbuf_init)]
225                        {
226                            init = read_buf.init_len() - bytes_read;
227                        }
228                        len += bytes_read as u64;
229
230                        // SAFETY: BorrowedBuf guarantees all of its filled bytes are init
231                        unsafe { buf.set_len(buf.len() + bytes_read) };
232
233                        // Read again if the buffer still has enough capacity, as BufWriter itself
234                        // would do This will occur if the reader returns
235                        // short reads
236                    }
237                    Err(ref e) if e.canonicalize() == Error::Interrupted => {}
238                    Err(e) => return Err(e),
239                }
240            } else {
241                #[cfg(borrowedbuf_init)]
242                {
243                    // All the bytes that were already in the buffer are initialized,
244                    // treat them as such when the buffer is flushed.
245                    init += buf.len();
246                }
247
248                self.flush_buf()?;
249            }
250        }
251    }
252}