use crate::{Error, Result};
pub const MAX_DICT_SIZE: usize = 112 * 1024;
pub const N_TRAIN: usize = 32;
pub const DEFAULT_LEVEL: i32 = 3;
const ZSTD_MAGIC: [u8; 4] = [0x37, 0xA4, 0x30, 0xEC];
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ZstdDictionary(Vec<u8>);
impl ZstdDictionary {
fn new_checked(bytes: Vec<u8>) -> Result<Self> {
if bytes.is_empty() {
return Err(Error::CompressionError("zstd: empty dictionary".into()));
}
if bytes.len() > MAX_DICT_SIZE {
return Err(Error::CompressionError(format!(
"zstd: dictionary size {} exceeds MAX_DICT_SIZE ({})",
bytes.len(),
MAX_DICT_SIZE
)));
}
if bytes.len() < 4 || bytes[0..4] != ZSTD_MAGIC {
return Err(Error::CompressionError(
"zstd: invalid dictionary magic (expected 0xEC30A437)".into(),
));
}
Ok(Self(bytes))
}
pub fn from_bytes(bytes: Vec<u8>) -> Result<Self> {
Self::new_checked(bytes)
}
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
pub struct ZstdDictCompressor;
impl ZstdDictCompressor {
pub fn train(samples: &[Vec<u8>], max_dict_size: usize) -> Result<ZstdDictionary> {
if samples.len() < 8 {
return Err(Error::CompressionError(format!(
"zstd: insufficient samples ({} provided, need >= 8)",
samples.len()
)));
}
let cap = max_dict_size.min(MAX_DICT_SIZE);
let bytes = zstd::dict::from_samples(samples, cap)
.map_err(|e| Error::CompressionError(format!("zstd: train: {e}")))?;
ZstdDictionary::new_checked(bytes)
}
pub fn compress(data: &[u8], dict: &ZstdDictionary) -> Result<Vec<u8>> {
Self::compress_with_level(data, dict, DEFAULT_LEVEL)
}
pub fn compress_with_level(data: &[u8], dict: &ZstdDictionary, level: i32) -> Result<Vec<u8>> {
let mut compressor = zstd::bulk::Compressor::with_dictionary(level, dict.as_bytes())
.map_err(|e| Error::CompressionError(format!("zstd: compressor init: {e}")))?;
compressor
.compress(data)
.map_err(|e| Error::CompressionError(format!("zstd: compress: {e}")))
}
pub fn decompress(data: &[u8], dict: &ZstdDictionary, max_output: usize) -> Result<Vec<u8>> {
let mut decompressor = zstd::bulk::Decompressor::with_dictionary(dict.as_bytes())
.map_err(|e| Error::CompressionError(format!("zstd: decompressor init: {e}")))?;
decompressor
.decompress(data, max_output)
.map_err(|e| Error::CompressionError(format!("zstd: decompress: {e}")))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_samples(count: usize) -> Vec<Vec<u8>> {
(0..count)
.map(|i| {
format!(
r#"{{"id":{i},"name":"item-{i}","value":{val},"active":true}}"#,
val = i * 10
)
.into_bytes()
})
.collect()
}
fn repetitive_json() -> Vec<u8> {
let item = br#"{"id":1,"name":"test","value":42,"active":true}"#;
item.repeat(100)
}
#[test]
fn test_train_compress_decompress_roundtrip() {
let samples = make_samples(N_TRAIN);
let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
let data = repetitive_json();
let compressed = ZstdDictCompressor::compress(&data, &dict).unwrap();
let decompressed =
ZstdDictCompressor::decompress(&compressed, &dict, data.len() * 2).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_train_insufficient_samples_error() {
let samples = make_samples(3);
let err = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("insufficient samples"),
"error should mention insufficient samples: {msg}"
);
}
#[test]
fn test_train_clamps_to_max_dict_size() {
let samples = make_samples(N_TRAIN);
let dict = ZstdDictCompressor::train(&samples, usize::MAX).unwrap();
assert!(
dict.len() <= MAX_DICT_SIZE,
"dict size {} exceeds MAX_DICT_SIZE",
dict.len()
);
}
#[test]
fn test_from_bytes_rejects_empty() {
assert!(ZstdDictionary::from_bytes(vec![]).is_err());
}
#[test]
fn test_from_bytes_rejects_invalid_magic() {
assert!(ZstdDictionary::from_bytes(vec![0x00, 0x01, 0x02, 0x03]).is_err());
}
#[test]
fn test_from_bytes_rejects_oversized() {
let mut bytes = ZSTD_MAGIC.to_vec();
bytes.extend(std::iter::repeat_n(0u8, MAX_DICT_SIZE));
assert!(ZstdDictionary::from_bytes(bytes).is_err());
}
#[test]
fn test_compress_with_level() {
let samples = make_samples(N_TRAIN);
let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
let data = repetitive_json();
for level in [1, 9] {
let c = ZstdDictCompressor::compress_with_level(&data, &dict, level).unwrap();
let d = ZstdDictCompressor::decompress(&c, &dict, data.len() * 2).unwrap();
assert_eq!(d, data, "level {level} roundtrip failed");
}
}
#[test]
fn test_dictionary_equality() {
let samples = make_samples(N_TRAIN);
let d1 = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
let d2 = d1.clone();
assert_eq!(d1, d2);
}
#[test]
fn test_is_empty_is_always_false_for_valid_dict() {
let samples = make_samples(N_TRAIN);
let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
assert!(!dict.is_empty());
}
}