rocksdb_fileformat/
compression.rs

1use crate::error::{Error, Result};
2use crate::types::CompressionType;
3
4/// Decompress data according to the specified compression type
5pub fn decompress(data: &[u8], compression_type: CompressionType) -> Result<Vec<u8>> {
6    match compression_type {
7        CompressionType::None => Ok(data.to_vec()),
8        CompressionType::Snappy => decompress_snappy(data),
9        CompressionType::Zlib => decompress_zlib(data),
10        CompressionType::LZ4 => decompress_lz4(data),
11        CompressionType::ZSTD => decompress_zstd(data),
12        _ => Err(Error::UnsupportedCompressionType(compression_type as u8)),
13    }
14}
15
16/// Compress data according to the specified compression type
17pub fn compress(data: &[u8], compression_type: CompressionType) -> Result<Vec<u8>> {
18    match compression_type {
19        CompressionType::None => Ok(data.to_vec()),
20        CompressionType::Snappy => compress_snappy(data),
21        CompressionType::Zlib => compress_zlib(data),
22        CompressionType::LZ4 => compress_lz4(data),
23        CompressionType::ZSTD => compress_zstd(data),
24        _ => Err(Error::UnsupportedCompressionType(compression_type as u8)),
25    }
26}
27
28fn compress_snappy(data: &[u8]) -> Result<Vec<u8>> {
29    snap::raw::Encoder::new()
30        .compress_vec(data)
31        .map_err(|e| Error::Decompression(format!("Snappy compression failed: {}", e)))
32}
33
34fn compress_zlib(data: &[u8]) -> Result<Vec<u8>> {
35    use flate2::Compression;
36    use flate2::write::ZlibEncoder;
37    use std::io::Write;
38
39    let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
40    encoder
41        .write_all(data)
42        .map_err(|e| Error::Decompression(format!("Zlib compression failed: {}", e)))?;
43    encoder
44        .finish()
45        .map_err(|e| Error::Decompression(format!("Zlib compression failed: {}", e)))
46}
47
48fn compress_lz4(data: &[u8]) -> Result<Vec<u8>> {
49    // LZ4 in RocksDB includes a 4-byte uncompressed size header
50    let compressed_block = lz4::block::compress(data, None, false)
51        .map_err(|e| Error::Decompression(format!("LZ4 compression failed: {}", e)))?;
52
53    let mut result = Vec::new();
54    result.extend_from_slice(&(data.len() as u32).to_le_bytes());
55    result.extend_from_slice(&compressed_block);
56    Ok(result)
57}
58
59fn compress_zstd(data: &[u8]) -> Result<Vec<u8>> {
60    zstd::stream::encode_all(data, 0)
61        .map_err(|e| Error::Decompression(format!("ZSTD compression failed: {}", e)))
62}
63
64fn decompress_snappy(data: &[u8]) -> Result<Vec<u8>> {
65    snap::raw::Decoder::new()
66        .decompress_vec(data)
67        .map_err(|e| Error::Decompression(format!("Snappy decompression failed: {}", e)))
68}
69
70fn decompress_zlib(data: &[u8]) -> Result<Vec<u8>> {
71    use flate2::read::ZlibDecoder;
72    use std::io::Read;
73
74    let mut decoder = ZlibDecoder::new(data);
75    let mut decompressed = Vec::new();
76    decoder
77        .read_to_end(&mut decompressed)
78        .map_err(|e| Error::Decompression(format!("Zlib decompression failed: {}", e)))?;
79
80    Ok(decompressed)
81}
82
83fn decompress_lz4(data: &[u8]) -> Result<Vec<u8>> {
84    // LZ4 in RocksDB includes a 4-byte uncompressed size header
85    if data.len() < 4 {
86        return Err(Error::Decompression("LZ4 data too short".to_string()));
87    }
88
89    let uncompressed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
90    let compressed_data = &data[4..];
91
92    lz4::block::decompress(compressed_data, Some(uncompressed_size as i32))
93        .map_err(|e| Error::Decompression(format!("LZ4 decompression failed: {}", e)))
94}
95
96fn decompress_zstd(data: &[u8]) -> Result<Vec<u8>> {
97    zstd::stream::decode_all(data)
98        .map_err(|e| Error::Decompression(format!("ZSTD decompression failed: {}", e)))
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn test_no_compression() -> Result<()> {
107        let data = b"hello world";
108        let result = decompress(data, CompressionType::None)?;
109        assert_eq!(result, data);
110        Ok(())
111    }
112
113    #[test]
114    fn test_snappy_compression() -> Result<()> {
115        let original = b"hello world hello world hello world";
116        let compressed = snap::raw::Encoder::new()
117            .compress_vec(original)
118            .map_err(|e| Error::Decompression(format!("Snappy compression failed: {}", e)))?;
119        let decompressed = decompress(&compressed, CompressionType::Snappy)?;
120        assert_eq!(decompressed, original);
121        Ok(())
122    }
123
124    #[test]
125    fn test_zlib_compression() -> Result<()> {
126        use flate2::Compression;
127        use flate2::write::ZlibEncoder;
128        use std::io::Write;
129
130        let original = b"hello world hello world hello world";
131        let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
132        encoder
133            .write_all(original)
134            .map_err(|e| Error::Decompression(format!("Zlib write failed: {}", e)))?;
135        let compressed = encoder
136            .finish()
137            .map_err(|e| Error::Decompression(format!("Zlib finish failed: {}", e)))?;
138
139        let decompressed = decompress(&compressed, CompressionType::Zlib)?;
140        assert_eq!(decompressed, original);
141        Ok(())
142    }
143
144    #[test]
145    fn test_lz4_compression() -> Result<()> {
146        let original = b"hello world hello world hello world";
147        let compressed_block = lz4::block::compress(original, None, false)
148            .map_err(|e| Error::Decompression(format!("LZ4 compression failed: {}", e)))?;
149
150        // Create LZ4 data with uncompressed size header (as RocksDB does)
151        let mut lz4_data = Vec::new();
152        lz4_data.extend_from_slice(&(original.len() as u32).to_le_bytes());
153        lz4_data.extend_from_slice(&compressed_block);
154
155        let decompressed = decompress(&lz4_data, CompressionType::LZ4)?;
156        assert_eq!(decompressed, original);
157        Ok(())
158    }
159
160    #[test]
161    fn test_zstd_compression() -> Result<()> {
162        let original = b"hello world hello world hello world";
163        let compressed = zstd::stream::encode_all(&original[..], 0)
164            .map_err(|e| Error::Decompression(format!("ZSTD compression failed: {}", e)))?;
165
166        let decompressed = decompress(&compressed, CompressionType::ZSTD)?;
167        assert_eq!(decompressed, original);
168        Ok(())
169    }
170
171    #[test]
172    fn test_unsupported_compression() -> Result<()> {
173        let data = b"hello world";
174        let result = decompress(data, CompressionType::BZip2);
175        assert!(matches!(result, Err(Error::UnsupportedCompressionType(_))));
176        Ok(())
177    }
178
179    #[test]
180    fn test_round_trip_no_compression() -> Result<()> {
181        let original = b"hello world hello world hello world";
182        let compressed = compress(original, CompressionType::None)?;
183        let decompressed = decompress(&compressed, CompressionType::None)?;
184        assert_eq!(decompressed, original);
185        Ok(())
186    }
187
188    #[test]
189    fn test_round_trip_snappy() -> Result<()> {
190        let original = b"hello world hello world hello world";
191        let compressed = compress(original, CompressionType::Snappy)?;
192        let decompressed = decompress(&compressed, CompressionType::Snappy)?;
193        assert_eq!(decompressed, original);
194        assert!(compressed.len() < original.len());
195        Ok(())
196    }
197
198    #[test]
199    fn test_round_trip_zlib() -> Result<()> {
200        let original = b"hello world hello world hello world";
201        let compressed = compress(original, CompressionType::Zlib)?;
202        let decompressed = decompress(&compressed, CompressionType::Zlib)?;
203        assert_eq!(decompressed, original);
204        assert!(compressed.len() < original.len());
205        Ok(())
206    }
207
208    #[test]
209    fn test_round_trip_lz4() -> Result<()> {
210        let original = b"hello world hello world hello world";
211        let compressed = compress(original, CompressionType::LZ4)?;
212        let decompressed = decompress(&compressed, CompressionType::LZ4)?;
213        assert_eq!(decompressed, original);
214        assert!(compressed.len() < original.len());
215        Ok(())
216    }
217
218    #[test]
219    fn test_round_trip_zstd() -> Result<()> {
220        let original = b"hello world hello world hello world";
221        let compressed = compress(original, CompressionType::ZSTD)?;
222        let decompressed = decompress(&compressed, CompressionType::ZSTD)?;
223        assert_eq!(decompressed, original);
224        assert!(compressed.len() < original.len());
225        Ok(())
226    }
227}