1use super::base64_decode;
2use super::deflate_compress;
3use crate::{Transform, TransformError, TransformerCategory};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub struct DeflateDecompress;
11
12pub(crate) struct BitReader<'a> {
14 bytes: &'a [u8],
15 byte_index: usize, bit_position: u8, }
18
19impl<'a> BitReader<'a> {
20 fn new(bytes: &'a [u8]) -> Self {
21 BitReader {
22 bytes,
23 byte_index: 0,
24 bit_position: 0,
25 }
26 }
27
28 fn read_bits(&mut self, num_bits: u8) -> Result<u32, TransformError> {
30 if num_bits > 32 {
31 return Err(TransformError::CompressionError(
32 "Cannot read more than 32 bits at once".to_string(),
33 ));
34 }
35 let mut value = 0u32;
36 let mut bits_read = 0u8;
37 while bits_read < num_bits {
38 if self.byte_index >= self.bytes.len() {
39 if bits_read < num_bits {
41 if num_bits - bits_read > 7 {
43 return Err(TransformError::CompressionError(
44 "Unexpected end of DEFLATE stream (large bit request past EOF)"
45 .to_string(),
46 ));
47 }
48 break; }
51 }
53
54 let current_byte = self.bytes[self.byte_index];
55 let bits_to_read_from_byte = 8 - self.bit_position;
56 let bits_needed = num_bits - bits_read;
57 let bits_to_read = std::cmp::min(bits_needed, bits_to_read_from_byte);
58
59 let mask = (1u32 << bits_to_read) - 1;
61 let byte_part = (current_byte >> self.bit_position) & (mask as u8);
62 value |= (byte_part as u32) << bits_read;
63
64 self.bit_position += bits_to_read;
65 bits_read += bits_to_read;
66
67 if self.bit_position == 8 {
68 self.bit_position = 0;
69 self.byte_index += 1;
70 }
71 }
72 Ok(value)
73 }
74
75 fn align_to_byte(&mut self) {
77 if self.bit_position > 0 {
78 self.bit_position = 0;
79 self.byte_index += 1;
80 }
81 }
82
83 fn remaining_bytes(&self) -> usize {
85 self.bytes.len().saturating_sub(self.byte_index)
86 }
87}
88
89const MAX_BITS_LITLEN: u8 = 9;
91const MAX_BITS_DIST: u8 = 5;
92
93#[derive(Clone)]
94struct HuffmanCode {
95 symbol: u16,
96 length: u8,
97}
98
99struct FixedHuffmanDecoder {
101 litlen_lookup: HashMap<u16, HuffmanCode>,
102 dist_lookup: HashMap<u16, HuffmanCode>,
103}
104
105impl FixedHuffmanDecoder {
106 fn new() -> Self {
107 let (litlen_table, dist_table) = Self::build_fixed_tables();
108 FixedHuffmanDecoder {
109 litlen_lookup: litlen_table,
110 dist_lookup: dist_table,
111 }
112 }
113
114 fn build_fixed_tables() -> (HashMap<u16, HuffmanCode>, HashMap<u16, HuffmanCode>) {
116 let mut litlen_lookup = HashMap::new();
117 let mut dist_lookup = HashMap::new();
118
119 for symbol in 0..=287u16 {
121 let (code, len) = match symbol {
122 0..=143 => (0x30 + symbol, 8),
123 144..=255 => (0x190 + (symbol - 144), 9),
124 256..=279 => (symbol - 256, 7),
125 280..=285 => (0xC0 + (symbol - 280), 8),
126 _ => (0, 0), };
128 if len > 0 {
129 let reversed_code = deflate_compress::reverse_bits(code, len);
130 litlen_lookup.insert(
131 reversed_code,
132 HuffmanCode {
133 symbol,
134 length: len,
135 },
136 );
137 }
138 }
139
140 for symbol in 0..=31u16 {
142 let code = symbol;
143 let len = 5;
144 let reversed_code = deflate_compress::reverse_bits(code, len);
145 dist_lookup.insert(
146 reversed_code,
147 HuffmanCode {
148 symbol,
149 length: len,
150 },
151 );
152 }
153
154 (litlen_lookup, dist_lookup)
155 }
156
157 fn decode_literal_length(&self, reader: &mut BitReader) -> Result<u16, TransformError> {
159 let mut current_bits = 0u16;
160 let mut len = 0u8;
161 loop {
162 let bit = reader.read_bits(1)? as u16;
163 current_bits |= bit << len;
164 len += 1;
165 if let Some(code) = self.litlen_lookup.get(¤t_bits) {
166 if code.length == len {
167 return Ok(code.symbol);
168 }
169 }
170 if len > MAX_BITS_LITLEN {
171 return Err(TransformError::CompressionError(format!(
172 "Invalid Huffman code found (litlen prefix: {:b}, len: {})",
173 current_bits, len
174 )));
175 }
176 }
177 }
178
179 fn decode_distance(&self, reader: &mut BitReader) -> Result<u16, TransformError> {
181 let mut current_bits = 0u16;
182 let mut len = 0u8;
183 loop {
184 let bit = reader.read_bits(1)? as u16;
185 current_bits |= bit << len;
186 len += 1;
187 if let Some(code) = self.dist_lookup.get(¤t_bits) {
188 if code.length == len {
189 if code.symbol <= 29 {
190 return Ok(code.symbol);
192 } else {
193 return Err(TransformError::CompressionError(format!(
194 "Invalid distance symbol {} decoded",
195 code.symbol
196 )));
197 }
198 }
199 }
200 if len > MAX_BITS_DIST {
201 return Err(TransformError::CompressionError(format!(
202 "Invalid fixed Huffman distance code found (prefix: {:b}, len: {})",
203 current_bits, len
204 )));
205 }
206 }
207 }
208}
209
210pub(crate) fn deflate_decode_bytes(
213 compressed_bytes: &[u8],
214) -> Result<(Vec<u8>, usize), TransformError> {
215 if compressed_bytes.is_empty() {
216 return Ok((Vec::new(), 0)); }
218
219 let mut reader = BitReader::new(compressed_bytes);
220 let mut output: Vec<u8> = Vec::with_capacity(compressed_bytes.len() * 3);
221 let fixed_decoder = FixedHuffmanDecoder::new();
222
223 loop {
224 let bfinal = reader.read_bits(1)?;
225 let btype = reader.read_bits(2)?;
226
227 match btype {
228 0b00 => {
229 reader.align_to_byte();
231 let len = reader.read_bits(16)? as u16;
232 let nlen = reader.read_bits(16)? as u16;
233 if len != !nlen {
234 return Err(TransformError::CompressionError("LEN/NLEN mismatch".into()));
235 }
236 let len_usize = len as usize;
237 let remaining_bytes = reader.remaining_bytes();
239 let bytes_needed = if reader.bit_position == 0 {
240 len_usize
241 } else {
242 len_usize + 1
244 };
245 if remaining_bytes < bytes_needed {
246 return Err(TransformError::CompressionError(
247 "Unexpected end of stream reading uncompressed data".into(),
248 ));
249 }
250 output.reserve(len_usize);
251 for _ in 0..len_usize {
252 if reader.bit_position != 0 {
253 return Err(TransformError::CompressionError(
254 "Misaligned stream reading uncompressed data byte".into(),
255 ));
256 }
257 let byte = reader.read_bits(8)? as u8;
258 output.push(byte);
259 }
260 }
261 0b01 => {
262 loop {
264 let lit_len_code = fixed_decoder.decode_literal_length(&mut reader)?;
265 match lit_len_code {
266 0..=255 => {
267 output.push(lit_len_code as u8);
268 }
269 256 => {
270 break; }
272 257..=285 => {
273 let (len_base, len_extra_bits) =
275 deflate_compress::get_length_info(lit_len_code);
276 let len_extra_val = if len_extra_bits > 0 {
277 reader.read_bits(len_extra_bits)?
278 } else {
279 0
280 };
281 let length = len_base + len_extra_val as u16;
282
283 let dist_code = fixed_decoder.decode_distance(&mut reader)?;
284 let (dist_base, dist_extra_bits) =
285 deflate_compress::get_distance_info(dist_code);
286 let dist_extra_val = if dist_extra_bits > 0 {
287 reader.read_bits(dist_extra_bits)?
288 } else {
289 0
290 };
291 let distance = dist_base + dist_extra_val as u16;
292
293 let current_len = output.len();
294 if distance as usize > current_len {
295 return Err(TransformError::CompressionError(format!(
296 "Invalid back-reference distance {} > {}",
297 distance, current_len
298 )));
299 }
300 let start = current_len - distance as usize;
301 output.reserve(length as usize);
302 for i in 0..length {
303 let copied_byte = output[start + i as usize];
304 output.push(copied_byte);
305 }
306 }
307 _ => unreachable!(),
308 }
309 }
310 }
311 0b10 => {
312 return Err(TransformError::CompressionError(
314 "Dynamic Huffman codes (BTYPE=10) are not supported".into(),
315 ));
316 }
317 _ => {
318 return Err(TransformError::CompressionError(
320 "Invalid or reserved block type (BTYPE=11)".into(),
321 ));
322 }
323 }
324
325 if bfinal == 1 {
326 break;
327 }
328 }
329
330 let consumed_bytes = if reader.bit_position > 0 {
331 reader.byte_index + 1 } else {
333 reader.byte_index
334 };
335
336 Ok((output, consumed_bytes)) }
338
339impl Transform for DeflateDecompress {
340 fn name(&self) -> &'static str {
341 "DEFLATE Decompress"
342 }
343
344 fn id(&self) -> &'static str {
345 "deflatedecompress"
346 }
347
348 fn category(&self) -> TransformerCategory {
349 TransformerCategory::Compression
350 }
351
352 fn description(&self) -> &'static str {
353 "Decompresses DEFLATE input (RFC 1951). Expects Base64 input."
354 }
355
356 fn transform(&self, input: &str) -> Result<String, TransformError> {
357 let compressed_bytes = base64_decode::base64_decode(input).map_err(|e| {
358 TransformError::InvalidArgument(format!("Invalid Base64 input: {}", e).into())
359 })?;
360 let (output, _consumed_bytes) = deflate_decode_bytes(&compressed_bytes)?;
362 String::from_utf8(output).map_err(|_| TransformError::Utf8Error)
363 }
364
365 fn default_test_input(&self) -> &'static str {
366 "80jNycnXUSjPL8pJUQQA" }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use crate::transformers::base64_encode;
374
375 #[test]
376 fn test_decompress_uncompressed_block() {
377 let transformer = DeflateDecompress;
378 let input_str = "test";
379 let compressed_bytes = vec![0x01, 0x04, 0x00, 0xFB, 0xFF, 0x74, 0x65, 0x73, 0x74];
381 let base64_input = base64_encode::base64_encode(&compressed_bytes);
382
383 match transformer.transform(&base64_input) {
384 Ok(decompressed) => {
385 assert_eq!(decompressed, input_str);
386 }
387 Err(e) => {
388 panic!("Decompression failed for uncompressed block: {:?}", e);
389 }
390 }
391 }
392
393 #[test]
394 fn test_decompress_empty() {
395 let transformer = DeflateDecompress;
396 assert_eq!(transformer.transform("").unwrap(), "");
397 assert_eq!(transformer.transform("AwA=").unwrap(), "");
398 }
399
400 #[test]
401 fn test_decompress_fixed_simple() {
402 let decompressor = DeflateDecompress;
403 let expected_output = "Hello, world!"; let decompressed = decompressor
405 .transform(decompressor.default_test_input())
406 .unwrap();
407 assert_eq!(decompressed, expected_output);
408
409 let input_hi_b64 = "80jMygUA"; let decompressed_hi_result = decompressor.transform(input_hi_b64);
414 assert!(decompressed_hi_result.is_ok()); }
417}