1#![no_std]
2
3#[macro_use]
4extern crate alloc;
5
6#[cfg(test)]
7extern crate std;
8
9use alloc::vec::Vec;
10use core::convert::{TryFrom, TryInto};
11use crc32fast::Hasher;
12use miniz_oxide::inflate::TINFLStatus;
13use num_enum::TryFromPrimitive;
14
15const PNG_MAGIC_BYTES: &[u8] = &[137, 80, 78, 71, 13, 10, 26, 10];
16
17#[repr(u8)]
18#[derive(Debug, Copy, Clone, PartialEq, Eq, TryFromPrimitive)]
19pub enum BitDepth {
20 One = 1,
21 Two = 2,
22 Four = 4,
23 Eight = 8,
24 Sixteen = 16,
25}
26
27#[repr(u8)]
28#[derive(Debug, Copy, Clone, PartialEq, Eq, TryFromPrimitive)]
29pub enum ColorType {
30 Grayscale = 0,
31 Rgb = 2,
32 Palette = 3,
33 GrayscaleAlpha = 4,
34 RgbAlpha = 6,
35}
36
37impl ColorType {
38 pub fn sample_multiplier(&self) -> usize {
39 match self {
40 ColorType::Grayscale => 1,
41 ColorType::Rgb => 3,
42 ColorType::Palette => 1,
43 ColorType::GrayscaleAlpha => 2,
44 ColorType::RgbAlpha => 4,
45 }
46 }
47}
48
49#[derive(Debug, Copy, Clone)]
50enum PixelType {
51 Grayscale1,
52 Grayscale2,
53 Grayscale4,
54 Grayscale8,
55 Grayscale16,
56
57 Rgb8,
58 Rgb16,
59
60 Palette1,
61 Palette2,
62 Palette4,
63 Palette8,
64
65 GrayscaleAlpha8,
66 GrayscaleAlpha16,
67
68 RgbAlpha8,
69 RgbAlpha16,
70}
71
72impl PixelType {
73 fn new(color_type: ColorType, bit_depth: BitDepth) -> Result<Self, DecodeError> {
74 let result = match color_type {
75 ColorType::Grayscale => match bit_depth {
76 BitDepth::One => PixelType::Grayscale1,
77 BitDepth::Two => PixelType::Grayscale2,
78 BitDepth::Four => PixelType::Grayscale4,
79 BitDepth::Eight => PixelType::Grayscale8,
80 BitDepth::Sixteen => PixelType::Grayscale16,
81 },
82 ColorType::Rgb => match bit_depth {
83 BitDepth::Eight => PixelType::Rgb8,
84 BitDepth::Sixteen => PixelType::Rgb16,
85 _ => return Err(DecodeError::InvalidColorTypeBitDepthCombination),
86 },
87 ColorType::Palette => match bit_depth {
88 BitDepth::One => PixelType::Palette1,
89 BitDepth::Two => PixelType::Palette2,
90 BitDepth::Four => PixelType::Palette4,
91 BitDepth::Eight => PixelType::Palette8,
92 _ => return Err(DecodeError::InvalidColorTypeBitDepthCombination),
93 },
94 ColorType::GrayscaleAlpha => match bit_depth {
95 BitDepth::Eight => PixelType::GrayscaleAlpha8,
96 BitDepth::Sixteen => PixelType::GrayscaleAlpha16,
97 _ => return Err(DecodeError::InvalidColorTypeBitDepthCombination),
98 },
99 ColorType::RgbAlpha => match bit_depth {
100 BitDepth::Eight => PixelType::RgbAlpha8,
101 BitDepth::Sixteen => PixelType::RgbAlpha16,
102 _ => return Err(DecodeError::InvalidColorTypeBitDepthCombination),
103 },
104 };
105
106 Ok(result)
107 }
108}
109
110#[inline(always)]
111fn u16_to_u8(val: u16) -> u8 {
112 (val >> 8) as u8
113}
114
115#[derive(Default)]
116struct AncillaryChunks<'a> {
117 palette: Option<&'a [u8]>,
118 transparency: Option<TransparencyChunk<'a>>,
119 background: Option<&'a [u8]>,
120}
121
122struct ScanlineIterator<'a> {
123 image_width: usize, pixel_cursor: usize,
125 pixel_type: PixelType,
126 scanline: &'a [u8],
127 extra_chunks: &'a AncillaryChunks<'a>,
128}
129
130impl<'a> ScanlineIterator<'a> {
131 fn new(
132 image_width: u32,
133 pixel_type: PixelType,
134 scanline: &'a [u8],
135 extra_chunks: &'a AncillaryChunks<'a>,
136 ) -> Self {
137 Self {
138 image_width: image_width as usize,
139 pixel_cursor: 0,
140 pixel_type,
141 scanline,
142 extra_chunks,
143 }
144 }
145}
146
147impl<'a> Iterator for ScanlineIterator<'a> {
148 type Item = (u8, u8, u8, u8);
149
150 fn next(&mut self) -> Option<Self::Item> {
151 if self.pixel_cursor >= self.image_width {
152 return None;
153 }
154
155 let pixel = match self.pixel_type {
156 PixelType::Grayscale1 => {
157 let byte = self.scanline[self.pixel_cursor / 8];
158 let bit_offset = 7 - self.pixel_cursor % 8;
159 let grayscale_val = (byte >> bit_offset) & 1;
160
161 let alpha = match self.extra_chunks.transparency {
162 Some(TransparencyChunk::Grayscale(transparent_val))
163 if grayscale_val == transparent_val =>
164 {
165 0
166 },
167 _ => 255,
168 };
169
170 let pixel_val = grayscale_val * 255;
171
172 Some((pixel_val, pixel_val, pixel_val, alpha))
173 },
174 PixelType::Grayscale2 => {
175 let byte = self.scanline[self.pixel_cursor / 4];
176 let bit_offset = 6 - ((self.pixel_cursor % 4) * 2);
177 let grayscale_val = (byte >> bit_offset) & 0b11;
178
179 let alpha = match self.extra_chunks.transparency {
180 Some(TransparencyChunk::Grayscale(transparent_val))
181 if grayscale_val == transparent_val =>
182 {
183 0
184 },
185 _ => 255,
186 };
187
188 let pixel_val = ((grayscale_val as f32 / 3.0) * 255.0) as u8;
190
191 Some((pixel_val, pixel_val, pixel_val, alpha))
192 },
193 PixelType::Grayscale4 => {
194 let byte = self.scanline[self.pixel_cursor / 2];
195 let bit_offset = 4 - ((self.pixel_cursor % 2) * 4);
196 let grayscale_val = (byte >> bit_offset) & 0b1111;
197
198 let alpha = match self.extra_chunks.transparency {
199 Some(TransparencyChunk::Grayscale(transparent_val))
200 if grayscale_val == transparent_val =>
201 {
202 0
203 },
204 _ => 255,
205 };
206
207 let pixel_val = ((grayscale_val as f32 / 15.0) * 255.0) as u8;
209 Some((pixel_val, pixel_val, pixel_val, alpha))
210 },
211 PixelType::Grayscale8 => {
212 let byte = self.scanline[self.pixel_cursor];
213
214 let alpha = match self.extra_chunks.transparency {
215 Some(TransparencyChunk::Grayscale(transparent_val))
216 if byte == transparent_val =>
217 {
218 0
219 },
220 _ => 255,
221 };
222 Some((byte, byte, byte, alpha))
223 },
224 PixelType::Grayscale16 => {
225 let offset = self.pixel_cursor * 2;
226 let grayscale_val =
227 u16::from_be_bytes([self.scanline[offset], self.scanline[offset + 1]]);
228
229 let pixel_val = u16_to_u8(grayscale_val);
230
231 let alpha = match self.extra_chunks.transparency {
235 Some(TransparencyChunk::Grayscale(transparent_val))
236 if pixel_val == transparent_val =>
237 {
238 0
239 },
240 _ => 255,
241 };
242
243 Some((pixel_val, pixel_val, pixel_val, alpha))
244 },
245 PixelType::Rgb8 => {
246 let offset = self.pixel_cursor * 3;
247 let r = self.scanline[offset];
248 let g = self.scanline[offset + 1];
249 let b = self.scanline[offset + 2];
250
251 let alpha = match self.extra_chunks.transparency {
252 Some(TransparencyChunk::Rgb(t_r, t_g, t_b))
253 if r == t_r && g == t_g && b == t_b =>
254 {
255 0
256 },
257 _ => 255,
258 };
259
260 Some((r, g, b, alpha))
261 },
262 PixelType::Rgb16 => {
263 let offset = self.pixel_cursor * 6;
264 let r = u16::from_be_bytes([self.scanline[offset], self.scanline[offset + 1]]);
265 let g = u16::from_be_bytes([self.scanline[offset + 2], self.scanline[offset + 3]]);
266 let b = u16::from_be_bytes([self.scanline[offset + 4], self.scanline[offset + 5]]);
267
268 let r = u16_to_u8(r);
269 let g = u16_to_u8(g);
270 let b = u16_to_u8(b);
271
272 let alpha = match self.extra_chunks.transparency {
273 Some(TransparencyChunk::Rgb(t_r, t_g, t_b))
274 if r == t_r && g == t_g && b == t_b =>
275 {
276 0
277 },
278 _ => 255,
279 };
280
281 Some((r, g, b, alpha))
282 },
283 PixelType::Palette1 => {
284 let byte = self.scanline[self.pixel_cursor / 8];
285 let bit_offset = 7 - self.pixel_cursor % 8;
286 let palette_idx = ((byte >> bit_offset) & 1) as usize;
287
288 let offset = palette_idx * 3;
289
290 let palette = self.extra_chunks.palette.unwrap();
291 let r = palette[offset];
292 let g = palette[offset + 1];
293 let b = palette[offset + 2];
294
295 let alpha: u8 = match self.extra_chunks.transparency {
296 Some(TransparencyChunk::Palette(data)) => {
297 *data.get(palette_idx).unwrap_or(&255)
298 },
299 Some(_) | None => 255,
300 };
301
302 Some((r, g, b, alpha))
303 },
304 PixelType::Palette2 => {
305 let byte = self.scanline[self.pixel_cursor / 4];
306 let bit_offset = 6 - ((self.pixel_cursor % 4) * 2);
307 let palette_idx = ((byte >> bit_offset) & 0b11) as usize;
308
309 let offset = palette_idx * 3;
310
311 let palette = self.extra_chunks.palette.unwrap();
312 let r = palette[offset];
313 let g = palette[offset + 1];
314 let b = palette[offset + 2];
315
316 let alpha: u8 = match self.extra_chunks.transparency {
317 Some(TransparencyChunk::Palette(data)) => {
318 *data.get(palette_idx).unwrap_or(&255)
319 },
320 Some(_) | None => 255,
321 };
322
323 Some((r, g, b, alpha))
324 },
325 PixelType::Palette4 => {
326 let byte = self.scanline[self.pixel_cursor / 2];
327 let bit_offset = 4 - ((self.pixel_cursor % 2) * 4);
328 let palette_idx = ((byte >> bit_offset) & 0b1111) as usize;
329
330 let offset = palette_idx * 3;
331
332 let palette = self.extra_chunks.palette.unwrap();
333 let r = palette[offset];
334 let g = palette[offset + 1];
335 let b = palette[offset + 2];
336
337 let alpha: u8 = match self.extra_chunks.transparency {
338 Some(TransparencyChunk::Palette(data)) => {
339 *data.get(palette_idx).unwrap_or(&255)
340 },
341 Some(_) | None => 255,
342 };
343
344 Some((r, g, b, alpha))
345 },
346 PixelType::Palette8 => {
347 let offset = self.scanline[self.pixel_cursor] as usize * 3;
348
349 let palette = self.extra_chunks.palette.unwrap();
350 let r = palette[offset];
351 let g = palette[offset + 1];
352 let b = palette[offset + 2];
353
354 let alpha: u8 = match self.extra_chunks.transparency {
355 Some(TransparencyChunk::Palette(data)) => *data.get(offset).unwrap_or(&255),
356 Some(_) | None => 255,
357 };
358
359 Some((r, g, b, alpha))
360 },
361 PixelType::GrayscaleAlpha8 => {
362 let offset = self.pixel_cursor * 2;
363 let grayscale_val = self.scanline[offset];
364 let alpha = self.scanline[offset + 1];
365
366 Some((grayscale_val, grayscale_val, grayscale_val, alpha))
367 },
368 PixelType::GrayscaleAlpha16 => {
369 let offset = self.pixel_cursor * 4;
370 let grayscale_val =
371 u16::from_be_bytes([self.scanline[offset], self.scanline[offset + 1]]);
372 let alpha =
373 u16::from_be_bytes([self.scanline[offset + 2], self.scanline[offset + 3]]);
374
375 let grayscale_val = u16_to_u8(grayscale_val);
376 let alpha = u16_to_u8(alpha);
377
378 Some((grayscale_val, grayscale_val, grayscale_val, alpha))
379 },
380 PixelType::RgbAlpha8 => {
381 let offset = self.pixel_cursor * 4;
382 let r = self.scanline[offset];
383 let g = self.scanline[offset + 1];
384 let b = self.scanline[offset + 2];
385 let a = self.scanline[offset + 3];
386
387 Some((r, g, b, a))
388 },
389 PixelType::RgbAlpha16 => {
390 let offset = self.pixel_cursor * 8;
391 let r = u16::from_be_bytes([self.scanline[offset], self.scanline[offset + 1]]);
392 let g = u16::from_be_bytes([self.scanline[offset + 2], self.scanline[offset + 3]]);
393 let b = u16::from_be_bytes([self.scanline[offset + 4], self.scanline[offset + 5]]);
394 let a = u16::from_be_bytes([self.scanline[offset + 6], self.scanline[offset + 7]]);
395
396 let r = u16_to_u8(r);
397 let g = u16_to_u8(g);
398 let b = u16_to_u8(b);
399 let a = u16_to_u8(a);
400
401 Some((r, g, b, a))
402 },
403 };
404
405 self.pixel_cursor += 1;
406 pixel
407 }
408}
409
410#[repr(u8)]
411#[derive(Debug, Copy, Clone, PartialEq, Eq, TryFromPrimitive)]
412pub enum CompressionMethod {
413 Deflate = 0,
414}
415
416#[repr(u8)]
417#[derive(Debug, Copy, Clone, PartialEq, Eq, TryFromPrimitive)]
418pub enum FilterMethod {
419 Adaptive = 0,
420}
421
422#[repr(u8)]
423#[derive(Debug, Copy, Clone, PartialEq, Eq, TryFromPrimitive)]
424pub enum FilterType {
425 None = 0,
426 Sub = 1,
427 Up = 2,
428 Average = 3,
429 Paeth = 4,
430}
431
432#[repr(u8)]
433#[derive(Debug, Copy, Clone, PartialEq, Eq, TryFromPrimitive)]
434pub enum InterlaceMethod {
435 None = 0,
436 Adam7 = 1,
437}
438
439#[derive(Debug, Clone, PartialEq, Eq)]
440pub struct PngHeader {
441 pub width: u32,
442 pub height: u32,
443 pub bit_depth: BitDepth,
444 pub color_type: ColorType,
445 pub compression_method: CompressionMethod,
446 pub filter_method: FilterMethod,
447 pub interlace_method: InterlaceMethod,
448}
449
450impl PngHeader {
451 fn from_chunk(chunk: &Chunk) -> Result<Self, DecodeError> {
452 if chunk.chunk_type != ChunkType::ImageHeader {
453 return Err(DecodeError::InvalidChunkType);
454 }
455
456 if chunk.data.len() < 13 {
457 return Err(DecodeError::MissingBytes);
458 }
459
460 let width = read_u32(chunk.data, 0);
461 let height = read_u32(chunk.data, 4);
462 let bit_depth = chunk.data[8];
463 let color_type = chunk.data[9];
464 let compression_method = chunk.data[10];
465 let filter_method = chunk.data[11];
466 let interlace_method = chunk.data[12];
467
468 Ok(PngHeader {
469 width,
470 height,
471 bit_depth: TryFrom::try_from(bit_depth).map_err(|_| DecodeError::InvalidBitDepth)?,
472 color_type: TryFrom::try_from(color_type).map_err(|_| DecodeError::InvalidColorType)?,
473 compression_method: TryFrom::try_from(compression_method)
474 .map_err(|_| DecodeError::InvalidCompressionMethod)?,
475 filter_method: TryFrom::try_from(filter_method)
476 .map_err(|_| DecodeError::InvalidFilterMethod)?,
477 interlace_method: TryFrom::try_from(interlace_method)
478 .map_err(|_| DecodeError::InvalidInterlaceMethod)?,
479 })
480 }
481}
482
483#[derive(Debug, Clone, PartialEq, Eq)]
484pub enum DecodeError {
485 InvalidMagicBytes,
486 MissingBytes,
487 HeaderChunkNotFirst,
488 EndChunkNotLast,
489 InvalidChunkType,
490 InvalidChunk,
491 Decompress(TINFLStatus),
492
493 IncorrectChunkCrc,
494 InvalidBitDepth,
495 InvalidColorType,
496 InvalidColorTypeBitDepthCombination,
497 InvalidCompressionMethod,
498 InvalidFilterMethod,
499 InvalidFilterType,
500 InvalidInterlaceMethod,
501
502 IntegerOverflow,
505}
506
507#[derive(Debug, Copy, Clone, PartialEq, Eq)]
508pub enum ChunkType {
509 ImageHeader,
510 Palette,
511 Transparency,
512 Background,
513 Srgb,
514 ImageData,
515 ImageEnd,
516 Gamma,
517 Unknown([u8; 4]),
518}
519
520impl ChunkType {
521 fn from_bytes(bytes: &[u8; 4]) -> Self {
522 match bytes {
523 b"IHDR" => ChunkType::ImageHeader,
524 b"PLTE" => ChunkType::Palette,
525 b"tRNS" => ChunkType::Transparency,
526 b"bKGD" => ChunkType::Background,
527 b"sRGB" => ChunkType::Srgb,
528 b"IDAT" => ChunkType::ImageData,
529 b"IEND" => ChunkType::ImageEnd,
530 b"gAMA" => ChunkType::Gamma,
531 unknown_chunk_type => {
532 ChunkType::Unknown(*unknown_chunk_type)
534 },
535 }
536 }
537}
538
539#[derive(Debug)]
540struct Chunk<'a> {
541 chunk_type: ChunkType,
542 data: &'a [u8],
543 _crc: u32,
544}
545
546impl<'a> Chunk<'a> {
547 fn byte_size(&self) -> usize {
548 4 + 4 + self.data.len() + 4
550 }
551}
552
553enum TransparencyChunk<'a> {
554 Palette(&'a [u8]),
555 Grayscale(u8),
556 Rgb(u8, u8, u8),
557}
558
559impl<'a> TransparencyChunk<'a> {
560 fn from_chunk(chunk: &Chunk<'a>, pixel_type: PixelType) -> Option<Self> {
561 match pixel_type {
562 PixelType::Grayscale1 => Some(TransparencyChunk::Grayscale(chunk.data[1] & 0b1)),
563 PixelType::Grayscale2 => Some(TransparencyChunk::Grayscale(chunk.data[1] & 0b11)),
564 PixelType::Grayscale4 => Some(TransparencyChunk::Grayscale(chunk.data[1] & 0b1111)),
565 PixelType::Grayscale8 => Some(TransparencyChunk::Grayscale(chunk.data[1])),
566 PixelType::Grayscale16 => {
567 let val = u16::from_be_bytes([chunk.data[0], chunk.data[1]]);
568 Some(TransparencyChunk::Grayscale(u16_to_u8(val)))
569 },
570 PixelType::Rgb8 => {
571 let r = chunk.data[1];
572 let g = chunk.data[3];
573 let b = chunk.data[5];
574 Some(TransparencyChunk::Rgb(r, g, b))
575 },
576 PixelType::Rgb16 => {
577 let r = u16::from_be_bytes([chunk.data[0], chunk.data[1]]);
578 let g = u16::from_be_bytes([chunk.data[2], chunk.data[3]]);
579 let b = u16::from_be_bytes([chunk.data[4], chunk.data[5]]);
580 Some(TransparencyChunk::Rgb(u16_to_u8(r), u16_to_u8(g), u16_to_u8(b)))
581 },
582 PixelType::Palette1 => Some(TransparencyChunk::Palette(chunk.data)),
583 PixelType::Palette2 => Some(TransparencyChunk::Palette(chunk.data)),
584 PixelType::Palette4 => Some(TransparencyChunk::Palette(chunk.data)),
585 PixelType::Palette8 => Some(TransparencyChunk::Palette(chunk.data)),
586 PixelType::GrayscaleAlpha8 => None,
587 PixelType::GrayscaleAlpha16 => None,
588 PixelType::RgbAlpha8 => None,
589 PixelType::RgbAlpha16 => None,
590 }
591 }
592}
593
594fn read_u32(bytes: &[u8], offset: usize) -> u32 {
595 u32::from_be_bytes([bytes[offset], bytes[offset + 1], bytes[offset + 2], bytes[offset + 3]])
596}
597
598fn read_chunk(bytes: &[u8]) -> Result<Chunk<'_>, DecodeError> {
599 if bytes.len() < 4 {
600 return Err(DecodeError::MissingBytes);
601 }
602
603 let length = read_u32(bytes, 0) as usize;
604 let bytes = &bytes[4..];
605
606 if bytes.len() < (4 + length + 4) {
607 return Err(DecodeError::MissingBytes);
608 }
609
610 let chunk_type = ChunkType::from_bytes(&[bytes[0], bytes[1], bytes[2], bytes[3]]);
611
612 let crc_offset = 4 + length;
613 let crc = read_u32(bytes, crc_offset);
614
615 let data_for_crc = &bytes[..crc_offset];
617
618 let mut hasher = Hasher::new();
619 hasher.reset();
620 hasher.update(data_for_crc);
621
622 if crc != hasher.finalize() {
623 return Err(DecodeError::IncorrectChunkCrc);
624 }
625
626 Ok(Chunk { chunk_type, data: &data_for_crc[4..], _crc: crc })
627}
628
629fn defilter(
630 filter_type: FilterType,
631 bytes_per_pixel: usize,
632 x: usize,
633 current_scanline: &[u8],
634 last_scanline: &[u8],
635) -> u8 {
636 match filter_type {
637 FilterType::None => current_scanline[x],
638 FilterType::Sub => {
639 if let Some(idx) = x.checked_sub(bytes_per_pixel) {
640 current_scanline[x].wrapping_add(current_scanline[idx])
641 } else {
642 current_scanline[x]
643 }
644 },
645 FilterType::Up => current_scanline[x].wrapping_add(last_scanline[x]),
646 FilterType::Average => {
647 let raw_val = if let Some(idx) = x.checked_sub(bytes_per_pixel) {
648 current_scanline[idx]
649 } else {
650 0
651 };
652
653 (current_scanline[x] as u16 + ((raw_val as u16 + last_scanline[x] as u16) / 2)) as u8
654 },
655 FilterType::Paeth => {
656 if let Some(idx) = x.checked_sub(bytes_per_pixel) {
657 let left = current_scanline[idx];
658 let above = last_scanline[x];
659 let upper_left = last_scanline[idx];
660
661 let predictor = paeth_predictor(left as i16, above as i16, upper_left as i16);
662
663 current_scanline[x].wrapping_add(predictor)
664 } else {
665 let left = 0;
666 let above = last_scanline[x];
667 let upper_left = 0;
668
669 let predictor = paeth_predictor(left as i16, above as i16, upper_left as i16);
670
671 current_scanline[x].wrapping_add(predictor)
672 }
673 },
674 }
675}
676
677fn process_scanlines(
678 header: &PngHeader,
679 scanline_data: &mut [u8],
680 output_rgba: &mut [[u8; 4]],
681 ancillary_chunks: &AncillaryChunks,
682 pixel_type: PixelType,
683) -> Result<(), DecodeError> {
684 let mut cursor = 0;
685 let bytes_per_pixel: usize =
686 (header.bit_depth as usize * header.color_type.sample_multiplier()).div_ceil(8);
687
688 match header.interlace_method {
689 InterlaceMethod::None => {
690 let bytes_per_scanline = (header.width as u64
692 * header.bit_depth as u64
693 * header.color_type.sample_multiplier() as u64)
694 .div_ceil(8);
695 let bytes_per_scanline: usize =
696 bytes_per_scanline.try_into().map_err(|_| DecodeError::IntegerOverflow)?;
697
698 let mut last_scanline = vec![0u8; bytes_per_scanline];
699
700 for y in 0..header.height {
701 let filter_type = FilterType::try_from(scanline_data[cursor])
702 .map_err(|_| DecodeError::InvalidFilterType)?;
703 cursor += 1;
704
705 let current_scanline = &mut scanline_data[cursor..(cursor + bytes_per_scanline)];
706
707 for x in 0..(bytes_per_scanline) {
708 let unfiltered_byte =
709 defilter(filter_type, bytes_per_pixel, x, current_scanline, &last_scanline);
710 current_scanline[x] = unfiltered_byte;
711 }
712
713 let scanline_iter = ScanlineIterator::new(
714 header.width,
715 pixel_type,
716 current_scanline,
717 ancillary_chunks,
718 );
719
720 for (idx, (r, g, b, a)) in scanline_iter.enumerate() {
721 let (output_x, output_y) = (idx, y);
722
723 let output_idx = (output_y as u64 * header.width as u64) + (output_x as u64);
724 let output_idx: usize =
725 output_idx.try_into().map_err(|_| DecodeError::IntegerOverflow)?;
726
727 output_rgba[output_idx] = [r, g, b, a];
728 }
729
730 last_scanline.copy_from_slice(current_scanline);
731 cursor += bytes_per_scanline;
732 }
733 },
734 InterlaceMethod::Adam7 => {
735 let max_bytes_per_scanline = header.width as usize * bytes_per_pixel;
736 let mut last_scanline = vec![0u8; max_bytes_per_scanline];
737
738 for pass in 1..=7 {
749 let (pass_width, pass_height) = match pass {
750 1 => {
751 let pass_width = header.width.div_ceil(8);
752 let pass_height = header.height.div_ceil(8);
753 (pass_width, pass_height)
754 },
755 2 => {
756 let pass_width = (header.width / 8) + ((header.width % 8) / 5);
757 let pass_height = header.height.div_ceil(8);
758 (pass_width, pass_height)
759 },
760 3 => {
761 let pass_width = ((header.width / 8) * 2) + (header.width % 8).div_ceil(4);
762 let pass_height = (header.height / 8) + ((header.height % 8) / 5);
763 (pass_width, pass_height)
764 },
765 4 => {
766 let pass_width = ((header.width / 8) * 2) + (header.width % 8 + 1) / 4;
767 let pass_height = header.height.div_ceil(4);
768 (pass_width, pass_height)
769 },
770 5 => {
771 let pass_width = (header.width / 2) + (header.width % 2);
772 let pass_height = ((header.height / 8) * 2) + (header.height % 8 + 1) / 4;
773 (pass_width, pass_height)
774 },
775 6 => {
776 let pass_width = header.width / 2;
777 let pass_height = (header.height / 2) + (header.height % 2);
778 (pass_width, pass_height)
779 },
780 7 => {
781 let pass_width = header.width;
782 let pass_height = header.height / 2;
783 (pass_width, pass_height)
784 },
785 _ => (0, 0),
786 };
787
788 if pass_width == 0 || pass_height == 0 {
790 continue;
791 }
792
793 let bytes_per_scanline = (pass_width as u64
794 * header.bit_depth as u64
795 * header.color_type.sample_multiplier() as u64)
796 .div_ceil(8);
797 let bytes_per_scanline: usize =
798 bytes_per_scanline.try_into().expect("bytes_per_scanline overflowed a usize");
799
800 let last_scanline = &mut last_scanline[..(bytes_per_scanline)];
801 for byte in last_scanline.iter_mut() {
802 *byte = 0;
803 }
804
805 for y in 0..pass_height {
806 let filter_type = FilterType::try_from(scanline_data[cursor])
807 .map_err(|_| DecodeError::InvalidFilterType)?;
808 cursor += 1;
809
810 let current_scanline =
811 &mut scanline_data[cursor..(cursor + bytes_per_scanline)];
812
813 for x in 0..(bytes_per_scanline) {
814 let unfiltered_byte = defilter(
815 filter_type,
816 bytes_per_pixel,
817 x,
818 current_scanline,
819 last_scanline,
820 );
821 current_scanline[x] = unfiltered_byte;
822 }
823
824 let scanline_iter = ScanlineIterator::new(
825 pass_width,
826 pixel_type,
827 current_scanline,
828 ancillary_chunks,
829 );
830
831 for (idx, (r, g, b, a)) in scanline_iter.enumerate() {
832 let (output_x, output_y) = match pass {
834 1 => (idx * 8, y * 8),
835 2 => (idx * 8 + 4, y * 8),
836 3 => (idx * 4, y * 8 + 4),
837 4 => (idx * 4 + 2, y * 4),
838 5 => (idx * 2, y * 4 + 2),
839 6 => (idx * 2 + 1, y * 2),
840 7 => (idx, y * 2 + 1),
841 _ => (0, 0),
842 };
843
844 let output_idx =
845 (output_y as u64 * header.width as u64) + (output_x as u64);
846 let output_idx: usize =
847 output_idx.try_into().map_err(|_| DecodeError::IntegerOverflow)?;
848
849 output_rgba[output_idx] = [r, g, b, a];
850 }
851
852 last_scanline.copy_from_slice(current_scanline);
853
854 cursor += bytes_per_scanline;
855 }
856 }
857 },
858 }
859
860 Ok(())
861}
862
863fn paeth_predictor(a: i16, b: i16, c: i16) -> u8 {
864 let p = a + b - c;
869 let pa = (p - a).abs();
870 let pb = (p - b).abs();
871 let pc = (p - c).abs();
872
873 if pa <= pb && pa <= pc {
874 a as u8
875 } else if pb <= pc {
876 b as u8
877 } else {
878 c as u8
879 }
880}
881
882pub fn decode(bytes: &[u8]) -> Result<(PngHeader, Vec<[u8; 4]>), DecodeError> {
892 if bytes.len() < PNG_MAGIC_BYTES.len() {
893 return Err(DecodeError::MissingBytes);
894 }
895
896 if &bytes[0..PNG_MAGIC_BYTES.len()] != PNG_MAGIC_BYTES {
897 return Err(DecodeError::InvalidMagicBytes);
898 }
899
900 let bytes = &bytes[PNG_MAGIC_BYTES.len()..];
901
902 let header_chunk = read_chunk(bytes)?;
903 let header = PngHeader::from_chunk(&header_chunk)?;
904
905 let mut bytes = &bytes[header_chunk.byte_size()..];
906
907 let mut compressed_data: Vec<u8> =
908 Vec::with_capacity(header.width as usize * header.height as usize * 3);
909
910 let pixel_type = PixelType::new(header.color_type, header.bit_depth)?;
911 let mut ancillary_chunks = AncillaryChunks::default();
912
913 while !bytes.is_empty() {
914 let chunk = read_chunk(bytes)?;
915
916 match chunk.chunk_type {
917 ChunkType::ImageData => compressed_data.extend_from_slice(chunk.data),
918 ChunkType::Palette => ancillary_chunks.palette = Some(chunk.data),
919 ChunkType::Transparency => {
920 ancillary_chunks.transparency = TransparencyChunk::from_chunk(&chunk, pixel_type)
921 },
922 ChunkType::Background => ancillary_chunks.background = Some(chunk.data),
923 ChunkType::ImageEnd => break,
924 _ => {},
925 }
926
927 bytes = &bytes[chunk.byte_size()..];
928 }
929
930 let mut scanline_data = miniz_oxide::inflate::decompress_to_vec_zlib(&compressed_data)
931 .map_err(|miniz_oxide::inflate::DecompressError { status, output: _ }| {
932 DecodeError::Decompress(status)
933 })?;
934
935 let mut output_rgba = vec![[0u8; 4]; header.width as usize * header.height as usize];
937
938 process_scanlines(
939 &header,
940 &mut scanline_data,
941 &mut output_rgba,
942 &ancillary_chunks,
943 pixel_type,
944 )?;
945
946 Ok((header, output_rgba))
947}
948
949#[cfg(test)]
950mod tests {
951 use super::*;
952
953 #[test]
954 fn png_suite_test() {
955 use image::EncodableLayout;
956
957 for entry in
958 std::fs::read_dir("test_pngs/png_suite").expect("Shaders directory should exist")
959 {
960 let entry = entry.unwrap();
961 let path = entry.path();
962
963 if let Some(extension) = path.extension().and_then(|os_str| os_str.to_str()) {
964 if extension.to_ascii_lowercase().as_str() == "png" {
965 let png_bytes = std::fs::read(&path).unwrap();
966
967 let (_header, decoded): (PngHeader, Vec<[u8; 4]>) = if path
968 .file_stem()
969 .expect("expected png path to be a file")
970 .to_string_lossy()
971 .starts_with('x')
972 {
973 assert!(decode(&png_bytes).is_err());
974 continue;
975 } else {
976 decode(&png_bytes).unwrap()
977 };
978 let decoded: Vec<u8> = decoded.into_flattened();
979
980 let comparison_image = image::open(path).unwrap();
992 let comarison_rgba8 = comparison_image.to_rgba8();
993
994 let comparison_bytes = comarison_rgba8.as_bytes();
995 assert_eq!(decoded.len(), comparison_bytes.len());
996
997 for (idx, (test_byte, comparison_byte)) in
998 decoded.iter().zip(comparison_bytes.iter()).enumerate()
999 {
1000 let start_idx = idx.saturating_sub(16);
1001 let end_idx = (idx + 16).min(decoded.len());
1002 assert_eq!(test_byte, comparison_byte, "incorrect byte at index {}, decoded slice: {:?}, comparison_slice: {:?}", idx, &decoded[start_idx..end_idx], &comparison_bytes[start_idx..end_idx]);
1003 }
1004 }
1005 }
1006 }
1007 }
1008
1009 #[test]
1010 fn test_trailing_zero() {
1011 let path = "test_pngs/trailing_zero.png";
1012 let png_bytes = std::fs::read(path).unwrap();
1013 let (_header, _decoded) = decode(&png_bytes)
1014 .expect("A PNG with trailing zeroes after the ImageEnd chunk should be readable");
1015 }
1016}