sockudo_ws/
compression.rs1use std::sync::Arc;
10
11use bytes::Bytes;
12use parking_lot::Mutex;
13
14use crate::Compression;
15use crate::deflate::{DeflateConfig, DeflateContext, DeflateDecoder, DeflateEncoder};
16use crate::error::Result;
17
18const SHARED_POOL_SIZE: usize = 4;
20
21pub enum CompressionContext {
23 Disabled,
25 Dedicated(DeflateContext),
27 Shared {
29 pool: Arc<SharedCompressorPool>,
30 decoder: DeflateDecoder,
31 config: DeflateConfig,
32 },
33}
34
35impl CompressionContext {
36 pub fn server(mode: Compression) -> Self {
38 match mode {
39 Compression::Disabled => CompressionContext::Disabled,
40 Compression::Shared => {
41 let config = mode.to_deflate_config().unwrap();
42 CompressionContext::Shared {
43 pool: Arc::new(SharedCompressorPool::new(config.clone())),
44 decoder: DeflateDecoder::new(
45 config.client_max_window_bits,
46 config.client_no_context_takeover,
47 ),
48 config,
49 }
50 }
51 _ => {
52 let config = mode.to_deflate_config().unwrap();
53 CompressionContext::Dedicated(DeflateContext::server(config))
54 }
55 }
56 }
57
58 pub fn client(mode: Compression) -> Self {
60 match mode {
61 Compression::Disabled => CompressionContext::Disabled,
62 Compression::Shared => {
63 let config = mode.to_deflate_config().unwrap();
64 CompressionContext::Shared {
65 pool: Arc::new(SharedCompressorPool::new(config.clone())),
66 decoder: DeflateDecoder::new(
67 config.server_max_window_bits,
68 config.server_no_context_takeover,
69 ),
70 config,
71 }
72 }
73 _ => {
74 let config = mode.to_deflate_config().unwrap();
75 CompressionContext::Dedicated(DeflateContext::client(config))
76 }
77 }
78 }
79
80 pub fn with_shared_pool(pool: Arc<SharedCompressorPool>, is_server: bool) -> Self {
82 let config = pool.config.clone();
83 let decoder = if is_server {
84 DeflateDecoder::new(
85 config.client_max_window_bits,
86 config.client_no_context_takeover,
87 )
88 } else {
89 DeflateDecoder::new(
90 config.server_max_window_bits,
91 config.server_no_context_takeover,
92 )
93 };
94
95 CompressionContext::Shared {
96 pool,
97 decoder,
98 config,
99 }
100 }
101
102 #[inline]
104 pub fn is_enabled(&self) -> bool {
105 !matches!(self, CompressionContext::Disabled)
106 }
107
108 pub fn compress(&mut self, data: &[u8]) -> Result<Option<Bytes>> {
112 match self {
113 CompressionContext::Disabled => Ok(None),
114 CompressionContext::Dedicated(ctx) => ctx.compress(data),
115 CompressionContext::Shared { pool, config, .. } => {
116 if data.len() < config.compression_threshold {
118 return Ok(None);
119 }
120 pool.compress(data)
121 }
122 }
123 }
124
125 pub fn decompress(&mut self, data: &[u8], max_size: usize) -> Result<Bytes> {
127 match self {
128 CompressionContext::Disabled => {
129 Ok(Bytes::copy_from_slice(data))
132 }
133 CompressionContext::Dedicated(ctx) => ctx.decompress(data, max_size),
134 CompressionContext::Shared { decoder, .. } => decoder.decompress(data, max_size),
135 }
136 }
137
138 pub fn config(&self) -> Option<&DeflateConfig> {
140 match self {
141 CompressionContext::Disabled => None,
142 CompressionContext::Dedicated(ctx) => Some(&ctx.config),
143 CompressionContext::Shared { config, .. } => Some(config),
144 }
145 }
146}
147
148pub struct SharedCompressorPool {
153 encoders: Vec<Mutex<DeflateEncoder>>,
155 config: DeflateConfig,
157 next_encoder: std::sync::atomic::AtomicUsize,
159}
160
161impl SharedCompressorPool {
162 pub fn new(config: DeflateConfig) -> Self {
164 let encoders = (0..SHARED_POOL_SIZE)
165 .map(|_| {
166 Mutex::new(DeflateEncoder::new(
167 config.server_max_window_bits,
168 true, config.compression_level,
170 config.compression_threshold,
171 ))
172 })
173 .collect();
174
175 Self {
176 encoders,
177 config,
178 next_encoder: std::sync::atomic::AtomicUsize::new(0),
179 }
180 }
181
182 pub fn compress(&self, data: &[u8]) -> Result<Option<Bytes>> {
184 let idx = self
186 .next_encoder
187 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
188 % SHARED_POOL_SIZE;
189
190 let mut encoder = self.encoders[idx].lock();
191 encoder.compress(data)
192 }
193
194 pub fn config(&self) -> &DeflateConfig {
196 &self.config
197 }
198}
199
200static GLOBAL_POOL: std::sync::OnceLock<Arc<SharedCompressorPool>> = std::sync::OnceLock::new();
205
206pub fn global_shared_pool() -> Arc<SharedCompressorPool> {
208 GLOBAL_POOL
209 .get_or_init(|| {
210 let config = Compression::Shared.to_deflate_config().unwrap();
211 Arc::new(SharedCompressorPool::new(config))
212 })
213 .clone()
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn test_compression_context_disabled() {
222 let mut ctx = CompressionContext::server(Compression::Disabled);
223 assert!(!ctx.is_enabled());
224
225 let result = ctx.compress(b"Hello, World!").unwrap();
226 assert!(result.is_none());
227 }
228
229 #[test]
230 fn test_compression_context_dedicated() {
231 let mut ctx = CompressionContext::server(Compression::Dedicated);
232 assert!(ctx.is_enabled());
233
234 let data = b"Hello, World! This is a test message that should be compressed. ".repeat(10);
236 let compressed = ctx.compress(&data).unwrap();
237 assert!(compressed.is_some());
238
239 let compressed = compressed.unwrap();
240 assert!(compressed.len() < data.len());
241
242 let decompressed = ctx.decompress(&compressed, 1024 * 1024).unwrap();
244 assert_eq!(decompressed.as_ref(), data.as_slice());
245 }
246
247 #[test]
248 fn test_shared_pool() {
249 let config = Compression::Shared.to_deflate_config().unwrap();
250 let pool = SharedCompressorPool::new(config);
251
252 let data = b"Hello, World! This is a test message that should be compressed. ".repeat(10);
254
255 let compressed1 = pool.compress(&data).unwrap();
256 let compressed2 = pool.compress(&data).unwrap();
257
258 assert!(compressed1.is_some());
259 assert!(compressed2.is_some());
260
261 let c1 = compressed1.unwrap();
263 let c2 = compressed2.unwrap();
264 assert!(c1.len() < data.len());
265 assert!(c2.len() < data.len());
266 }
267
268 #[test]
269 fn test_compression_modes_configs() {
270 for mode in [
272 Compression::Disabled,
273 Compression::Dedicated,
274 Compression::Shared,
275 Compression::Window256B,
276 Compression::Window1KB,
277 Compression::Window2KB,
278 Compression::Window4KB,
279 Compression::Window8KB,
280 Compression::Window16KB,
281 Compression::Window32KB,
282 ] {
283 if mode == Compression::Disabled {
284 assert!(mode.to_deflate_config().is_none());
285 } else {
286 let config = mode.to_deflate_config();
287 assert!(config.is_some(), "Mode {:?} should have config", mode);
288
289 let config = config.unwrap();
290 assert!(config.server_max_window_bits >= 8);
291 assert!(config.server_max_window_bits <= 15);
292 }
293 }
294 }
295
296 #[test]
297 fn test_window_sizes() {
298 assert_eq!(Compression::Disabled.window_bits(), 0);
299 assert_eq!(Compression::Window256B.window_bits(), 8);
300 assert_eq!(Compression::Window1KB.window_bits(), 10);
301 assert_eq!(Compression::Window2KB.window_bits(), 11);
302 assert_eq!(Compression::Window4KB.window_bits(), 12);
303 assert_eq!(Compression::Window8KB.window_bits(), 13);
304 assert_eq!(Compression::Window16KB.window_bits(), 14);
305 assert_eq!(Compression::Window32KB.window_bits(), 15);
306 assert_eq!(Compression::Dedicated.window_bits(), 15);
307 assert_eq!(Compression::Shared.window_bits(), 15);
308 }
309}