Skip to main content

oxihuman_core/
huffman_stub.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Real Huffman coding implementation with tree construction, canonical codes,
5//! bit-level encoding/decoding, and length-limited codes (max 15 bits).
6
7#![allow(dead_code)]
8
9use std::collections::BinaryHeap;
10
11/// Maximum allowed code length (like DEFLATE).
12const MAX_CODE_LEN: u8 = 15;
13
14// ---------------------------------------------------------------------------
15// Data structures
16// ---------------------------------------------------------------------------
17
18/// A node in the Huffman tree.
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct HuffNode {
21    pub symbol: Option<u8>,
22    pub freq: u64,
23    pub left: Option<usize>,
24    pub right: Option<usize>,
25}
26
27/// The full Huffman tree stored as a flat node array.
28#[derive(Debug, Clone)]
29pub struct HuffmanTree {
30    pub nodes: Vec<HuffNode>,
31    /// Index of the root node.
32    root: Option<usize>,
33}
34
35/// Lookup table: for each symbol 0..=255, `(code_bits, code_length)`.
36/// Symbols that do not appear have `code_length == 0`.
37#[derive(Debug, Clone)]
38pub struct HuffmanCodeTable {
39    /// Indexed by symbol value (0..=255). `(code_bits, code_length)`.
40    pub codes: Vec<(u32, u8)>,
41}
42
43/// A symbol with frequency and assigned code length (legacy public API).
44#[allow(dead_code)]
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct HuffmanSymbol {
47    pub byte: u8,
48    pub frequency: usize,
49    pub code_len: u8,
50}
51
52/// A frequency table mapping bytes to HuffmanSymbol entries (legacy API).
53#[allow(dead_code)]
54#[derive(Debug, Clone)]
55pub struct HuffmanTable {
56    pub symbols: Vec<HuffmanSymbol>,
57}
58
59/// Bit-level writer: packs bits into a `Vec<u8>`.
60#[derive(Debug, Clone)]
61pub struct BitWriter {
62    pub buffer: Vec<u8>,
63    pub bit_pos: usize,
64}
65
66/// Bit-level reader over a byte slice.
67#[derive(Debug, Clone)]
68pub struct BitReader<'a> {
69    pub data: &'a [u8],
70    pub bit_pos: usize,
71}
72
73// ---------------------------------------------------------------------------
74// Errors
75// ---------------------------------------------------------------------------
76
77/// Errors that may occur during Huffman operations.
78#[derive(Debug, Clone, PartialEq, Eq)]
79pub enum HuffmanError {
80    /// The input data is empty (no symbols to encode).
81    EmptyInput,
82    /// A symbol was not found in the table during encoding.
83    SymbolNotFound(u8),
84    /// Ran out of bits while decoding.
85    UnexpectedEndOfStream,
86    /// Decoded bit sequence does not match any symbol.
87    InvalidCode,
88}
89
90impl std::fmt::Display for HuffmanError {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        match self {
93            Self::EmptyInput => write!(f, "empty input"),
94            Self::SymbolNotFound(s) => write!(f, "symbol {s} not in table"),
95            Self::UnexpectedEndOfStream => write!(f, "unexpected end of bit stream"),
96            Self::InvalidCode => write!(f, "invalid huffman code in stream"),
97        }
98    }
99}
100
101impl std::error::Error for HuffmanError {}
102
103// ---------------------------------------------------------------------------
104// BitWriter
105// ---------------------------------------------------------------------------
106
107impl BitWriter {
108    /// Create a new, empty bit writer.
109    pub fn new() -> Self {
110        Self {
111            buffer: Vec::new(),
112            bit_pos: 0,
113        }
114    }
115
116    /// Create a new writer with pre-allocated capacity (in bytes).
117    pub fn with_capacity(bytes: usize) -> Self {
118        Self {
119            buffer: Vec::with_capacity(bytes),
120            bit_pos: 0,
121        }
122    }
123
124    /// Write `num_bits` least-significant bits of `value` to the stream,
125    /// MSB first (big-endian bit order within each code word).
126    pub fn write_bits(&mut self, value: u32, num_bits: u8) {
127        for i in (0..num_bits).rev() {
128            let bit = (value >> i) & 1;
129            let byte_idx = self.bit_pos / 8;
130            let bit_idx = 7 - (self.bit_pos % 8);
131            if byte_idx >= self.buffer.len() {
132                self.buffer.push(0);
133            }
134            if bit == 1 {
135                self.buffer[byte_idx] |= 1 << bit_idx;
136            }
137            self.bit_pos += 1;
138        }
139    }
140
141    /// Total number of bits written so far.
142    pub fn total_bits(&self) -> usize {
143        self.bit_pos
144    }
145
146    /// Consume the writer, returning `(byte_buffer, total_bit_count)`.
147    pub fn finish(self) -> (Vec<u8>, usize) {
148        (self.buffer, self.bit_pos)
149    }
150}
151
152impl Default for BitWriter {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158// ---------------------------------------------------------------------------
159// BitReader
160// ---------------------------------------------------------------------------
161
162impl<'a> BitReader<'a> {
163    /// Create a new reader over the given byte slice.
164    pub fn new(data: &'a [u8]) -> Self {
165        Self { data, bit_pos: 0 }
166    }
167
168    /// Read `num_bits` from the stream (MSB first), returning the value.
169    /// Returns `None` if not enough bits remain.
170    pub fn read_bits(&mut self, num_bits: u8) -> Option<u32> {
171        let total_bits = self.data.len() * 8;
172        if self.bit_pos + num_bits as usize > total_bits {
173            return None;
174        }
175        let mut value: u32 = 0;
176        for _ in 0..num_bits {
177            let byte_idx = self.bit_pos / 8;
178            let bit_idx = 7 - (self.bit_pos % 8);
179            let bit = (self.data[byte_idx] >> bit_idx) & 1;
180            value = (value << 1) | bit as u32;
181            self.bit_pos += 1;
182        }
183        Some(value)
184    }
185
186    /// Read a single bit, returning 0 or 1, or `None` at end.
187    pub fn read_bit(&mut self) -> Option<u32> {
188        self.read_bits(1)
189    }
190
191    /// How many bits have been consumed.
192    pub fn position(&self) -> usize {
193        self.bit_pos
194    }
195
196    /// Set the reader position (in bits).
197    pub fn set_position(&mut self, pos: usize) {
198        self.bit_pos = pos;
199    }
200}
201
202// ---------------------------------------------------------------------------
203// HuffmanTree - build from frequencies using a min-heap
204// ---------------------------------------------------------------------------
205
206/// Entry used in the priority queue while building the tree.
207#[derive(Debug, Clone, Eq, PartialEq)]
208struct HeapEntry {
209    freq: u64,
210    /// Tie-break: lower index wins (keeps tree deterministic).
211    node_idx: usize,
212}
213
214impl Ord for HeapEntry {
215    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
216        // We want a min-heap, so reverse the natural order.
217        other
218            .freq
219            .cmp(&self.freq)
220            .then_with(|| other.node_idx.cmp(&self.node_idx))
221    }
222}
223
224impl PartialOrd for HeapEntry {
225    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
226        Some(self.cmp(other))
227    }
228}
229
230impl HuffmanTree {
231    /// Build a Huffman tree from a frequency array (indexed by symbol 0..=255).
232    /// Returns `None` if all frequencies are zero.
233    pub fn build(freq: &[u64; 256]) -> Option<Self> {
234        let mut nodes: Vec<HuffNode> = Vec::new();
235        let mut heap = BinaryHeap::new();
236
237        for (sym, &f) in freq.iter().enumerate() {
238            if f > 0 {
239                let idx = nodes.len();
240                nodes.push(HuffNode {
241                    symbol: Some(sym as u8),
242                    freq: f,
243                    left: None,
244                    right: None,
245                });
246                heap.push(HeapEntry {
247                    freq: f,
248                    node_idx: idx,
249                });
250            }
251        }
252
253        if heap.is_empty() {
254            return None;
255        }
256
257        // Special case: single symbol - we still need a tree with depth 1.
258        if heap.len() == 1 {
259            let entry = heap.pop()?;
260            let root_idx = nodes.len();
261            nodes.push(HuffNode {
262                symbol: None,
263                freq: entry.freq,
264                left: Some(entry.node_idx),
265                right: None,
266            });
267            return Some(Self {
268                nodes,
269                root: Some(root_idx),
270            });
271        }
272
273        while heap.len() >= 2 {
274            let a = heap.pop()?;
275            let b = heap.pop()?;
276            let combined_freq = a.freq + b.freq;
277            let parent_idx = nodes.len();
278            nodes.push(HuffNode {
279                symbol: None,
280                freq: combined_freq,
281                left: Some(a.node_idx),
282                right: Some(b.node_idx),
283            });
284            heap.push(HeapEntry {
285                freq: combined_freq,
286                node_idx: parent_idx,
287            });
288        }
289
290        let root_entry = heap.pop()?;
291        Some(Self {
292            nodes,
293            root: Some(root_entry.node_idx),
294        })
295    }
296
297    /// Extract code lengths per symbol by walking the tree.
298    /// Returns an array of 256 code lengths.
299    pub fn code_lengths(&self) -> [u8; 256] {
300        let mut lengths = [0u8; 256];
301        if let Some(root) = self.root {
302            self.walk(root, 0, &mut lengths);
303        }
304        lengths
305    }
306
307    fn walk(&self, idx: usize, depth: u8, lengths: &mut [u8; 256]) {
308        let node = &self.nodes[idx];
309        if let Some(sym) = node.symbol {
310            // Leaf node - depth is the code length (min 1).
311            lengths[sym as usize] = depth.max(1);
312            return;
313        }
314        if let Some(left) = node.left {
315            self.walk(left, depth.saturating_add(1), lengths);
316        }
317        if let Some(right) = node.right {
318            self.walk(right, depth.saturating_add(1), lengths);
319        }
320    }
321}
322
323// ---------------------------------------------------------------------------
324// Length-limited Huffman codes
325// ---------------------------------------------------------------------------
326
327/// Limit code lengths to `max_len` using a heuristic:
328/// clamp all to max_len, then fix the Kraft inequality by lengthening short codes.
329fn limit_code_lengths(lengths: &mut [u8; 256], max_len: u8) {
330    let needs_limiting = lengths.iter().any(|&l| l > max_len);
331    if !needs_limiting {
332        return;
333    }
334
335    // Collect active symbols.
336    let mut syms: Vec<(usize, u8)> = lengths
337        .iter()
338        .enumerate()
339        .filter(|(_, &l)| l > 0)
340        .map(|(s, &l)| (s, l))
341        .collect();
342
343    // Clamp all lengths to max_len.
344    for (_, len) in &mut syms {
345        if *len > max_len {
346            *len = max_len;
347        }
348    }
349
350    // Fix Kraft inequality: sum(2^(max_len - l_i)) <= 2^max_len
351    loop {
352        let kraft_sum: u64 = syms.iter().map(|(_, l)| 1u64 << (max_len - *l)).sum();
353        let kraft_limit = 1u64 << max_len;
354
355        if kraft_sum <= kraft_limit {
356            break;
357        }
358
359        // Sort by length ascending, then symbol ascending.
360        syms.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| a.0.cmp(&b.0)));
361
362        let mut fixed = false;
363        for (_, len) in &mut syms {
364            if *len < max_len {
365                *len += 1;
366                fixed = true;
367                break;
368            }
369        }
370        if !fixed {
371            break;
372        }
373    }
374
375    // Write back.
376    for &(s, l) in &syms {
377        lengths[s] = l;
378    }
379}
380
381// ---------------------------------------------------------------------------
382// Canonical Huffman code assignment
383// ---------------------------------------------------------------------------
384
385impl HuffmanCodeTable {
386    /// Build a canonical Huffman table from code lengths.
387    /// Symbols with `length == 0` are not encoded.
388    pub fn from_lengths(lengths: &[u8; 256]) -> Self {
389        let mut codes = vec![(0u32, 0u8); 256];
390
391        // Collect (symbol, length) pairs for active symbols, then sort.
392        let mut active: Vec<(u8, u8)> = lengths
393            .iter()
394            .enumerate()
395            .filter(|(_, &l)| l > 0)
396            .map(|(s, &l)| (s as u8, l))
397            .collect();
398
399        // Canonical ordering: by (length, symbol).
400        active.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| a.0.cmp(&b.0)));
401
402        if active.is_empty() {
403            return Self { codes };
404        }
405
406        let mut code: u32 = 0;
407        let mut prev_len = active[0].1;
408
409        for (i, &(sym, len)) in active.iter().enumerate() {
410            if i > 0 {
411                code += 1;
412                if len > prev_len {
413                    code <<= len - prev_len;
414                }
415            }
416            codes[sym as usize] = (code, len);
417            prev_len = len;
418        }
419
420        Self { codes }
421    }
422
423    /// Build a Huffman table directly from raw data bytes.
424    /// Returns `None` if `data` is empty.
425    pub fn from_data(data: &[u8]) -> Option<Self> {
426        if data.is_empty() {
427            return None;
428        }
429        let mut freq = [0u64; 256];
430        for &b in data {
431            freq[b as usize] += 1;
432        }
433        let tree = HuffmanTree::build(&freq)?;
434        let mut lengths = tree.code_lengths();
435        limit_code_lengths(&mut lengths, MAX_CODE_LEN);
436        Some(Self::from_lengths(&lengths))
437    }
438
439    /// Look up the code for a given symbol.
440    /// Returns `None` if the symbol is not in the table.
441    pub fn lookup(&self, symbol: u8) -> Option<(u32, u8)> {
442        let (bits, len) = self.codes[symbol as usize];
443        if len == 0 {
444            None
445        } else {
446            Some((bits, len))
447        }
448    }
449}
450
451// ---------------------------------------------------------------------------
452// Encoding
453// ---------------------------------------------------------------------------
454
455/// Encode a slice of bytes into a packed bit stream using the given table.
456/// Returns `(byte_buffer, total_bit_count)`.
457pub fn huffman_encode(
458    data: &[u8],
459    table: &HuffmanCodeTable,
460) -> Result<(Vec<u8>, usize), HuffmanError> {
461    if data.is_empty() {
462        return Err(HuffmanError::EmptyInput);
463    }
464    let mut writer = BitWriter::with_capacity(data.len());
465    for &b in data {
466        let (bits, len) = table.lookup(b).ok_or(HuffmanError::SymbolNotFound(b))?;
467        writer.write_bits(bits, len);
468    }
469    Ok(writer.finish())
470}
471
472// ---------------------------------------------------------------------------
473// Decoding
474// ---------------------------------------------------------------------------
475
476/// Decode lookup structure for efficient symbol resolution.
477struct DecodeLookup {
478    /// (code_bits, code_length, symbol), sorted by (length, code).
479    entries: Vec<(u32, u8, u8)>,
480}
481
482impl DecodeLookup {
483    fn from_table(table: &HuffmanCodeTable) -> Self {
484        let mut entries: Vec<(u32, u8, u8)> = table
485            .codes
486            .iter()
487            .enumerate()
488            .filter(|(_, &(_, len))| len > 0)
489            .map(|(sym, &(bits, len))| (bits, len, sym as u8))
490            .collect();
491        entries.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| a.0.cmp(&b.0)));
492        Self { entries }
493    }
494
495    /// Decode one symbol from the reader.
496    fn decode_one(&self, reader: &mut BitReader<'_>) -> Result<u8, HuffmanError> {
497        let start = reader.position();
498        let mut accumulated: u32 = 0;
499        let mut bits_read: u8 = 0;
500
501        for &(code, len, sym) in &self.entries {
502            while bits_read < len {
503                let bit = reader
504                    .read_bit()
505                    .ok_or(HuffmanError::UnexpectedEndOfStream)?;
506                accumulated = (accumulated << 1) | bit;
507                bits_read += 1;
508            }
509            if bits_read == len && accumulated == code {
510                return Ok(sym);
511            }
512        }
513
514        reader.set_position(start);
515        Err(HuffmanError::InvalidCode)
516    }
517}
518
519/// Decode a packed bit stream back to symbols.
520///
521/// - `data`: the byte buffer containing packed bits.
522/// - `bit_count`: total number of valid bits in the stream.
523/// - `symbol_count`: how many symbols to decode.
524/// - `table`: the Huffman table used for encoding.
525pub fn huffman_decode(
526    data: &[u8],
527    bit_count: usize,
528    symbol_count: usize,
529    table: &HuffmanCodeTable,
530) -> Result<Vec<u8>, HuffmanError> {
531    let lookup = DecodeLookup::from_table(table);
532    let mut reader = BitReader::new(data);
533    let mut output = Vec::with_capacity(symbol_count);
534
535    for _ in 0..symbol_count {
536        if reader.position() >= bit_count {
537            return Err(HuffmanError::UnexpectedEndOfStream);
538        }
539        let sym = lookup.decode_one(&mut reader)?;
540        output.push(sym);
541    }
542
543    Ok(output)
544}
545
546// ---------------------------------------------------------------------------
547// Legacy public API (kept for backward compatibility)
548// ---------------------------------------------------------------------------
549
550/// Build a frequency table from the given data slice (legacy API).
551///
552/// This now uses a real Huffman tree to assign code lengths.
553#[allow(dead_code)]
554pub fn build_frequency_table(data: &[u8]) -> HuffmanTable {
555    let mut freq = [0u64; 256];
556    for &b in data {
557        freq[b as usize] += 1;
558    }
559
560    let mut symbols: Vec<HuffmanSymbol> = freq
561        .iter()
562        .enumerate()
563        .filter(|(_, &f)| f > 0)
564        .map(|(i, &f)| HuffmanSymbol {
565            byte: i as u8,
566            frequency: f as usize,
567            code_len: 0,
568        })
569        .collect();
570
571    // Build real tree and get lengths.
572    if let Some(tree) = HuffmanTree::build(&freq) {
573        let mut lengths = tree.code_lengths();
574        limit_code_lengths(&mut lengths, MAX_CODE_LEN);
575        for sym in &mut symbols {
576            sym.code_len = lengths[sym.byte as usize];
577        }
578    }
579
580    // Sort by frequency descending (legacy behavior).
581    symbols.sort_by(|a, b| b.frequency.cmp(&a.frequency));
582
583    HuffmanTable { symbols }
584}
585
586/// Encode a byte to its stub code (index in table), or `None` if not present.
587#[allow(dead_code)]
588pub fn encode_symbol(table: &HuffmanTable, byte: u8) -> Option<u8> {
589    table
590        .symbols
591        .iter()
592        .enumerate()
593        .find(|(_, s)| s.byte == byte)
594        .map(|(i, _)| i as u8)
595}
596
597/// Decode a code back to the original byte, or `None` if code out of range.
598#[allow(dead_code)]
599pub fn decode_symbol(table: &HuffmanTable, code: u8) -> Option<u8> {
600    table.symbols.get(code as usize).map(|s| s.byte)
601}
602
603/// Return the number of symbols in the table.
604#[allow(dead_code)]
605pub fn table_size(table: &HuffmanTable) -> usize {
606    table.symbols.len()
607}
608
609// ---------------------------------------------------------------------------
610// Tests
611// ---------------------------------------------------------------------------
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    // -----------------------------------------------------------------------
618    // Legacy API tests (kept from original stub)
619    // -----------------------------------------------------------------------
620
621    #[test]
622    fn test_empty_data_gives_empty_table() {
623        let t = build_frequency_table(&[]);
624        assert_eq!(table_size(&t), 0);
625    }
626
627    #[test]
628    fn test_single_byte_table() {
629        let t = build_frequency_table(&[42u8; 10]);
630        assert_eq!(table_size(&t), 1);
631        assert_eq!(t.symbols[0].byte, 42);
632        assert_eq!(t.symbols[0].frequency, 10);
633    }
634
635    #[test]
636    fn test_multiple_bytes_sorted_by_frequency() {
637        let data = [1u8, 1, 1, 2, 2, 3];
638        let t = build_frequency_table(&data);
639        assert!(t.symbols[0].frequency >= t.symbols[1].frequency);
640    }
641
642    #[test]
643    fn test_encode_symbol_found() {
644        let data = [5u8, 5, 5, 10, 10];
645        let t = build_frequency_table(&data);
646        // byte 5 has highest freq -> code 0
647        assert_eq!(encode_symbol(&t, 5), Some(0));
648    }
649
650    #[test]
651    fn test_encode_symbol_not_found() {
652        let data = [1u8, 2, 3];
653        let t = build_frequency_table(&data);
654        assert_eq!(encode_symbol(&t, 99), None);
655    }
656
657    #[test]
658    fn test_decode_symbol_roundtrip() {
659        let data = [7u8, 7, 8, 9];
660        let t = build_frequency_table(&data);
661        let code = encode_symbol(&t, 7).expect("should succeed");
662        assert_eq!(decode_symbol(&t, code), Some(7));
663    }
664
665    #[test]
666    fn test_decode_out_of_range() {
667        let t = build_frequency_table(&[1u8, 2]);
668        assert_eq!(decode_symbol(&t, 200), None);
669    }
670
671    #[test]
672    fn test_code_len_assigned() {
673        let data = [0u8, 0, 1, 2];
674        let t = build_frequency_table(&data);
675        // All symbols get code_len >= 1
676        for sym in &t.symbols {
677            assert!(sym.code_len >= 1);
678        }
679    }
680
681    #[test]
682    fn test_table_size_matches_unique_bytes() {
683        let data = [10u8, 20, 30, 10, 20];
684        let t = build_frequency_table(&data);
685        assert_eq!(table_size(&t), 3);
686    }
687
688    // -----------------------------------------------------------------------
689    // BitWriter / BitReader tests
690    // -----------------------------------------------------------------------
691
692    #[test]
693    fn test_bit_writer_single_byte() {
694        let mut w = BitWriter::new();
695        w.write_bits(0b10110011, 8);
696        assert_eq!(w.total_bits(), 8);
697        let (buf, bits) = w.finish();
698        assert_eq!(bits, 8);
699        assert_eq!(buf, vec![0b10110011]);
700    }
701
702    #[test]
703    fn test_bit_writer_partial_byte() {
704        let mut w = BitWriter::new();
705        w.write_bits(0b101, 3);
706        assert_eq!(w.total_bits(), 3);
707        let (buf, bits) = w.finish();
708        assert_eq!(bits, 3);
709        // 101 written to top 3 bits of byte => 10100000
710        assert_eq!(buf, vec![0b10100000]);
711    }
712
713    #[test]
714    fn test_bit_roundtrip() {
715        let mut w = BitWriter::new();
716        w.write_bits(0b110, 3);
717        w.write_bits(0b01011, 5);
718        w.write_bits(0b1, 1);
719        let (buf, total) = w.finish();
720        assert_eq!(total, 9);
721
722        let mut r = BitReader::new(&buf);
723        assert_eq!(r.read_bits(3), Some(0b110));
724        assert_eq!(r.read_bits(5), Some(0b01011));
725        assert_eq!(r.read_bits(1), Some(0b1));
726    }
727
728    #[test]
729    fn test_bit_reader_out_of_bounds() {
730        let data = [0xFF];
731        let mut r = BitReader::new(&data);
732        assert_eq!(r.read_bits(8), Some(0xFF));
733        assert_eq!(r.read_bits(1), None);
734    }
735
736    // -----------------------------------------------------------------------
737    // Huffman tree tests
738    // -----------------------------------------------------------------------
739
740    #[test]
741    fn test_tree_build_empty() {
742        let freq = [0u64; 256];
743        assert!(HuffmanTree::build(&freq).is_none());
744    }
745
746    #[test]
747    fn test_tree_single_symbol() {
748        let mut freq = [0u64; 256];
749        freq[65] = 100; // 'A'
750        let tree = HuffmanTree::build(&freq).expect("should succeed");
751        let lengths = tree.code_lengths();
752        assert_eq!(lengths[65], 1);
753        for (i, &l) in lengths.iter().enumerate() {
754            if i != 65 {
755                assert_eq!(l, 0);
756            }
757        }
758    }
759
760    #[test]
761    fn test_tree_two_symbols() {
762        let mut freq = [0u64; 256];
763        freq[0] = 10;
764        freq[1] = 5;
765        let tree = HuffmanTree::build(&freq).expect("should succeed");
766        let lengths = tree.code_lengths();
767        assert_eq!(lengths[0], 1);
768        assert_eq!(lengths[1], 1);
769    }
770
771    #[test]
772    fn test_tree_multiple_symbols_kraft_inequality() {
773        let mut freq = [0u64; 256];
774        freq[0] = 100;
775        freq[1] = 50;
776        freq[2] = 25;
777        freq[3] = 12;
778        let tree = HuffmanTree::build(&freq).expect("should succeed");
779        let lengths = tree.code_lengths();
780
781        let kraft: f64 = lengths
782            .iter()
783            .filter(|&&l| l > 0)
784            .map(|&l| 2.0f64.powi(-(l as i32)))
785            .sum();
786        assert!(kraft <= 1.0 + 1e-10, "Kraft inequality violated: {kraft}");
787    }
788
789    // -----------------------------------------------------------------------
790    // Canonical code tests
791    // -----------------------------------------------------------------------
792
793    #[test]
794    fn test_canonical_codes_simple() {
795        let mut lengths = [0u8; 256];
796        lengths[b'A' as usize] = 1;
797        lengths[b'B' as usize] = 2;
798        lengths[b'C' as usize] = 2;
799
800        let table = HuffmanCodeTable::from_lengths(&lengths);
801
802        let (a_bits, a_len) = table.codes[b'A' as usize];
803        let (b_bits, b_len) = table.codes[b'B' as usize];
804        let (c_bits, c_len) = table.codes[b'C' as usize];
805
806        assert_eq!(a_len, 1);
807        assert_eq!(b_len, 2);
808        assert_eq!(c_len, 2);
809
810        // Canonical assignment: A=0, B=10, C=11
811        assert_eq!(a_bits, 0b0);
812        assert_eq!(b_bits, 0b10);
813        assert_eq!(c_bits, 0b11);
814    }
815
816    // -----------------------------------------------------------------------
817    // Length limiting tests
818    // -----------------------------------------------------------------------
819
820    #[test]
821    fn test_length_limiting() {
822        let mut lengths = [0u8; 256];
823        lengths[..32].fill(20);
824        limit_code_lengths(&mut lengths, MAX_CODE_LEN);
825        for (i, &len) in lengths[..32].iter().enumerate() {
826            assert!(
827                len <= MAX_CODE_LEN,
828                "symbol {i} has length {} > {MAX_CODE_LEN}",
829                len
830            );
831        }
832    }
833
834    #[test]
835    fn test_length_limiting_preserves_kraft() {
836        let mut lengths = [0u8; 256];
837        lengths[..16].fill(18);
838        limit_code_lengths(&mut lengths, MAX_CODE_LEN);
839
840        let kraft: f64 = lengths
841            .iter()
842            .filter(|&&l| l > 0)
843            .map(|&l| 2.0f64.powi(-(l as i32)))
844            .sum();
845        assert!(
846            kraft <= 1.0 + 1e-10,
847            "Kraft inequality violated after limiting: {kraft}"
848        );
849    }
850
851    // -----------------------------------------------------------------------
852    // Encode / Decode roundtrip tests
853    // -----------------------------------------------------------------------
854
855    #[test]
856    fn test_encode_decode_roundtrip_simple() {
857        let data = b"aabbbc";
858        let table = HuffmanCodeTable::from_data(data).expect("should succeed");
859        let (encoded, bit_count) = huffman_encode(data, &table).expect("should succeed");
860        let decoded =
861            huffman_decode(&encoded, bit_count, data.len(), &table).expect("should succeed");
862        assert_eq!(decoded, data);
863    }
864
865    #[test]
866    fn test_encode_decode_roundtrip_single_symbol() {
867        let data = vec![42u8; 100];
868        let table = HuffmanCodeTable::from_data(&data).expect("should succeed");
869        let (encoded, bit_count) = huffman_encode(&data, &table).expect("should succeed");
870        let decoded =
871            huffman_decode(&encoded, bit_count, data.len(), &table).expect("should succeed");
872        assert_eq!(decoded, data);
873    }
874
875    #[test]
876    fn test_encode_decode_roundtrip_all_bytes() {
877        let mut data: Vec<u8> = (0..=255u8).collect();
878        data.extend(std::iter::repeat_n(0u8, 50));
879        data.extend(std::iter::repeat_n(1u8, 30));
880        data.extend(std::iter::repeat_n(255u8, 20));
881
882        let table = HuffmanCodeTable::from_data(&data).expect("should succeed");
883
884        for &(_, len) in &table.codes {
885            if len > 0 {
886                assert!(len <= MAX_CODE_LEN);
887            }
888        }
889
890        let (encoded, bit_count) = huffman_encode(&data, &table).expect("should succeed");
891        let decoded =
892            huffman_decode(&encoded, bit_count, data.len(), &table).expect("should succeed");
893        assert_eq!(decoded, data);
894    }
895
896    #[test]
897    fn test_encode_decode_roundtrip_two_symbols() {
898        let data = vec![0u8, 0, 0, 1, 1, 0, 1, 0, 0, 1];
899        let table = HuffmanCodeTable::from_data(&data).expect("should succeed");
900        let (encoded, bit_count) = huffman_encode(&data, &table).expect("should succeed");
901        let decoded =
902            huffman_decode(&encoded, bit_count, data.len(), &table).expect("should succeed");
903        assert_eq!(decoded, data);
904    }
905
906    #[test]
907    fn test_encode_decode_large_data() {
908        let mut data = Vec::new();
909        for sym in 0u8..50 {
910            let count = 1000 / (sym as usize + 1);
911            for _ in 0..count {
912                data.push(sym);
913            }
914        }
915        let table = HuffmanCodeTable::from_data(&data).expect("should succeed");
916        let (encoded, bit_count) = huffman_encode(&data, &table).expect("should succeed");
917        let decoded =
918            huffman_decode(&encoded, bit_count, data.len(), &table).expect("should succeed");
919        assert_eq!(decoded, data);
920    }
921
922    #[test]
923    fn test_encode_empty_data_error() {
924        let table = HuffmanCodeTable {
925            codes: vec![(0, 0); 256],
926        };
927        assert_eq!(huffman_encode(&[], &table), Err(HuffmanError::EmptyInput));
928    }
929
930    #[test]
931    fn test_encode_symbol_not_in_table_error() {
932        let mut lengths = [0u8; 256];
933        lengths[0] = 1;
934        let table = HuffmanCodeTable::from_lengths(&lengths);
935        let result = huffman_encode(&[0, 1], &table);
936        assert_eq!(result, Err(HuffmanError::SymbolNotFound(1)));
937    }
938
939    #[test]
940    fn test_decode_unexpected_end() {
941        let data = b"ab";
942        let table = HuffmanCodeTable::from_data(data).expect("should succeed");
943        let (encoded, bit_count) = huffman_encode(data, &table).expect("should succeed");
944        let result = huffman_decode(&encoded, bit_count, 100, &table);
945        assert!(result.is_err());
946    }
947
948    #[test]
949    fn test_huffman_compression_ratio() {
950        let data: Vec<u8> = std::iter::repeat_n(0u8, 1000)
951            .chain(std::iter::repeat_n(1u8, 10))
952            .chain(std::iter::once(2u8))
953            .collect();
954
955        let table = HuffmanCodeTable::from_data(&data).expect("should succeed");
956        let (_, bit_count) = huffman_encode(&data, &table).expect("should succeed");
957        let original_bits = data.len() * 8;
958        assert!(
959            bit_count < original_bits,
960            "Expected compression: {bit_count} bits < {original_bits} bits"
961        );
962    }
963
964    #[test]
965    fn test_canonical_codes_no_prefix_conflict() {
966        let data: Vec<u8> = (0..10)
967            .flat_map(|i| vec![i; (i as usize + 1) * 10])
968            .collect();
969        let table = HuffmanCodeTable::from_data(&data).expect("should succeed");
970        let active: Vec<(u32, u8)> = table
971            .codes
972            .iter()
973            .filter(|(_, len)| *len > 0)
974            .copied()
975            .collect();
976
977        for (i, &(code_a, len_a)) in active.iter().enumerate() {
978            for &(code_b, len_b) in &active[i + 1..] {
979                if len_a <= len_b {
980                    let shifted = code_b >> (len_b - len_a);
981                    assert_ne!(
982                        shifted, code_a,
983                        "Prefix conflict: ({code_a:#b}, {len_a}) is prefix of ({code_b:#b}, {len_b})"
984                    );
985                } else {
986                    let shifted = code_a >> (len_a - len_b);
987                    assert_ne!(
988                        shifted, code_b,
989                        "Prefix conflict: ({code_b:#b}, {len_b}) is prefix of ({code_a:#b}, {len_a})"
990                    );
991                }
992            }
993        }
994    }
995
996    #[test]
997    fn test_from_data_none_on_empty() {
998        assert!(HuffmanCodeTable::from_data(&[]).is_none());
999    }
1000
1001    #[test]
1002    fn test_table_lookup() {
1003        let data = b"aaabbc";
1004        let table = HuffmanCodeTable::from_data(data).expect("should succeed");
1005        assert!(table.lookup(b'a').is_some());
1006        assert!(table.lookup(b'z').is_none());
1007    }
1008}