Skip to main content

ax_io/utils/
take.rs

1use core::{
2    cmp,
3    io::{BorrowedBuf, BorrowedCursor},
4};
5
6use crate::{BufRead, Error, IoBuf, Read, Result, Seek, SeekFrom};
7
8/// Reader adapter which limits the bytes read from an underlying reader.
9///
10/// This struct is generally created by calling [`take`] on a reader.
11/// Please see the documentation of [`take`] for more details.
12///
13/// See [`std::io::Take`] for more details.
14///
15/// [`take`]: Read::take
16#[derive(Debug)]
17pub struct Take<T> {
18    inner: T,
19    len: u64,
20    limit: u64,
21}
22
23impl<T> Take<T> {
24    pub(crate) fn new(inner: T, limit: u64) -> Self {
25        Take {
26            inner,
27            len: limit,
28            limit,
29        }
30    }
31
32    /// Returns the number of bytes that can be read before this instance will
33    /// return EOF.
34    pub fn limit(&self) -> u64 {
35        self.limit
36    }
37
38    /// Returns the number of bytes read so far.
39    pub fn position(&self) -> u64 {
40        self.len - self.limit
41    }
42
43    /// Sets the number of bytes that can be read before this instance will
44    /// return EOF. This is the same as constructing a new `Take` instance, so
45    /// the amount of bytes read and the previous limit value don't matter when
46    /// calling this method.
47    pub fn set_limit(&mut self, limit: u64) {
48        self.len = limit;
49        self.limit = limit;
50    }
51
52    /// Consumes the `Take`, returning the wrapped reader.
53    pub fn into_inner(self) -> T {
54        self.inner
55    }
56
57    /// Gets a reference to the underlying reader.
58    ///
59    /// Care should be taken to avoid modifying the internal I/O state of the
60    /// underlying reader as doing so may corrupt the internal limit of this
61    /// `Take`.
62    pub fn get_ref(&self) -> &T {
63        &self.inner
64    }
65
66    /// Gets a mutable reference to the underlying reader.
67    ///
68    /// Care should be taken to avoid modifying the internal I/O state of the
69    /// underlying reader as doing so may corrupt the internal limit of this
70    /// `Take`.
71    pub fn get_mut(&mut self) -> &mut T {
72        &mut self.inner
73    }
74}
75
76impl<T: Read> Read for Take<T> {
77    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
78        // Don't call into inner reader at all at EOF because it may still block
79        if self.limit == 0 {
80            return Ok(0);
81        }
82
83        let max = cmp::min(buf.len() as u64, self.limit) as usize;
84        let n = self.inner.read(&mut buf[..max])?;
85        assert!(n as u64 <= self.limit, "number of read bytes exceeds limit");
86        self.limit -= n as u64;
87        Ok(n)
88    }
89
90    fn read_buf(&mut self, mut buf: BorrowedCursor<'_>) -> Result<()> {
91        // Don't call into inner reader at all at EOF because it may still block
92        if self.limit == 0 {
93            return Ok(());
94        }
95
96        if self.limit < buf.capacity() as u64 {
97            // The condition above guarantees that `self.limit` fits in `usize`.
98            let limit = self.limit as usize;
99
100            #[cfg(borrowedbuf_init)]
101            let extra_init = cmp::min(limit, buf.init_mut().len());
102
103            // SAFETY: no uninit data is written to ibuf
104            let ibuf = unsafe { &mut buf.as_mut()[..limit] };
105
106            let mut sliced_buf: BorrowedBuf<'_> = ibuf.into();
107
108            #[cfg(borrowedbuf_init)]
109            // SAFETY: extra_init bytes of ibuf are known to be initialized
110            unsafe {
111                sliced_buf.set_init(extra_init);
112            }
113
114            let mut cursor = sliced_buf.unfilled();
115            let result = self.inner.read_buf(cursor.reborrow());
116
117            #[cfg(borrowedbuf_init)]
118            let new_init = cursor.init_mut().len();
119            let filled = sliced_buf.len();
120
121            // cursor / sliced_buf / ibuf must drop here
122
123            #[cfg(borrowedbuf_init)]
124            unsafe {
125                // SAFETY: filled bytes have been filled and therefore initialized
126                buf.advance_unchecked(filled);
127                // SAFETY: new_init bytes of buf's unfilled buffer have been initialized
128                buf.set_init(new_init);
129            }
130            #[cfg(not(borrowedbuf_init))]
131            // SAFETY: filled bytes have been filled and therefore initialized
132            unsafe {
133                buf.advance(filled);
134            }
135
136            self.limit -= filled as u64;
137
138            result
139        } else {
140            let written = buf.written();
141            let result = self.inner.read_buf(buf.reborrow());
142            self.limit -= (buf.written() - written) as u64;
143            result
144        }
145    }
146}
147
148impl<T: BufRead> BufRead for Take<T> {
149    fn fill_buf(&mut self) -> Result<&[u8]> {
150        // Don't call into inner reader at all at EOF because it may still block
151        if self.limit == 0 {
152            return Ok(&[]);
153        }
154
155        let buf = self.inner.fill_buf()?;
156        let cap = cmp::min(buf.len() as u64, self.limit) as usize;
157        Ok(&buf[..cap])
158    }
159
160    fn consume(&mut self, amt: usize) {
161        // Don't let callers reset the limit by passing an overlarge value
162        let amt = cmp::min(amt as u64, self.limit) as usize;
163        self.limit -= amt as u64;
164        self.inner.consume(amt);
165    }
166}
167
168impl<T: Seek> Seek for Take<T> {
169    fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
170        let new_position = match pos {
171            SeekFrom::Start(v) => Some(v),
172            SeekFrom::Current(v) => self.position().checked_add_signed(v),
173            SeekFrom::End(v) => self.len.checked_add_signed(v),
174        };
175        let new_position = match new_position {
176            Some(v) if v <= self.len => v,
177            _ => return Err(Error::InvalidInput),
178        };
179        while new_position != self.position() {
180            if let Some(offset) = new_position.checked_signed_diff(self.position()) {
181                self.inner.seek_relative(offset)?;
182                self.limit = self.limit.wrapping_sub(offset as u64);
183                break;
184            }
185            let offset = if new_position > self.position() {
186                i64::MAX
187            } else {
188                i64::MIN
189            };
190            self.inner.seek_relative(offset)?;
191            self.limit = self.limit.wrapping_sub(offset as u64);
192        }
193        Ok(new_position)
194    }
195
196    fn stream_len(&mut self) -> Result<u64> {
197        Ok(self.len)
198    }
199
200    fn stream_position(&mut self) -> Result<u64> {
201        Ok(self.position())
202    }
203
204    fn seek_relative(&mut self, offset: i64) -> Result<()> {
205        if self
206            .position()
207            .checked_add_signed(offset)
208            .is_none_or(|p| p > self.len)
209        {
210            return Err(Error::InvalidInput);
211        }
212        self.inner.seek_relative(offset)?;
213        self.limit = self.limit.wrapping_sub(offset as u64);
214        Ok(())
215    }
216}
217
218impl<T: IoBuf> IoBuf for Take<T> {
219    fn remaining(&self) -> usize {
220        cmp::min(self.inner.remaining(), self.limit as usize)
221    }
222}