final_state_rs/
fse16.rs

1use std::{collections::HashMap, convert::TryInto};
2use tiny_bitstream::{BitDstream, BitEstream, BitReader, BitWriter};
3
4/// Build cs = f0 + f1 + ... + fs-1
5///
6/// # normalized_counter
7///
8/// Table that counts the number of symbols and normalized as well as the sum
9/// of frequency is 2^R. Where R is named `table_log` in that code.
10///
11/// normalized_counter[symbol_index] = symbol frequency
12/// normalized_counter.len() = number of symbols
13pub fn build_cumulative_symbol_frequency(normalized_counter: &[usize]) -> Vec<usize> {
14    let mut cs = Vec::with_capacity(normalized_counter.len() + 1);
15
16    let cumul_fn = |acc, frequency| {
17        cs.push(acc);
18        acc + frequency
19    };
20    let sum = normalized_counter.iter().fold(0, cumul_fn);
21    cs.push(sum);
22    cs
23}
24
25pub fn compress(state: usize, table_log: usize, frequency: usize, cumul: usize) -> usize {
26    #[cfg(feature = "checks")]
27    // todo: add some natural checks behind a compilation feature; in some case that test
28    // doesn't have any reasons to be true.
29    if frequency == 0 {
30        panic!("attemp division by zero because of an unexpected null frequency")
31    }
32
33    // usize div by usize naturally give a rounded floor usize in rust
34    //(((state as f64 / frequency as f64).floor() as usize) << table_log)
35    //    + (state % frequency)
36    //    + cumul
37    ((state / frequency) << table_log) + (state % frequency) + cumul
38}
39
40pub fn simple_normalization(histogram: &mut [usize], cumul: &mut [usize], table_log: usize) {
41    // linear interpolation naïve sur une fonction de cumulation
42    let mut previous = 0;
43    let max_cumul = *cumul.last().unwrap();
44    let target_range = 1 << table_log; // D - C
45    let actual_range = max_cumul; // B - A
46
47    cumul.iter_mut().enumerate().skip(1).for_each(|(i, c)| {
48        *c = (target_range * (*c)) / actual_range;
49        if *c <= previous {
50            panic!("table log too low");
51            // todo: we expect to never force value actually...
52            // we need to increase table_log instead
53
54            // note: we could force to previous + 1 and accumulate a dept that
55            //       we substract to the nexts values. If at the end we keep
56            //       a dept > 0 we should panic. If not just inform user that
57            //       we got to force the normalized counter to fit.
58
59            // D'autres idées:
60            // 1. Correction à posteriorie, si j'ai une dette, après avoir
61            // calculé ma cdf je verifie si je peut pas supprimer quelques
62            // truc pour forcer a faire entrer dans mon table_log.
63            // 2. Panic je double
64            // 3. Lorsque je tombe sur un pépin, j'invertie les deux dernières
65            // valeurs.
66        }
67
68        histogram[i - 1] = *c - previous;
69        previous = *c;
70    });
71}
72
73/// Meme chose que encode_u8 mais avec un tbleau de u16 comme source. Generalement
74/// l'histogramme est plus coûteux à réaliser sur cette taille là.
75pub fn encode(
76    hist: &mut [usize],
77    symbol_index: &HashMap<u16, usize>,
78    table_log: usize, // R
79    src: &[u16],
80) -> (usize, Vec<u32>, Vec<u8>) {
81    let mut cs = build_cumulative_symbol_frequency(hist);
82
83    simple_normalization(hist, &mut cs, table_log);
84
85    let mut state = 0;
86
87    let d = 32 - table_log;
88    let msk = 2usize.pow(16) - 1;
89
90    let mut estream = BitEstream::new();
91
92    let mut nb_bits_table = vec![];
93
94    src.iter().for_each(|symbol| {
95        let index = *symbol_index.get(symbol).unwrap();
96        let fs = *hist.get(index).unwrap();
97        if state >= (fs << d) {
98            let bits = state & msk;
99            let nb_bits = u64::BITS - bits.leading_zeros();
100            estream.unchecked_write(bits, nb_bits.try_into().unwrap());
101            nb_bits_table.push(nb_bits);
102            state >>= 16;
103        };
104
105        state = compress(state, table_log, fs, *cs.get(index).unwrap());
106    });
107    (state, nb_bits_table, estream.try_into().unwrap())
108}
109
110/// Compresse une source de u8, on a besoin d'un histogramme ainsi que d'une
111/// table des symbole ("a" est à la position i dans l'histogramme)
112pub fn encode_u8(
113    hist: &mut [usize],
114    table_log: usize, // R
115    src: &[u8],
116) -> (usize, Vec<u32>, Vec<u8>) {
117    let mut cs = build_cumulative_symbol_frequency(hist);
118
119    simple_normalization(hist, &mut cs, table_log);
120
121    let mut state = 0;
122
123    // Une table superieure a 32 fera crasher ce programne,
124    // mais en general, il est deconseille d'utiliser
125    // une table superieure a 13. Pour des questions de
126    // performances, pas par superstition...
127    let d = 32 - table_log;
128    let msk = 2usize.pow(16) - 1;
129
130    let mut estream = BitEstream::new();
131
132    let mut nb_bits_table = vec![];
133
134    src.iter().for_each(|symbol| {
135        let index = *symbol as usize;
136        let fs = hist[index];
137        // On fait attention de ne le faire que si
138        // l'etat est plus grand que la probabilitee << d.
139        //
140        // Ca nous permet de tenir un etat entre 2^16 et 2^32 une
141        // fois 2^16 depasse. Et de laisser l'etat tranquil si
142        // on est encore en dessous de 2^16.
143        //
144        // Ce shift nous permet surtout de ne pas avoir un etat qui
145        // tend vers l'infini, et ne nous empeche pas de trouver le
146        // prochain etat de notre state machine.
147        //
148        // A cause de la normalisation, le max des probabilites
149        // devrait tenir sur table_log bits.
150        // Comme d = 32 - table_log, max(fs << d) = 2^32.
151        if state >= (fs << d) {
152            // On recupere les 16 premier bits
153            // de l'etat actuelle et ont la stoque dans un
154            // stream. On shift l'etat de 16 pour guarder
155            // seulement les 16 bit plus grands.
156            let bits = state & msk;
157            let nb_bits = u64::BITS - bits.leading_zeros();
158            estream.unchecked_write(bits, nb_bits.try_into().unwrap());
159            nb_bits_table.push(nb_bits);
160            state >>= 16;
161        };
162
163        state = compress(state, table_log, fs, cs[index]);
164    });
165    //println!("state {state}");
166    (state, nb_bits_table, estream.try_into().unwrap())
167}
168
169/// Todo: trouver le symbole par dychotomie. ( et explorer d'autres méthodes plus
170/// couteuses en mémoire)
171pub fn find_s(state: usize, cs: &[usize]) -> usize {
172    for (i, &c) in cs.iter().enumerate() {
173        if c == state {
174            return i;
175        }
176        if c > state {
177            return i - 1;
178        }
179    }
180    0
181}
182
183pub fn decompress(state: usize, frequency: usize, table_log: usize, cumul: usize) -> usize {
184    let mask = 2usize.pow(table_log as u32) - 1;
185    (frequency * (state >> table_log)) + (state & mask) - cumul
186}
187
188/// Décompression de la source u16, pareil que u8
189pub fn decode(
190    mut state: usize,
191    mut bits: Vec<u32>,
192    str: Vec<u8>,
193    normalized_counter: &[usize],
194    symbols: &[u16],
195    table_log: usize,
196) -> Vec<u16> {
197    let mask = 2usize.pow(table_log as u32) - 1;
198
199    let mut dstream: BitDstream = str.try_into().unwrap();
200    dstream.read(1).unwrap(); // read mark
201
202    let cs = build_cumulative_symbol_frequency(normalized_counter);
203    let mut ret = vec![];
204    while state > 0 {
205        //println!("reverse state {state}");
206        // todo add a security timing to auto kill loop
207        let symbol_index = find_s(state & mask, &cs);
208        ret.push(*symbols.get(symbol_index).expect("symbol not found"));
209        state = decompress(
210            state,
211            *normalized_counter
212                .get(symbol_index)
213                .expect("symbol frequency not found"),
214            table_log,
215            *cs.get(symbol_index).expect("symbol cumul not found"),
216        );
217        if state < 2usize.pow(16) {
218            if let Some(nb_bits) = bits.pop() {
219                state = (state << 16) + dstream.read(nb_bits as u8).unwrap() as usize;
220            }
221        }
222    }
223    ret.reverse();
224    ret
225}
226
227pub fn decode_u8(
228    mut state: usize,
229    mut bits: Vec<u32>,
230    str: Vec<u8>,
231    normalized_counter: &[usize],
232    table_log: usize,
233) -> Vec<u8> {
234    let mask = 2usize.pow(table_log as u32) - 1;
235
236    let mut dstream: BitDstream = str.try_into().unwrap();
237    dstream.read(1).unwrap(); // read mark
238
239    let cs = build_cumulative_symbol_frequency(normalized_counter);
240    let mut ret = vec![];
241    while state > 0 {
242        //println!("reverse state {state}");
243        // todo add a security timing to auto kill loop
244        let symbol_index = find_s(state & mask, &cs);
245        ret.push(symbol_index.try_into().expect("symbol overflow"));
246        state = decompress(
247            state,
248            *normalized_counter
249                .get(symbol_index)
250                .expect("symbol frequency not found"),
251            table_log,
252            *cs.get(symbol_index).expect("symbol cumul not found"),
253        );
254        if state < 2usize.pow(16) {
255            // Si on a un etat < 16, on essaye de lire le stream.
256            // Dans le cas ou on avait shifte, le stream contient
257            // forcement des bits. Si on ne trouve pas de bits,
258            // ca veut dire qu'on arrive a la fin de la decompression
259            // et que l'etat a une valeur attendue.
260            if let Some(nb_bits) = bits.pop() {
261                state = (state << 16) + dstream.read(nb_bits as u8).unwrap() as usize;
262            }
263        }
264    }
265    ret.reverse();
266    ret
267}