1#![allow(dead_code)]
8
9use std::collections::BinaryHeap;
10
11const MAX_CODE_LEN: u8 = 15;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct HuffNode {
21 pub symbol: Option<u8>,
22 pub freq: u64,
23 pub left: Option<usize>,
24 pub right: Option<usize>,
25}
26
27#[derive(Debug, Clone)]
29pub struct HuffmanTree {
30 pub nodes: Vec<HuffNode>,
31 root: Option<usize>,
33}
34
35#[derive(Debug, Clone)]
38pub struct HuffmanCodeTable {
39 pub codes: Vec<(u32, u8)>,
41}
42
43#[allow(dead_code)]
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct HuffmanSymbol {
47 pub byte: u8,
48 pub frequency: usize,
49 pub code_len: u8,
50}
51
52#[allow(dead_code)]
54#[derive(Debug, Clone)]
55pub struct HuffmanTable {
56 pub symbols: Vec<HuffmanSymbol>,
57}
58
59#[derive(Debug, Clone)]
61pub struct BitWriter {
62 pub buffer: Vec<u8>,
63 pub bit_pos: usize,
64}
65
66#[derive(Debug, Clone)]
68pub struct BitReader<'a> {
69 pub data: &'a [u8],
70 pub bit_pos: usize,
71}
72
73#[derive(Debug, Clone, PartialEq, Eq)]
79pub enum HuffmanError {
80 EmptyInput,
82 SymbolNotFound(u8),
84 UnexpectedEndOfStream,
86 InvalidCode,
88}
89
90impl std::fmt::Display for HuffmanError {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 match self {
93 Self::EmptyInput => write!(f, "empty input"),
94 Self::SymbolNotFound(s) => write!(f, "symbol {s} not in table"),
95 Self::UnexpectedEndOfStream => write!(f, "unexpected end of bit stream"),
96 Self::InvalidCode => write!(f, "invalid huffman code in stream"),
97 }
98 }
99}
100
101impl std::error::Error for HuffmanError {}
102
103impl BitWriter {
108 pub fn new() -> Self {
110 Self {
111 buffer: Vec::new(),
112 bit_pos: 0,
113 }
114 }
115
116 pub fn with_capacity(bytes: usize) -> Self {
118 Self {
119 buffer: Vec::with_capacity(bytes),
120 bit_pos: 0,
121 }
122 }
123
124 pub fn write_bits(&mut self, value: u32, num_bits: u8) {
127 for i in (0..num_bits).rev() {
128 let bit = (value >> i) & 1;
129 let byte_idx = self.bit_pos / 8;
130 let bit_idx = 7 - (self.bit_pos % 8);
131 if byte_idx >= self.buffer.len() {
132 self.buffer.push(0);
133 }
134 if bit == 1 {
135 self.buffer[byte_idx] |= 1 << bit_idx;
136 }
137 self.bit_pos += 1;
138 }
139 }
140
141 pub fn total_bits(&self) -> usize {
143 self.bit_pos
144 }
145
146 pub fn finish(self) -> (Vec<u8>, usize) {
148 (self.buffer, self.bit_pos)
149 }
150}
151
152impl Default for BitWriter {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158impl<'a> BitReader<'a> {
163 pub fn new(data: &'a [u8]) -> Self {
165 Self { data, bit_pos: 0 }
166 }
167
168 pub fn read_bits(&mut self, num_bits: u8) -> Option<u32> {
171 let total_bits = self.data.len() * 8;
172 if self.bit_pos + num_bits as usize > total_bits {
173 return None;
174 }
175 let mut value: u32 = 0;
176 for _ in 0..num_bits {
177 let byte_idx = self.bit_pos / 8;
178 let bit_idx = 7 - (self.bit_pos % 8);
179 let bit = (self.data[byte_idx] >> bit_idx) & 1;
180 value = (value << 1) | bit as u32;
181 self.bit_pos += 1;
182 }
183 Some(value)
184 }
185
186 pub fn read_bit(&mut self) -> Option<u32> {
188 self.read_bits(1)
189 }
190
191 pub fn position(&self) -> usize {
193 self.bit_pos
194 }
195
196 pub fn set_position(&mut self, pos: usize) {
198 self.bit_pos = pos;
199 }
200}
201
202#[derive(Debug, Clone, Eq, PartialEq)]
208struct HeapEntry {
209 freq: u64,
210 node_idx: usize,
212}
213
214impl Ord for HeapEntry {
215 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
216 other
218 .freq
219 .cmp(&self.freq)
220 .then_with(|| other.node_idx.cmp(&self.node_idx))
221 }
222}
223
224impl PartialOrd for HeapEntry {
225 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
226 Some(self.cmp(other))
227 }
228}
229
230impl HuffmanTree {
231 pub fn build(freq: &[u64; 256]) -> Option<Self> {
234 let mut nodes: Vec<HuffNode> = Vec::new();
235 let mut heap = BinaryHeap::new();
236
237 for (sym, &f) in freq.iter().enumerate() {
238 if f > 0 {
239 let idx = nodes.len();
240 nodes.push(HuffNode {
241 symbol: Some(sym as u8),
242 freq: f,
243 left: None,
244 right: None,
245 });
246 heap.push(HeapEntry {
247 freq: f,
248 node_idx: idx,
249 });
250 }
251 }
252
253 if heap.is_empty() {
254 return None;
255 }
256
257 if heap.len() == 1 {
259 let entry = heap.pop()?;
260 let root_idx = nodes.len();
261 nodes.push(HuffNode {
262 symbol: None,
263 freq: entry.freq,
264 left: Some(entry.node_idx),
265 right: None,
266 });
267 return Some(Self {
268 nodes,
269 root: Some(root_idx),
270 });
271 }
272
273 while heap.len() >= 2 {
274 let a = heap.pop()?;
275 let b = heap.pop()?;
276 let combined_freq = a.freq + b.freq;
277 let parent_idx = nodes.len();
278 nodes.push(HuffNode {
279 symbol: None,
280 freq: combined_freq,
281 left: Some(a.node_idx),
282 right: Some(b.node_idx),
283 });
284 heap.push(HeapEntry {
285 freq: combined_freq,
286 node_idx: parent_idx,
287 });
288 }
289
290 let root_entry = heap.pop()?;
291 Some(Self {
292 nodes,
293 root: Some(root_entry.node_idx),
294 })
295 }
296
297 pub fn code_lengths(&self) -> [u8; 256] {
300 let mut lengths = [0u8; 256];
301 if let Some(root) = self.root {
302 self.walk(root, 0, &mut lengths);
303 }
304 lengths
305 }
306
307 fn walk(&self, idx: usize, depth: u8, lengths: &mut [u8; 256]) {
308 let node = &self.nodes[idx];
309 if let Some(sym) = node.symbol {
310 lengths[sym as usize] = depth.max(1);
312 return;
313 }
314 if let Some(left) = node.left {
315 self.walk(left, depth.saturating_add(1), lengths);
316 }
317 if let Some(right) = node.right {
318 self.walk(right, depth.saturating_add(1), lengths);
319 }
320 }
321}
322
323fn limit_code_lengths(lengths: &mut [u8; 256], max_len: u8) {
330 let needs_limiting = lengths.iter().any(|&l| l > max_len);
331 if !needs_limiting {
332 return;
333 }
334
335 let mut syms: Vec<(usize, u8)> = lengths
337 .iter()
338 .enumerate()
339 .filter(|(_, &l)| l > 0)
340 .map(|(s, &l)| (s, l))
341 .collect();
342
343 for (_, len) in &mut syms {
345 if *len > max_len {
346 *len = max_len;
347 }
348 }
349
350 loop {
352 let kraft_sum: u64 = syms.iter().map(|(_, l)| 1u64 << (max_len - *l)).sum();
353 let kraft_limit = 1u64 << max_len;
354
355 if kraft_sum <= kraft_limit {
356 break;
357 }
358
359 syms.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| a.0.cmp(&b.0)));
361
362 let mut fixed = false;
363 for (_, len) in &mut syms {
364 if *len < max_len {
365 *len += 1;
366 fixed = true;
367 break;
368 }
369 }
370 if !fixed {
371 break;
372 }
373 }
374
375 for &(s, l) in &syms {
377 lengths[s] = l;
378 }
379}
380
381impl HuffmanCodeTable {
386 pub fn from_lengths(lengths: &[u8; 256]) -> Self {
389 let mut codes = vec![(0u32, 0u8); 256];
390
391 let mut active: Vec<(u8, u8)> = lengths
393 .iter()
394 .enumerate()
395 .filter(|(_, &l)| l > 0)
396 .map(|(s, &l)| (s as u8, l))
397 .collect();
398
399 active.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| a.0.cmp(&b.0)));
401
402 if active.is_empty() {
403 return Self { codes };
404 }
405
406 let mut code: u32 = 0;
407 let mut prev_len = active[0].1;
408
409 for (i, &(sym, len)) in active.iter().enumerate() {
410 if i > 0 {
411 code += 1;
412 if len > prev_len {
413 code <<= len - prev_len;
414 }
415 }
416 codes[sym as usize] = (code, len);
417 prev_len = len;
418 }
419
420 Self { codes }
421 }
422
423 pub fn from_data(data: &[u8]) -> Option<Self> {
426 if data.is_empty() {
427 return None;
428 }
429 let mut freq = [0u64; 256];
430 for &b in data {
431 freq[b as usize] += 1;
432 }
433 let tree = HuffmanTree::build(&freq)?;
434 let mut lengths = tree.code_lengths();
435 limit_code_lengths(&mut lengths, MAX_CODE_LEN);
436 Some(Self::from_lengths(&lengths))
437 }
438
439 pub fn lookup(&self, symbol: u8) -> Option<(u32, u8)> {
442 let (bits, len) = self.codes[symbol as usize];
443 if len == 0 {
444 None
445 } else {
446 Some((bits, len))
447 }
448 }
449}
450
451pub fn huffman_encode(
458 data: &[u8],
459 table: &HuffmanCodeTable,
460) -> Result<(Vec<u8>, usize), HuffmanError> {
461 if data.is_empty() {
462 return Err(HuffmanError::EmptyInput);
463 }
464 let mut writer = BitWriter::with_capacity(data.len());
465 for &b in data {
466 let (bits, len) = table.lookup(b).ok_or(HuffmanError::SymbolNotFound(b))?;
467 writer.write_bits(bits, len);
468 }
469 Ok(writer.finish())
470}
471
472struct DecodeLookup {
478 entries: Vec<(u32, u8, u8)>,
480}
481
482impl DecodeLookup {
483 fn from_table(table: &HuffmanCodeTable) -> Self {
484 let mut entries: Vec<(u32, u8, u8)> = table
485 .codes
486 .iter()
487 .enumerate()
488 .filter(|(_, &(_, len))| len > 0)
489 .map(|(sym, &(bits, len))| (bits, len, sym as u8))
490 .collect();
491 entries.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| a.0.cmp(&b.0)));
492 Self { entries }
493 }
494
495 fn decode_one(&self, reader: &mut BitReader<'_>) -> Result<u8, HuffmanError> {
497 let start = reader.position();
498 let mut accumulated: u32 = 0;
499 let mut bits_read: u8 = 0;
500
501 for &(code, len, sym) in &self.entries {
502 while bits_read < len {
503 let bit = reader
504 .read_bit()
505 .ok_or(HuffmanError::UnexpectedEndOfStream)?;
506 accumulated = (accumulated << 1) | bit;
507 bits_read += 1;
508 }
509 if bits_read == len && accumulated == code {
510 return Ok(sym);
511 }
512 }
513
514 reader.set_position(start);
515 Err(HuffmanError::InvalidCode)
516 }
517}
518
519pub fn huffman_decode(
526 data: &[u8],
527 bit_count: usize,
528 symbol_count: usize,
529 table: &HuffmanCodeTable,
530) -> Result<Vec<u8>, HuffmanError> {
531 let lookup = DecodeLookup::from_table(table);
532 let mut reader = BitReader::new(data);
533 let mut output = Vec::with_capacity(symbol_count);
534
535 for _ in 0..symbol_count {
536 if reader.position() >= bit_count {
537 return Err(HuffmanError::UnexpectedEndOfStream);
538 }
539 let sym = lookup.decode_one(&mut reader)?;
540 output.push(sym);
541 }
542
543 Ok(output)
544}
545
546#[allow(dead_code)]
554pub fn build_frequency_table(data: &[u8]) -> HuffmanTable {
555 let mut freq = [0u64; 256];
556 for &b in data {
557 freq[b as usize] += 1;
558 }
559
560 let mut symbols: Vec<HuffmanSymbol> = freq
561 .iter()
562 .enumerate()
563 .filter(|(_, &f)| f > 0)
564 .map(|(i, &f)| HuffmanSymbol {
565 byte: i as u8,
566 frequency: f as usize,
567 code_len: 0,
568 })
569 .collect();
570
571 if let Some(tree) = HuffmanTree::build(&freq) {
573 let mut lengths = tree.code_lengths();
574 limit_code_lengths(&mut lengths, MAX_CODE_LEN);
575 for sym in &mut symbols {
576 sym.code_len = lengths[sym.byte as usize];
577 }
578 }
579
580 symbols.sort_by(|a, b| b.frequency.cmp(&a.frequency));
582
583 HuffmanTable { symbols }
584}
585
586#[allow(dead_code)]
588pub fn encode_symbol(table: &HuffmanTable, byte: u8) -> Option<u8> {
589 table
590 .symbols
591 .iter()
592 .enumerate()
593 .find(|(_, s)| s.byte == byte)
594 .map(|(i, _)| i as u8)
595}
596
597#[allow(dead_code)]
599pub fn decode_symbol(table: &HuffmanTable, code: u8) -> Option<u8> {
600 table.symbols.get(code as usize).map(|s| s.byte)
601}
602
603#[allow(dead_code)]
605pub fn table_size(table: &HuffmanTable) -> usize {
606 table.symbols.len()
607}
608
609#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
622 fn test_empty_data_gives_empty_table() {
623 let t = build_frequency_table(&[]);
624 assert_eq!(table_size(&t), 0);
625 }
626
627 #[test]
628 fn test_single_byte_table() {
629 let t = build_frequency_table(&[42u8; 10]);
630 assert_eq!(table_size(&t), 1);
631 assert_eq!(t.symbols[0].byte, 42);
632 assert_eq!(t.symbols[0].frequency, 10);
633 }
634
635 #[test]
636 fn test_multiple_bytes_sorted_by_frequency() {
637 let data = [1u8, 1, 1, 2, 2, 3];
638 let t = build_frequency_table(&data);
639 assert!(t.symbols[0].frequency >= t.symbols[1].frequency);
640 }
641
642 #[test]
643 fn test_encode_symbol_found() {
644 let data = [5u8, 5, 5, 10, 10];
645 let t = build_frequency_table(&data);
646 assert_eq!(encode_symbol(&t, 5), Some(0));
648 }
649
650 #[test]
651 fn test_encode_symbol_not_found() {
652 let data = [1u8, 2, 3];
653 let t = build_frequency_table(&data);
654 assert_eq!(encode_symbol(&t, 99), None);
655 }
656
657 #[test]
658 fn test_decode_symbol_roundtrip() {
659 let data = [7u8, 7, 8, 9];
660 let t = build_frequency_table(&data);
661 let code = encode_symbol(&t, 7).expect("should succeed");
662 assert_eq!(decode_symbol(&t, code), Some(7));
663 }
664
665 #[test]
666 fn test_decode_out_of_range() {
667 let t = build_frequency_table(&[1u8, 2]);
668 assert_eq!(decode_symbol(&t, 200), None);
669 }
670
671 #[test]
672 fn test_code_len_assigned() {
673 let data = [0u8, 0, 1, 2];
674 let t = build_frequency_table(&data);
675 for sym in &t.symbols {
677 assert!(sym.code_len >= 1);
678 }
679 }
680
681 #[test]
682 fn test_table_size_matches_unique_bytes() {
683 let data = [10u8, 20, 30, 10, 20];
684 let t = build_frequency_table(&data);
685 assert_eq!(table_size(&t), 3);
686 }
687
688 #[test]
693 fn test_bit_writer_single_byte() {
694 let mut w = BitWriter::new();
695 w.write_bits(0b10110011, 8);
696 assert_eq!(w.total_bits(), 8);
697 let (buf, bits) = w.finish();
698 assert_eq!(bits, 8);
699 assert_eq!(buf, vec![0b10110011]);
700 }
701
702 #[test]
703 fn test_bit_writer_partial_byte() {
704 let mut w = BitWriter::new();
705 w.write_bits(0b101, 3);
706 assert_eq!(w.total_bits(), 3);
707 let (buf, bits) = w.finish();
708 assert_eq!(bits, 3);
709 assert_eq!(buf, vec![0b10100000]);
711 }
712
713 #[test]
714 fn test_bit_roundtrip() {
715 let mut w = BitWriter::new();
716 w.write_bits(0b110, 3);
717 w.write_bits(0b01011, 5);
718 w.write_bits(0b1, 1);
719 let (buf, total) = w.finish();
720 assert_eq!(total, 9);
721
722 let mut r = BitReader::new(&buf);
723 assert_eq!(r.read_bits(3), Some(0b110));
724 assert_eq!(r.read_bits(5), Some(0b01011));
725 assert_eq!(r.read_bits(1), Some(0b1));
726 }
727
728 #[test]
729 fn test_bit_reader_out_of_bounds() {
730 let data = [0xFF];
731 let mut r = BitReader::new(&data);
732 assert_eq!(r.read_bits(8), Some(0xFF));
733 assert_eq!(r.read_bits(1), None);
734 }
735
736 #[test]
741 fn test_tree_build_empty() {
742 let freq = [0u64; 256];
743 assert!(HuffmanTree::build(&freq).is_none());
744 }
745
746 #[test]
747 fn test_tree_single_symbol() {
748 let mut freq = [0u64; 256];
749 freq[65] = 100; let tree = HuffmanTree::build(&freq).expect("should succeed");
751 let lengths = tree.code_lengths();
752 assert_eq!(lengths[65], 1);
753 for (i, &l) in lengths.iter().enumerate() {
754 if i != 65 {
755 assert_eq!(l, 0);
756 }
757 }
758 }
759
760 #[test]
761 fn test_tree_two_symbols() {
762 let mut freq = [0u64; 256];
763 freq[0] = 10;
764 freq[1] = 5;
765 let tree = HuffmanTree::build(&freq).expect("should succeed");
766 let lengths = tree.code_lengths();
767 assert_eq!(lengths[0], 1);
768 assert_eq!(lengths[1], 1);
769 }
770
771 #[test]
772 fn test_tree_multiple_symbols_kraft_inequality() {
773 let mut freq = [0u64; 256];
774 freq[0] = 100;
775 freq[1] = 50;
776 freq[2] = 25;
777 freq[3] = 12;
778 let tree = HuffmanTree::build(&freq).expect("should succeed");
779 let lengths = tree.code_lengths();
780
781 let kraft: f64 = lengths
782 .iter()
783 .filter(|&&l| l > 0)
784 .map(|&l| 2.0f64.powi(-(l as i32)))
785 .sum();
786 assert!(kraft <= 1.0 + 1e-10, "Kraft inequality violated: {kraft}");
787 }
788
789 #[test]
794 fn test_canonical_codes_simple() {
795 let mut lengths = [0u8; 256];
796 lengths[b'A' as usize] = 1;
797 lengths[b'B' as usize] = 2;
798 lengths[b'C' as usize] = 2;
799
800 let table = HuffmanCodeTable::from_lengths(&lengths);
801
802 let (a_bits, a_len) = table.codes[b'A' as usize];
803 let (b_bits, b_len) = table.codes[b'B' as usize];
804 let (c_bits, c_len) = table.codes[b'C' as usize];
805
806 assert_eq!(a_len, 1);
807 assert_eq!(b_len, 2);
808 assert_eq!(c_len, 2);
809
810 assert_eq!(a_bits, 0b0);
812 assert_eq!(b_bits, 0b10);
813 assert_eq!(c_bits, 0b11);
814 }
815
816 #[test]
821 fn test_length_limiting() {
822 let mut lengths = [0u8; 256];
823 lengths[..32].fill(20);
824 limit_code_lengths(&mut lengths, MAX_CODE_LEN);
825 for (i, &len) in lengths[..32].iter().enumerate() {
826 assert!(
827 len <= MAX_CODE_LEN,
828 "symbol {i} has length {} > {MAX_CODE_LEN}",
829 len
830 );
831 }
832 }
833
834 #[test]
835 fn test_length_limiting_preserves_kraft() {
836 let mut lengths = [0u8; 256];
837 lengths[..16].fill(18);
838 limit_code_lengths(&mut lengths, MAX_CODE_LEN);
839
840 let kraft: f64 = lengths
841 .iter()
842 .filter(|&&l| l > 0)
843 .map(|&l| 2.0f64.powi(-(l as i32)))
844 .sum();
845 assert!(
846 kraft <= 1.0 + 1e-10,
847 "Kraft inequality violated after limiting: {kraft}"
848 );
849 }
850
851 #[test]
856 fn test_encode_decode_roundtrip_simple() {
857 let data = b"aabbbc";
858 let table = HuffmanCodeTable::from_data(data).expect("should succeed");
859 let (encoded, bit_count) = huffman_encode(data, &table).expect("should succeed");
860 let decoded =
861 huffman_decode(&encoded, bit_count, data.len(), &table).expect("should succeed");
862 assert_eq!(decoded, data);
863 }
864
865 #[test]
866 fn test_encode_decode_roundtrip_single_symbol() {
867 let data = vec![42u8; 100];
868 let table = HuffmanCodeTable::from_data(&data).expect("should succeed");
869 let (encoded, bit_count) = huffman_encode(&data, &table).expect("should succeed");
870 let decoded =
871 huffman_decode(&encoded, bit_count, data.len(), &table).expect("should succeed");
872 assert_eq!(decoded, data);
873 }
874
875 #[test]
876 fn test_encode_decode_roundtrip_all_bytes() {
877 let mut data: Vec<u8> = (0..=255u8).collect();
878 data.extend(std::iter::repeat_n(0u8, 50));
879 data.extend(std::iter::repeat_n(1u8, 30));
880 data.extend(std::iter::repeat_n(255u8, 20));
881
882 let table = HuffmanCodeTable::from_data(&data).expect("should succeed");
883
884 for &(_, len) in &table.codes {
885 if len > 0 {
886 assert!(len <= MAX_CODE_LEN);
887 }
888 }
889
890 let (encoded, bit_count) = huffman_encode(&data, &table).expect("should succeed");
891 let decoded =
892 huffman_decode(&encoded, bit_count, data.len(), &table).expect("should succeed");
893 assert_eq!(decoded, data);
894 }
895
896 #[test]
897 fn test_encode_decode_roundtrip_two_symbols() {
898 let data = vec![0u8, 0, 0, 1, 1, 0, 1, 0, 0, 1];
899 let table = HuffmanCodeTable::from_data(&data).expect("should succeed");
900 let (encoded, bit_count) = huffman_encode(&data, &table).expect("should succeed");
901 let decoded =
902 huffman_decode(&encoded, bit_count, data.len(), &table).expect("should succeed");
903 assert_eq!(decoded, data);
904 }
905
906 #[test]
907 fn test_encode_decode_large_data() {
908 let mut data = Vec::new();
909 for sym in 0u8..50 {
910 let count = 1000 / (sym as usize + 1);
911 for _ in 0..count {
912 data.push(sym);
913 }
914 }
915 let table = HuffmanCodeTable::from_data(&data).expect("should succeed");
916 let (encoded, bit_count) = huffman_encode(&data, &table).expect("should succeed");
917 let decoded =
918 huffman_decode(&encoded, bit_count, data.len(), &table).expect("should succeed");
919 assert_eq!(decoded, data);
920 }
921
922 #[test]
923 fn test_encode_empty_data_error() {
924 let table = HuffmanCodeTable {
925 codes: vec![(0, 0); 256],
926 };
927 assert_eq!(huffman_encode(&[], &table), Err(HuffmanError::EmptyInput));
928 }
929
930 #[test]
931 fn test_encode_symbol_not_in_table_error() {
932 let mut lengths = [0u8; 256];
933 lengths[0] = 1;
934 let table = HuffmanCodeTable::from_lengths(&lengths);
935 let result = huffman_encode(&[0, 1], &table);
936 assert_eq!(result, Err(HuffmanError::SymbolNotFound(1)));
937 }
938
939 #[test]
940 fn test_decode_unexpected_end() {
941 let data = b"ab";
942 let table = HuffmanCodeTable::from_data(data).expect("should succeed");
943 let (encoded, bit_count) = huffman_encode(data, &table).expect("should succeed");
944 let result = huffman_decode(&encoded, bit_count, 100, &table);
945 assert!(result.is_err());
946 }
947
948 #[test]
949 fn test_huffman_compression_ratio() {
950 let data: Vec<u8> = std::iter::repeat_n(0u8, 1000)
951 .chain(std::iter::repeat_n(1u8, 10))
952 .chain(std::iter::once(2u8))
953 .collect();
954
955 let table = HuffmanCodeTable::from_data(&data).expect("should succeed");
956 let (_, bit_count) = huffman_encode(&data, &table).expect("should succeed");
957 let original_bits = data.len() * 8;
958 assert!(
959 bit_count < original_bits,
960 "Expected compression: {bit_count} bits < {original_bits} bits"
961 );
962 }
963
964 #[test]
965 fn test_canonical_codes_no_prefix_conflict() {
966 let data: Vec<u8> = (0..10)
967 .flat_map(|i| vec![i; (i as usize + 1) * 10])
968 .collect();
969 let table = HuffmanCodeTable::from_data(&data).expect("should succeed");
970 let active: Vec<(u32, u8)> = table
971 .codes
972 .iter()
973 .filter(|(_, len)| *len > 0)
974 .copied()
975 .collect();
976
977 for (i, &(code_a, len_a)) in active.iter().enumerate() {
978 for &(code_b, len_b) in &active[i + 1..] {
979 if len_a <= len_b {
980 let shifted = code_b >> (len_b - len_a);
981 assert_ne!(
982 shifted, code_a,
983 "Prefix conflict: ({code_a:#b}, {len_a}) is prefix of ({code_b:#b}, {len_b})"
984 );
985 } else {
986 let shifted = code_a >> (len_a - len_b);
987 assert_ne!(
988 shifted, code_b,
989 "Prefix conflict: ({code_b:#b}, {len_b}) is prefix of ({code_a:#b}, {len_a})"
990 );
991 }
992 }
993 }
994 }
995
996 #[test]
997 fn test_from_data_none_on_empty() {
998 assert!(HuffmanCodeTable::from_data(&[]).is_none());
999 }
1000
1001 #[test]
1002 fn test_table_lookup() {
1003 let data = b"aaabbc";
1004 let table = HuffmanCodeTable::from_data(data).expect("should succeed");
1005 assert!(table.lookup(b'a').is_some());
1006 assert!(table.lookup(b'z').is_none());
1007 }
1008}