use super::common::BinaryTransform;
use bytes::{Buf, Bytes};
use zstd::bulk::{compress, decompress, Compressor, Decompressor};
use zstd::stream::decode_all;
#[derive(Default)]
pub struct ZstdCompressor {
level: i32,
inner: Compressor<'static>,
}
impl ZstdCompressor {
pub fn new(level: i32) -> std::io::Result<Self> {
let mut inner = Compressor::default();
if level != 0 {
inner.set_dictionary(level, &[])?;
}
Ok(Self { level, inner })
}
}
impl BinaryTransform for ZstdCompressor {
fn transform(&self, data: Bytes) -> Result<Bytes, std::io::Error> {
compress(&data, self.level).map(Bytes::from)
}
fn transform_mut(&mut self, data: Bytes) -> Result<Bytes, std::io::Error> {
self.inner.compress(&data).map(Bytes::from)
}
}
#[derive(Default)]
pub struct ZstdDecompressor {
inner: Decompressor<'static>,
}
impl ZstdCompressor {}
impl BinaryTransform for ZstdDecompressor {
fn transform(&self, data: Bytes) -> Result<Bytes, std::io::Error> {
if let Ok(bound) = zstd_safe::decompress_bound(&data) {
decompress(&data, bound as usize).map(Bytes::from)
} else {
decode_all(data.reader()).map(Bytes::from)
}
}
fn transform_mut(&mut self, data: Bytes) -> Result<Bytes, std::io::Error> {
if let Ok(bound) = zstd_safe::decompress_bound(&data) {
self.inner
.decompress(&data, bound as usize)
.map(Bytes::from)
} else {
decode_all(data.reader()).map(Bytes::from)
}
}
}
#[cfg(test)]
mod test {
use super::{ZstdCompressor, ZstdDecompressor};
use crate::transform::BinaryTransform;
use bytes::Bytes;
use lipsum::lipsum;
#[test]
fn test_roundtrip_default() {
let compressor = ZstdCompressor::default();
let decompressor = ZstdDecompressor::default();
let data = Bytes::from(lipsum(1000));
let compressed = compressor
.transform(data.clone())
.expect("should compress via zstd");
let uncompressed = decompressor
.transform(compressed)
.expect("should decompress from zstd");
assert_eq!(data, uncompressed);
}
#[test]
fn test_roundtrip_mut_default() {
let mut compressor = ZstdCompressor::default();
let mut decompressor = ZstdDecompressor::default();
let data = Bytes::from(lipsum(1000));
let compressed = compressor
.transform_mut(data.clone())
.expect("should compress via zstd");
let uncompressed = decompressor
.transform_mut(compressed)
.expect("should decompress from zstd");
assert_eq!(data, uncompressed);
}
#[test]
fn test_compressor_adds_size() {
let mut compressor = ZstdCompressor::default();
let data = Bytes::from(lipsum(1000));
let compressed = compressor
.transform_mut(data.clone())
.expect("should compress via zstd");
let size_bound = zstd_safe::decompress_bound(&compressed)
.expect("should be able to determine decompress bound");
assert_eq!(size_bound as usize, data.len());
}
}