sockudo_ws/
deflate.rs

1//! Per-Message Deflate Extension (RFC 7692)
2//!
3//! This module implements the permessage-deflate WebSocket extension,
4//! which compresses message payloads using the DEFLATE algorithm.
5
6use bytes::{Bytes, BytesMut};
7use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status};
8
9use crate::error::{Error, Result};
10
11/// Trailer bytes that must be removed after compression and added before decompression
12const DEFLATE_TRAILER: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
13
14/// Default LZ77 window size (32KB = 2^15)
15pub const DEFAULT_WINDOW_BITS: u8 = 15;
16
17/// Minimum LZ77 window size (256 bytes = 2^8)
18pub const MIN_WINDOW_BITS: u8 = 8;
19
20/// Maximum LZ77 window size (32KB = 2^15)
21pub const MAX_WINDOW_BITS: u8 = 15;
22
23/// Configuration for permessage-deflate extension
24#[derive(Debug, Clone)]
25pub struct DeflateConfig {
26    /// Server's maximum LZ77 window bits (for compression when server, decompression when client)
27    pub server_max_window_bits: u8,
28    /// Client's maximum LZ77 window bits (for compression when client, decompression when server)
29    pub client_max_window_bits: u8,
30    /// If true, server must reset compression context after each message
31    pub server_no_context_takeover: bool,
32    /// If true, client must reset compression context after each message
33    pub client_no_context_takeover: bool,
34    /// Compression level (0-9, where 0 is no compression, 9 is max)
35    pub compression_level: u32,
36    /// Minimum message size to compress (smaller messages may not benefit)
37    pub compression_threshold: usize,
38}
39
40impl Default for DeflateConfig {
41    fn default() -> Self {
42        Self {
43            server_max_window_bits: DEFAULT_WINDOW_BITS,
44            client_max_window_bits: DEFAULT_WINDOW_BITS,
45            server_no_context_takeover: false,
46            client_no_context_takeover: false,
47            compression_level: 6,      // Default zlib compression level
48            compression_threshold: 32, // Don't compress tiny messages
49        }
50    }
51}
52
53impl DeflateConfig {
54    /// Create config optimized for low memory usage
55    pub fn low_memory() -> Self {
56        Self {
57            server_max_window_bits: 10, // 1KB window
58            client_max_window_bits: 10,
59            server_no_context_takeover: true,
60            client_no_context_takeover: true,
61            compression_level: 1, // Fast compression
62            compression_threshold: 64,
63        }
64    }
65
66    /// Create config optimized for best compression
67    pub fn best_compression() -> Self {
68        Self {
69            server_max_window_bits: MAX_WINDOW_BITS,
70            client_max_window_bits: MAX_WINDOW_BITS,
71            server_no_context_takeover: false,
72            client_no_context_takeover: false,
73            compression_level: 9,
74            compression_threshold: 16,
75        }
76    }
77
78    /// Parse extension parameters from handshake
79    pub fn from_params(params: &[(&str, Option<&str>)]) -> Result<Self> {
80        let mut config = Self::default();
81
82        for (name, value) in params {
83            match *name {
84                "server_no_context_takeover" => {
85                    if value.is_some() {
86                        return Err(Error::HandshakeFailed(
87                            "server_no_context_takeover must not have a value",
88                        ));
89                    }
90                    config.server_no_context_takeover = true;
91                }
92                "client_no_context_takeover" => {
93                    if value.is_some() {
94                        return Err(Error::HandshakeFailed(
95                            "client_no_context_takeover must not have a value",
96                        ));
97                    }
98                    config.client_no_context_takeover = true;
99                    // When client uses no_context_takeover, server should too
100                    // to ensure decompression works correctly on client side
101                    config.server_no_context_takeover = true;
102                }
103                "server_max_window_bits" => {
104                    if let Some(v) = value {
105                        let bits: u8 = v.parse().map_err(|_| {
106                            Error::HandshakeFailed("invalid server_max_window_bits value")
107                        })?;
108                        if !(MIN_WINDOW_BITS..=MAX_WINDOW_BITS).contains(&bits) {
109                            return Err(Error::HandshakeFailed(
110                                "server_max_window_bits out of range (8-15)",
111                            ));
112                        }
113                        config.server_max_window_bits = bits;
114                    }
115                }
116                "client_max_window_bits" => {
117                    if let Some(v) = value {
118                        let bits: u8 = v.parse().map_err(|_| {
119                            Error::HandshakeFailed("invalid client_max_window_bits value")
120                        })?;
121                        if !(MIN_WINDOW_BITS..=MAX_WINDOW_BITS).contains(&bits) {
122                            return Err(Error::HandshakeFailed(
123                                "client_max_window_bits out of range (8-15)",
124                            ));
125                        }
126                        config.client_max_window_bits = bits;
127                    }
128                    // If no value, client just indicates support
129                }
130                _ => {
131                    return Err(Error::HandshakeFailed(
132                        "unknown permessage-deflate parameter",
133                    ));
134                }
135            }
136        }
137
138        Ok(config)
139    }
140
141    /// Generate extension response header value for server
142    pub fn to_response_header(&self) -> String {
143        let mut parts = vec!["permessage-deflate".to_string()];
144
145        if self.server_no_context_takeover {
146            parts.push("server_no_context_takeover".to_string());
147        }
148        if self.client_no_context_takeover {
149            parts.push("client_no_context_takeover".to_string());
150        }
151        if self.server_max_window_bits < MAX_WINDOW_BITS {
152            parts.push(format!(
153                "server_max_window_bits={}",
154                self.server_max_window_bits
155            ));
156        }
157        if self.client_max_window_bits < MAX_WINDOW_BITS {
158            parts.push(format!(
159                "client_max_window_bits={}",
160                self.client_max_window_bits
161            ));
162        }
163
164        parts.join("; ")
165    }
166}
167
168/// Deflate compressor for outgoing messages
169pub struct DeflateEncoder {
170    compress: Compress,
171    no_context_takeover: bool,
172    window_bits: u8,
173    compression_level: Compression,
174    threshold: usize,
175}
176
177impl DeflateEncoder {
178    /// Create a new encoder
179    pub fn new(window_bits: u8, no_context_takeover: bool, level: u32, threshold: usize) -> Self {
180        let compression_level = Compression::new(level);
181        // Use the negotiated window_bits for compression
182        // This ensures the compressed data can be decompressed by clients with smaller windows
183        let compress = Compress::new_with_window_bits(compression_level, false, window_bits);
184
185        Self {
186            compress,
187            no_context_takeover,
188            window_bits,
189            compression_level,
190            threshold,
191        }
192    }
193
194    /// Compress a message payload
195    ///
196    /// Returns None if the message is too small to benefit from compression
197    /// or if compression would make it larger.
198    pub fn compress(&mut self, data: &[u8]) -> Result<Option<Bytes>> {
199        if data.len() < self.threshold {
200            return Ok(None);
201        }
202
203        // Reset context if required
204        if self.no_context_takeover {
205            self.compress.reset();
206        }
207
208        // Estimate output size (compressed data is often smaller, but we need headroom)
209        let max_output = data.len() + 64;
210        let mut output = BytesMut::with_capacity(max_output);
211
212        // Compress the data
213        let mut total_in: usize = 0;
214        let mut iterations = 0u32;
215
216        loop {
217            iterations += 1;
218            if iterations > 100_000 {
219                return Err(Error::Compression(
220                    "compression took too many iterations".into(),
221                ));
222            }
223
224            // Ensure we have space in output buffer
225            let available = output.capacity() - output.len();
226            if available == 0 {
227                output.reserve(4096);
228            }
229
230            let input = &data[total_in..];
231            let before_out = self.compress.total_out();
232            let before_in = self.compress.total_in();
233
234            // Get writable slice
235            let out_start = output.len();
236            let out_capacity = output.capacity();
237            unsafe {
238                output.set_len(out_capacity);
239            }
240
241            let status = self
242                .compress
243                .compress(input, &mut output[out_start..], FlushCompress::Sync)
244                .map_err(|e| Error::Compression(format!("deflate error: {}", e)))?;
245
246            let consumed = (self.compress.total_in() - before_in) as usize;
247            let produced = (self.compress.total_out() - before_out) as usize;
248
249            total_in += consumed;
250
251            unsafe {
252                output.set_len(out_start + produced);
253            }
254
255            match status {
256                Status::Ok | Status::BufError => {
257                    if total_in >= data.len() {
258                        break;
259                    }
260                }
261                Status::StreamEnd => break,
262            }
263        }
264
265        // Per RFC 7692: Remove trailing 0x00 0x00 0xff 0xff
266        if output.len() >= 4 && output.ends_with(&DEFLATE_TRAILER) {
267            output.truncate(output.len() - 4);
268        }
269
270        // Only use compression if it actually reduces size
271        if output.len() >= data.len() {
272            return Ok(None);
273        }
274
275        Ok(Some(output.freeze()))
276    }
277
278    /// Reset the compression context (for no_context_takeover)
279    pub fn reset(&mut self) {
280        self.compress.reset();
281    }
282}
283
284/// Deflate decompressor for incoming messages
285pub struct DeflateDecoder {
286    decompress: Decompress,
287    no_context_takeover: bool,
288    window_bits: u8,
289}
290
291impl DeflateDecoder {
292    /// Create a new decoder
293    pub fn new(window_bits: u8, no_context_takeover: bool) -> Self {
294        // Use raw deflate (no zlib header) with the negotiated window_bits
295        let decompress = Decompress::new_with_window_bits(false, window_bits);
296
297        Self {
298            decompress,
299            no_context_takeover,
300            window_bits,
301        }
302    }
303
304    /// Decompress a message payload
305    pub fn decompress(&mut self, data: &[u8], max_size: usize) -> Result<Bytes> {
306        // Reset context if required
307        if self.no_context_takeover {
308            self.decompress.reset(false);
309        }
310
311        // Per RFC 7692: Append 0x00 0x00 0xff 0xff before decompressing
312        let mut input = BytesMut::with_capacity(data.len() + 4);
313        input.extend_from_slice(data);
314        input.extend_from_slice(&DEFLATE_TRAILER);
315
316        // Start with reasonable output buffer (at least 1KB or 4x input)
317        let initial_cap = std::cmp::max(1024, data.len() * 4);
318        let mut output = BytesMut::with_capacity(initial_cap);
319        let mut total_in: usize = 0;
320        let mut iterations = 0u32;
321
322        loop {
323            iterations += 1;
324            // Safety check to prevent infinite loops
325            if iterations > 100_000 {
326                return Err(Error::Compression(
327                    "decompression took too many iterations".into(),
328                ));
329            }
330
331            // Check size limit
332            if output.len() > max_size {
333                return Err(Error::MessageTooLarge);
334            }
335
336            // Ensure we have space in output buffer
337            let available = output.capacity() - output.len();
338            if available == 0 {
339                if output.capacity() >= max_size {
340                    return Err(Error::MessageTooLarge);
341                }
342                // At least double or add 4KB, whichever is larger
343                let additional = std::cmp::max(output.capacity(), 4096);
344                output.reserve(additional);
345            }
346
347            let before_out = self.decompress.total_out();
348            let before_in = self.decompress.total_in();
349
350            // Get writable slice
351            let out_start = output.len();
352            let out_capacity = output.capacity();
353            unsafe {
354                output.set_len(out_capacity);
355            }
356
357            let status = self
358                .decompress
359                .decompress(
360                    &input[total_in..],
361                    &mut output[out_start..],
362                    FlushDecompress::Sync,
363                )
364                .map_err(|e| Error::Compression(format!("inflate error: {}", e)))?;
365
366            let consumed = (self.decompress.total_in() - before_in) as usize;
367            let produced = (self.decompress.total_out() - before_out) as usize;
368
369            total_in += consumed;
370
371            unsafe {
372                output.set_len(out_start + produced);
373            }
374
375            match status {
376                Status::Ok => {
377                    if total_in >= input.len() {
378                        break;
379                    }
380                }
381                Status::StreamEnd => break,
382                Status::BufError => {
383                    // Need more output space - will be handled at top of loop
384                }
385            }
386        }
387
388        Ok(output.freeze())
389    }
390
391    /// Reset the decompression context (for no_context_takeover)
392    pub fn reset(&mut self) {
393        self.decompress.reset(false);
394    }
395}
396
397/// Combined compressor/decompressor context for a WebSocket connection
398pub struct DeflateContext {
399    /// Encoder for outgoing messages
400    pub encoder: DeflateEncoder,
401    /// Decoder for incoming messages
402    pub decoder: DeflateDecoder,
403    /// Configuration
404    pub config: DeflateConfig,
405}
406
407impl DeflateContext {
408    /// Create context for server role
409    pub fn server(config: DeflateConfig) -> Self {
410        let encoder = DeflateEncoder::new(
411            config.server_max_window_bits,
412            config.server_no_context_takeover,
413            config.compression_level,
414            config.compression_threshold,
415        );
416        let decoder = DeflateDecoder::new(
417            config.client_max_window_bits,
418            config.client_no_context_takeover,
419        );
420
421        Self {
422            encoder,
423            decoder,
424            config,
425        }
426    }
427
428    /// Create context for client role
429    pub fn client(config: DeflateConfig) -> Self {
430        let encoder = DeflateEncoder::new(
431            config.client_max_window_bits,
432            config.client_no_context_takeover,
433            config.compression_level,
434            config.compression_threshold,
435        );
436        let decoder = DeflateDecoder::new(
437            config.server_max_window_bits,
438            config.server_no_context_takeover,
439        );
440
441        Self {
442            encoder,
443            decoder,
444            config,
445        }
446    }
447
448    /// Compress a message if beneficial
449    pub fn compress(&mut self, data: &[u8]) -> Result<Option<Bytes>> {
450        self.encoder.compress(data)
451    }
452
453    /// Decompress a message
454    pub fn decompress(&mut self, data: &[u8], max_size: usize) -> Result<Bytes> {
455        self.decoder.decompress(data, max_size)
456    }
457}
458
459/// Parse permessage-deflate extension parameters from header value
460pub fn parse_deflate_offer(value: &str) -> Option<Vec<(&str, Option<&str>)>> {
461    let value = value.trim();
462
463    // Check if this is a permessage-deflate offer
464    if !value.starts_with("permessage-deflate") {
465        return None;
466    }
467
468    let rest = value.strip_prefix("permessage-deflate")?.trim_start();
469
470    if rest.is_empty() {
471        return Some(Vec::new());
472    }
473
474    // Must start with semicolon if there are parameters
475    if !rest.starts_with(';') {
476        return None;
477    }
478
479    let mut params = Vec::new();
480
481    for part in rest[1..].split(';') {
482        let part = part.trim();
483        if part.is_empty() {
484            continue;
485        }
486
487        if let Some((name, value)) = part.split_once('=') {
488            let name = name.trim();
489            let value = value.trim().trim_matches('"');
490            params.push((name, Some(value)));
491        } else {
492            params.push((part, None));
493        }
494    }
495
496    Some(params)
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    #[test]
504    fn test_compress_decompress() {
505        let config = DeflateConfig::default();
506        let mut ctx = DeflateContext::server(config);
507
508        let original = b"Hello, World! This is a test message that should be compressed.";
509
510        // Compress
511        let compressed = ctx.compress(original).unwrap();
512        assert!(compressed.is_some());
513        let compressed = compressed.unwrap();
514        assert!(compressed.len() < original.len());
515
516        // Decompress
517        let decompressed = ctx.decompress(&compressed, 1024).unwrap();
518        assert_eq!(&decompressed[..], &original[..]);
519    }
520
521    #[test]
522    fn test_small_message_not_compressed() {
523        let config = DeflateConfig {
524            compression_threshold: 100,
525            ..Default::default()
526        };
527        let mut ctx = DeflateContext::server(config);
528
529        let small = b"tiny";
530        let result = ctx.compress(small).unwrap();
531        assert!(result.is_none());
532    }
533
534    #[test]
535    fn test_context_takeover() {
536        let config = DeflateConfig {
537            server_no_context_takeover: false,
538            compression_threshold: 0,
539            ..Default::default()
540        };
541        let mut ctx = DeflateContext::server(config);
542
543        let msg = b"Hello, World! Hello, World! Hello, World!";
544
545        // First compression
546        let first = ctx.compress(msg).unwrap().unwrap();
547
548        // Second compression should benefit from context
549        let second = ctx.compress(msg).unwrap().unwrap();
550
551        // With context takeover, second should be smaller or equal
552        // (references previous data in LZ77 window)
553        assert!(second.len() <= first.len());
554    }
555
556    #[test]
557    fn test_no_context_takeover() {
558        let config = DeflateConfig {
559            server_no_context_takeover: true,
560            compression_threshold: 0,
561            ..Default::default()
562        };
563        let mut ctx = DeflateContext::server(config);
564
565        let msg = b"Hello, World! Hello, World! Hello, World!";
566
567        // Both compressions should produce same output
568        let first = ctx.compress(msg).unwrap().unwrap();
569        let second = ctx.compress(msg).unwrap().unwrap();
570
571        assert_eq!(first.len(), second.len());
572    }
573
574    #[test]
575    fn test_parse_deflate_offer() {
576        // Simple offer
577        let params = parse_deflate_offer("permessage-deflate").unwrap();
578        assert!(params.is_empty());
579
580        // With parameters
581        let params = parse_deflate_offer(
582            "permessage-deflate; server_no_context_takeover; server_max_window_bits=10",
583        )
584        .unwrap();
585        assert_eq!(params.len(), 2);
586        assert_eq!(params[0], ("server_no_context_takeover", None));
587        assert_eq!(params[1], ("server_max_window_bits", Some("10")));
588
589        // Not a deflate offer
590        assert!(parse_deflate_offer("some-other-extension").is_none());
591    }
592
593    #[test]
594    fn test_config_from_params() {
595        let params = vec![
596            ("server_no_context_takeover", None),
597            ("client_max_window_bits", Some("12")),
598        ];
599
600        let config = DeflateConfig::from_params(&params).unwrap();
601        assert!(config.server_no_context_takeover);
602        assert!(!config.client_no_context_takeover);
603        assert_eq!(config.client_max_window_bits, 12);
604        assert_eq!(config.server_max_window_bits, DEFAULT_WINDOW_BITS);
605    }
606
607    #[test]
608    fn test_response_header() {
609        let config = DeflateConfig {
610            server_no_context_takeover: true,
611            server_max_window_bits: 12,
612            ..Default::default()
613        };
614
615        let header = config.to_response_header();
616        assert!(header.contains("permessage-deflate"));
617        assert!(header.contains("server_no_context_takeover"));
618        assert!(header.contains("server_max_window_bits=12"));
619    }
620}