1use std::collections::HashMap;
2use std::fmt::{Display, Formatter};
3
4use crate::core::{CODE_BASE, CODE_MASK, CODE_MAX, fsst_hash, is_escape_code, LEN_BITS};
5use crate::core::counter::Counter;
6use crate::core::symbol::Symbol;
7use crate::util::endian::Endian;
8
9pub trait SymbolTable: SymbolTableClone + Display {
10 fn add(&mut self, s: Symbol) -> bool;
11 fn find_longest_symbol_code(&self, str_bytes: &[u8]) -> u16;
12 fn get_symbol(&self, code: u16) -> &Symbol;
13 fn encode_for(&self, target: &Symbol) -> (u8, usize, usize);
14 fn len(&self) -> usize;
15 fn clear(&mut self);
16 fn finalize(&mut self);
17 fn dump(&self) -> Vec<u8>;
18}
19
20pub trait SymbolTableClone {
21 fn clone_box<'a>(&self) -> Box<dyn SymbolTable + 'a>
22 where
23 Self: 'a;
24}
25
26impl<T: Clone + SymbolTable> SymbolTableClone for T {
27 fn clone_box<'a>(&self) -> Box<dyn SymbolTable + 'a>
28 where
29 Self: 'a,
30 {
31 Box::new(self.clone())
32 }
33}
34
35#[derive(Clone, Copy)]
36struct PerfectHashSymbolTable {
37 byte_codes: [u16; CODE_BASE as usize],
39
40 short_codes: [u16; 65536],
42
43 hash_table: [Symbol; PerfectHashSymbolTable::TABLE_SIZE],
44 symbols: [Symbol; CODE_MAX as usize],
45 len_histo: [u8; Symbol::MAX_LEN],
46 symbol_num: u16,
47 finalized: bool,
48}
49
50impl PerfectHashSymbolTable {
51 const TABLE_SIZE: usize = 4096;
52
53 pub fn new() -> PerfectHashSymbolTable {
54 let unused = Symbol::from_byte_code(0, CODE_MASK);
55 let mut symbols = [unused; CODE_MAX as usize];
56 let mut byte_codes = [0u16; CODE_BASE as usize];
57 for i in 0..CODE_BASE {
58 let byte_code = (1 << LEN_BITS) | i;
59 byte_codes[i as usize] = byte_code;
60 symbols[i as usize] = Symbol::from_byte_code(i as u8, byte_code);
61 }
62
63 let mut short_codes = [0u16; 65536];
64 for i in 0..short_codes.len() {
65 short_codes[i] = (1 << LEN_BITS) | ((i as u16) & 0xff);
66 }
67
68 let len_histo = [0u8; Symbol::MAX_LEN];
69 let hash_table = [Symbol::free(); PerfectHashSymbolTable::TABLE_SIZE];
70 PerfectHashSymbolTable {
71 byte_codes,
72 short_codes,
73 hash_table,
74 symbols,
75 len_histo,
76 symbol_num: 0,
77 finalized: false,
78 }
79 }
80
81 fn hash_insert(&mut self, s: &Symbol) -> bool {
82 let src_symbol = self.get_hash_symbol_mut(s.hash());
83 if src_symbol.taken() {
84 return false;
85 }
86
87 src_symbol.update_to(s);
88 return true;
89 }
90
91 fn get_hash_symbol_mut(&mut self, hash_value: usize) -> &mut Symbol {
92 &mut self.hash_table[Self::hash_idx(hash_value)]
93 }
94
95 fn get_hash_symbol(&self, hash_value: usize) -> &Symbol {
96 &self.hash_table[Self::hash_idx(hash_value)]
97 }
98
99 fn hash_idx(hash_value: usize) -> usize {
100 hash_value & (PerfectHashSymbolTable::TABLE_SIZE - 1)
101 }
102}
103
104impl SymbolTable for PerfectHashSymbolTable {
105 fn add(&mut self, mut s: Symbol) -> bool {
106 let len = s.length();
107 let code = CODE_BASE + self.symbol_num;
108 s.set_code_len(code, len);
109 if len == 1 {
110 self.byte_codes[s.first()] = code | (1 << LEN_BITS); } else if len == 2 {
112 self.short_codes[s.first2()] = code | (2 << LEN_BITS); } else if !self.hash_insert(&s) {
114 return false;
115 }
116
117 self.symbols[code as usize] = s;
118 self.symbol_num += 1;
119 self.len_histo[len - 1] += 1;
120 return true;
121 }
122
123 fn find_longest_symbol_code(&self, str_bytes: &[u8]) -> u16 {
124 let target_symbol = Symbol::from_str_bytes(str_bytes);
125 let src_symbol = self.get_hash_symbol(target_symbol.hash());
126 if target_symbol.prefix_match(src_symbol) {
127 return src_symbol.code();
128 }
129
130 if target_symbol.length() >= 2 {
131 let code = self.short_codes[target_symbol.first2()] & CODE_MASK;
132 if code >= CODE_BASE {
133 return code;
134 }
135 }
136
137 self.byte_codes[target_symbol.first()] & CODE_MASK
138 }
139
140 fn get_symbol(&self, code: u16) -> &Symbol {
141 &self.symbols[code as usize]
142 }
143
144 fn encode_for(&self, target: &Symbol) -> (u8, usize, usize) {
145 let src_symbol = self.get_hash_symbol(target.hash());
146 if target.prefix_match(src_symbol) {
147 return (src_symbol.code() as u8, src_symbol.length(), 1);
148 }
149
150 let code = self.short_codes[target.first2()];
151 let s_len = (code >> LEN_BITS) as usize;
152 let out_len = (1 + ((code & CODE_BASE) >> 8)) as usize;
153 (code as u8, s_len, out_len)
154 }
155
156 fn len(&self) -> usize {
157 self.symbol_num as usize
158 }
159
160 fn clear(&mut self) {
161 for i in CODE_BASE..CODE_BASE + self.symbol_num {
162 let s = self.get_symbol(i);
163 match s.length() {
164 1 => {
165 let v = s.first();
166 self.byte_codes[v] = (v as u16 & 0xff) | (1 << LEN_BITS)
167 }
168 2 => {
169 let v = s.first2();
170 self.short_codes[v] = (v as u16 & 0xff) | (1 << LEN_BITS)
171 }
172 _ => {
173 let src = self.get_hash_symbol_mut(s.hash());
174 src.reset();
175 }
176 }
177 }
178 self.len_histo.fill(0);
179 self.symbol_num = 0;
180 }
181
182 fn finalize(&mut self) {
183 let mut rsum = [0u8; Symbol::MAX_LEN];
185 for i in 0..rsum.len() - 1 {
186 rsum[i + 1] = rsum[i] + self.len_histo[i];
187 }
188
189 let mut new_codes = [0u8; CODE_BASE as usize];
190 for i in CODE_BASE..CODE_BASE + self.symbol_num {
191 let mut s = self.symbols[i as usize];
192 let len = s.length();
193 new_codes[(i - CODE_BASE) as usize] = rsum[len - 1];
194 rsum[len - 1] += 1;
195 let new_code = new_codes[(i - CODE_BASE) as usize];
196 s.set_code_len(new_code as u16, len);
197 self.symbols[new_code as usize] = s;
198 }
199
200 for i in 0..CODE_BASE as usize {
201 if (self.byte_codes[i] & CODE_MASK) >= CODE_BASE {
202 let idx = (self.byte_codes[i] & 0xff) as usize;
203 self.byte_codes[i] = new_codes[idx] as u16 | (1 << LEN_BITS);
204 } else {
205 self.byte_codes[i] = CODE_MASK | (1 << LEN_BITS);
206 }
207 }
208
209 for i in 0..self.short_codes.len() {
210 if (self.short_codes[i] & CODE_MASK) >= CODE_BASE {
211 let idx = (self.short_codes[i] & 0xff) as usize;
212 self.short_codes[i] = new_codes[idx] as u16 | (self.short_codes[i] & (0xf << LEN_BITS));
213 } else {
214 self.short_codes[i] = self.byte_codes[i & 0xff];
215 }
216 }
217
218 for i in 0..self.hash_table.len() {
219 if self.hash_table[i].taken() {
220 let idx = (self.hash_table[i].code() & 0xff) as usize;
221 self.hash_table[i] = self.symbols[new_codes[idx] as usize];
222 }
223 }
224 self.finalized = true;
225 }
226
227 fn dump(&self) -> Vec<u8> {
228 let mut total_size = 9usize;
229 for i in 0..self.len_histo.len() {
230 total_size += self.len_histo[i] as usize * (i + 1);
231 }
232 let mut buf = Vec::with_capacity(total_size);
233 buf.push(Endian::get_native_endian().into());
234 self.len_histo.iter().for_each(|l| buf.push(*l));
235 for i in 0..self.symbol_num {
236 let s = self.get_symbol(i);
237 let mut num = s.as_u64();
238 for _ in 0..s.length() {
239 buf.push(num as u8);
240 num >>= 8;
241 }
242 }
243 buf
244 }
245}
246
247impl Display for PerfectHashSymbolTable {
248 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
249 let (start, end) = if self.finalized {
250 (0usize, self.symbol_num as usize)
251 } else {
252 (CODE_BASE as usize, (CODE_BASE + self.symbol_num) as usize)
253 };
254 let symbols_str = &self.symbols[start..end].iter()
255 .map(|&x| x.to_string())
256 .collect::<Vec<String>>()
257 .join(", ");
258 write!(f, "[{}]", symbols_str)
259 }
260}
261
262pub struct SymbolTableBuilder {
263 counter: Counter,
264 count_frac: u32,
265}
266
267impl SymbolTableBuilder {
268 pub fn build_from(s: &str) -> Box<dyn SymbolTable> {
269 let str = String::from(s);
270 let sample = vec![&str];
271 SymbolTableBuilder {
272 counter: Counter::new(),
273 count_frac: 0,
274 }.build(&sample)
275 }
276
277 pub fn build_from_samples(samples: &Vec<&String>) -> Box<dyn SymbolTable> {
278 SymbolTableBuilder {
279 counter: Counter::new(),
280 count_frac: 5,
281 }.build(samples)
282 }
283
284 fn build(&mut self, samples: &Vec<&String>) -> Box<dyn SymbolTable> {
285 let mut symbol_table: Box<dyn SymbolTable> = Box::new(PerfectHashSymbolTable::new());
286 let mut best_table = symbol_table.clone_box();
287 let mut best_gain = i64::MIN;
288 let mut best_single = [0u8; Counter::ENTRY_SIZE * 2];
289 let mut sample_frac = 8;
290 loop {
291 let gain = self.compute_freq(samples, sample_frac, &symbol_table);
292 if gain > best_gain {
293 best_gain = gain;
294 best_single = self.counter.backup_single();
295 best_table = symbol_table.clone_box();
296 }
297 if sample_frac >= 128 {
298 break;
299 }
300 self.make_table(sample_frac, &mut symbol_table);
301 self.counter.reset();
302 sample_frac += 30;
303 }
304 self.counter.restore_single(best_single);
305 self.make_table(sample_frac, &mut best_table);
306 best_table.finalize();
307 best_table
308 }
309
310 fn compute_freq(&mut self, samples: &Vec<&String>, sample_frac: u32, symbol_table: &Box<dyn SymbolTable>) -> i64 {
311 let mut gain = 0i64;
312 for i in 0..samples.len() {
313 if samples.len() > 128 && sample_frac < 128 {
314 let rand = 1 + ((fsst_hash(1 + i) * sample_frac as usize) & 127);
315 if rand > sample_frac as usize {
316 continue;
317 }
318 }
319 gain += self.count_line(samples[i].as_bytes(), sample_frac, symbol_table);
320 }
321 gain
322 }
323
324 fn count_line(&mut self, str_bytes: &[u8], sample_frac: u32, symbol_table: &Box<dyn SymbolTable>) -> i64 {
325 let mut gain = 0i64;
326 let mut pos = 0;
327 let mut code1 = symbol_table.find_longest_symbol_code(&str_bytes);
328 let mut s1 = symbol_table.get_symbol(code1);
329 loop {
330 self.counter.inc_single(code1 as usize);
331 if s1.length() > 1 {
332 self.counter.inc_single(str_bytes[pos] as usize);
333 }
334 gain += s1.length() as i64 - (1 + is_escape_code(code1) as i64);
335 pos += s1.length();
336 if pos >= str_bytes.len() {
337 break;
338 }
339
340 let code2 = symbol_table.find_longest_symbol_code(&str_bytes[pos..]);
341 let s2 = symbol_table.get_symbol(code2);
342 if sample_frac < 128 {
343 self.counter.inc_concat(code1 as usize, code2 as usize);
344 if s2.length() > 1 {
345 self.counter.inc_concat(code1 as usize, str_bytes[pos] as usize);
346 }
347 }
348 code1 = code2;
349 s1 = s2;
350 }
351 gain
352 }
353
354 fn make_table(&mut self, sample_frac: u32, symbol_table: &mut Box<dyn SymbolTable>) {
355 let mut candidates: HashMap<Symbol, u32> = HashMap::with_capacity(CODE_MAX as usize);
356 let end = CODE_BASE as usize + symbol_table.len();
357 let mut pos1 = 0usize;
358 while pos1 < end {
359 let cnt1 = self.counter.get_single_and_forward(&mut pos1);
360 if cnt1 == 0 {
361 pos1 += 1;
362 continue;
363 }
364
365 let s1 = symbol_table.get_symbol(pos1 as u16);
366 let heuristic_cnt = match s1.length() {
367 1 => 8 * cnt1,
368 _ => cnt1
369 };
370 self.expand_candidate(&mut candidates, s1.clone(), heuristic_cnt, sample_frac);
371 if s1.length() == Symbol::MAX_LEN
372 || sample_frac >= 128 {
373 pos1 += 1;
374 continue;
375 }
376
377 let mut pos2 = 0usize;
378 while pos2 < end {
379 let cnt2 = self.counter.get_concat_and_forward(pos1, &mut pos2);
380 if cnt2 > 0 {
381 let s2 = symbol_table.get_symbol(pos2 as u16);
382 let s3 = *s1 + *s2;
383 self.expand_candidate(&mut candidates, s3, cnt2, sample_frac);
384 }
385 pos2 += 1;
386 }
387 pos1 += 1;
388 }
389
390 let mut sorted_vec: Vec<(Symbol, u32)> = candidates.iter().map(|(k, v)| (*k, *v)).collect();
391 sorted_vec.sort_by(|a, b| {
392 if a.1 == b.1 {
393 b.0.cmp(&a.0)
394 } else {
395 a.1.cmp(&b.1)
396 }
397 });
398 symbol_table.clear();
399 while symbol_table.len() < 255 && !sorted_vec.is_empty() {
400 let s = sorted_vec.pop().unwrap();
401 symbol_table.add(s.0);
402 }
403 }
404
405 fn expand_candidate(&self, candidates: &mut HashMap<Symbol, u32>, s: Symbol, cnt: u32, sample_frac: u32) {
406 if cnt >= (self.count_frac * sample_frac / 128) {
407 let gain = s.length() as u32 * cnt;
408 candidates.insert(s, candidates.get(&s).unwrap_or(&0) + gain);
409 }
410 }
411}