use lz4_flex::frame::{FrameDecoder, FrameEncoder};
use lzma_rs::{lzma_compress, lzma_decompress};
use serde_derive::{Serialize, Deserialize};
use std::fmt::Display;
use std::io::{Cursor, Read, Seek, Write};
use crate::binary::{ReadBytes, WriteBytes};
use crate::error::{RLibError, Result};
#[cfg(test)]
mod test;
const MAGIC_NUMBERS_LZMA: [u32; 9] = [
0x0100005D,
0x1000005D,
0x0800005D,
0x1000005D,
0x2000005D,
0x4000005D,
0x8000005D,
0x0000005D,
0x0400005D,
];
const MAGIC_NUMBER_LZ4: u32 = 0x184D2204;
const MAGIC_NUMBER_ZSTD: u32 = 0xfd2fb528;
pub trait Compressible {
fn compress(&self, format: CompressionFormat) -> Result<Vec<u8>>;
}
pub trait Decompressible {
fn decompress(&self) -> Result<Vec<u8>>;
}
#[derive(Debug, Copy, Clone, Default, PartialEq, Serialize, Deserialize)]
pub enum CompressionFormat {
#[default]None,
Lzma1,
Lz4,
Zstd,
}
impl Compressible for [u8] {
fn compress(&self, format: CompressionFormat) -> Result<Vec<u8>> {
match format {
CompressionFormat::None => Ok(self.to_vec()),
CompressionFormat::Lzma1 => {
let mut dst = vec![];
dst.write_i32(self.len() as i32)?;
let mut compressed_data = vec![];
let mut src = Cursor::new(self);
lzma_compress(&mut src, &mut compressed_data).unwrap();
if compressed_data.len() < 13 {
return Err(RLibError::DataCannotBeCompressed);
}
dst.extend_from_slice(&compressed_data[..5]);
dst.extend_from_slice(&compressed_data[13..]);
Ok(dst)
},
CompressionFormat::Lz4 => {
let mut dst = vec![];
dst.write_u32(self.len() as u32)?;
let mut encoder = FrameEncoder::new(&mut dst);
encoder.write_all(self)?;
encoder.finish()?;
Ok(dst)
},
CompressionFormat::Zstd => {
let mut dst = vec![];
dst.write_u32(self.len() as u32)?;
let mut encoder = zstd::Encoder::new(&mut dst, 3)?;
encoder.include_checksum(true)?;
encoder.include_contentsize(true)?;
encoder.set_pledged_src_size(Some(self.len() as u64))?;
let mut src = Cursor::new(self.to_vec());
std::io::copy(&mut src, &mut encoder)?;
encoder.finish()?;
Ok(dst)
},
}
}
}
impl Decompressible for &[u8] {
fn decompress(&self) -> Result<Vec<u8>> {
if self.is_empty() {
return Ok(vec![]);
}
let mut src = Cursor::new(self);
let u_size = src.read_u32()?;
let magic_number = src.read_u32()?;
let format = if magic_number == MAGIC_NUMBER_ZSTD {
CompressionFormat::Zstd
} else if magic_number == MAGIC_NUMBER_LZ4 {
CompressionFormat::Lz4
} else if MAGIC_NUMBERS_LZMA.contains(&magic_number) {
CompressionFormat::Lzma1
}
else {
CompressionFormat::None
};
src.seek_relative(-4)?;
match format {
CompressionFormat::None => Ok(self.to_vec()),
CompressionFormat::Lzma1 => {
if self.len() < 9 {
return Err(RLibError::DataCannotBeDecompressed);
}
let mut fixed_data: Vec<u8> = Vec::with_capacity(self.len() + 4);
fixed_data.extend_from_slice(&src.read_slice(5, false)?);
fixed_data.write_u64(u_size as u64)?;
src.read_to_end(&mut fixed_data)?;
let mut dst = Vec::with_capacity(u_size as usize);
let mut reader = Cursor::new(fixed_data);
let result = lzma_decompress(&mut reader, &mut dst);
if result.is_err() {
src.set_position(4);
let mut fixed_data = Vec::with_capacity(self.len() + 4);
fixed_data.extend_from_slice(&src.read_slice(5, false)?);
fixed_data.write_u64(u64::MAX)?;
src.read_to_end(&mut fixed_data)?;
let mut dst = Vec::with_capacity(u_size as usize);
let mut reader = Cursor::new(fixed_data);
lzma_decompress(&mut reader, &mut dst)?;
Ok(dst)
} else {
Ok(dst)
}
},
CompressionFormat::Lz4 => {
let mut dst = Vec::with_capacity(u_size as usize);
let mut reader = FrameDecoder::new(src);
std::io::copy(&mut reader, &mut dst)?;
Ok(dst)
},
CompressionFormat::Zstd => {
let mut dst = Vec::with_capacity(u_size as usize);
zstd::stream::copy_decode(src, &mut dst)?;
Ok(dst)
},
}
}
}
impl From<&str> for CompressionFormat {
fn from(value: &str) -> Self {
match value {
"Lzma1" => Self::Lzma1,
"Lz4" => Self::Lz4,
"Zstd" => Self::Zstd,
_ => Self::None,
}
}
}
impl Display for CompressionFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Lzma1 => write!(f, "Lzma1"),
Self::Lz4 => write!(f, "Lz4"),
Self::Zstd => write!(f, "Zstd"),
Self::None => write!(f, "None"),
}
}
}