1use crate::data::{BitVec, Padded, PaddedBits, UnPadded, UnPaddedBits};
2use crate::encoding_map::EncodingMap;
3use crate::encoding_stats::EncodingStats;
4use crate::error::Result;
5use crate::frequency_map::{FrequencyMap, FrequencyMapping};
6use crate::huffman_tree::{self, Node};
7
8use std::collections::HashMap;
9
10#[derive(Debug)]
12pub struct HuffmanData {
13 pub encoded_data: Vec<u8>,
15 pub encoding_map: HashMap<u8, String>,
17 pub stats: EncodingStats,
19}
20
21impl HuffmanData {
22 pub fn new(data: &[u8]) -> Result<HuffmanData> {
41 let frequency_map: FrequencyMap = FrequencyMap::build(data);
42 let huffman_tree: Node = huffman_tree::build(&frequency_map)?;
43 let encoding_map: EncodingMap = EncodingMap::new(&huffman_tree)?;
44
45 let encoded_data: UnPaddedBits = Self::huffman_encode(data, &encoding_map);
46 let encoded_data: PaddedBits = encoded_data.pad();
47 let encoded_data = encoded_data.to_vec_u8()?;
48 let stats: EncodingStats = EncodingStats::new(data, &encoded_data);
49
50 let huffman_encoded_data = HuffmanData {
51 encoded_data,
52 encoding_map: encoding_map.extract().0,
53 stats,
54 };
55 Ok(huffman_encoded_data)
56 }
57
58 pub fn decode(&self) -> Result<Vec<u8>> {
77 let encoded_data: PaddedBits = PaddedBits::from_vec_u8(&self.encoded_data);
78 let encoded_data: UnPaddedBits = encoded_data.unpad();
79 let encoding_map: EncodingMap = EncodingMap::from(self.encoding_map.clone());
80 let decoded_data = Self::huffman_decode(&encoded_data, &encoding_map);
81
82 Ok(decoded_data)
83 }
84
85 fn huffman_decode(encoded_data: &UnPaddedBits, encoding_map: &EncodingMap) -> Vec<u8> {
86 let mut data: Vec<u8> = Vec::with_capacity(encoded_data.len());
87 let mut code = BitVec::with_capacity(encoding_map.get_longest_code());
88 let min_len = encoding_map.get_shortest_code();
89
90 for code_bit in encoded_data {
91 code.push(*code_bit);
92 if code.len() < min_len {
93 continue;
94 }
95
96 if let Some(&byte) = encoding_map.get_inverse(&code) {
97 code.clear();
98 data.push(byte);
99 }
100 }
101 data
102 }
103
104 fn huffman_encode(data: &[u8], encoding_map: &EncodingMap) -> UnPaddedBits {
105 let mut encoded_data = UnPaddedBits::new();
106 for c in data {
107 if let Some(code) = encoding_map.get(c) {
108 encoded_data.extend_from_slice(code);
109 }
110 }
111 encoded_data
112 }
113}
114
115#[cfg(test)]
117mod tests {
118 use crate::data::BitVector;
119
120 use super::*;
121
122 #[test]
123 fn test_huffman_encode() {
124 let input_data: Vec<u8> = Vec::from("this is a test string!");
125 let input_encoding_map: HashMap<u8, String> = [
126 (b'h', "10010"),
127 (b'a', "0011"),
128 (b' ', "01"),
129 (b'g', "0001"),
130 (b'i', "101"),
131 (b's', "110"),
132 (b'!', "0010"),
133 (b'n', "10011"),
134 (b'r', "1000"),
135 (b't', "111"),
136 (b'e', "0000"),
137 ]
138 .iter()
139 .map(|(k, v)| (*k, v.to_string()))
140 .collect();
141 let input_encoding_map = EncodingMap::from(input_encoding_map);
142
143 let expected_data = UnPaddedBits::from_string(
144 "11110010101110011011100100110111100001101110111011110001011001100010010",
145 );
146
147 let test_output = HuffmanData::huffman_encode(&input_data, &input_encoding_map);
148
149 assert_eq!(expected_data, test_output);
150 }
151
152 #[test]
153 fn test_huffman_decode() {
154 let input_data = UnPaddedBits::from_string(
155 "11110010101110011011100100110111100001101110111011110001011001100010010",
156 );
157 let input_encoding_map: HashMap<u8, String> = [
158 (b'h', "10010"),
159 (b'a', "0011"),
160 (b' ', "01"),
161 (b'g', "0001"),
162 (b'i', "101"),
163 (b's', "110"),
164 (b'!', "0010"),
165 (b'n', "10011"),
166 (b'r', "1000"),
167 (b't', "111"),
168 (b'e', "0000"),
169 ]
170 .iter()
171 .map(|(k, v)| (*k, v.to_string()))
172 .collect();
173 let input_encoding_map = EncodingMap::from(input_encoding_map);
174
175 let expected_data: Vec<u8> = Vec::from("this is a test string!");
176
177 let test_output = HuffmanData::huffman_decode(&input_data, &input_encoding_map);
178 println!("{:?}", input_encoding_map.extract());
179 assert_eq!(expected_data, test_output);
180 assert_eq!(
181 String::from_utf8(expected_data).unwrap(),
182 String::from_utf8(test_output).unwrap()
183 );
184 }
185}