huffman_encoding/
lib.rs

1use std::{collections::HashMap, fmt, hash::Hash};
2
3use bit_vec::BitVec;
4
5#[derive(Debug, thiserror::Error)]
6pub enum Error {
7    #[error("No such key in encoding dictionary: {0}")]
8    NoSuchKey(String),
9    #[error("Invalid weight nodes: {0}")]
10    InvalidWeights(String),
11}
12
13type Result<T> = std::result::Result<T, Error>;
14
15struct Node<T> {
16    freq: u32,
17    value: Option<T>,
18    left: Option<Box<Node<T>>>,
19    right: Option<Box<Node<T>>>,
20}
21
22impl<T> Node<T> {
23    fn new(freq: u32, value: Option<T>) -> Self {
24        Self {
25            freq,
26            value,
27            left: None,
28            right: None,
29        }
30    }
31}
32
33pub struct Encoder<T> {
34    encoding: HashMap<T, BitVec>,
35}
36
37impl<T> Encoder<T>
38where
39    T: Eq + Hash + Clone + fmt::Debug,
40{
41    fn new(root: &Node<T>) -> Self {
42        fn assign<T>(node: &Node<T>, encoding: &mut HashMap<T, BitVec>, mut current_bits: BitVec)
43        where
44            T: Eq + Hash + Clone,
45        {
46            if let Some(ch) = node.value.as_ref() {
47                encoding.insert(ch.clone(), current_bits);
48            } else {
49                if let Some(ref l) = node.left {
50                    let mut bits = current_bits.clone();
51                    bits.push(false);
52                    assign(l, encoding, bits);
53                }
54                if let Some(ref r) = node.right {
55                    current_bits.push(true);
56                    assign(r, encoding, current_bits);
57                }
58            }
59        }
60        let mut encoding = HashMap::new();
61        let bits = BitVec::new();
62        assign(&root, &mut encoding, bits);
63        Self { encoding }
64    }
65
66    pub fn encode(&self, data: &[T]) -> Result<BitVec> {
67        let mut vec = BitVec::new();
68        for item in data {
69            let mut encoding = self
70                .encoding
71                .get(item)
72                .ok_or_else(|| Error::NoSuchKey(format!("{:?}", item)))?
73                .clone();
74            vec.append(&mut encoding);
75        }
76        Ok(vec)
77    }
78}
79
80pub struct Decoder<T> {
81    root: Node<T>,
82}
83
84impl<T> Decoder<T> {
85    fn new(root: Node<T>) -> Self {
86        Self { root }
87    }
88
89    pub fn decode_iter<'a>(&'a self, encoded: &'a BitVec) -> impl Iterator<Item = &'a T> {
90        DecoderIter {
91            input: encoded.iter(),
92            root: &self.root,
93            current_node: &self.root,
94        }
95    }
96
97    pub fn decode<'a>(&'a self, encoded: &'a BitVec) -> Vec<&'a T> {
98        self.decode_iter(encoded).collect()
99    }
100
101    pub fn decode_owned(&self, encoded: &BitVec) -> Vec<T>
102    where
103        T: Clone,
104    {
105        self.decode_iter(encoded).cloned().collect()
106    }
107}
108
109struct DecoderIter<'a, T> {
110    root: &'a Node<T>,
111    input: bit_vec::Iter<'a>,
112    current_node: &'a Node<T>,
113}
114
115impl<'a, T> Iterator for DecoderIter<'a, T> {
116    type Item = &'a T;
117    fn next(&mut self) -> Option<Self::Item> {
118        let bit = self.input.next()?;
119        if bit {
120            if let Some(ref right) = self.current_node.right {
121                self.current_node = right;
122            }
123        } else if let Some(ref left) = self.current_node.left {
124            self.current_node = left;
125        }
126        if let Some(value) = self.current_node.value.as_ref() {
127            self.current_node = &self.root;
128            Some(value)
129        } else {
130            self.next()
131        }
132    }
133}
134
135pub struct Huffman<T> {
136    encoder: Encoder<T>,
137    decoder: Decoder<T>,
138}
139
140impl<T> Huffman<T>
141where
142    T: Eq + Hash + Clone + fmt::Debug,
143{
144    pub fn new(weights: impl IntoIterator<Item = (T, u32)>) -> Result<Self> {
145        let mut nodes = weights
146            .into_iter()
147            .map(|(value, frequency)| Box::new(Node::new(frequency, Some(value))))
148            .collect::<Vec<_>>();
149
150        while nodes.len() > 1 {
151            nodes.sort_by(|a, b| (&(b.freq)).cmp(&(a.freq)));
152            let a = nodes
153                .pop()
154                .ok_or_else(|| Error::InvalidWeights("Expected at least 1 node".to_string()))?;
155            let b = nodes
156                .pop()
157                .ok_or_else(|| Error::InvalidWeights("Expected at least 1 node".to_string()))?;
158            let mut c = Node::new(a.freq + b.freq, None);
159            c.left = Some(a);
160            c.right = Some(b);
161            nodes.push(Box::new(c));
162        }
163
164        let root = *nodes
165            .pop()
166            .ok_or_else(|| Error::InvalidWeights("Expected root node".to_string()))?;
167        let encoder = Encoder::new(&root);
168        let decoder = Decoder::new(root);
169        Ok(Self { encoder, decoder })
170    }
171
172    /// Encodes data into a BitVec. Fails if any of the data is not present in the dictionary.
173    pub fn encode(&self, data: &[T]) -> Result<BitVec> {
174        self.encoder.encode(data)
175    }
176
177    pub fn decode<'a>(&'a self, encoded: &'a BitVec) -> Vec<&'a T> {
178        self.decoder.decode(encoded)
179    }
180
181    pub fn decode_iter<'a>(&'a self, encoded: &'a BitVec) -> impl Iterator<Item = &'a T> {
182        self.decoder.decode_iter(encoded)
183    }
184
185    pub fn decode_owned(&self, encoded: &BitVec) -> Vec<T> {
186        self.decoder.decode_owned(encoded)
187    }
188
189    /// Split into a Encoder and Decoder.
190    pub fn split(self) -> (Encoder<T>, Decoder<T>) {
191        (self.encoder, self.decoder)
192    }
193}
194
195#[cfg(test)]
196mod tests {
197
198    use super::*;
199
200    #[test]
201    fn test_encode_decode_i32() {
202        let weights = vec![(0, 10), (1, 1), (2, 5)];
203        let huffman = Huffman::new(weights).unwrap();
204        let data = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
205        let encoded = huffman.encode(&data).unwrap();
206        let decoded = huffman.decode_owned(&encoded);
207        assert_eq!(data, decoded);
208    }
209
210    #[test]
211    fn test_split() {
212        let weights = vec![(0, 10), (1, 1), (2, 5)];
213        let (encoder, decoder) = Huffman::new(weights).unwrap().split();
214        let data = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
215        let encoded = encoder.encode(&data).unwrap();
216        let decoded = decoder.decode_owned(&encoded);
217        assert_eq!(data, decoded);
218    }
219
220    #[test]
221    fn test_encode_decode_string() {
222        let weights = vec![
223            ("hello".to_string(), 2),
224            ("hey".to_string(), 3),
225            ("howdy".to_string(), 1),
226        ];
227        let huffman = Huffman::new(weights).unwrap();
228        let data = vec!["howdy".into(), "howdy".into(), "hey".into(), "hello".into()];
229        let encoded = huffman.encode(&data).unwrap();
230        let decoded = huffman.decode_owned(&encoded);
231        assert_eq!(data, decoded);
232    }
233}