Skip to main content

oxigdal_websocket/protocol/
compression.rs

1//! Compression support for WebSocket protocol
2
3use crate::error::{Error, Result};
4use bytes::{Bytes, BytesMut};
5use std::io::{Read, Write};
6
7/// Compression type
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum CompressionType {
10    /// No compression
11    None,
12    /// Gzip compression
13    Gzip,
14    /// Zstd compression
15    Zstd,
16}
17
18/// Compression level
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum CompressionLevel {
21    /// Fastest compression
22    Fast,
23    /// Default compression
24    Default,
25    /// Best compression
26    Best,
27}
28
29impl CompressionLevel {
30    /// Get zstd compression level
31    pub fn zstd_level(&self) -> i32 {
32        match self {
33            CompressionLevel::Fast => 1,
34            CompressionLevel::Default => 3,
35            CompressionLevel::Best => 19,
36        }
37    }
38
39    /// Get gzip compression level
40    pub fn gzip_level(&self) -> flate2::Compression {
41        match self {
42            CompressionLevel::Fast => flate2::Compression::fast(),
43            CompressionLevel::Default => flate2::Compression::default(),
44            CompressionLevel::Best => flate2::Compression::best(),
45        }
46    }
47}
48
49/// Compression codec
50pub struct CompressionCodec {
51    compression_type: CompressionType,
52    level: CompressionLevel,
53}
54
55impl CompressionCodec {
56    /// Create a new compression codec
57    pub fn new(compression_type: CompressionType, level: CompressionLevel) -> Self {
58        Self {
59            compression_type,
60            level,
61        }
62    }
63
64    /// Compress data
65    pub fn compress(&self, data: &[u8]) -> Result<BytesMut> {
66        match self.compression_type {
67            CompressionType::None => Ok(BytesMut::from(data)),
68            CompressionType::Gzip => self.compress_gzip(data),
69            CompressionType::Zstd => self.compress_zstd(data),
70        }
71    }
72
73    /// Decompress data
74    pub fn decompress(&self, data: &[u8]) -> Result<Bytes> {
75        match self.compression_type {
76            CompressionType::None => Ok(Bytes::copy_from_slice(data)),
77            CompressionType::Gzip => self.decompress_gzip(data),
78            CompressionType::Zstd => self.decompress_zstd(data),
79        }
80    }
81
82    /// Compress data using gzip
83    fn compress_gzip(&self, data: &[u8]) -> Result<BytesMut> {
84        use flate2::write::GzEncoder;
85
86        let mut encoder = GzEncoder::new(Vec::new(), self.level.gzip_level());
87        encoder
88            .write_all(data)
89            .map_err(|e| Error::Compression(format!("Gzip compression failed: {}", e)))?;
90
91        let compressed = encoder
92            .finish()
93            .map_err(|e| Error::Compression(format!("Gzip finish failed: {}", e)))?;
94
95        Ok(BytesMut::from(&compressed[..]))
96    }
97
98    /// Decompress data using gzip
99    fn decompress_gzip(&self, data: &[u8]) -> Result<Bytes> {
100        use flate2::read::GzDecoder;
101
102        let mut decoder = GzDecoder::new(data);
103        let mut decompressed = Vec::new();
104
105        decoder
106            .read_to_end(&mut decompressed)
107            .map_err(|e| Error::Compression(format!("Gzip decompression failed: {}", e)))?;
108
109        Ok(Bytes::from(decompressed))
110    }
111
112    /// Compress data using zstd
113    fn compress_zstd(&self, data: &[u8]) -> Result<BytesMut> {
114        let compressed = zstd::encode_all(data, self.level.zstd_level())
115            .map_err(|e| Error::Compression(format!("Zstd compression failed: {}", e)))?;
116
117        Ok(BytesMut::from(&compressed[..]))
118    }
119
120    /// Decompress data using zstd
121    fn decompress_zstd(&self, data: &[u8]) -> Result<Bytes> {
122        let decompressed = zstd::decode_all(data)
123            .map_err(|e| Error::Compression(format!("Zstd decompression failed: {}", e)))?;
124
125        Ok(Bytes::from(decompressed))
126    }
127
128    /// Get compression type
129    pub fn compression_type(&self) -> CompressionType {
130        self.compression_type
131    }
132
133    /// Get compression level
134    pub fn level(&self) -> CompressionLevel {
135        self.level
136    }
137}
138
139/// Estimate compression ratio for data
140pub fn estimate_compression_ratio(data: &[u8]) -> f64 {
141    // Simple heuristic: count unique bytes
142    let mut seen = [false; 256];
143    let mut unique_count = 0;
144
145    for &byte in data {
146        if !seen[byte as usize] {
147            seen[byte as usize] = true;
148            unique_count += 1;
149        }
150    }
151
152    // Lower unique count suggests better compression
153    let ratio = unique_count as f64 / 256.0;
154    1.0 - ratio // Higher value means better compression potential
155}
156
157/// Determine if data should be compressed based on size and content
158pub fn should_compress(data: &[u8], min_size: usize) -> bool {
159    if data.len() < min_size {
160        return false;
161    }
162
163    // Check compression potential
164    estimate_compression_ratio(data) > 0.3
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn test_gzip_compression() -> Result<()> {
173        let codec = CompressionCodec::new(CompressionType::Gzip, CompressionLevel::Default);
174        let data = b"Hello, World! This is a test message.".repeat(10);
175
176        let compressed = codec.compress(&data)?;
177        let decompressed = codec.decompress(&compressed)?;
178
179        assert_eq!(data.as_slice(), decompressed.as_ref());
180        assert!(compressed.len() < data.len());
181        Ok(())
182    }
183
184    #[test]
185    fn test_zstd_compression() -> Result<()> {
186        let codec = CompressionCodec::new(CompressionType::Zstd, CompressionLevel::Default);
187        let data = b"Hello, World! This is a test message.".repeat(10);
188
189        let compressed = codec.compress(&data)?;
190        let decompressed = codec.decompress(&compressed)?;
191
192        assert_eq!(data.as_slice(), decompressed.as_ref());
193        assert!(compressed.len() < data.len());
194        Ok(())
195    }
196
197    #[test]
198    fn test_no_compression() -> Result<()> {
199        let codec = CompressionCodec::new(CompressionType::None, CompressionLevel::Default);
200        let data = b"Hello, World!";
201
202        let compressed = codec.compress(data)?;
203        let decompressed = codec.decompress(&compressed)?;
204
205        assert_eq!(data, compressed.as_ref());
206        assert_eq!(data, decompressed.as_ref());
207        Ok(())
208    }
209
210    #[test]
211    fn test_compression_levels() -> Result<()> {
212        let data = b"Hello, World! This is a test message.".repeat(100);
213
214        let fast = CompressionCodec::new(CompressionType::Zstd, CompressionLevel::Fast);
215        let default = CompressionCodec::new(CompressionType::Zstd, CompressionLevel::Default);
216        let best = CompressionCodec::new(CompressionType::Zstd, CompressionLevel::Best);
217
218        let fast_compressed = fast.compress(&data)?;
219        let default_compressed = default.compress(&data)?;
220        let best_compressed = best.compress(&data)?;
221
222        // Best should compress better than default, default better than fast
223        assert!(best_compressed.len() <= default_compressed.len());
224        assert!(default_compressed.len() <= fast_compressed.len());
225
226        Ok(())
227    }
228
229    #[test]
230    fn test_estimate_compression_ratio() {
231        // Highly repetitive data
232        let repetitive = vec![0u8; 1000];
233        let ratio1 = estimate_compression_ratio(&repetitive);
234        assert!(ratio1 > 0.9);
235
236        // Random-like data
237        let random: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
238        let ratio2 = estimate_compression_ratio(&random);
239        assert!(ratio2 < ratio1);
240    }
241
242    #[test]
243    fn test_should_compress() {
244        // Too small
245        let small = vec![0u8; 10];
246        assert!(!should_compress(&small, 100));
247
248        // Large and repetitive
249        let large_repetitive = vec![0u8; 1000];
250        assert!(should_compress(&large_repetitive, 100));
251
252        // Large but random
253        let large_random: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
254        // This might or might not compress well depending on the threshold
255        let _ = should_compress(&large_random, 100);
256    }
257}