png_decoder/
lib.rs

1#![no_std]
2
3#[macro_use]
4extern crate alloc;
5
6#[cfg(test)]
7extern crate std;
8
9use alloc::vec::Vec;
10use core::convert::{TryFrom, TryInto};
11use crc32fast::Hasher;
12use miniz_oxide::inflate::TINFLStatus;
13use num_enum::TryFromPrimitive;
14
15const PNG_MAGIC_BYTES: &[u8] = &[137, 80, 78, 71, 13, 10, 26, 10];
16
17#[repr(u8)]
18#[derive(Debug, Copy, Clone, PartialEq, Eq, TryFromPrimitive)]
19pub enum BitDepth {
20    One = 1,
21    Two = 2,
22    Four = 4,
23    Eight = 8,
24    Sixteen = 16,
25}
26
27#[repr(u8)]
28#[derive(Debug, Copy, Clone, PartialEq, Eq, TryFromPrimitive)]
29pub enum ColorType {
30    Grayscale = 0,
31    Rgb = 2,
32    Palette = 3,
33    GrayscaleAlpha = 4,
34    RgbAlpha = 6,
35}
36
37impl ColorType {
38    pub fn sample_multiplier(&self) -> usize {
39        match self {
40            ColorType::Grayscale => 1,
41            ColorType::Rgb => 3,
42            ColorType::Palette => 1,
43            ColorType::GrayscaleAlpha => 2,
44            ColorType::RgbAlpha => 4,
45        }
46    }
47}
48
49#[derive(Debug, Copy, Clone)]
50enum PixelType {
51    Grayscale1,
52    Grayscale2,
53    Grayscale4,
54    Grayscale8,
55    Grayscale16,
56
57    Rgb8,
58    Rgb16,
59
60    Palette1,
61    Palette2,
62    Palette4,
63    Palette8,
64
65    GrayscaleAlpha8,
66    GrayscaleAlpha16,
67
68    RgbAlpha8,
69    RgbAlpha16,
70}
71
72impl PixelType {
73    fn new(color_type: ColorType, bit_depth: BitDepth) -> Result<Self, DecodeError> {
74        let result = match color_type {
75            ColorType::Grayscale => match bit_depth {
76                BitDepth::One => PixelType::Grayscale1,
77                BitDepth::Two => PixelType::Grayscale2,
78                BitDepth::Four => PixelType::Grayscale4,
79                BitDepth::Eight => PixelType::Grayscale8,
80                BitDepth::Sixteen => PixelType::Grayscale16,
81            },
82            ColorType::Rgb => match bit_depth {
83                BitDepth::Eight => PixelType::Rgb8,
84                BitDepth::Sixteen => PixelType::Rgb16,
85                _ => return Err(DecodeError::InvalidColorTypeBitDepthCombination),
86            },
87            ColorType::Palette => match bit_depth {
88                BitDepth::One => PixelType::Palette1,
89                BitDepth::Two => PixelType::Palette2,
90                BitDepth::Four => PixelType::Palette4,
91                BitDepth::Eight => PixelType::Palette8,
92                _ => return Err(DecodeError::InvalidColorTypeBitDepthCombination),
93            },
94            ColorType::GrayscaleAlpha => match bit_depth {
95                BitDepth::Eight => PixelType::GrayscaleAlpha8,
96                BitDepth::Sixteen => PixelType::GrayscaleAlpha16,
97                _ => return Err(DecodeError::InvalidColorTypeBitDepthCombination),
98            },
99            ColorType::RgbAlpha => match bit_depth {
100                BitDepth::Eight => PixelType::RgbAlpha8,
101                BitDepth::Sixteen => PixelType::RgbAlpha16,
102                _ => return Err(DecodeError::InvalidColorTypeBitDepthCombination),
103            },
104        };
105
106        Ok(result)
107    }
108}
109
110#[inline(always)]
111fn u16_to_u8(val: u16) -> u8 {
112    (val >> 8) as u8
113}
114
115#[derive(Default)]
116struct AncillaryChunks<'a> {
117    palette: Option<&'a [u8]>,
118    transparency: Option<TransparencyChunk<'a>>,
119    background: Option<&'a [u8]>,
120}
121
122struct ScanlineIterator<'a> {
123    image_width: usize, // Width in pixels
124    pixel_cursor: usize,
125    pixel_type: PixelType,
126    scanline: &'a [u8],
127    extra_chunks: &'a AncillaryChunks<'a>,
128}
129
130impl<'a> ScanlineIterator<'a> {
131    fn new(
132        image_width: u32,
133        pixel_type: PixelType,
134        scanline: &'a [u8],
135        extra_chunks: &'a AncillaryChunks<'a>,
136    ) -> Self {
137        Self {
138            image_width: image_width as usize,
139            pixel_cursor: 0,
140            pixel_type,
141            scanline,
142            extra_chunks,
143        }
144    }
145}
146
147impl<'a> Iterator for ScanlineIterator<'a> {
148    type Item = (u8, u8, u8, u8);
149
150    fn next(&mut self) -> Option<Self::Item> {
151        if self.pixel_cursor >= self.image_width {
152            return None;
153        }
154
155        let pixel = match self.pixel_type {
156            PixelType::Grayscale1 => {
157                let byte = self.scanline[self.pixel_cursor / 8];
158                let bit_offset = 7 - self.pixel_cursor % 8;
159                let grayscale_val = (byte >> bit_offset) & 1;
160
161                let alpha = match self.extra_chunks.transparency {
162                    Some(TransparencyChunk::Grayscale(transparent_val))
163                        if grayscale_val == transparent_val =>
164                    {
165                        0
166                    },
167                    _ => 255,
168                };
169
170                let pixel_val = grayscale_val * 255;
171
172                Some((pixel_val, pixel_val, pixel_val, alpha))
173            },
174            PixelType::Grayscale2 => {
175                let byte = self.scanline[self.pixel_cursor / 4];
176                let bit_offset = 6 - ((self.pixel_cursor % 4) * 2);
177                let grayscale_val = (byte >> bit_offset) & 0b11;
178
179                let alpha = match self.extra_chunks.transparency {
180                    Some(TransparencyChunk::Grayscale(transparent_val))
181                        if grayscale_val == transparent_val =>
182                    {
183                        0
184                    },
185                    _ => 255,
186                };
187
188                // TODO - use a lookup table
189                let pixel_val = ((grayscale_val as f32 / 3.0) * 255.0) as u8;
190
191                Some((pixel_val, pixel_val, pixel_val, alpha))
192            },
193            PixelType::Grayscale4 => {
194                let byte = self.scanline[self.pixel_cursor / 2];
195                let bit_offset = 4 - ((self.pixel_cursor % 2) * 4);
196                let grayscale_val = (byte >> bit_offset) & 0b1111;
197
198                let alpha = match self.extra_chunks.transparency {
199                    Some(TransparencyChunk::Grayscale(transparent_val))
200                        if grayscale_val == transparent_val =>
201                    {
202                        0
203                    },
204                    _ => 255,
205                };
206
207                // TODO - use a lookup table
208                let pixel_val = ((grayscale_val as f32 / 15.0) * 255.0) as u8;
209                Some((pixel_val, pixel_val, pixel_val, alpha))
210            },
211            PixelType::Grayscale8 => {
212                let byte = self.scanline[self.pixel_cursor];
213
214                let alpha = match self.extra_chunks.transparency {
215                    Some(TransparencyChunk::Grayscale(transparent_val))
216                        if byte == transparent_val =>
217                    {
218                        0
219                    },
220                    _ => 255,
221                };
222                Some((byte, byte, byte, alpha))
223            },
224            PixelType::Grayscale16 => {
225                let offset = self.pixel_cursor * 2;
226                let grayscale_val =
227                    u16::from_be_bytes([self.scanline[offset], self.scanline[offset + 1]]);
228
229                let pixel_val = u16_to_u8(grayscale_val);
230
231                // TODO(bschwind) - This may need to be compared to the original
232                //                  16-bit transparency value, instead of the transformed
233                //                  8-bit value.
234                let alpha = match self.extra_chunks.transparency {
235                    Some(TransparencyChunk::Grayscale(transparent_val))
236                        if pixel_val == transparent_val =>
237                    {
238                        0
239                    },
240                    _ => 255,
241                };
242
243                Some((pixel_val, pixel_val, pixel_val, alpha))
244            },
245            PixelType::Rgb8 => {
246                let offset = self.pixel_cursor * 3;
247                let r = self.scanline[offset];
248                let g = self.scanline[offset + 1];
249                let b = self.scanline[offset + 2];
250
251                let alpha = match self.extra_chunks.transparency {
252                    Some(TransparencyChunk::Rgb(t_r, t_g, t_b))
253                        if r == t_r && g == t_g && b == t_b =>
254                    {
255                        0
256                    },
257                    _ => 255,
258                };
259
260                Some((r, g, b, alpha))
261            },
262            PixelType::Rgb16 => {
263                let offset = self.pixel_cursor * 6;
264                let r = u16::from_be_bytes([self.scanline[offset], self.scanline[offset + 1]]);
265                let g = u16::from_be_bytes([self.scanline[offset + 2], self.scanline[offset + 3]]);
266                let b = u16::from_be_bytes([self.scanline[offset + 4], self.scanline[offset + 5]]);
267
268                let r = u16_to_u8(r);
269                let g = u16_to_u8(g);
270                let b = u16_to_u8(b);
271
272                let alpha = match self.extra_chunks.transparency {
273                    Some(TransparencyChunk::Rgb(t_r, t_g, t_b))
274                        if r == t_r && g == t_g && b == t_b =>
275                    {
276                        0
277                    },
278                    _ => 255,
279                };
280
281                Some((r, g, b, alpha))
282            },
283            PixelType::Palette1 => {
284                let byte = self.scanline[self.pixel_cursor / 8];
285                let bit_offset = 7 - self.pixel_cursor % 8;
286                let palette_idx = ((byte >> bit_offset) & 1) as usize;
287
288                let offset = palette_idx * 3;
289
290                let palette = self.extra_chunks.palette.unwrap();
291                let r = palette[offset];
292                let g = palette[offset + 1];
293                let b = palette[offset + 2];
294
295                let alpha: u8 = match self.extra_chunks.transparency {
296                    Some(TransparencyChunk::Palette(data)) => {
297                        *data.get(palette_idx).unwrap_or(&255)
298                    },
299                    Some(_) | None => 255,
300                };
301
302                Some((r, g, b, alpha))
303            },
304            PixelType::Palette2 => {
305                let byte = self.scanline[self.pixel_cursor / 4];
306                let bit_offset = 6 - ((self.pixel_cursor % 4) * 2);
307                let palette_idx = ((byte >> bit_offset) & 0b11) as usize;
308
309                let offset = palette_idx * 3;
310
311                let palette = self.extra_chunks.palette.unwrap();
312                let r = palette[offset];
313                let g = palette[offset + 1];
314                let b = palette[offset + 2];
315
316                let alpha: u8 = match self.extra_chunks.transparency {
317                    Some(TransparencyChunk::Palette(data)) => {
318                        *data.get(palette_idx).unwrap_or(&255)
319                    },
320                    Some(_) | None => 255,
321                };
322
323                Some((r, g, b, alpha))
324            },
325            PixelType::Palette4 => {
326                let byte = self.scanline[self.pixel_cursor / 2];
327                let bit_offset = 4 - ((self.pixel_cursor % 2) * 4);
328                let palette_idx = ((byte >> bit_offset) & 0b1111) as usize;
329
330                let offset = palette_idx * 3;
331
332                let palette = self.extra_chunks.palette.unwrap();
333                let r = palette[offset];
334                let g = palette[offset + 1];
335                let b = palette[offset + 2];
336
337                let alpha: u8 = match self.extra_chunks.transparency {
338                    Some(TransparencyChunk::Palette(data)) => {
339                        *data.get(palette_idx).unwrap_or(&255)
340                    },
341                    Some(_) | None => 255,
342                };
343
344                Some((r, g, b, alpha))
345            },
346            PixelType::Palette8 => {
347                let offset = self.scanline[self.pixel_cursor] as usize * 3;
348
349                let palette = self.extra_chunks.palette.unwrap();
350                let r = palette[offset];
351                let g = palette[offset + 1];
352                let b = palette[offset + 2];
353
354                let alpha: u8 = match self.extra_chunks.transparency {
355                    Some(TransparencyChunk::Palette(data)) => *data.get(offset).unwrap_or(&255),
356                    Some(_) | None => 255,
357                };
358
359                Some((r, g, b, alpha))
360            },
361            PixelType::GrayscaleAlpha8 => {
362                let offset = self.pixel_cursor * 2;
363                let grayscale_val = self.scanline[offset];
364                let alpha = self.scanline[offset + 1];
365
366                Some((grayscale_val, grayscale_val, grayscale_val, alpha))
367            },
368            PixelType::GrayscaleAlpha16 => {
369                let offset = self.pixel_cursor * 4;
370                let grayscale_val =
371                    u16::from_be_bytes([self.scanline[offset], self.scanline[offset + 1]]);
372                let alpha =
373                    u16::from_be_bytes([self.scanline[offset + 2], self.scanline[offset + 3]]);
374
375                let grayscale_val = u16_to_u8(grayscale_val);
376                let alpha = u16_to_u8(alpha);
377
378                Some((grayscale_val, grayscale_val, grayscale_val, alpha))
379            },
380            PixelType::RgbAlpha8 => {
381                let offset = self.pixel_cursor * 4;
382                let r = self.scanline[offset];
383                let g = self.scanline[offset + 1];
384                let b = self.scanline[offset + 2];
385                let a = self.scanline[offset + 3];
386
387                Some((r, g, b, a))
388            },
389            PixelType::RgbAlpha16 => {
390                let offset = self.pixel_cursor * 8;
391                let r = u16::from_be_bytes([self.scanline[offset], self.scanline[offset + 1]]);
392                let g = u16::from_be_bytes([self.scanline[offset + 2], self.scanline[offset + 3]]);
393                let b = u16::from_be_bytes([self.scanline[offset + 4], self.scanline[offset + 5]]);
394                let a = u16::from_be_bytes([self.scanline[offset + 6], self.scanline[offset + 7]]);
395
396                let r = u16_to_u8(r);
397                let g = u16_to_u8(g);
398                let b = u16_to_u8(b);
399                let a = u16_to_u8(a);
400
401                Some((r, g, b, a))
402            },
403        };
404
405        self.pixel_cursor += 1;
406        pixel
407    }
408}
409
410#[repr(u8)]
411#[derive(Debug, Copy, Clone, PartialEq, Eq, TryFromPrimitive)]
412pub enum CompressionMethod {
413    Deflate = 0,
414}
415
416#[repr(u8)]
417#[derive(Debug, Copy, Clone, PartialEq, Eq, TryFromPrimitive)]
418pub enum FilterMethod {
419    Adaptive = 0,
420}
421
422#[repr(u8)]
423#[derive(Debug, Copy, Clone, PartialEq, Eq, TryFromPrimitive)]
424pub enum FilterType {
425    None = 0,
426    Sub = 1,
427    Up = 2,
428    Average = 3,
429    Paeth = 4,
430}
431
432#[repr(u8)]
433#[derive(Debug, Copy, Clone, PartialEq, Eq, TryFromPrimitive)]
434pub enum InterlaceMethod {
435    None = 0,
436    Adam7 = 1,
437}
438
439#[derive(Debug, Clone, PartialEq, Eq)]
440pub struct PngHeader {
441    pub width: u32,
442    pub height: u32,
443    pub bit_depth: BitDepth,
444    pub color_type: ColorType,
445    pub compression_method: CompressionMethod,
446    pub filter_method: FilterMethod,
447    pub interlace_method: InterlaceMethod,
448}
449
450impl PngHeader {
451    fn from_chunk(chunk: &Chunk) -> Result<Self, DecodeError> {
452        if chunk.chunk_type != ChunkType::ImageHeader {
453            return Err(DecodeError::InvalidChunkType);
454        }
455
456        if chunk.data.len() < 13 {
457            return Err(DecodeError::MissingBytes);
458        }
459
460        let width = read_u32(chunk.data, 0);
461        let height = read_u32(chunk.data, 4);
462        let bit_depth = chunk.data[8];
463        let color_type = chunk.data[9];
464        let compression_method = chunk.data[10];
465        let filter_method = chunk.data[11];
466        let interlace_method = chunk.data[12];
467
468        Ok(PngHeader {
469            width,
470            height,
471            bit_depth: TryFrom::try_from(bit_depth).map_err(|_| DecodeError::InvalidBitDepth)?,
472            color_type: TryFrom::try_from(color_type).map_err(|_| DecodeError::InvalidColorType)?,
473            compression_method: TryFrom::try_from(compression_method)
474                .map_err(|_| DecodeError::InvalidCompressionMethod)?,
475            filter_method: TryFrom::try_from(filter_method)
476                .map_err(|_| DecodeError::InvalidFilterMethod)?,
477            interlace_method: TryFrom::try_from(interlace_method)
478                .map_err(|_| DecodeError::InvalidInterlaceMethod)?,
479        })
480    }
481}
482
483#[derive(Debug, Clone, PartialEq, Eq)]
484pub enum DecodeError {
485    InvalidMagicBytes,
486    MissingBytes,
487    HeaderChunkNotFirst,
488    EndChunkNotLast,
489    InvalidChunkType,
490    InvalidChunk,
491    Decompress(TINFLStatus),
492
493    IncorrectChunkCrc,
494    InvalidBitDepth,
495    InvalidColorType,
496    InvalidColorTypeBitDepthCombination,
497    InvalidCompressionMethod,
498    InvalidFilterMethod,
499    InvalidFilterType,
500    InvalidInterlaceMethod,
501
502    // The width/height specified in the image contains too many
503    // bytes to address with a usize on this platform.
504    IntegerOverflow,
505}
506
507#[derive(Debug, Copy, Clone, PartialEq, Eq)]
508pub enum ChunkType {
509    ImageHeader,
510    Palette,
511    Transparency,
512    Background,
513    Srgb,
514    ImageData,
515    ImageEnd,
516    Gamma,
517    Unknown([u8; 4]),
518}
519
520impl ChunkType {
521    fn from_bytes(bytes: &[u8; 4]) -> Self {
522        match bytes {
523            b"IHDR" => ChunkType::ImageHeader,
524            b"PLTE" => ChunkType::Palette,
525            b"tRNS" => ChunkType::Transparency,
526            b"bKGD" => ChunkType::Background,
527            b"sRGB" => ChunkType::Srgb,
528            b"IDAT" => ChunkType::ImageData,
529            b"IEND" => ChunkType::ImageEnd,
530            b"gAMA" => ChunkType::Gamma,
531            unknown_chunk_type => {
532                // println!("chunk_type: {:?}", alloc::string::String::from_utf8(chunk_type.to_vec()));
533                ChunkType::Unknown(*unknown_chunk_type)
534            },
535        }
536    }
537}
538
539#[derive(Debug)]
540struct Chunk<'a> {
541    chunk_type: ChunkType,
542    data: &'a [u8],
543    _crc: u32,
544}
545
546impl<'a> Chunk<'a> {
547    fn byte_size(&self) -> usize {
548        // length bytes + chunk type bytes + data bytes + crc bytes
549        4 + 4 + self.data.len() + 4
550    }
551}
552
553enum TransparencyChunk<'a> {
554    Palette(&'a [u8]),
555    Grayscale(u8),
556    Rgb(u8, u8, u8),
557}
558
559impl<'a> TransparencyChunk<'a> {
560    fn from_chunk(chunk: &Chunk<'a>, pixel_type: PixelType) -> Option<Self> {
561        match pixel_type {
562            PixelType::Grayscale1 => Some(TransparencyChunk::Grayscale(chunk.data[1] & 0b1)),
563            PixelType::Grayscale2 => Some(TransparencyChunk::Grayscale(chunk.data[1] & 0b11)),
564            PixelType::Grayscale4 => Some(TransparencyChunk::Grayscale(chunk.data[1] & 0b1111)),
565            PixelType::Grayscale8 => Some(TransparencyChunk::Grayscale(chunk.data[1])),
566            PixelType::Grayscale16 => {
567                let val = u16::from_be_bytes([chunk.data[0], chunk.data[1]]);
568                Some(TransparencyChunk::Grayscale(u16_to_u8(val)))
569            },
570            PixelType::Rgb8 => {
571                let r = chunk.data[1];
572                let g = chunk.data[3];
573                let b = chunk.data[5];
574                Some(TransparencyChunk::Rgb(r, g, b))
575            },
576            PixelType::Rgb16 => {
577                let r = u16::from_be_bytes([chunk.data[0], chunk.data[1]]);
578                let g = u16::from_be_bytes([chunk.data[2], chunk.data[3]]);
579                let b = u16::from_be_bytes([chunk.data[4], chunk.data[5]]);
580                Some(TransparencyChunk::Rgb(u16_to_u8(r), u16_to_u8(g), u16_to_u8(b)))
581            },
582            PixelType::Palette1 => Some(TransparencyChunk::Palette(chunk.data)),
583            PixelType::Palette2 => Some(TransparencyChunk::Palette(chunk.data)),
584            PixelType::Palette4 => Some(TransparencyChunk::Palette(chunk.data)),
585            PixelType::Palette8 => Some(TransparencyChunk::Palette(chunk.data)),
586            PixelType::GrayscaleAlpha8 => None,
587            PixelType::GrayscaleAlpha16 => None,
588            PixelType::RgbAlpha8 => None,
589            PixelType::RgbAlpha16 => None,
590        }
591    }
592}
593
594fn read_u32(bytes: &[u8], offset: usize) -> u32 {
595    u32::from_be_bytes([bytes[offset], bytes[offset + 1], bytes[offset + 2], bytes[offset + 3]])
596}
597
598fn read_chunk(bytes: &[u8]) -> Result<Chunk<'_>, DecodeError> {
599    if bytes.len() < 4 {
600        return Err(DecodeError::MissingBytes);
601    }
602
603    let length = read_u32(bytes, 0) as usize;
604    let bytes = &bytes[4..];
605
606    if bytes.len() < (4 + length + 4) {
607        return Err(DecodeError::MissingBytes);
608    }
609
610    let chunk_type = ChunkType::from_bytes(&[bytes[0], bytes[1], bytes[2], bytes[3]]);
611
612    let crc_offset = 4 + length;
613    let crc = read_u32(bytes, crc_offset);
614
615    // Offset by 4 to not include the chunk type.
616    let data_for_crc = &bytes[..crc_offset];
617
618    let mut hasher = Hasher::new();
619    hasher.reset();
620    hasher.update(data_for_crc);
621
622    if crc != hasher.finalize() {
623        return Err(DecodeError::IncorrectChunkCrc);
624    }
625
626    Ok(Chunk { chunk_type, data: &data_for_crc[4..], _crc: crc })
627}
628
629fn defilter(
630    filter_type: FilterType,
631    bytes_per_pixel: usize,
632    x: usize,
633    current_scanline: &[u8],
634    last_scanline: &[u8],
635) -> u8 {
636    match filter_type {
637        FilterType::None => current_scanline[x],
638        FilterType::Sub => {
639            if let Some(idx) = x.checked_sub(bytes_per_pixel) {
640                current_scanline[x].wrapping_add(current_scanline[idx])
641            } else {
642                current_scanline[x]
643            }
644        },
645        FilterType::Up => current_scanline[x].wrapping_add(last_scanline[x]),
646        FilterType::Average => {
647            let raw_val = if let Some(idx) = x.checked_sub(bytes_per_pixel) {
648                current_scanline[idx]
649            } else {
650                0
651            };
652
653            (current_scanline[x] as u16 + ((raw_val as u16 + last_scanline[x] as u16) / 2)) as u8
654        },
655        FilterType::Paeth => {
656            if let Some(idx) = x.checked_sub(bytes_per_pixel) {
657                let left = current_scanline[idx];
658                let above = last_scanline[x];
659                let upper_left = last_scanline[idx];
660
661                let predictor = paeth_predictor(left as i16, above as i16, upper_left as i16);
662
663                current_scanline[x].wrapping_add(predictor)
664            } else {
665                let left = 0;
666                let above = last_scanline[x];
667                let upper_left = 0;
668
669                let predictor = paeth_predictor(left as i16, above as i16, upper_left as i16);
670
671                current_scanline[x].wrapping_add(predictor)
672            }
673        },
674    }
675}
676
677fn process_scanlines(
678    header: &PngHeader,
679    scanline_data: &mut [u8],
680    output_rgba: &mut [[u8; 4]],
681    ancillary_chunks: &AncillaryChunks,
682    pixel_type: PixelType,
683) -> Result<(), DecodeError> {
684    let mut cursor = 0;
685    let bytes_per_pixel: usize =
686        (header.bit_depth as usize * header.color_type.sample_multiplier()).div_ceil(8);
687
688    match header.interlace_method {
689        InterlaceMethod::None => {
690            // TODO(bschwind) - Deduplicate this logic.
691            let bytes_per_scanline = (header.width as u64
692                * header.bit_depth as u64
693                * header.color_type.sample_multiplier() as u64)
694                .div_ceil(8);
695            let bytes_per_scanline: usize =
696                bytes_per_scanline.try_into().map_err(|_| DecodeError::IntegerOverflow)?;
697
698            let mut last_scanline = vec![0u8; bytes_per_scanline];
699
700            for y in 0..header.height {
701                let filter_type = FilterType::try_from(scanline_data[cursor])
702                    .map_err(|_| DecodeError::InvalidFilterType)?;
703                cursor += 1;
704
705                let current_scanline = &mut scanline_data[cursor..(cursor + bytes_per_scanline)];
706
707                for x in 0..(bytes_per_scanline) {
708                    let unfiltered_byte =
709                        defilter(filter_type, bytes_per_pixel, x, current_scanline, &last_scanline);
710                    current_scanline[x] = unfiltered_byte;
711                }
712
713                let scanline_iter = ScanlineIterator::new(
714                    header.width,
715                    pixel_type,
716                    current_scanline,
717                    ancillary_chunks,
718                );
719
720                for (idx, (r, g, b, a)) in scanline_iter.enumerate() {
721                    let (output_x, output_y) = (idx, y);
722
723                    let output_idx = (output_y as u64 * header.width as u64) + (output_x as u64);
724                    let output_idx: usize =
725                        output_idx.try_into().map_err(|_| DecodeError::IntegerOverflow)?;
726
727                    output_rgba[output_idx] = [r, g, b, a];
728                }
729
730                last_scanline.copy_from_slice(current_scanline);
731                cursor += bytes_per_scanline;
732            }
733        },
734        InterlaceMethod::Adam7 => {
735            let max_bytes_per_scanline = header.width as usize * bytes_per_pixel;
736            let mut last_scanline = vec![0u8; max_bytes_per_scanline];
737
738            // Adam7 Interlacing Pattern
739            // 1 6 4 6 2 6 4 6
740            // 7 7 7 7 7 7 7 7
741            // 5 6 5 6 5 6 5 6
742            // 7 7 7 7 7 7 7 7
743            // 3 6 4 6 3 6 4 6
744            // 7 7 7 7 7 7 7 7
745            // 5 6 5 6 5 6 5 6
746            // 7 7 7 7 7 7 7 7
747
748            for pass in 1..=7 {
749                let (pass_width, pass_height) = match pass {
750                    1 => {
751                        let pass_width = header.width.div_ceil(8);
752                        let pass_height = header.height.div_ceil(8);
753                        (pass_width, pass_height)
754                    },
755                    2 => {
756                        let pass_width = (header.width / 8) + ((header.width % 8) / 5);
757                        let pass_height = header.height.div_ceil(8);
758                        (pass_width, pass_height)
759                    },
760                    3 => {
761                        let pass_width = ((header.width / 8) * 2) + (header.width % 8).div_ceil(4);
762                        let pass_height = (header.height / 8) + ((header.height % 8) / 5);
763                        (pass_width, pass_height)
764                    },
765                    4 => {
766                        let pass_width = ((header.width / 8) * 2) + (header.width % 8 + 1) / 4;
767                        let pass_height = header.height.div_ceil(4);
768                        (pass_width, pass_height)
769                    },
770                    5 => {
771                        let pass_width = (header.width / 2) + (header.width % 2);
772                        let pass_height = ((header.height / 8) * 2) + (header.height % 8 + 1) / 4;
773                        (pass_width, pass_height)
774                    },
775                    6 => {
776                        let pass_width = header.width / 2;
777                        let pass_height = (header.height / 2) + (header.height % 2);
778                        (pass_width, pass_height)
779                    },
780                    7 => {
781                        let pass_width = header.width;
782                        let pass_height = header.height / 2;
783                        (pass_width, pass_height)
784                    },
785                    _ => (0, 0),
786                };
787
788                // Skip empty passes.
789                if pass_width == 0 || pass_height == 0 {
790                    continue;
791                }
792
793                let bytes_per_scanline = (pass_width as u64
794                    * header.bit_depth as u64
795                    * header.color_type.sample_multiplier() as u64)
796                    .div_ceil(8);
797                let bytes_per_scanline: usize =
798                    bytes_per_scanline.try_into().expect("bytes_per_scanline overflowed a usize");
799
800                let last_scanline = &mut last_scanline[..(bytes_per_scanline)];
801                for byte in last_scanline.iter_mut() {
802                    *byte = 0;
803                }
804
805                for y in 0..pass_height {
806                    let filter_type = FilterType::try_from(scanline_data[cursor])
807                        .map_err(|_| DecodeError::InvalidFilterType)?;
808                    cursor += 1;
809
810                    let current_scanline =
811                        &mut scanline_data[cursor..(cursor + bytes_per_scanline)];
812
813                    for x in 0..(bytes_per_scanline) {
814                        let unfiltered_byte = defilter(
815                            filter_type,
816                            bytes_per_pixel,
817                            x,
818                            current_scanline,
819                            last_scanline,
820                        );
821                        current_scanline[x] = unfiltered_byte;
822                    }
823
824                    let scanline_iter = ScanlineIterator::new(
825                        pass_width,
826                        pixel_type,
827                        current_scanline,
828                        ancillary_chunks,
829                    );
830
831                    for (idx, (r, g, b, a)) in scanline_iter.enumerate() {
832                        // Put rgba in output_rgba
833                        let (output_x, output_y) = match pass {
834                            1 => (idx * 8, y * 8),
835                            2 => (idx * 8 + 4, y * 8),
836                            3 => (idx * 4, y * 8 + 4),
837                            4 => (idx * 4 + 2, y * 4),
838                            5 => (idx * 2, y * 4 + 2),
839                            6 => (idx * 2 + 1, y * 2),
840                            7 => (idx, y * 2 + 1),
841                            _ => (0, 0),
842                        };
843
844                        let output_idx =
845                            (output_y as u64 * header.width as u64) + (output_x as u64);
846                        let output_idx: usize =
847                            output_idx.try_into().map_err(|_| DecodeError::IntegerOverflow)?;
848
849                        output_rgba[output_idx] = [r, g, b, a];
850                    }
851
852                    last_scanline.copy_from_slice(current_scanline);
853
854                    cursor += bytes_per_scanline;
855                }
856            }
857        },
858    }
859
860    Ok(())
861}
862
863fn paeth_predictor(a: i16, b: i16, c: i16) -> u8 {
864    // TODO(bschwind) - Accept i16 or convert once and store in a temp.
865    // a = left pixel
866    // b = above pixel
867    // c = upper left
868    let p = a + b - c;
869    let pa = (p - a).abs();
870    let pb = (p - b).abs();
871    let pc = (p - c).abs();
872
873    if pa <= pb && pa <= pc {
874        a as u8
875    } else if pb <= pc {
876        b as u8
877    } else {
878        c as u8
879    }
880}
881
882/// Decodes the provided PNG into RGBA pixels.
883///
884/// The returned [`PngHeader`] contains the image’s size, and other PNG metadata which is not
885/// necessary to make use of the pixels (the returned format is always 8-bit-per-component RGBA).
886///
887/// The returned [`Vec`] contains the pixels, represented as `[r, g, b, a]` arrays.
888/// Its length will be equal to `header.width * header.height`.
889/// (If you need a `Vec<u8>` of length `header.width * header.height * 4` instead, you can use
890/// [`Vec::into_flattened()`] to convert it.)
891pub fn decode(bytes: &[u8]) -> Result<(PngHeader, Vec<[u8; 4]>), DecodeError> {
892    if bytes.len() < PNG_MAGIC_BYTES.len() {
893        return Err(DecodeError::MissingBytes);
894    }
895
896    if &bytes[0..PNG_MAGIC_BYTES.len()] != PNG_MAGIC_BYTES {
897        return Err(DecodeError::InvalidMagicBytes);
898    }
899
900    let bytes = &bytes[PNG_MAGIC_BYTES.len()..];
901
902    let header_chunk = read_chunk(bytes)?;
903    let header = PngHeader::from_chunk(&header_chunk)?;
904
905    let mut bytes = &bytes[header_chunk.byte_size()..];
906
907    let mut compressed_data: Vec<u8> =
908        Vec::with_capacity(header.width as usize * header.height as usize * 3);
909
910    let pixel_type = PixelType::new(header.color_type, header.bit_depth)?;
911    let mut ancillary_chunks = AncillaryChunks::default();
912
913    while !bytes.is_empty() {
914        let chunk = read_chunk(bytes)?;
915
916        match chunk.chunk_type {
917            ChunkType::ImageData => compressed_data.extend_from_slice(chunk.data),
918            ChunkType::Palette => ancillary_chunks.palette = Some(chunk.data),
919            ChunkType::Transparency => {
920                ancillary_chunks.transparency = TransparencyChunk::from_chunk(&chunk, pixel_type)
921            },
922            ChunkType::Background => ancillary_chunks.background = Some(chunk.data),
923            ChunkType::ImageEnd => break,
924            _ => {},
925        }
926
927        bytes = &bytes[chunk.byte_size()..];
928    }
929
930    let mut scanline_data = miniz_oxide::inflate::decompress_to_vec_zlib(&compressed_data)
931        .map_err(|miniz_oxide::inflate::DecompressError { status, output: _ }| {
932            DecodeError::Decompress(status)
933        })?;
934
935    // For now, output data is always RGBA, 1 byte per channel.
936    let mut output_rgba = vec![[0u8; 4]; header.width as usize * header.height as usize];
937
938    process_scanlines(
939        &header,
940        &mut scanline_data,
941        &mut output_rgba,
942        &ancillary_chunks,
943        pixel_type,
944    )?;
945
946    Ok((header, output_rgba))
947}
948
949#[cfg(test)]
950mod tests {
951    use super::*;
952
953    #[test]
954    fn png_suite_test() {
955        use image::EncodableLayout;
956
957        for entry in
958            std::fs::read_dir("test_pngs/png_suite").expect("Shaders directory should exist")
959        {
960            let entry = entry.unwrap();
961            let path = entry.path();
962
963            if let Some(extension) = path.extension().and_then(|os_str| os_str.to_str()) {
964                if extension.to_ascii_lowercase().as_str() == "png" {
965                    let png_bytes = std::fs::read(&path).unwrap();
966
967                    let (_header, decoded): (PngHeader, Vec<[u8; 4]>) = if path
968                        .file_stem()
969                        .expect("expected png path to be a file")
970                        .to_string_lossy()
971                        .starts_with('x')
972                    {
973                        assert!(decode(&png_bytes).is_err());
974                        continue;
975                    } else {
976                        decode(&png_bytes).unwrap()
977                    };
978                    let decoded: Vec<u8> = decoded.into_flattened();
979
980                    // Uncomment to inspect output.png for debugging.
981                    // let image_buf: image::ImageBuffer<image::Rgba<u8>, _> =
982                    //     image::ImageBuffer::from_vec(
983                    //         _header.width,
984                    //         _header.height,
985                    //         decoded.clone(),
986                    //     )
987                    //     .unwrap();
988
989                    // image_buf.save("output.png").unwrap();
990
991                    let comparison_image = image::open(path).unwrap();
992                    let comarison_rgba8 = comparison_image.to_rgba8();
993
994                    let comparison_bytes = comarison_rgba8.as_bytes();
995                    assert_eq!(decoded.len(), comparison_bytes.len());
996
997                    for (idx, (test_byte, comparison_byte)) in
998                        decoded.iter().zip(comparison_bytes.iter()).enumerate()
999                    {
1000                        let start_idx = idx.saturating_sub(16);
1001                        let end_idx = (idx + 16).min(decoded.len());
1002                        assert_eq!(test_byte, comparison_byte, "incorrect byte at index {}, decoded slice: {:?}, comparison_slice: {:?}", idx, &decoded[start_idx..end_idx], &comparison_bytes[start_idx..end_idx]);
1003                    }
1004                }
1005            }
1006        }
1007    }
1008
1009    #[test]
1010    fn test_trailing_zero() {
1011        let path = "test_pngs/trailing_zero.png";
1012        let png_bytes = std::fs::read(path).unwrap();
1013        let (_header, _decoded) = decode(&png_bytes)
1014            .expect("A PNG with trailing zeroes after the ImageEnd chunk should be readable");
1015    }
1016}