use crate::error::{Error, Result};
use crate::parser_config::ParserOptions;
mod ascii85;
mod ascii_hex;
mod brotli;
mod ccitt;
mod dct;
mod flate;
mod jbig2;
mod lzw;
mod predictor;
mod runlength;
pub use ascii85::Ascii85Decoder;
pub use ascii_hex::AsciiHexDecoder;
pub use brotli::BrotliDecoder;
pub use ccitt::CcittFaxDecoder;
pub use dct::DctDecoder;
pub use flate::FlateDecoder;
pub use jbig2::Jbig2Decoder;
pub use lzw::LzwDecoder;
pub use predictor::{decode_predictor, CcittParams, DecodeParams, PngPredictor};
pub use runlength::RunLengthDecoder;
const DEFAULT_MAX_DECOMPRESSION_RATIO: u32 = 100;
const DEFAULT_MAX_DECOMPRESSED_SIZE: usize = 100 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Filter {
FlateDecode,
ASCIIHexDecode,
ASCII85Decode,
LZWDecode,
RunLengthDecode,
DCTDecode,
CCITTFaxDecode,
JBIG2Decode,
BrotliDecode,
}
pub trait StreamDecoder {
fn decode(&self, input: &[u8]) -> Result<Vec<u8>>;
fn name(&self) -> &str;
}
fn normalize_filter_name(name: &str) -> Result<&'static str> {
match name {
"FlateDecode" => return Ok("FlateDecode"),
"ASCIIHexDecode" => return Ok("ASCIIHexDecode"),
"ASCII85Decode" => return Ok("ASCII85Decode"),
"LZWDecode" => return Ok("LZWDecode"),
"RunLengthDecode" => return Ok("RunLengthDecode"),
"DCTDecode" => return Ok("DCTDecode"),
"CCITTFaxDecode" => return Ok("CCITTFaxDecode"),
"JBIG2Decode" => return Ok("JBIG2Decode"),
"BrotliDecode" => return Ok("BrotliDecode"),
_ => {},
}
match name {
"Fl" => return Ok("FlateDecode"),
"AHx" => return Ok("ASCIIHexDecode"),
"A85" => return Ok("ASCII85Decode"),
"LZW" => return Ok("LZWDecode"),
"RL" => return Ok("RunLengthDecode"),
"DCT" => return Ok("DCTDecode"),
"CCF" => return Ok("CCITTFaxDecode"),
_ => {},
}
let lower = name.to_ascii_lowercase();
match lower.as_str() {
"flatedecode" => Ok("FlateDecode"),
"asciihexdecode" => Ok("ASCIIHexDecode"),
"ascii85decode" => Ok("ASCII85Decode"),
"lzwdecode" => Ok("LZWDecode"),
"runlengthdecode" => Ok("RunLengthDecode"),
"dctdecode" => Ok("DCTDecode"),
"ccittfaxdecode" => Ok("CCITTFaxDecode"),
"jbig2decode" => Ok("JBIG2Decode"),
"brotlidecode" => Ok("BrotliDecode"),
_ => Err(Error::UnsupportedFilter(name.to_string())),
}
}
fn create_decoder(filter_name: &str) -> Result<Box<dyn StreamDecoder>> {
let canonical = normalize_filter_name(filter_name)?;
Ok(match canonical {
"FlateDecode" => Box::new(FlateDecoder::default()),
"ASCIIHexDecode" => Box::new(AsciiHexDecoder),
"ASCII85Decode" => Box::new(Ascii85Decoder),
"LZWDecode" => Box::new(LzwDecoder),
"RunLengthDecode" => Box::new(RunLengthDecoder),
"DCTDecode" => Box::new(DctDecoder),
"CCITTFaxDecode" => Box::new(CcittFaxDecoder),
"JBIG2Decode" => Box::new(Jbig2Decoder),
"BrotliDecode" => Box::new(BrotliDecoder),
_ => unreachable!(),
})
}
pub fn decode_stream(data: &[u8], filters: &[String]) -> Result<Vec<u8>> {
decode_stream_with_params(data, filters, None)
}
pub fn decode_stream_with_options(
data: &[u8],
filters: &[String],
params: Option<&DecodeParams>,
options: Option<&ParserOptions>,
) -> Result<Vec<u8>> {
let max_ratio = options
.map(|o| o.max_decompression_ratio)
.unwrap_or(DEFAULT_MAX_DECOMPRESSION_RATIO);
let max_size = options
.map(|o| o.max_decompressed_size)
.unwrap_or(DEFAULT_MAX_DECOMPRESSED_SIZE);
let compressed_size = data.len();
let mut current = data.to_vec();
for filter_name in filters {
let decoder = create_decoder(filter_name)?;
current = decoder.decode(¤t)?;
if max_ratio > 0 && compressed_size > 0 {
let ratio = current.len() as u64 / compressed_size.max(1) as u64;
if ratio > max_ratio as u64 {
return Err(Error::Decode(format!(
"Decompression bomb detected: ratio {}:1 exceeds limit {}:1 (compressed: {} bytes, decompressed: {} bytes)",
ratio,
max_ratio,
compressed_size,
current.len()
)));
}
}
if max_size > 0 && current.len() > max_size {
return Err(Error::Decode(format!(
"Decompression bomb detected: decompressed size {} bytes exceeds limit {} bytes",
current.len(),
max_size
)));
}
}
if let Some(params) = params {
if params.predictor != 1 {
current = decode_predictor(¤t, params)?;
}
}
Ok(current)
}
pub fn decode_stream_with_params(
data: &[u8],
filters: &[String],
params: Option<&DecodeParams>,
) -> Result<Vec<u8>> {
let mut current = data.to_vec();
for filter_name in filters {
let decoder = create_decoder(filter_name)?;
current = decoder.decode(¤t)?;
}
if let Some(params) = params {
if params.predictor != 1 {
current = decode_predictor(¤t, params)?;
}
}
Ok(current)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decode_stream_no_filters() {
let data = b"Hello, World!";
let result = decode_stream(data, &[]).unwrap();
assert_eq!(result, data);
}
#[test]
fn test_decode_stream_unsupported_filter() {
let data = b"test";
let filters = vec!["UnsupportedFilter".to_string()];
let result = decode_stream(data, &filters);
assert!(result.is_err());
match result {
Err(crate::error::Error::UnsupportedFilter(name)) => {
assert_eq!(name, "UnsupportedFilter");
},
_ => panic!("Expected UnsupportedFilter error"),
}
}
#[test]
fn test_decode_stream_pipeline() {
let data = b"48656C6C6F"; let filters = vec!["ASCIIHexDecode".to_string()];
let result = decode_stream(data, &filters).unwrap();
assert_eq!(result, b"Hello");
}
#[test]
fn test_normalize_filter_abbreviations() {
assert_eq!(normalize_filter_name("A85").unwrap(), "ASCII85Decode");
assert_eq!(normalize_filter_name("AHx").unwrap(), "ASCIIHexDecode");
assert_eq!(normalize_filter_name("LZW").unwrap(), "LZWDecode");
assert_eq!(normalize_filter_name("Fl").unwrap(), "FlateDecode");
assert_eq!(normalize_filter_name("RL").unwrap(), "RunLengthDecode");
assert_eq!(normalize_filter_name("CCF").unwrap(), "CCITTFaxDecode");
assert_eq!(normalize_filter_name("DCT").unwrap(), "DCTDecode");
}
#[test]
fn test_normalize_filter_case_insensitive() {
assert_eq!(normalize_filter_name("Flatedecode").unwrap(), "FlateDecode");
assert_eq!(normalize_filter_name("FLATEDECODE").unwrap(), "FlateDecode");
assert_eq!(normalize_filter_name("flatedecode").unwrap(), "FlateDecode");
assert_eq!(normalize_filter_name("ascii85decode").unwrap(), "ASCII85Decode");
assert_eq!(normalize_filter_name("ASCIIHEXDECODE").unwrap(), "ASCIIHexDecode");
}
#[test]
fn test_normalize_filter_unknown() {
let result = normalize_filter_name("BogusFilter");
assert!(result.is_err());
match result {
Err(crate::error::Error::UnsupportedFilter(name)) => {
assert_eq!(name, "BogusFilter");
},
_ => panic!("Expected UnsupportedFilter error"),
}
}
#[test]
fn test_decode_stream_with_abbreviation() {
let data = b"48656C6C6F"; let filters = vec!["AHx".to_string()];
let result = decode_stream(data, &filters).unwrap();
assert_eq!(result, b"Hello");
}
}