Skip to main content

oxigdal_websocket/protocol/
compression.rs

1//! Compression support for WebSocket protocol
2
3use crate::error::{Error, Result};
4use bytes::{Bytes, BytesMut};
5
6/// Compression type
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum CompressionType {
9    /// No compression
10    None,
11    /// Gzip compression
12    Gzip,
13    /// Zstd compression
14    Zstd,
15}
16
17/// Compression level
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum CompressionLevel {
20    /// Fastest compression
21    Fast,
22    /// Default compression
23    Default,
24    /// Best compression
25    Best,
26}
27
28impl CompressionLevel {
29    /// Get zstd compression level
30    pub fn zstd_level(&self) -> i32 {
31        match self {
32            CompressionLevel::Fast => 1,
33            CompressionLevel::Default => 3,
34            CompressionLevel::Best => 19,
35        }
36    }
37
38    /// Get gzip compression level as u8
39    pub fn gzip_level_u8(&self) -> u8 {
40        match self {
41            CompressionLevel::Fast => 1,
42            CompressionLevel::Default => 6,
43            CompressionLevel::Best => 9,
44        }
45    }
46}
47
48/// Compression codec
49pub struct CompressionCodec {
50    compression_type: CompressionType,
51    level: CompressionLevel,
52}
53
54impl CompressionCodec {
55    /// Create a new compression codec
56    pub fn new(compression_type: CompressionType, level: CompressionLevel) -> Self {
57        Self {
58            compression_type,
59            level,
60        }
61    }
62
63    /// Compress data
64    pub fn compress(&self, data: &[u8]) -> Result<BytesMut> {
65        match self.compression_type {
66            CompressionType::None => Ok(BytesMut::from(data)),
67            CompressionType::Gzip => self.compress_gzip(data),
68            CompressionType::Zstd => self.compress_zstd(data),
69        }
70    }
71
72    /// Decompress data
73    pub fn decompress(&self, data: &[u8]) -> Result<Bytes> {
74        match self.compression_type {
75            CompressionType::None => Ok(Bytes::copy_from_slice(data)),
76            CompressionType::Gzip => self.decompress_gzip(data),
77            CompressionType::Zstd => self.decompress_zstd(data),
78        }
79    }
80
81    /// Compress data using gzip
82    fn compress_gzip(&self, data: &[u8]) -> Result<BytesMut> {
83        let compressed = oxiarc_archive::gzip::compress(data, self.level.gzip_level_u8())
84            .map_err(|e| Error::Compression(format!("Gzip compression failed: {}", e)))?;
85        Ok(BytesMut::from(&compressed[..]))
86    }
87
88    /// Decompress data using gzip
89    fn decompress_gzip(&self, data: &[u8]) -> Result<Bytes> {
90        let mut reader = std::io::Cursor::new(data);
91        let decompressed = oxiarc_archive::gzip::decompress(&mut reader)
92            .map_err(|e| Error::Compression(format!("Gzip decompression failed: {}", e)))?;
93        Ok(Bytes::from(decompressed))
94    }
95
96    /// Compress data using zstd
97    fn compress_zstd(&self, data: &[u8]) -> Result<BytesMut> {
98        let compressed = oxiarc_zstd::encode_all(data, self.level.zstd_level())
99            .map_err(|e| Error::Compression(format!("Zstd compression failed: {}", e)))?;
100        Ok(BytesMut::from(&compressed[..]))
101    }
102
103    /// Decompress data using zstd
104    fn decompress_zstd(&self, data: &[u8]) -> Result<Bytes> {
105        let decompressed = oxiarc_zstd::decode_all(data)
106            .map_err(|e| Error::Compression(format!("Zstd decompression failed: {}", e)))?;
107        Ok(Bytes::from(decompressed))
108    }
109
110    /// Get compression type
111    pub fn compression_type(&self) -> CompressionType {
112        self.compression_type
113    }
114
115    /// Get compression level
116    pub fn level(&self) -> CompressionLevel {
117        self.level
118    }
119}
120
121/// Estimate compression ratio for data
122pub fn estimate_compression_ratio(data: &[u8]) -> f64 {
123    // Simple heuristic: count unique bytes
124    let mut seen = [false; 256];
125    let mut unique_count = 0;
126
127    for &byte in data {
128        if !seen[byte as usize] {
129            seen[byte as usize] = true;
130            unique_count += 1;
131        }
132    }
133
134    // Lower unique count suggests better compression
135    let ratio = unique_count as f64 / 256.0;
136    1.0 - ratio // Higher value means better compression potential
137}
138
139/// Determine if data should be compressed based on size and content
140pub fn should_compress(data: &[u8], min_size: usize) -> bool {
141    if data.len() < min_size {
142        return false;
143    }
144
145    // Check compression potential
146    estimate_compression_ratio(data) > 0.3
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn test_gzip_compression() -> Result<()> {
155        let codec = CompressionCodec::new(CompressionType::Gzip, CompressionLevel::Default);
156        let data = b"Hello, World! This is a test message.".repeat(10);
157
158        let compressed = codec.compress(&data)?;
159        let decompressed = codec.decompress(&compressed)?;
160
161        assert_eq!(data.as_slice(), decompressed.as_ref());
162        assert!(compressed.len() < data.len());
163        Ok(())
164    }
165
166    #[test]
167    fn test_zstd_compression() -> Result<()> {
168        let codec = CompressionCodec::new(CompressionType::Zstd, CompressionLevel::Default);
169        let data = b"Hello, World! This is a test message.".repeat(10);
170
171        let compressed = codec.compress(&data)?;
172        let decompressed = codec.decompress(&compressed)?;
173
174        assert_eq!(data.as_slice(), decompressed.as_ref());
175        assert!(compressed.len() < data.len());
176        Ok(())
177    }
178
179    #[test]
180    fn test_no_compression() -> Result<()> {
181        let codec = CompressionCodec::new(CompressionType::None, CompressionLevel::Default);
182        let data = b"Hello, World!";
183
184        let compressed = codec.compress(data)?;
185        let decompressed = codec.decompress(&compressed)?;
186
187        assert_eq!(data, compressed.as_ref());
188        assert_eq!(data, decompressed.as_ref());
189        Ok(())
190    }
191
192    #[test]
193    fn test_compression_levels() -> Result<()> {
194        let data = b"Hello, World! This is a test message.".repeat(100);
195
196        let fast = CompressionCodec::new(CompressionType::Zstd, CompressionLevel::Fast);
197        let default = CompressionCodec::new(CompressionType::Zstd, CompressionLevel::Default);
198        let best = CompressionCodec::new(CompressionType::Zstd, CompressionLevel::Best);
199
200        let fast_compressed = fast.compress(&data)?;
201        let default_compressed = default.compress(&data)?;
202        let best_compressed = best.compress(&data)?;
203
204        // Best should compress better than default, default better than fast
205        assert!(best_compressed.len() <= default_compressed.len());
206        assert!(default_compressed.len() <= fast_compressed.len());
207
208        Ok(())
209    }
210
211    #[test]
212    fn test_estimate_compression_ratio() {
213        // Highly repetitive data
214        let repetitive = vec![0u8; 1000];
215        let ratio1 = estimate_compression_ratio(&repetitive);
216        assert!(ratio1 > 0.9);
217
218        // Random-like data
219        let random: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
220        let ratio2 = estimate_compression_ratio(&random);
221        assert!(ratio2 < ratio1);
222    }
223
224    #[test]
225    fn test_should_compress() {
226        // Too small
227        let small = vec![0u8; 10];
228        assert!(!should_compress(&small, 100));
229
230        // Large and repetitive
231        let large_repetitive = vec![0u8; 1000];
232        assert!(should_compress(&large_repetitive, 100));
233
234        // Large but random
235        let large_random: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
236        // This might or might not compress well depending on the threshold
237        let _ = should_compress(&large_random, 100);
238    }
239}