use std::io;
#[derive(Debug, Clone, Copy)]
pub struct CompressionLevel(pub i32);
impl CompressionLevel {
pub const FAST: Self = Self(1);
pub const DEFAULT: Self = Self(3);
pub const BETTER: Self = Self(9);
pub const BEST: Self = Self(19);
pub const MAX: Self = Self(22);
}
impl Default for CompressionLevel {
fn default() -> Self {
Self::FAST }
}
#[derive(Clone)]
pub struct CompressionDict {
raw_dict: crate::directories::OwnedBytes,
}
impl CompressionDict {
pub fn train(samples: &[&[u8]], dict_size: usize) -> io::Result<Self> {
let raw_dict = zstd::dict::from_samples(samples, dict_size).map_err(io::Error::other)?;
Ok(Self {
raw_dict: crate::directories::OwnedBytes::new(raw_dict),
})
}
pub fn from_bytes(bytes: Vec<u8>) -> Self {
Self {
raw_dict: crate::directories::OwnedBytes::new(bytes),
}
}
pub fn from_owned_bytes(bytes: crate::directories::OwnedBytes) -> Self {
Self { raw_dict: bytes }
}
pub fn as_bytes(&self) -> &[u8] {
self.raw_dict.as_slice()
}
pub fn len(&self) -> usize {
self.raw_dict.len()
}
pub fn is_empty(&self) -> bool {
self.raw_dict.is_empty()
}
}
pub fn compress(data: &[u8], level: CompressionLevel) -> io::Result<Vec<u8>> {
thread_local! {
static COMPRESSOR: std::cell::RefCell<Option<(i32, zstd::bulk::Compressor<'static>)>> =
const { std::cell::RefCell::new(None) };
}
COMPRESSOR.with(|cell| {
let mut slot = cell.borrow_mut();
if slot.as_ref().is_none_or(|(l, _)| *l != level.0) {
let cmp = zstd::bulk::Compressor::new(level.0).map_err(io::Error::other)?;
*slot = Some((level.0, cmp));
}
slot.as_mut()
.unwrap()
.1
.compress(data)
.map_err(io::Error::other)
})
}
pub fn compress_with_dict(
data: &[u8],
level: CompressionLevel,
dict: &CompressionDict,
) -> io::Result<Vec<u8>> {
thread_local! {
static DICT_CMP: std::cell::RefCell<Option<(usize, i32, zstd::bulk::Compressor<'static>)>> =
const { std::cell::RefCell::new(None) };
}
let dict_key = dict.as_bytes().as_ptr() as usize;
DICT_CMP.with(|cell| {
let mut slot = cell.borrow_mut();
if slot
.as_ref()
.is_none_or(|(k, l, _)| *k != dict_key || *l != level.0)
{
let cmp = zstd::bulk::Compressor::with_dictionary(level.0, dict.as_bytes())
.map_err(io::Error::other)?;
*slot = Some((dict_key, level.0, cmp));
}
slot.as_mut()
.unwrap()
.2
.compress(data)
.map_err(io::Error::other)
})
}
const DECOMPRESS_CAPACITY: usize = 512 * 1024;
pub fn decompress(data: &[u8]) -> io::Result<Vec<u8>> {
thread_local! {
static DECOMPRESSOR: std::cell::RefCell<zstd::bulk::Decompressor<'static>> =
std::cell::RefCell::new(zstd::bulk::Decompressor::new().unwrap());
}
DECOMPRESSOR.with(|dc| {
dc.borrow_mut()
.decompress(data, DECOMPRESS_CAPACITY)
.or_else(|_| zstd::decode_all(data))
})
}
pub fn decompress_with_dict(data: &[u8], dict: &CompressionDict) -> io::Result<Vec<u8>> {
thread_local! {
static DICT_DC: std::cell::RefCell<Option<(usize, zstd::bulk::Decompressor<'static>)>> =
const { std::cell::RefCell::new(None) };
}
let dict_key = dict.as_bytes().as_ptr() as usize;
DICT_DC.with(|cell| {
let mut slot = cell.borrow_mut();
if slot.as_ref().is_none_or(|(k, _)| *k != dict_key) {
let dc = zstd::bulk::Decompressor::with_dictionary(dict.as_bytes())
.map_err(io::Error::other)?;
*slot = Some((dict_key, dc));
}
slot.as_mut()
.unwrap()
.1
.decompress(data, DECOMPRESS_CAPACITY)
.or_else(|_| {
let mut decoder = zstd::Decoder::with_dictionary(data, dict.as_bytes())?;
let mut output = Vec::new();
io::Read::read_to_end(&mut decoder, &mut output)?;
Ok(output)
})
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_roundtrip() {
let data = b"Hello, World! This is a test of compression.".repeat(100);
let compressed = compress(&data, CompressionLevel::default()).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(data, decompressed.as_slice());
assert!(compressed.len() < data.len());
}
#[test]
fn test_empty_data() {
let data: &[u8] = &[];
let compressed = compress(data, CompressionLevel::default()).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert!(decompressed.is_empty());
}
#[test]
fn test_compression_levels() {
let data = b"Test data for compression levels".repeat(100);
for level in [1, 3, 9, 19] {
let compressed = compress(&data, CompressionLevel(level)).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(data.as_slice(), decompressed.as_slice());
}
}
}