use std::io::{Cursor, Read as _};
use flate2::read::ZlibDecoder;
use lzma_rs::decompress::{Options, UnpackedSize};
use crate::{
error::Error,
util::{checksum::crc32, read::Reader},
version::Version,
};
const CHUNK_SIZE: usize = 4096;
const CHUNK_CRC_LEN: usize = 4;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum BlockCompression {
Stored,
Zlib,
Lzma1,
}
#[derive(Debug)]
pub struct DecompressedBlock {
pub bytes: Box<[u8]>,
pub compression: BlockCompression,
pub consumed: usize,
}
pub fn decompress_block(
setup0: &[u8],
start: usize,
version: &Version,
) -> Result<DecompressedBlock, Error> {
decompress_block_inner(setup0, start, version, None)
}
pub(crate) fn decompress_block_with_decryption(
setup0: &[u8],
start: usize,
version: &Version,
key: &[u8; 32],
nonce: &[u8; 24],
) -> Result<DecompressedBlock, Error> {
decompress_block_inner(setup0, start, version, Some((key, nonce)))
}
fn decompress_block_inner(
setup0: &[u8],
start: usize,
version: &Version,
decryption: Option<(&[u8; 32], &[u8; 24])>,
) -> Result<DecompressedBlock, Error> {
let mut reader = Reader::at(setup0, start)?;
let (compression, stored_size, header_consumed) = parse_outer_header(&mut reader, version)?;
let stored_usize = usize::try_from(stored_size).map_err(|_| Error::Overflow {
what: "stored_size",
})?;
let compressed_start = reader.pos();
let compressed_end = compressed_start
.checked_add(stored_usize)
.ok_or(Error::Overflow {
what: "compressed end",
})?;
let framed = setup0
.get(compressed_start..compressed_end)
.ok_or(Error::Truncated {
what: "block compressed body",
})?;
let mut raw = unframe_chunks(framed)?;
if let Some((key, nonce)) = decryption {
crate::crypto::xchacha20::apply_keystream(key, nonce, &mut raw);
}
let bytes = match compression {
BlockCompression::Stored => raw.into_boxed_slice(),
BlockCompression::Zlib => decompress_zlib(&raw)?.into_boxed_slice(),
BlockCompression::Lzma1 => decompress_inno_lzma1(&raw)?.into_boxed_slice(),
};
let consumed = header_consumed
.checked_add(stored_usize)
.ok_or(Error::Overflow { what: "block end" })?;
Ok(DecompressedBlock {
bytes,
compression,
consumed,
})
}
fn parse_outer_header(
reader: &mut Reader<'_>,
version: &Version,
) -> Result<(BlockCompression, u32, usize), Error> {
let expected_crc = reader.u32_le("block expected_crc")?;
if version.at_least(6, 7, 0) {
let header9 = reader.array::<9>("block header (>=6.7.0)")?;
let actual_crc = crc32(&header9);
if actual_crc != expected_crc {
return Err(Error::BadChecksum {
what: "block header (>=6.7.0)",
expected: expected_crc,
actual: actual_crc,
});
}
let [s0, s1, s2, s3, s4, s5, s6, s7, flag] = header9;
let stored_size_u64 = u64::from_le_bytes([s0, s1, s2, s3, s4, s5, s6, s7]);
let stored_size = u32::try_from(stored_size_u64).map_err(|_| Error::Overflow {
what: "block stored_size > u32::MAX",
})?;
let compression = if flag == 0 {
BlockCompression::Stored
} else {
BlockCompression::Lzma1
};
Ok((compression, stored_size, 13))
} else if version.at_least(4, 0, 9) {
let header5 = reader.array::<5>("block header (>=4.0.9)")?;
let actual_crc = crc32(&header5);
if actual_crc != expected_crc {
return Err(Error::BadChecksum {
what: "block header (>=4.0.9)",
expected: expected_crc,
actual: actual_crc,
});
}
let [s0, s1, s2, s3, flag] = header5;
let stored_size = u32::from_le_bytes([s0, s1, s2, s3]);
let compression = if flag == 0 {
BlockCompression::Stored
} else if version.at_least(4, 1, 6) {
BlockCompression::Lzma1
} else {
BlockCompression::Zlib
};
Ok((compression, stored_size, 9))
} else {
let header8 = reader.array::<8>("block header (<4.0.9)")?;
let actual_crc = crc32(&header8);
if actual_crc != expected_crc {
return Err(Error::BadChecksum {
what: "block header (<4.0.9)",
expected: expected_crc,
actual: actual_crc,
});
}
let [c0, c1, c2, c3, u0, u1, u2, u3] = header8;
let compressed_size = u32::from_le_bytes([c0, c1, c2, c3]);
let uncompressed_size = u32::from_le_bytes([u0, u1, u2, u3]);
let (mut stored_size, compression) = if compressed_size == u32::MAX {
(uncompressed_size, BlockCompression::Stored)
} else {
(compressed_size, BlockCompression::Zlib)
};
const CHUNK_SIZE_U32: u32 = CHUNK_SIZE as u32;
const CHUNK_SIZE_MINUS_1: u32 = CHUNK_SIZE_U32.wrapping_sub(1);
const CHUNK_CRC_LEN_U32: u32 = CHUNK_CRC_LEN as u32;
let bumped = stored_size
.checked_add(CHUNK_SIZE_MINUS_1)
.ok_or(Error::Overflow {
what: "old block ceil",
})?;
let chunks = bumped.checked_div(CHUNK_SIZE_U32).ok_or(Error::Overflow {
what: "old block ceil-div",
})?;
let crc_overhead = chunks
.checked_mul(CHUNK_CRC_LEN_U32)
.ok_or(Error::Overflow {
what: "old block CRC overhead",
})?;
stored_size = stored_size
.checked_add(crc_overhead)
.ok_or(Error::Overflow {
what: "old stored_size",
})?;
Ok((compression, stored_size, 12))
}
}
fn unframe_chunks(framed: &[u8]) -> Result<Vec<u8>, Error> {
let mut out = Vec::with_capacity(framed.len());
let mut cursor = 0usize;
while cursor < framed.len() {
let crc_end = cursor.checked_add(CHUNK_CRC_LEN).ok_or(Error::Overflow {
what: "chunk CRC end",
})?;
let crc_bytes = framed
.get(cursor..crc_end)
.ok_or(Error::Truncated { what: "chunk CRC" })?;
let mut crc_arr = [0u8; 4];
crc_arr.copy_from_slice(crc_bytes);
let expected = u32::from_le_bytes(crc_arr);
let chunk_start = crc_end;
let remaining = framed.len().saturating_sub(chunk_start);
let chunk_len = remaining.min(CHUNK_SIZE);
let chunk_end = chunk_start
.checked_add(chunk_len)
.ok_or(Error::Overflow { what: "chunk end" })?;
let chunk = framed
.get(chunk_start..chunk_end)
.ok_or(Error::Truncated { what: "chunk body" })?;
let actual = crc32(chunk);
if actual != expected {
return Err(Error::BadChecksum {
what: "block sub-chunk",
expected,
actual,
});
}
out.extend_from_slice(chunk);
cursor = chunk_end;
}
Ok(out)
}
fn decompress_zlib(raw: &[u8]) -> Result<Vec<u8>, Error> {
let mut decoder = ZlibDecoder::new(raw);
let mut out = Vec::new();
decoder
.read_to_end(&mut out)
.map_err(|source| Error::Decompress {
stream: "block (zlib)",
source,
})?;
Ok(out)
}
fn decompress_inno_lzma1(raw: &[u8]) -> Result<Vec<u8>, Error> {
let mut input = Cursor::new(raw);
let mut out = Vec::new();
let opts = Options {
unpacked_size: UnpackedSize::UseProvided(None),
..Options::default()
};
lzma_rs::lzma_decompress_with_options(&mut input, &mut out, &opts).map_err(|e| {
Error::Decompress {
stream: "block (lzma1)",
source: std::io::Error::other(e.to_string()),
}
})?;
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn unframe_strips_crc_prefixes() {
let chunk_a: Vec<u8> = (0..CHUNK_SIZE).map(|i| (i & 0xFF) as u8).collect();
let chunk_b = b"final-tail".to_vec();
let crc_a = super::crc32(&chunk_a);
let crc_b = super::crc32(&chunk_b);
let mut framed = Vec::new();
framed.extend_from_slice(&crc_a.to_le_bytes());
framed.extend_from_slice(&chunk_a);
framed.extend_from_slice(&crc_b.to_le_bytes());
framed.extend_from_slice(&chunk_b);
let raw = unframe_chunks(&framed).unwrap();
assert_eq!(raw.len(), chunk_a.len() + chunk_b.len());
assert_eq!(&raw[..CHUNK_SIZE], chunk_a.as_slice());
assert_eq!(&raw[CHUNK_SIZE..], chunk_b.as_slice());
}
#[test]
fn unframe_rejects_bad_crc() {
let mut framed = vec![0u8; 4];
framed.extend_from_slice(b"abcd");
let err = unframe_chunks(&framed).unwrap_err();
assert!(matches!(err, Error::BadChecksum { .. }));
}
}