rsmc_core/
zlib.rs

1use flate2::{
2    write::{ZlibDecoder, ZlibEncoder},
3    Compression,
4};
5use std::io::Write;
6
7use crate::{
8    client::{Compressor, Error},
9    protocol::Packet,
10};
11
12/// The minimum number of bytes before the Zlib compressor starts
13/// compressing data. About 5 times the size of a packet header.
14pub const DEFAULT_MIN_BYTES: usize = 128;
15
16/// A compressor that implements zlib compression and decompression.
17#[derive(Debug, Clone, Copy)]
18pub struct ZlibCompressor {
19    compression: Compression,
20    min_bytes: usize,
21}
22
23impl ZlibCompressor {
24    /// Construct a new zlib compressor with the given compression
25    /// ratio and min_bytes. Packets smaller than min_bytes will not
26    /// get compressed by the Zlib compressor.
27    pub fn new(compression: Compression, min_bytes: usize) -> Self {
28        ZlibCompressor {
29            compression,
30            min_bytes,
31        }
32    }
33}
34
35impl Default for ZlibCompressor {
36    fn default() -> Self {
37        ZlibCompressor::new(Compression::default(), DEFAULT_MIN_BYTES)
38    }
39}
40
41impl Compressor for ZlibCompressor {
42    fn compress(&self, mut packet: Packet) -> Result<Packet, Error> {
43        if packet.value.len() < self.min_bytes {
44            return Ok(packet);
45        }
46
47        let mut out = vec![];
48        let mut enc = ZlibEncoder::new(&mut out, self.compression);
49        enc.write_all(&packet.value)?;
50        enc.finish()?;
51
52        // Update the header lengths to match the new value.
53        let key_len = packet.header.key_length as u32;
54        let ext_len = packet.header.extras_length as u32;
55        let val_len = out.len() as u32;
56        packet.header.body_len = key_len + ext_len + val_len;
57        // Set a flag indicating that this data is compressed with zlib.
58        // NB: extras must be non-empty to compress packets.
59        packet.extras[0] = 1;
60        packet.value = out;
61        Ok(packet)
62    }
63
64    fn decompress(&self, mut packet: Packet) -> Result<Packet, Error> {
65        if packet.extras.get(0) != Some(&1) {
66            // This packet did not have the compression flag enabled.
67            return Ok(packet);
68        }
69
70        let mut out = vec![];
71        let mut dec = ZlibDecoder::new(&mut out);
72        dec.write_all(&packet.value)?;
73        dec.finish()?;
74
75        // Update the header lengths to match the new value.
76        let key_len = packet.header.key_length as u32;
77        let ext_len = packet.header.extras_length as u32;
78        let val_len = out.len() as u32;
79        packet.header.body_len = key_len + ext_len + val_len;
80        // Unset the flag indicating that this data is compressed with zlib.
81        packet.extras[0] = 0;
82        packet.value = out;
83        Ok(packet)
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use flate2::Compression;
90
91    use crate::{
92        client::Compressor,
93        protocol::{Packet, SetExtras},
94    };
95
96    use super::ZlibCompressor;
97
98    #[test]
99    fn test_zlib() {
100        let compressor = ZlibCompressor::new(Compression::new(9), 1);
101
102        let key = b"my_test_key";
103        let value = b"0000000000000000000000000000000000000000000000";
104        let packet = Packet::set(&key[..], &value[..], SetExtras::new(0, 300)).unwrap();
105
106        let compressed = compressor.compress(packet.clone()).unwrap();
107        let uncompressed = compressor.decompress(compressed.clone()).unwrap();
108
109        assert!(compressed.header.body_len < packet.header.body_len);
110        assert_eq!(packet, uncompressed);
111    }
112}