use std::borrow::Cow;
use crate::config::Compression;
use crate::error::{PersistError, Result};
pub(crate) const SCHEME_NONE: u8 = 0;
pub(crate) const SCHEME_ZSTD: u8 = 1;
pub(crate) const SCHEME_LZ4: u8 = 2;
const MAX_DECOMPRESS_RATIO: usize = 4096;
pub(crate) fn scheme_tag(scheme: Compression) -> u8 {
match scheme {
Compression::None => SCHEME_NONE,
Compression::Zstd { .. } => SCHEME_ZSTD,
Compression::Lz4 => SCHEME_LZ4,
}
}
pub(crate) fn encode(scheme: Compression, raw: &[u8]) -> Result<Cow<'_, [u8]>> {
match scheme {
Compression::None => Ok(Cow::Borrowed(raw)),
Compression::Zstd { level } => encode_zstd(raw, level),
Compression::Lz4 => encode_lz4(raw),
}
}
pub(crate) fn decode(tag: u8, data: &[u8], uncompressed_len: usize) -> Result<Vec<u8>> {
if tag != SCHEME_NONE && uncompressed_len > data.len().saturating_mul(MAX_DECOMPRESS_RATIO) {
return Err(PersistError::InvalidPayload {
reason: "declared uncompressed size exceeds the decompression-ratio guard",
});
}
match tag {
SCHEME_NONE => Ok(data.to_vec()),
SCHEME_ZSTD => decode_zstd(data, uncompressed_len),
SCHEME_LZ4 => decode_lz4(data, uncompressed_len),
_ => Err(PersistError::InvalidPayload {
reason: "unknown compression scheme tag",
}),
}
}
#[cfg(feature = "zstd")]
fn encode_zstd(raw: &[u8], level: i32) -> Result<Cow<'_, [u8]>> {
if !(1..=22).contains(&level) {
return Err(PersistError::Compression {
reason: "zstd level must be in 1..=22",
});
}
zstd::encode_all(raw, level)
.map(Cow::Owned)
.map_err(|_| PersistError::Compression {
reason: "zstd compression failed",
})
}
#[cfg(not(feature = "zstd"))]
fn encode_zstd(_raw: &[u8], _level: i32) -> Result<Cow<'_, [u8]>> {
Err(PersistError::Unsupported {
feature: "Zstd compression",
available_in: "the `zstd` cargo feature",
})
}
#[cfg(feature = "zstd")]
fn decode_zstd(data: &[u8], uncompressed_len: usize) -> Result<Vec<u8>> {
let out = zstd::decode_all(data).map_err(|_| PersistError::Compression {
reason: "zstd decompression failed",
})?;
if out.len() != uncompressed_len {
return Err(PersistError::Compression {
reason: "zstd decompressed length does not match the recorded length",
});
}
Ok(out)
}
#[cfg(not(feature = "zstd"))]
fn decode_zstd(_data: &[u8], _uncompressed_len: usize) -> Result<Vec<u8>> {
Err(PersistError::Unsupported {
feature: "Zstd decompression",
available_in: "the `zstd` cargo feature",
})
}
#[cfg(feature = "lz4")]
fn encode_lz4(raw: &[u8]) -> Result<Cow<'_, [u8]>> {
Ok(Cow::Owned(lz4_flex::block::compress(raw)))
}
#[cfg(not(feature = "lz4"))]
fn encode_lz4(_raw: &[u8]) -> Result<Cow<'_, [u8]>> {
Err(PersistError::Unsupported {
feature: "LZ4 compression",
available_in: "the `lz4` cargo feature",
})
}
#[cfg(feature = "lz4")]
fn decode_lz4(data: &[u8], uncompressed_len: usize) -> Result<Vec<u8>> {
let out = lz4_flex::block::decompress(data, uncompressed_len).map_err(|_| {
PersistError::Compression {
reason: "lz4 decompression failed",
}
})?;
if out.len() != uncompressed_len {
return Err(PersistError::Compression {
reason: "lz4 decompressed length does not match the recorded length",
});
}
Ok(out)
}
#[cfg(not(feature = "lz4"))]
fn decode_lz4(_data: &[u8], _uncompressed_len: usize) -> Result<Vec<u8>> {
Err(PersistError::Unsupported {
feature: "LZ4 decompression",
available_in: "the `lz4` cargo feature",
})
}
#[cfg(all(test, any(feature = "zstd", feature = "lz4")))]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
fn round_trip(scheme: Compression) {
let raw: Vec<u8> = (0..4096u32).flat_map(|n| n.to_le_bytes()).collect();
let encoded = encode(scheme, &raw).unwrap();
let decoded = decode(scheme_tag(scheme), &encoded, raw.len()).unwrap();
assert_eq!(decoded, raw);
}
#[cfg(feature = "zstd")]
#[test]
fn zstd_round_trips() {
round_trip(Compression::Zstd { level: 3 });
}
#[cfg(feature = "zstd")]
#[test]
fn zstd_rejects_bad_level() {
assert!(matches!(
encode(Compression::Zstd { level: 99 }, b"data"),
Err(PersistError::Compression { .. })
));
}
#[cfg(feature = "lz4")]
#[test]
fn lz4_round_trips() {
round_trip(Compression::Lz4);
}
#[test]
fn none_is_identity() {
let raw = b"verbatim";
let encoded = encode(Compression::None, raw).unwrap();
assert_eq!(&*encoded, raw);
assert_eq!(decode(SCHEME_NONE, raw, raw.len()).unwrap(), raw);
}
#[cfg(feature = "lz4")]
#[test]
fn decompression_bomb_claim_is_rejected() {
let encoded = encode(Compression::Lz4, b"small").unwrap();
let huge = encoded.len() * MAX_DECOMPRESS_RATIO + 1;
assert!(matches!(
decode(SCHEME_LZ4, &encoded, huge),
Err(PersistError::InvalidPayload { .. })
));
}
}