use crate::error::IoError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Compression {
None,
#[cfg(feature = "snappy")]
Snappy,
#[cfg(feature = "zstd")]
Zstd,
}
pub fn compress(input: &[u8], codec: Compression) -> Result<Vec<u8>, IoError> {
match codec {
Compression::None => Ok(input.to_vec()),
#[cfg(feature = "snappy")]
Compression::Snappy => snappy_compress(input),
#[cfg(feature = "zstd")]
Compression::Zstd => zstd_compress(input),
}
}
#[cfg(feature = "snappy")]
fn snappy_compress(input: &[u8]) -> Result<Vec<u8>, IoError> {
use snap::raw::{Encoder, max_compress_len};
let mut encoder = Encoder::new();
let max_len = max_compress_len(input.len());
let mut out = vec![0u8; max_len];
let compressed_len = encoder
.compress(input, &mut out)
.map_err(|e| IoError::Compression(format!("Snappy compression failed: {:?}", e)))?;
out.truncate(compressed_len);
Ok(out)
}
#[cfg(feature = "zstd")]
fn zstd_compress(input: &[u8]) -> Result<Vec<u8>, IoError> {
zstd::stream::encode_all(input, 1)
.map_err(|e| IoError::Compression(format!("Zstd compression failed: {e}")))
}
pub fn decompress(input: &[u8], codec: Compression) -> Result<Vec<u8>, IoError> {
match codec {
Compression::None => Ok(input.to_vec()),
#[cfg(feature = "snappy")]
Compression::Snappy => snappy_decompress(input),
#[cfg(feature = "zstd")]
Compression::Zstd => zstd_decompress(input),
}
}
#[cfg(feature = "snappy")]
fn snappy_decompress(input: &[u8]) -> Result<Vec<u8>, IoError> {
use snap::raw::Decoder;
let mut decoder = Decoder::new();
decoder
.decompress_vec(input)
.map_err(|e| IoError::Compression(format!("Snappy decompression failed: {:?}", e)))
}
#[cfg(feature = "zstd")]
fn zstd_decompress(input: &[u8]) -> Result<Vec<u8>, IoError> {
zstd::stream::decode_all(input)
.map_err(|e| IoError::Compression(format!("Zstd decompression failed: {e}")))
}
pub fn parquet_codec_name(codec: Compression) -> &'static str {
match codec {
Compression::None => "UNCOMPRESSED",
#[cfg(feature = "snappy")]
Compression::Snappy => "SNAPPY",
#[cfg(feature = "zstd")]
Compression::Zstd => "ZSTD",
}
}