use std::io::{self, Read};
use byteorder::{BigEndian, ByteOrder};
use snap;
use crate::{Error, Result};
pub fn compress(src: &[u8]) -> Result<Vec<u8>> {
let mut buf = vec![0; snap::raw::max_compress_len(src.len())];
let len = snap::raw::Encoder::new().compress(src, &mut buf)?;
buf.truncate(len);
Ok(buf)
}
pub fn compress_xerial(src: &[u8]) -> Result<Vec<u8>> {
let compressed = compress(src)?;
let chunk_len = i32::try_from(compressed.len()).map_err(|_| Error::CodecError)?;
let mut out = Vec::with_capacity(MAGIC.len() + 8 + 4 + compressed.len());
out.extend_from_slice(MAGIC);
out.extend_from_slice(&1i32.to_be_bytes()); out.extend_from_slice(&1i32.to_be_bytes()); out.extend_from_slice(&chunk_len.to_be_bytes());
out.extend_from_slice(&compressed);
Ok(out)
}
fn uncompress_to(src: &[u8], dst: &mut Vec<u8>) -> Result<()> {
let min_len = snap::raw::decompress_len(src)?;
if min_len > 0 {
let off = dst.len();
dst.resize(off + min_len, 0);
let uncompressed_len = {
let buf = &mut dst.as_mut_slice()[off..off + min_len];
snap::raw::Decoder::new().decompress(src, buf)?
};
dst.truncate(off + uncompressed_len);
}
Ok(())
}
const MAGIC: &[u8] = &[0x82, b'S', b'N', b'A', b'P', b'P', b'Y', 0];
macro_rules! next_i32 {
($slice:expr) => {{
if $slice.len() < 4 {
return Err(Error::UnexpectedEOF);
}
{
let n = BigEndian::read_i32($slice);
$slice = &$slice[4..];
n
}
}};
}
fn validate_stream(mut stream: &[u8]) -> Result<&[u8]> {
if stream.len() < MAGIC.len() {
return Err(Error::UnexpectedEOF);
}
if &stream[..MAGIC.len()] != MAGIC {
return Err(Error::InvalidSnappy(snap::Error::Header));
}
stream = &stream[MAGIC.len()..];
let version = next_i32!(stream);
if version != 1 {
return Err(Error::InvalidSnappy(snap::Error::Header));
}
let compat = next_i32!(stream);
if compat != 1 {
return Err(Error::InvalidSnappy(snap::Error::Header));
}
Ok(stream)
}
#[test]
fn test_validate_stream() {
let header = [
0x82, 0x53, 0x4e, 0x41, 0x50, 0x50, 0x59, 0x00, 0, 0, 0, 1, 0, 0, 0, 1, 0x56,
];
let rest = validate_stream(&header).unwrap();
assert_eq!(rest, &[0x56]);
}
pub struct SnappyReader<'a> {
compressed_data: &'a [u8],
uncompressed_pos: usize,
uncompressed_chunk: Vec<u8>,
}
impl SnappyReader<'_> {
pub fn new(mut stream: &[u8]) -> Result<SnappyReader<'_>> {
stream = validate_stream(stream)?;
Ok(SnappyReader {
compressed_data: stream,
uncompressed_pos: 0,
uncompressed_chunk: Vec::new(),
})
}
fn read_inner(&mut self, buf: &mut [u8]) -> Result<usize> {
if self.uncompressed_pos < self.uncompressed_chunk.len() {
return self.read_uncompressed(buf);
}
if self.next_chunk()? {
self.read_uncompressed(buf)
} else {
Ok(0)
}
}
fn read_uncompressed(&mut self, buf: &mut [u8]) -> Result<usize> {
let n = (&self.uncompressed_chunk[self.uncompressed_pos..]).read(buf)?;
self.uncompressed_pos += n;
Ok(n)
}
fn next_chunk(&mut self) -> Result<bool> {
if self.compressed_data.is_empty() {
return Ok(false);
}
self.uncompressed_pos = 0;
let chunk_size = next_i32!(self.compressed_data);
if chunk_size <= 0 {
return Err(Error::InvalidSnappy(snap::Error::UnsupportedChunkLength {
len: u64::from(chunk_size.unsigned_abs()),
header: false,
}));
}
let chunk_size = usize::try_from(chunk_size).map_err(|_| Error::CodecError)?;
self.uncompressed_chunk.clear();
uncompress_to(
&self.compressed_data[..chunk_size],
&mut self.uncompressed_chunk,
)?;
self.compressed_data = &self.compressed_data[chunk_size..];
Ok(true)
}
fn read_to_end_inner(&mut self, buf: &mut Vec<u8>) -> Result<usize> {
let init_len = buf.len();
if self.uncompressed_pos < self.uncompressed_chunk.len() {
let rest = &self.uncompressed_chunk[self.uncompressed_pos..];
buf.extend_from_slice(rest);
self.uncompressed_pos += rest.len();
}
while !self.compressed_data.is_empty() {
let chunk_size = next_i32!(self.compressed_data);
if chunk_size <= 0 {
return Err(Error::InvalidSnappy(snap::Error::UnsupportedChunkLength {
len: u64::from(chunk_size.unsigned_abs()),
header: false,
}));
}
let chunk_size = usize::try_from(chunk_size).map_err(|_| Error::CodecError)?;
let (c1, c2) = self.compressed_data.split_at(chunk_size);
uncompress_to(c1, buf)?;
self.compressed_data = c2;
}
Ok(buf.len() - init_len)
}
}
macro_rules! to_io_error {
($expr:expr) => {
match $expr {
Ok(n) => Ok(n),
Err(Error::Io(io_error)) => Err(io_error),
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e.to_string())),
}
};
}
impl Read for SnappyReader<'_> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
to_io_error!(self.read_inner(buf))
}
fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
to_io_error!(self.read_to_end_inner(buf))
}
}
#[cfg(test)]
mod tests {
use std::io::Read;
use std::str;
use super::{SnappyReader, compress, uncompress_to};
use crate::error::{Error, Result};
fn uncompress(src: &[u8]) -> Result<Vec<u8>> {
let mut v = Vec::new();
uncompress_to(src, &mut v).map(|()| v)
}
#[test]
fn test_compress() {
let msg = "This is test".as_bytes();
let compressed = compress(msg).unwrap();
let expected = &[
12, 44, 84, 104, 105, 115, 32, 105, 115, 32, 116, 101, 115, 116,
];
assert_eq!(&compressed, expected);
}
#[test]
fn test_uncompress() {
let compressed = &[
12, 44, 84, 104, 105, 115, 32, 105, 115, 32, 116, 101, 115, 116,
];
let uncompressed = String::from_utf8(uncompress(compressed).unwrap()).unwrap();
assert_eq!(&uncompressed, "This is test");
}
#[test]
fn test_uncompress_invalid_input() {
let compressed = &[
12, 42, 84, 104, 105, 115, 32, 105, 115, 32, 116, 101, 115, 116,
];
let uncompressed = uncompress(compressed);
assert!(uncompressed.is_err());
assert!(matches!(uncompressed.err(), Some(Error::InvalidSnappy(_))));
}
static ORIGINAL: &str = include_str!("../../test-data/fetch1.txt");
static COMPRESSED: &[u8] = include_bytes!("../../test-data/fetch1.snappy.chunked.4k");
#[test]
fn test_snappy_reader_read() {
let mut buf = Vec::new();
let mut r = SnappyReader::new(COMPRESSED).unwrap();
let mut tmp_buf = [0u8; 1024];
loop {
match r.read(&mut tmp_buf).unwrap() {
0 => break,
n => buf.extend_from_slice(&tmp_buf[..n]),
}
}
assert_eq!(ORIGINAL, str::from_utf8(&buf[..]).unwrap());
}
#[test]
fn test_snappy_reader_read_to_end() {
let mut buf = Vec::new();
let mut r = SnappyReader::new(COMPRESSED).unwrap();
r.read_to_end(&mut buf).unwrap();
assert_eq!(ORIGINAL, str::from_utf8(&buf[..]).unwrap());
}
}