sockudo_ws/
compression.rs

1//! Compression management for WebSocket connections
2//!
3//! This module provides compression support with multiple modes:
4//! - **Disabled**: No compression
5//! - **Dedicated**: Each connection has its own compressor
6//! - **Shared**: Connections share a pool of compressors
7//! - **Window sizes**: Various window sizes for memory/compression tradeoffs
8
9use std::sync::Arc;
10
11use bytes::Bytes;
12use parking_lot::Mutex;
13
14use crate::Compression;
15use crate::deflate::{DeflateConfig, DeflateContext, DeflateDecoder, DeflateEncoder};
16use crate::error::Result;
17
18/// Number of compressors in the shared pool
19const SHARED_POOL_SIZE: usize = 4;
20
21/// A compression context that can be either dedicated or shared
22pub enum CompressionContext {
23    /// No compression
24    Disabled,
25    /// Dedicated per-connection compressor
26    Dedicated(DeflateContext),
27    /// Shared compressor from pool (encoder only, decoder is per-connection)
28    Shared {
29        pool: Arc<SharedCompressorPool>,
30        decoder: DeflateDecoder,
31        config: DeflateConfig,
32    },
33}
34
35impl CompressionContext {
36    /// Create a new compression context for the given mode (server role)
37    pub fn server(mode: Compression) -> Self {
38        match mode {
39            Compression::Disabled => CompressionContext::Disabled,
40            Compression::Shared => {
41                let config = mode.to_deflate_config().unwrap();
42                CompressionContext::Shared {
43                    pool: Arc::new(SharedCompressorPool::new(config.clone())),
44                    decoder: DeflateDecoder::new(
45                        config.client_max_window_bits,
46                        config.client_no_context_takeover,
47                    ),
48                    config,
49                }
50            }
51            _ => {
52                let config = mode.to_deflate_config().unwrap();
53                CompressionContext::Dedicated(DeflateContext::server(config))
54            }
55        }
56    }
57
58    /// Create a new compression context for the given mode (client role)
59    pub fn client(mode: Compression) -> Self {
60        match mode {
61            Compression::Disabled => CompressionContext::Disabled,
62            Compression::Shared => {
63                let config = mode.to_deflate_config().unwrap();
64                CompressionContext::Shared {
65                    pool: Arc::new(SharedCompressorPool::new(config.clone())),
66                    decoder: DeflateDecoder::new(
67                        config.server_max_window_bits,
68                        config.server_no_context_takeover,
69                    ),
70                    config,
71                }
72            }
73            _ => {
74                let config = mode.to_deflate_config().unwrap();
75                CompressionContext::Dedicated(DeflateContext::client(config))
76            }
77        }
78    }
79
80    /// Create a shared context that uses an existing pool
81    pub fn with_shared_pool(pool: Arc<SharedCompressorPool>, is_server: bool) -> Self {
82        let config = pool.config.clone();
83        let decoder = if is_server {
84            DeflateDecoder::new(
85                config.client_max_window_bits,
86                config.client_no_context_takeover,
87            )
88        } else {
89            DeflateDecoder::new(
90                config.server_max_window_bits,
91                config.server_no_context_takeover,
92            )
93        };
94
95        CompressionContext::Shared {
96            pool,
97            decoder,
98            config,
99        }
100    }
101
102    /// Check if compression is enabled
103    #[inline]
104    pub fn is_enabled(&self) -> bool {
105        !matches!(self, CompressionContext::Disabled)
106    }
107
108    /// Compress a message payload
109    ///
110    /// Returns `None` if compression is disabled or if compression wouldn't reduce size.
111    pub fn compress(&mut self, data: &[u8]) -> Result<Option<Bytes>> {
112        match self {
113            CompressionContext::Disabled => Ok(None),
114            CompressionContext::Dedicated(ctx) => ctx.compress(data),
115            CompressionContext::Shared { pool, config, .. } => {
116                // Check threshold before acquiring encoder
117                if data.len() < config.compression_threshold {
118                    return Ok(None);
119                }
120                pool.compress(data)
121            }
122        }
123    }
124
125    /// Decompress a message payload
126    pub fn decompress(&mut self, data: &[u8], max_size: usize) -> Result<Bytes> {
127        match self {
128            CompressionContext::Disabled => {
129                // This shouldn't happen - protocol layer should not call decompress
130                // if compression is disabled
131                Ok(Bytes::copy_from_slice(data))
132            }
133            CompressionContext::Dedicated(ctx) => ctx.decompress(data, max_size),
134            CompressionContext::Shared { decoder, .. } => decoder.decompress(data, max_size),
135        }
136    }
137
138    /// Get the DeflateConfig for this context
139    pub fn config(&self) -> Option<&DeflateConfig> {
140        match self {
141            CompressionContext::Disabled => None,
142            CompressionContext::Dedicated(ctx) => Some(&ctx.config),
143            CompressionContext::Shared { config, .. } => Some(config),
144        }
145    }
146}
147
148/// A pool of shared compressors for the `Shared` compression mode
149///
150/// This pool allows multiple connections to share compressor instances,
151/// reducing memory usage when you have many connections.
152pub struct SharedCompressorPool {
153    /// Pool of encoders
154    encoders: Vec<Mutex<DeflateEncoder>>,
155    /// Configuration used for the pool
156    config: DeflateConfig,
157    /// Current encoder index (simple round-robin)
158    next_encoder: std::sync::atomic::AtomicUsize,
159}
160
161impl SharedCompressorPool {
162    /// Create a new shared compressor pool
163    pub fn new(config: DeflateConfig) -> Self {
164        let encoders = (0..SHARED_POOL_SIZE)
165            .map(|_| {
166                Mutex::new(DeflateEncoder::new(
167                    config.server_max_window_bits,
168                    true, // Always reset for shared mode (no context takeover)
169                    config.compression_level,
170                    config.compression_threshold,
171                ))
172            })
173            .collect();
174
175        Self {
176            encoders,
177            config,
178            next_encoder: std::sync::atomic::AtomicUsize::new(0),
179        }
180    }
181
182    /// Compress data using a pooled encoder
183    pub fn compress(&self, data: &[u8]) -> Result<Option<Bytes>> {
184        // Round-robin selection
185        let idx = self
186            .next_encoder
187            .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
188            % SHARED_POOL_SIZE;
189
190        let mut encoder = self.encoders[idx].lock();
191        encoder.compress(data)
192    }
193
194    /// Get the pool's configuration
195    pub fn config(&self) -> &DeflateConfig {
196        &self.config
197    }
198}
199
200/// Global shared compressor pool for the default `Shared` mode
201///
202/// This is initialized lazily and provides a singleton pool for
203/// all connections using `Compression::Shared`.
204static GLOBAL_POOL: std::sync::OnceLock<Arc<SharedCompressorPool>> = std::sync::OnceLock::new();
205
206/// Get the global shared compressor pool
207pub fn global_shared_pool() -> Arc<SharedCompressorPool> {
208    GLOBAL_POOL
209        .get_or_init(|| {
210            let config = Compression::Shared.to_deflate_config().unwrap();
211            Arc::new(SharedCompressorPool::new(config))
212        })
213        .clone()
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn test_compression_context_disabled() {
222        let mut ctx = CompressionContext::server(Compression::Disabled);
223        assert!(!ctx.is_enabled());
224
225        let result = ctx.compress(b"Hello, World!").unwrap();
226        assert!(result.is_none());
227    }
228
229    #[test]
230    fn test_compression_context_dedicated() {
231        let mut ctx = CompressionContext::server(Compression::Dedicated);
232        assert!(ctx.is_enabled());
233
234        // Large enough to compress
235        let data = b"Hello, World! This is a test message that should be compressed. ".repeat(10);
236        let compressed = ctx.compress(&data).unwrap();
237        assert!(compressed.is_some());
238
239        let compressed = compressed.unwrap();
240        assert!(compressed.len() < data.len());
241
242        // Decompress
243        let decompressed = ctx.decompress(&compressed, 1024 * 1024).unwrap();
244        assert_eq!(decompressed.as_ref(), data.as_slice());
245    }
246
247    #[test]
248    fn test_shared_pool() {
249        let config = Compression::Shared.to_deflate_config().unwrap();
250        let pool = SharedCompressorPool::new(config);
251
252        // Large enough to compress
253        let data = b"Hello, World! This is a test message that should be compressed. ".repeat(10);
254
255        let compressed1 = pool.compress(&data).unwrap();
256        let compressed2 = pool.compress(&data).unwrap();
257
258        assert!(compressed1.is_some());
259        assert!(compressed2.is_some());
260
261        // Both should compress to similar sizes
262        let c1 = compressed1.unwrap();
263        let c2 = compressed2.unwrap();
264        assert!(c1.len() < data.len());
265        assert!(c2.len() < data.len());
266    }
267
268    #[test]
269    fn test_compression_modes_configs() {
270        // Test all modes produce valid configs
271        for mode in [
272            Compression::Disabled,
273            Compression::Dedicated,
274            Compression::Shared,
275            Compression::Window256B,
276            Compression::Window1KB,
277            Compression::Window2KB,
278            Compression::Window4KB,
279            Compression::Window8KB,
280            Compression::Window16KB,
281            Compression::Window32KB,
282        ] {
283            if mode == Compression::Disabled {
284                assert!(mode.to_deflate_config().is_none());
285            } else {
286                let config = mode.to_deflate_config();
287                assert!(config.is_some(), "Mode {:?} should have config", mode);
288
289                let config = config.unwrap();
290                assert!(config.server_max_window_bits >= 8);
291                assert!(config.server_max_window_bits <= 15);
292            }
293        }
294    }
295
296    #[test]
297    fn test_window_sizes() {
298        assert_eq!(Compression::Disabled.window_bits(), 0);
299        assert_eq!(Compression::Window256B.window_bits(), 8);
300        assert_eq!(Compression::Window1KB.window_bits(), 10);
301        assert_eq!(Compression::Window2KB.window_bits(), 11);
302        assert_eq!(Compression::Window4KB.window_bits(), 12);
303        assert_eq!(Compression::Window8KB.window_bits(), 13);
304        assert_eq!(Compression::Window16KB.window_bits(), 14);
305        assert_eq!(Compression::Window32KB.window_bits(), 15);
306        assert_eq!(Compression::Dedicated.window_bits(), 15);
307        assert_eq!(Compression::Shared.window_bits(), 15);
308    }
309}