#[cfg(feature = "zstd")]
mod zstd_ffi;
#[cfg(all(feature = "zstd-pure", not(feature = "zstd")))]
mod zstd_pure;
use crate::coding::{Decode, Encode};
use byteorder::{ReadBytesExt, WriteBytesExt};
use std::io::{Read, Write};
#[cfg(zstd_any)]
use std::sync::Arc;
#[cfg(zstd_any)]
pub trait CompressionProvider {
fn compress(data: &[u8], level: i32) -> crate::Result<Vec<u8>>;
fn decompress(data: &[u8], capacity: usize) -> crate::Result<Vec<u8>>;
fn compress_with_dict(data: &[u8], level: i32, dict_raw: &[u8]) -> crate::Result<Vec<u8>>;
fn decompress_with_dict(
data: &[u8],
dict_raw: &[u8],
capacity: usize,
) -> crate::Result<Vec<u8>>;
}
#[cfg(feature = "zstd")]
pub type ZstdBackend = zstd_ffi::ZstdFfiProvider;
#[cfg(all(feature = "zstd-pure", not(feature = "zstd")))]
pub type ZstdBackend = zstd_pure::ZstdPureProvider;
#[cfg(zstd_any)]
#[derive(Clone)]
pub struct ZstdDictionary {
id: u32,
raw: Arc<[u8]>,
}
#[cfg(zstd_any)]
impl ZstdDictionary {
#[must_use]
pub fn new(raw: &[u8]) -> Self {
Self {
id: compute_dict_id(raw),
raw: Arc::from(raw),
}
}
#[must_use]
pub fn id(&self) -> u32 {
self.id
}
#[must_use]
pub fn raw(&self) -> &[u8] {
&self.raw
}
}
#[cfg(zstd_any)]
impl std::fmt::Debug for ZstdDictionary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ZstdDictionary")
.field("id", &format_args!("{:#010x}", self.id))
.field("size", &self.raw.len())
.finish()
}
}
#[cfg(zstd_any)]
#[expect(
clippy::cast_possible_truncation,
reason = "intentionally truncated to 32-bit fingerprint"
)]
fn compute_dict_id(raw: &[u8]) -> u32 {
xxhash_rust::xxh3::xxh3_64(raw) as u32
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum CompressionType {
None,
#[cfg(feature = "lz4")]
Lz4,
#[cfg(zstd_any)]
Zstd(i32),
#[cfg(zstd_any)]
ZstdDict {
level: i32,
dict_id: u32,
},
}
impl CompressionType {
#[cfg(zstd_any)]
fn validate_zstd_level(level: i32) -> crate::Result<()> {
if !(1..=22).contains(&level) {
return Err(crate::Error::Io(std::io::Error::other(format!(
"invalid zstd compression level {level}, expected 1..=22"
))));
}
Ok(())
}
#[cfg(zstd_any)]
pub fn zstd(level: i32) -> crate::Result<Self> {
Self::validate_zstd_level(level)?;
Ok(Self::Zstd(level))
}
#[cfg(zstd_any)]
pub fn zstd_dict(level: i32, dict_id: u32) -> crate::Result<Self> {
Self::validate_zstd_level(level)?;
Ok(Self::ZstdDict { level, dict_id })
}
}
impl std::fmt::Display for CompressionType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
Self::None => "none",
#[cfg(feature = "lz4")]
Self::Lz4 => "lz4",
#[cfg(zstd_any)]
Self::Zstd(_) => "zstd",
#[cfg(zstd_any)]
Self::ZstdDict { .. } => "zstd+dict",
}
)
}
}
impl Encode for CompressionType {
fn encode_into<W: Write>(&self, writer: &mut W) -> Result<(), crate::Error> {
match self {
Self::None => {
writer.write_u8(0)?;
}
#[cfg(feature = "lz4")]
Self::Lz4 => {
writer.write_u8(1)?;
}
#[cfg(zstd_any)]
Self::Zstd(level) => {
writer.write_u8(3)?;
debug_assert!(
(1..=22).contains(level),
"zstd level {level} outside valid range 1..=22"
);
#[expect(
clippy::cast_possible_truncation,
reason = "level range 1..=22 fits i8"
)]
writer.write_i8(*level as i8)?;
}
#[cfg(zstd_any)]
Self::ZstdDict { level, dict_id } => {
writer.write_u8(4)?;
debug_assert!(
(1..=22).contains(level),
"zstd level {level} outside valid range 1..=22"
);
#[expect(
clippy::cast_possible_truncation,
reason = "level range 1..=22 fits i8"
)]
writer.write_i8(*level as i8)?;
byteorder::WriteBytesExt::write_u32::<byteorder::LittleEndian>(writer, *dict_id)?;
}
}
Ok(())
}
}
impl Decode for CompressionType {
fn decode_from<R: Read>(reader: &mut R) -> Result<Self, crate::Error> {
let tag = reader.read_u8()?;
match tag {
0 => Ok(Self::None),
#[cfg(feature = "lz4")]
1 => Ok(Self::Lz4),
#[cfg(zstd_any)]
3 => {
let level = i32::from(reader.read_i8()?);
Self::validate_zstd_level(level)?;
Ok(Self::Zstd(level))
}
#[cfg(zstd_any)]
4 => {
let level = i32::from(reader.read_i8()?);
Self::validate_zstd_level(level)?;
let dict_id = byteorder::ReadBytesExt::read_u32::<byteorder::LittleEndian>(reader)?;
Ok(Self::ZstdDict { level, dict_id })
}
tag => Err(crate::Error::InvalidTag(("CompressionType", tag))),
}
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::indexing_slicing,
clippy::useless_vec,
clippy::expect_used,
reason = "test code"
)]
mod tests {
use super::*;
use test_log::test;
#[test]
fn compression_serialize_none() {
let serialized = CompressionType::None.encode_into_vec();
assert_eq!(1, serialized.len());
}
#[cfg(feature = "lz4")]
mod lz4 {
use super::*;
use test_log::test;
#[test]
fn compression_serialize_lz4() {
let serialized = CompressionType::Lz4.encode_into_vec();
assert_eq!(1, serialized.len());
}
}
#[cfg(zstd_any)]
mod zstd {
use super::*;
use test_log::test;
#[test]
fn compression_serialize_zstd() {
let serialized = CompressionType::Zstd(3).encode_into_vec();
assert_eq!(2, serialized.len());
}
#[test]
fn compression_roundtrip_zstd() {
for level in [1, 3, 9, 19] {
let original = CompressionType::Zstd(level);
let serialized = original.encode_into_vec();
let decoded =
CompressionType::decode_from(&mut &serialized[..]).expect("decode failed");
assert_eq!(original, decoded);
}
}
#[test]
fn compression_display_zstd() {
assert_eq!(format!("{}", CompressionType::Zstd(3)), "zstd");
}
#[test]
fn compression_zstd_rejects_invalid_level() {
for invalid_level in [0, 23, -1, 200] {
let result = CompressionType::zstd(invalid_level);
assert!(result.is_err(), "level {invalid_level} should be rejected");
}
}
#[test]
fn compression_zstd_decode_rejects_invalid_level() {
let valid = CompressionType::Zstd(3).encode_into_vec();
assert_eq!(valid.len(), 2);
let corrupted = vec![valid[0], 0];
let result = CompressionType::decode_from(&mut &corrupted[..]);
assert!(result.is_err(), "level 0 should be rejected on decode");
let corrupted = vec![valid[0], 23];
let result = CompressionType::decode_from(&mut &corrupted[..]);
assert!(result.is_err(), "level 23 should be rejected on decode");
}
#[test]
fn compression_serialize_zstd_dict() {
let serialized = CompressionType::ZstdDict {
level: 3,
dict_id: 0xDEAD_BEEF,
}
.encode_into_vec();
assert_eq!(serialized, [4, 3, 0xEF, 0xBE, 0xAD, 0xDE]);
}
#[test]
fn compression_roundtrip_zstd_dict() {
for level in [1, 3, 9, 19] {
for dict_id in [0, 1, 0xDEAD_BEEF, u32::MAX] {
let original = CompressionType::ZstdDict { level, dict_id };
let serialized = original.encode_into_vec();
let decoded =
CompressionType::decode_from(&mut &serialized[..]).expect("decode failed");
assert_eq!(original, decoded);
}
}
}
#[test]
fn compression_display_zstd_dict() {
assert_eq!(
format!(
"{}",
CompressionType::ZstdDict {
level: 3,
dict_id: 42
}
),
"zstd+dict"
);
}
#[test]
fn compression_zstd_dict_rejects_invalid_level() {
for invalid_level in [0, 23, -1, 200] {
let result = CompressionType::zstd_dict(invalid_level, 42);
assert!(result.is_err(), "level {invalid_level} should be rejected");
}
}
#[test]
fn compression_zstd_dict_decode_rejects_invalid_level() {
let mut buf = CompressionType::ZstdDict {
level: 3,
dict_id: 42,
}
.encode_into_vec();
assert_eq!(buf[0], 4); buf[1] = 0;
let result = CompressionType::decode_from(&mut &buf[..]);
assert!(result.is_err(), "level 0 should be rejected on decode");
}
#[test]
fn zstd_dictionary_id_deterministic() {
let dict_bytes = b"sample dictionary content for testing";
let d1 = ZstdDictionary::new(dict_bytes);
let d2 = ZstdDictionary::new(dict_bytes);
assert_eq!(d1.id(), d2.id());
}
#[test]
fn zstd_dictionary_different_content_different_id() {
let d1 = ZstdDictionary::new(b"dictionary one");
let d2 = ZstdDictionary::new(b"dictionary two");
assert_ne!(d1.id(), d2.id());
}
#[test]
fn zstd_dictionary_raw_roundtrip() {
let raw = b"my dictionary bytes";
let dict = ZstdDictionary::new(raw);
assert_eq!(dict.raw(), raw);
}
#[test]
fn zstd_dictionary_debug_format() {
let dict = ZstdDictionary::new(b"test");
let debug = format!("{dict:?}");
assert!(debug.contains("ZstdDictionary"));
assert!(debug.contains("size: 4"));
}
}
}