use std::cmp;
use std::io::{self, Read, Write};
use byteorder::{ReadBytesExt, ByteOrder, LittleEndian as LE};
use compress::{Encoder, max_compress_len};
use crc32::crc32c;
use decompress::{Decoder, decompress_len};
use error::{Error, IntoInnerError, new_into_inner_error};
use MAX_BLOCK_SIZE;
lazy_static! {
static ref MAX_COMPRESS_BLOCK_SIZE: usize =
max_compress_len(MAX_BLOCK_SIZE);
}
const STREAM_IDENTIFIER: &'static [u8] = b"\xFF\x06\x00\x00sNaPpY";
const STREAM_BODY: &'static [u8] = b"sNaPpY";
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum ChunkType {
Stream = 0xFF,
Compressed = 0x00,
Uncompressed = 0x01,
Padding = 0xFE,
}
impl ChunkType {
fn from_u8(b: u8) -> Result<ChunkType, u8> {
match b {
0xFF => Ok(ChunkType::Stream),
0x00 => Ok(ChunkType::Compressed),
0x01 => Ok(ChunkType::Uncompressed),
0xFE => Ok(ChunkType::Padding),
b => Err(b),
}
}
}
pub struct Writer<W: Write> {
inner: Option<Inner<W>>,
src: Vec<u8>,
}
struct Inner<W> {
w: W,
enc: Encoder,
dst: Vec<u8>,
wrote_stream_ident: bool,
chunk_header: [u8; 8],
}
impl<W: Write> Writer<W> {
pub fn new(wtr: W) -> Writer<W> {
Writer {
inner: Some(Inner {
w: wtr,
enc: Encoder::new(),
dst: vec![0; *MAX_COMPRESS_BLOCK_SIZE],
wrote_stream_ident: false,
chunk_header: [0; 8],
}),
src: Vec::with_capacity(MAX_BLOCK_SIZE),
}
}
pub fn into_inner(mut self) -> Result<W, IntoInnerError<Writer<W>>> {
match self.flush() {
Ok(()) => Ok(self.inner.take().unwrap().w),
Err(err) => Err(new_into_inner_error(self, err)),
}
}
}
impl<W: Write> Drop for Writer<W> {
fn drop(&mut self) {
if self.inner.is_some() {
let _ = self.flush();
}
}
}
impl<W: Write> Write for Writer<W> {
fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
let mut total = 0;
loop {
let free = self.src.capacity() - self.src.len();
let n =
if buf.len() <= free {
break;
} else if self.src.is_empty() {
try!(self.inner.as_mut().unwrap().write(buf))
} else {
self.src.extend_from_slice(&buf[0..free]);
try!(self.flush());
free
};
buf = &buf[n..];
total += n;
}
debug_assert!(buf.len() <= (self.src.capacity() - self.src.len()));
self.src.extend_from_slice(buf);
total += buf.len();
debug_assert!(self.src.capacity() == MAX_BLOCK_SIZE);
Ok(total)
}
fn flush(&mut self) -> io::Result<()> {
if self.src.is_empty() {
return Ok(());
}
try!(self.inner.as_mut().unwrap().write(&self.src));
self.src.truncate(0);
Ok(())
}
}
impl<W: Write> Inner<W> {
fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
let mut total = 0;
if !self.wrote_stream_ident {
self.wrote_stream_ident = true;
try!(self.w.write_all(STREAM_IDENTIFIER));
}
while !buf.is_empty() {
let mut src = buf;
if src.len() > MAX_BLOCK_SIZE {
src = &src[0..MAX_BLOCK_SIZE];
}
buf = &buf[src.len()..];
let checksum = crc32c_masked(src);
let compress_len = try!(self.enc.compress(src, &mut self.dst));
let (chunk_type, chunk_len) =
if compress_len >= src.len() - (src.len() / 8) {
(ChunkType::Uncompressed, 4 + src.len())
} else {
(ChunkType::Compressed, 4 + compress_len)
};
self.chunk_header[0] = chunk_type as u8;
LE::write_uint(&mut self.chunk_header[1..], chunk_len as u64, 3);
LE::write_u32(&mut self.chunk_header[4..], checksum);
try!(self.w.write_all(&self.chunk_header));
if chunk_type == ChunkType::Compressed {
try!(self.w.write_all(&self.dst[0..compress_len]))
} else {
try!(self.w.write_all(src))
};
total += src.len();
}
Ok(total)
}
}
pub struct Reader<R: Read> {
r: R,
dec: Decoder,
src: Vec<u8>,
dst: Vec<u8>,
dsts: usize,
dste: usize,
read_stream_ident: bool,
}
impl<R: Read> Reader<R> {
pub fn new(rdr: R) -> Reader<R> {
Reader {
r: rdr,
dec: Decoder::new(),
src: vec![0; *MAX_COMPRESS_BLOCK_SIZE],
dst: vec![0; MAX_BLOCK_SIZE],
dsts: 0,
dste: 0,
read_stream_ident: false,
}
}
}
impl<R: Read> Read for Reader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
macro_rules! fail {
($err:expr) => {
return Err(io::Error::from($err));
}
}
loop {
if self.dsts < self.dste {
let len = cmp::min(self.dste - self.dsts, buf.len());
let dste = self.dsts.checked_add(len).unwrap();
buf[0..len].copy_from_slice(&self.dst[self.dsts..dste]);
self.dsts = dste;
return Ok(len);
}
if !try!(read_exact_eof(&mut self.r, &mut self.src[0..4])) {
return Ok(0);
}
let ty = ChunkType::from_u8(self.src[0]);
if !self.read_stream_ident {
if ty != Ok(ChunkType::Stream) {
fail!(Error::StreamHeader { byte: self.src[0] });
}
self.read_stream_ident = true;
}
let len64 = LE::read_uint(&self.src[1..4], 3);
if len64 > self.src.len() as u64 {
fail!(Error::UnsupportedChunkLength {
len: len64,
header: false,
});
}
let len = len64 as usize;
match ty {
Err(b) if 0x02 <= b && b <= 0x7F => {
fail!(Error::UnsupportedChunkType { byte: b });
}
Err(b) if 0x80 <= b && b <= 0xFD => {
try!(self.r.read_exact(&mut self.src[0..len]));
}
Err(b) => {
unreachable!("BUG: unhandled chunk type: {}", b);
}
Ok(ChunkType::Padding) => {
try!(self.r.read_exact(&mut self.src[0..len]));
}
Ok(ChunkType::Stream) => {
if len != STREAM_BODY.len() {
fail!(Error::UnsupportedChunkLength {
len: len64,
header: true,
})
}
try!(self.r.read_exact(&mut self.src[0..len]));
if &self.src[0..len] != STREAM_BODY {
fail!(Error::StreamHeaderMismatch {
bytes: self.src[0..len].to_vec(),
});
}
}
Ok(ChunkType::Uncompressed) => {
let expected_sum = try!(self.r.read_u32::<LE>());
let n = len - 4;
if n > self.dst.len() {
fail!(Error::UnsupportedChunkLength {
len: n as u64,
header: false,
});
}
try!(self.r.read_exact(&mut self.dst[0..n]));
let got_sum = crc32c_masked(&self.dst[0..n]);
if expected_sum != got_sum {
fail!(Error::Checksum {
expected: expected_sum,
got: got_sum,
});
}
self.dsts = 0;
self.dste = n;
}
Ok(ChunkType::Compressed) => {
let expected_sum = try!(self.r.read_u32::<LE>());
let sn = len - 4;
if sn > self.src.len() {
fail!(Error::UnsupportedChunkLength {
len: len64,
header: false,
});
}
try!(self.r.read_exact(&mut self.src[0..sn]));
let dn = try!(decompress_len(&self.src));
if dn > self.dst.len() {
fail!(Error::UnsupportedChunkLength {
len: dn as u64,
header: false,
});
}
try!(self.dec.decompress(
&self.src[0..sn], &mut self.dst[0..dn]));
let got_sum = crc32c_masked(&self.dst[0..dn]);
if expected_sum != got_sum {
fail!(Error::Checksum {
expected: expected_sum,
got: got_sum,
});
}
self.dsts = 0;
self.dste = dn;
}
}
}
}
}
fn read_exact_eof<R: Read>(rdr: &mut R, buf: &mut [u8]) -> io::Result<bool> {
use std::io::ErrorKind::UnexpectedEof;
match rdr.read_exact(buf) {
Ok(()) => Ok(true),
Err(ref err) if err.kind() == UnexpectedEof => Ok(false),
Err(err) => Err(err),
}
}
fn crc32c_masked(buf: &[u8]) -> u32 {
let sum = crc32c(buf);
(sum.wrapping_shr(15) | sum.wrapping_shl(17)).wrapping_add(0xA282EAD8)
}