fsst_rust/core/
symbol_table.rs

1use std::collections::HashMap;
2use std::fmt::{Display, Formatter};
3
4use crate::core::{CODE_BASE, CODE_MASK, CODE_MAX, fsst_hash, is_escape_code, LEN_BITS};
5use crate::core::counter::Counter;
6use crate::core::symbol::Symbol;
7use crate::util::endian::Endian;
8
9pub trait SymbolTable: SymbolTableClone + Display {
10    fn add(&mut self, s: Symbol) -> bool;
11    fn find_longest_symbol_code(&self, str_bytes: &[u8]) -> u16;
12    fn get_symbol(&self, code: u16) -> &Symbol;
13    fn encode_for(&self, target: &Symbol) -> (u8, usize, usize);
14    fn len(&self) -> usize;
15    fn clear(&mut self);
16    fn finalize(&mut self);
17    fn dump(&self) -> Vec<u8>;
18}
19
20pub trait SymbolTableClone {
21    fn clone_box<'a>(&self) -> Box<dyn SymbolTable + 'a>
22    where
23        Self: 'a;
24}
25
26impl<T: Clone + SymbolTable> SymbolTableClone for T {
27    fn clone_box<'a>(&self) -> Box<dyn SymbolTable + 'a>
28    where
29        Self: 'a,
30    {
31        Box::new(self.clone())
32    }
33}
34
35#[derive(Clone, Copy)]
36struct PerfectHashSymbolTable {
37    // lookup table (only used during symbolTable construction, not during normal text compression)
38    byte_codes: [u16; CODE_BASE as usize],
39
40    // lookup table using the next two bytes (65536 codes), or just the next single byte
41    short_codes: [u16; 65536],
42
43    hash_table: [Symbol; PerfectHashSymbolTable::TABLE_SIZE],
44    symbols: [Symbol; CODE_MAX as usize],
45    len_histo: [u8; Symbol::MAX_LEN],
46    symbol_num: u16,
47    finalized: bool,
48}
49
50impl PerfectHashSymbolTable {
51    const TABLE_SIZE: usize = 4096;
52
53    pub fn new() -> PerfectHashSymbolTable {
54        let unused = Symbol::from_byte_code(0, CODE_MASK);
55        let mut symbols = [unused; CODE_MAX as usize];
56        let mut byte_codes = [0u16; CODE_BASE as usize];
57        for i in 0..CODE_BASE {
58            let byte_code = (1 << LEN_BITS) | i;
59            byte_codes[i as usize] = byte_code;
60            symbols[i as usize] = Symbol::from_byte_code(i as u8, byte_code);
61        }
62
63        let mut short_codes = [0u16; 65536];
64        for i in 0..short_codes.len() {
65            short_codes[i] = (1 << LEN_BITS) | ((i as u16) & 0xff);
66        }
67
68        let len_histo = [0u8; Symbol::MAX_LEN];
69        let hash_table = [Symbol::free(); PerfectHashSymbolTable::TABLE_SIZE];
70        PerfectHashSymbolTable {
71            byte_codes,
72            short_codes,
73            hash_table,
74            symbols,
75            len_histo,
76            symbol_num: 0,
77            finalized: false,
78        }
79    }
80
81    fn hash_insert(&mut self, s: &Symbol) -> bool {
82        let src_symbol = self.get_hash_symbol_mut(s.hash());
83        if src_symbol.taken() {
84            return false;
85        }
86
87        src_symbol.update_to(s);
88        return true;
89    }
90
91    fn get_hash_symbol_mut(&mut self, hash_value: usize) -> &mut Symbol {
92        &mut self.hash_table[Self::hash_idx(hash_value)]
93    }
94
95    fn get_hash_symbol(&self, hash_value: usize) -> &Symbol {
96        &self.hash_table[Self::hash_idx(hash_value)]
97    }
98
99    fn hash_idx(hash_value: usize) -> usize {
100        hash_value & (PerfectHashSymbolTable::TABLE_SIZE - 1)
101    }
102}
103
104impl SymbolTable for PerfectHashSymbolTable {
105    fn add(&mut self, mut s: Symbol) -> bool {
106        let len = s.length();
107        let code = CODE_BASE + self.symbol_num;
108        s.set_code_len(code, len);
109        if len == 1 {
110            self.byte_codes[s.first()] = code | (1 << LEN_BITS); // len=1 (<<FSST_LEN_BITS)
111        } else if len == 2 {
112            self.short_codes[s.first2()] = code | (2 << LEN_BITS); // len=2 (<<FSST_LEN_BITS)
113        } else if !self.hash_insert(&s) {
114            return false;
115        }
116
117        self.symbols[code as usize] = s;
118        self.symbol_num += 1;
119        self.len_histo[len - 1] += 1;
120        return true;
121    }
122
123    fn find_longest_symbol_code(&self, str_bytes: &[u8]) -> u16 {
124        let target_symbol = Symbol::from_str_bytes(str_bytes);
125        let src_symbol = self.get_hash_symbol(target_symbol.hash());
126        if target_symbol.prefix_match(src_symbol) {
127            return src_symbol.code();
128        }
129
130        if target_symbol.length() >= 2 {
131            let code = self.short_codes[target_symbol.first2()] & CODE_MASK;
132            if code >= CODE_BASE {
133                return code;
134            }
135        }
136
137        self.byte_codes[target_symbol.first()] & CODE_MASK
138    }
139
140    fn get_symbol(&self, code: u16) -> &Symbol {
141        &self.symbols[code as usize]
142    }
143
144    fn encode_for(&self, target: &Symbol) -> (u8, usize, usize) {
145        let src_symbol = self.get_hash_symbol(target.hash());
146        if target.prefix_match(src_symbol) {
147            return (src_symbol.code() as u8, src_symbol.length(), 1);
148        }
149
150        let code = self.short_codes[target.first2()];
151        let s_len = (code >> LEN_BITS) as usize;
152        let out_len = (1 + ((code & CODE_BASE) >> 8)) as usize;
153        (code as u8, s_len, out_len)
154    }
155
156    fn len(&self) -> usize {
157        self.symbol_num as usize
158    }
159
160    fn clear(&mut self) {
161        for i in CODE_BASE..CODE_BASE + self.symbol_num {
162            let s = self.get_symbol(i);
163            match s.length() {
164                1 => {
165                    let v = s.first();
166                    self.byte_codes[v] = (v as u16 & 0xff) | (1 << LEN_BITS)
167                }
168                2 => {
169                    let v = s.first2();
170                    self.short_codes[v] = (v as u16 & 0xff) | (1 << LEN_BITS)
171                }
172                _ => {
173                    let src = self.get_hash_symbol_mut(s.hash());
174                    src.reset();
175                }
176            }
177        }
178        self.len_histo.fill(0);
179        self.symbol_num = 0;
180    }
181
182    fn finalize(&mut self) {
183        // compute running sum of code lengths (starting offsets for each length)
184        let mut rsum = [0u8; Symbol::MAX_LEN];
185        for i in 0..rsum.len() - 1 {
186            rsum[i + 1] = rsum[i] + self.len_histo[i];
187        }
188
189        let mut new_codes = [0u8; CODE_BASE as usize];
190        for i in CODE_BASE..CODE_BASE + self.symbol_num {
191            let mut s = self.symbols[i as usize];
192            let len = s.length();
193            new_codes[(i - CODE_BASE) as usize] = rsum[len - 1];
194            rsum[len - 1] += 1;
195            let new_code = new_codes[(i - CODE_BASE) as usize];
196            s.set_code_len(new_code as u16, len);
197            self.symbols[new_code as usize] = s;
198        }
199
200        for i in 0..CODE_BASE as usize {
201            if (self.byte_codes[i] & CODE_MASK) >= CODE_BASE {
202                let idx = (self.byte_codes[i] & 0xff) as usize;
203                self.byte_codes[i] = new_codes[idx] as u16 | (1 << LEN_BITS);
204            } else {
205                self.byte_codes[i] = CODE_MASK | (1 << LEN_BITS);
206            }
207        }
208
209        for i in 0..self.short_codes.len() {
210            if (self.short_codes[i] & CODE_MASK) >= CODE_BASE {
211                let idx = (self.short_codes[i] & 0xff) as usize;
212                self.short_codes[i] = new_codes[idx] as u16 | (self.short_codes[i] & (0xf << LEN_BITS));
213            } else {
214                self.short_codes[i] = self.byte_codes[i & 0xff];
215            }
216        }
217
218        for i in 0..self.hash_table.len() {
219            if self.hash_table[i].taken() {
220                let idx = (self.hash_table[i].code() & 0xff) as usize;
221                self.hash_table[i] = self.symbols[new_codes[idx] as usize];
222            }
223        }
224        self.finalized = true;
225    }
226
227    fn dump(&self) -> Vec<u8> {
228        let mut total_size = 9usize;
229        for i in 0..self.len_histo.len() {
230            total_size += self.len_histo[i] as usize * (i + 1);
231        }
232        let mut buf = Vec::with_capacity(total_size);
233        buf.push(Endian::get_native_endian().into());
234        self.len_histo.iter().for_each(|l| buf.push(*l));
235        for i in 0..self.symbol_num {
236            let s = self.get_symbol(i);
237            let mut num = s.as_u64();
238            for _ in 0..s.length() {
239                buf.push(num as u8);
240                num >>= 8;
241            }
242        }
243        buf
244    }
245}
246
247impl Display for PerfectHashSymbolTable {
248    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
249        let (start, end) = if self.finalized {
250            (0usize, self.symbol_num as usize)
251        } else {
252            (CODE_BASE as usize, (CODE_BASE + self.symbol_num) as usize)
253        };
254        let symbols_str = &self.symbols[start..end].iter()
255            .map(|&x| x.to_string())
256            .collect::<Vec<String>>()
257            .join(", ");
258        write!(f, "[{}]", symbols_str)
259    }
260}
261
262pub struct SymbolTableBuilder {
263    counter: Counter,
264    count_frac: u32,
265}
266
267impl SymbolTableBuilder {
268    pub fn build_from(s: &str) -> Box<dyn SymbolTable> {
269        let str = String::from(s);
270        let sample = vec![&str];
271        SymbolTableBuilder {
272            counter: Counter::new(),
273            count_frac: 0,
274        }.build(&sample)
275    }
276
277    pub fn build_from_samples(samples: &Vec<&String>) -> Box<dyn SymbolTable> {
278        SymbolTableBuilder {
279            counter: Counter::new(),
280            count_frac: 5,
281        }.build(samples)
282    }
283
284    fn build(&mut self, samples: &Vec<&String>) -> Box<dyn SymbolTable> {
285        let mut symbol_table: Box<dyn SymbolTable> = Box::new(PerfectHashSymbolTable::new());
286        let mut best_table = symbol_table.clone_box();
287        let mut best_gain = i64::MIN;
288        let mut best_single = [0u8; Counter::ENTRY_SIZE * 2];
289        let mut sample_frac = 8;
290        loop {
291            let gain = self.compute_freq(samples, sample_frac, &symbol_table);
292            if gain > best_gain {
293                best_gain = gain;
294                best_single = self.counter.backup_single();
295                best_table = symbol_table.clone_box();
296            }
297            if sample_frac >= 128 {
298                break;
299            }
300            self.make_table(sample_frac, &mut symbol_table);
301            self.counter.reset();
302            sample_frac += 30;
303        }
304        self.counter.restore_single(best_single);
305        self.make_table(sample_frac, &mut best_table);
306        best_table.finalize();
307        best_table
308    }
309
310    fn compute_freq(&mut self, samples: &Vec<&String>, sample_frac: u32, symbol_table: &Box<dyn SymbolTable>) -> i64 {
311        let mut gain = 0i64;
312        for i in 0..samples.len() {
313            if samples.len() > 128 && sample_frac < 128 {
314                let rand = 1 + ((fsst_hash(1 + i) * sample_frac as usize) & 127);
315                if rand > sample_frac as usize {
316                    continue;
317                }
318            }
319            gain += self.count_line(samples[i].as_bytes(), sample_frac, symbol_table);
320        }
321        gain
322    }
323
324    fn count_line(&mut self, str_bytes: &[u8], sample_frac: u32, symbol_table: &Box<dyn SymbolTable>) -> i64 {
325        let mut gain = 0i64;
326        let mut pos = 0;
327        let mut code1 = symbol_table.find_longest_symbol_code(&str_bytes);
328        let mut s1 = symbol_table.get_symbol(code1);
329        loop {
330            self.counter.inc_single(code1 as usize);
331            if s1.length() > 1 {
332                self.counter.inc_single(str_bytes[pos] as usize);
333            }
334            gain += s1.length() as i64 - (1 + is_escape_code(code1) as i64);
335            pos += s1.length();
336            if pos >= str_bytes.len() {
337                break;
338            }
339
340            let code2 = symbol_table.find_longest_symbol_code(&str_bytes[pos..]);
341            let s2 = symbol_table.get_symbol(code2);
342            if sample_frac < 128 {
343                self.counter.inc_concat(code1 as usize, code2 as usize);
344                if s2.length() > 1 {
345                    self.counter.inc_concat(code1 as usize, str_bytes[pos] as usize);
346                }
347            }
348            code1 = code2;
349            s1 = s2;
350        }
351        gain
352    }
353
354    fn make_table(&mut self, sample_frac: u32, symbol_table: &mut Box<dyn SymbolTable>) {
355        let mut candidates: HashMap<Symbol, u32> = HashMap::with_capacity(CODE_MAX as usize);
356        let end = CODE_BASE as usize + symbol_table.len();
357        let mut pos1 = 0usize;
358        while pos1 < end {
359            let cnt1 = self.counter.get_single_and_forward(&mut pos1);
360            if cnt1 == 0 {
361                pos1 += 1;
362                continue;
363            }
364
365            let s1 = symbol_table.get_symbol(pos1 as u16);
366            let heuristic_cnt = match s1.length() {
367                1 => 8 * cnt1,
368                _ => cnt1
369            };
370            self.expand_candidate(&mut candidates, s1.clone(), heuristic_cnt, sample_frac);
371            if s1.length() == Symbol::MAX_LEN
372                || sample_frac >= 128 {
373                pos1 += 1;
374                continue;
375            }
376
377            let mut pos2 = 0usize;
378            while pos2 < end {
379                let cnt2 = self.counter.get_concat_and_forward(pos1, &mut pos2);
380                if cnt2 > 0 {
381                    let s2 = symbol_table.get_symbol(pos2 as u16);
382                    let s3 = *s1 + *s2;
383                    self.expand_candidate(&mut candidates, s3, cnt2, sample_frac);
384                }
385                pos2 += 1;
386            }
387            pos1 += 1;
388        }
389
390        let mut sorted_vec: Vec<(Symbol, u32)> = candidates.iter().map(|(k, v)| (*k, *v)).collect();
391        sorted_vec.sort_by(|a, b| {
392            if a.1 == b.1 {
393                b.0.cmp(&a.0)
394            } else {
395                a.1.cmp(&b.1)
396            }
397        });
398        symbol_table.clear();
399        while symbol_table.len() < 255 && !sorted_vec.is_empty() {
400            let s = sorted_vec.pop().unwrap();
401            symbol_table.add(s.0);
402        }
403    }
404
405    fn expand_candidate(&self, candidates: &mut HashMap<Symbol, u32>, s: Symbol, cnt: u32, sample_frac: u32) {
406        if cnt >= (self.count_frac * sample_frac / 128) {
407            let gain = s.length() as u32 * cnt;
408            candidates.insert(s, candidates.get(&s).unwrap_or(&0) + gain);
409        }
410    }
411}