json_packer/
huffman.rs

1use std::collections::HashMap;
2
3use crate::{bitstream::{BitReader, BitWriter}, Error};
4
5#[derive(Debug, Clone)]
6pub struct HuffmanCodec {
7    // 编码映射:key -> (LSB-first code, bit_len)
8    enc_map: HashMap<String, (u64, u8)>,
9    // 解码用二叉树
10    root: Box<Node>,
11}
12
13#[derive(Debug, Clone)]
14enum Node {
15    Internal { left: Box<Node>, right: Box<Node> },
16    Leaf(String),
17}
18
19impl HuffmanCodec {
20    pub fn from_frequencies(freq_map: &HashMap<String, u64>) -> Result<Self, Error> {
21        // 收集符号并排序(字典序)确保确定性
22        let mut symbols: Vec<(String, u64)> = freq_map
23            .iter()
24            .map(|(k, &f)| (k.clone(), f))
25            .collect();
26        symbols.sort_by(|a, b| a.0.cmp(&b.0));
27
28        if symbols.is_empty() {
29            // 空字典:允许构建一个空的解码器(解码时会失败)
30            // 空字典:构造一个不可用的解码器,但避免崩溃
31            return Ok(HuffmanCodec {
32                enc_map: HashMap::new(),
33                root: Box::new(Node::Leaf(String::new())),
34            });
35        }
36
37        // 特殊情况:只有一个符号,分配长度1的码字 "0"
38        if symbols.len() == 1 {
39            let key = symbols[0].0.clone();
40            let mut enc_map = HashMap::new();
41            // LSB-first: 单比特0
42            enc_map.insert(key.clone(), (0, 1));
43            // 为保证解码消费1位(对齐),构造一个内部节点,左分支为 key,右分支为哑叶
44            let root = Box::new(Node::Internal {
45                left: Box::new(Node::Leaf(key)),
46                right: Box::new(Node::Leaf(String::new())),
47            });
48            return Ok(HuffmanCodec { enc_map, root });
49        }
50
51        // 1) 通过普通 Huffman 构建 code lengths(叶子深度)
52        let code_lengths = build_code_lengths(&symbols);
53
54        // 2) Canonical 编码:按 (len, key lex) 排序,生成 MSB-first 码字
55        let mut by_len: Vec<(usize, &str)> = symbols
56            .iter()
57            .enumerate()
58            .map(|(i, (k, _))| (code_lengths[i], k.as_str()))
59            .collect();
60        by_len.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(b.1)));
61
62        let max_len = by_len.iter().map(|(l, _)| *l).max().unwrap_or(1);
63        let mut bl_count = vec![0usize; max_len + 1];
64        for (l, _) in &by_len { bl_count[*l] += 1; }
65
66        // 计算每个长度的起始码(MSB-first)
67        let mut next_code = vec![0u32; max_len + 1];
68        let mut code: u32 = 0;
69        for bits in 1..=max_len {
70            code = (code + bl_count[bits - 1] as u32) << 1;
71            next_code[bits] = code;
72        }
73
74        // 3) 构建编码映射与解码树
75        let mut enc_map: HashMap<String, (u64, u8)> = HashMap::with_capacity(by_len.len());
76        let mut root = Node::Internal { left: Box::new(Node::Leaf(String::new())), right: Box::new(Node::Leaf(String::new())) };
77
78        for (len, key) in by_len {
79            if len == 0 { return Err(Error::HuffmanError); }
80            let code_msb = next_code[len];
81            next_code[len] += 1;
82
83            // 将 MSB-first 码字反转成 LSB-first 存储,便于 BitWriter 低位优先写入
84            let code_lsb = reverse_low_bits(code_msb as u64, len as u8);
85            enc_map.insert(key.to_string(), (code_lsb, len as u8));
86
87            // 在解码树中插入(按照 MSB-first 路径)
88            insert_codeword(&mut root, key, code_msb, len as u8)?;
89        }
90
91        Ok(HuffmanCodec { enc_map, root: Box::new(root) })
92    }
93
94    pub fn write_key_code(&self, key: &str, writer: &mut BitWriter) -> Result<(), Error> {
95        let (code_lsb, len) = self
96            .enc_map
97            .get(key)
98            .copied()
99            .ok_or(Error::HuffmanError)?;
100        writer.write_bits(code_lsb, len as u32);
101        Ok(())
102    }
103
104    pub fn decode_key(&self, reader: &mut BitReader) -> Result<String, Error> {
105        // 逐位读取并下行
106        let mut node = self.root.as_ref();
107        loop {
108            match node {
109                Node::Leaf(key) => return Ok(key.clone()),
110                Node::Internal { left, right } => {
111                    let bit = reader.read_bits(1)? as u8;
112                    node = if bit == 0 { left.as_ref() } else { right.as_ref() };
113                }
114            }
115        }
116    }
117
118    pub fn try_get_code(&self, key: &str) -> Option<(u64, u8)> { self.enc_map.get(key).copied() }
119}
120
121fn reverse_low_bits(mut v: u64, bits: u8) -> u64 {
122    let mut r = 0u64;
123    for _ in 0..bits {
124        r = (r << 1) | (v & 1);
125        v >>= 1;
126    }
127    r
128}
129
130#[derive(Debug, Clone)]
131struct HeapNode {
132    freq: u64,
133    // 为了确定性,包含字典序最小的叶子索引作为 tie-breaker
134    min_sym_idx: usize,
135    node: Box<TreeNode>,
136}
137
138#[derive(Debug, Clone)]
139enum TreeNode {
140    Leaf { sym_idx: usize },
141    Internal { left: Box<TreeNode>, right: Box<TreeNode> },
142}
143
144fn build_code_lengths(symbols: &[(String, u64)]) -> Vec<usize> {
145    use std::cmp::Ordering;
146    use std::collections::BinaryHeap;
147
148    // 小根堆:通过 Ord 反转实现
149    #[derive(Debug)]
150    struct OrdNode(HeapNode);
151    impl PartialEq for OrdNode { fn eq(&self, other: &Self) -> bool { self.0.freq == other.0.freq && self.0.min_sym_idx == other.0.min_sym_idx } }
152    impl Eq for OrdNode {}
153    impl PartialOrd for OrdNode {
154        fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
155    }
156    impl Ord for OrdNode {
157        fn cmp(&self, other: &Self) -> Ordering {
158            // 反转为小根堆:频率小的优先,其次最小符号索引优先
159            other.0.freq.cmp(&self.0.freq).then(other.0.min_sym_idx.cmp(&self.0.min_sym_idx))
160        }
161    }
162
163    let mut heap: BinaryHeap<OrdNode> = BinaryHeap::new();
164    for (i, (_k, f)) in symbols.iter().enumerate() {
165        heap.push(OrdNode(HeapNode { freq: *f, min_sym_idx: i, node: Box::new(TreeNode::Leaf { sym_idx: i }) }));
166    }
167
168    if heap.len() == 1 {
169        return vec![1];
170    }
171
172    while heap.len() > 1 {
173        let OrdNode(a) = heap.pop().unwrap();
174        let OrdNode(b) = heap.pop().unwrap();
175        let min_sym_idx = a.min_sym_idx.min(b.min_sym_idx);
176        let merged = HeapNode {
177            freq: a.freq + b.freq,
178            min_sym_idx,
179            node: Box::new(TreeNode::Internal { left: a.node, right: b.node }),
180        };
181        heap.push(OrdNode(merged));
182    }
183
184    let root = heap.pop().unwrap().0.node;
185    // 计算叶子深度
186    let mut code_lengths = vec![0usize; symbols.len()];
187    fn walk(node: &TreeNode, depth: usize, lens: &mut [usize]) {
188        match node {
189            TreeNode::Leaf { sym_idx } => lens[*sym_idx] = depth.max(1),
190            TreeNode::Internal { left, right } => {
191                walk(left, depth + 1, lens);
192                walk(right, depth + 1, lens);
193            }
194        }
195    }
196    walk(&root, 0, &mut code_lengths);
197    code_lengths
198}
199
200fn insert_codeword(root: &mut Node, key: &str, code_msb: u32, len: u8) -> Result<(), Error> {
201    let mut node = root;
202    for i in (0..len).rev() { // 从 MSB 到 LSB
203        let bit = ((code_msb >> i) & 1) as u8;
204        match node {
205            Node::Internal { left, right } => {
206                if bit == 0 {
207                    if matches!(left.as_ref(), Node::Leaf(s) if s.is_empty()) {
208                        // 继续向下
209                    } else if matches!(left.as_ref(), Node::Internal { .. }) {
210                        // ok
211                    } else if let Node::Leaf(_) = left.as_ref() {
212                        // 将叶子展开为内部节点
213                        *left = Box::new(Node::Internal { left: Box::new(Node::Leaf(String::new())), right: Box::new(Node::Leaf(String::new())) });
214                    }
215                    node = left.as_mut();
216                } else {
217                    if matches!(right.as_ref(), Node::Leaf(s) if s.is_empty()) {
218                        // 继续向下
219                    } else if matches!(right.as_ref(), Node::Internal { .. }) {
220                        // ok
221                    } else if let Node::Leaf(_) = right.as_ref() {
222                        *right = Box::new(Node::Internal { left: Box::new(Node::Leaf(String::new())), right: Box::new(Node::Leaf(String::new())) });
223                    }
224                    node = right.as_mut();
225                }
226            }
227            Node::Leaf(_) => {
228                // 展开叶子为内部节点
229                *node = Node::Internal { left: Box::new(Node::Leaf(String::new())), right: Box::new(Node::Leaf(String::new())) };
230                if let Node::Internal { left, right } = node {
231                    node = if bit == 0 { left.as_mut() } else { right.as_mut() };
232                }
233            }
234        }
235    }
236    // 最后位置写入叶子
237    *node = Node::Leaf(key.to_string());
238    Ok(())
239}