use bytes::{Buf, BytesMut};
use std::io::{BufRead, BufReader, Read, Result, Seek, SeekFrom};
pub struct PeekReader<R: Read> {
source: BufReader<R>,
buf: BytesMut,
}
impl<R: Read> PeekReader<R> {
pub fn with_capacity(capacity: usize, inner: R) -> Self {
Self {
source: BufReader::with_capacity(capacity, inner),
buf: BytesMut::new(),
}
}
pub fn peek(&mut self, amt: usize) -> Result<&[u8]> {
if self.buf.remaining() < amt {
let mut extend = amt - self.buf.remaining();
self.buf.resize(amt, 0);
while extend > 0 {
let start = self.buf.len() - extend;
let count = self.source.read(&mut self.buf[start..])?;
if count == 0 {
self.buf.truncate(start);
break;
}
extend -= count;
}
}
Ok(&self.buf[..self.buf.len().min(amt)])
}
}
impl<R: Read> Read for PeekReader<R> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
if buf.is_empty() {
return Ok(0);
}
if self.buf.has_remaining() {
let count = buf.len().min(self.buf.remaining());
self.buf.copy_to_slice(&mut buf[..count]);
return Ok(count);
}
self.source.read(buf)
}
}
impl<R: Read + Seek> Seek for PeekReader<R> {
fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
self.buf.clear();
self.source.seek(pos)
}
}
impl<R: Read> BufRead for PeekReader<R> {
fn fill_buf(&mut self) -> Result<&[u8]> {
if self.buf.has_remaining() {
Ok(&self.buf)
} else {
self.source.fill_buf()
}
}
fn consume(&mut self, amt: usize) {
if self.buf.has_remaining() {
assert!(amt <= self.buf.remaining());
self.buf.advance(amt);
} else {
self.source.consume(amt);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
fn make_peek() -> PeekReader<Cursor<&'static [u8]>> {
PeekReader::with_capacity(64, Cursor::new(b"abcdefghijklmnopqrstuvwxyz"))
}
fn read_bytes<R: Read>(peek: &mut PeekReader<R>, amt: usize) -> Vec<u8> {
let mut buf = vec![0; amt];
let amt = peek.read(&mut buf).unwrap();
buf.truncate(amt);
buf
}
#[test]
fn read() {
let mut peek = make_peek();
assert_eq!(&read_bytes(&mut peek, 3), b"abc");
assert_eq!(&read_bytes(&mut peek, 3), b"def");
assert_eq!(peek.peek(2).unwrap(), b"gh");
assert_eq!(peek.peek(1).unwrap(), b"g");
assert_eq!(peek.peek(4).unwrap(), b"ghij");
assert_eq!(&read_bytes(&mut peek, 3), b"ghi");
assert_eq!(peek.peek(2).unwrap(), b"jk");
assert_eq!(&read_bytes(&mut peek, 3), b"jk");
assert_eq!(&read_bytes(&mut peek, 3), b"lmn");
}
#[test]
fn seek() {
let mut peek = make_peek();
assert_eq!(peek.peek(4).unwrap(), b"abcd");
peek.seek(SeekFrom::Start(10)).unwrap();
assert_eq!(&read_bytes(&mut peek, 3), b"klm");
assert_eq!(peek.peek(4).unwrap(), b"nopq");
peek.seek(SeekFrom::Start(5)).unwrap();
assert_eq!(peek.peek(4).unwrap(), b"fghi");
}
#[test]
fn buf() {
let mut peek = make_peek();
assert_eq!(peek.fill_buf().unwrap(), b"abcdefghijklmnopqrstuvwxyz");
peek.consume(5);
assert_eq!(peek.fill_buf().unwrap(), b"fghijklmnopqrstuvwxyz");
assert_eq!(peek.peek(5).unwrap(), b"fghij");
assert_eq!(peek.fill_buf().unwrap(), b"fghij");
peek.consume(3);
assert_eq!(peek.fill_buf().unwrap(), b"ij");
peek.consume(2);
assert_eq!(peek.fill_buf().unwrap(), b"klmnopqrstuvwxyz");
}
#[test]
fn eof() {
let mut peek = make_peek();
peek.seek(SeekFrom::Start(24)).unwrap();
assert_eq!(peek.peek(4).unwrap(), b"yz");
assert_eq!(&read_bytes(&mut peek, 3), b"yz");
assert_eq!(peek.peek(4).unwrap(), b"");
}
}