use bytes::Bytes;
use crate::{
ChunkManifest, Codec, CodecError, CodecKind, DECOMPRESS_BOOTSTRAP_CAPACITY,
validate_decompress_manifest,
};
#[derive(Debug, Clone)]
pub struct CpuZstd {
level: i32,
}
impl CpuZstd {
pub const DEFAULT_LEVEL: i32 = 3;
pub fn new(level: i32) -> Self {
Self {
level: level.clamp(1, 22),
}
}
}
impl Default for CpuZstd {
fn default() -> Self {
Self::new(Self::DEFAULT_LEVEL)
}
}
pub fn decompress_blocking(input: &[u8], manifest: &ChunkManifest) -> Result<Vec<u8>, CodecError> {
if manifest.codec != CodecKind::CpuZstd {
return Err(CodecError::CodecMismatch {
expected: CodecKind::CpuZstd,
got: manifest.codec,
});
}
let allocated_orig_size = validate_decompress_manifest(manifest, input.len())?;
use std::io::Read;
let limit = manifest.original_size.saturating_add(1024);
let mut decoder = zstd::stream::Decoder::new(input).map_err(CodecError::Io)?;
let mut buf = Vec::with_capacity(allocated_orig_size.min(DECOMPRESS_BOOTSTRAP_CAPACITY));
(&mut decoder)
.take(limit)
.read_to_end(&mut buf)
.map_err(CodecError::Io)?;
if (buf.len() as u64) > manifest.original_size {
return Err(CodecError::Io(std::io::Error::other(format!(
"zstd decompression bomb detected: produced {} bytes, manifest claimed {}",
buf.len(),
manifest.original_size
))));
}
if buf.len() as u64 != manifest.original_size {
return Err(CodecError::SizeMismatch {
expected: manifest.original_size,
got: buf.len() as u64,
});
}
let actual_crc = crc32c::crc32c(&buf);
if actual_crc != manifest.crc32c {
return Err(CodecError::CrcMismatch {
expected: manifest.crc32c,
got: actual_crc,
});
}
Ok(buf)
}
pub fn compress_blocking(input: &[u8], level: i32) -> Result<(Vec<u8>, ChunkManifest), CodecError> {
let level = level.clamp(1, 22);
let original_size = input.len() as u64;
let original_crc = crc32c::crc32c(input);
let compressed = zstd::stream::encode_all(input, level).map_err(CodecError::Io)?;
Ok((
compressed.clone(),
ChunkManifest {
codec: CodecKind::CpuZstd,
original_size,
compressed_size: compressed.len() as u64,
crc32c: original_crc,
},
))
}
#[async_trait::async_trait]
impl Codec for CpuZstd {
fn kind(&self) -> CodecKind {
CodecKind::CpuZstd
}
async fn compress(&self, input: Bytes) -> Result<(Bytes, ChunkManifest), CodecError> {
let level = self.level;
let original_size = input.len() as u64;
let original_crc = crc32c::crc32c(&input);
let compressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
zstd::stream::encode_all(input.as_ref(), level)
})
.await??;
let compressed_size = compressed.len() as u64;
let manifest = ChunkManifest {
codec: CodecKind::CpuZstd,
original_size,
compressed_size,
crc32c: original_crc,
};
Ok((Bytes::from(compressed), manifest))
}
async fn decompress(
&self,
input: Bytes,
manifest: &ChunkManifest,
) -> Result<Bytes, CodecError> {
if manifest.codec != CodecKind::CpuZstd {
return Err(CodecError::CodecMismatch {
expected: CodecKind::CpuZstd,
got: manifest.codec,
});
}
let allocated_orig_size = validate_decompress_manifest(manifest, input.len())?;
let expected_crc = manifest.crc32c;
let expected_orig_size = manifest.original_size;
let decompressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
use std::io::Read;
let limit = expected_orig_size.saturating_add(1024);
let mut decoder = zstd::stream::Decoder::new(input.as_ref())?;
let mut buf =
Vec::with_capacity(allocated_orig_size.min(DECOMPRESS_BOOTSTRAP_CAPACITY));
(&mut decoder).take(limit).read_to_end(&mut buf)?;
if (buf.len() as u64) > expected_orig_size {
return Err(std::io::Error::other(format!(
"zstd decompression bomb detected: produced {} bytes, manifest claimed {}",
buf.len(),
expected_orig_size
)));
}
Ok(buf)
})
.await??;
if decompressed.len() as u64 != expected_orig_size {
return Err(CodecError::SizeMismatch {
expected: expected_orig_size,
got: decompressed.len() as u64,
});
}
let actual_crc = crc32c::crc32c(&decompressed);
if actual_crc != expected_crc {
return Err(CodecError::CrcMismatch {
expected: expected_crc,
got: actual_crc,
});
}
Ok(Bytes::from(decompressed))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn roundtrip_small() {
let codec = CpuZstd::default();
let input = Bytes::from_static(b"hello squished s3 hello squished s3 hello squished s3");
let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
assert_eq!(manifest.codec, CodecKind::CpuZstd);
assert_eq!(manifest.original_size, input.len() as u64);
let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
assert_eq!(decompressed, input);
}
#[tokio::test]
async fn roundtrip_compressible() {
let codec = CpuZstd::default();
let input = Bytes::from(vec![b'x'; 1024 * 1024]);
let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
assert!(
compressed.len() < input.len() / 100,
"expected zstd to compress 1 MiB of x bytes very well, got {} bytes",
compressed.len()
);
let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
assert_eq!(decompressed, input);
}
#[tokio::test]
async fn detects_corrupted_compressed_payload() {
let codec = CpuZstd::default();
let input = Bytes::from(vec![b'x'; 1024]);
let (mut compressed, manifest) = codec.compress(input).await.unwrap();
let mut buf = compressed.to_vec();
if buf.len() > 8 {
buf[5] ^= 0xff;
}
compressed = Bytes::from(buf);
let err = codec.decompress(compressed, &manifest).await.unwrap_err();
assert!(matches!(
err,
CodecError::Io(_) | CodecError::CrcMismatch { .. } | CodecError::SizeMismatch { .. }
));
}
#[tokio::test]
async fn rejects_codec_mismatch() {
let codec = CpuZstd::default();
let manifest = ChunkManifest {
codec: CodecKind::Passthrough,
original_size: 10,
compressed_size: 10,
crc32c: 0,
};
let err = codec
.decompress(Bytes::from_static(b"0123456789"), &manifest)
.await
.unwrap_err();
assert!(matches!(err, CodecError::CodecMismatch { .. }));
}
#[tokio::test]
async fn issue_89_rejects_manifest_over_5gib() {
let codec = CpuZstd::default();
let body = Bytes::from_static(&[0x00, 0xd1, 0xd1, 0xd1, 0xd1, 0xd1]);
let manifest = ChunkManifest {
codec: CodecKind::CpuZstd,
original_size: crate::MAX_DECOMPRESSED_BYTES + 1,
compressed_size: body.len() as u64,
crc32c: 0,
};
let err = codec.decompress(body, &manifest).await.unwrap_err();
match err {
CodecError::ManifestSizeExceedsLimit { requested, limit } => {
assert_eq!(requested, crate::MAX_DECOMPRESSED_BYTES + 1);
assert_eq!(limit, crate::MAX_DECOMPRESSED_BYTES);
}
other => panic!("expected ManifestSizeExceedsLimit, got {other:?}"),
}
}
#[tokio::test]
async fn issue_89_bootstrap_cap_keeps_4gib_claim_alloc_safe() {
let codec = CpuZstd::default();
let body = Bytes::from_static(&[0x00, 0xd1, 0xd1, 0xd1, 0xd1, 0xd1]);
let manifest = ChunkManifest {
codec: CodecKind::CpuZstd,
original_size: u32::MAX as u64,
compressed_size: body.len() as u64,
crc32c: 0,
};
let err = codec.decompress(body, &manifest).await.unwrap_err();
assert!(
matches!(err, CodecError::Io(_) | CodecError::SizeMismatch { .. }),
"expected Io or SizeMismatch, got {err:?}"
);
}
#[test]
fn blocking_roundtrip() {
let input = b"hello squished s3 hello squished s3 hello squished s3";
let (compressed, manifest) = compress_blocking(input, CpuZstd::DEFAULT_LEVEL).unwrap();
assert_eq!(manifest.codec, CodecKind::CpuZstd);
let decompressed = decompress_blocking(&compressed, &manifest).unwrap();
assert_eq!(decompressed, input);
}
}