1use crate::{Transform, TransformError, TransformerCategory};
2use super::base64_encode;
4
5pub(crate) const LENGTH_CODES: [(u16, u16, u8); 29] = [
7 (257, 3, 0),
8 (258, 4, 0),
9 (259, 5, 0),
10 (260, 6, 0),
11 (261, 7, 0),
12 (262, 8, 0),
13 (263, 9, 0),
14 (264, 10, 0),
15 (265, 11, 1),
16 (266, 13, 1),
17 (267, 15, 1),
18 (268, 17, 1),
19 (269, 19, 2),
20 (270, 23, 2),
21 (271, 27, 2),
22 (272, 31, 2),
23 (273, 35, 3),
24 (274, 43, 3),
25 (275, 51, 3),
26 (276, 59, 3),
27 (277, 67, 4),
28 (278, 83, 4),
29 (279, 99, 4),
30 (280, 115, 4),
31 (281, 131, 5),
32 (282, 163, 5),
33 (283, 195, 5),
34 (284, 227, 5),
35 (285, 258, 0),
36];
37
38pub(crate) const DISTANCE_CODES: [(u16, u16, u8); 30] = [
39 (0, 1, 0),
40 (1, 2, 0),
41 (2, 3, 0),
42 (3, 4, 0),
43 (4, 5, 1),
44 (5, 7, 1),
45 (6, 9, 2),
46 (7, 13, 2),
47 (8, 17, 3),
48 (9, 25, 3),
49 (10, 33, 4),
50 (11, 49, 4),
51 (12, 65, 5),
52 (13, 97, 5),
53 (14, 129, 6),
54 (15, 193, 6),
55 (16, 257, 7),
56 (17, 385, 7),
57 (18, 513, 8),
58 (19, 769, 8),
59 (20, 1025, 9),
60 (21, 1537, 9),
61 (22, 2049, 10),
62 (23, 3073, 10),
63 (24, 4097, 11),
64 (25, 6145, 11),
65 (26, 8193, 12),
66 (27, 12289, 12),
67 (28, 16385, 13),
68 (29, 24577, 13),
69];
70
71fn get_length_code(length: u16) -> (u16, u32, u8) {
73 assert!(
74 (3..=258).contains(&length),
75 "Length must be between 3 and 258 inclusive"
76 );
77 if length == 258 {
78 return (285, 0, 0);
79 }
80 for i in 0..LENGTH_CODES.len() - 1 {
81 let (code, base_len, num_extra_bits) = LENGTH_CODES[i];
82 let next_base_len = if i + 1 < LENGTH_CODES.len() - 1 {
83 LENGTH_CODES[i + 1].1
84 } else {
85 258
86 };
87 let range_limit = base_len + (1 << num_extra_bits) - 1;
88 if length >= base_len && length <= range_limit {
89 let extra_val = length - base_len;
90 return (code, extra_val as u32, num_extra_bits);
91 }
92 if length > range_limit && length < next_base_len {
93 panic!("Length {} falls between code ranges", length);
94 }
95 }
96 panic!("Length code not found for {}", length);
97}
98
99fn get_distance_code(distance: u16) -> (u16, u32, u8) {
101 assert!(
102 (1..=32768).contains(&distance),
103 "Distance must be between 1 and 32768 inclusive"
104 );
105 for i in 0..DISTANCE_CODES.len() {
106 let (code, base_dist, num_extra_bits) = DISTANCE_CODES[i];
107 let range_limit = base_dist + (1 << num_extra_bits) - 1;
108 if distance >= base_dist && distance <= range_limit {
109 let extra_val = distance - base_dist;
110 return (code, extra_val as u32, num_extra_bits);
111 }
112 if i + 1 < DISTANCE_CODES.len() {
113 let next_base_dist = DISTANCE_CODES[i + 1].1;
114 if distance > range_limit && distance < next_base_dist {
115 panic!("Distance {} falls between code ranges", distance);
116 }
117 } else if distance > range_limit {
118 panic!("Distance {} is out of bounds (> 32768?)", distance);
119 }
120 }
121 panic!("Distance code not found for {}", distance);
122}
123
124pub(crate) fn get_length_info(code: u16) -> (u16, u8) {
126 assert!((257..=285).contains(&code));
128 for &(c, base, extra) in LENGTH_CODES.iter() {
129 if c == code {
130 return (base, extra);
131 }
132 }
133 unreachable!(); }
135
136pub(crate) fn get_distance_info(code: u16) -> (u16, u8) {
138 assert!(code <= 29);
140 for &(c, base, extra) in DISTANCE_CODES.iter() {
141 if c == code {
142 return (base, extra);
143 }
144 }
145 unreachable!(); }
147
148pub(crate) fn reverse_bits(value: u16, num_bits: u8) -> u16 {
150 let mut result = 0u16;
151 let mut v = value;
152 for _ in 0..num_bits {
153 result <<= 1;
154 if (v & 1) == 1 {
155 result |= 1;
156 }
157 v >>= 1;
158 }
159 result
160}
161
162fn get_fixed_literal_length_huffman_code(code: u16) -> (u16, u8) {
164 let (pattern, num_bits) = match code {
165 0..=143 => (0b00110000 + code, 8),
166 144..=255 => (0b110010000 + (code - 144u16), 9),
167 256..=279 => (code - 256u16, 7),
168 280..=285 => (0b11000000 + (code - 280u16), 8),
169 _ => panic!("Invalid literal/length code for fixed Huffman: {}", code),
170 };
171 (reverse_bits(pattern, num_bits), num_bits)
172}
173
174fn get_fixed_distance_huffman_code(distance_code: u16) -> (u16, u8) {
176 let num_bits = 5;
177 if distance_code <= 29 {
178 (reverse_bits(distance_code, num_bits), num_bits)
179 } else {
180 panic!("Invalid distance code for fixed Huffman: {}", distance_code);
181 }
182}
183
184struct BitWriter {
186 bytes: Vec<u8>,
187 current_byte: u8,
188 bit_position: u8, }
190
191impl BitWriter {
192 fn new() -> Self {
193 BitWriter {
194 bytes: Vec::new(),
195 current_byte: 0,
196 bit_position: 0,
197 }
198 }
199
200 fn write_bits(&mut self, mut value: u32, mut num_bits: u8) {
201 while num_bits > 0 {
202 let remaining_bits_in_byte = 8 - self.bit_position;
203 let bits_to_write = std::cmp::min(num_bits, remaining_bits_in_byte);
204 let bit_mask = (1u32 << bits_to_write) - 1;
205 let bits = (value & bit_mask) as u8;
206 self.current_byte |= bits << self.bit_position;
207 self.bit_position += bits_to_write;
208 if self.bit_position == 8 {
209 self.bytes.push(self.current_byte);
210 self.current_byte = 0;
211 self.bit_position = 0;
212 }
213 value >>= bits_to_write;
214 num_bits -= bits_to_write;
215 }
216 }
217
218 fn flush_byte(&mut self) {
219 if self.bit_position > 0 {
220 self.bytes.push(self.current_byte);
221 self.current_byte = 0;
222 self.bit_position = 0;
223 }
224 }
225
226 fn get_bytes(mut self) -> Vec<u8> {
227 self.flush_byte();
228 self.bytes
229 }
230
231 fn align_to_byte(&mut self) {
232 if self.bit_position > 0 {
233 self.bytes.push(self.current_byte);
234 self.current_byte = 0;
235 self.bit_position = 0;
236 }
237 }
238
239 fn write_bytes_raw(&mut self, bytes: &[u8]) {
240 assert!(self.bit_position == 0, "Writer must be byte-aligned");
241 self.bytes.extend_from_slice(bytes);
242 }
243}
244
245const MAX_WINDOW_SIZE: usize = 32 * 1024;
247const MIN_MATCH_LEN: usize = 3;
248const MAX_MATCH_LEN: usize = 258;
249const HASH_TABLE_SIZE: usize = 1 << 15;
250
251#[derive(Debug, Clone, PartialEq)]
252enum Lz77Token {
253 Literal(u8),
254 Match(u16, u16), }
256
257fn lz77_compress(input: &[u8]) -> Vec<Lz77Token> {
258 if input.is_empty() {
259 return Vec::new();
260 }
261 let mut tokens = Vec::new();
262 let mut head: Vec<Option<usize>> = vec![None; HASH_TABLE_SIZE];
263 let mut prev: Vec<Option<usize>> = vec![None; MAX_WINDOW_SIZE];
264 let mut current_pos = 0;
265 while current_pos < input.len() {
266 let window_start = if current_pos > MAX_WINDOW_SIZE {
267 current_pos - MAX_WINDOW_SIZE
268 } else {
269 0
270 };
271 if current_pos + MIN_MATCH_LEN > input.len() {
272 tokens.extend(input[current_pos..].iter().map(|&b| Lz77Token::Literal(b)));
273 break;
274 }
275 let hash = calculate_hash(&input[current_pos..current_pos + MIN_MATCH_LEN]);
276 let mut best_match_len = 0;
277 let mut best_match_dist = 0;
278 let mut match_pos_opt = head[hash];
279 while let Some(match_pos) = match_pos_opt {
280 if match_pos < window_start {
281 break;
282 }
283 let current_match_len =
284 calculate_match_length(input, match_pos, current_pos, MAX_MATCH_LEN);
285 if current_match_len >= MIN_MATCH_LEN && current_match_len > best_match_len {
286 best_match_len = current_match_len;
287 best_match_dist = (current_pos - match_pos) as u16;
288 if best_match_len == MAX_MATCH_LEN {
289 break;
290 }
291 }
292 match_pos_opt = prev[match_pos % MAX_WINDOW_SIZE];
293 }
294 prev[current_pos % MAX_WINDOW_SIZE] = head[hash];
295 head[hash] = Some(current_pos);
296 if best_match_len >= MIN_MATCH_LEN {
297 tokens.push(Lz77Token::Match(best_match_len as u16, best_match_dist));
298 for i in 1..best_match_len {
300 let pos_to_update = current_pos + i;
301 if pos_to_update + MIN_MATCH_LEN <= input.len() {
302 let next_hash =
303 calculate_hash(&input[pos_to_update..pos_to_update + MIN_MATCH_LEN]);
304 prev[pos_to_update % MAX_WINDOW_SIZE] = head[next_hash];
305 head[next_hash] = Some(pos_to_update);
306 }
307 }
308 current_pos += best_match_len;
309 } else {
310 tokens.push(Lz77Token::Literal(input[current_pos]));
311 current_pos += 1;
312 }
313 }
314 tokens
315}
316
317#[inline]
318fn calculate_hash(bytes: &[u8]) -> usize {
319 (((bytes[0] as usize) << 8) | ((bytes[1] as usize) << 4) | (bytes[2] as usize))
320 % HASH_TABLE_SIZE
321}
322
323#[inline]
324fn calculate_match_length(input: &[u8], pos1: usize, pos2: usize, max_len: usize) -> usize {
325 let mut len = 0;
326 let input_len = input.len();
327 while len < max_len && pos2 + len < input_len && input[pos1 + len] == input[pos2 + len] {
328 len += 1;
329 }
330 len
331}
332
333pub(crate) fn deflate_bytes(input_bytes: &[u8]) -> Result<Vec<u8>, TransformError> {
335 let mut writer = BitWriter::new();
336
337 if input_bytes.is_empty() {
338 writer.write_bits(1, 1); writer.write_bits(1, 2); let (reversed_eob_huff, eob_bits) = get_fixed_literal_length_huffman_code(256); writer.write_bits(reversed_eob_huff as u32, eob_bits);
343 return Ok(writer.get_bytes());
344 }
345
346 let lz77_tokens = lz77_compress(input_bytes);
347
348 let mut estimated_bits = 0;
350 for token in &lz77_tokens {
351 match token {
352 Lz77Token::Literal(byte) => {
353 let (_, bits) = get_fixed_literal_length_huffman_code(*byte as u16);
354 estimated_bits += bits as usize;
355 }
356 Lz77Token::Match(length, distance) => {
357 let (len_code, _, len_extra_bits) = get_length_code(*length);
358 let (_, len_huff_bits) = get_fixed_literal_length_huffman_code(len_code);
359 estimated_bits += len_huff_bits as usize + len_extra_bits as usize;
360
361 let (dist_code, _, dist_extra_bits) = get_distance_code(*distance);
362 let (_, dist_huff_bits) = get_fixed_distance_huffman_code(dist_code);
363 estimated_bits += dist_huff_bits as usize + dist_extra_bits as usize;
364 }
365 }
366 }
367 let (_, eob_bits) = get_fixed_literal_length_huffman_code(256); estimated_bits += eob_bits as usize;
369 estimated_bits += 3; let uncompressed_size_bytes = input_bytes.len() + 5;
372 let uncompressed_size_bits = uncompressed_size_bytes * 8;
373
374 writer.write_bits(1, 1); if estimated_bits >= uncompressed_size_bits {
378 writer.write_bits(0, 2); writer.align_to_byte();
381 let len: u16 = input_bytes.len().try_into().map_err(|_| {
382 TransformError::CompressionError(
383 "Input too large for uncompressed block length (max 65535)".into(),
384 )
385 })?;
386 let nlen = !len;
387 writer.write_bytes_raw(&len.to_le_bytes());
388 writer.write_bytes_raw(&nlen.to_le_bytes());
389 writer.write_bytes_raw(input_bytes);
390 } else {
391 writer.write_bits(1, 2); for token in lz77_tokens {
394 match token {
395 Lz77Token::Match(length, distance) => {
396 let (len_code, len_extra_val, len_extra_bits) = get_length_code(length);
397 let (reversed_len_huff, len_huff_bits) =
398 get_fixed_literal_length_huffman_code(len_code);
399 writer.write_bits(reversed_len_huff as u32, len_huff_bits);
400 if len_extra_bits > 0 {
401 writer.write_bits(len_extra_val, len_extra_bits);
402 }
403
404 let (dist_code, dist_extra_val, dist_extra_bits) = get_distance_code(distance);
405 let (reversed_dist_huff, dist_huff_bits) =
406 get_fixed_distance_huffman_code(dist_code);
407 writer.write_bits(reversed_dist_huff as u32, dist_huff_bits);
408 if dist_extra_bits > 0 {
409 writer.write_bits(dist_extra_val, dist_extra_bits);
410 }
411 }
412 Lz77Token::Literal(byte) => {
413 let (reversed_huff, huff_bits) =
414 get_fixed_literal_length_huffman_code(byte as u16);
415 writer.write_bits(reversed_huff as u32, huff_bits);
416 }
417 }
418 }
419 let (reversed_eob_huff, eob_bits) = get_fixed_literal_length_huffman_code(256);
421 writer.write_bits(reversed_eob_huff as u32, eob_bits);
422 }
423
424 Ok(writer.get_bytes())
425}
426
427#[derive(Debug, Clone, Copy, PartialEq, Eq)]
429pub struct DeflateCompress;
430
431impl Transform for DeflateCompress {
432 fn name(&self) -> &'static str {
433 "DEFLATE Compress"
434 }
435
436 fn id(&self) -> &'static str {
437 "deflatecompress"
438 }
439
440 fn category(&self) -> TransformerCategory {
441 TransformerCategory::Compression
442 }
443
444 fn description(&self) -> &'static str {
445 "Compresses input using the DEFLATE algorithm (RFC 1951) and encodes the output as Base64."
446 }
447
448 fn transform(&self, input: &str) -> Result<String, TransformError> {
450 let input_bytes = input.as_bytes();
451 let compressed_data = deflate_bytes(input_bytes)?; Ok(base64_encode::base64_encode(&compressed_data))
453 }
454
455 fn default_test_input(&self) -> &'static str {
456 "Hello, Deflate World!"
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463 use crate::transformers::deflate_decompress::DeflateDecompress;
464 use crate::Transform;
465
466 #[test]
467 fn test_deflate_empty() {
468 let transformer = DeflateCompress;
469 let result = transformer.transform("");
470 assert!(result.is_ok());
471 assert_eq!(result.unwrap(), "AwA=");
473 }
474
475 #[test]
476 fn test_deflate_simple() {
477 let compressor = DeflateCompress;
478 let decompressor = DeflateDecompress;
479 let input = compressor.default_test_input();
480 let compressed_b64 = compressor.transform(input).unwrap();
481 let decompressed = decompressor.transform(&compressed_b64).unwrap();
482 assert_eq!(decompressed, input);
483
484 let input_hi = "Hi";
486 let compressed_hi_b64 = compressor.transform(input_hi).unwrap();
487 let decompressed_hi = decompressor.transform(&compressed_hi_b64).unwrap();
488 assert_eq!(decompressed_hi, input_hi);
489 }
490
491 #[test]
492 fn test_deflate_repeated() {
493 let transformer = DeflateCompress;
494 let input = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
495 let expected_base64 = "SyQZAAA=";
496 match transformer.transform(input) {
497 Ok(actual_base64) => {
498 assert_eq!(actual_base64, expected_base64);
499 }
500 Err(e) => {
501 panic!("transform failed for input '{}': {:?}", input, e);
502 }
503 }
504 }
505
506 #[test]
507 fn test_deflate_longer_text() {
508 let transformer = DeflateCompress;
509 let input =
510 "This is a slightly longer test string to see how DEFLATE compression handles it.";
511 let expected_base64 = "C8nILFYAokSF4pzM9IySnEqFnPy89NQihZLU4hKF4pKizLx0hZJ8heLUVIWM/HIFF1c3H8cQV4Xk/NyCotTi4sz8PIWMxLyUnFSgOSV6AA==";
512 match transformer.transform(input) {
513 Ok(actual_base64) => {
514 assert_eq!(actual_base64, expected_base64);
515 }
516 Err(e) => {
517 panic!("transform failed for input '{}': {:?}", input, e);
518 }
519 }
520 }
521}