1use flate2::{
2 write::{ZlibDecoder, ZlibEncoder},
3 Compression,
4};
5use std::io::Write;
6
7use crate::{
8 client::{Compressor, Error},
9 protocol::Packet,
10};
11
12pub const DEFAULT_MIN_BYTES: usize = 128;
15
16#[derive(Debug, Clone, Copy)]
18pub struct ZlibCompressor {
19 compression: Compression,
20 min_bytes: usize,
21}
22
23impl ZlibCompressor {
24 pub fn new(compression: Compression, min_bytes: usize) -> Self {
28 ZlibCompressor {
29 compression,
30 min_bytes,
31 }
32 }
33}
34
35impl Default for ZlibCompressor {
36 fn default() -> Self {
37 ZlibCompressor::new(Compression::default(), DEFAULT_MIN_BYTES)
38 }
39}
40
41impl Compressor for ZlibCompressor {
42 fn compress(&self, mut packet: Packet) -> Result<Packet, Error> {
43 if packet.value.len() < self.min_bytes {
44 return Ok(packet);
45 }
46
47 let mut out = vec![];
48 let mut enc = ZlibEncoder::new(&mut out, self.compression);
49 enc.write_all(&packet.value)?;
50 enc.finish()?;
51
52 let key_len = packet.header.key_length as u32;
54 let ext_len = packet.header.extras_length as u32;
55 let val_len = out.len() as u32;
56 packet.header.body_len = key_len + ext_len + val_len;
57 packet.extras[0] = 1;
60 packet.value = out;
61 Ok(packet)
62 }
63
64 fn decompress(&self, mut packet: Packet) -> Result<Packet, Error> {
65 if packet.extras.get(0) != Some(&1) {
66 return Ok(packet);
68 }
69
70 let mut out = vec![];
71 let mut dec = ZlibDecoder::new(&mut out);
72 dec.write_all(&packet.value)?;
73 dec.finish()?;
74
75 let key_len = packet.header.key_length as u32;
77 let ext_len = packet.header.extras_length as u32;
78 let val_len = out.len() as u32;
79 packet.header.body_len = key_len + ext_len + val_len;
80 packet.extras[0] = 0;
82 packet.value = out;
83 Ok(packet)
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use flate2::Compression;
90
91 use crate::{
92 client::Compressor,
93 protocol::{Packet, SetExtras},
94 };
95
96 use super::ZlibCompressor;
97
98 #[test]
99 fn test_zlib() {
100 let compressor = ZlibCompressor::new(Compression::new(9), 1);
101
102 let key = b"my_test_key";
103 let value = b"0000000000000000000000000000000000000000000000";
104 let packet = Packet::set(&key[..], &value[..], SetExtras::new(0, 300)).unwrap();
105
106 let compressed = compressor.compress(packet.clone()).unwrap();
107 let uncompressed = compressor.decompress(compressed.clone()).unwrap();
108
109 assert!(compressed.header.body_len < packet.header.body_len);
110 assert_eq!(packet, uncompressed);
111 }
112}