use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PngPredictor {
None = 10,
Sub = 11,
Up = 12,
Average = 13,
Paeth = 14,
Optimum = 15,
}
#[derive(Debug, Clone)]
pub struct DecodeParams {
pub predictor: i64,
pub columns: usize,
pub colors: usize,
pub bits_per_component: usize,
}
impl Default for DecodeParams {
fn default() -> Self {
Self {
predictor: 1, columns: 1,
colors: 1,
bits_per_component: 8,
}
}
}
impl DecodeParams {
pub fn bytes_per_row(&self) -> usize {
let pixel_bytes = (self.columns * self.colors * self.bits_per_component).div_ceil(8);
if self.predictor >= 10 {
pixel_bytes + 1 } else {
pixel_bytes
}
}
pub fn pixel_bytes_per_row(&self) -> usize {
(self.columns * self.colors * self.bits_per_component).div_ceil(8)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CcittParams {
pub k: i64,
pub columns: u32,
pub rows: Option<u32>,
pub black_is_1: bool,
pub end_of_line: bool,
pub encoded_byte_align: bool,
pub end_of_block: bool,
}
impl Default for CcittParams {
fn default() -> Self {
Self {
k: -1, columns: 1,
rows: None,
black_is_1: false, end_of_line: false,
encoded_byte_align: false,
end_of_block: true, }
}
}
impl CcittParams {
pub fn is_group_4(&self) -> bool {
self.k == -1
}
pub fn is_group_3(&self) -> bool {
self.k >= 0
}
}
pub fn decode_predictor(data: &[u8], params: &DecodeParams) -> Result<Vec<u8>> {
match params.predictor {
1 => {
Ok(data.to_vec())
},
2 => {
decode_tiff_predictor(data, params)
},
10..=15 => {
decode_png_predictor(data, params)
},
_ => Err(Error::Decode(format!("Unsupported predictor: {}", params.predictor))),
}
}
fn decode_tiff_predictor(data: &[u8], params: &DecodeParams) -> Result<Vec<u8>> {
let bytes_per_row = params.pixel_bytes_per_row();
let colors = params.colors;
if !data.len().is_multiple_of(bytes_per_row) {
return Err(Error::Decode(format!(
"Data length {} is not a multiple of row size {}",
data.len(),
bytes_per_row
)));
}
let mut output = Vec::with_capacity(data.len());
for row_data in data.chunks(bytes_per_row) {
for i in 0..colors {
output.push(row_data[i]);
}
for i in colors..row_data.len() {
let left = output[output.len() - colors];
output.push(row_data[i].wrapping_add(left));
}
}
Ok(output)
}
fn decode_png_predictor(data: &[u8], params: &DecodeParams) -> Result<Vec<u8>> {
let bytes_per_row = params.bytes_per_row(); let pixel_bytes = params.pixel_bytes_per_row();
if !data.len().is_multiple_of(bytes_per_row) {
return Err(Error::Decode(format!(
"Data length {} is not a multiple of row size {}",
data.len(),
bytes_per_row
)));
}
let row_count = data.len() / bytes_per_row;
let mut output = Vec::with_capacity(row_count * pixel_bytes);
let bpp = params.colors;
for row_idx in 0..row_count {
let row_start = row_idx * bytes_per_row;
let row_data = &data[row_start..row_start + bytes_per_row];
let predictor_tag = if params.predictor == 15 {
row_data[0]
} else {
(params.predictor - 10) as u8
};
let encoded_pixels = &row_data[1..];
match predictor_tag {
0 => {
output.extend_from_slice(encoded_pixels);
},
1 => {
decode_png_sub(encoded_pixels, &mut output, bpp);
},
2 => {
decode_png_up(encoded_pixels, &mut output, row_idx, pixel_bytes);
},
3 => {
decode_png_average(encoded_pixels, &mut output, row_idx, pixel_bytes, bpp);
},
4 => {
decode_png_paeth(encoded_pixels, &mut output, row_idx, pixel_bytes, bpp);
},
_ => {
return Err(Error::Decode(format!("Invalid PNG predictor tag: {}", predictor_tag)));
},
}
}
Ok(output)
}
fn decode_png_sub(encoded: &[u8], output: &mut Vec<u8>, bpp: usize) {
let start_pos = output.len();
for (i, &byte) in encoded.iter().enumerate() {
let left = if i >= bpp {
output[start_pos + i - bpp]
} else {
0
};
output.push(byte.wrapping_add(left));
}
}
fn decode_png_up(encoded: &[u8], output: &mut Vec<u8>, row_idx: usize, pixel_bytes: usize) {
for (i, &byte) in encoded.iter().enumerate() {
let up = if row_idx > 0 {
output[(row_idx - 1) * pixel_bytes + i]
} else {
0
};
output.push(byte.wrapping_add(up));
}
}
fn decode_png_average(
encoded: &[u8],
output: &mut Vec<u8>,
row_idx: usize,
pixel_bytes: usize,
bpp: usize,
) {
let start_pos = output.len();
for (i, &byte) in encoded.iter().enumerate() {
let left = if i >= bpp {
output[start_pos + i - bpp] as u16
} else {
0
};
let up = if row_idx > 0 {
output[(row_idx - 1) * pixel_bytes + i] as u16
} else {
0
};
let avg = ((left + up) / 2) as u8;
output.push(byte.wrapping_add(avg));
}
}
fn decode_png_paeth(
encoded: &[u8],
output: &mut Vec<u8>,
row_idx: usize,
pixel_bytes: usize,
bpp: usize,
) {
let start_pos = output.len();
for (i, &byte) in encoded.iter().enumerate() {
let left = if i >= bpp {
output[start_pos + i - bpp] as i16
} else {
0
};
let up = if row_idx > 0 {
output[(row_idx - 1) * pixel_bytes + i] as i16
} else {
0
};
let up_left = if row_idx > 0 && i >= bpp {
output[(row_idx - 1) * pixel_bytes + i - bpp] as i16
} else {
0
};
let paeth = paeth_predictor(left, up, up_left) as u8;
output.push(byte.wrapping_add(paeth));
}
}
fn paeth_predictor(a: i16, b: i16, c: i16) -> i16 {
let p = a + b - c;
let pa = (p - a).abs();
let pb = (p - b).abs();
let pc = (p - c).abs();
if pa <= pb && pa <= pc {
a
} else if pb <= pc {
b
} else {
c
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_no_predictor() {
let data = b"Hello, World!";
let params = DecodeParams {
predictor: 1,
..Default::default()
};
let result = decode_predictor(data, ¶ms).unwrap();
assert_eq!(result, data);
}
#[test]
fn test_png_up_predictor() {
let params = DecodeParams {
predictor: 12, columns: 5,
colors: 1,
bits_per_component: 8,
};
let encoded = vec![
2, 10, 20, 30, 40, 50, 2, 5, 5, 5, 5, 5, ];
let result = decode_predictor(&encoded, ¶ms).unwrap();
assert_eq!(result, vec![10, 20, 30, 40, 50, 15, 25, 35, 45, 55]);
}
#[test]
fn test_bytes_per_row_calculation() {
let params = DecodeParams {
predictor: 12, columns: 5,
colors: 1,
bits_per_component: 8,
};
assert_eq!(params.bytes_per_row(), 6); assert_eq!(params.pixel_bytes_per_row(), 5);
}
#[test]
fn test_decode_params_default() {
let params = DecodeParams::default();
assert_eq!(params.predictor, 1);
assert_eq!(params.columns, 1);
assert_eq!(params.colors, 1);
assert_eq!(params.bits_per_component, 8);
}
}