#[cfg(feature = "zstd")]
mod zstd_backend;
use crate::coding::{Decode, Encode};
use byteorder::{ReadBytesExt, WriteBytesExt};
use std::io::{Read, Write};
#[cfg(zstd_any)]
use std::sync::Arc;
#[cfg(feature = "zstd")]
use once_cell::race::OnceBox;
#[cfg(zstd_any)]
pub trait CompressionProvider {
fn compress(data: &[u8], level: i32) -> crate::Result<Vec<u8>>;
fn compress_with_layout(data: &[u8], level: i32) -> crate::Result<(Vec<u8>, Vec<u32>)>;
fn decompress(data: &[u8], capacity: usize) -> crate::Result<Vec<u8>>;
fn compress_with_dict(data: &[u8], level: i32, dict_raw: &[u8]) -> crate::Result<Vec<u8>>;
fn decompress_with_dict(
data: &[u8],
dict: &ZstdDictionary,
capacity: usize,
) -> crate::Result<Vec<u8>>;
}
#[cfg(feature = "zstd")]
pub type ZstdBackend = zstd_backend::ZstdProvider;
#[cfg(zstd_any)]
pub struct ZstdDictionary {
id: u64,
raw: Arc<[u8]>,
#[cfg(feature = "zstd")]
prepared: Arc<OnceBox<structured_zstd::decoding::DictionaryHandle>>,
}
#[cfg(zstd_any)]
impl Clone for ZstdDictionary {
fn clone(&self) -> Self {
Self {
id: self.id,
raw: Arc::clone(&self.raw),
#[cfg(feature = "zstd")]
prepared: Arc::clone(&self.prepared),
}
}
}
#[cfg(zstd_any)]
impl PartialEq for ZstdDictionary {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
#[cfg(zstd_any)]
impl Eq for ZstdDictionary {}
#[cfg(zstd_any)]
impl ZstdDictionary {
#[must_use]
pub fn new(raw: &[u8]) -> Self {
Self {
id: compute_dict_id(raw),
raw: Arc::from(raw),
#[cfg(feature = "zstd")]
prepared: Arc::new(OnceBox::new()),
}
}
#[cfg(feature = "zstd")]
pub(crate) fn prepared_handle(
&self,
) -> crate::Result<structured_zstd::decoding::DictionaryHandle> {
use structured_zstd::decoding::{Dictionary, DictionaryHandle};
const DICT_MAGIC: [u8; 4] = [0x37, 0xA4, 0x30, 0xEC];
self.prepared
.get_or_try_init(|| -> crate::Result<Box<DictionaryHandle>> {
let handle = if self.raw.starts_with(&DICT_MAGIC) {
DictionaryHandle::decode_dict(&self.raw)
.map_err(|e| crate::Error::Io(std::io::Error::other(e)))?
} else {
#[expect(
clippy::cast_possible_truncation,
reason = "intentional: lower 32 bits of xxh3 as internal dict id (matches compressor)"
)]
let raw_content_id = (self.id as u32).max(1);
let dict = Dictionary::from_raw_content(raw_content_id, self.raw.to_vec())
.map_err(|e| crate::Error::Io(std::io::Error::other(e)))?;
DictionaryHandle::from_dictionary(dict)
};
Ok(Box::new(handle))
})
.cloned()
}
#[must_use]
#[expect(
clippy::cast_possible_truncation,
reason = "intentional: public API returns 32-bit fingerprint"
)]
pub fn id(&self) -> u32 {
self.id as u32
}
#[cfg(feature = "zstd")]
#[must_use]
pub(crate) fn id64(&self) -> u64 {
self.id
}
#[must_use]
pub fn raw(&self) -> &[u8] {
&self.raw
}
}
#[cfg(zstd_any)]
impl std::fmt::Debug for ZstdDictionary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ZstdDictionary")
.field("id", &format_args!("{:#018x}", self.id))
.field("size", &self.raw.len())
.finish_non_exhaustive() }
}
#[cfg(zstd_any)]
fn compute_dict_id(raw: &[u8]) -> u64 {
xxhash_rust::xxh3::xxh3_64(raw)
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum CompressionType {
None,
#[cfg(feature = "lz4")]
Lz4,
#[cfg(zstd_any)]
Zstd(i32),
#[cfg(zstd_any)]
ZstdDict {
level: i32,
dict_id: u32,
},
}
impl CompressionType {
#[must_use]
pub fn dict_id(&self) -> u32 {
#[cfg(zstd_any)]
if let Self::ZstdDict { dict_id, .. } = self {
return *dict_id;
}
0
}
#[cfg(zstd_any)]
fn validate_zstd_level(level: i32) -> crate::Result<()> {
if !(1..=22).contains(&level) {
return Err(crate::Error::Io(std::io::Error::other(format!(
"invalid zstd compression level {level}, expected 1..=22"
))));
}
Ok(())
}
#[cfg(zstd_any)]
pub fn zstd(level: i32) -> crate::Result<Self> {
Self::validate_zstd_level(level)?;
Ok(Self::Zstd(level))
}
#[cfg(zstd_any)]
pub fn zstd_dict(level: i32, dict_id: u32) -> crate::Result<Self> {
Self::validate_zstd_level(level)?;
Ok(Self::ZstdDict { level, dict_id })
}
}
impl std::fmt::Display for CompressionType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
Self::None => "none",
#[cfg(feature = "lz4")]
Self::Lz4 => "lz4",
#[cfg(zstd_any)]
Self::Zstd(_) => "zstd",
#[cfg(zstd_any)]
Self::ZstdDict { .. } => "zstd+dict",
}
)
}
}
impl Encode for CompressionType {
fn encode_into<W: Write>(&self, writer: &mut W) -> Result<(), crate::Error> {
match self {
Self::None => {
writer.write_u8(0)?;
}
#[cfg(feature = "lz4")]
Self::Lz4 => {
writer.write_u8(1)?;
}
#[cfg(zstd_any)]
Self::Zstd(level) => {
writer.write_u8(3)?;
debug_assert!(
(1..=22).contains(level),
"zstd level {level} outside valid range 1..=22"
);
#[expect(
clippy::cast_possible_truncation,
reason = "level range 1..=22 fits i8"
)]
writer.write_i8(*level as i8)?;
}
#[cfg(zstd_any)]
Self::ZstdDict { level, dict_id } => {
writer.write_u8(4)?;
debug_assert!(
(1..=22).contains(level),
"zstd level {level} outside valid range 1..=22"
);
#[expect(
clippy::cast_possible_truncation,
reason = "level range 1..=22 fits i8"
)]
writer.write_i8(*level as i8)?;
byteorder::WriteBytesExt::write_u32::<byteorder::LittleEndian>(writer, *dict_id)?;
}
}
Ok(())
}
}
impl Decode for CompressionType {
fn decode_from<R: Read>(reader: &mut R) -> Result<Self, crate::Error> {
let tag = reader.read_u8()?;
match tag {
0 => Ok(Self::None),
#[cfg(feature = "lz4")]
1 => Ok(Self::Lz4),
#[cfg(zstd_any)]
3 => {
let level = i32::from(reader.read_i8()?);
Self::validate_zstd_level(level)?;
Ok(Self::Zstd(level))
}
#[cfg(zstd_any)]
4 => {
let level = i32::from(reader.read_i8()?);
Self::validate_zstd_level(level)?;
let dict_id = byteorder::ReadBytesExt::read_u32::<byteorder::LittleEndian>(reader)?;
Ok(Self::ZstdDict { level, dict_id })
}
tag => Err(crate::Error::InvalidTag(("CompressionType", tag))),
}
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::indexing_slicing,
clippy::useless_vec,
clippy::expect_used,
reason = "test code"
)]
mod tests {
use super::*;
use test_log::test;
#[test]
fn compression_serialize_none() {
let serialized = CompressionType::None.encode_into_vec();
assert_eq!(1, serialized.len());
}
#[cfg(feature = "lz4")]
mod lz4 {
use super::*;
use test_log::test;
#[test]
fn compression_serialize_lz4() {
let serialized = CompressionType::Lz4.encode_into_vec();
assert_eq!(1, serialized.len());
}
}
#[cfg(zstd_any)]
mod zstd {
use super::*;
use test_log::test;
#[test]
fn compression_serialize_zstd() {
let serialized = CompressionType::Zstd(3).encode_into_vec();
assert_eq!(2, serialized.len());
}
#[test]
fn compression_roundtrip_zstd() {
for level in [1, 3, 9, 19] {
let original = CompressionType::Zstd(level);
let serialized = original.encode_into_vec();
let decoded =
CompressionType::decode_from(&mut &serialized[..]).expect("decode failed");
assert_eq!(original, decoded);
}
}
#[test]
fn compression_display_zstd() {
assert_eq!(format!("{}", CompressionType::Zstd(3)), "zstd");
}
#[test]
fn compression_zstd_rejects_invalid_level() {
for invalid_level in [0, 23, -1, 200] {
let result = CompressionType::zstd(invalid_level);
assert!(result.is_err(), "level {invalid_level} should be rejected");
}
}
#[test]
fn compression_zstd_decode_rejects_invalid_level() {
let valid = CompressionType::Zstd(3).encode_into_vec();
assert_eq!(valid.len(), 2);
let corrupted = vec![valid[0], 0];
let result = CompressionType::decode_from(&mut &corrupted[..]);
assert!(result.is_err(), "level 0 should be rejected on decode");
let corrupted = vec![valid[0], 23];
let result = CompressionType::decode_from(&mut &corrupted[..]);
assert!(result.is_err(), "level 23 should be rejected on decode");
}
#[test]
fn compression_serialize_zstd_dict() {
let serialized = CompressionType::ZstdDict {
level: 3,
dict_id: 0xDEAD_BEEF,
}
.encode_into_vec();
assert_eq!(serialized, [4, 3, 0xEF, 0xBE, 0xAD, 0xDE]);
}
#[test]
fn compression_roundtrip_zstd_dict() {
for level in [1, 3, 9, 19] {
for dict_id in [0, 1, 0xDEAD_BEEF, u32::MAX] {
let original = CompressionType::ZstdDict { level, dict_id };
let serialized = original.encode_into_vec();
let decoded =
CompressionType::decode_from(&mut &serialized[..]).expect("decode failed");
assert_eq!(original, decoded);
}
}
}
#[test]
fn compression_display_zstd_dict() {
assert_eq!(
format!(
"{}",
CompressionType::ZstdDict {
level: 3,
dict_id: 42
}
),
"zstd+dict"
);
}
#[test]
fn compression_zstd_dict_rejects_invalid_level() {
for invalid_level in [0, 23, -1, 200] {
let result = CompressionType::zstd_dict(invalid_level, 42);
assert!(result.is_err(), "level {invalid_level} should be rejected");
}
}
#[test]
fn compression_zstd_dict_decode_rejects_invalid_level() {
let mut buf = CompressionType::ZstdDict {
level: 3,
dict_id: 42,
}
.encode_into_vec();
assert_eq!(buf[0], 4); buf[1] = 0;
let result = CompressionType::decode_from(&mut &buf[..]);
assert!(result.is_err(), "level 0 should be rejected on decode");
}
#[test]
fn zstd_dictionary_id_deterministic() {
let dict_bytes = b"sample dictionary content for testing";
let d1 = ZstdDictionary::new(dict_bytes);
let d2 = ZstdDictionary::new(dict_bytes);
assert_eq!(d1.id(), d2.id());
}
#[test]
fn zstd_dictionary_different_content_different_id() {
let d1 = ZstdDictionary::new(b"dictionary one");
let d2 = ZstdDictionary::new(b"dictionary two");
assert_ne!(d1.id(), d2.id());
}
#[test]
fn zstd_dictionary_raw_roundtrip() {
let raw = b"my dictionary bytes";
let dict = ZstdDictionary::new(raw);
assert_eq!(dict.raw(), raw);
}
#[test]
fn zstd_dictionary_debug_format() {
let dict = ZstdDictionary::new(b"test");
let debug = format!("{dict:?}");
assert!(debug.contains("ZstdDictionary"));
assert!(debug.contains("size: 4"));
}
#[cfg(feature = "zstd")]
#[test]
fn prepared_handle_raw_content_dict_parses_and_memoizes() {
let dict = ZstdDictionary::new(b"raw-content training bytes here");
let h1 = dict
.prepared_handle()
.expect("first call must parse raw-content dict");
let h2 = dict
.prepared_handle()
.expect("second call must hit the cache");
assert_eq!(
h1.id(),
h2.id(),
"cached handle must report the same dict id"
);
}
#[cfg(feature = "zstd")]
#[test]
fn prepared_handle_rejects_corrupted_finalized_magic() {
let mut bad = vec![0x37, 0xA4, 0x30, 0xEC]; bad.extend_from_slice(&[0xFF; 16]); let dict = ZstdDictionary::new(&bad);
let result = dict.prepared_handle();
assert!(
result.is_err(),
"corrupted finalized dict must surface parse error",
);
assert!(
dict.prepared.get().is_none(),
"failed parse must NOT populate the OnceCell — retry-on-failure contract",
);
}
#[cfg(feature = "zstd")]
#[test]
fn prepared_handle_shared_across_clones() {
let dict_a = ZstdDictionary::new(b"shared dict bytes for clone test");
let dict_b = dict_a.clone();
let _ = dict_a
.prepared_handle()
.expect("parse via dict_a must succeed");
let h_b = dict_b
.prepared_handle()
.expect("dict_b must see cached handle");
assert_eq!(h_b.id(), dict_a.id());
assert!(
dict_b.prepared.get().is_some(),
"OnceCell must be populated on dict_b after dict_a parsed",
);
}
#[cfg(feature = "zstd")]
#[test]
fn prepared_handle_is_lazy_and_populated_after_first_call() {
let dict = ZstdDictionary::new(b"laziness test bytes");
assert!(
dict.prepared.get().is_none(),
"ZstdDictionary::new must NOT eagerly parse the dictionary",
);
let _ = dict.prepared_handle().expect("explicit parse must succeed");
assert!(
dict.prepared.get().is_some(),
"OnceCell must be populated after first prepared_handle call",
);
}
}
}