Skip to main content

ax_io/utils/
take.rs

1use core::{
2    cmp,
3    io::{BorrowedBuf, BorrowedCursor},
4};
5
6use crate::{BufRead, Error, IoBuf, Read, Result, Seek, SeekFrom};
7
8/// Reader adapter which limits the bytes read from an underlying reader.
9///
10/// This struct is generally created by calling [`take`] on a reader.
11/// Please see the documentation of [`take`] for more details.
12///
13/// See [`std::io::Take`] for more details.
14///
15/// [`take`]: Read::take
16#[derive(Debug)]
17pub struct Take<T> {
18    inner: T,
19    len: u64,
20    limit: u64,
21}
22
23impl<T> Take<T> {
24    pub(crate) fn new(inner: T, limit: u64) -> Self {
25        Take {
26            inner,
27            len: limit,
28            limit,
29        }
30    }
31
32    /// Returns the number of bytes that can be read before this instance will
33    /// return EOF.
34    pub fn limit(&self) -> u64 {
35        self.limit
36    }
37
38    /// Returns the number of bytes read so far.
39    pub fn position(&self) -> u64 {
40        self.len - self.limit
41    }
42
43    /// Sets the number of bytes that can be read before this instance will
44    /// return EOF. This is the same as constructing a new `Take` instance, so
45    /// the amount of bytes read and the previous limit value don't matter when
46    /// calling this method.
47    pub fn set_limit(&mut self, limit: u64) {
48        self.len = limit;
49        self.limit = limit;
50    }
51
52    /// Consumes the `Take`, returning the wrapped reader.
53    pub fn into_inner(self) -> T {
54        self.inner
55    }
56
57    /// Gets a reference to the underlying reader.
58    ///
59    /// Care should be taken to avoid modifying the internal I/O state of the
60    /// underlying reader as doing so may corrupt the internal limit of this
61    /// `Take`.
62    pub fn get_ref(&self) -> &T {
63        &self.inner
64    }
65
66    /// Gets a mutable reference to the underlying reader.
67    ///
68    /// Care should be taken to avoid modifying the internal I/O state of the
69    /// underlying reader as doing so may corrupt the internal limit of this
70    /// `Take`.
71    pub fn get_mut(&mut self) -> &mut T {
72        &mut self.inner
73    }
74}
75
76impl<T: Read> Read for Take<T> {
77    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
78        // Don't call into inner reader at all at EOF because it may still block
79        if self.limit == 0 {
80            return Ok(0);
81        }
82
83        let max = cmp::min(buf.len() as u64, self.limit) as usize;
84        let n = self.inner.read(&mut buf[..max])?;
85        assert!(n as u64 <= self.limit, "number of read bytes exceeds limit");
86        self.limit -= n as u64;
87        Ok(n)
88    }
89
90    fn read_buf(&mut self, mut buf: BorrowedCursor<'_>) -> Result<()> {
91        // Don't call into inner reader at all at EOF because it may still block
92        if self.limit == 0 {
93            return Ok(());
94        }
95
96        if self.limit < buf.capacity() as u64 {
97            // The condition above guarantees that `self.limit` fits in `usize`.
98            let limit = self.limit as usize;
99
100            let is_init = buf.is_init();
101
102            // SAFETY: no uninit data is written to ibuf
103            let ibuf = unsafe { &mut buf.as_mut()[..limit] };
104
105            let mut sliced_buf: BorrowedBuf<'_> = ibuf.into();
106
107            // SAFETY: extra_init bytes of ibuf are known to be initialized
108            if is_init {
109                unsafe { sliced_buf.set_init() };
110            }
111
112            let mut cursor = sliced_buf.unfilled();
113            let result = self.inner.read_buf(cursor.reborrow());
114
115            let should_init = cursor.is_init();
116            let filled = sliced_buf.len();
117
118            // cursor / sliced_buf / ibuf must drop here
119
120            // Avoid accidentally quadratic behaviour by initializing the whole
121            // cursor if only part of it was initialized.
122            if should_init {
123                // SAFETY: no uninit data is written
124                let uninit = unsafe { &mut buf.as_mut()[limit..] };
125                uninit.write_filled(0);
126                // SAFETY: all bytes that were not initialized by `T::read_buf`
127                // have just been written to.
128                unsafe { buf.set_init() };
129            }
130
131            unsafe {
132                // SAFETY: filled bytes have been filled and therefore initialized
133                buf.advance(filled);
134            }
135
136            self.limit -= filled as u64;
137
138            result
139        } else {
140            let written = buf.written();
141            let result = self.inner.read_buf(buf.reborrow());
142            self.limit -= (buf.written() - written) as u64;
143            result
144        }
145    }
146}
147
148impl<T: BufRead> BufRead for Take<T> {
149    fn fill_buf(&mut self) -> Result<&[u8]> {
150        // Don't call into inner reader at all at EOF because it may still block
151        if self.limit == 0 {
152            return Ok(&[]);
153        }
154
155        let buf = self.inner.fill_buf()?;
156        let cap = cmp::min(buf.len() as u64, self.limit) as usize;
157        Ok(&buf[..cap])
158    }
159
160    fn consume(&mut self, amt: usize) {
161        // Don't let callers reset the limit by passing an overlarge value
162        let amt = cmp::min(amt as u64, self.limit) as usize;
163        self.limit -= amt as u64;
164        self.inner.consume(amt);
165    }
166}
167
168impl<T: Seek> Seek for Take<T> {
169    fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
170        let new_position = match pos {
171            SeekFrom::Start(v) => Some(v),
172            SeekFrom::Current(v) => self.position().checked_add_signed(v),
173            SeekFrom::End(v) => self.len.checked_add_signed(v),
174        };
175        let new_position = match new_position {
176            Some(v) if v <= self.len => v,
177            _ => return Err(Error::InvalidInput),
178        };
179        while new_position != self.position() {
180            if let Some(offset) = new_position.checked_signed_diff(self.position()) {
181                self.inner.seek_relative(offset)?;
182                self.limit = self.limit.wrapping_sub(offset as u64);
183                break;
184            }
185            let offset = if new_position > self.position() {
186                i64::MAX
187            } else {
188                i64::MIN
189            };
190            self.inner.seek_relative(offset)?;
191            self.limit = self.limit.wrapping_sub(offset as u64);
192        }
193        Ok(new_position)
194    }
195
196    fn stream_len(&mut self) -> Result<u64> {
197        Ok(self.len)
198    }
199
200    fn stream_position(&mut self) -> Result<u64> {
201        Ok(self.position())
202    }
203
204    fn seek_relative(&mut self, offset: i64) -> Result<()> {
205        if self
206            .position()
207            .checked_add_signed(offset)
208            .is_none_or(|p| p > self.len)
209        {
210            return Err(Error::InvalidInput);
211        }
212        self.inner.seek_relative(offset)?;
213        self.limit = self.limit.wrapping_sub(offset as u64);
214        Ok(())
215    }
216}
217
218impl<T: IoBuf> IoBuf for Take<T> {
219    fn remaining(&self) -> usize {
220        cmp::min(self.inner.remaining(), self.limit as usize)
221    }
222}