Skip to main content

cassandra_protocol/
compression.rs

1/// CDRS support traffic compression as it is described in [Apache
2/// Cassandra protocol](
3/// https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L790)
4///
5/// Before being used, client and server must agree on a compression algorithm to
6/// use, which is done in the STARTUP message. As a consequence, a STARTUP message
7/// must never be compressed.  However, once the STARTUP envelope has been received
8/// by the server, messages can be compressed (including the response to the STARTUP
9/// request).
10use derive_more::Display;
11use snap::raw::{Decoder, Encoder};
12use std::convert::{From, TryInto};
13use std::error::Error;
14use std::fmt;
15use std::io;
16use std::result;
17
18type Result<T> = result::Result<T, CompressionError>;
19
20pub const LZ4: &str = "lz4";
21pub const SNAPPY: &str = "snappy";
22
23/// An error which may occur during encoding or decoding frame body. As there are only two types
24/// of compressors it contains two related enum options.
25#[derive(Debug)]
26pub enum CompressionError {
27    /// Snappy error.
28    Snappy(snap::Error),
29    /// Lz4 error.
30    Lz4(io::Error),
31}
32
33impl fmt::Display for CompressionError {
34    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
35        match *self {
36            CompressionError::Snappy(ref err) => write!(f, "Snappy Error: {err:?}"),
37            CompressionError::Lz4(ref err) => write!(f, "Lz4 Error: {err:?}"),
38        }
39    }
40}
41
42impl Error for CompressionError {
43    fn source(&self) -> Option<&(dyn Error + 'static)> {
44        match *self {
45            CompressionError::Snappy(ref err) => Some(err),
46            CompressionError::Lz4(ref err) => Some(err),
47        }
48    }
49}
50
51impl Clone for CompressionError {
52    fn clone(&self) -> Self {
53        match self {
54            CompressionError::Snappy(error) => CompressionError::Snappy(error.clone()),
55            CompressionError::Lz4(error) => CompressionError::Lz4(io::Error::new(
56                error.kind(),
57                error
58                    .get_ref()
59                    .map(|error| error.to_string())
60                    .unwrap_or_default(),
61            )),
62        }
63    }
64}
65
66/// Enum which represents a type of compression. Only non-startup frame's body can be compressed.
67#[derive(Debug, PartialEq, Clone, Copy, Eq, Ord, PartialOrd, Hash, Display)]
68pub enum Compression {
69    /// [lz4](https://code.google.com/p/lz4/) compression
70    Lz4,
71    /// [snappy](https://code.google.com/p/snappy/) compression
72    Snappy,
73    /// No compression
74    None,
75}
76
77impl Compression {
78    /// It encodes `bytes` basing on type of `Compression`..
79    ///
80    /// # Examples
81    ///
82    /// ```
83    ///    use cassandra_protocol::compression::Compression;
84    ///
85    ///   let snappy_compression = Compression::Snappy;
86    ///   let bytes = String::from("Hello World").into_bytes().to_vec();
87    ///   let encoded = snappy_compression.encode(&bytes).unwrap();
88    ///   assert_eq!(snappy_compression.decode(encoded).unwrap(), bytes);
89    ///
90    /// ```
91    pub fn encode(&self, bytes: &[u8]) -> Result<Vec<u8>> {
92        match *self {
93            Compression::Lz4 => Compression::encode_lz4(bytes),
94            Compression::Snappy => Compression::encode_snappy(bytes),
95            Compression::None => Ok(bytes.into()),
96        }
97    }
98
99    /// Checks if current compression actually compresses data.
100    #[inline]
101    pub fn is_compressed(self) -> bool {
102        self != Compression::None
103    }
104
105    /// It decodes `bytes` basing on type of compression.
106    pub fn decode(&self, bytes: Vec<u8>) -> Result<Vec<u8>> {
107        match *self {
108            Compression::Lz4 => Compression::decode_lz4(bytes),
109            Compression::Snappy => Compression::decode_snappy(bytes),
110            Compression::None => Ok(bytes),
111        }
112    }
113
114    /// It transforms compression method into a `&str`.
115    pub fn as_str(&self) -> Option<&'static str> {
116        match *self {
117            Compression::Lz4 => Some(LZ4),
118            Compression::Snappy => Some(SNAPPY),
119            Compression::None => None,
120        }
121    }
122
123    fn encode_snappy(bytes: &[u8]) -> Result<Vec<u8>> {
124        let mut encoder = Encoder::new();
125        encoder
126            .compress_vec(bytes)
127            .map_err(CompressionError::Snappy)
128    }
129
130    fn decode_snappy(bytes: Vec<u8>) -> Result<Vec<u8>> {
131        let mut decoder = Decoder::new();
132        decoder
133            .decompress_vec(bytes.as_slice())
134            .map_err(CompressionError::Snappy)
135    }
136
137    fn encode_lz4(bytes: &[u8]) -> Result<Vec<u8>> {
138        let len = 4 + lz4_flex::block::get_maximum_output_size(bytes.len());
139        assert!(len <= i32::MAX as usize);
140
141        let mut result = vec![0; len];
142
143        let len = bytes.len() as i32;
144        result[..4].copy_from_slice(&len.to_be_bytes());
145
146        let compressed_len = lz4_flex::compress_into(bytes, &mut result[4..])
147            .map_err(|error| CompressionError::Lz4(io::Error::other(error)))?;
148
149        result.truncate(4 + compressed_len);
150        Ok(result)
151    }
152
153    fn decode_lz4(bytes: Vec<u8>) -> Result<Vec<u8>> {
154        // lz4 wire format prepends a 4-byte big-endian uncompressed length so
155        // the decoder knows how much memory to allocate. Validate length before
156        // slicing to avoid panics on truncated input.
157        if bytes.len() < 4 {
158            return Err(CompressionError::Lz4(io::Error::new(
159                io::ErrorKind::UnexpectedEof,
160                "lz4 payload missing 4-byte uncompressed length header",
161            )));
162        }
163
164        let uncompressed_size = i32::from_be_bytes(
165            bytes[..4]
166                .try_into()
167                .map_err(|error| CompressionError::Lz4(io::Error::other(error)))?,
168        );
169
170        // a negative size is impossible for a real payload; without this check
171        // the `as usize` cast would silently turn it into ~2 GB+ and ask
172        // lz4_flex to allocate a buffer that size before any decoding begins.
173        if uncompressed_size < 0 {
174            return Err(CompressionError::Lz4(io::Error::new(
175                io::ErrorKind::InvalidData,
176                format!("negative uncompressed size {uncompressed_size}"),
177            )));
178        }
179
180        lz4_flex::decompress(&bytes[4..], uncompressed_size as usize)
181            .map_err(|error| CompressionError::Lz4(io::Error::other(error)))
182    }
183}
184
185impl From<String> for Compression {
186    /// It converts `String` into `Compression`. If string is neither `lz4` nor `snappy` then
187    /// `Compression::None` will be returned
188    fn from(compression_string: String) -> Compression {
189        Compression::from(compression_string.as_str())
190    }
191}
192
193impl Compression {
194    /// It converts `Compression` into `String`. If compression is `None` then empty string will be
195    /// returned
196    pub fn to_protocol_string(self) -> String {
197        match self {
198            Compression::Lz4 => "LZ4".to_string(),
199            Compression::Snappy => "SNAPPY".to_string(),
200            Compression::None => "NONE".to_string(),
201        }
202    }
203
204    pub fn from_protocol_string(protocol_string: &str) -> std::result::Result<Self, String> {
205        match protocol_string {
206            "lz4" | "LZ4" => Ok(Compression::Lz4),
207            "snappy" | "SNAPPY" => Ok(Compression::Snappy),
208            "none" | "NONE" => Ok(Compression::None),
209            _ => Err("Unknown compression".to_string()),
210        }
211    }
212}
213
214impl<'a> From<&'a str> for Compression {
215    /// It converts `str` into `Compression`. If string is neither `lz4` nor `snappy` then
216    /// `Compression::None` will be returned
217    fn from(compression_str: &'a str) -> Compression {
218        match compression_str {
219            LZ4 => Compression::Lz4,
220            SNAPPY => Compression::Snappy,
221            _ => Compression::None,
222        }
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn test_compression_to_protocol_string() {
232        let lz4 = Compression::Lz4;
233        assert_eq!("LZ4", lz4.to_protocol_string());
234
235        let snappy = Compression::Snappy;
236        assert_eq!("SNAPPY", snappy.to_protocol_string());
237
238        let none = Compression::None;
239        assert_eq!("NONE", none.to_protocol_string());
240    }
241
242    #[test]
243    fn test_compression_from_protocol_str() {
244        let lz4 = "lz4";
245        assert_eq!(
246            Compression::from_protocol_string(lz4).unwrap(),
247            Compression::Lz4
248        );
249
250        let lz4 = "LZ4";
251        assert_eq!(
252            Compression::from_protocol_string(lz4).unwrap(),
253            Compression::Lz4
254        );
255
256        let snappy = "snappy";
257        assert_eq!(
258            Compression::from_protocol_string(snappy).unwrap(),
259            Compression::Snappy
260        );
261
262        let snappy = "SNAPPY";
263        assert_eq!(
264            Compression::from_protocol_string(snappy).unwrap(),
265            Compression::Snappy
266        );
267
268        let none = "none";
269        assert_eq!(
270            Compression::from_protocol_string(none).unwrap(),
271            Compression::None
272        );
273
274        let none = "NONE";
275        assert_eq!(
276            Compression::from_protocol_string(none).unwrap(),
277            Compression::None
278        );
279    }
280
281    #[test]
282    fn test_compression_from_string() {
283        let lz4 = "lz4".to_string();
284        assert_eq!(Compression::from(lz4), Compression::Lz4);
285        let snappy = "snappy".to_string();
286        assert_eq!(Compression::from(snappy), Compression::Snappy);
287        let none = "x".to_string();
288        assert_eq!(Compression::from(none), Compression::None);
289    }
290
291    #[test]
292    fn test_compression_encode_snappy() {
293        let snappy_compression = Compression::Snappy;
294        let bytes = String::from("Hello World").into_bytes().to_vec();
295        snappy_compression
296            .encode(&bytes)
297            .expect("Should work without exceptions");
298    }
299
300    #[test]
301    fn test_compression_decode_snappy() {
302        let snappy_compression = Compression::Snappy;
303        let bytes = String::from("Hello World").into_bytes().to_vec();
304        let encoded = snappy_compression.encode(&bytes).unwrap();
305        assert_eq!(snappy_compression.decode(encoded).unwrap(), bytes);
306    }
307
308    #[test]
309    fn test_compression_encode_lz4() {
310        let snappy_compression = Compression::Lz4;
311        let bytes = String::from("Hello World").into_bytes().to_vec();
312        snappy_compression
313            .encode(&bytes)
314            .expect("Should work without exceptions");
315    }
316
317    #[test]
318    fn test_compression_decode_lz4() {
319        let lz4_compression = Compression::Lz4;
320        let bytes = String::from("Hello World").into_bytes().to_vec();
321        let encoded = lz4_compression.encode(&bytes).unwrap();
322        assert_eq!(lz4_compression.decode(encoded).unwrap(), bytes);
323    }
324
325    #[test]
326    fn test_compression_encode_none() {
327        let none_compression = Compression::None;
328        let bytes = String::from("Hello World").into_bytes().to_vec();
329        none_compression
330            .encode(&bytes)
331            .expect("Should work without exceptions");
332    }
333
334    #[test]
335    fn test_compression_decode_none() {
336        let none_compression = Compression::None;
337        let bytes = String::from("Hello World").into_bytes().to_vec();
338        let encoded = none_compression.encode(&bytes).unwrap();
339        assert_eq!(none_compression.decode(encoded).unwrap(), bytes);
340    }
341
342    #[test]
343    fn test_compression_decode_lz4_with_invalid_input() {
344        let lz4_compression = Compression::Lz4;
345        let decode = lz4_compression.decode(vec![0, 0, 0, 0x7f]);
346        assert!(decode.is_err());
347    }
348
349    #[test]
350    fn test_compression_decode_lz4_short_input_is_error_not_panic() {
351        // the lz4 wire format prepends a 4-byte big-endian uncompressed size;
352        // a payload shorter than that header must surface as an error rather
353        // than panicking on `bytes[..4]` slicing.
354        let lz4_compression = Compression::Lz4;
355        assert!(lz4_compression.decode(vec![]).is_err());
356        assert!(lz4_compression.decode(vec![1, 2, 3]).is_err());
357    }
358
359    #[test]
360    fn test_compression_decode_lz4_negative_size_is_error_not_oom() {
361        // a negative i32 uncompressed length cast through `as usize` becomes
362        // a huge value (~2 GB+) and would otherwise hand lz4_flex an absurd
363        // allocation request - guard against that.
364        let lz4_compression = Compression::Lz4;
365        // -1 in big-endian i32 followed by a dummy compressed byte
366        let bytes = vec![0xff, 0xff, 0xff, 0xff, 0];
367        assert!(lz4_compression.decode(bytes).is_err());
368    }
369
370    #[test]
371    fn test_compression_encode_snappy_with_non_utf8() {
372        let snappy_compression = Compression::Snappy;
373        let v = vec![0xff, 0xff];
374        let encoded = snappy_compression
375            .encode(&v)
376            .expect("Should work without exceptions");
377        assert_eq!(snappy_compression.decode(encoded).unwrap(), v);
378    }
379}