use std::io::{self, Write};
use crate::{decompress::DecodeLimit, error::Error};
struct LimitedWriter {
buf: Vec<u8>,
limit: usize,
overflowed: bool,
}
impl Write for LimitedWriter {
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
let room = self.limit.saturating_sub(self.buf.len());
if data.len() > room {
if let Some(head) = data.get(..room) {
self.buf.extend_from_slice(head);
}
self.overflowed = true;
return Err(io::Error::other("decompressed output exceeds limit"));
}
self.buf.extend_from_slice(data);
Ok(data.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
pub fn decompress_lzma(compressed: &[u8], limit: DecodeLimit) -> Result<Vec<u8>, Error> {
if compressed.len() < 5 {
return Err(Error::DecompressionFailed {
method: "lzma",
detail: "LZMA stream too short (need at least 5 bytes for header)".into(),
});
}
let uncompressed_size_bytes: [u8; 8] = match limit {
DecodeLimit::Exact(size) => (size as u64).to_le_bytes(),
DecodeLimit::Capped(_) | DecodeLimit::Truncate(_) => [0xFF; 8],
};
let mut lzma_header = Vec::with_capacity(compressed.len().saturating_add(8));
let (props, body) = compressed.split_at(5);
lzma_header.extend_from_slice(props); lzma_header.extend_from_slice(&uncompressed_size_bytes);
lzma_header.extend_from_slice(body);
let max_output = limit.size();
let capacity = max_output.min(compressed.len().saturating_mul(4));
let mut writer = LimitedWriter {
buf: Vec::with_capacity(capacity),
limit: max_output,
overflowed: false,
};
let mut reader = std::io::BufReader::new(std::io::Cursor::new(&lzma_header));
match lzma_rs::lzma_decompress(&mut reader, &mut writer) {
Ok(()) => {}
Err(e) => {
if writer.overflowed {
match limit {
DecodeLimit::Capped(n) => return Err(Error::OutputTooLarge { limit: n }),
DecodeLimit::Truncate(_) | DecodeLimit::Exact(_) => {}
}
} else {
let msg = e.to_string();
if !writer.buf.is_empty() && msg.contains("more bytes are available") {
} else {
return Err(Error::DecompressionFailed {
method: "lzma",
detail: msg,
});
}
}
}
}
Ok(writer.buf)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn too_short_input() {
let result = decompress_lzma(&[0x5D, 0x00, 0x00], DecodeLimit::Capped(1024));
assert!(result.is_err());
}
const NSIS_EOS_STREAM: &[u8] = include_bytes!("../../tests/fixtures/lzma_eos_marker_file.bin");
#[test]
fn eos_marker_stream_decompresses_when_capped() {
let out = decompress_lzma(NSIS_EOS_STREAM, DecodeLimit::Capped(64 * 1024 * 1024))
.expect("EOS-terminated stream should decode with unknown size");
assert_eq!(out.len(), 1430, "decompressed size should match the icon");
assert_eq!(
out.get(..4),
Some(&[0x00, 0x00, 0x01, 0x00][..]),
"should be a valid .ico header"
);
}
#[test]
fn exact_size_larger_than_actual_rejects_eos_marker() {
let result = decompress_lzma(NSIS_EOS_STREAM, DecodeLimit::Exact(64 * 1024 * 1024));
assert!(
result.is_err(),
"an over-large exact size must not silently succeed on an EOS-terminated stream"
);
}
#[test]
fn capped_rejects_when_budget_below_actual() {
let result = decompress_lzma(NSIS_EOS_STREAM, DecodeLimit::Capped(512));
assert!(matches!(result, Err(Error::OutputTooLarge { limit: 512 })));
}
#[test]
fn truncate_caps_without_error() {
let out = decompress_lzma(NSIS_EOS_STREAM, DecodeLimit::Truncate(512)).unwrap();
assert_eq!(out.len(), 512);
}
}