use bytes::{BufMut, Bytes, BytesMut};
use crate::CompressionError;
const XERIAL_HEADER: [u8; 16] = [
0x82, b'S', b'N', b'A', b'P', b'P', b'Y', 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, ];
const XERIAL_CHUNK_SIZE: usize = 32 * 1024;
pub fn compress(data: &[u8]) -> Result<Bytes, CompressionError> {
let mut out = BytesMut::with_capacity(XERIAL_HEADER.len() + data.len());
out.put_slice(&XERIAL_HEADER);
let mut encoder = snap::raw::Encoder::new();
for chunk in data.chunks(XERIAL_CHUNK_SIZE) {
let max = snap::raw::max_compress_len(chunk.len());
let mut buf = vec![0u8; max];
let n = encoder
.compress(chunk, &mut buf)
.map_err(|e| CompressionError::InvalidData(format!("snappy encode: {e}")))?;
out.put_u32(u32::try_from(n).expect("chunk size fits u32"));
out.put_slice(&buf[..n]);
}
Ok(out.freeze())
}
pub fn decompress(data: &[u8], max_output: usize) -> Result<Bytes, CompressionError> {
if data.len() < XERIAL_HEADER.len() {
return Err(CompressionError::InvalidData(
"snappy payload too short for xerial header".into(),
));
}
if data[..8] != XERIAL_HEADER[..8] {
return Err(CompressionError::InvalidData(
"snappy missing xerial magic".into(),
));
}
let mut rest = &data[XERIAL_HEADER.len()..];
let mut out = BytesMut::with_capacity(data.len().saturating_mul(2).min(max_output));
let mut decoder = snap::raw::Decoder::new();
while !rest.is_empty() {
if rest.len() < 4 {
return Err(CompressionError::InvalidData(
"snappy chunk header truncated".into(),
));
}
let len = u32::from_be_bytes([rest[0], rest[1], rest[2], rest[3]]) as usize;
rest = &rest[4..];
if rest.len() < len {
return Err(CompressionError::InvalidData(
"snappy chunk body truncated".into(),
));
}
let (block, tail) = rest.split_at(len);
rest = tail;
let max_out = snap::raw::decompress_len(block)
.map_err(|e| CompressionError::InvalidData(format!("snappy decode_len: {e}")))?;
if out.len().saturating_add(max_out) > max_output {
return Err(CompressionError::TooLarge { limit: max_output });
}
let mut buf = vec![0u8; max_out];
let n = decoder
.decompress(block, &mut buf)
.map_err(|e| CompressionError::InvalidData(format!("snappy decode: {e}")))?;
out.put_slice(&buf[..n]);
}
Ok(out.freeze())
}
#[cfg(test)]
mod tests {
use super::*;
use assert2::assert;
const HELLO: &[u8] = b"hello kafka, this is a moderately repetitive payload to compress";
const BIG_CAP: usize = 256 * 1024 * 1024;
#[test]
fn roundtrip() {
let z = compress(HELLO).unwrap();
let back = decompress(&z, BIG_CAP).unwrap();
assert!(back.as_ref() == HELLO);
}
#[test]
fn decompress_truncated_header() {
assert!(matches!(
decompress(&XERIAL_HEADER[..4], BIG_CAP),
Err(CompressionError::InvalidData(_))
));
}
#[test]
fn decompress_missing_magic() {
let bytes = [0u8; 20];
assert!(matches!(
decompress(&bytes, BIG_CAP),
Err(CompressionError::InvalidData(_))
));
}
#[test]
fn decompress_truncated_chunk() {
let mut bytes = XERIAL_HEADER.to_vec();
bytes.extend_from_slice(&[0, 0, 0, 100]); bytes.push(0); assert!(matches!(
decompress(&bytes, BIG_CAP),
Err(CompressionError::InvalidData(_))
));
}
#[test]
fn decompression_bomb_rejected() {
let bomb = vec![0u8; 64 * 1024 * 1024];
let z = compress(&bomb).unwrap();
assert!(matches!(
decompress(&z, 1024),
Err(CompressionError::TooLarge { limit: 1024 })
));
let back = decompress(&z, BIG_CAP).unwrap();
assert!(back.len() == bomb.len());
}
}