pub mod bzip2;
pub mod deflate;
pub mod lzma;
use core::fmt;
use crate::error::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionMethod {
Deflate,
Bzip2,
Lzma,
None,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionMode {
Solid,
NonSolid,
}
impl fmt::Display for CompressionMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
CompressionMethod::Deflate => "deflate",
CompressionMethod::Bzip2 => "bzip2",
CompressionMethod::Lzma => "lzma",
CompressionMethod::None => "none",
};
f.write_str(s)
}
}
impl fmt::Display for CompressionMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
CompressionMode::Solid => "solid",
CompressionMode::NonSolid => "non-solid",
};
f.write_str(s)
}
}
pub fn read_length_prefix(data: &[u8]) -> Result<(bool, u32), Error> {
if data.len() < 4 {
return Err(Error::TooShort {
expected: 4,
actual: data.len(),
context: "length prefix",
});
}
let raw = crate::util::read_u32_le(data, 0);
let is_compressed = raw & 0x8000_0000 != 0;
let size = raw & 0x7FFF_FFFF;
Ok((is_compressed, size))
}
pub fn detect_compression(data: &[u8]) -> CompressionMethod {
if data.is_empty() {
return CompressionMethod::None;
}
if data.first().copied() == Some(0x5D) && data.len() >= 5 {
return CompressionMethod::Lzma;
}
if data.first().copied() == Some(0x31) && data.len() >= 4 {
return CompressionMethod::Bzip2;
}
CompressionMethod::Deflate
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DecodeLimit {
Exact(usize),
Capped(usize),
Truncate(usize),
}
impl DecodeLimit {
#[inline]
pub fn size(self) -> usize {
match self {
DecodeLimit::Exact(n) | DecodeLimit::Capped(n) | DecodeLimit::Truncate(n) => n,
}
}
}
pub fn decompress_block(
data: &[u8],
method: CompressionMethod,
limit: DecodeLimit,
) -> Result<Vec<u8>, Error> {
match method {
CompressionMethod::Deflate => deflate::decompress_deflate(data, limit),
CompressionMethod::Bzip2 => bzip2::decompress_bzip2(data, limit),
CompressionMethod::Lzma => lzma::decompress_lzma(data, limit),
CompressionMethod::None => Ok(data.to_vec()),
}
}
pub fn decompress_header(
data: &[u8],
expected_size: usize,
) -> Result<(Vec<u8>, CompressionMethod, CompressionMode, usize), Error> {
let (is_compressed, size) = read_length_prefix(data)?;
let size_usize = size as usize;
let payload_end = 4_usize.checked_add(size_usize);
if !is_compressed && payload_end.is_some_and(|end| end <= data.len()) {
let bytes = data.get(4..).and_then(|s| s.get(..size_usize));
if let Some(bytes) = bytes {
return Ok((
bytes.to_vec(),
CompressionMethod::None,
CompressionMode::NonSolid,
payload_end.unwrap_or(0),
));
}
}
let compressed_size = size_usize;
let non_solid_consumed = 4_usize.saturating_add(compressed_size);
let non_solid_viable = is_compressed && data.len() >= non_solid_consumed;
let compressed_data: &[u8] = if non_solid_viable {
data.get(4..non_solid_consumed).unwrap_or(&[])
} else {
&[]
};
let method = detect_compression(compressed_data);
if let Ok(decompressed) =
decompress_block(compressed_data, method, DecodeLimit::Exact(expected_size))
&& !decompressed.is_empty()
{
return Ok((
decompressed,
method,
CompressionMode::NonSolid,
non_solid_consumed,
));
}
let methods = [
CompressionMethod::Lzma,
CompressionMethod::Deflate,
CompressionMethod::Bzip2,
];
for &m in &methods {
if m == method {
continue;
}
if let Ok(decompressed) =
decompress_block(compressed_data, m, DecodeLimit::Exact(expected_size))
&& !decompressed.is_empty()
{
return Ok((
decompressed,
m,
CompressionMode::NonSolid,
non_solid_consumed,
));
}
}
let solid_expected = expected_size.saturating_add(4); let solid_method = detect_compression(data);
if let Ok(decompressed) =
decompress_block(data, solid_method, DecodeLimit::Exact(solid_expected))
{
let stripped = strip_solid_prefix(decompressed)?;
return Ok((stripped, solid_method, CompressionMode::Solid, 0));
}
for &m in &methods {
if m == solid_method {
continue;
}
if let Ok(decompressed) = decompress_block(data, m, DecodeLimit::Exact(solid_expected)) {
let stripped = strip_solid_prefix(decompressed)?;
return Ok((stripped, m, CompressionMode::Solid, 0));
}
}
Err(Error::UnsupportedCompression)
}
fn strip_solid_prefix(data: Vec<u8>) -> Result<Vec<u8>, Error> {
if data.len() < 4 {
return Err(Error::TooShort {
expected: 4,
actual: data.len(),
context: "solid stream length prefix",
});
}
let prefix = crate::util::read_u32_le(&data, 0) as usize;
if prefix == data.len().saturating_sub(4) {
Ok(data.get(4..).unwrap_or(&[]).to_vec())
} else {
Ok(data)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn read_length_prefix_compressed() {
let val = 0x8000_0000u32 | 1000;
let data = val.to_le_bytes();
let (is_compressed, size) = read_length_prefix(&data).unwrap();
assert!(is_compressed);
assert_eq!(size, 1000);
}
#[test]
fn read_length_prefix_uncompressed() {
let val = 2048u32;
let data = val.to_le_bytes();
let (is_compressed, size) = read_length_prefix(&data).unwrap();
assert!(!is_compressed);
assert_eq!(size, 2048);
}
#[test]
fn read_length_prefix_too_short() {
let data = [0u8; 3];
assert!(read_length_prefix(&data).is_err());
}
#[test]
fn detect_compression_lzma() {
let data = [0x5D, 0x00, 0x00, 0x01, 0x00, 0xFF];
assert_eq!(detect_compression(&data), CompressionMethod::Lzma);
}
#[test]
fn detect_compression_bzip2() {
let data = [0x31, 0x41, 0x59, 0x26];
assert_eq!(detect_compression(&data), CompressionMethod::Bzip2);
}
#[test]
fn detect_compression_deflate_fallback() {
let data = [0x78, 0x9C, 0x01, 0x02];
assert_eq!(detect_compression(&data), CompressionMethod::Deflate);
}
#[test]
fn detect_compression_empty() {
assert_eq!(detect_compression(&[]), CompressionMethod::None);
}
#[test]
fn decompress_header_uncompressed() {
let payload = b"hello world test data";
let size = payload.len() as u32;
let mut data = Vec::new();
data.extend_from_slice(&size.to_le_bytes());
data.extend_from_slice(payload);
let (decompressed, method, mode, consumed) =
decompress_header(&data, payload.len()).unwrap();
assert_eq!(&decompressed, payload);
assert_eq!(method, CompressionMethod::None);
assert_eq!(mode, CompressionMode::NonSolid);
assert_eq!(consumed, 4 + payload.len());
}
}