embedded_png/
inflate.rs

1use core::cmp::min;
2use miniz_oxide::inflate::core::{decompress, DecompressorOxide};
3use miniz_oxide::inflate::core::inflate_flags::{TINFL_FLAG_COMPUTE_ADLER32, TINFL_FLAG_HAS_MORE_INPUT, TINFL_FLAG_PARSE_ZLIB_HEADER};
4use miniz_oxide::inflate::TINFLStatus;
5use crate::error::DecodeError;
6use crate::png::Chunk;
7use crate::types::{ChunkType, FilterType};
8
9/// The decompressor is implemented as a Q buffer
10/// Q because it is a circular buffer + a linear buffer, which if you draw them, looks like a Q
11///
12/// Data is always decompressed in the circular buffer, but since the whole circular buffer
13/// might be needed  when extracting new data, we temporarily store remaining data
14/// in the linear buffer.
15///
16/// The linear part would not be needed if we could control the maximum data that must be decompressed
17pub struct ChunkDecompressor<'src, T> {
18    // internal miniz decompressor data
19    decompressor: DecompressorOxide,
20    // slice containing all ImageData chunks
21    data_chunks: &'src [u8],
22    // position of the next chunk in this slice
23    next_chunk_start: Option<usize>,
24    // slice of the current data chunks
25    current_chunk: Option<&'src [u8]>,
26    // true when all data have been taken from chunks
27    chunk_end: bool,
28    // circular buffer where decompressed data is written
29    buffer: T,
30    // first waiting decompressed byte
31    data_pos: usize,
32    // size of waiting decompressed data
33    buffer_count: usize,
34    // non-circular buffer data that must be kept out of circular buffer during decompress call
35    buffer_extra: T,
36    // size of waiting data in buffer_extra
37    extra_count: usize,
38    // common flags for decompression
39    flags: u32,
40    total_decompressed: usize, // TODO remove
41    // TODO should we have a next_scanline_size here ?
42}
43
44impl<'src, 'buf> ChunkDecompressor<'src, &'buf mut [u8]> {
45    // buffer size must be >= min(decompression_window(32k), total_output_size)
46    // buffer extra size must be >= max scanline bytes
47    pub fn new_ref(data_chunks: &'src [u8], buffer: &'buf mut [u8], buffer_extra: &'buf mut [u8], check_crc: bool) -> Self {
48        let decompressor = DecompressorOxide::new();
49        // png has zlib header, always pass has more input, it doesn't matter if it's false
50        let mut flags = TINFL_FLAG_PARSE_ZLIB_HEADER | TINFL_FLAG_HAS_MORE_INPUT;
51        if check_crc {
52            flags |= TINFL_FLAG_COMPUTE_ADLER32;
53        }
54        ChunkDecompressor {
55            decompressor,
56            data_chunks,
57            next_chunk_start: Some(0),
58            current_chunk: None,
59            chunk_end: false,
60            buffer,
61            data_pos: 0,
62            buffer_count: 0,
63            buffer_extra,
64            extra_count: 0,
65            flags,
66            total_decompressed: 0,
67        }
68    }
69}
70
71// TODO alloc only
72impl<'src> ChunkDecompressor<'src, Vec<u8>> {
73    // buffer size must be >= min(decompression_window(32k), total_output_size)
74    // buffer extra size must be >= max scanline bytes
75    pub fn new_vec(data_chunks: &'src [u8], check_crc: bool) -> Self {
76        let decompressor = DecompressorOxide::new();
77        // png has zlib header, always pass has more input, it doesn't matter if it's false
78        let mut flags = TINFL_FLAG_PARSE_ZLIB_HEADER | TINFL_FLAG_HAS_MORE_INPUT;
79        if check_crc {
80            flags |= TINFL_FLAG_COMPUTE_ADLER32;
81        }
82        ChunkDecompressor {
83            decompressor,
84            data_chunks,
85            next_chunk_start: Some(0),
86            current_chunk: None,
87            chunk_end: false,
88            buffer: vec![0_u8; 1024<<5],
89            data_pos: 0,
90            buffer_count: 0,
91            buffer_extra: vec![0_u8; 1024<<5], // TODO
92            extra_count: 0,
93            flags,
94            total_decompressed: 0,
95        }
96    }
97    pub fn update_buffers(&mut self) {}
98}
99
100impl<'src> ChunkDecompressor<'src, [u8; 1024<<5]> {
101    // buffer size must be >= min(decompression_window(32k), total_output_size)
102    // buffer extra size must be >= max scanline bytes
103    pub fn new_static(data_chunks: &'src [u8], check_crc: bool) -> Self {
104        let decompressor = DecompressorOxide::new();
105        // png has zlib header, always pass has more input, it doesn't matter if it's false
106        let mut flags = TINFL_FLAG_PARSE_ZLIB_HEADER | TINFL_FLAG_HAS_MORE_INPUT;
107        if check_crc {
108            flags |= TINFL_FLAG_COMPUTE_ADLER32;
109        }
110        ChunkDecompressor {
111            decompressor,
112            data_chunks,
113            next_chunk_start: Some(0),
114            current_chunk: None,
115            chunk_end: false,
116            buffer: [0_u8; 1024<<5],
117            data_pos: 0,
118            buffer_count: 0,
119            buffer_extra: [0_u8; 1024<<5], // TODO
120            extra_count: 0,
121            flags,
122            total_decompressed: 0,
123        }
124    }
125}
126
127
128impl<'src, T> ChunkDecompressor<'src, T>
129where T: AsRef<[u8]> + AsMut<[u8]>
130{
131    // advance current chunk by one, result in self.current_chunk
132    fn check_chunk_data(&mut self) {
133        // we already have some data
134        if let Some(chunk) = self.current_chunk && !chunk.is_empty() {
135            return;
136        }
137        // loop just in case there are empty chunks
138        loop {
139            if let Some(next_start) = self.next_chunk_start && next_start < self.data_chunks.len() {
140                // it was already checked during first parse, so we can unwrap, and avoid crc check
141                let next_chunk = Chunk::from_bytes(self.data_chunks, next_start, false).unwrap();
142                if next_chunk.end < self.data_chunks.len() {
143                    self.next_chunk_start = Some(next_chunk.end);
144                } else {
145                    // TODO smelly, we assign none at 2 different steps
146                    self.next_chunk_start = None;
147                }
148                if next_chunk.chunk_type == ChunkType::ImageData {
149                    if !next_chunk.data.is_empty() {
150                        self.current_chunk = Some(next_chunk.data);
151                        return;
152                    }
153                }
154            } else {
155                // this is the end my friend
156                self.current_chunk = None;
157                self.next_chunk_start = None;
158                self.chunk_end = true;
159                return;
160            }
161        }
162    }
163
164    // get the next scanline, extracting data with the decompressor if needed
165    fn get_enough_data(&mut self, size: usize) -> Result<(), DecodeError> {
166        debug_assert!(size <= self.buffer.as_ref().len(), "Decompression buffer too small (need {})", size);
167        debug_assert!(size >= self.extra_count, "Decompression extra buffer too big (need {})", size);
168        // we already have enough data
169        if self.buffer_count + self.extra_count >= size {
170            return Ok(());
171        }
172
173        // now we need to decompress some bytes
174        // position where to start decompressing
175        let mut buffer_pos = self.buffer_count + self.data_pos;
176
177        // since decompress() does not cross the buffer wrap but can grow through the end
178        // we must save any data that is after buffer_pos and before buffer.len()
179        if buffer_pos >= self.buffer.as_ref().len() {
180            buffer_pos -= self.buffer.as_ref().len();
181            // save what's after the new buffer_pos (between data_pos and buffer.len())
182            let data_count = self.buffer.as_ref().len() - self.data_pos;
183            let new_extra_count = self.extra_count + data_count;
184            // there is no overflow since buffer_extra has enough data for a single scanline
185            self.buffer_extra.as_mut()[self.extra_count..new_extra_count].copy_from_slice(&self.buffer.as_ref()[self.data_pos..self.buffer.as_ref().len()]);
186            // update info accordingly
187            self.data_pos = 0;
188            self.extra_count = new_extra_count;
189            self.buffer_count = buffer_pos;
190        }
191
192        // get some bytes to uncompress
193        self.check_chunk_data();
194        let next_data = match self.current_chunk {
195            None => &[], // continue, we might have more output pending
196            Some(x) => x,
197        };
198
199        // run decompress
200        let (status, in_count, out_count) =
201            decompress(&mut self.decompressor,
202                       next_data,
203                       self.buffer.as_mut(),
204                       buffer_pos,
205                       self.flags
206            );
207        println!("decompressed {}", out_count);
208
209        // account for byte read
210        if let Some(chunk) = &mut self.current_chunk {
211            *chunk = &(*chunk)[in_count..];
212            if chunk.is_empty() && self.next_chunk_start.is_none() {
213                self.chunk_end = true;
214            }
215        }
216
217        // account for bytes written
218        self.buffer_count += out_count;
219        self.total_decompressed += out_count;
220        debug_assert!(buffer_pos + out_count <= self.buffer.as_ref().len(), "decompress wrapped around");
221
222        // account for errors
223        if (status as i32) < 0 {
224            return Err(DecodeError::Decompress(status));
225        }
226        match status {
227            TINFLStatus::Done => if !self.chunk_end {
228                return Err(DecodeError::InvalidChunk);
229            }
230            TINFLStatus::NeedsMoreInput => if self.chunk_end {
231                return Err(DecodeError::InvalidChunk);
232            }
233            // TINFLStatus::HasMoreOutput is handled gracefully by decompress on next run
234            _ => {}
235        }
236
237        // rerun to avoid duplicating logic
238        self.get_enough_data(size)
239    }
240
241    // remove size bytes from buffer
242    fn remove_data(&mut self, size: usize) {
243        debug_assert!(size >= self.extra_count, "Decompression extra buffer too big to remove (need {})", size);
244        let main_buffer_count = size - self.extra_count;
245        self.data_pos += main_buffer_count;
246        if self.data_pos >= self.buffer.as_ref().len() {
247            self.data_pos -= self.buffer.as_ref().len();
248        }
249        self.buffer_count -= main_buffer_count;
250        self.extra_count = 0;
251    }
252
253    // extract a filter type from first data byte
254    fn filter_type(&mut self) -> Result<FilterType, DecodeError> {
255        let byte = if self.extra_count > 0 {
256            self.buffer_extra.as_ref()[0]
257        } else {
258            self.buffer.as_ref()[self.data_pos]
259        };
260        FilterType::try_from(byte).map_err(|_| DecodeError::InvalidFilterType)
261    }
262
263    // copy the whole scanline_data to slice,
264    // starting at decompressed pos + 1 : because we do not copy the filter type
265    // we copy target len bytes
266    fn copy_to_slice(&self, target: &mut [u8]) {
267        let count = target.len();
268        debug_assert!(count + 1 <= self.buffer_count + self.extra_count, "copy_to_slice, error slice too big {} > {}", count + 1, self.buffer_count + self.extra_count);
269        debug_assert!(count + 1 >= self.extra_count, "copy_to_slice error, extra too big {} < {}", count + 1, self.extra_count);
270        if self.extra_count > 1 {
271            // first extra_buffer
272            let next_count = self.extra_count-1;
273            target[..next_count].copy_from_slice(&self.buffer_extra.as_ref()[1..self.extra_count]);
274            // then first half of circular buffer
275            let count = count - next_count;
276            if count > 0 {
277                let next_pos = next_count;
278                let buffer_end = min(self.data_pos + count, self.buffer.as_ref().len());
279                let next_count = buffer_end - self.data_pos;
280                target[next_pos..next_pos + next_count].copy_from_slice(&self.buffer.as_ref()[self.data_pos..buffer_end]);
281                // finally second half if needed
282                let count = count - next_count;
283                if count > 0 {
284                    let next_pos = next_pos + next_count;
285                    target[next_pos..].copy_from_slice(&self.buffer.as_ref()[..count]);
286                }
287            }
288        } else {
289            // first half of circular buffer
290            let buffer_end = min(self.data_pos + 1 + count, self.buffer.as_ref().len());
291            let next_count = buffer_end - self.data_pos - 1;
292            target[..next_count].copy_from_slice(&self.buffer.as_ref()[self.data_pos+1..buffer_end]);
293            // finally second half if needed
294            let count = count - next_count;
295            if count > 0 {
296                let next_pos = next_count;
297                target[next_pos..].copy_from_slice(&self.buffer.as_ref()[..count]);
298            }
299        }
300    }
301
302    fn enumerate(&self, count: usize) -> impl Iterator<Item = (usize, u8)> {
303        let main_count = count + 1 - self.extra_count;
304        let end = min(self.data_pos + main_count, self.buffer.as_ref().len());
305        self.buffer_extra.as_ref()[..self.extra_count].iter()
306            .chain(self.buffer.as_ref()[self.data_pos..end].iter())
307            .chain(if end == self.buffer.as_ref().len() {
308                self.buffer.as_ref()[0..main_count-(self.buffer.as_ref().len()-self.data_pos)].iter()
309            } else {
310                [].iter()
311            })
312            .skip(1)
313            .copied()
314            .enumerate()
315    }
316
317    pub fn decode_next_scanline(&mut self, last_scanline: &mut [u8], bytes_per_pixel: usize) -> Result<(), DecodeError> {
318        self.get_enough_data(last_scanline.len()+1)?;
319        let filter_type = self.filter_type()?;
320
321        // decode scanline directly into last scanline
322        match filter_type {
323            FilterType::None => self.copy_to_slice(last_scanline),
324            FilterType::Sub => {
325                let mut left_pixel = [0_u8; 8];
326                self.enumerate(last_scanline.len())
327                    .fold(0, |byte, (i,value)| {
328                        let left = left_pixel[byte];
329                        last_scanline[i] = value.wrapping_add(left);
330                        left_pixel[byte] = last_scanline[i];
331                        (byte+1) % bytes_per_pixel
332                    });
333            }
334            FilterType::Up => for (i,value) in self.enumerate(last_scanline.len()) {
335                last_scanline[i] = value.wrapping_add(last_scanline[i]);
336            }
337            FilterType::Average => {
338                let mut left_pixel = [0_u8; 8];
339                self.enumerate(last_scanline.len())
340                    .fold(0, |byte, (i,value)| {
341                        let left = left_pixel[byte];
342                        let top = last_scanline[i];
343                        // we can either work wit u16 or with u8 and a carry
344                        // let's choose u16
345                        let average = (left as u16 + top as u16) / 2;
346                        last_scanline[i] = value.wrapping_add(average as u8);
347                        left_pixel[byte] = last_scanline[i];
348                        (byte+1) % bytes_per_pixel
349                    });
350            }
351            FilterType::Paeth => {
352                let mut top_left_pixel = [0_u8; 8];
353                let mut left_pixel = [0_u8; 8];
354                self.enumerate(last_scanline.len())
355                    .fold(0, |byte, (i,value)| {
356                        let a = left_pixel[byte] as i16;
357                        let b = last_scanline[i] as i16;
358                        let c = top_left_pixel[byte] as i16;
359                        let p = a + b - c;      // initial estimate
360                        let pa = (p - a).abs(); // distances to a, b, c
361                        let pb = (p - b).abs();
362                        let pc = (p - c).abs();
363                        // return nearest of a,b,c,
364                        // breaking ties in order a,b,c.
365                        let predictor = if pa <= pb && pa <= pc {
366                            left_pixel[byte]
367                        } else if pb <= pc {
368                            last_scanline[i]
369                        } else {
370                            top_left_pixel[byte]
371                        };
372                        top_left_pixel[byte] = last_scanline[i];
373                        last_scanline[i] = value.wrapping_add(predictor);
374                        left_pixel[byte] = last_scanline[i];
375                        (byte+1) % bytes_per_pixel
376                    });
377            }
378        }
379        // accounting
380        self.remove_data(last_scanline.len() + 1);
381        Ok(())
382    }
383
384    pub fn reset(&mut self) {
385        todo!()
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use std::fs;
392    use png_decoder::*;
393    use crate::colors::AlphaColor;
394    use crate::ParsedPng;
395    use super::*;
396
397    #[test]
398    fn list_chunks() {
399        let bytes = fs::read("sekiro.png").unwrap();
400        let png = ParsedPng::from_bytes(&bytes, true, AlphaColor).unwrap();
401
402        let mut decompressor = ChunkDecompressor::new_vec(png.data_chunks, true);
403
404        for _ in 0..35 {
405            decompressor.current_chunk = None;
406            decompressor.check_chunk_data();
407            assert!(decompressor.current_chunk.is_some(), "Missing chunk");
408            assert!(!decompressor.chunk_end, "Decompression ended early");
409            assert_eq!(decompressor.current_chunk.unwrap().len(), 32_768, "Incorrect chunk size");
410        }
411        decompressor.current_chunk = None;
412        decompressor.check_chunk_data();
413        assert!(decompressor.current_chunk.is_some(), "Missing chunk");
414        assert!(!decompressor.chunk_end, "Decompression ended early");
415        assert_eq!(decompressor.current_chunk.unwrap().len(), 7_663, "Incorrect chunk size");
416        decompressor.current_chunk = None;
417        decompressor.check_chunk_data();
418        assert!(decompressor.chunk_end, "Decompression ended late");
419    }
420
421    #[test]
422    fn read_chunks() {
423        let bytes = fs::read("sekiro.png").unwrap();
424        let png = ParsedPng::from_bytes(&bytes, true, AlphaColor).unwrap();
425
426        //let mut undecoded = pre_decode(&bytes).unwrap();
427
428        let mut decompressor = ChunkDecompressor::new_vec(png.data_chunks,true);
429        let mut scanline = vec![0_u8; 5120];
430
431        for i in 0..720 {
432            let r = decompressor.get_enough_data(5121);
433            assert!(r.is_ok(), "Get data Error");
434            decompressor.copy_to_slice(&mut scanline);
435            assert_eq!(decompressor.enumerate(5120).count(), 5120, "Enumerate can't count");
436            let enumeration: Vec<u8> = decompressor.enumerate(5120).map(|(_,x)| x).collect();
437            assert_eq!(enumeration, scanline, "Enumerate misaligned with copy to slice");
438            //assert_eq!(scanline, &undecoded.scanline_data[5121*i+1..5121*(i+1)], "Incorrect data at {}", i);
439            decompressor.remove_data(5121);
440        }
441        assert_eq!(decompressor.buffer_count, 0, "Main buffer left");
442        assert!(decompressor.chunk_end, "Decompression left some data");
443        assert_eq!(decompressor.extra_count, 0, "Extra buffer left");
444    }
445
446    #[test]
447    fn decode() {
448        let bytes = fs::read("sekiro.png").unwrap();
449        let png = ParsedPng::from_bytes(&bytes, true, AlphaColor).unwrap();
450
451        /*let mut undecoded = pre_decode(&bytes).unwrap();
452        let mut image = vec![0_u8; 1280*720*4];
453        undecoded.process_scanlines(
454            |scanline_iter,xy_calculator,y| {
455                for (idx, (r, g, b, a)) in scanline_iter.enumerate() {
456                    let (x, y) = xy_calculator.get_xy(idx, y);
457                    let idx = (x + y * 1280)*4;
458                    image[idx] = r;
459                    image[idx+1] = g;
460                    image[idx+2] = b;
461                    image[idx+3] = a;
462                }
463            }).unwrap();
464*/
465        let mut decompressor = ChunkDecompressor::new_vec(png.data_chunks, true);
466        let mut scanline = vec![0_u8; 5120];
467        for i in 0..720 {
468            let r = decompressor.decode_next_scanline(&mut scanline, 4);
469            assert!(r.is_ok(), "Get data Error");
470            //assert_eq!(scanline, &image[i*5120..i*5120 + 5120], "Incorrect image at {}", i);
471        }
472        assert_eq!(decompressor.buffer_count, 0, "Main buffer left");
473        assert!(decompressor.chunk_end, "Decompression left some data");
474        assert_eq!(decompressor.extra_count, 0, "Extra buffer left");
475    }
476    
477    // TODO test buffer limits (max size +-1)
478}