1use std::collections::HashMap;
2
3use crate::{bitstream::{BitReader, BitWriter}, Error};
4
5#[derive(Debug, Clone)]
6pub struct HuffmanCodec {
7 enc_map: HashMap<String, (u64, u8)>,
9 root: Box<Node>,
11}
12
13#[derive(Debug, Clone)]
14enum Node {
15 Internal { left: Box<Node>, right: Box<Node> },
16 Leaf(String),
17}
18
19impl HuffmanCodec {
20 pub fn from_frequencies(freq_map: &HashMap<String, u64>) -> Result<Self, Error> {
21 let mut symbols: Vec<(String, u64)> = freq_map
23 .iter()
24 .map(|(k, &f)| (k.clone(), f))
25 .collect();
26 symbols.sort_by(|a, b| a.0.cmp(&b.0));
27
28 if symbols.is_empty() {
29 return Ok(HuffmanCodec {
32 enc_map: HashMap::new(),
33 root: Box::new(Node::Leaf(String::new())),
34 });
35 }
36
37 if symbols.len() == 1 {
39 let key = symbols[0].0.clone();
40 let mut enc_map = HashMap::new();
41 enc_map.insert(key.clone(), (0, 1));
43 let root = Box::new(Node::Internal {
45 left: Box::new(Node::Leaf(key)),
46 right: Box::new(Node::Leaf(String::new())),
47 });
48 return Ok(HuffmanCodec { enc_map, root });
49 }
50
51 let code_lengths = build_code_lengths(&symbols);
53
54 let mut by_len: Vec<(usize, &str)> = symbols
56 .iter()
57 .enumerate()
58 .map(|(i, (k, _))| (code_lengths[i], k.as_str()))
59 .collect();
60 by_len.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(b.1)));
61
62 let max_len = by_len.iter().map(|(l, _)| *l).max().unwrap_or(1);
63 let mut bl_count = vec![0usize; max_len + 1];
64 for (l, _) in &by_len { bl_count[*l] += 1; }
65
66 let mut next_code = vec![0u32; max_len + 1];
68 let mut code: u32 = 0;
69 for bits in 1..=max_len {
70 code = (code + bl_count[bits - 1] as u32) << 1;
71 next_code[bits] = code;
72 }
73
74 let mut enc_map: HashMap<String, (u64, u8)> = HashMap::with_capacity(by_len.len());
76 let mut root = Node::Internal { left: Box::new(Node::Leaf(String::new())), right: Box::new(Node::Leaf(String::new())) };
77
78 for (len, key) in by_len {
79 if len == 0 { return Err(Error::HuffmanError); }
80 let code_msb = next_code[len];
81 next_code[len] += 1;
82
83 let code_lsb = reverse_low_bits(code_msb as u64, len as u8);
85 enc_map.insert(key.to_string(), (code_lsb, len as u8));
86
87 insert_codeword(&mut root, key, code_msb, len as u8)?;
89 }
90
91 Ok(HuffmanCodec { enc_map, root: Box::new(root) })
92 }
93
94 pub fn write_key_code(&self, key: &str, writer: &mut BitWriter) -> Result<(), Error> {
95 let (code_lsb, len) = self
96 .enc_map
97 .get(key)
98 .copied()
99 .ok_or(Error::HuffmanError)?;
100 writer.write_bits(code_lsb, len as u32);
101 Ok(())
102 }
103
104 pub fn decode_key(&self, reader: &mut BitReader) -> Result<String, Error> {
105 let mut node = self.root.as_ref();
107 loop {
108 match node {
109 Node::Leaf(key) => return Ok(key.clone()),
110 Node::Internal { left, right } => {
111 let bit = reader.read_bits(1)? as u8;
112 node = if bit == 0 { left.as_ref() } else { right.as_ref() };
113 }
114 }
115 }
116 }
117
118 pub fn try_get_code(&self, key: &str) -> Option<(u64, u8)> { self.enc_map.get(key).copied() }
119}
120
121fn reverse_low_bits(mut v: u64, bits: u8) -> u64 {
122 let mut r = 0u64;
123 for _ in 0..bits {
124 r = (r << 1) | (v & 1);
125 v >>= 1;
126 }
127 r
128}
129
130#[derive(Debug, Clone)]
131struct HeapNode {
132 freq: u64,
133 min_sym_idx: usize,
135 node: Box<TreeNode>,
136}
137
138#[derive(Debug, Clone)]
139enum TreeNode {
140 Leaf { sym_idx: usize },
141 Internal { left: Box<TreeNode>, right: Box<TreeNode> },
142}
143
144fn build_code_lengths(symbols: &[(String, u64)]) -> Vec<usize> {
145 use std::cmp::Ordering;
146 use std::collections::BinaryHeap;
147
148 #[derive(Debug)]
150 struct OrdNode(HeapNode);
151 impl PartialEq for OrdNode { fn eq(&self, other: &Self) -> bool { self.0.freq == other.0.freq && self.0.min_sym_idx == other.0.min_sym_idx } }
152 impl Eq for OrdNode {}
153 impl PartialOrd for OrdNode {
154 fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
155 }
156 impl Ord for OrdNode {
157 fn cmp(&self, other: &Self) -> Ordering {
158 other.0.freq.cmp(&self.0.freq).then(other.0.min_sym_idx.cmp(&self.0.min_sym_idx))
160 }
161 }
162
163 let mut heap: BinaryHeap<OrdNode> = BinaryHeap::new();
164 for (i, (_k, f)) in symbols.iter().enumerate() {
165 heap.push(OrdNode(HeapNode { freq: *f, min_sym_idx: i, node: Box::new(TreeNode::Leaf { sym_idx: i }) }));
166 }
167
168 if heap.len() == 1 {
169 return vec![1];
170 }
171
172 while heap.len() > 1 {
173 let OrdNode(a) = heap.pop().unwrap();
174 let OrdNode(b) = heap.pop().unwrap();
175 let min_sym_idx = a.min_sym_idx.min(b.min_sym_idx);
176 let merged = HeapNode {
177 freq: a.freq + b.freq,
178 min_sym_idx,
179 node: Box::new(TreeNode::Internal { left: a.node, right: b.node }),
180 };
181 heap.push(OrdNode(merged));
182 }
183
184 let root = heap.pop().unwrap().0.node;
185 let mut code_lengths = vec![0usize; symbols.len()];
187 fn walk(node: &TreeNode, depth: usize, lens: &mut [usize]) {
188 match node {
189 TreeNode::Leaf { sym_idx } => lens[*sym_idx] = depth.max(1),
190 TreeNode::Internal { left, right } => {
191 walk(left, depth + 1, lens);
192 walk(right, depth + 1, lens);
193 }
194 }
195 }
196 walk(&root, 0, &mut code_lengths);
197 code_lengths
198}
199
200fn insert_codeword(root: &mut Node, key: &str, code_msb: u32, len: u8) -> Result<(), Error> {
201 let mut node = root;
202 for i in (0..len).rev() { let bit = ((code_msb >> i) & 1) as u8;
204 match node {
205 Node::Internal { left, right } => {
206 if bit == 0 {
207 if matches!(left.as_ref(), Node::Leaf(s) if s.is_empty()) {
208 } else if matches!(left.as_ref(), Node::Internal { .. }) {
210 } else if let Node::Leaf(_) = left.as_ref() {
212 *left = Box::new(Node::Internal { left: Box::new(Node::Leaf(String::new())), right: Box::new(Node::Leaf(String::new())) });
214 }
215 node = left.as_mut();
216 } else {
217 if matches!(right.as_ref(), Node::Leaf(s) if s.is_empty()) {
218 } else if matches!(right.as_ref(), Node::Internal { .. }) {
220 } else if let Node::Leaf(_) = right.as_ref() {
222 *right = Box::new(Node::Internal { left: Box::new(Node::Leaf(String::new())), right: Box::new(Node::Leaf(String::new())) });
223 }
224 node = right.as_mut();
225 }
226 }
227 Node::Leaf(_) => {
228 *node = Node::Internal { left: Box::new(Node::Leaf(String::new())), right: Box::new(Node::Leaf(String::new())) };
230 if let Node::Internal { left, right } = node {
231 node = if bit == 0 { left.as_mut() } else { right.as_mut() };
232 }
233 }
234 }
235 }
236 *node = Node::Leaf(key.to_string());
238 Ok(())
239}