1use std::{collections::HashMap, convert::TryInto};
2use tiny_bitstream::{BitDstream, BitEstream, BitReader, BitWriter};
3
4pub fn build_cumulative_symbol_frequency(normalized_counter: &[usize]) -> Vec<usize> {
14 let mut cs = Vec::with_capacity(normalized_counter.len() + 1);
15
16 let cumul_fn = |acc, frequency| {
17 cs.push(acc);
18 acc + frequency
19 };
20 let sum = normalized_counter.iter().fold(0, cumul_fn);
21 cs.push(sum);
22 cs
23}
24
25pub fn compress(state: usize, table_log: usize, frequency: usize, cumul: usize) -> usize {
26 #[cfg(feature = "checks")]
27 if frequency == 0 {
30 panic!("attemp division by zero because of an unexpected null frequency")
31 }
32
33 ((state / frequency) << table_log) + (state % frequency) + cumul
38}
39
40pub fn simple_normalization(histogram: &mut [usize], cumul: &mut [usize], table_log: usize) {
41 let mut previous = 0;
43 let max_cumul = *cumul.last().unwrap();
44 let target_range = 1 << table_log; let actual_range = max_cumul; cumul.iter_mut().enumerate().skip(1).for_each(|(i, c)| {
48 *c = (target_range * (*c)) / actual_range;
49 if *c <= previous {
50 panic!("table log too low");
51 }
67
68 histogram[i - 1] = *c - previous;
69 previous = *c;
70 });
71}
72
73pub fn encode(
76 hist: &mut [usize],
77 symbol_index: &HashMap<u16, usize>,
78 table_log: usize, src: &[u16],
80) -> (usize, Vec<u32>, Vec<u8>) {
81 let mut cs = build_cumulative_symbol_frequency(hist);
82
83 simple_normalization(hist, &mut cs, table_log);
84
85 let mut state = 0;
86
87 let d = 32 - table_log;
88 let msk = 2usize.pow(16) - 1;
89
90 let mut estream = BitEstream::new();
91
92 let mut nb_bits_table = vec![];
93
94 src.iter().for_each(|symbol| {
95 let index = *symbol_index.get(symbol).unwrap();
96 let fs = *hist.get(index).unwrap();
97 if state >= (fs << d) {
98 let bits = state & msk;
99 let nb_bits = u64::BITS - bits.leading_zeros();
100 estream.unchecked_write(bits, nb_bits.try_into().unwrap());
101 nb_bits_table.push(nb_bits);
102 state >>= 16;
103 };
104
105 state = compress(state, table_log, fs, *cs.get(index).unwrap());
106 });
107 (state, nb_bits_table, estream.try_into().unwrap())
108}
109
110pub fn encode_u8(
113 hist: &mut [usize],
114 table_log: usize, src: &[u8],
116) -> (usize, Vec<u32>, Vec<u8>) {
117 let mut cs = build_cumulative_symbol_frequency(hist);
118
119 simple_normalization(hist, &mut cs, table_log);
120
121 let mut state = 0;
122
123 let d = 32 - table_log;
128 let msk = 2usize.pow(16) - 1;
129
130 let mut estream = BitEstream::new();
131
132 let mut nb_bits_table = vec![];
133
134 src.iter().for_each(|symbol| {
135 let index = *symbol as usize;
136 let fs = hist[index];
137 if state >= (fs << d) {
152 let bits = state & msk;
157 let nb_bits = u64::BITS - bits.leading_zeros();
158 estream.unchecked_write(bits, nb_bits.try_into().unwrap());
159 nb_bits_table.push(nb_bits);
160 state >>= 16;
161 };
162
163 state = compress(state, table_log, fs, cs[index]);
164 });
165 (state, nb_bits_table, estream.try_into().unwrap())
167}
168
169pub fn find_s(state: usize, cs: &[usize]) -> usize {
172 for (i, &c) in cs.iter().enumerate() {
173 if c == state {
174 return i;
175 }
176 if c > state {
177 return i - 1;
178 }
179 }
180 0
181}
182
183pub fn decompress(state: usize, frequency: usize, table_log: usize, cumul: usize) -> usize {
184 let mask = 2usize.pow(table_log as u32) - 1;
185 (frequency * (state >> table_log)) + (state & mask) - cumul
186}
187
188pub fn decode(
190 mut state: usize,
191 mut bits: Vec<u32>,
192 str: Vec<u8>,
193 normalized_counter: &[usize],
194 symbols: &[u16],
195 table_log: usize,
196) -> Vec<u16> {
197 let mask = 2usize.pow(table_log as u32) - 1;
198
199 let mut dstream: BitDstream = str.try_into().unwrap();
200 dstream.read(1).unwrap(); let cs = build_cumulative_symbol_frequency(normalized_counter);
203 let mut ret = vec![];
204 while state > 0 {
205 let symbol_index = find_s(state & mask, &cs);
208 ret.push(*symbols.get(symbol_index).expect("symbol not found"));
209 state = decompress(
210 state,
211 *normalized_counter
212 .get(symbol_index)
213 .expect("symbol frequency not found"),
214 table_log,
215 *cs.get(symbol_index).expect("symbol cumul not found"),
216 );
217 if state < 2usize.pow(16) {
218 if let Some(nb_bits) = bits.pop() {
219 state = (state << 16) + dstream.read(nb_bits as u8).unwrap() as usize;
220 }
221 }
222 }
223 ret.reverse();
224 ret
225}
226
227pub fn decode_u8(
228 mut state: usize,
229 mut bits: Vec<u32>,
230 str: Vec<u8>,
231 normalized_counter: &[usize],
232 table_log: usize,
233) -> Vec<u8> {
234 let mask = 2usize.pow(table_log as u32) - 1;
235
236 let mut dstream: BitDstream = str.try_into().unwrap();
237 dstream.read(1).unwrap(); let cs = build_cumulative_symbol_frequency(normalized_counter);
240 let mut ret = vec![];
241 while state > 0 {
242 let symbol_index = find_s(state & mask, &cs);
245 ret.push(symbol_index.try_into().expect("symbol overflow"));
246 state = decompress(
247 state,
248 *normalized_counter
249 .get(symbol_index)
250 .expect("symbol frequency not found"),
251 table_log,
252 *cs.get(symbol_index).expect("symbol cumul not found"),
253 );
254 if state < 2usize.pow(16) {
255 if let Some(nb_bits) = bits.pop() {
261 state = (state << 16) + dstream.read(nb_bits as u8).unwrap() as usize;
262 }
263 }
264 }
265 ret.reverse();
266 ret
267}