use std::io::{self, BufRead, Read, Write, Seek, SeekFrom};
use std::cmp;
use std::ptr;
pub const DEFAULT_BUF_CAPACITY: usize = 4096;
pub const DEFAULT_BUF_INCREMENT: usize = 1024;
pub struct AccReader<R: Read> {
source: R,
buf: Vec<u8>,
pos: usize,
inc: usize,
}
impl<R: Read> AccReader<R> {
#[inline]
pub fn new(source: R) -> AccReader<R> {
AccReader::with_initial_capacity_and_increment(DEFAULT_BUF_CAPACITY, DEFAULT_BUF_INCREMENT, source)
}
#[inline]
pub fn with_initial_capacity(cap: usize, source: R) -> AccReader<R> {
AccReader::with_initial_capacity_and_increment(cap, DEFAULT_BUF_INCREMENT, source)
}
#[inline]
pub fn with_increment(inc: usize, source: R) -> AccReader<R> {
AccReader::with_initial_capacity_and_increment(DEFAULT_BUF_CAPACITY, inc, source)
}
#[inline]
pub fn with_initial_capacity_and_increment(cap: usize, inc: usize, source: R) -> AccReader<R> {
AccReader {
source: source,
buf: Vec::with_capacity(cap),
pos: 0,
inc: inc,
}
}
#[inline]
pub fn into_inner(self) -> R {
self.source
}
fn read_up_to(&mut self, n: u64) -> io::Result<()> {
let old_len = self.buf.len();
self.buf.reserve(n as usize);
unsafe { self.buf.set_len(old_len + n as usize); }
let mut error = None;
let mut read = 0;
{
let mut target = &mut self.buf[old_len..];
while !target.is_empty() {
match self.source.read(target) {
Ok(0) => break,
Ok(n) => { read += n; let tmp = target; target = &mut tmp[n..]; }
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => { error = Some(e); break; },
}
}
}
unsafe { self.buf.set_len(old_len + read as usize); }
if let Some(e) = error {
Err(e)
} else {
Ok(())
}
}
}
impl<R: Read> Read for AccReader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let need_to_read = cmp::min(self.buf.len() - self.pos, buf.len());
if need_to_read > 0 {
unsafe {
ptr::copy_nonoverlapping(
self.buf.as_ptr().offset(self.pos as isize),
buf.as_mut_ptr(),
need_to_read
);
}
self.pos += need_to_read;
Ok(need_to_read)
} else { let read = try!(self.source.read(buf));
let _ = self.buf.write_all(&buf[..read]);
self.pos += read;
Ok(read)
}
}
}
impl<R: Read> BufRead for AccReader<R> {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
let available = self.buf.len() - self.pos; if available == 0 {
let old_len = self.buf.len();
self.buf.reserve(self.inc);
unsafe { self.buf.set_len(old_len + self.inc); }
let (read, error) = match self.source.read(&mut self.buf[self.pos..]) {
Ok(n) => (n, None),
Err(e) => (0, Some(e)),
};
unsafe { self.buf.set_len(old_len + read); }
if let Some(e) = error {
Err(e)
} else {
Ok(&self.buf[self.pos..])
}
} else {
Ok(&self.buf[self.pos..])
}
}
fn consume(&mut self, amt: usize) {
self.pos = cmp::min(self.pos + amt, self.buf.len());
}
}
impl<R: Read> Seek for AccReader<R> {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
match pos {
SeekFrom::End(n) => {
if n > 0 {
Err(io::Error::new(io::ErrorKind::UnexpectedEof, "seeking beyond end of stream"))
} else {
try!(self.source.read_to_end(&mut self.buf));
let d = (-n) as u64;
if d > self.buf.len() as u64 {
Err(io::Error::new(io::ErrorKind::InvalidInput, "seeking before the begining of stream"))
} else {
self.pos = (self.buf.len() as u64 - d) as usize;
Ok(self.pos as u64)
}
}
}
SeekFrom::Start(n) if n <= self.buf.len() as u64 => {
self.pos = n as usize;
Ok(self.pos as u64)
}
SeekFrom::Start(n) => { let need_to_read = n - self.buf.len() as u64;
try!(self.read_up_to(need_to_read));
if n > self.buf.len() as u64 { Err(io::Error::new(io::ErrorKind::UnexpectedEof, "seeking beyond end of stream"))
} else {
self.pos = n as usize;
Ok(n)
}
}
SeekFrom::Current(0) => { Ok(self.pos as u64) }
SeekFrom::Current(n) if n < 0 => {
let d = (-n) as u64;
if d > self.pos as u64 {
Err(io::Error::new(io::ErrorKind::InvalidInput, "seeking before the beginning of stream"))
} else {
self.pos = (self.pos as u64 - d) as usize;
Ok(self.pos as u64)
}
}
SeekFrom::Current(n) => { let new_pos = self.pos as u64 + n as u64;
if new_pos > self.buf.len() as u64 {
let need_to_read = new_pos - self.buf.len() as u64;
try!(self.read_up_to(need_to_read));
if new_pos > self.buf.len() as u64 { Err(io::Error::new(io::ErrorKind::UnexpectedEof, "seeking beyond end of stream"))
} else {
self.pos = new_pos as usize;
Ok(new_pos)
}
} else {
self.pos = new_pos as usize;
Ok(self.pos as u64)
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::io::{self, BufRead, Read, Seek, SeekFrom};
use super::*;
#[test]
fn test_acc_reader_read() {
let inner: &[u8] = &[5, 6, 7, 0, 1, 2, 3];
let mut reader = AccReader::new(inner);
let mut buf = [0, 0];
assert_eq!(reader.read(&mut buf).unwrap(), 2);
assert_eq!(buf, [5, 6]);
assert_eq!(reader.read(&mut buf).unwrap(), 2);
assert_eq!(buf, [7, 0]);
assert_eq!(reader.read(&mut buf).unwrap(), 2);
assert_eq!(buf, [1, 2]);
assert_eq!(reader.read(&mut buf).unwrap(), 1);
assert_eq!(buf[0], 3);
}
#[test]
fn test_acc_reader_buf_read() {
let inner: &[u8] = &[5, 6, 7, 0, 1, 2, 3, 4];
let mut reader = AccReader::with_initial_capacity_and_increment(3, 3, inner);
assert_eq!(reader.fill_buf().ok(), Some(&[5, 6, 7][..]));
reader.consume(3);
assert_eq!(reader.fill_buf().ok(), Some(&[0, 1, 2][..]));
reader.consume(3);
assert_eq!(reader.fill_buf().ok(), Some(&[3, 4][..]));
reader.consume(2);
assert_eq!(reader.fill_buf().ok(), Some(&[][..]));
}
#[test]
fn test_acc_reader_seek() {
let inner: &[u8] = &[5, 6, 7, 0, 1, 2, 3, 4];
let mut reader = AccReader::new(inner);
let mut buf = [0, 0];
assert_eq!(reader.seek(SeekFrom::Start(2)).unwrap(), 2);
reader.read_exact(&mut buf).unwrap();
assert_eq!(buf, [7, 0]);
assert_eq!(reader.seek(SeekFrom::Current(-1)).unwrap(), 3);
reader.read_exact(&mut buf).unwrap();
assert_eq!(buf, [0, 1]);
assert_eq!(reader.seek(SeekFrom::End(-3)).unwrap(), 5);
reader.read_exact(&mut buf).unwrap();
assert_eq!(buf, [2, 3]);
assert_eq!(reader.seek(SeekFrom::End(3)).err().unwrap().kind(), io::ErrorKind::UnexpectedEof);
assert_eq!(reader.seek(SeekFrom::Current(-128)).err().unwrap().kind(), io::ErrorKind::InvalidInput);
let mut reader = AccReader::new(inner);
assert_eq!(reader.seek(SeekFrom::Start(inner.len() as u64)).unwrap(), inner.len() as u64);
assert_eq!(reader.read(&mut buf).unwrap(), 0);
let mut reader = AccReader::new(inner);
assert_eq!(reader.seek(SeekFrom::Start(128)).err().unwrap().kind(), io::ErrorKind::UnexpectedEof);
}
}