bitfold_protocol/command_codec/
compression.rs

1//! Data compression and decompression utilities.
2
3use std::io::{self, Read, Write};
4
5use bitfold_core::config::CompressionAlgorithm;
6use flate2::{read::ZlibDecoder, write::ZlibEncoder, Compression};
7
8/// Compresses data using the specified algorithm.
9/// Returns compressed data with 1-byte header: `[algorithm_id][compressed_data]`
10/// Returns original data with header `[0][original_data]` if compression is disabled or ineffective.
11pub fn compress(
12    data: &[u8],
13    algorithm: CompressionAlgorithm,
14    threshold: usize,
15) -> io::Result<Vec<u8>> {
16    compress_with_buffer(data, algorithm, threshold, Vec::new())
17}
18
19/// Compresses data using the specified algorithm with a provided output buffer.
20/// This version reuses the output buffer to reduce allocations in hot paths.
21/// Returns the compressed data, reusing the provided buffer when possible.
22pub fn compress_with_buffer(
23    data: &[u8],
24    algorithm: CompressionAlgorithm,
25    threshold: usize,
26    mut output: Vec<u8>,
27) -> io::Result<Vec<u8>> {
28    output.clear();
29
30    // Don't compress small packets
31    if data.len() < threshold {
32        output.reserve(data.len() + 1);
33        output.push(0); // Uncompressed marker
34        output.extend_from_slice(data);
35        return Ok(output);
36    }
37
38    match algorithm {
39        CompressionAlgorithm::None => {
40            output.reserve(data.len() + 1);
41            output.push(0); // Uncompressed marker
42            output.extend_from_slice(data);
43            Ok(output)
44        }
45        CompressionAlgorithm::Zlib => {
46            let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
47            encoder.write_all(data)?;
48            let compressed = encoder.finish()?;
49
50            // Only use compression if it actually reduces size
51            if compressed.len() < data.len() {
52                output.reserve(compressed.len() + 1);
53                output.push(1); // Zlib marker
54                output.extend_from_slice(&compressed);
55                Ok(output)
56            } else {
57                output.reserve(data.len() + 1);
58                output.push(0); // Uncompressed marker
59                output.extend_from_slice(data);
60                Ok(output)
61            }
62        }
63        CompressionAlgorithm::Lz4 => {
64            let compressed = lz4::block::compress(data, None, false)?;
65
66            // Only use compression if it actually reduces size
67            if compressed.len() + 4 < data.len() {
68                output.reserve(compressed.len() + 5);
69                output.push(2); // LZ4 marker
70                output.extend_from_slice(&(data.len() as u32).to_be_bytes());
71                output.extend_from_slice(&compressed);
72                Ok(output)
73            } else {
74                output.reserve(data.len() + 1);
75                output.push(0); // Uncompressed marker
76                output.extend_from_slice(data);
77                Ok(output)
78            }
79        }
80    }
81}
82
83/// Decompresses data based on the 1-byte header.
84/// Header format: `[algorithm_id][data]`
85/// - 0: Uncompressed
86/// - 1: Zlib
87/// - 2: LZ4
88pub fn decompress(data: &[u8]) -> io::Result<Vec<u8>> {
89    if data.is_empty() {
90        return Err(io::Error::new(io::ErrorKind::InvalidData, "Empty data for decompression"));
91    }
92
93    let algorithm_id = data[0];
94    let payload = &data[1..];
95
96    match algorithm_id {
97        0 => {
98            // Uncompressed
99            Ok(payload.to_vec())
100        }
101        1 => {
102            // Zlib
103            let mut decoder = ZlibDecoder::new(payload);
104            let mut decompressed = Vec::new();
105            decoder.read_to_end(&mut decompressed)?;
106            Ok(decompressed)
107        }
108        2 => {
109            // LZ4 - first 4 bytes are original size
110            if payload.len() < 4 {
111                return Err(io::Error::new(io::ErrorKind::InvalidData, "LZ4 payload too short"));
112            }
113            let original_size =
114                u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]) as usize;
115            let compressed_data = &payload[4..];
116            let decompressed = lz4::block::decompress(compressed_data, Some(original_size as i32))?;
117            Ok(decompressed)
118        }
119        _ => Err(io::Error::new(
120            io::ErrorKind::InvalidData,
121            format!("Unknown compression algorithm: {}", algorithm_id),
122        )),
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn test_compression_none() {
132        let data = b"Test data that will not be compressed";
133        let compressed = compress(data, CompressionAlgorithm::None, 10).unwrap();
134        assert_eq!(compressed[0], 0); // Uncompressed marker
135        assert_eq!(&compressed[1..], data);
136
137        let decompressed = decompress(&compressed).unwrap();
138        assert_eq!(decompressed, data);
139    }
140
141    #[test]
142    fn test_compression_zlib() {
143        let data = b"Test data that should compress well because it has lots of repetition repetition repetition";
144        let compressed = compress(data, CompressionAlgorithm::Zlib, 10).unwrap();
145        assert_eq!(compressed[0], 1); // Zlib marker
146        assert!(compressed.len() < data.len() + 1);
147
148        let decompressed = decompress(&compressed).unwrap();
149        assert_eq!(decompressed, data);
150    }
151
152    #[test]
153    fn test_compression_lz4() {
154        let data = b"Test data that should compress well because it has lots of repetition repetition repetition";
155        let compressed = compress(data, CompressionAlgorithm::Lz4, 10).unwrap();
156        assert_eq!(compressed[0], 2); // LZ4 marker
157        assert!(compressed.len() < data.len() + 5);
158
159        let decompressed = decompress(&compressed).unwrap();
160        assert_eq!(decompressed, data);
161    }
162
163    #[test]
164    fn test_compression_below_threshold() {
165        let data = b"tiny";
166        let compressed = compress(data, CompressionAlgorithm::Zlib, 100).unwrap();
167        assert_eq!(compressed[0], 0); // Should not compress
168        assert_eq!(&compressed[1..], data);
169    }
170
171    #[test]
172    fn test_compression_ineffective() {
173        // Random-ish data that won't compress well
174        let data = b"a1b2c3d4e5f6g7h8i9j0";
175        let compressed = compress(data, CompressionAlgorithm::Zlib, 5).unwrap();
176        // Should fall back to uncompressed if compression doesn't help
177        if compressed[0] == 0 {
178            assert_eq!(&compressed[1..], data);
179        }
180    }
181
182    #[test]
183    fn test_decompression_unknown_algorithm() {
184        let data = vec![99, 1, 2, 3]; // Invalid algorithm ID
185        assert!(decompress(&data).is_err());
186    }
187
188    #[test]
189    fn test_compress_with_buffer_reuse() {
190        let data = b"Test data for buffer reuse";
191        let buffer = Vec::with_capacity(100);
192        let compressed =
193            compress_with_buffer(data, CompressionAlgorithm::None, 10, buffer).unwrap();
194        assert_eq!(compressed[0], 0);
195        assert_eq!(&compressed[1..], data);
196    }
197}