use alloc::borrow::Cow;
use alloc::vec;
use alloc::vec::Vec;
use zenflate::crc32;
use crate::chunk::ancillary::PngAncillary;
use crate::chunk::ihdr::Ihdr;
use crate::chunk::{ChunkIter, ChunkRef};
use crate::error::PngError;
use super::postprocess::output_bytes_per_pixel;
pub(crate) struct IdatSource<'a> {
data: Cow<'a, [u8]>,
chunk_pos: usize,
current_range: (usize, usize),
done: bool,
pub post_idat_pos: usize,
skip_crc: bool,
}
impl<'a> IdatSource<'a> {
pub fn file_data(&self) -> &[u8] {
&self.data
}
pub fn new(data: Cow<'a, [u8]>, first_idat_pos: usize, skip_crc: bool) -> Self {
let length =
u32::from_be_bytes(data[first_idat_pos..first_idat_pos + 4].try_into().unwrap())
as usize;
let data_start = first_idat_pos + 8; let data_end = data_start + length;
let next_pos = data_end + 4;
Self {
data,
chunk_pos: next_pos,
current_range: (data_start, data_end),
done: false,
post_idat_pos: 0,
skip_crc,
}
}
}
impl zenflate::InputSource for IdatSource<'_> {
type Error = PngError;
fn fill_buf(&mut self) -> Result<&[u8], PngError> {
let (start, end) = self.current_range;
if start < end {
return Ok(&self.data[start..end]);
}
if self.done {
return Ok(&[]);
}
loop {
if self.chunk_pos + 12 > self.data.len() {
self.done = true;
self.post_idat_pos = self.chunk_pos;
return Ok(&[]);
}
let length = u32::from_be_bytes(
self.data[self.chunk_pos..self.chunk_pos + 4]
.try_into()
.unwrap(),
) as usize;
let chunk_type: [u8; 4] = self.data[self.chunk_pos + 4..self.chunk_pos + 8]
.try_into()
.unwrap();
let data_start = self.chunk_pos + 8;
let Some(data_end) = data_start.checked_add(length) else {
return Err(PngError::Decode("IDAT chunk length overflow".into()));
};
let Some(crc_end) = data_end.checked_add(4) else {
return Err(PngError::Decode("IDAT chunk length overflow".into()));
};
if crc_end > self.data.len() {
return Err(PngError::Decode("truncated IDAT chunk".into()));
}
if chunk_type != *b"IDAT" {
self.done = true;
self.post_idat_pos = self.chunk_pos;
return Ok(&[]);
}
if !self.skip_crc {
let stored_crc =
u32::from_be_bytes(self.data[data_end..crc_end].try_into().unwrap());
let computed_crc = crc32(crc32(0, &chunk_type), &self.data[data_start..data_end]);
if stored_crc != computed_crc {
return Err(PngError::Decode("CRC mismatch in IDAT chunk".into()));
}
}
self.current_range = (data_start, data_end);
self.chunk_pos = crc_end;
if data_start < data_end {
return Ok(&self.data[data_start..data_end]);
}
}
}
fn consume(&mut self, n: usize) {
self.current_range.0 += n;
}
}
pub(crate) struct FdatSource<'a> {
data: &'a [u8],
chunk_pos: usize,
current_data: &'a [u8],
done: bool,
pub post_fdat_pos: usize,
skip_crc: bool,
}
impl<'a> FdatSource<'a> {
pub fn new(data: &'a [u8], first_fdat_pos: usize, skip_crc: bool) -> Self {
let length =
u32::from_be_bytes(data[first_fdat_pos..first_fdat_pos + 4].try_into().unwrap())
as usize;
let data_start = first_fdat_pos + 8; let data_end = data_start + length;
let next_pos = data_end + 4;
let deflate_start = data_start + 4;
let deflate_data = if deflate_start < data_end {
&data[deflate_start..data_end]
} else {
&data[data_end..data_end] };
Self {
data,
chunk_pos: next_pos,
current_data: deflate_data,
done: false,
post_fdat_pos: 0,
skip_crc,
}
}
}
impl<'a> zenflate::InputSource for FdatSource<'a> {
type Error = PngError;
fn fill_buf(&mut self) -> Result<&[u8], PngError> {
if !self.current_data.is_empty() {
return Ok(self.current_data);
}
if self.done {
return Ok(&[]);
}
loop {
if self.chunk_pos + 12 > self.data.len() {
self.done = true;
self.post_fdat_pos = self.chunk_pos;
return Ok(&[]);
}
let length = u32::from_be_bytes(
self.data[self.chunk_pos..self.chunk_pos + 4]
.try_into()
.unwrap(),
) as usize;
let chunk_type: [u8; 4] = self.data[self.chunk_pos + 4..self.chunk_pos + 8]
.try_into()
.unwrap();
let data_start = self.chunk_pos + 8;
let Some(data_end) = data_start.checked_add(length) else {
return Err(PngError::Decode("fdAT chunk length overflow".into()));
};
let Some(crc_end) = data_end.checked_add(4) else {
return Err(PngError::Decode("fdAT chunk length overflow".into()));
};
if crc_end > self.data.len() {
return Err(PngError::Decode("truncated fdAT chunk".into()));
}
if chunk_type != *b"fdAT" {
self.done = true;
self.post_fdat_pos = self.chunk_pos;
return Ok(&[]);
}
if !self.skip_crc {
let stored_crc =
u32::from_be_bytes(self.data[data_end..crc_end].try_into().unwrap());
let computed_crc = crc32(crc32(0, &chunk_type), &self.data[data_start..data_end]);
if stored_crc != computed_crc {
return Err(PngError::Decode("CRC mismatch in fdAT chunk".into()));
}
}
let deflate_start = data_start + 4;
if deflate_start < data_end {
self.current_data = &self.data[deflate_start..data_end];
} else {
self.current_data = &[];
}
self.chunk_pos = crc_end;
if !self.current_data.is_empty() {
return Ok(self.current_data);
}
}
}
fn consume(&mut self, n: usize) {
self.current_data = &self.current_data[n..];
}
}
pub(super) fn unfilter_row(
filter_type: u8,
row: &mut [u8],
prev: &[u8],
bpp: usize,
) -> Result<(), PngError> {
crate::simd::unfilter_row(filter_type, row, prev, bpp)
}
pub(crate) struct RowDecoder<'a> {
decompressor: zenflate::StreamDecompressor<IdatSource<'a>>,
ihdr: Ihdr,
ancillary: PngAncillary,
first_idat_pos: usize,
prev_row: Vec<u8>,
current_row: Vec<u8>,
rows_yielded: u32,
stride: usize,
bpp: usize,
chunk_warnings: Vec<crate::decode::PngWarning>,
}
impl<'a> RowDecoder<'a> {
pub fn new(
data: Cow<'a, [u8]>,
config: &crate::decode::PngDecodeConfig,
) -> Result<Self, PngError> {
if data.len() < 8 || data[..8] != crate::chunk::PNG_SIGNATURE {
return Err(PngError::Decode("not a PNG file".into()));
}
let mut chunks = ChunkIter::new_with_config(&data, config.skip_critical_chunk_crc);
let ihdr_chunk = chunks
.next()
.ok_or_else(|| PngError::Decode("empty PNG (no chunks)".into()))??;
if ihdr_chunk.chunk_type != *b"IHDR" {
return Err(PngError::Decode("first chunk is not IHDR".into()));
}
let ihdr = Ihdr::parse(ihdr_chunk.data)?;
let mut ancillary = PngAncillary::default();
let mut first_idat_pos = None;
for chunk_result in &mut chunks {
let chunk = chunk_result?;
if chunk.chunk_type == *b"IDAT" {
first_idat_pos = Some(chunks.pos() - 12 - chunk.data.len());
break;
}
ancillary.collect(&chunk)?;
}
let chunk_warnings = chunks.warnings;
let first_idat_pos =
first_idat_pos.ok_or_else(|| PngError::Decode("no IDAT chunk found".into()))?;
if ihdr.is_indexed() && ancillary.palette.is_none() {
return Err(PngError::Decode(
"indexed color type requires PLTE chunk".into(),
));
}
let output_bpp = output_bytes_per_pixel(&ihdr, &ancillary) as u32;
config.validate(ihdr.width, ihdr.height, output_bpp)?;
let stride = ihdr.stride()?;
let raw_row_bytes = ihdr.raw_row_bytes()?;
let bpp = ihdr.filter_bpp();
let source = IdatSource::new(data, first_idat_pos, config.skip_critical_chunk_crc);
let decompressor = zenflate::StreamDecompressor::zlib(source, stride * 2)
.with_skip_checksum(config.skip_decompression_checksum);
Ok(Self {
decompressor,
ihdr,
ancillary,
first_idat_pos,
prev_row: vec![0u8; raw_row_bytes],
current_row: vec![0u8; raw_row_bytes],
rows_yielded: 0,
stride,
bpp,
chunk_warnings,
})
}
pub fn ihdr(&self) -> &Ihdr {
&self.ihdr
}
pub fn ancillary(&self) -> &PngAncillary {
&self.ancillary
}
pub fn first_idat_pos(&self) -> usize {
self.first_idat_pos
}
pub fn next_raw_row(&mut self) -> Option<Result<&[u8], PngError>> {
if self.rows_yielded >= self.ihdr.height {
return None;
}
if let Err(e) = self.fill_stride() {
return Some(Err(e));
}
if self.decompressor.peek().len() < self.stride {
return None;
}
let peeked = self.decompressor.peek();
let filter_byte = peeked[0];
let raw_row_bytes = self.stride - 1;
self.current_row[..raw_row_bytes].copy_from_slice(&peeked[1..self.stride]);
self.decompressor.advance(self.stride);
if let Err(e) = unfilter_row(
filter_byte,
&mut self.current_row[..raw_row_bytes],
&self.prev_row,
self.bpp,
) {
return Some(Err(e));
}
core::mem::swap(&mut self.current_row, &mut self.prev_row);
self.rows_yielded += 1;
Some(Ok(&self.prev_row[..raw_row_bytes]))
}
pub fn next_raw_row_direct(
&mut self,
dest: &mut [u8],
prev: &[u8],
) -> Option<Result<(), PngError>> {
if self.rows_yielded >= self.ihdr.height {
return None;
}
if let Err(e) = self.fill_stride() {
return Some(Err(e));
}
if self.decompressor.peek().len() < self.stride {
return None;
}
let peeked = self.decompressor.peek();
let filter_byte = peeked[0];
let raw_row_bytes = self.stride - 1;
dest[..raw_row_bytes].copy_from_slice(&peeked[1..self.stride]);
self.decompressor.advance(self.stride);
if let Err(e) = unfilter_row(filter_byte, &mut dest[..raw_row_bytes], prev, self.bpp) {
return Some(Err(e));
}
self.rows_yielded += 1;
Some(Ok(()))
}
fn fill_stride(&mut self) -> Result<(), PngError> {
loop {
let available = self.decompressor.peek().len();
if available >= self.stride {
return Ok(());
}
if self.decompressor.is_done() {
if available > 0 && available < self.stride {
return Err(PngError::Decode(alloc::format!(
"truncated row data: got {} bytes, expected {} (row {})",
available,
self.stride,
self.rows_yielded
)));
}
return Ok(());
}
match self.decompressor.fill() {
Ok(_) => {}
Err(e) => {
return Err(PngError::Decode(alloc::format!(
"decompression error: {e:?}"
)));
}
}
}
}
pub fn finish_metadata(&mut self) {
let data: &[u8] = self.decompressor.source_ref().file_data();
let mut pos = self.first_idat_pos;
while pos + 12 <= data.len() {
let length = u32::from_be_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
let chunk_type: [u8; 4] = data[pos + 4..pos + 8].try_into().unwrap();
let Some(crc_end) = (pos + 8).checked_add(length).and_then(|v| v.checked_add(4)) else {
return;
};
if crc_end > data.len() {
return;
}
if chunk_type != *b"IDAT" {
break;
}
pos = crc_end;
}
while pos + 12 <= data.len() {
let length = u32::from_be_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
let chunk_type: [u8; 4] = data[pos + 4..pos + 8].try_into().unwrap();
let data_start = pos + 8;
let Some(data_end) = data_start.checked_add(length) else {
break;
};
let Some(crc_end) = data_end.checked_add(4) else {
break;
};
if crc_end > data.len() {
break;
}
if chunk_type == *b"IEND" {
break;
}
let chunk_data = &data[data_start..data_end];
self.ancillary.collect_late(&ChunkRef {
chunk_type,
data: chunk_data,
});
pos = crc_end;
}
}
pub fn collect_decode_warnings(&self) -> Vec<crate::decode::PngWarning> {
let mut warnings = self.chunk_warnings.clone();
if self.decompressor.checksum_matched() == Some(false) {
warnings.push(crate::decode::PngWarning::DecompressionChecksumSkipped);
}
warnings
}
}