use anyhow::{bail, Result};
use crate::cpio;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Compression {
None,
Gzip,
Bzip2,
Zstd,
}
impl std::fmt::Display for Compression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Compression::None => write!(f, "none"),
Compression::Gzip => write!(f, "gzip"),
Compression::Bzip2 => write!(f, "bzip2"),
Compression::Zstd => write!(f, "zstd"),
}
}
}
impl std::str::FromStr for Compression {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"none" => Ok(Compression::None),
"gzip" | "gz" => Ok(Compression::Gzip),
"bzip2" | "bz2" => Ok(Compression::Bzip2),
"zstd" | "zst" => Ok(Compression::Zstd),
_ => bail!("unknown compression: {s} (expected: none, gzip, bzip2, zstd)"),
}
}
}
#[derive(Debug)]
pub struct RawSegment {
pub offset: usize,
pub data: Vec<u8>,
pub compression: Compression,
}
pub fn detect_compression(data: &[u8]) -> Option<Compression> {
if data.len() >= 6 && &data[..6] == b"070701" {
Some(Compression::None)
} else if data.len() >= 2 && data[0] == 0x1f && data[1] == 0x8b {
Some(Compression::Gzip)
} else if data.len() >= 3 && &data[..3] == b"BZh" {
Some(Compression::Bzip2)
} else if data.len() >= 4 && data[..4] == [0x28, 0xb5, 0x2f, 0xfd] {
Some(Compression::Zstd)
} else {
None
}
}
fn compressed_size(data: &[u8], comp: Compression) -> Result<usize> {
match comp {
Compression::None => Ok(data.len()),
Compression::Gzip => gzip_compressed_size(data),
Compression::Bzip2 => bzip2_compressed_size(data),
Compression::Zstd => zstd_compressed_size(data),
}
}
fn gzip_compressed_size(data: &[u8]) -> Result<usize> {
anyhow::ensure!(data.len() >= 10, "gzip data too short");
anyhow::ensure!(data[0] == 0x1f && data[1] == 0x8b, "not gzip");
let flg = data[3];
let mut pos: usize = 10;
if flg & 4 != 0 {
anyhow::ensure!(pos + 2 <= data.len(), "truncated gzip FEXTRA");
let xlen = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
pos += 2 + xlen;
}
if flg & 8 != 0 {
while pos < data.len() && data[pos] != 0 {
pos += 1;
}
pos += 1; }
if flg & 16 != 0 {
while pos < data.len() && data[pos] != 0 {
pos += 1;
}
pos += 1;
}
if flg & 2 != 0 {
pos += 2;
}
let mut decomp = flate2::Decompress::new(false);
let mut out_buf = vec![0u8; 64 * 1024];
loop {
let in_before = decomp.total_in() as usize;
let _ = in_before; let status = decomp.decompress(
&data[pos + decomp.total_in() as usize..],
&mut out_buf,
flate2::FlushDecompress::None,
)?;
if status == flate2::Status::StreamEnd {
break;
}
}
let total = pos + decomp.total_in() as usize + 8; Ok(total)
}
fn bzip2_compressed_size(data: &[u8]) -> Result<usize> {
use std::io::Read;
let mut decoder = bzip2::read::BzDecoder::new(data);
let mut out = Vec::new();
decoder.read_to_end(&mut out)?;
let mut decomp = bzip2::Decompress::new(false);
let mut out_buf = vec![0u8; 64 * 1024];
loop {
let status = decomp.decompress(
&data[decomp.total_in() as usize..],
&mut out_buf,
)?;
if status == bzip2::Status::MemNeeded && decomp.total_out() > 0 {
continue;
}
if (status == bzip2::Status::StreamEnd || status == bzip2::Status::Ok)
&& decomp.total_out() > 0 {
let remaining = &data[decomp.total_in() as usize..];
if remaining.is_empty() || remaining[0] == 0 || detect_compression(remaining).is_some() {
break;
}
}
if status == bzip2::Status::StreamEnd {
break;
}
}
Ok(decomp.total_in() as usize)
}
fn zstd_compressed_size(data: &[u8]) -> Result<usize> {
let mut pos = 0;
loop {
if pos >= data.len() {
break;
}
if data.len() - pos < 4 || data[pos..pos + 4] != [0x28, 0xb5, 0x2f, 0xfd] {
break;
}
let frame_size = zstd::zstd_safe::find_frame_compressed_size(&data[pos..])
.map_err(|code| anyhow::anyhow!("zstd frame error: {code}"))?;
pos += frame_size;
}
Ok(pos)
}
pub fn split_segments(data: &[u8]) -> Result<Vec<RawSegment>> {
let mut segments = Vec::new();
let mut pos = 0;
while pos < data.len() {
while pos < data.len() && data[pos] == 0 {
pos += 1;
}
if pos >= data.len() {
break;
}
let comp = detect_compression(&data[pos..])
.ok_or_else(|| anyhow::anyhow!("unknown format at offset {pos}"))?;
match comp {
Compression::None => {
let end = cpio::scan_archive_end(&data[pos..])
.map(|len| pos + len)
.unwrap_or(data.len());
segments.push(RawSegment {
offset: pos,
data: data[pos..end].to_vec(),
compression: Compression::None,
});
pos = end;
}
comp => {
let consumed = compressed_size(&data[pos..], comp)?;
let end = pos + consumed;
segments.push(RawSegment {
offset: pos,
data: data[pos..end].to_vec(),
compression: comp,
});
pos = end;
}
}
}
Ok(segments)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn detect_cpio() {
assert_eq!(detect_compression(b"070701"), Some(Compression::None));
}
#[test]
fn detect_gzip() {
assert_eq!(
detect_compression(&[0x1f, 0x8b, 0x08]),
Some(Compression::Gzip)
);
}
#[test]
fn detect_bzip2() {
assert_eq!(detect_compression(b"BZh9"), Some(Compression::Bzip2));
}
#[test]
fn detect_zstd() {
assert_eq!(
detect_compression(&[0x28, 0xb5, 0x2f, 0xfd]),
Some(Compression::Zstd)
);
}
#[test]
fn parse_compression_str() {
assert_eq!("gzip".parse::<Compression>().unwrap(), Compression::Gzip);
assert_eq!("bz2".parse::<Compression>().unwrap(), Compression::Bzip2);
assert_eq!("zstd".parse::<Compression>().unwrap(), Compression::Zstd);
assert_eq!("none".parse::<Compression>().unwrap(), Compression::None);
}
}