use crate::error::{Error, Result};
use crate::object::{Dictionary, Object};
pub fn apply_chain(
data: &[u8],
filters: &[String],
decode_parms: Option<&Object>,
) -> Result<Vec<u8>> {
let mut current = data.to_vec();
for (i, name) in filters.iter().enumerate() {
let params = parms_for_filter(decode_parms, i);
current = apply_one(¤t, name, params.as_ref())?;
}
Ok(current)
}
fn parms_for_filter(decode_parms: Option<&Object>, idx: usize) -> Option<Dictionary> {
match decode_parms? {
Object::Dictionary(d) if idx == 0 => Some(d.clone()),
Object::Array(arr) => match arr.get(idx)? {
Object::Dictionary(d) => Some(d.clone()),
_ => None,
},
_ => None,
}
}
fn apply_one(data: &[u8], name: &str, parms: Option<&Dictionary>) -> Result<Vec<u8>> {
match name {
"FlateDecode" | "Fl" => {
let decoded = flate_decode(data)?;
apply_png_predictor(&decoded, parms)
}
"LZWDecode" | "LZW" => {
let decoded = lzw_decode(data)?;
apply_png_predictor(&decoded, parms)
}
"ASCIIHexDecode" | "AHx" => ascii_hex_decode(data),
"ASCII85Decode" | "A85" => ascii_85_decode(data),
"DCTDecode" | "DCT" | "JPXDecode" | "JBIG2Decode" | "CCITTFaxDecode" | "RunLengthDecode"
| "RL" => Err(Error::UnsupportedFilter(name.into())),
other => Err(Error::UnsupportedFilter(other.into())),
}
}
fn flate_decode(data: &[u8]) -> Result<Vec<u8>> {
use flate2::read::ZlibDecoder;
use std::io::Read;
let mut decoder = ZlibDecoder::new(data);
let mut out = Vec::with_capacity(data.len() * 2);
decoder
.read_to_end(&mut out)
.map_err(|e| Error::Decompression(e.to_string()))?;
Ok(out)
}
fn lzw_decode(data: &[u8]) -> Result<Vec<u8>> {
const CLEAR: u16 = 256;
const EOD: u16 = 257;
let mut code_size: u32 = 9;
let mut next_code: u16 = 258;
let mut dict: Vec<Vec<u8>> = Vec::with_capacity(4096);
for i in 0..256 {
dict.push(vec![i as u8]);
}
dict.push(Vec::new()); dict.push(Vec::new());
let mut out: Vec<u8> = Vec::with_capacity(data.len() * 3);
let mut prev: Option<Vec<u8>> = None;
let mut bit_buf: u32 = 0;
let mut bit_count: u32 = 0;
let mut iter = data.iter();
loop {
while bit_count < code_size {
match iter.next() {
Some(&b) => {
bit_buf = (bit_buf << 8) | b as u32;
bit_count += 8;
}
None => return Ok(out),
}
}
let code = ((bit_buf >> (bit_count - code_size)) & ((1 << code_size) - 1)) as u16;
bit_count -= code_size;
if code == EOD {
break;
}
if code == CLEAR {
code_size = 9;
next_code = 258;
dict.truncate(258);
prev = None;
continue;
}
let entry: Vec<u8> = if (code as usize) < dict.len() {
dict[code as usize].clone()
} else if let Some(ref p) = prev {
let mut e = p.clone();
e.push(p[0]);
e
} else {
return Err(Error::Decompression(format!("invalid LZW code {code}")));
};
out.extend_from_slice(&entry);
if let Some(p) = prev.take() {
let mut new_entry = p;
new_entry.push(entry[0]);
dict.push(new_entry);
next_code += 1;
if next_code == (1 << code_size) - 1 && code_size < 12 {
code_size += 1;
}
}
prev = Some(entry);
}
Ok(out)
}
pub(crate) fn apply_png_predictor(data: &[u8], parms: Option<&Dictionary>) -> Result<Vec<u8>> {
let parms = match parms {
Some(p) => p,
None => return Ok(data.to_vec()),
};
let predictor = parms
.get_optional(b"Predictor")
.and_then(|o| o.as_i64().ok())
.unwrap_or(1);
if predictor < 10 {
return Ok(data.to_vec());
}
let columns = parms
.get_optional(b"Columns")
.and_then(|o| o.as_i64().ok())
.unwrap_or(1) as usize;
if columns == 0 {
return Ok(data.to_vec());
}
let stride = columns + 1;
let mut out: Vec<u8> = Vec::with_capacity(data.len());
let mut prev_row: Vec<u8> = vec![0; columns];
for row in data.chunks(stride) {
if row.len() < 2 {
break;
}
let filter = row[0];
let data_row = &row[1..];
let mut decoded_row: Vec<u8> = Vec::with_capacity(data_row.len());
match filter {
0 => decoded_row.extend_from_slice(data_row),
1 => {
for (i, &b) in data_row.iter().enumerate() {
let left = if i == 0 { 0 } else { decoded_row[i - 1] };
decoded_row.push(b.wrapping_add(left));
}
}
2 => {
for (i, &b) in data_row.iter().enumerate() {
decoded_row.push(b.wrapping_add(*prev_row.get(i).unwrap_or(&0)));
}
}
3 => {
for (i, &b) in data_row.iter().enumerate() {
let left = if i == 0 { 0u16 } else { decoded_row[i - 1] as u16 };
let up = *prev_row.get(i).unwrap_or(&0) as u16;
decoded_row.push(b.wrapping_add(((left + up) / 2) as u8));
}
}
4 => {
for (i, &b) in data_row.iter().enumerate() {
let left = if i == 0 { 0i16 } else { decoded_row[i - 1] as i16 };
let up = *prev_row.get(i).unwrap_or(&0) as i16;
let up_left = if i == 0 {
0i16
} else {
*prev_row.get(i - 1).unwrap_or(&0) as i16
};
let p = left + up - up_left;
let pa = (p - left).abs();
let pb = (p - up).abs();
let pc = (p - up_left).abs();
let pick = if pa <= pb && pa <= pc {
left
} else if pb <= pc {
up
} else {
up_left
};
decoded_row.push(b.wrapping_add(pick as u8));
}
}
_ => decoded_row.extend_from_slice(data_row),
}
out.extend_from_slice(&decoded_row);
prev_row = decoded_row;
}
Ok(out)
}
fn ascii_hex_decode(data: &[u8]) -> Result<Vec<u8>> {
let mut out = Vec::with_capacity(data.len() / 2);
let mut nybble: i16 = -1;
for &b in data {
if b == b'>' {
break;
}
if b.is_ascii_whitespace() {
continue;
}
let v = hex_value(b).ok_or_else(|| {
Error::Decompression(format!("invalid ASCIIHex byte: 0x{b:02x}"))
})?;
if nybble < 0 {
nybble = v as i16;
} else {
out.push((((nybble as u8) << 4) | v) & 0xff);
nybble = -1;
}
}
if nybble >= 0 {
out.push(((nybble as u8) << 4) & 0xf0);
}
Ok(out)
}
#[inline]
fn hex_value(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
fn ascii_85_decode(data: &[u8]) -> Result<Vec<u8>> {
let mut out = Vec::with_capacity(data.len() * 4 / 5);
let mut accum: u32 = 0;
let mut count: u32 = 0;
let mut iter = data.iter().peekable();
if data.starts_with(b"<~") {
iter.next();
iter.next();
}
while let Some(&b) = iter.next() {
if b == b'~' || b == b'>' {
break;
}
if b.is_ascii_whitespace() {
continue;
}
if b == b'z' && count == 0 {
out.extend_from_slice(&[0, 0, 0, 0]);
continue;
}
if !(0x21..=0x75).contains(&b) {
return Err(Error::Decompression(format!(
"invalid ASCII85 byte: 0x{b:02x}"
)));
}
accum = accum.wrapping_mul(85).wrapping_add((b - 33) as u32);
count += 1;
if count == 5 {
out.extend_from_slice(&accum.to_be_bytes());
accum = 0;
count = 0;
}
}
if count > 0 {
for _ in 0..(5 - count) {
accum = accum.wrapping_mul(85).wrapping_add(84);
}
let bytes = accum.to_be_bytes();
out.extend_from_slice(&bytes[..(count as usize - 1)]);
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ascii_hex_round_trip() {
let data = b"68656c6c6f>";
let decoded = ascii_hex_decode(data).expect("decode");
assert_eq!(decoded, b"hello");
}
#[test]
fn flate_decode_simple() {
let encoded: &[u8] = &[
0x78, 0x9c, 0xcb, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00, 0x06, 0x2c, 0x02, 0x15,
];
let decoded = flate_decode(encoded).expect("decode");
assert_eq!(&decoded, b"hello");
}
}