1use bytes::{Buf, BytesMut};
21use std::io::{BufRead, BufReader, Read, Result, Seek, SeekFrom};
22
23pub struct PeekReader<R: Read> {
24 source: BufReader<R>,
25 buf: BytesMut,
26}
27
28impl<R: Read> PeekReader<R> {
29 pub fn with_capacity(capacity: usize, inner: R) -> Self {
30 Self {
31 source: BufReader::with_capacity(capacity, inner),
32 buf: BytesMut::new(),
33 }
34 }
35
36 pub fn peek(&mut self, amt: usize) -> Result<&[u8]> {
39 if self.buf.remaining() < amt {
40 let mut extend = amt - self.buf.remaining();
41 self.buf.resize(amt, 0);
42 while extend > 0 {
43 let start = self.buf.len() - extend;
44 let count = self.source.read(&mut self.buf[start..])?;
45 if count == 0 {
46 self.buf.truncate(start);
48 break;
49 }
50 extend -= count;
51 }
52 }
53 Ok(&self.buf[..self.buf.len().min(amt)])
54 }
55
56 }
59
60impl<R: Read> Read for PeekReader<R> {
61 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
62 if buf.is_empty() {
63 return Ok(0);
64 }
65 if self.buf.has_remaining() {
66 let count = buf.len().min(self.buf.remaining());
67 self.buf.copy_to_slice(&mut buf[..count]);
68 return Ok(count);
69 }
70 self.source.read(buf)
71 }
72}
73
74impl<R: Read + Seek> Seek for PeekReader<R> {
75 fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
76 self.buf.clear();
77 self.source.seek(pos)
78 }
79}
80
81impl<R: Read> BufRead for PeekReader<R> {
82 fn fill_buf(&mut self) -> Result<&[u8]> {
83 if self.buf.has_remaining() {
84 Ok(&self.buf)
85 } else {
86 self.source.fill_buf()
87 }
88 }
89
90 fn consume(&mut self, amt: usize) {
91 if self.buf.has_remaining() {
92 assert!(amt <= self.buf.remaining());
93 self.buf.advance(amt);
94 } else {
95 self.source.consume(amt);
96 }
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use std::io::Cursor;
104
105 fn make_peek() -> PeekReader<Cursor<&'static [u8]>> {
106 PeekReader::with_capacity(64, Cursor::new(b"abcdefghijklmnopqrstuvwxyz"))
109 }
110
111 fn read_bytes<R: Read>(peek: &mut PeekReader<R>, amt: usize) -> Vec<u8> {
112 let mut buf = vec![0; amt];
113 let amt = peek.read(&mut buf).unwrap();
114 buf.truncate(amt);
115 buf
116 }
117
118 #[test]
119 fn read() {
120 let mut peek = make_peek();
121 assert_eq!(&read_bytes(&mut peek, 3), b"abc");
123 assert_eq!(&read_bytes(&mut peek, 3), b"def");
124 assert_eq!(peek.peek(2).unwrap(), b"gh");
126 assert_eq!(peek.peek(1).unwrap(), b"g");
128 assert_eq!(peek.peek(4).unwrap(), b"ghij");
130 assert_eq!(&read_bytes(&mut peek, 3), b"ghi");
132 assert_eq!(peek.peek(2).unwrap(), b"jk");
134 assert_eq!(&read_bytes(&mut peek, 3), b"jk");
136 assert_eq!(&read_bytes(&mut peek, 3), b"lmn");
138 }
139
140 #[test]
141 fn seek() {
142 let mut peek = make_peek();
143 assert_eq!(peek.peek(4).unwrap(), b"abcd");
145 peek.seek(SeekFrom::Start(10)).unwrap();
147 assert_eq!(&read_bytes(&mut peek, 3), b"klm");
149 assert_eq!(peek.peek(4).unwrap(), b"nopq");
151 peek.seek(SeekFrom::Start(5)).unwrap();
153 assert_eq!(peek.peek(4).unwrap(), b"fghi");
155 }
156
157 #[test]
158 fn buf() {
159 let mut peek = make_peek();
160 assert_eq!(peek.fill_buf().unwrap(), b"abcdefghijklmnopqrstuvwxyz");
162 peek.consume(5);
163 assert_eq!(peek.fill_buf().unwrap(), b"fghijklmnopqrstuvwxyz");
165 assert_eq!(peek.peek(5).unwrap(), b"fghij");
167 assert_eq!(peek.fill_buf().unwrap(), b"fghij");
169 peek.consume(3);
170 assert_eq!(peek.fill_buf().unwrap(), b"ij");
172 peek.consume(2);
173 assert_eq!(peek.fill_buf().unwrap(), b"klmnopqrstuvwxyz");
175 }
176
177 #[test]
178 fn eof() {
179 let mut peek = make_peek();
180 peek.seek(SeekFrom::Start(24)).unwrap();
182 assert_eq!(peek.peek(4).unwrap(), b"yz");
184 assert_eq!(&read_bytes(&mut peek, 3), b"yz");
186 assert_eq!(peek.peek(4).unwrap(), b"");
188 }
189}