cassandra_protocol/
compression.rs1use derive_more::Display;
11use snap::raw::{Decoder, Encoder};
12use std::convert::{From, TryInto};
13use std::error::Error;
14use std::fmt;
15use std::io;
16use std::result;
17
18type Result<T> = result::Result<T, CompressionError>;
19
20pub const LZ4: &str = "lz4";
21pub const SNAPPY: &str = "snappy";
22
23#[derive(Debug)]
26pub enum CompressionError {
27 Snappy(snap::Error),
29 Lz4(io::Error),
31}
32
33impl fmt::Display for CompressionError {
34 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
35 match *self {
36 CompressionError::Snappy(ref err) => write!(f, "Snappy Error: {err:?}"),
37 CompressionError::Lz4(ref err) => write!(f, "Lz4 Error: {err:?}"),
38 }
39 }
40}
41
42impl Error for CompressionError {
43 fn source(&self) -> Option<&(dyn Error + 'static)> {
44 match *self {
45 CompressionError::Snappy(ref err) => Some(err),
46 CompressionError::Lz4(ref err) => Some(err),
47 }
48 }
49}
50
51impl Clone for CompressionError {
52 fn clone(&self) -> Self {
53 match self {
54 CompressionError::Snappy(error) => CompressionError::Snappy(error.clone()),
55 CompressionError::Lz4(error) => CompressionError::Lz4(io::Error::new(
56 error.kind(),
57 error
58 .get_ref()
59 .map(|error| error.to_string())
60 .unwrap_or_default(),
61 )),
62 }
63 }
64}
65
66#[derive(Debug, PartialEq, Clone, Copy, Eq, Ord, PartialOrd, Hash, Display)]
68pub enum Compression {
69 Lz4,
71 Snappy,
73 None,
75}
76
77impl Compression {
78 pub fn encode(&self, bytes: &[u8]) -> Result<Vec<u8>> {
92 match *self {
93 Compression::Lz4 => Compression::encode_lz4(bytes),
94 Compression::Snappy => Compression::encode_snappy(bytes),
95 Compression::None => Ok(bytes.into()),
96 }
97 }
98
99 #[inline]
101 pub fn is_compressed(self) -> bool {
102 self != Compression::None
103 }
104
105 pub fn decode(&self, bytes: Vec<u8>) -> Result<Vec<u8>> {
107 match *self {
108 Compression::Lz4 => Compression::decode_lz4(bytes),
109 Compression::Snappy => Compression::decode_snappy(bytes),
110 Compression::None => Ok(bytes),
111 }
112 }
113
114 pub fn as_str(&self) -> Option<&'static str> {
116 match *self {
117 Compression::Lz4 => Some(LZ4),
118 Compression::Snappy => Some(SNAPPY),
119 Compression::None => None,
120 }
121 }
122
123 fn encode_snappy(bytes: &[u8]) -> Result<Vec<u8>> {
124 let mut encoder = Encoder::new();
125 encoder
126 .compress_vec(bytes)
127 .map_err(CompressionError::Snappy)
128 }
129
130 fn decode_snappy(bytes: Vec<u8>) -> Result<Vec<u8>> {
131 let mut decoder = Decoder::new();
132 decoder
133 .decompress_vec(bytes.as_slice())
134 .map_err(CompressionError::Snappy)
135 }
136
137 fn encode_lz4(bytes: &[u8]) -> Result<Vec<u8>> {
138 let len = 4 + lz4_flex::block::get_maximum_output_size(bytes.len());
139 assert!(len <= i32::MAX as usize);
140
141 let mut result = vec![0; len];
142
143 let len = bytes.len() as i32;
144 result[..4].copy_from_slice(&len.to_be_bytes());
145
146 let compressed_len = lz4_flex::compress_into(bytes, &mut result[4..])
147 .map_err(|error| CompressionError::Lz4(io::Error::new(io::ErrorKind::Other, error)))?;
148
149 result.truncate(4 + compressed_len);
150 Ok(result)
151 }
152
153 fn decode_lz4(bytes: Vec<u8>) -> Result<Vec<u8>> {
154 let uncompressed_size =
155 i32::from_be_bytes(bytes[..4].try_into().map_err(|error| {
156 CompressionError::Lz4(io::Error::new(io::ErrorKind::Other, error))
157 })?);
158
159 lz4_flex::decompress(&bytes[4..], uncompressed_size as usize)
160 .map_err(|error| CompressionError::Lz4(io::Error::new(io::ErrorKind::Other, error)))
161 }
162}
163
164impl From<String> for Compression {
165 fn from(compression_string: String) -> Compression {
168 Compression::from(compression_string.as_str())
169 }
170}
171
172impl Compression {
173 pub fn to_protocol_string(self) -> String {
176 match self {
177 Compression::Lz4 => "LZ4".to_string(),
178 Compression::Snappy => "SNAPPY".to_string(),
179 Compression::None => "NONE".to_string(),
180 }
181 }
182
183 pub fn from_protocol_string(protocol_string: &str) -> std::result::Result<Self, String> {
184 match protocol_string {
185 "lz4" | "LZ4" => Ok(Compression::Lz4),
186 "snappy" | "SNAPPY" => Ok(Compression::Snappy),
187 "none" | "NONE" => Ok(Compression::None),
188 _ => Err("Unknown compression".to_string()),
189 }
190 }
191}
192
193impl<'a> From<&'a str> for Compression {
194 fn from(compression_str: &'a str) -> Compression {
197 match compression_str {
198 LZ4 => Compression::Lz4,
199 SNAPPY => Compression::Snappy,
200 _ => Compression::None,
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn test_compression_to_protocol_string() {
211 let lz4 = Compression::Lz4;
212 assert_eq!("LZ4", lz4.to_protocol_string());
213
214 let snappy = Compression::Snappy;
215 assert_eq!("SNAPPY", snappy.to_protocol_string());
216
217 let none = Compression::None;
218 assert_eq!("NONE", none.to_protocol_string());
219 }
220
221 #[test]
222 fn test_compression_from_protocol_str() {
223 let lz4 = "lz4";
224 assert_eq!(
225 Compression::from_protocol_string(lz4).unwrap(),
226 Compression::Lz4
227 );
228
229 let lz4 = "LZ4";
230 assert_eq!(
231 Compression::from_protocol_string(lz4).unwrap(),
232 Compression::Lz4
233 );
234
235 let snappy = "snappy";
236 assert_eq!(
237 Compression::from_protocol_string(snappy).unwrap(),
238 Compression::Snappy
239 );
240
241 let snappy = "SNAPPY";
242 assert_eq!(
243 Compression::from_protocol_string(snappy).unwrap(),
244 Compression::Snappy
245 );
246
247 let none = "none";
248 assert_eq!(
249 Compression::from_protocol_string(none).unwrap(),
250 Compression::None
251 );
252
253 let none = "NONE";
254 assert_eq!(
255 Compression::from_protocol_string(none).unwrap(),
256 Compression::None
257 );
258 }
259
260 #[test]
261 fn test_compression_from_string() {
262 let lz4 = "lz4".to_string();
263 assert_eq!(Compression::from(lz4), Compression::Lz4);
264 let snappy = "snappy".to_string();
265 assert_eq!(Compression::from(snappy), Compression::Snappy);
266 let none = "x".to_string();
267 assert_eq!(Compression::from(none), Compression::None);
268 }
269
270 #[test]
271 fn test_compression_encode_snappy() {
272 let snappy_compression = Compression::Snappy;
273 let bytes = String::from("Hello World").into_bytes().to_vec();
274 snappy_compression
275 .encode(&bytes)
276 .expect("Should work without exceptions");
277 }
278
279 #[test]
280 fn test_compression_decode_snappy() {
281 let snappy_compression = Compression::Snappy;
282 let bytes = String::from("Hello World").into_bytes().to_vec();
283 let encoded = snappy_compression.encode(&bytes).unwrap();
284 assert_eq!(snappy_compression.decode(encoded).unwrap(), bytes);
285 }
286
287 #[test]
288 fn test_compression_encode_lz4() {
289 let snappy_compression = Compression::Lz4;
290 let bytes = String::from("Hello World").into_bytes().to_vec();
291 snappy_compression
292 .encode(&bytes)
293 .expect("Should work without exceptions");
294 }
295
296 #[test]
297 fn test_compression_decode_lz4() {
298 let lz4_compression = Compression::Lz4;
299 let bytes = String::from("Hello World").into_bytes().to_vec();
300 let encoded = lz4_compression.encode(&bytes).unwrap();
301 assert_eq!(lz4_compression.decode(encoded).unwrap(), bytes);
302 }
303
304 #[test]
305 fn test_compression_encode_none() {
306 let none_compression = Compression::None;
307 let bytes = String::from("Hello World").into_bytes().to_vec();
308 none_compression
309 .encode(&bytes)
310 .expect("Should work without exceptions");
311 }
312
313 #[test]
314 fn test_compression_decode_none() {
315 let none_compression = Compression::None;
316 let bytes = String::from("Hello World").into_bytes().to_vec();
317 let encoded = none_compression.encode(&bytes).unwrap();
318 assert_eq!(none_compression.decode(encoded).unwrap(), bytes);
319 }
320
321 #[test]
322 fn test_compression_decode_lz4_with_invalid_input() {
323 let lz4_compression = Compression::Lz4;
324 let decode = lz4_compression.decode(vec![0, 0, 0, 0x7f]);
325 assert!(decode.is_err());
326 }
327
328 #[test]
329 fn test_compression_encode_snappy_with_non_utf8() {
330 let snappy_compression = Compression::Snappy;
331 let v = vec![0xff, 0xff];
332 let encoded = snappy_compression
333 .encode(&v)
334 .expect("Should work without exceptions");
335 assert_eq!(snappy_compression.decode(encoded).unwrap(), v);
336 }
337}