1use bytes::{Bytes, BytesMut};
7use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status};
8
9use crate::error::{Error, Result};
10
11const DEFLATE_TRAILER: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
13
14pub const DEFAULT_WINDOW_BITS: u8 = 15;
16
17pub const MIN_WINDOW_BITS: u8 = 8;
19
20pub const MAX_WINDOW_BITS: u8 = 15;
22
23#[derive(Debug, Clone)]
25pub struct DeflateConfig {
26 pub server_max_window_bits: u8,
28 pub client_max_window_bits: u8,
30 pub server_no_context_takeover: bool,
32 pub client_no_context_takeover: bool,
34 pub compression_level: u32,
36 pub compression_threshold: usize,
38}
39
40impl Default for DeflateConfig {
41 fn default() -> Self {
42 Self {
43 server_max_window_bits: DEFAULT_WINDOW_BITS,
44 client_max_window_bits: DEFAULT_WINDOW_BITS,
45 server_no_context_takeover: false,
46 client_no_context_takeover: false,
47 compression_level: 6, compression_threshold: 32, }
50 }
51}
52
53impl DeflateConfig {
54 pub fn low_memory() -> Self {
56 Self {
57 server_max_window_bits: 10, client_max_window_bits: 10,
59 server_no_context_takeover: true,
60 client_no_context_takeover: true,
61 compression_level: 1, compression_threshold: 64,
63 }
64 }
65
66 pub fn best_compression() -> Self {
68 Self {
69 server_max_window_bits: MAX_WINDOW_BITS,
70 client_max_window_bits: MAX_WINDOW_BITS,
71 server_no_context_takeover: false,
72 client_no_context_takeover: false,
73 compression_level: 9,
74 compression_threshold: 16,
75 }
76 }
77
78 pub fn from_params(params: &[(&str, Option<&str>)]) -> Result<Self> {
80 let mut config = Self::default();
81
82 for (name, value) in params {
83 match *name {
84 "server_no_context_takeover" => {
85 if value.is_some() {
86 return Err(Error::HandshakeFailed(
87 "server_no_context_takeover must not have a value",
88 ));
89 }
90 config.server_no_context_takeover = true;
91 }
92 "client_no_context_takeover" => {
93 if value.is_some() {
94 return Err(Error::HandshakeFailed(
95 "client_no_context_takeover must not have a value",
96 ));
97 }
98 config.client_no_context_takeover = true;
99 config.server_no_context_takeover = true;
102 }
103 "server_max_window_bits" => {
104 if let Some(v) = value {
105 let bits: u8 = v.parse().map_err(|_| {
106 Error::HandshakeFailed("invalid server_max_window_bits value")
107 })?;
108 if !(MIN_WINDOW_BITS..=MAX_WINDOW_BITS).contains(&bits) {
109 return Err(Error::HandshakeFailed(
110 "server_max_window_bits out of range (8-15)",
111 ));
112 }
113 config.server_max_window_bits = bits;
114 }
115 }
116 "client_max_window_bits" => {
117 if let Some(v) = value {
118 let bits: u8 = v.parse().map_err(|_| {
119 Error::HandshakeFailed("invalid client_max_window_bits value")
120 })?;
121 if !(MIN_WINDOW_BITS..=MAX_WINDOW_BITS).contains(&bits) {
122 return Err(Error::HandshakeFailed(
123 "client_max_window_bits out of range (8-15)",
124 ));
125 }
126 config.client_max_window_bits = bits;
127 }
128 }
130 _ => {
131 return Err(Error::HandshakeFailed(
132 "unknown permessage-deflate parameter",
133 ));
134 }
135 }
136 }
137
138 Ok(config)
139 }
140
141 pub fn to_response_header(&self) -> String {
143 let mut parts = vec!["permessage-deflate".to_string()];
144
145 if self.server_no_context_takeover {
146 parts.push("server_no_context_takeover".to_string());
147 }
148 if self.client_no_context_takeover {
149 parts.push("client_no_context_takeover".to_string());
150 }
151 if self.server_max_window_bits < MAX_WINDOW_BITS {
152 parts.push(format!(
153 "server_max_window_bits={}",
154 self.server_max_window_bits
155 ));
156 }
157 if self.client_max_window_bits < MAX_WINDOW_BITS {
158 parts.push(format!(
159 "client_max_window_bits={}",
160 self.client_max_window_bits
161 ));
162 }
163
164 parts.join("; ")
165 }
166}
167
168pub struct DeflateEncoder {
170 compress: Compress,
171 no_context_takeover: bool,
172 window_bits: u8,
173 compression_level: Compression,
174 threshold: usize,
175}
176
177impl DeflateEncoder {
178 pub fn new(window_bits: u8, no_context_takeover: bool, level: u32, threshold: usize) -> Self {
180 let compression_level = Compression::new(level);
181 let compress = Compress::new_with_window_bits(compression_level, false, window_bits);
184
185 Self {
186 compress,
187 no_context_takeover,
188 window_bits,
189 compression_level,
190 threshold,
191 }
192 }
193
194 pub fn compress(&mut self, data: &[u8]) -> Result<Option<Bytes>> {
199 if data.len() < self.threshold {
200 return Ok(None);
201 }
202
203 if self.no_context_takeover {
205 self.compress.reset();
206 }
207
208 let max_output = data.len() + 64;
210 let mut output = BytesMut::with_capacity(max_output);
211
212 let mut total_in: usize = 0;
214 let mut iterations = 0u32;
215
216 loop {
217 iterations += 1;
218 if iterations > 100_000 {
219 return Err(Error::Compression(
220 "compression took too many iterations".into(),
221 ));
222 }
223
224 let available = output.capacity() - output.len();
226 if available == 0 {
227 output.reserve(4096);
228 }
229
230 let input = &data[total_in..];
231 let before_out = self.compress.total_out();
232 let before_in = self.compress.total_in();
233
234 let out_start = output.len();
236 let out_capacity = output.capacity();
237 unsafe {
238 output.set_len(out_capacity);
239 }
240
241 let status = self
242 .compress
243 .compress(input, &mut output[out_start..], FlushCompress::Sync)
244 .map_err(|e| Error::Compression(format!("deflate error: {}", e)))?;
245
246 let consumed = (self.compress.total_in() - before_in) as usize;
247 let produced = (self.compress.total_out() - before_out) as usize;
248
249 total_in += consumed;
250
251 unsafe {
252 output.set_len(out_start + produced);
253 }
254
255 match status {
256 Status::Ok | Status::BufError => {
257 if total_in >= data.len() {
258 break;
259 }
260 }
261 Status::StreamEnd => break,
262 }
263 }
264
265 if output.len() >= 4 && output.ends_with(&DEFLATE_TRAILER) {
267 output.truncate(output.len() - 4);
268 }
269
270 if output.len() >= data.len() {
272 return Ok(None);
273 }
274
275 Ok(Some(output.freeze()))
276 }
277
278 pub fn reset(&mut self) {
280 self.compress.reset();
281 }
282}
283
284pub struct DeflateDecoder {
286 decompress: Decompress,
287 no_context_takeover: bool,
288 window_bits: u8,
289}
290
291impl DeflateDecoder {
292 pub fn new(window_bits: u8, no_context_takeover: bool) -> Self {
294 let decompress = Decompress::new_with_window_bits(false, window_bits);
296
297 Self {
298 decompress,
299 no_context_takeover,
300 window_bits,
301 }
302 }
303
304 pub fn decompress(&mut self, data: &[u8], max_size: usize) -> Result<Bytes> {
306 if self.no_context_takeover {
308 self.decompress.reset(false);
309 }
310
311 let mut input = BytesMut::with_capacity(data.len() + 4);
313 input.extend_from_slice(data);
314 input.extend_from_slice(&DEFLATE_TRAILER);
315
316 let initial_cap = std::cmp::max(1024, data.len() * 4);
318 let mut output = BytesMut::with_capacity(initial_cap);
319 let mut total_in: usize = 0;
320 let mut iterations = 0u32;
321
322 loop {
323 iterations += 1;
324 if iterations > 100_000 {
326 return Err(Error::Compression(
327 "decompression took too many iterations".into(),
328 ));
329 }
330
331 if output.len() > max_size {
333 return Err(Error::MessageTooLarge);
334 }
335
336 let available = output.capacity() - output.len();
338 if available == 0 {
339 if output.capacity() >= max_size {
340 return Err(Error::MessageTooLarge);
341 }
342 let additional = std::cmp::max(output.capacity(), 4096);
344 output.reserve(additional);
345 }
346
347 let before_out = self.decompress.total_out();
348 let before_in = self.decompress.total_in();
349
350 let out_start = output.len();
352 let out_capacity = output.capacity();
353 unsafe {
354 output.set_len(out_capacity);
355 }
356
357 let status = self
358 .decompress
359 .decompress(
360 &input[total_in..],
361 &mut output[out_start..],
362 FlushDecompress::Sync,
363 )
364 .map_err(|e| Error::Compression(format!("inflate error: {}", e)))?;
365
366 let consumed = (self.decompress.total_in() - before_in) as usize;
367 let produced = (self.decompress.total_out() - before_out) as usize;
368
369 total_in += consumed;
370
371 unsafe {
372 output.set_len(out_start + produced);
373 }
374
375 match status {
376 Status::Ok => {
377 if total_in >= input.len() {
378 break;
379 }
380 }
381 Status::StreamEnd => break,
382 Status::BufError => {
383 }
385 }
386 }
387
388 Ok(output.freeze())
389 }
390
391 pub fn reset(&mut self) {
393 self.decompress.reset(false);
394 }
395}
396
397pub struct DeflateContext {
399 pub encoder: DeflateEncoder,
401 pub decoder: DeflateDecoder,
403 pub config: DeflateConfig,
405}
406
407impl DeflateContext {
408 pub fn server(config: DeflateConfig) -> Self {
410 let encoder = DeflateEncoder::new(
411 config.server_max_window_bits,
412 config.server_no_context_takeover,
413 config.compression_level,
414 config.compression_threshold,
415 );
416 let decoder = DeflateDecoder::new(
417 config.client_max_window_bits,
418 config.client_no_context_takeover,
419 );
420
421 Self {
422 encoder,
423 decoder,
424 config,
425 }
426 }
427
428 pub fn client(config: DeflateConfig) -> Self {
430 let encoder = DeflateEncoder::new(
431 config.client_max_window_bits,
432 config.client_no_context_takeover,
433 config.compression_level,
434 config.compression_threshold,
435 );
436 let decoder = DeflateDecoder::new(
437 config.server_max_window_bits,
438 config.server_no_context_takeover,
439 );
440
441 Self {
442 encoder,
443 decoder,
444 config,
445 }
446 }
447
448 pub fn compress(&mut self, data: &[u8]) -> Result<Option<Bytes>> {
450 self.encoder.compress(data)
451 }
452
453 pub fn decompress(&mut self, data: &[u8], max_size: usize) -> Result<Bytes> {
455 self.decoder.decompress(data, max_size)
456 }
457}
458
459pub fn parse_deflate_offer(value: &str) -> Option<Vec<(&str, Option<&str>)>> {
461 let value = value.trim();
462
463 if !value.starts_with("permessage-deflate") {
465 return None;
466 }
467
468 let rest = value.strip_prefix("permessage-deflate")?.trim_start();
469
470 if rest.is_empty() {
471 return Some(Vec::new());
472 }
473
474 if !rest.starts_with(';') {
476 return None;
477 }
478
479 let mut params = Vec::new();
480
481 for part in rest[1..].split(';') {
482 let part = part.trim();
483 if part.is_empty() {
484 continue;
485 }
486
487 if let Some((name, value)) = part.split_once('=') {
488 let name = name.trim();
489 let value = value.trim().trim_matches('"');
490 params.push((name, Some(value)));
491 } else {
492 params.push((part, None));
493 }
494 }
495
496 Some(params)
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502
503 #[test]
504 fn test_compress_decompress() {
505 let config = DeflateConfig::default();
506 let mut ctx = DeflateContext::server(config);
507
508 let original = b"Hello, World! This is a test message that should be compressed.";
509
510 let compressed = ctx.compress(original).unwrap();
512 assert!(compressed.is_some());
513 let compressed = compressed.unwrap();
514 assert!(compressed.len() < original.len());
515
516 let decompressed = ctx.decompress(&compressed, 1024).unwrap();
518 assert_eq!(&decompressed[..], &original[..]);
519 }
520
521 #[test]
522 fn test_small_message_not_compressed() {
523 let config = DeflateConfig {
524 compression_threshold: 100,
525 ..Default::default()
526 };
527 let mut ctx = DeflateContext::server(config);
528
529 let small = b"tiny";
530 let result = ctx.compress(small).unwrap();
531 assert!(result.is_none());
532 }
533
534 #[test]
535 fn test_context_takeover() {
536 let config = DeflateConfig {
537 server_no_context_takeover: false,
538 compression_threshold: 0,
539 ..Default::default()
540 };
541 let mut ctx = DeflateContext::server(config);
542
543 let msg = b"Hello, World! Hello, World! Hello, World!";
544
545 let first = ctx.compress(msg).unwrap().unwrap();
547
548 let second = ctx.compress(msg).unwrap().unwrap();
550
551 assert!(second.len() <= first.len());
554 }
555
556 #[test]
557 fn test_no_context_takeover() {
558 let config = DeflateConfig {
559 server_no_context_takeover: true,
560 compression_threshold: 0,
561 ..Default::default()
562 };
563 let mut ctx = DeflateContext::server(config);
564
565 let msg = b"Hello, World! Hello, World! Hello, World!";
566
567 let first = ctx.compress(msg).unwrap().unwrap();
569 let second = ctx.compress(msg).unwrap().unwrap();
570
571 assert_eq!(first.len(), second.len());
572 }
573
574 #[test]
575 fn test_parse_deflate_offer() {
576 let params = parse_deflate_offer("permessage-deflate").unwrap();
578 assert!(params.is_empty());
579
580 let params = parse_deflate_offer(
582 "permessage-deflate; server_no_context_takeover; server_max_window_bits=10",
583 )
584 .unwrap();
585 assert_eq!(params.len(), 2);
586 assert_eq!(params[0], ("server_no_context_takeover", None));
587 assert_eq!(params[1], ("server_max_window_bits", Some("10")));
588
589 assert!(parse_deflate_offer("some-other-extension").is_none());
591 }
592
593 #[test]
594 fn test_config_from_params() {
595 let params = vec![
596 ("server_no_context_takeover", None),
597 ("client_max_window_bits", Some("12")),
598 ];
599
600 let config = DeflateConfig::from_params(¶ms).unwrap();
601 assert!(config.server_no_context_takeover);
602 assert!(!config.client_no_context_takeover);
603 assert_eq!(config.client_max_window_bits, 12);
604 assert_eq!(config.server_max_window_bits, DEFAULT_WINDOW_BITS);
605 }
606
607 #[test]
608 fn test_response_header() {
609 let config = DeflateConfig {
610 server_no_context_takeover: true,
611 server_max_window_bits: 12,
612 ..Default::default()
613 };
614
615 let header = config.to_response_header();
616 assert!(header.contains("permessage-deflate"));
617 assert!(header.contains("server_no_context_takeover"));
618 assert!(header.contains("server_max_window_bits=12"));
619 }
620}