boytacean_encoding/
huffman.rs

1use boytacean_common::error::Error;
2use std::{
3    cmp::Ordering,
4    collections::BinaryHeap,
5    io::{Cursor, Read},
6    mem::size_of,
7};
8
9use crate::codec::Codec;
10
11#[derive(Debug, Eq, PartialEq)]
12struct Node {
13    frequency: u32,
14    character: Option<u8>,
15    left: Option<Box<Node>>,
16    right: Option<Box<Node>>,
17}
18
19impl Ord for Node {
20    fn cmp(&self, other: &Self) -> Ordering {
21        other.frequency.cmp(&self.frequency)
22    }
23}
24
25impl PartialOrd for Node {
26    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
27        Some(self.cmp(other))
28    }
29}
30
31pub struct Huffman;
32
33impl Huffman {
34    fn build_frequency(data: &[u8]) -> [u32; 256] {
35        let mut frequency_map = [0_u32; 256];
36        for &byte in data {
37            frequency_map[byte as usize] += 1;
38        }
39        frequency_map
40    }
41
42    fn build_tree(frequency_map: &[u32; 256]) -> Option<Box<Node>> {
43        let mut heap: BinaryHeap<Box<Node>> = BinaryHeap::new();
44
45        for (byte, &frequency) in frequency_map.iter().enumerate() {
46            if frequency == 0 {
47                continue;
48            }
49            heap.push(Box::new(Node {
50                frequency,
51                character: Some(byte as u8),
52                left: None,
53                right: None,
54            }));
55        }
56
57        while heap.len() > 1 {
58            let left = heap.pop().unwrap();
59            let right = heap.pop().unwrap();
60
61            let merged = Box::new(Node {
62                frequency: left.frequency + right.frequency,
63                character: None,
64                left: Some(left),
65                right: Some(right),
66            });
67
68            heap.push(merged);
69        }
70
71        heap.pop()
72    }
73
74    fn build_codes(node: &Node, prefix: Vec<u8>, codes: &mut [Vec<u8>]) {
75        if let Some(character) = node.character {
76            codes[character as usize] = prefix;
77        } else {
78            if let Some(ref left) = node.left {
79                let mut left_prefix = prefix.clone();
80                left_prefix.push(0);
81                Self::build_codes(left, left_prefix, codes);
82            }
83            if let Some(ref right) = node.right {
84                let mut right_prefix = prefix;
85                right_prefix.push(1);
86                Self::build_codes(right, right_prefix, codes);
87            }
88        }
89    }
90
91    fn encode_data(data: &[u8], codes: &[Vec<u8>]) -> Vec<u8> {
92        let mut bit_buffer = Vec::new();
93        let mut current_byte = 0u8;
94        let mut bit_count = 0;
95
96        for &byte in data {
97            let code = &codes[byte as usize];
98            for &bit in code {
99                current_byte <<= 1;
100                if bit == 1 {
101                    current_byte |= 1;
102                }
103                bit_count += 1;
104
105                if bit_count == 8 {
106                    bit_buffer.push(current_byte);
107                    current_byte = 0;
108                    bit_count = 0;
109                }
110            }
111        }
112
113        if bit_count > 0 {
114            current_byte <<= 8 - bit_count;
115            bit_buffer.push(current_byte);
116        }
117
118        bit_buffer
119    }
120
121    fn decode_data(encoded: &[u8], root: &Node, data_length: u64) -> Vec<u8> {
122        let mut decoded = Vec::new();
123        let mut current_node = root;
124        let mut bit_index = 0;
125
126        for &byte in encoded {
127            if decoded.len() as u64 == data_length {
128                break;
129            }
130
131            for bit_offset in (0..8).rev() {
132                let bit = (byte >> bit_offset) & 1;
133                current_node = if bit == 0 {
134                    current_node.left.as_deref().unwrap()
135                } else {
136                    current_node.right.as_deref().unwrap()
137                };
138
139                if let Some(character) = current_node.character {
140                    decoded.push(character);
141                    current_node = root;
142                }
143
144                if decoded.len() as u64 == data_length {
145                    break;
146                }
147
148                bit_index += 1;
149                if bit_index == encoded.len() * 8 {
150                    break;
151                }
152            }
153        }
154
155        decoded
156    }
157
158    fn encode_tree(node: &Node) -> Vec<u8> {
159        let mut result = Vec::new();
160        if let Some(character) = node.character {
161            result.push(1);
162            result.push(character);
163        } else {
164            result.push(0);
165            if let Some(ref left) = node.left {
166                result.extend(Self::encode_tree(left));
167            }
168            if let Some(ref right) = node.right {
169                result.extend(Self::encode_tree(right));
170            }
171        }
172        result
173    }
174
175    fn decode_tree(data: &mut &[u8]) -> Box<Node> {
176        let mut node = Box::new(Node {
177            frequency: 0,
178            character: None,
179            left: None,
180            right: None,
181        });
182
183        if data[0] == 1 {
184            node.character = Some(data[1]);
185            *data = &data[2..];
186        } else {
187            *data = &data[1..];
188            node.left = Some(Self::decode_tree(data));
189            node.right = Some(Self::decode_tree(data));
190        }
191        node
192    }
193}
194
195impl Codec for Huffman {
196    type EncodeOptions = ();
197    type DecodeOptions = ();
198
199    fn encode(data: &[u8], _options: &Self::EncodeOptions) -> Result<Vec<u8>, Error> {
200        let frequency_map = Self::build_frequency(data);
201        let tree = Self::build_tree(&frequency_map)
202            .ok_or(Error::CustomError(String::from("Failed to build tree")))?;
203
204        let mut codes = vec![Vec::new(); 256];
205        Self::build_codes(&tree, Vec::new(), &mut codes);
206
207        let encoded_tree = Self::encode_tree(&tree);
208        let encoded_data = Self::encode_data(data, &codes);
209        let tree_length = encoded_tree.len() as u32;
210        let data_length = data.len() as u64;
211
212        let mut result = Vec::new();
213        result.extend(tree_length.to_be_bytes());
214        result.extend(encoded_tree);
215        result.extend(data_length.to_be_bytes());
216        result.extend(encoded_data);
217
218        Ok(result)
219    }
220
221    fn decode(data: &[u8], _options: &Self::DecodeOptions) -> Result<Vec<u8>, Error> {
222        let mut reader = Cursor::new(data);
223
224        let mut buffer = [0x00; size_of::<u32>()];
225        reader.read_exact(&mut buffer)?;
226        let tree_length = u32::from_be_bytes(buffer);
227
228        let mut buffer = vec![0; tree_length as usize];
229        reader.read_exact(&mut buffer)?;
230        let tree = Self::decode_tree(&mut buffer.as_slice());
231
232        let mut buffer = [0x00; size_of::<u64>()];
233        reader.read_exact(&mut buffer)?;
234        let data_length = u64::from_be_bytes(buffer);
235
236        let mut buffer =
237            vec![0; data.len() - size_of::<u32>() - tree_length as usize - size_of::<u64>()];
238        reader.read_exact(&mut buffer)?;
239
240        let result = Self::decode_data(&buffer, &tree, data_length);
241
242        Ok(result)
243    }
244}
245
246pub fn encode_huffman(data: &[u8]) -> Result<Vec<u8>, Error> {
247    Huffman::encode(data, &())
248}
249
250pub fn decode_huffman(data: &[u8]) -> Result<Vec<u8>, Error> {
251    Huffman::decode(data, &())
252}
253
254#[cfg(test)]
255mod tests {
256    use super::{decode_huffman, encode_huffman};
257
258    #[test]
259    fn test_huffman_encoding() {
260        let data = b"this is an example for huffman encoding, huffman encoding, huffman encoding";
261        let encoded = encode_huffman(data).unwrap();
262        let decoded = decode_huffman(&encoded).unwrap();
263        assert_eq!(data.to_vec(), decoded);
264        assert_eq!(encoded.len(), 109);
265        assert_eq!(decoded.len(), 75);
266    }
267}