use std::borrow::Cow;
use std::fmt::Display;
use std::io::{Cursor, Read, Write, copy};
use std::str::FromStr;
use std::time::Instant;
use lz4_flex::frame::{FrameDecoder, FrameEncoder};
use super::byte_grouping::BG4Predictor;
use super::byte_grouping::bg4::{bg4_regroup, bg4_split};
use crate::error::{CoreError, Result};
pub static mut BG4_SPLIT_RUNTIME: f64 = 0.;
pub static mut BG4_REGROUP_RUNTIME: f64 = 0.;
pub static mut BG4_LZ4_COMPRESS_RUNTIME: f64 = 0.;
pub static mut BG4_LZ4_DECOMPRESS_RUNTIME: f64 = 0.;
#[repr(u8)]
#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
pub enum CompressionScheme {
None = 0,
LZ4 = 1,
ByteGrouping4LZ4 = 2, #[default]
Auto = 99,
}
impl Display for CompressionScheme {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", Into::<&str>::into(self))
}
}
impl From<&CompressionScheme> for &'static str {
fn from(value: &CompressionScheme) -> Self {
match value {
CompressionScheme::Auto => "auto",
CompressionScheme::LZ4 => "lz4",
CompressionScheme::ByteGrouping4LZ4 => "bg4-lz4",
CompressionScheme::None => "none",
}
}
}
impl From<CompressionScheme> for &'static str {
fn from(value: CompressionScheme) -> Self {
From::from(&value)
}
}
impl TryFrom<u8> for CompressionScheme {
type Error = CoreError;
fn try_from(value: u8) -> Result<Self> {
match value {
0 => Ok(CompressionScheme::None),
1 => Ok(CompressionScheme::LZ4),
2 => Ok(CompressionScheme::ByteGrouping4LZ4),
99 => Ok(CompressionScheme::Auto),
_ => Err(CoreError::MalformedData(format!("cannot convert value {value} to CompressionScheme"))),
}
}
}
impl FromStr for CompressionScheme {
type Err = CoreError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.trim().to_lowercase().as_str() {
"" | "auto" => Ok(CompressionScheme::Auto),
"none" => Ok(CompressionScheme::None),
"lz4" => Ok(CompressionScheme::LZ4),
"bg4-lz4" => Ok(CompressionScheme::ByteGrouping4LZ4),
_ => Err(CoreError::MalformedData(format!(
"Invalid compression scheme '{s}'. Valid values are: auto, none, lz4, bg4-lz4."
))),
}
}
}
impl CompressionScheme {
pub fn resolve_for_data(&self, data: &[u8]) -> Self {
if *self == CompressionScheme::Auto {
CompressionScheme::choose_from_data(data)
} else {
*self
}
}
pub fn compress_from_slice<'a>(&self, data: &'a [u8]) -> Result<Cow<'a, [u8]>> {
Ok(match self {
CompressionScheme::Auto => return self.resolve_for_data(data).compress_from_slice(data),
CompressionScheme::None => data.into(),
CompressionScheme::LZ4 => lz4_compress_from_slice(data).map(Cow::from)?,
CompressionScheme::ByteGrouping4LZ4 => bg4_lz4_compress_from_slice(data).map(Cow::from)?,
})
}
pub fn decompress_from_slice<'a>(&self, data: &'a [u8]) -> Result<Cow<'a, [u8]>> {
Ok(match self {
CompressionScheme::Auto => {
return Err(CoreError::MalformedData("Cannot decompress with Auto scheme".to_string()));
},
CompressionScheme::None => data.into(),
CompressionScheme::LZ4 => lz4_decompress_from_slice(data).map(Cow::from)?,
CompressionScheme::ByteGrouping4LZ4 => bg4_lz4_decompress_from_slice(data).map(Cow::from)?,
})
}
pub fn decompress_from_reader<R: Read, W: Write>(&self, reader: &mut R, writer: &mut W) -> Result<u64> {
Ok(match self {
CompressionScheme::Auto => {
return Err(CoreError::MalformedData("Cannot decompress with Auto scheme".to_string()));
},
CompressionScheme::None => copy(reader, writer)?,
CompressionScheme::LZ4 => lz4_decompress_from_reader(reader, writer)?,
CompressionScheme::ByteGrouping4LZ4 => bg4_lz4_decompress_from_reader(reader, writer)?,
})
}
pub fn choose_from_data(data: &[u8]) -> Self {
let mut bg4_predictor = BG4Predictor::default();
bg4_predictor.add_data(0, data);
if bg4_predictor.bg4_recommended() {
CompressionScheme::ByteGrouping4LZ4
} else {
CompressionScheme::LZ4
}
}
}
pub fn lz4_compress_from_slice(data: &[u8]) -> Result<Vec<u8>> {
let mut enc = FrameEncoder::new(Vec::new());
enc.write_all(data)?;
Ok(enc.finish()?)
}
pub fn lz4_decompress_from_slice(data: &[u8]) -> Result<Vec<u8>> {
let mut dest = vec![];
lz4_decompress_from_reader(&mut Cursor::new(data), &mut dest)?;
Ok(dest)
}
fn lz4_decompress_from_reader<R: Read, W: Write>(reader: &mut R, writer: &mut W) -> Result<u64> {
let mut dec = FrameDecoder::new(reader);
Ok(copy(&mut dec, writer)?)
}
pub fn bg4_lz4_compress_from_slice(data: &[u8]) -> Result<Vec<u8>> {
let s = Instant::now();
let groups = bg4_split(data);
unsafe {
BG4_SPLIT_RUNTIME += s.elapsed().as_secs_f64();
}
let s = Instant::now();
let mut dest = vec![];
let mut enc = FrameEncoder::new(&mut dest);
enc.write_all(&groups)?;
enc.finish()?;
unsafe {
BG4_LZ4_COMPRESS_RUNTIME += s.elapsed().as_secs_f64();
}
Ok(dest)
}
pub fn bg4_lz4_decompress_from_slice(data: &[u8]) -> Result<Vec<u8>> {
let mut dest = vec![];
bg4_lz4_decompress_from_reader(&mut Cursor::new(data), &mut dest)?;
Ok(dest)
}
fn bg4_lz4_decompress_from_reader<R: Read, W: Write>(reader: &mut R, writer: &mut W) -> Result<u64> {
let s = Instant::now();
let mut g = vec![];
FrameDecoder::new(reader).read_to_end(&mut g)?;
unsafe {
BG4_LZ4_DECOMPRESS_RUNTIME += s.elapsed().as_secs_f64();
}
let s = Instant::now();
let regrouped = bg4_regroup(&g);
unsafe {
BG4_REGROUP_RUNTIME += s.elapsed().as_secs_f64();
}
writer.write_all(®rouped)?;
Ok(regrouped.len() as u64)
}
#[cfg(test)]
mod tests {
use std::mem::size_of;
use half::prelude::*;
use rand::RngExt;
use super::*;
#[test]
fn test_default_is_auto() {
assert_eq!(CompressionScheme::default(), CompressionScheme::Auto);
}
#[test]
fn test_to_str() {
assert_eq!(Into::<&str>::into(CompressionScheme::Auto), "auto");
assert_eq!(Into::<&str>::into(CompressionScheme::None), "none");
assert_eq!(Into::<&str>::into(CompressionScheme::LZ4), "lz4");
assert_eq!(Into::<&str>::into(CompressionScheme::ByteGrouping4LZ4), "bg4-lz4");
}
#[test]
fn test_from_str() {
assert_eq!("".parse::<CompressionScheme>().unwrap(), CompressionScheme::Auto);
assert_eq!("auto".parse::<CompressionScheme>().unwrap(), CompressionScheme::Auto);
assert_eq!("none".parse::<CompressionScheme>().unwrap(), CompressionScheme::None);
assert_eq!("lz4".parse::<CompressionScheme>().unwrap(), CompressionScheme::LZ4);
assert_eq!("bg4-lz4".parse::<CompressionScheme>().unwrap(), CompressionScheme::ByteGrouping4LZ4);
assert_eq!(" LZ4 ".parse::<CompressionScheme>().unwrap(), CompressionScheme::LZ4);
assert!("zstd".parse::<CompressionScheme>().is_err());
}
#[test]
fn test_display() {
assert_eq!(format!("{}", CompressionScheme::Auto), "auto");
assert_eq!(format!("{}", CompressionScheme::None), "none");
assert_eq!(format!("{}", CompressionScheme::LZ4), "lz4");
assert_eq!(format!("{}", CompressionScheme::ByteGrouping4LZ4), "bg4-lz4");
}
#[test]
fn test_parse_with_config_enum() {
use xet_runtime::utils::ConfigEnum;
let ce = ConfigEnum::new("auto", &["", "auto", "none", "lz4", "bg4-lz4"]);
let scheme: CompressionScheme = ce.parse().unwrap();
assert_eq!(scheme, CompressionScheme::Auto);
let ce = ConfigEnum::new("lz4", &["", "auto", "none", "lz4", "bg4-lz4"]);
let scheme: CompressionScheme = ce.parse().unwrap();
assert_eq!(scheme, CompressionScheme::LZ4);
let ce = ConfigEnum::new("none", &["", "auto", "none", "lz4", "bg4-lz4"]);
let scheme: CompressionScheme = ce.parse().unwrap();
assert_eq!(scheme, CompressionScheme::None);
}
#[test]
fn test_from_u8() {
assert_eq!(CompressionScheme::try_from(0u8), Ok(CompressionScheme::None));
assert_eq!(CompressionScheme::try_from(1u8), Ok(CompressionScheme::LZ4));
assert_eq!(CompressionScheme::try_from(2u8), Ok(CompressionScheme::ByteGrouping4LZ4));
assert_eq!(CompressionScheme::try_from(99u8), Ok(CompressionScheme::Auto));
assert!(CompressionScheme::try_from(3u8).is_err());
assert!(CompressionScheme::try_from(4u8).is_err());
}
#[test]
fn test_resolve_for_data() {
let data = vec![0u8; 1024];
let resolved = CompressionScheme::Auto.resolve_for_data(&data);
assert_ne!(resolved, CompressionScheme::Auto);
assert_eq!(CompressionScheme::LZ4.resolve_for_data(&data), CompressionScheme::LZ4);
assert_eq!(CompressionScheme::None.resolve_for_data(&data), CompressionScheme::None);
}
#[test]
fn test_bg4_lz4() {
let mut rng = rand::rng();
for i in 0..4 {
let n = 64 * 1024 + i * 23;
let all_zeros = vec![0u8; n];
let all_ones = vec![1u8; n];
let all_0xff = vec![0xFF; n];
let random_u8s: Vec<_> = (0..n).map(|_| rng.random_range(0..255)).collect();
let random_f32s_ng1_1: Vec<_> = (0..n / size_of::<f32>())
.map(|_| rng.random_range(-1.0f32..=1.0))
.flat_map(|f| f.to_le_bytes())
.collect();
let random_f32s_0_2: Vec<_> = (0..n / size_of::<f32>())
.map(|_| rng.random_range(0f32..=2.0))
.flat_map(|f| f.to_le_bytes())
.collect();
let random_f64s_ng1_1: Vec<_> = (0..n / size_of::<f64>())
.map(|_| rng.random_range(-1.0f64..=1.0))
.flat_map(|f| f.to_le_bytes())
.collect();
let random_f64s_0_2: Vec<_> = (0..n / size_of::<f64>())
.map(|_| rng.random_range(0f64..=2.0))
.flat_map(|f| f.to_le_bytes())
.collect();
let random_f16s_ng1_1: Vec<_> = (0..n / size_of::<f16>())
.map(|_| f16::from_f32(rng.random_range(-1.0f32..=1.0)))
.flat_map(|f| f.to_le_bytes())
.collect();
let random_f16s_0_2: Vec<_> = (0..n / size_of::<f16>())
.map(|_| f16::from_f32(rng.random_range(0f32..=2.0)))
.flat_map(|f| f.to_le_bytes())
.collect();
let random_bf16s_ng1_1: Vec<_> = (0..n / size_of::<bf16>())
.map(|_| bf16::from_f32(rng.random_range(-1.0f32..=1.0)))
.flat_map(|f| f.to_le_bytes())
.collect();
let random_bf16s_0_2: Vec<_> = (0..n / size_of::<bf16>())
.map(|_| bf16::from_f32(rng.random_range(0f32..=2.0)))
.flat_map(|f| f.to_le_bytes())
.collect();
let dataset = [
all_zeros, all_ones, all_0xff, random_u8s, random_f32s_ng1_1, random_f32s_0_2, random_f64s_ng1_1, random_f64s_0_2, random_f16s_ng1_1, random_f16s_0_2, random_bf16s_ng1_1, random_bf16s_0_2, ];
for data in dataset {
let bg4_lz4_compressed = bg4_lz4_compress_from_slice(&data).unwrap();
let bg4_lz4_uncompressed = bg4_lz4_decompress_from_slice(&bg4_lz4_compressed).unwrap();
assert_eq!(data.len(), bg4_lz4_uncompressed.len());
assert_eq!(data, bg4_lz4_uncompressed);
let lz4_compressed = lz4_compress_from_slice(&data).unwrap();
let lz4_uncompressed = lz4_decompress_from_slice(&lz4_compressed).unwrap();
assert_eq!(data, lz4_uncompressed);
let compression_scheme_predictor = CompressionScheme::choose_from_data(&data);
println!(
"Compression ratio: {:.2}, {:.2} (KL predicted = {compression_scheme_predictor:?}",
data.len() as f32 / bg4_lz4_compressed.len() as f32,
data.len() as f32 / lz4_compressed.len() as f32
);
}
}
}
}