Skip to main content

binrw/io/
take_seek.rs

1//! Types for seekable reader adapters which limit the number of bytes read from
2//! the underlying reader.
3
4use super::{Read, Result, Seek, SeekFrom};
5
6/// Read adapter which limits the bytes read from an underlying reader, with
7/// seek support.
8///
9/// This struct is generally created by importing the [`TakeSeekExt`] extension
10/// and calling [`take_seek`] on a reader.
11///
12/// [`take_seek`]: TakeSeekExt::take_seek
13#[derive(Debug)]
14pub struct TakeSeek<T> {
15    inner: T,
16    pos: u64,
17    end: u64,
18}
19
20impl<T> TakeSeek<T> {
21    /// Gets a reference to the underlying reader.
22    pub fn get_ref(&self) -> &T {
23        &self.inner
24    }
25
26    /// Gets a mutable reference to the underlying reader.
27    ///
28    /// Care should be taken to avoid modifying the internal I/O state of the
29    /// underlying reader as doing so may corrupt the internal limit of this
30    /// `TakeSeek`.
31    pub fn get_mut(&mut self) -> &mut T {
32        &mut self.inner
33    }
34
35    /// Consumes this wrapper, returning the wrapped value.
36    pub fn into_inner(self) -> T {
37        self.inner
38    }
39
40    /// Returns the number of bytes that can be read before this instance will
41    /// return EOF.
42    ///
43    /// # Note
44    ///
45    /// This instance may reach EOF after reading fewer bytes than indicated by
46    /// this method if the underlying [`Read`] instance reaches EOF.
47    pub fn limit(&self) -> u64 {
48        self.end.saturating_sub(self.pos)
49    }
50}
51
52impl<T: Seek> TakeSeek<T> {
53    /// Sets the number of bytes that can be read before this instance will
54    /// return EOF. This is the same as constructing a new `TakeSeek` instance,
55    /// so the amount of bytes read and the previous limit value don’t matter
56    /// when calling this method.
57    ///
58    /// # Panics
59    ///
60    /// Panics if the inner stream returns an error from `stream_position`.
61    pub fn set_limit(&mut self, limit: u64) {
62        let pos = self
63            .inner
64            .stream_position()
65            .expect("cannot get position for `set_limit`");
66        self.pos = pos;
67        self.end = pos + limit;
68    }
69}
70
71impl<T: Read> Read for TakeSeek<T> {
72    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
73        let limit = self.limit();
74
75        // Don't call into inner reader at all at EOF because it may still block
76        if limit == 0 {
77            return Ok(0);
78        }
79
80        // Lint: It is impossible for this cast to truncate because the value
81        // being cast is the minimum of two values, and one of the value types
82        // is already `usize`.
83        #[allow(clippy::cast_possible_truncation)]
84        let max = (buf.len() as u64).min(limit) as usize;
85        let n = self.inner.read(&mut buf[0..max])?;
86        self.pos += n as u64;
87        Ok(n)
88    }
89}
90
91impl<T: Seek> Seek for TakeSeek<T> {
92    fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
93        let pos = match pos {
94            SeekFrom::End(end) => {
95                let inner_end = self.inner.seek(SeekFrom::End(0))?;
96                match self.end.min(inner_end).checked_add_signed(end) {
97                    Some(pos) => SeekFrom::Start(pos),
98                    None => {
99                        return Err(super::Error::new(
100                            super::ErrorKind::InvalidInput,
101                            "invalid seek to a negative or overflowing position",
102                        ))
103                    }
104                }
105            }
106            pos => pos,
107        };
108        self.pos = self.inner.seek(pos)?;
109        Ok(self.pos)
110    }
111
112    fn stream_position(&mut self) -> Result<u64> {
113        Ok(self.pos)
114    }
115}
116
117/// An extension trait that implements `take_seek()` for compatible streams.
118pub trait TakeSeekExt {
119    /// Creates an adapter which will read at most `limit` bytes from the
120    /// wrapped stream.
121    fn take_seek(self, limit: u64) -> TakeSeek<Self>
122    where
123        Self: Sized;
124}
125
126impl<T: Read + Seek> TakeSeekExt for T {
127    fn take_seek(mut self, limit: u64) -> TakeSeek<Self>
128    where
129        Self: Sized,
130    {
131        let pos = self
132            .stream_position()
133            .expect("cannot get position for `take_seek`");
134
135        TakeSeek {
136            inner: self,
137            pos,
138            end: pos + limit,
139        }
140    }
141}