oxigdal_websocket/protocol/
compression.rs1use crate::error::{Error, Result};
4use bytes::{Bytes, BytesMut};
5use std::io::{Read, Write};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum CompressionType {
10 None,
12 Gzip,
14 Zstd,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum CompressionLevel {
21 Fast,
23 Default,
25 Best,
27}
28
29impl CompressionLevel {
30 pub fn zstd_level(&self) -> i32 {
32 match self {
33 CompressionLevel::Fast => 1,
34 CompressionLevel::Default => 3,
35 CompressionLevel::Best => 19,
36 }
37 }
38
39 pub fn gzip_level(&self) -> flate2::Compression {
41 match self {
42 CompressionLevel::Fast => flate2::Compression::fast(),
43 CompressionLevel::Default => flate2::Compression::default(),
44 CompressionLevel::Best => flate2::Compression::best(),
45 }
46 }
47}
48
49pub struct CompressionCodec {
51 compression_type: CompressionType,
52 level: CompressionLevel,
53}
54
55impl CompressionCodec {
56 pub fn new(compression_type: CompressionType, level: CompressionLevel) -> Self {
58 Self {
59 compression_type,
60 level,
61 }
62 }
63
64 pub fn compress(&self, data: &[u8]) -> Result<BytesMut> {
66 match self.compression_type {
67 CompressionType::None => Ok(BytesMut::from(data)),
68 CompressionType::Gzip => self.compress_gzip(data),
69 CompressionType::Zstd => self.compress_zstd(data),
70 }
71 }
72
73 pub fn decompress(&self, data: &[u8]) -> Result<Bytes> {
75 match self.compression_type {
76 CompressionType::None => Ok(Bytes::copy_from_slice(data)),
77 CompressionType::Gzip => self.decompress_gzip(data),
78 CompressionType::Zstd => self.decompress_zstd(data),
79 }
80 }
81
82 fn compress_gzip(&self, data: &[u8]) -> Result<BytesMut> {
84 use flate2::write::GzEncoder;
85
86 let mut encoder = GzEncoder::new(Vec::new(), self.level.gzip_level());
87 encoder
88 .write_all(data)
89 .map_err(|e| Error::Compression(format!("Gzip compression failed: {}", e)))?;
90
91 let compressed = encoder
92 .finish()
93 .map_err(|e| Error::Compression(format!("Gzip finish failed: {}", e)))?;
94
95 Ok(BytesMut::from(&compressed[..]))
96 }
97
98 fn decompress_gzip(&self, data: &[u8]) -> Result<Bytes> {
100 use flate2::read::GzDecoder;
101
102 let mut decoder = GzDecoder::new(data);
103 let mut decompressed = Vec::new();
104
105 decoder
106 .read_to_end(&mut decompressed)
107 .map_err(|e| Error::Compression(format!("Gzip decompression failed: {}", e)))?;
108
109 Ok(Bytes::from(decompressed))
110 }
111
112 fn compress_zstd(&self, data: &[u8]) -> Result<BytesMut> {
114 let compressed = zstd::encode_all(data, self.level.zstd_level())
115 .map_err(|e| Error::Compression(format!("Zstd compression failed: {}", e)))?;
116
117 Ok(BytesMut::from(&compressed[..]))
118 }
119
120 fn decompress_zstd(&self, data: &[u8]) -> Result<Bytes> {
122 let decompressed = zstd::decode_all(data)
123 .map_err(|e| Error::Compression(format!("Zstd decompression failed: {}", e)))?;
124
125 Ok(Bytes::from(decompressed))
126 }
127
128 pub fn compression_type(&self) -> CompressionType {
130 self.compression_type
131 }
132
133 pub fn level(&self) -> CompressionLevel {
135 self.level
136 }
137}
138
139pub fn estimate_compression_ratio(data: &[u8]) -> f64 {
141 let mut seen = [false; 256];
143 let mut unique_count = 0;
144
145 for &byte in data {
146 if !seen[byte as usize] {
147 seen[byte as usize] = true;
148 unique_count += 1;
149 }
150 }
151
152 let ratio = unique_count as f64 / 256.0;
154 1.0 - ratio }
156
157pub fn should_compress(data: &[u8], min_size: usize) -> bool {
159 if data.len() < min_size {
160 return false;
161 }
162
163 estimate_compression_ratio(data) > 0.3
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_gzip_compression() -> Result<()> {
173 let codec = CompressionCodec::new(CompressionType::Gzip, CompressionLevel::Default);
174 let data = b"Hello, World! This is a test message.".repeat(10);
175
176 let compressed = codec.compress(&data)?;
177 let decompressed = codec.decompress(&compressed)?;
178
179 assert_eq!(data.as_slice(), decompressed.as_ref());
180 assert!(compressed.len() < data.len());
181 Ok(())
182 }
183
184 #[test]
185 fn test_zstd_compression() -> Result<()> {
186 let codec = CompressionCodec::new(CompressionType::Zstd, CompressionLevel::Default);
187 let data = b"Hello, World! This is a test message.".repeat(10);
188
189 let compressed = codec.compress(&data)?;
190 let decompressed = codec.decompress(&compressed)?;
191
192 assert_eq!(data.as_slice(), decompressed.as_ref());
193 assert!(compressed.len() < data.len());
194 Ok(())
195 }
196
197 #[test]
198 fn test_no_compression() -> Result<()> {
199 let codec = CompressionCodec::new(CompressionType::None, CompressionLevel::Default);
200 let data = b"Hello, World!";
201
202 let compressed = codec.compress(data)?;
203 let decompressed = codec.decompress(&compressed)?;
204
205 assert_eq!(data, compressed.as_ref());
206 assert_eq!(data, decompressed.as_ref());
207 Ok(())
208 }
209
210 #[test]
211 fn test_compression_levels() -> Result<()> {
212 let data = b"Hello, World! This is a test message.".repeat(100);
213
214 let fast = CompressionCodec::new(CompressionType::Zstd, CompressionLevel::Fast);
215 let default = CompressionCodec::new(CompressionType::Zstd, CompressionLevel::Default);
216 let best = CompressionCodec::new(CompressionType::Zstd, CompressionLevel::Best);
217
218 let fast_compressed = fast.compress(&data)?;
219 let default_compressed = default.compress(&data)?;
220 let best_compressed = best.compress(&data)?;
221
222 assert!(best_compressed.len() <= default_compressed.len());
224 assert!(default_compressed.len() <= fast_compressed.len());
225
226 Ok(())
227 }
228
229 #[test]
230 fn test_estimate_compression_ratio() {
231 let repetitive = vec![0u8; 1000];
233 let ratio1 = estimate_compression_ratio(&repetitive);
234 assert!(ratio1 > 0.9);
235
236 let random: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
238 let ratio2 = estimate_compression_ratio(&random);
239 assert!(ratio2 < ratio1);
240 }
241
242 #[test]
243 fn test_should_compress() {
244 let small = vec![0u8; 10];
246 assert!(!should_compress(&small, 100));
247
248 let large_repetitive = vec![0u8; 1000];
250 assert!(should_compress(&large_repetitive, 100));
251
252 let large_random: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
254 let _ = should_compress(&large_random, 100);
256 }
257}