Skip to main content

bcp_encoder/
compression.rs

1use std::io::Cursor;
2
3use crate::error::CompressionError;
4
5/// Minimum block body size (in bytes) before per-block compression
6/// is attempted.
7///
8/// Blocks smaller than this threshold are always stored uncompressed
9/// because zstd framing overhead (~13 bytes for the frame header)
10/// outweighs any savings on very small inputs.
11///
12/// Default: 256 bytes.
13pub const COMPRESSION_THRESHOLD: usize = 256;
14
15/// Default zstd compression level (1โ€“22 scale).
16///
17/// Level 3 provides a good balance of speed and compression ratio
18/// for typical code/text context blocks (RFC ยง4.6). Higher levels
19/// yield diminishing returns for the latency cost.
20const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
21
22/// Compress a byte slice with zstd.
23///
24/// Returns `Some(compressed)` if compression reduced the size, or
25/// `None` if the compressed output is >= the input size. This
26/// ensures compression is never harmful โ€” the caller should store
27/// the block uncompressed when `None` is returned.
28///
29/// Uses the default compression level (3).
30///
31/// # Example
32///
33/// ```rust
34/// use bcp_encoder::compression::compress;
35///
36/// let data = "fn main() { }\n".repeat(100);
37/// match compress(data.as_bytes()) {
38///     Some(compressed) => assert!(compressed.len() < data.len()),
39///     None => { /* data was incompressible */ }
40/// }
41/// ```
42pub fn compress(data: &[u8]) -> Option<Vec<u8>> {
43    let compressed = zstd::encode_all(Cursor::new(data), DEFAULT_COMPRESSION_LEVEL).ok()?;
44    if compressed.len() < data.len() {
45        Some(compressed)
46    } else {
47        None
48    }
49}
50
51/// Decompress a zstd-compressed byte slice.
52///
53/// The `max_size` parameter provides an upper bound on the
54/// decompressed output to prevent decompression bombs โ€” if the
55/// decompressed data exceeds this limit, an error is returned
56/// without completing decompression.
57///
58/// # Errors
59///
60/// - [`CompressionError::DecompressFailed`] if zstd cannot decode
61///   the input (invalid frame, truncated data, etc.).
62/// - [`CompressionError::DecompressionBomb`] if the decompressed
63///   size exceeds `max_size`.
64pub fn decompress(data: &[u8], max_size: usize) -> Result<Vec<u8>, CompressionError> {
65    let decompressed = zstd::decode_all(Cursor::new(data))
66        .map_err(|e| CompressionError::DecompressFailed(e.to_string()))?;
67    if decompressed.len() > max_size {
68        return Err(CompressionError::DecompressionBomb {
69            actual: decompressed.len(),
70            limit: max_size,
71        });
72    }
73    Ok(decompressed)
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn compress_returns_none_for_small_incompressible_data() {
82        let data = b"abc123";
83        assert!(compress(data).is_none());
84    }
85
86    #[test]
87    fn compress_reduces_repetitive_data() {
88        let data = "fn main() { }\n".repeat(100);
89        let result = compress(data.as_bytes());
90        assert!(result.is_some());
91        let compressed = result.unwrap();
92        assert!(compressed.len() < data.len());
93    }
94
95    #[test]
96    fn compress_decompress_roundtrip() {
97        let data = "pub fn hello() -> &'static str { \"world\" }\n".repeat(50);
98        let compressed = compress(data.as_bytes()).expect("should compress");
99        let decompressed = decompress(&compressed, 1024 * 1024).expect("should decompress");
100        assert_eq!(decompressed, data.as_bytes());
101    }
102
103    #[test]
104    fn decompress_rejects_bomb() {
105        let data = "x".repeat(10_000);
106        let compressed = compress(data.as_bytes()).expect("should compress");
107        let result = decompress(&compressed, 100);
108        assert!(matches!(
109            result,
110            Err(CompressionError::DecompressionBomb { .. })
111        ));
112    }
113
114    #[test]
115    fn decompress_rejects_invalid_data() {
116        let garbage = b"this is not zstd data";
117        let result = decompress(garbage, 1024 * 1024);
118        assert!(matches!(result, Err(CompressionError::DecompressFailed(_))));
119    }
120
121    #[test]
122    fn compression_threshold_is_256() {
123        assert_eq!(COMPRESSION_THRESHOLD, 256);
124    }
125}