use std::io::Cursor;
use crate::error::DecodeError;
pub const MAX_BLOCK_DECOMPRESSED_SIZE: usize = 16 * 1024 * 1024;
pub const MAX_PAYLOAD_DECOMPRESSED_SIZE: usize = 256 * 1024 * 1024;
pub fn decompress(data: &[u8], max_size: usize) -> Result<Vec<u8>, DecodeError> {
let decompressed = zstd::decode_all(Cursor::new(data))
.map_err(|e| DecodeError::DecompressFailed(e.to_string()))?;
if decompressed.len() > max_size {
return Err(DecodeError::DecompressionBomb {
actual: decompressed.len(),
limit: max_size,
});
}
Ok(decompressed)
}
#[cfg(test)]
mod tests {
use super::*;
fn compress_test_data(data: &[u8]) -> Vec<u8> {
zstd::encode_all(Cursor::new(data), 3).unwrap()
}
#[test]
fn decompress_roundtrip() {
let original = "fn main() { println!(\"hello\"); }\n".repeat(50);
let compressed = compress_test_data(original.as_bytes());
let result = decompress(&compressed, MAX_BLOCK_DECOMPRESSED_SIZE).unwrap();
assert_eq!(result, original.as_bytes());
}
#[test]
fn decompress_rejects_invalid_data() {
let garbage = b"this is not zstd data";
let result = decompress(garbage, MAX_BLOCK_DECOMPRESSED_SIZE);
assert!(matches!(result, Err(DecodeError::DecompressFailed(_))));
}
#[test]
fn decompress_rejects_bomb() {
let data = "x".repeat(10_000);
let compressed = compress_test_data(data.as_bytes());
let result = decompress(&compressed, 100);
assert!(matches!(result, Err(DecodeError::DecompressionBomb { .. })));
}
#[test]
fn constants_are_correct() {
assert_eq!(MAX_BLOCK_DECOMPRESSED_SIZE, 16 * 1024 * 1024);
assert_eq!(MAX_PAYLOAD_DECOMPRESSED_SIZE, 256 * 1024 * 1024);
}
}