1use flate2::{Decompress, FlushDecompress, Status};
9
10use super::predictor;
11use crate::error::DecodeError;
12
13const MAX_DECOMPRESSED_SIZE: u64 = rpdfium_core::fx_system::DEFAULT_MAX_DECOMPRESSED_STREAM_SIZE;
15
16pub fn decode(
24 input: &[u8],
25 predictor: Option<i32>,
26 columns: Option<i32>,
27 colors: Option<i32>,
28 bits_per_component: Option<i32>,
29) -> Result<Vec<u8>, DecodeError> {
30 let decompressed = inflate(input)?;
31
32 let predictor_val = predictor.unwrap_or(1);
33 if predictor_val <= 1 {
34 return Ok(decompressed);
35 }
36
37 let columns = columns.unwrap_or(1);
38 let colors = colors.unwrap_or(1);
39 let bpc = bits_per_component.unwrap_or(8);
40
41 predictor::apply_predictor(&decompressed, predictor_val, columns, colors, bpc)
42}
43
44pub fn encode(input: &[u8]) -> Vec<u8> {
48 use flate2::Compression;
49 use flate2::write::ZlibEncoder;
50 use std::io::Write;
51
52 let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
53 encoder.write_all(input).unwrap_or_default();
54 encoder.finish().unwrap_or_default()
55}
56
57fn inflate(input: &[u8]) -> Result<Vec<u8>, DecodeError> {
59 let mut decompressor = Decompress::new(true); let mut output = Vec::with_capacity(input.len().saturating_mul(2).min(1 << 20));
61 let mut buf = [0u8; 32 * 1024]; loop {
64 let in_before = decompressor.total_in();
65 let out_before = decompressor.total_out();
66
67 let status = decompressor
68 .decompress(
69 &input[in_before as usize..],
70 &mut buf,
71 FlushDecompress::None,
72 )
73 .map_err(|e| DecodeError::Flate(e.to_string()))?;
74
75 let bytes_written = (decompressor.total_out() - out_before) as usize;
76 output.extend_from_slice(&buf[..bytes_written]);
77
78 if output.len() as u64 > MAX_DECOMPRESSED_SIZE {
79 return Err(DecodeError::OutputTooLarge {
80 limit: MAX_DECOMPRESSED_SIZE,
81 });
82 }
83
84 match status {
85 Status::StreamEnd => break,
86 Status::Ok | Status::BufError => {
87 let in_after = decompressor.total_in();
88 let out_after = decompressor.total_out();
89 if in_after == in_before && out_after == out_before {
91 break;
92 }
93 }
94 }
95 }
96
97 Ok(output)
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn test_encode_decode_roundtrip() {
106 let original = b"Hello, World! FlateEncode/FlateDecode round-trip.";
107 let compressed = encode(original);
108 assert!(!compressed.is_empty());
109 let result = decode(&compressed, None, None, None, None).unwrap();
110 assert_eq!(result, original);
111 }
112
113 #[test]
114 fn test_encode_empty() {
115 let compressed = encode(b"");
116 assert!(!compressed.is_empty());
118 let result = decode(&compressed, None, None, None, None).unwrap();
119 assert!(result.is_empty());
120 }
121
122 #[test]
123 fn test_encode_all_bytes() {
124 let original: Vec<u8> = (0u8..=255).collect();
125 let compressed = encode(&original);
126 assert_eq!(
128 decode(&compressed, None, None, None, None).unwrap(),
129 original
130 );
131 }
132
133 #[test]
134 fn test_encode_large_data() {
135 let original: Vec<u8> = (0..10_000).map(|i| (i % 256) as u8).collect();
136 let compressed = encode(&original);
137 assert!(compressed.len() < original.len());
139 assert_eq!(
140 decode(&compressed, None, None, None, None).unwrap(),
141 original
142 );
143 }
144
145 #[test]
146 fn test_decode_invalid_data() {
147 let result = decode(b"not valid zlib data", None, None, None, None);
148 assert!(result.is_err());
149 }
150
151 #[test]
152 fn test_decode_with_png_sub_predictor() {
153 let encoded_data = vec![1u8, 10, 10];
156 let compressed = encode(&encoded_data);
157 let result = decode(&compressed, Some(15), Some(2), Some(1), Some(8)).unwrap();
158 assert_eq!(result, vec![10, 20]);
159 }
160
161 #[test]
162 fn test_decode_with_png_up_predictor() {
163 let encoded_data = vec![0u8, 10, 20, 2, 5, 5];
167 let compressed = encode(&encoded_data);
168 let result = decode(&compressed, Some(15), Some(2), Some(1), Some(8)).unwrap();
169 assert_eq!(result, vec![10, 20, 15, 25]);
170 }
171
172 #[test]
173 fn test_decode_with_tiff_predictor() {
174 let encoded_data = vec![10u8, 5, 3];
176 let compressed = encode(&encoded_data);
177 let result = decode(&compressed, Some(2), Some(3), Some(1), Some(8)).unwrap();
178 assert_eq!(result, vec![10, 15, 18]);
179 }
180
181 #[test]
182 fn test_decode_no_predictor_explicit() {
183 let original = b"test data";
184 let compressed = encode(original);
185 let result = decode(&compressed, Some(1), None, None, None).unwrap();
186 assert_eq!(result, original);
187 }
188
189 #[test]
190 fn test_flate_roundtrip_pseudo_random_data() {
191 let data: Vec<u8> = (0..1000).map(|i| ((i * 7 + 13) % 256) as u8).collect();
193 let compressed = encode(&data);
194 let decompressed = decode(&compressed, None, None, None, None).unwrap();
195 assert_eq!(decompressed, data);
196 }
197
198 #[test]
199 fn test_flate_roundtrip_all_same_bytes() {
200 let data = vec![0xAA; 10_000];
202 let compressed = encode(&data);
203 assert!(
204 compressed.len() < 100,
205 "identical bytes should compress well"
206 );
207 let decompressed = decode(&compressed, None, None, None, None).unwrap();
208 assert_eq!(decompressed, data);
209 }
210
211 #[test]
212 fn test_flate_roundtrip_single_byte() {
213 let data = vec![42u8];
214 let compressed = encode(&data);
215 let decompressed = decode(&compressed, None, None, None, None).unwrap();
216 assert_eq!(decompressed, data);
217 }
218
219 #[test]
220 fn test_flate_roundtrip_alternating_pattern() {
221 let data: Vec<u8> = (0..2000)
223 .map(|i| if i % 2 == 0 { 0x00 } else { 0xFF })
224 .collect();
225 let compressed = encode(&data);
226 let decompressed = decode(&compressed, None, None, None, None).unwrap();
227 assert_eq!(decompressed, data);
228 }
229}