1use std::{collections::HashMap, fmt, hash::Hash};
2
3use bit_vec::BitVec;
4
5#[derive(Debug, thiserror::Error)]
6pub enum Error {
7 #[error("No such key in encoding dictionary: {0}")]
8 NoSuchKey(String),
9 #[error("Invalid weight nodes: {0}")]
10 InvalidWeights(String),
11}
12
13type Result<T> = std::result::Result<T, Error>;
14
15struct Node<T> {
16 freq: u32,
17 value: Option<T>,
18 left: Option<Box<Node<T>>>,
19 right: Option<Box<Node<T>>>,
20}
21
22impl<T> Node<T> {
23 fn new(freq: u32, value: Option<T>) -> Self {
24 Self {
25 freq,
26 value,
27 left: None,
28 right: None,
29 }
30 }
31}
32
33pub struct Encoder<T> {
34 encoding: HashMap<T, BitVec>,
35}
36
37impl<T> Encoder<T>
38where
39 T: Eq + Hash + Clone + fmt::Debug,
40{
41 fn new(root: &Node<T>) -> Self {
42 fn assign<T>(node: &Node<T>, encoding: &mut HashMap<T, BitVec>, mut current_bits: BitVec)
43 where
44 T: Eq + Hash + Clone,
45 {
46 if let Some(ch) = node.value.as_ref() {
47 encoding.insert(ch.clone(), current_bits);
48 } else {
49 if let Some(ref l) = node.left {
50 let mut bits = current_bits.clone();
51 bits.push(false);
52 assign(l, encoding, bits);
53 }
54 if let Some(ref r) = node.right {
55 current_bits.push(true);
56 assign(r, encoding, current_bits);
57 }
58 }
59 }
60 let mut encoding = HashMap::new();
61 let bits = BitVec::new();
62 assign(&root, &mut encoding, bits);
63 Self { encoding }
64 }
65
66 pub fn encode(&self, data: &[T]) -> Result<BitVec> {
67 let mut vec = BitVec::new();
68 for item in data {
69 let mut encoding = self
70 .encoding
71 .get(item)
72 .ok_or_else(|| Error::NoSuchKey(format!("{:?}", item)))?
73 .clone();
74 vec.append(&mut encoding);
75 }
76 Ok(vec)
77 }
78}
79
80pub struct Decoder<T> {
81 root: Node<T>,
82}
83
84impl<T> Decoder<T> {
85 fn new(root: Node<T>) -> Self {
86 Self { root }
87 }
88
89 pub fn decode_iter<'a>(&'a self, encoded: &'a BitVec) -> impl Iterator<Item = &'a T> {
90 DecoderIter {
91 input: encoded.iter(),
92 root: &self.root,
93 current_node: &self.root,
94 }
95 }
96
97 pub fn decode<'a>(&'a self, encoded: &'a BitVec) -> Vec<&'a T> {
98 self.decode_iter(encoded).collect()
99 }
100
101 pub fn decode_owned(&self, encoded: &BitVec) -> Vec<T>
102 where
103 T: Clone,
104 {
105 self.decode_iter(encoded).cloned().collect()
106 }
107}
108
109struct DecoderIter<'a, T> {
110 root: &'a Node<T>,
111 input: bit_vec::Iter<'a>,
112 current_node: &'a Node<T>,
113}
114
115impl<'a, T> Iterator for DecoderIter<'a, T> {
116 type Item = &'a T;
117 fn next(&mut self) -> Option<Self::Item> {
118 let bit = self.input.next()?;
119 if bit {
120 if let Some(ref right) = self.current_node.right {
121 self.current_node = right;
122 }
123 } else if let Some(ref left) = self.current_node.left {
124 self.current_node = left;
125 }
126 if let Some(value) = self.current_node.value.as_ref() {
127 self.current_node = &self.root;
128 Some(value)
129 } else {
130 self.next()
131 }
132 }
133}
134
135pub struct Huffman<T> {
136 encoder: Encoder<T>,
137 decoder: Decoder<T>,
138}
139
140impl<T> Huffman<T>
141where
142 T: Eq + Hash + Clone + fmt::Debug,
143{
144 pub fn new(weights: impl IntoIterator<Item = (T, u32)>) -> Result<Self> {
145 let mut nodes = weights
146 .into_iter()
147 .map(|(value, frequency)| Box::new(Node::new(frequency, Some(value))))
148 .collect::<Vec<_>>();
149
150 while nodes.len() > 1 {
151 nodes.sort_by(|a, b| (&(b.freq)).cmp(&(a.freq)));
152 let a = nodes
153 .pop()
154 .ok_or_else(|| Error::InvalidWeights("Expected at least 1 node".to_string()))?;
155 let b = nodes
156 .pop()
157 .ok_or_else(|| Error::InvalidWeights("Expected at least 1 node".to_string()))?;
158 let mut c = Node::new(a.freq + b.freq, None);
159 c.left = Some(a);
160 c.right = Some(b);
161 nodes.push(Box::new(c));
162 }
163
164 let root = *nodes
165 .pop()
166 .ok_or_else(|| Error::InvalidWeights("Expected root node".to_string()))?;
167 let encoder = Encoder::new(&root);
168 let decoder = Decoder::new(root);
169 Ok(Self { encoder, decoder })
170 }
171
172 pub fn encode(&self, data: &[T]) -> Result<BitVec> {
174 self.encoder.encode(data)
175 }
176
177 pub fn decode<'a>(&'a self, encoded: &'a BitVec) -> Vec<&'a T> {
178 self.decoder.decode(encoded)
179 }
180
181 pub fn decode_iter<'a>(&'a self, encoded: &'a BitVec) -> impl Iterator<Item = &'a T> {
182 self.decoder.decode_iter(encoded)
183 }
184
185 pub fn decode_owned(&self, encoded: &BitVec) -> Vec<T> {
186 self.decoder.decode_owned(encoded)
187 }
188
189 pub fn split(self) -> (Encoder<T>, Decoder<T>) {
191 (self.encoder, self.decoder)
192 }
193}
194
195#[cfg(test)]
196mod tests {
197
198 use super::*;
199
200 #[test]
201 fn test_encode_decode_i32() {
202 let weights = vec![(0, 10), (1, 1), (2, 5)];
203 let huffman = Huffman::new(weights).unwrap();
204 let data = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
205 let encoded = huffman.encode(&data).unwrap();
206 let decoded = huffman.decode_owned(&encoded);
207 assert_eq!(data, decoded);
208 }
209
210 #[test]
211 fn test_split() {
212 let weights = vec![(0, 10), (1, 1), (2, 5)];
213 let (encoder, decoder) = Huffman::new(weights).unwrap().split();
214 let data = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
215 let encoded = encoder.encode(&data).unwrap();
216 let decoded = decoder.decode_owned(&encoded);
217 assert_eq!(data, decoded);
218 }
219
220 #[test]
221 fn test_encode_decode_string() {
222 let weights = vec![
223 ("hello".to_string(), 2),
224 ("hey".to_string(), 3),
225 ("howdy".to_string(), 1),
226 ];
227 let huffman = Huffman::new(weights).unwrap();
228 let data = vec!["howdy".into(), "howdy".into(), "hey".into(), "hello".into()];
229 let encoded = huffman.encode(&data).unwrap();
230 let decoded = huffman.decode_owned(&encoded);
231 assert_eq!(data, decoded);
232 }
233}