use std::io::Read;
use kimberlite_types::{BoundedSize, CompressionKind};
use crate::StorageError;
const MAX_DECOMPRESSED_SIZE: usize = 1024 * 1024 * 1024;
pub trait Codec: Send + Sync {
fn kind(&self) -> CompressionKind;
fn compress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError>;
fn decompress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError>;
}
#[derive(Debug, Clone, Copy)]
pub struct NoneCodec;
impl Codec for NoneCodec {
fn kind(&self) -> CompressionKind {
CompressionKind::None
}
fn compress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError> {
Ok(input.to_vec())
}
fn decompress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError> {
Ok(input.to_vec())
}
}
#[derive(Debug, Clone, Copy)]
pub struct Lz4Codec;
impl Codec for Lz4Codec {
fn kind(&self) -> CompressionKind {
CompressionKind::Lz4
}
fn compress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError> {
Ok(lz4_flex::compress_prepend_size(input))
}
fn decompress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError> {
if input.len() < 4 {
return Err(StorageError::DecompressionFailed {
codec: "lz4",
reason: format!(
"input too short: need 4-byte size prefix, got {} bytes",
input.len()
),
});
}
let claimed_size_raw = u32::from_le_bytes(input[0..4].try_into().expect("4 bytes"));
let _claimed_size: BoundedSize<MAX_DECOMPRESSED_SIZE> = BoundedSize::try_from(claimed_size_raw)
.map_err(|e| StorageError::DecompressionFailed {
codec: "lz4",
reason: format!(
"claimed size {} exceeds MAX_DECOMPRESSED_SIZE ({})",
e.value, e.max
),
})?;
lz4_flex::decompress_size_prepended(input).map_err(|e| StorageError::DecompressionFailed {
codec: "lz4",
reason: e.to_string(),
})
}
}
#[derive(Debug, Clone, Copy)]
pub struct ZstdCodec {
pub level: i32,
}
impl ZstdCodec {
pub fn new(level: i32) -> Self {
Self { level }
}
}
impl Default for ZstdCodec {
fn default() -> Self {
Self { level: 3 }
}
}
impl Codec for ZstdCodec {
fn kind(&self) -> CompressionKind {
CompressionKind::Zstd
}
fn compress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError> {
zstd::encode_all(input, self.level).map_err(|e| StorageError::CompressionFailed {
codec: "zstd",
reason: e.to_string(),
})
}
fn decompress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError> {
let decoder = zstd::Decoder::new(input).map_err(|e| StorageError::DecompressionFailed {
codec: "zstd",
reason: format!("failed to create decoder: {e}"),
})?;
let mut output = Vec::new();
let mut limited_reader = decoder.take(MAX_DECOMPRESSED_SIZE as u64);
let bytes_read = std::io::copy(&mut limited_reader, &mut output).map_err(|e| {
StorageError::DecompressionFailed {
codec: "zstd",
reason: format!("decompression failed: {e}"),
}
})?;
if bytes_read == MAX_DECOMPRESSED_SIZE as u64 {
let mut probe = [0u8; 1];
let mut decoder_inner = limited_reader.into_inner();
if decoder_inner
.read(&mut probe)
.map_err(|e| StorageError::DecompressionFailed {
codec: "zstd",
reason: format!("probe read failed: {e}"),
})?
> 0
{
return Err(StorageError::DecompressionFailed {
codec: "zstd",
reason: format!(
"decompressed size exceeds MAX_DECOMPRESSED_SIZE ({MAX_DECOMPRESSED_SIZE} bytes)"
),
});
}
}
Ok(output)
}
}
#[derive(Debug)]
pub struct CodecRegistry {
lz4: Lz4Codec,
zstd: ZstdCodec,
none: NoneCodec,
}
impl CodecRegistry {
pub fn new() -> Self {
Self {
lz4: Lz4Codec,
zstd: ZstdCodec::default(),
none: NoneCodec,
}
}
pub fn with_zstd_level(level: i32) -> Self {
Self {
lz4: Lz4Codec,
zstd: ZstdCodec::new(level),
none: NoneCodec,
}
}
pub fn get(&self, kind: CompressionKind) -> &dyn Codec {
match kind {
CompressionKind::None => &self.none,
CompressionKind::Lz4 => &self.lz4,
CompressionKind::Zstd => &self.zstd,
}
}
pub fn compress(&self, kind: CompressionKind, data: &[u8]) -> Result<Vec<u8>, StorageError> {
self.get(kind).compress(data)
}
pub fn decompress(&self, kind: CompressionKind, data: &[u8]) -> Result<Vec<u8>, StorageError> {
self.get(kind).decompress(data)
}
}
impl Default for CodecRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn none_codec_roundtrip() {
let codec = NoneCodec;
let data = b"hello world";
let compressed = codec.compress(data).unwrap();
let decompressed = codec.decompress(&compressed).unwrap();
assert_eq!(data.as_slice(), &decompressed);
}
#[test]
fn lz4_codec_roundtrip() {
let codec = Lz4Codec;
let data = b"hello world hello world hello world";
let compressed = codec.compress(data).unwrap();
let decompressed = codec.decompress(&compressed).unwrap();
assert_eq!(data.as_slice(), &decompressed);
}
#[test]
fn lz4_rejects_oversized_size_prefix() {
let codec = Lz4Codec;
let bomb = [0xFF, 0xFF, 0xFF, 0xFF, 0x00];
let err = codec
.decompress(&bomb)
.expect_err("oversized size prefix must be rejected");
match err {
StorageError::DecompressionFailed { codec: c, reason } => {
assert_eq!(c, "lz4");
assert!(
reason.contains("exceeds MAX_DECOMPRESSED_SIZE"),
"expected size-prefix guard error, got: {reason}"
);
}
other => panic!("expected DecompressionFailed, got {other:?}"),
}
}
#[test]
fn lz4_rejects_short_input() {
let codec = Lz4Codec;
assert!(codec.decompress(&[]).is_err());
assert!(codec.decompress(&[0x00, 0x00, 0x00]).is_err());
}
#[test]
fn zstd_codec_roundtrip() {
let codec = ZstdCodec::default();
let data = b"hello world hello world hello world";
let compressed = codec.compress(data).unwrap();
let decompressed = codec.decompress(&compressed).unwrap();
assert_eq!(data.as_slice(), &decompressed);
}
#[test]
fn lz4_compresses_repetitive_data() {
let codec = Lz4Codec;
let data: Vec<u8> = vec![42; 10_000];
let compressed = codec.compress(&data).unwrap();
assert!(compressed.len() < data.len());
}
#[test]
fn zstd_compresses_repetitive_data() {
let codec = ZstdCodec::default();
let data: Vec<u8> = vec![42; 10_000];
let compressed = codec.compress(&data).unwrap();
assert!(compressed.len() < data.len());
}
#[test]
fn codec_registry_lookup() {
let registry = CodecRegistry::new();
assert_eq!(
registry.get(CompressionKind::None).kind(),
CompressionKind::None
);
assert_eq!(
registry.get(CompressionKind::Lz4).kind(),
CompressionKind::Lz4
);
assert_eq!(
registry.get(CompressionKind::Zstd).kind(),
CompressionKind::Zstd
);
}
#[test]
fn codec_registry_roundtrip() {
let registry = CodecRegistry::new();
let data = b"test data for codec registry roundtrip";
for kind in [
CompressionKind::None,
CompressionKind::Lz4,
CompressionKind::Zstd,
] {
let compressed = registry.compress(kind, data).unwrap();
let decompressed = registry.decompress(kind, &compressed).unwrap();
assert_eq!(
data.as_slice(),
&decompressed,
"roundtrip failed for {kind}"
);
}
}
#[test]
fn empty_data_roundtrip() {
let registry = CodecRegistry::new();
let data = b"";
for kind in [
CompressionKind::None,
CompressionKind::Lz4,
CompressionKind::Zstd,
] {
let compressed = registry.compress(kind, data).unwrap();
let decompressed = registry.decompress(kind, &compressed).unwrap();
assert_eq!(
data.as_slice(),
&decompressed,
"empty roundtrip failed for {kind}"
);
}
}
#[test]
fn zstd_rejects_decompression_bomb() {
let bomb_size = MAX_DECOMPRESSED_SIZE + 1024 * 1024; let payload: Vec<u8> = vec![0u8; bomb_size];
let codec = ZstdCodec::default();
let compressed = codec.compress(&payload).unwrap();
assert!(
compressed.len() < bomb_size / 100,
"compressed size {} should be <1% of original {}",
compressed.len(),
bomb_size
);
let result = codec.decompress(&compressed);
assert!(result.is_err(), "decompression bomb should be rejected");
let err = result.unwrap_err();
match err {
StorageError::DecompressionFailed { codec: c, reason } => {
assert_eq!(c, "zstd");
assert!(
reason.contains("exceeds MAX_DECOMPRESSED_SIZE"),
"error should mention size limit: {reason}"
);
}
_ => panic!("wrong error type: {err:?}"),
}
}
#[test]
fn zstd_allows_large_but_under_limit_data() {
let size = MAX_DECOMPRESSED_SIZE / 2;
let payload: Vec<u8> = vec![42u8; size];
let codec = ZstdCodec::default();
let compressed = codec.compress(&payload).unwrap();
let decompressed = codec.decompress(&compressed).unwrap();
assert_eq!(decompressed.len(), size);
assert_eq!(decompressed, payload);
}
#[cfg(test)]
mod proptest_codec {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn zstd_roundtrip_under_limit(data in prop::collection::vec(any::<u8>(), 0..1024*1024)) {
let codec = ZstdCodec::default();
let compressed = codec.compress(&data).unwrap();
let decompressed = codec.decompress(&compressed).unwrap();
assert_eq!(data, decompressed);
}
#[test]
fn zstd_rejects_oversized_payloads(
byte in any::<u8>(),
multiplier in 1u32..10
) {
let size = MAX_DECOMPRESSED_SIZE + (multiplier as usize * 10 * 1024 * 1024);
let payload = vec![byte; size];
let codec = ZstdCodec::default();
let compressed = codec.compress(&payload).unwrap();
let result = codec.decompress(&compressed);
assert!(result.is_err(), "oversized payload should be rejected");
}
}
}
}