1use bytes::{Bytes, BytesMut};
7use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status};
8
9use crate::error::{Error, Result};
10
11const DEFLATE_TRAILER: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
13
14pub const DEFAULT_WINDOW_BITS: u8 = 15;
16
17pub const MIN_WINDOW_BITS: u8 = 8;
19
20pub const MAX_WINDOW_BITS: u8 = 15;
22
23#[derive(Debug, Clone)]
25pub struct DeflateConfig {
26 pub server_max_window_bits: u8,
28 pub client_max_window_bits: u8,
30 pub server_no_context_takeover: bool,
32 pub client_no_context_takeover: bool,
34 pub compression_level: u32,
36 pub compression_threshold: usize,
38}
39
40impl Default for DeflateConfig {
41 fn default() -> Self {
42 Self {
43 server_max_window_bits: DEFAULT_WINDOW_BITS,
44 client_max_window_bits: DEFAULT_WINDOW_BITS,
45 server_no_context_takeover: false,
46 client_no_context_takeover: false,
47 compression_level: 6, compression_threshold: 32, }
50 }
51}
52
53impl DeflateConfig {
54 pub fn low_memory() -> Self {
56 Self {
57 server_max_window_bits: 10, client_max_window_bits: 10,
59 server_no_context_takeover: true,
60 client_no_context_takeover: true,
61 compression_level: 1, compression_threshold: 64,
63 }
64 }
65
66 pub fn best_compression() -> Self {
68 Self {
69 server_max_window_bits: MAX_WINDOW_BITS,
70 client_max_window_bits: MAX_WINDOW_BITS,
71 server_no_context_takeover: false,
72 client_no_context_takeover: false,
73 compression_level: 9,
74 compression_threshold: 16,
75 }
76 }
77
78 pub fn from_params(params: &[(&str, Option<&str>)]) -> Result<Self> {
80 let mut config = Self::default();
81
82 for (name, value) in params {
83 match *name {
84 "server_no_context_takeover" => {
85 if value.is_some() {
86 return Err(Error::HandshakeFailed(
87 "server_no_context_takeover must not have a value",
88 ));
89 }
90 config.server_no_context_takeover = true;
91 }
92 "client_no_context_takeover" => {
93 if value.is_some() {
94 return Err(Error::HandshakeFailed(
95 "client_no_context_takeover must not have a value",
96 ));
97 }
98 config.client_no_context_takeover = true;
99 config.server_no_context_takeover = true;
102 }
103 "server_max_window_bits" => {
104 if let Some(v) = value {
105 let bits: u8 = v.parse().map_err(|_| {
106 Error::HandshakeFailed("invalid server_max_window_bits value")
107 })?;
108 if !(MIN_WINDOW_BITS..=MAX_WINDOW_BITS).contains(&bits) {
109 return Err(Error::HandshakeFailed(
110 "server_max_window_bits out of range (8-15)",
111 ));
112 }
113 config.server_max_window_bits = bits;
114 }
115 }
116 "client_max_window_bits" => {
117 if let Some(v) = value {
118 let bits: u8 = v.parse().map_err(|_| {
119 Error::HandshakeFailed("invalid client_max_window_bits value")
120 })?;
121 if !(MIN_WINDOW_BITS..=MAX_WINDOW_BITS).contains(&bits) {
122 return Err(Error::HandshakeFailed(
123 "client_max_window_bits out of range (8-15)",
124 ));
125 }
126 config.client_max_window_bits = bits;
127 }
128 }
130 _ => {
131 return Err(Error::HandshakeFailed(
132 "unknown permessage-deflate parameter",
133 ));
134 }
135 }
136 }
137
138 Ok(config)
139 }
140
141 pub fn to_response_header(&self) -> String {
143 let mut parts = vec!["permessage-deflate".to_string()];
144
145 if self.server_no_context_takeover {
146 parts.push("server_no_context_takeover".to_string());
147 }
148 if self.client_no_context_takeover {
149 parts.push("client_no_context_takeover".to_string());
150 }
151 if self.server_max_window_bits < MAX_WINDOW_BITS {
152 parts.push(format!(
153 "server_max_window_bits={}",
154 self.server_max_window_bits
155 ));
156 }
157 if self.client_max_window_bits < MAX_WINDOW_BITS {
158 parts.push(format!(
159 "client_max_window_bits={}",
160 self.client_max_window_bits
161 ));
162 }
163
164 parts.join("; ")
165 }
166}
167
168pub struct DeflateEncoder {
170 compress: Compress,
171 no_context_takeover: bool,
172 #[allow(dead_code)]
173 window_bits: u8,
174 #[allow(dead_code)]
175 compression_level: Compression,
176 threshold: usize,
177}
178
179impl DeflateEncoder {
180 pub fn new(window_bits: u8, no_context_takeover: bool, level: u32, threshold: usize) -> Self {
182 let compression_level = Compression::new(level);
183 let compress = Compress::new_with_window_bits(compression_level, false, window_bits);
186
187 Self {
188 compress,
189 no_context_takeover,
190 window_bits,
191 compression_level,
192 threshold,
193 }
194 }
195
196 pub fn compress(&mut self, data: &[u8]) -> Result<Option<Bytes>> {
201 if data.len() < self.threshold {
202 return Ok(None);
203 }
204
205 if self.no_context_takeover {
207 self.compress.reset();
208 }
209
210 let max_output = data.len() + 64;
212 let mut output = BytesMut::with_capacity(max_output);
213
214 let mut total_in: usize = 0;
216 let mut iterations = 0u32;
217
218 loop {
219 iterations += 1;
220 if iterations > 100_000 {
221 return Err(Error::Compression(
222 "compression took too many iterations".into(),
223 ));
224 }
225
226 let available = output.capacity() - output.len();
228 if available == 0 {
229 output.reserve(4096);
230 }
231
232 let input = &data[total_in..];
233 let before_out = self.compress.total_out();
234 let before_in = self.compress.total_in();
235
236 let out_start = output.len();
239 let spare = output.spare_capacity_mut();
240
241 let spare_slice = unsafe {
245 std::slice::from_raw_parts_mut(spare.as_mut_ptr() as *mut u8, spare.len())
246 };
247
248 let status = self
249 .compress
250 .compress(input, spare_slice, FlushCompress::Sync)
251 .map_err(|e| Error::Compression(format!("deflate error: {}", e)))?;
252
253 let consumed = (self.compress.total_in() - before_in) as usize;
254 let produced = (self.compress.total_out() - before_out) as usize;
255
256 total_in += consumed;
257
258 unsafe {
261 output.set_len(out_start + produced);
262 }
263
264 match status {
265 Status::Ok | Status::BufError => {
266 if total_in >= data.len() {
267 break;
268 }
269 }
270 Status::StreamEnd => break,
271 }
272 }
273
274 if output.len() >= 4 && output.ends_with(&DEFLATE_TRAILER) {
276 output.truncate(output.len() - 4);
277 }
278
279 if output.len() >= data.len() {
281 return Ok(None);
282 }
283
284 Ok(Some(output.freeze()))
285 }
286
287 pub fn reset(&mut self) {
289 self.compress.reset();
290 }
291}
292
293pub struct DeflateDecoder {
295 decompress: Decompress,
296 no_context_takeover: bool,
297 #[allow(dead_code)]
298 window_bits: u8,
299}
300
301impl DeflateDecoder {
302 pub fn new(window_bits: u8, no_context_takeover: bool) -> Self {
304 let decompress = Decompress::new_with_window_bits(false, window_bits);
306
307 Self {
308 decompress,
309 no_context_takeover,
310 window_bits,
311 }
312 }
313
314 pub fn decompress(&mut self, data: &[u8], max_size: usize) -> Result<Bytes> {
316 if self.no_context_takeover {
318 self.decompress.reset(false);
319 }
320
321 let mut input = BytesMut::with_capacity(data.len() + 4);
323 input.extend_from_slice(data);
324 input.extend_from_slice(&DEFLATE_TRAILER);
325
326 let initial_cap = std::cmp::max(1024, data.len() * 4);
328 let mut output = BytesMut::with_capacity(initial_cap);
329 let mut total_in: usize = 0;
330 let mut iterations = 0u32;
331
332 loop {
333 iterations += 1;
334 if iterations > 100_000 {
336 return Err(Error::Compression(
337 "decompression took too many iterations".into(),
338 ));
339 }
340
341 if output.len() > max_size {
343 return Err(Error::MessageTooLarge);
344 }
345
346 let available = output.capacity() - output.len();
348 if available == 0 {
349 if output.capacity() >= max_size {
350 return Err(Error::MessageTooLarge);
351 }
352 let additional = std::cmp::max(output.capacity(), 4096);
354 output.reserve(additional);
355 }
356
357 let before_out = self.decompress.total_out();
358 let before_in = self.decompress.total_in();
359
360 let out_start = output.len();
362 let spare = output.spare_capacity_mut();
363
364 let spare_slice = unsafe {
368 std::slice::from_raw_parts_mut(spare.as_mut_ptr() as *mut u8, spare.len())
369 };
370
371 let status = self
372 .decompress
373 .decompress(&input[total_in..], spare_slice, FlushDecompress::Sync)
374 .map_err(|e| Error::Compression(format!("inflate error: {}", e)))?;
375
376 let consumed = (self.decompress.total_in() - before_in) as usize;
377 let produced = (self.decompress.total_out() - before_out) as usize;
378
379 total_in += consumed;
380
381 unsafe {
384 output.set_len(out_start + produced);
385 }
386
387 match status {
388 Status::Ok => {
389 if total_in >= input.len() {
390 break;
391 }
392 }
393 Status::StreamEnd => break,
394 Status::BufError => {
395 }
397 }
398 }
399
400 Ok(output.freeze())
401 }
402
403 pub fn reset(&mut self) {
405 self.decompress.reset(false);
406 }
407}
408
409pub struct DeflateContext {
411 pub encoder: DeflateEncoder,
413 pub decoder: DeflateDecoder,
415 pub config: DeflateConfig,
417}
418
419impl DeflateContext {
420 pub fn server(config: DeflateConfig) -> Self {
422 let encoder = DeflateEncoder::new(
423 config.server_max_window_bits,
424 config.server_no_context_takeover,
425 config.compression_level,
426 config.compression_threshold,
427 );
428 let decoder = DeflateDecoder::new(
429 config.client_max_window_bits,
430 config.client_no_context_takeover,
431 );
432
433 Self {
434 encoder,
435 decoder,
436 config,
437 }
438 }
439
440 pub fn client(config: DeflateConfig) -> Self {
442 let encoder = DeflateEncoder::new(
443 config.client_max_window_bits,
444 config.client_no_context_takeover,
445 config.compression_level,
446 config.compression_threshold,
447 );
448 let decoder = DeflateDecoder::new(
449 config.server_max_window_bits,
450 config.server_no_context_takeover,
451 );
452
453 Self {
454 encoder,
455 decoder,
456 config,
457 }
458 }
459
460 pub fn compress(&mut self, data: &[u8]) -> Result<Option<Bytes>> {
462 self.encoder.compress(data)
463 }
464
465 pub fn decompress(&mut self, data: &[u8], max_size: usize) -> Result<Bytes> {
467 self.decoder.decompress(data, max_size)
468 }
469}
470
471pub fn parse_deflate_offer(value: &str) -> Option<Vec<(&str, Option<&str>)>> {
473 let value = value.trim();
474
475 if !value.starts_with("permessage-deflate") {
477 return None;
478 }
479
480 let rest = value.strip_prefix("permessage-deflate")?.trim_start();
481
482 if rest.is_empty() {
483 return Some(Vec::new());
484 }
485
486 if !rest.starts_with(';') {
488 return None;
489 }
490
491 let mut params = Vec::new();
492
493 for part in rest[1..].split(';') {
494 let part = part.trim();
495 if part.is_empty() {
496 continue;
497 }
498
499 if let Some((name, value)) = part.split_once('=') {
500 let name = name.trim();
501 let value = value.trim().trim_matches('"');
502 params.push((name, Some(value)));
503 } else {
504 params.push((part, None));
505 }
506 }
507
508 Some(params)
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn test_compress_decompress() {
517 let config = DeflateConfig::default();
518 let mut ctx = DeflateContext::server(config);
519
520 let original = b"Hello, World! This is a test message that should be compressed.";
521
522 let compressed = ctx.compress(original).unwrap();
524 assert!(compressed.is_some());
525 let compressed = compressed.unwrap();
526 assert!(compressed.len() < original.len());
527
528 let decompressed = ctx.decompress(&compressed, 1024).unwrap();
530 assert_eq!(&decompressed[..], &original[..]);
531 }
532
533 #[test]
534 fn test_small_message_not_compressed() {
535 let config = DeflateConfig {
536 compression_threshold: 100,
537 ..Default::default()
538 };
539 let mut ctx = DeflateContext::server(config);
540
541 let small = b"tiny";
542 let result = ctx.compress(small).unwrap();
543 assert!(result.is_none());
544 }
545
546 #[test]
547 fn test_context_takeover() {
548 let config = DeflateConfig {
549 server_no_context_takeover: false,
550 compression_threshold: 0,
551 ..Default::default()
552 };
553 let mut ctx = DeflateContext::server(config);
554
555 let msg = b"Hello, World! Hello, World! Hello, World!";
556
557 let first = ctx.compress(msg).unwrap().unwrap();
559
560 let second = ctx.compress(msg).unwrap().unwrap();
562
563 assert!(second.len() <= first.len());
566 }
567
568 #[test]
569 fn test_no_context_takeover() {
570 let config = DeflateConfig {
571 server_no_context_takeover: true,
572 compression_threshold: 0,
573 ..Default::default()
574 };
575 let mut ctx = DeflateContext::server(config);
576
577 let msg = b"Hello, World! Hello, World! Hello, World!";
578
579 let first = ctx.compress(msg).unwrap().unwrap();
581 let second = ctx.compress(msg).unwrap().unwrap();
582
583 assert_eq!(first.len(), second.len());
584 }
585
586 #[test]
587 fn test_parse_deflate_offer() {
588 let params = parse_deflate_offer("permessage-deflate").unwrap();
590 assert!(params.is_empty());
591
592 let params = parse_deflate_offer(
594 "permessage-deflate; server_no_context_takeover; server_max_window_bits=10",
595 )
596 .unwrap();
597 assert_eq!(params.len(), 2);
598 assert_eq!(params[0], ("server_no_context_takeover", None));
599 assert_eq!(params[1], ("server_max_window_bits", Some("10")));
600
601 assert!(parse_deflate_offer("some-other-extension").is_none());
603 }
604
605 #[test]
606 fn test_config_from_params() {
607 let params = vec![
608 ("server_no_context_takeover", None),
609 ("client_max_window_bits", Some("12")),
610 ];
611
612 let config = DeflateConfig::from_params(¶ms).unwrap();
613 assert!(config.server_no_context_takeover);
614 assert!(!config.client_no_context_takeover);
615 assert_eq!(config.client_max_window_bits, 12);
616 assert_eq!(config.server_max_window_bits, DEFAULT_WINDOW_BITS);
617 }
618
619 #[test]
620 fn test_response_header() {
621 let config = DeflateConfig {
622 server_no_context_takeover: true,
623 server_max_window_bits: 12,
624 ..Default::default()
625 };
626
627 let header = config.to_response_header();
628 assert!(header.contains("permessage-deflate"));
629 assert!(header.contains("server_no_context_takeover"));
630 assert!(header.contains("server_max_window_bits=12"));
631 }
632}