libcoreinst/io/
peek.rs

1// Copyright 2022 Red Hat, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Read wrapper that allows peeking ahead in the stream without consuming
16// the peeked bytes.  BufRead.fill_buf() does not provide this, since it
17// only guarantees to return one byte.  For simplicity, we implement this
18// as a thin wrapper around BufReader.
19
20use bytes::{Buf, BytesMut};
21use std::io::{BufRead, BufReader, Read, Result, Seek, SeekFrom};
22
23pub struct PeekReader<R: Read> {
24    source: BufReader<R>,
25    buf: BytesMut,
26}
27
28impl<R: Read> PeekReader<R> {
29    pub fn with_capacity(capacity: usize, inner: R) -> Self {
30        Self {
31            source: BufReader::with_capacity(capacity, inner),
32            buf: BytesMut::new(),
33        }
34    }
35
36    /// Return the next amt bytes without consuming them.  May return fewer
37    /// bytes at EOF.
38    pub fn peek(&mut self, amt: usize) -> Result<&[u8]> {
39        if self.buf.remaining() < amt {
40            let mut extend = amt - self.buf.remaining();
41            self.buf.resize(amt, 0);
42            while extend > 0 {
43                let start = self.buf.len() - extend;
44                let count = self.source.read(&mut self.buf[start..])?;
45                if count == 0 {
46                    // EOF
47                    self.buf.truncate(start);
48                    break;
49                }
50                extend -= count;
51            }
52        }
53        Ok(&self.buf[..self.buf.len().min(amt)])
54    }
55
56    // no direct access to inner source, since that would lose data if
57    // buf is non-empty
58}
59
60impl<R: Read> Read for PeekReader<R> {
61    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
62        if buf.is_empty() {
63            return Ok(0);
64        }
65        if self.buf.has_remaining() {
66            let count = buf.len().min(self.buf.remaining());
67            self.buf.copy_to_slice(&mut buf[..count]);
68            return Ok(count);
69        }
70        self.source.read(buf)
71    }
72}
73
74impl<R: Read + Seek> Seek for PeekReader<R> {
75    fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
76        self.buf.clear();
77        self.source.seek(pos)
78    }
79}
80
81impl<R: Read> BufRead for PeekReader<R> {
82    fn fill_buf(&mut self) -> Result<&[u8]> {
83        if self.buf.has_remaining() {
84            Ok(&self.buf)
85        } else {
86            self.source.fill_buf()
87        }
88    }
89
90    fn consume(&mut self, amt: usize) {
91        if self.buf.has_remaining() {
92            assert!(amt <= self.buf.remaining());
93            self.buf.advance(amt);
94        } else {
95            self.source.consume(amt);
96        }
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use std::io::Cursor;
104
105    fn make_peek() -> PeekReader<Cursor<&'static [u8]>> {
106        // use BufReader capacity larger than input; we're not testing
107        // BufReader's buffering behavior
108        PeekReader::with_capacity(64, Cursor::new(b"abcdefghijklmnopqrstuvwxyz"))
109    }
110
111    fn read_bytes<R: Read>(peek: &mut PeekReader<R>, amt: usize) -> Vec<u8> {
112        let mut buf = vec![0; amt];
113        let amt = peek.read(&mut buf).unwrap();
114        buf.truncate(amt);
115        buf
116    }
117
118    #[test]
119    fn read() {
120        let mut peek = make_peek();
121        // read some bytes
122        assert_eq!(&read_bytes(&mut peek, 3), b"abc");
123        assert_eq!(&read_bytes(&mut peek, 3), b"def");
124        // peek at some bytes
125        assert_eq!(peek.peek(2).unwrap(), b"gh");
126        // peek reuses existing buffer
127        assert_eq!(peek.peek(1).unwrap(), b"g");
128        // peek extends buffer
129        assert_eq!(peek.peek(4).unwrap(), b"ghij");
130        // read after peek, partially emptying buffer
131        assert_eq!(&read_bytes(&mut peek, 3), b"ghi");
132        // peek extends buffer
133        assert_eq!(peek.peek(2).unwrap(), b"jk");
134        // read after peek, emptying buffer
135        assert_eq!(&read_bytes(&mut peek, 3), b"jk");
136        // normal read
137        assert_eq!(&read_bytes(&mut peek, 3), b"lmn");
138    }
139
140    #[test]
141    fn seek() {
142        let mut peek = make_peek();
143        // fill peek buffer
144        assert_eq!(peek.peek(4).unwrap(), b"abcd");
145        // seek
146        peek.seek(SeekFrom::Start(10)).unwrap();
147        // read
148        assert_eq!(&read_bytes(&mut peek, 3), b"klm");
149        // fill peek buffer
150        assert_eq!(peek.peek(4).unwrap(), b"nopq");
151        // seek
152        peek.seek(SeekFrom::Start(5)).unwrap();
153        // peek
154        assert_eq!(peek.peek(4).unwrap(), b"fghi");
155    }
156
157    #[test]
158    fn buf() {
159        let mut peek = make_peek();
160        // BufRead fill and partial consume
161        assert_eq!(peek.fill_buf().unwrap(), b"abcdefghijklmnopqrstuvwxyz");
162        peek.consume(5);
163        // BufRead fill
164        assert_eq!(peek.fill_buf().unwrap(), b"fghijklmnopqrstuvwxyz");
165        // peek
166        assert_eq!(peek.peek(5).unwrap(), b"fghij");
167        // Peek buffer fill and partial consume
168        assert_eq!(peek.fill_buf().unwrap(), b"fghij");
169        peek.consume(3);
170        // Peek buffer fill and consume
171        assert_eq!(peek.fill_buf().unwrap(), b"ij");
172        peek.consume(2);
173        // BufRead fill
174        assert_eq!(peek.fill_buf().unwrap(), b"klmnopqrstuvwxyz");
175    }
176
177    #[test]
178    fn eof() {
179        let mut peek = make_peek();
180        // seek to near end
181        peek.seek(SeekFrom::Start(24)).unwrap();
182        // peek past end
183        assert_eq!(peek.peek(4).unwrap(), b"yz");
184        // read to end
185        assert_eq!(&read_bytes(&mut peek, 3), b"yz");
186        // peek at end
187        assert_eq!(peek.peek(4).unwrap(), b"");
188    }
189}