buup/transformers/
deflate_compress.rs

1use crate::{Transform, TransformError, TransformerCategory};
2// Import the shared base64 encoder
3use super::base64_encode;
4
5// Length and Distance Codes from RFC 1951 Section 3.2.5
6pub(crate) const LENGTH_CODES: [(u16, u16, u8); 29] = [
7    (257, 3, 0),
8    (258, 4, 0),
9    (259, 5, 0),
10    (260, 6, 0),
11    (261, 7, 0),
12    (262, 8, 0),
13    (263, 9, 0),
14    (264, 10, 0),
15    (265, 11, 1),
16    (266, 13, 1),
17    (267, 15, 1),
18    (268, 17, 1),
19    (269, 19, 2),
20    (270, 23, 2),
21    (271, 27, 2),
22    (272, 31, 2),
23    (273, 35, 3),
24    (274, 43, 3),
25    (275, 51, 3),
26    (276, 59, 3),
27    (277, 67, 4),
28    (278, 83, 4),
29    (279, 99, 4),
30    (280, 115, 4),
31    (281, 131, 5),
32    (282, 163, 5),
33    (283, 195, 5),
34    (284, 227, 5),
35    (285, 258, 0),
36];
37
38pub(crate) const DISTANCE_CODES: [(u16, u16, u8); 30] = [
39    (0, 1, 0),
40    (1, 2, 0),
41    (2, 3, 0),
42    (3, 4, 0),
43    (4, 5, 1),
44    (5, 7, 1),
45    (6, 9, 2),
46    (7, 13, 2),
47    (8, 17, 3),
48    (9, 25, 3),
49    (10, 33, 4),
50    (11, 49, 4),
51    (12, 65, 5),
52    (13, 97, 5),
53    (14, 129, 6),
54    (15, 193, 6),
55    (16, 257, 7),
56    (17, 385, 7),
57    (18, 513, 8),
58    (19, 769, 8),
59    (20, 1025, 9),
60    (21, 1537, 9),
61    (22, 2049, 10),
62    (23, 3073, 10),
63    (24, 4097, 11),
64    (25, 6145, 11),
65    (26, 8193, 12),
66    (27, 12289, 12),
67    (28, 16385, 13),
68    (29, 24577, 13),
69];
70
71// Finds the DEFLATE length code and extra bits for a given length (3-258).
72fn get_length_code(length: u16) -> (u16, u32, u8) {
73    assert!(
74        (3..=258).contains(&length),
75        "Length must be between 3 and 258 inclusive"
76    );
77    if length == 258 {
78        return (285, 0, 0);
79    }
80    for i in 0..LENGTH_CODES.len() - 1 {
81        let (code, base_len, num_extra_bits) = LENGTH_CODES[i];
82        let next_base_len = if i + 1 < LENGTH_CODES.len() - 1 {
83            LENGTH_CODES[i + 1].1
84        } else {
85            258
86        };
87        let range_limit = base_len + (1 << num_extra_bits) - 1;
88        if length >= base_len && length <= range_limit {
89            let extra_val = length - base_len;
90            return (code, extra_val as u32, num_extra_bits);
91        }
92        if length > range_limit && length < next_base_len {
93            panic!("Length {} falls between code ranges", length);
94        }
95    }
96    panic!("Length code not found for {}", length);
97}
98
99// Finds the DEFLATE distance code and extra bits for a given distance (1-32768).
100fn get_distance_code(distance: u16) -> (u16, u32, u8) {
101    assert!(
102        (1..=32768).contains(&distance),
103        "Distance must be between 1 and 32768 inclusive"
104    );
105    for i in 0..DISTANCE_CODES.len() {
106        let (code, base_dist, num_extra_bits) = DISTANCE_CODES[i];
107        let range_limit = base_dist + (1 << num_extra_bits) - 1;
108        if distance >= base_dist && distance <= range_limit {
109            let extra_val = distance - base_dist;
110            return (code, extra_val as u32, num_extra_bits);
111        }
112        if i + 1 < DISTANCE_CODES.len() {
113            let next_base_dist = DISTANCE_CODES[i + 1].1;
114            if distance > range_limit && distance < next_base_dist {
115                panic!("Distance {} falls between code ranges", distance);
116            }
117        } else if distance > range_limit {
118            panic!("Distance {} is out of bounds (> 32768?)", distance);
119        }
120    }
121    panic!("Distance code not found for {}", distance);
122}
123
124/// Get length base and extra bits count from length code (257-285)
125pub(crate) fn get_length_info(code: u16) -> (u16, u8) {
126    // (base, extra_bits)
127    assert!((257..=285).contains(&code));
128    for &(c, base, extra) in LENGTH_CODES.iter() {
129        if c == code {
130            return (base, extra);
131        }
132    }
133    unreachable!(); // Code asserted to be in range
134}
135
136/// Get distance base and extra bits count from distance code (0-29)
137pub(crate) fn get_distance_info(code: u16) -> (u16, u8) {
138    // (base, extra_bits)
139    assert!(code <= 29);
140    for &(c, base, extra) in DISTANCE_CODES.iter() {
141        if c == code {
142            return (base, extra);
143        }
144    }
145    unreachable!(); // Code asserted to be in range
146}
147
148// Reverses the lowest `num_bits` of `value`.
149pub(crate) fn reverse_bits(value: u16, num_bits: u8) -> u16 {
150    let mut result = 0u16;
151    let mut v = value;
152    for _ in 0..num_bits {
153        result <<= 1;
154        if (v & 1) == 1 {
155            result |= 1;
156        }
157        v >>= 1;
158    }
159    result
160}
161
162// Returns the bit-reversed fixed Huffman code pattern and bit length for a given literal/length code (0-285).
163fn get_fixed_literal_length_huffman_code(code: u16) -> (u16, u8) {
164    let (pattern, num_bits) = match code {
165        0..=143 => (0b00110000 + code, 8),
166        144..=255 => (0b110010000 + (code - 144u16), 9),
167        256..=279 => (code - 256u16, 7),
168        280..=285 => (0b11000000 + (code - 280u16), 8),
169        _ => panic!("Invalid literal/length code for fixed Huffman: {}", code),
170    };
171    (reverse_bits(pattern, num_bits), num_bits)
172}
173
174// Returns the bit-reversed fixed Huffman code pattern (5 bits) and bit length for a given distance code (0-29).
175fn get_fixed_distance_huffman_code(distance_code: u16) -> (u16, u8) {
176    let num_bits = 5;
177    if distance_code <= 29 {
178        (reverse_bits(distance_code, num_bits), num_bits)
179    } else {
180        panic!("Invalid distance code for fixed Huffman: {}", distance_code);
181    }
182}
183
184// Writes bits LSB-first into a byte vector.
185struct BitWriter {
186    bytes: Vec<u8>,
187    current_byte: u8,
188    bit_position: u8, // Next bit position to write (0-7)
189}
190
191impl BitWriter {
192    fn new() -> Self {
193        BitWriter {
194            bytes: Vec::new(),
195            current_byte: 0,
196            bit_position: 0,
197        }
198    }
199
200    fn write_bits(&mut self, mut value: u32, mut num_bits: u8) {
201        while num_bits > 0 {
202            let remaining_bits_in_byte = 8 - self.bit_position;
203            let bits_to_write = std::cmp::min(num_bits, remaining_bits_in_byte);
204            let bit_mask = (1u32 << bits_to_write) - 1;
205            let bits = (value & bit_mask) as u8;
206            self.current_byte |= bits << self.bit_position;
207            self.bit_position += bits_to_write;
208            if self.bit_position == 8 {
209                self.bytes.push(self.current_byte);
210                self.current_byte = 0;
211                self.bit_position = 0;
212            }
213            value >>= bits_to_write;
214            num_bits -= bits_to_write;
215        }
216    }
217
218    fn flush_byte(&mut self) {
219        if self.bit_position > 0 {
220            self.bytes.push(self.current_byte);
221            self.current_byte = 0;
222            self.bit_position = 0;
223        }
224    }
225
226    fn get_bytes(mut self) -> Vec<u8> {
227        self.flush_byte();
228        self.bytes
229    }
230
231    fn align_to_byte(&mut self) {
232        if self.bit_position > 0 {
233            self.bytes.push(self.current_byte);
234            self.current_byte = 0;
235            self.bit_position = 0;
236        }
237    }
238
239    fn write_bytes_raw(&mut self, bytes: &[u8]) {
240        assert!(self.bit_position == 0, "Writer must be byte-aligned");
241        self.bytes.extend_from_slice(bytes);
242    }
243}
244
245// --- LZ77 Implementation ---
246const MAX_WINDOW_SIZE: usize = 32 * 1024;
247const MIN_MATCH_LEN: usize = 3;
248const MAX_MATCH_LEN: usize = 258;
249const HASH_TABLE_SIZE: usize = 1 << 15;
250
251#[derive(Debug, Clone, PartialEq)]
252enum Lz77Token {
253    Literal(u8),
254    Match(u16, u16), // length, distance
255}
256
257fn lz77_compress(input: &[u8]) -> Vec<Lz77Token> {
258    if input.is_empty() {
259        return Vec::new();
260    }
261    let mut tokens = Vec::new();
262    let mut head: Vec<Option<usize>> = vec![None; HASH_TABLE_SIZE];
263    let mut prev: Vec<Option<usize>> = vec![None; MAX_WINDOW_SIZE];
264    let mut current_pos = 0;
265    while current_pos < input.len() {
266        let window_start = if current_pos > MAX_WINDOW_SIZE {
267            current_pos - MAX_WINDOW_SIZE
268        } else {
269            0
270        };
271        if current_pos + MIN_MATCH_LEN > input.len() {
272            tokens.extend(input[current_pos..].iter().map(|&b| Lz77Token::Literal(b)));
273            break;
274        }
275        let hash = calculate_hash(&input[current_pos..current_pos + MIN_MATCH_LEN]);
276        let mut best_match_len = 0;
277        let mut best_match_dist = 0;
278        let mut match_pos_opt = head[hash];
279        while let Some(match_pos) = match_pos_opt {
280            if match_pos < window_start {
281                break;
282            }
283            let current_match_len =
284                calculate_match_length(input, match_pos, current_pos, MAX_MATCH_LEN);
285            if current_match_len >= MIN_MATCH_LEN && current_match_len > best_match_len {
286                best_match_len = current_match_len;
287                best_match_dist = (current_pos - match_pos) as u16;
288                if best_match_len == MAX_MATCH_LEN {
289                    break;
290                }
291            }
292            match_pos_opt = prev[match_pos % MAX_WINDOW_SIZE];
293        }
294        prev[current_pos % MAX_WINDOW_SIZE] = head[hash];
295        head[hash] = Some(current_pos);
296        if best_match_len >= MIN_MATCH_LEN {
297            tokens.push(Lz77Token::Match(best_match_len as u16, best_match_dist));
298            // Lazy update hash table for skipped bytes
299            for i in 1..best_match_len {
300                let pos_to_update = current_pos + i;
301                if pos_to_update + MIN_MATCH_LEN <= input.len() {
302                    let next_hash =
303                        calculate_hash(&input[pos_to_update..pos_to_update + MIN_MATCH_LEN]);
304                    prev[pos_to_update % MAX_WINDOW_SIZE] = head[next_hash];
305                    head[next_hash] = Some(pos_to_update);
306                }
307            }
308            current_pos += best_match_len;
309        } else {
310            tokens.push(Lz77Token::Literal(input[current_pos]));
311            current_pos += 1;
312        }
313    }
314    tokens
315}
316
317#[inline]
318fn calculate_hash(bytes: &[u8]) -> usize {
319    (((bytes[0] as usize) << 8) | ((bytes[1] as usize) << 4) | (bytes[2] as usize))
320        % HASH_TABLE_SIZE
321}
322
323#[inline]
324fn calculate_match_length(input: &[u8], pos1: usize, pos2: usize, max_len: usize) -> usize {
325    let mut len = 0;
326    let input_len = input.len();
327    while len < max_len && pos2 + len < input_len && input[pos1 + len] == input[pos2 + len] {
328        len += 1;
329    }
330    len
331}
332
333// Extracted core DEFLATE compression logic (without Base64 encoding)
334pub(crate) fn deflate_bytes(input_bytes: &[u8]) -> Result<Vec<u8>, TransformError> {
335    let mut writer = BitWriter::new();
336
337    if input_bytes.is_empty() {
338        // Minimal fixed block for empty input.
339        writer.write_bits(1, 1); // BFINAL
340        writer.write_bits(1, 2); // BTYPE=01 (Fixed Huffman)
341        let (reversed_eob_huff, eob_bits) = get_fixed_literal_length_huffman_code(256); // EOB
342        writer.write_bits(reversed_eob_huff as u32, eob_bits);
343        return Ok(writer.get_bytes());
344    }
345
346    let lz77_tokens = lz77_compress(input_bytes);
347
348    // Estimate size to choose between fixed Huffman and uncompressed block.
349    let mut estimated_bits = 0;
350    for token in &lz77_tokens {
351        match token {
352            Lz77Token::Literal(byte) => {
353                let (_, bits) = get_fixed_literal_length_huffman_code(*byte as u16);
354                estimated_bits += bits as usize;
355            }
356            Lz77Token::Match(length, distance) => {
357                let (len_code, _, len_extra_bits) = get_length_code(*length);
358                let (_, len_huff_bits) = get_fixed_literal_length_huffman_code(len_code);
359                estimated_bits += len_huff_bits as usize + len_extra_bits as usize;
360
361                let (dist_code, _, dist_extra_bits) = get_distance_code(*distance);
362                let (_, dist_huff_bits) = get_fixed_distance_huffman_code(dist_code);
363                estimated_bits += dist_huff_bits as usize + dist_extra_bits as usize;
364            }
365        }
366    }
367    let (_, eob_bits) = get_fixed_literal_length_huffman_code(256); // EOB marker
368    estimated_bits += eob_bits as usize;
369    estimated_bits += 3; // BFINAL + BTYPE bits
370
371    let uncompressed_size_bytes = input_bytes.len() + 5;
372    let uncompressed_size_bits = uncompressed_size_bytes * 8;
373
374    // --- Write DEFLATE Stream ---
375    writer.write_bits(1, 1); // BFINAL = 1
376
377    if estimated_bits >= uncompressed_size_bits {
378        // Write uncompressed block (BTYPE=00).
379        writer.write_bits(0, 2); // BTYPE=00
380        writer.align_to_byte();
381        let len: u16 = input_bytes.len().try_into().map_err(|_| {
382            TransformError::CompressionError(
383                "Input too large for uncompressed block length (max 65535)".into(),
384            )
385        })?;
386        let nlen = !len;
387        writer.write_bytes_raw(&len.to_le_bytes());
388        writer.write_bytes_raw(&nlen.to_le_bytes());
389        writer.write_bytes_raw(input_bytes);
390    } else {
391        // Write fixed Huffman block (BTYPE=01).
392        writer.write_bits(1, 2); // BTYPE=01
393        for token in lz77_tokens {
394            match token {
395                Lz77Token::Match(length, distance) => {
396                    let (len_code, len_extra_val, len_extra_bits) = get_length_code(length);
397                    let (reversed_len_huff, len_huff_bits) =
398                        get_fixed_literal_length_huffman_code(len_code);
399                    writer.write_bits(reversed_len_huff as u32, len_huff_bits);
400                    if len_extra_bits > 0 {
401                        writer.write_bits(len_extra_val, len_extra_bits);
402                    }
403
404                    let (dist_code, dist_extra_val, dist_extra_bits) = get_distance_code(distance);
405                    let (reversed_dist_huff, dist_huff_bits) =
406                        get_fixed_distance_huffman_code(dist_code);
407                    writer.write_bits(reversed_dist_huff as u32, dist_huff_bits);
408                    if dist_extra_bits > 0 {
409                        writer.write_bits(dist_extra_val, dist_extra_bits);
410                    }
411                }
412                Lz77Token::Literal(byte) => {
413                    let (reversed_huff, huff_bits) =
414                        get_fixed_literal_length_huffman_code(byte as u16);
415                    writer.write_bits(reversed_huff as u32, huff_bits);
416                }
417            }
418        }
419        // EOB marker.
420        let (reversed_eob_huff, eob_bits) = get_fixed_literal_length_huffman_code(256);
421        writer.write_bits(reversed_eob_huff as u32, eob_bits);
422    }
423
424    Ok(writer.get_bytes())
425}
426
427/// Deflate compression transformer (RFC 1951)
428#[derive(Debug, Clone, Copy, PartialEq, Eq)]
429pub struct DeflateCompress;
430
431impl Transform for DeflateCompress {
432    fn name(&self) -> &'static str {
433        "DEFLATE Compress"
434    }
435
436    fn id(&self) -> &'static str {
437        "deflatecompress"
438    }
439
440    fn category(&self) -> TransformerCategory {
441        TransformerCategory::Compression
442    }
443
444    fn description(&self) -> &'static str {
445        "Compresses input using the DEFLATE algorithm (RFC 1951) and encodes the output as Base64."
446    }
447
448    // Updated transform method uses deflate_bytes
449    fn transform(&self, input: &str) -> Result<String, TransformError> {
450        let input_bytes = input.as_bytes();
451        let compressed_data = deflate_bytes(input_bytes)?; // Call extracted function
452        Ok(base64_encode::base64_encode(&compressed_data))
453    }
454
455    fn default_test_input(&self) -> &'static str {
456        "Hello, Deflate World!"
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463    use crate::transformers::deflate_decompress::DeflateDecompress;
464    use crate::Transform;
465
466    #[test]
467    fn test_deflate_empty() {
468        let transformer = DeflateCompress;
469        let result = transformer.transform("");
470        assert!(result.is_ok());
471        // Expected raw DEFLATE for empty fixed block is [0x03, 0x00]
472        assert_eq!(result.unwrap(), "AwA=");
473    }
474
475    #[test]
476    fn test_deflate_simple() {
477        let compressor = DeflateCompress;
478        let decompressor = DeflateDecompress;
479        let input = compressor.default_test_input();
480        let compressed_b64 = compressor.transform(input).unwrap();
481        let decompressed = decompressor.transform(&compressed_b64).unwrap();
482        assert_eq!(decompressed, input);
483
484        // Original simple test
485        let input_hi = "Hi";
486        let compressed_hi_b64 = compressor.transform(input_hi).unwrap();
487        let decompressed_hi = decompressor.transform(&compressed_hi_b64).unwrap();
488        assert_eq!(decompressed_hi, input_hi);
489    }
490
491    #[test]
492    fn test_deflate_repeated() {
493        let transformer = DeflateCompress;
494        let input = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
495        let expected_base64 = "SyQZAAA=";
496        match transformer.transform(input) {
497            Ok(actual_base64) => {
498                assert_eq!(actual_base64, expected_base64);
499            }
500            Err(e) => {
501                panic!("transform failed for input '{}': {:?}", input, e);
502            }
503        }
504    }
505
506    #[test]
507    fn test_deflate_longer_text() {
508        let transformer = DeflateCompress;
509        let input =
510            "This is a slightly longer test string to see how DEFLATE compression handles it.";
511        let expected_base64 = "C8nILFYAokSF4pzM9IySnEqFnPy89NQihZLU4hKF4pKizLx0hZJ8heLUVIWM/HIFF1c3H8cQV4Xk/NyCotTi4sz8PIWMxLyUnFSgOSV6AA==";
512        match transformer.transform(input) {
513            Ok(actual_base64) => {
514                assert_eq!(actual_base64, expected_base64);
515            }
516            Err(e) => {
517                panic!("transform failed for input '{}': {:?}", input, e);
518            }
519        }
520    }
521}