1use crate::{
8 Code, Compressor, FSST_CODE_BASE, FSST_CODE_MASK, Symbol, advance_8byte_word, compare_masked,
9 lossy_pht::LossyPHT,
10};
11use std::cmp::Ordering;
12use std::collections::BinaryHeap;
13
14#[derive(Clone, Copy, Debug, Default)]
16struct CodesBitmap {
17 codes: [u64; 8],
18}
19
20assert_sizeof!(CodesBitmap => 64);
21
22impl CodesBitmap {
23 pub(crate) fn set(&mut self, index: usize) {
25 debug_assert!(
26 index <= FSST_CODE_MASK as usize,
27 "code cannot exceed {FSST_CODE_MASK}"
28 );
29
30 let map = index >> 6;
31 self.codes[map] |= 1 << (index % 64);
32 }
33
34 pub(crate) fn is_set(&self, index: usize) -> bool {
36 debug_assert!(
37 index <= FSST_CODE_MASK as usize,
38 "code cannot exceed {FSST_CODE_MASK}"
39 );
40
41 let map = index >> 6;
42 self.codes[map] & (1 << (index % 64)) != 0
43 }
44
45 pub(crate) fn codes(&self) -> CodesIterator<'_> {
47 CodesIterator {
48 inner: self,
49 index: 0,
50 block: self.codes[0],
51 reference: 0,
52 }
53 }
54
55 pub(crate) fn clear(&mut self) {
57 self.codes[0] = 0;
58 self.codes[1] = 0;
59 self.codes[2] = 0;
60 self.codes[3] = 0;
61 self.codes[4] = 0;
62 self.codes[5] = 0;
63 self.codes[6] = 0;
64 self.codes[7] = 0;
65 }
66}
67
68struct CodesIterator<'a> {
69 inner: &'a CodesBitmap,
70 index: usize,
71 block: u64,
72 reference: usize,
73}
74
75impl Iterator for CodesIterator<'_> {
76 type Item = u16;
77
78 fn next(&mut self) -> Option<Self::Item> {
79 while self.block == 0 {
81 self.index += 1;
82 if self.index >= 8 {
83 return None;
84 }
85 self.block = self.inner.codes[self.index];
86 self.reference = self.index * 64;
87 }
88
89 let position = self.block.trailing_zeros() as usize;
91 let code = self.reference + position;
92
93 if code >= 511 {
94 return None;
95 }
96
97 self.reference = code + 1;
99 self.block = if position == 63 {
100 0
101 } else {
102 self.block >> (1 + position)
103 };
104
105 Some(code as u16)
106 }
107}
108
109#[derive(Debug, Clone)]
110struct Counter {
111 counts1: Vec<usize>,
113
114 counts2: Vec<usize>,
116
117 code1_index: CodesBitmap,
119
120 pair_index: Vec<CodesBitmap>,
125}
126
127const COUNTS1_SIZE: usize = (FSST_CODE_MASK + 1) as usize;
128
129const COUNTS2_SIZE: usize = COUNTS1_SIZE * COUNTS1_SIZE;
135
136impl Counter {
137 fn new() -> Self {
138 let mut counts1 = Vec::with_capacity(COUNTS1_SIZE);
139 let mut counts2 = Vec::with_capacity(COUNTS2_SIZE);
140 unsafe {
143 counts1.set_len(COUNTS1_SIZE);
144 counts2.set_len(COUNTS2_SIZE);
145 }
146
147 Self {
148 counts1,
149 counts2,
150 code1_index: CodesBitmap::default(),
151 pair_index: vec![CodesBitmap::default(); COUNTS1_SIZE],
152 }
153 }
154
155 #[inline]
156 fn record_count1(&mut self, code1: u16) {
157 let base = if self.code1_index.is_set(code1 as usize) {
159 self.counts1[code1 as usize]
160 } else {
161 0
162 };
163
164 self.counts1[code1 as usize] = base + 1;
165 self.code1_index.set(code1 as usize);
166 }
167
168 #[inline]
169 fn record_count2(&mut self, code1: u16, code2: u16) {
170 debug_assert!(code1 == FSST_CODE_MASK || self.code1_index.is_set(code1 as usize));
171 debug_assert!(self.code1_index.is_set(code2 as usize));
172
173 let idx = (code1 as usize) * COUNTS1_SIZE + (code2 as usize);
174 if self.pair_index[code1 as usize].is_set(code2 as usize) {
175 self.counts2[idx] += 1;
176 } else {
177 self.counts2[idx] = 1;
178 }
179 self.pair_index[code1 as usize].set(code2 as usize);
180 }
181
182 #[inline]
183 fn count1(&self, code1: u16) -> usize {
184 debug_assert!(self.code1_index.is_set(code1 as usize));
185
186 self.counts1[code1 as usize]
187 }
188
189 #[inline]
190 fn count2(&self, code1: u16, code2: u16) -> usize {
191 debug_assert!(self.code1_index.is_set(code1 as usize));
192 debug_assert!(self.code1_index.is_set(code2 as usize));
193 debug_assert!(self.pair_index[code1 as usize].is_set(code2 as usize));
194
195 let idx = (code1 as usize) * 512 + (code2 as usize);
196 self.counts2[idx]
197 }
198
199 fn first_codes(&self) -> CodesIterator<'_> {
202 self.code1_index.codes()
203 }
204
205 fn second_codes(&self, code1: u16) -> CodesIterator<'_> {
211 self.pair_index[code1 as usize].codes()
212 }
213
214 fn clear(&mut self) {
217 self.code1_index.clear();
218 for index in &mut self.pair_index {
219 index.clear();
220 }
221 }
222}
223
224pub struct CompressorBuilder {
226 symbols: Vec<Symbol>,
230
231 n_symbols: u8,
234
235 len_histogram: [u8; 8],
239
240 codes_one_byte: Vec<Code>,
244
245 codes_two_byte: Vec<Code>,
247
248 lossy_pht: LossyPHT,
250}
251
252impl CompressorBuilder {
253 pub fn new() -> Self {
255 let symbols = vec![0u64; 511];
259
260 let symbols: Vec<Symbol> = unsafe { std::mem::transmute(symbols) };
262
263 let mut table = Self {
264 symbols,
265 n_symbols: 0,
266 len_histogram: [0; 8],
267 codes_two_byte: Vec::with_capacity(65_536),
268 codes_one_byte: Vec::with_capacity(512),
269 lossy_pht: LossyPHT::new(),
270 };
271
272 for byte in 0..=255 {
274 let symbol = Symbol::from_u8(byte);
275 table.symbols[byte as usize] = symbol;
276 }
277
278 for byte in 0..=255 {
280 table.codes_one_byte.push(Code::new_escape(byte));
282 }
283
284 for idx in 0..=65_535 {
286 table.codes_two_byte.push(Code::new_escape(idx as u8));
287 }
288
289 table
290 }
291}
292
293impl Default for CompressorBuilder {
294 fn default() -> Self {
295 Self::new()
296 }
297}
298
299impl CompressorBuilder {
300 pub fn insert(&mut self, symbol: Symbol, len: usize) -> bool {
311 assert!(self.n_symbols < 255, "cannot insert into full symbol table");
312 assert_eq!(len, symbol.len(), "provided len must equal symbol.len()");
313
314 if len == 2 {
315 self.codes_two_byte[symbol.first2() as usize] =
317 Code::new_symbol_building(self.n_symbols, 2);
318 } else if len == 1 {
319 self.codes_one_byte[symbol.first_byte() as usize] =
321 Code::new_symbol_building(self.n_symbols, 1);
322 } else {
323 if !self.lossy_pht.insert(symbol, len, self.n_symbols) {
325 return false;
326 }
327 }
328
329 self.len_histogram[len - 1] += 1;
331
332 self.symbols[256 + (self.n_symbols as usize)] = symbol;
335 self.n_symbols += 1;
336 true
337 }
338
339 fn clear(&mut self) {
344 for code in 0..(256 + self.n_symbols as usize) {
346 let symbol = self.symbols[code];
347 if symbol.len() == 1 {
348 self.codes_one_byte[symbol.first_byte() as usize] =
350 Code::new_escape(symbol.first_byte());
351 } else if symbol.len() == 2 {
352 self.codes_two_byte[symbol.first2() as usize] =
354 Code::new_escape(symbol.first_byte());
355 } else {
356 self.lossy_pht.remove(symbol);
358 }
359 }
360
361 for i in 0..=7 {
363 self.len_histogram[i] = 0;
364 }
365
366 self.n_symbols = 0;
367 }
368
369 fn finalize(&mut self) -> (u8, Vec<u8>) {
388 let byte_lim = self.n_symbols - self.len_histogram[0];
393
394 let mut codes_by_length = [0u8; 8];
398 codes_by_length[0] = byte_lim;
399 codes_by_length[1] = 0;
400
401 for i in 1..7 {
403 codes_by_length[i + 1] = codes_by_length[i] + self.len_histogram[i];
404 }
405
406 let mut no_suffix_code = 0;
410
411 let mut has_suffix_code = codes_by_length[2];
413
414 let mut new_codes = [0u8; FSST_CODE_BASE as usize];
417
418 let mut symbol_lens = [0u8; FSST_CODE_BASE as usize];
419
420 for i in 0..(self.n_symbols as usize) {
421 let symbol = self.symbols[256 + i];
422 let len = symbol.len();
423 if len == 2 {
424 let has_suffix = self
425 .symbols
426 .iter()
427 .skip(FSST_CODE_BASE as usize)
428 .enumerate()
429 .any(|(k, other)| i != k && symbol.first2() == other.first2());
430
431 if has_suffix {
432 has_suffix_code -= 1;
434 new_codes[i] = has_suffix_code;
435 } else {
436 new_codes[i] = no_suffix_code;
439 no_suffix_code += 1;
440 }
441 } else {
442 new_codes[i] = codes_by_length[len - 1];
444 codes_by_length[len - 1] += 1;
445 }
446
447 self.symbols[new_codes[i] as usize] = symbol;
450 symbol_lens[new_codes[i] as usize] = len as u8;
451 }
452
453 self.symbols.truncate(self.n_symbols as usize);
455
456 for byte in 0..=255 {
459 let one_byte = self.codes_one_byte[byte];
460 if one_byte.extended_code() >= FSST_CODE_BASE {
461 let new_code = new_codes[one_byte.code() as usize];
462 self.codes_one_byte[byte] = Code::new_symbol(new_code, 1);
463 } else {
464 self.codes_one_byte[byte] = Code::UNUSED;
466 }
467 }
468
469 for two_bytes in 0..=65_535 {
472 let two_byte = self.codes_two_byte[two_bytes];
473 if two_byte.extended_code() >= FSST_CODE_BASE {
474 let new_code = new_codes[two_byte.code() as usize];
475 self.codes_two_byte[two_bytes] = Code::new_symbol(new_code, 2);
476 } else {
477 self.codes_two_byte[two_bytes] = self.codes_one_byte[two_bytes & 0xFF];
479 }
480 }
481
482 self.lossy_pht.renumber(&new_codes);
484
485 let mut lengths = Vec::with_capacity(self.n_symbols as usize);
487 for symbol in &self.symbols {
488 lengths.push(symbol.len() as u8);
489 }
490
491 (has_suffix_code, lengths)
492 }
493
494 pub fn build(mut self) -> Compressor {
496 let (has_suffix_code, lengths) = self.finalize();
500
501 Compressor {
502 symbols: self.symbols,
503 lengths,
504 n_symbols: self.n_symbols,
505 has_suffix_code,
506 codes_two_byte: self.codes_two_byte,
507 lossy_pht: self.lossy_pht,
508 }
509 }
510}
511
512#[cfg(not(miri))]
516const GENERATIONS: [usize; 5] = [8usize, 38, 68, 98, 128];
517#[cfg(miri)]
518const GENERATIONS: [usize; 3] = [8usize, 38, 128];
519
520const FSST_SAMPLETARGET: usize = 1 << 14;
521const FSST_SAMPLEMAX: usize = 1 << 15;
522const FSST_SAMPLELINE: usize = 512;
523
524#[allow(clippy::ptr_arg)]
531fn make_sample<'a, 'b: 'a>(sample_buf: &'a mut Vec<u8>, str_in: &Vec<&'b [u8]>) -> Vec<&'a [u8]> {
532 assert!(
533 sample_buf.capacity() >= FSST_SAMPLEMAX,
534 "sample_buf.len() < FSST_SAMPLEMAX"
535 );
536
537 let mut sample: Vec<&[u8]> = Vec::new();
538
539 let tot_size: usize = str_in.iter().map(|s| s.len()).sum();
540 if tot_size < FSST_SAMPLETARGET {
541 return str_in.clone();
542 }
543
544 let mut sample_rnd = fsst_hash(4637947);
545 let sample_lim = FSST_SAMPLETARGET;
546 let mut sample_buf_offset: usize = 0;
547
548 while sample_buf_offset < sample_lim {
549 sample_rnd = fsst_hash(sample_rnd);
550 let line_nr = (sample_rnd as usize) % str_in.len();
551
552 let Some(line) = (line_nr..str_in.len())
555 .chain(0..line_nr)
556 .map(|line_nr| str_in[line_nr])
557 .find(|line| !line.is_empty())
558 else {
559 return sample;
560 };
561
562 let chunks = 1 + ((line.len() - 1) / FSST_SAMPLELINE);
563 sample_rnd = fsst_hash(sample_rnd);
564 let chunk = FSST_SAMPLELINE * ((sample_rnd as usize) % chunks);
565
566 let len = FSST_SAMPLELINE.min(line.len() - chunk);
567
568 sample_buf.extend_from_slice(&line[chunk..chunk + len]);
569
570 let slice =
572 unsafe { std::slice::from_raw_parts(sample_buf.as_ptr().add(sample_buf_offset), len) };
573
574 sample.push(slice);
575
576 sample_buf_offset += len;
577 }
578
579 sample
580}
581
582#[inline]
586pub(crate) fn fsst_hash(value: u64) -> u64 {
587 value.wrapping_mul(2971215073) ^ value.wrapping_shr(15)
588}
589
590impl Compressor {
591 pub fn train(values: &Vec<&[u8]>) -> Self {
601 let mut builder = CompressorBuilder::new();
602
603 if values.is_empty() {
604 return builder.build();
605 }
606
607 let mut counters = Counter::new();
608 let mut sample_memory = Vec::with_capacity(FSST_SAMPLEMAX);
609 let mut pqueue = BinaryHeap::with_capacity(65_536);
610
611 let sample = make_sample(&mut sample_memory, values);
612 for sample_frac in GENERATIONS {
613 for (i, line) in sample.iter().enumerate() {
614 if sample_frac < 128 && ((fsst_hash(i as u64) & 127) as usize) > sample_frac {
615 continue;
616 }
617
618 builder.compress_count(line, &mut counters);
619 }
620
621 pqueue.clear();
623 builder.optimize(&counters, sample_frac, &mut pqueue);
624 counters.clear();
625 }
626
627 builder.build()
628 }
629}
630
631impl CompressorBuilder {
632 fn find_longest_symbol(&self, word: u64) -> Code {
634 let entry = self.lossy_pht.lookup(word);
636 let ignored_bits = entry.ignored_bits;
637
638 if !entry.is_unused() && compare_masked(word, entry.symbol.to_u64(), ignored_bits) {
640 return entry.code;
641 }
642
643 let twobyte = self.codes_two_byte[word as u16 as usize];
645 if twobyte.extended_code() >= FSST_CODE_BASE {
646 return twobyte;
647 }
648
649 self.codes_one_byte[word as u8 as usize]
651 }
652
653 fn compress_count(&self, sample: &[u8], counter: &mut Counter) -> usize {
659 let mut gain = 0;
660 if sample.is_empty() {
661 return gain;
662 }
663
664 let mut in_ptr = sample.as_ptr();
665
666 let in_end = unsafe { in_ptr.byte_add(sample.len()) };
668 let in_end_sub8 = in_end as usize - 8;
669
670 let mut prev_code: u16 = FSST_CODE_MASK;
671
672 while (in_ptr as usize) < (in_end_sub8) {
673 let word: u64 = unsafe { std::ptr::read_unaligned(in_ptr as *const u64) };
675 let code = self.find_longest_symbol(word);
676 let code_u16 = code.extended_code();
677
678 gain += (code.len() as usize) - ((code_u16 < 256) as usize);
681
682 counter.record_count1(code_u16);
684 counter.record_count2(prev_code, code_u16);
685
686 if code.len() > 1 {
689 let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16;
690 counter.record_count1(code_first_byte);
691 counter.record_count2(prev_code, code_first_byte);
692 }
693
694 in_ptr = unsafe { in_ptr.byte_add(code.len() as usize) };
696
697 prev_code = code_u16;
698 }
699
700 let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
701 assert!(
702 remaining_bytes.is_positive(),
703 "in_ptr exceeded in_end, should not be possible"
704 );
705 let remaining_bytes = remaining_bytes as usize;
706
707 let mut bytes = [0u8; 8];
711 unsafe {
712 std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes);
715 }
716 let mut last_word = u64::from_le_bytes(bytes);
717
718 let mut remaining_bytes = remaining_bytes;
719
720 while remaining_bytes > 0 {
721 let code = self.find_longest_symbol(last_word);
723 let code_u16 = code.extended_code();
724
725 gain += (code.len() as usize) - ((code_u16 < 256) as usize);
728
729 counter.record_count1(code_u16);
731 counter.record_count2(prev_code, code_u16);
732
733 if code.len() > 1 {
736 let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16;
737 counter.record_count1(code_first_byte);
738 counter.record_count2(prev_code, code_first_byte);
739 }
740
741 let advance = code.len() as usize;
743 remaining_bytes -= advance;
744 last_word = advance_8byte_word(last_word, advance);
745
746 prev_code = code_u16;
747 }
748
749 gain
750 }
751
752 fn optimize(
755 &mut self,
756 counters: &Counter,
757 sample_frac: usize,
758 pqueue: &mut BinaryHeap<Candidate>,
759 ) {
760 for code1 in counters.first_codes() {
761 let symbol1 = self.symbols[code1 as usize];
762 let symbol1_len = symbol1.len();
763 let count = counters.count1(code1);
764
765 if count < (5 * sample_frac / 128) {
768 continue;
769 }
770
771 let mut gain = count * symbol1_len;
772 if code1 < 256 {
775 gain *= 8;
776 }
777
778 pqueue.push(Candidate {
779 symbol: symbol1,
780 gain,
781 });
782
783 if sample_frac >= 128 || symbol1_len == 8 {
785 continue;
786 }
787
788 for code2 in counters.second_codes(code1) {
789 let symbol2 = self.symbols[code2 as usize];
790
791 if symbol1_len + symbol2.len() > 8 {
793 continue;
794 }
795 let new_symbol = symbol1.concat(symbol2);
796 let gain = counters.count2(code1, code2) * new_symbol.len();
797
798 pqueue.push(Candidate {
799 symbol: new_symbol,
800 gain,
801 })
802 }
803 }
804
805 self.clear();
807
808 let mut n_symbols = 0;
810 while !pqueue.is_empty() && n_symbols < 255 {
811 let candidate = pqueue.pop().unwrap();
812 if self.insert(candidate.symbol, candidate.symbol.len()) {
813 n_symbols += 1;
814 }
815 }
816 }
817}
818
819#[derive(Copy, Clone, Debug)]
823struct Candidate {
824 gain: usize,
825 symbol: Symbol,
826}
827
828impl Candidate {
829 fn comparable_form(&self) -> (usize, usize) {
830 (self.gain, self.symbol.len())
831 }
832}
833
834impl Eq for Candidate {}
835
836impl PartialEq<Self> for Candidate {
837 fn eq(&self, other: &Self) -> bool {
838 self.comparable_form().eq(&other.comparable_form())
839 }
840}
841
842impl PartialOrd<Self> for Candidate {
843 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
844 Some(self.cmp(other))
845 }
846}
847
848impl Ord for Candidate {
849 fn cmp(&self, other: &Self) -> Ordering {
850 let self_ord = (self.gain, self.symbol.len());
851 let other_ord = (other.gain, other.symbol.len());
852
853 self_ord.cmp(&other_ord)
854 }
855}
856
857#[cfg(test)]
858mod test {
859 use crate::{Compressor, ESCAPE_CODE, builder::CodesBitmap};
860
861 #[test]
862 fn test_builder() {
863 let text = b"hello hello hello hello hello";
865
866 let table = Compressor::train(&vec![text, text, text, text, text]);
868
869 let compressed = table.compress(text);
871
872 assert!(compressed.iter().all(|b| *b != ESCAPE_CODE));
874
875 let compressed = table.compress("xyz123".as_bytes());
877 let decompressed = table.decompressor().decompress(&compressed);
878 assert_eq!(&decompressed, b"xyz123");
879 assert_eq!(
880 compressed,
881 vec![
882 ESCAPE_CODE,
883 b'x',
884 ESCAPE_CODE,
885 b'y',
886 ESCAPE_CODE,
887 b'z',
888 ESCAPE_CODE,
889 b'1',
890 ESCAPE_CODE,
891 b'2',
892 ESCAPE_CODE,
893 b'3',
894 ]
895 );
896 }
897
898 #[test]
899 fn test_bitmap() {
900 let mut map = CodesBitmap::default();
901 map.set(10);
902 map.set(100);
903 map.set(500);
904
905 let codes: Vec<u16> = map.codes().collect();
906 assert_eq!(codes, vec![10u16, 100, 500]);
907
908 let map = CodesBitmap::default();
910 assert!(map.codes().collect::<Vec<_>>().is_empty());
911
912 let mut map = CodesBitmap::default();
914 (0..8).for_each(|i| map.set(64 * i));
915 assert_eq!(
916 map.codes().collect::<Vec<_>>(),
917 (0u16..8).map(|i| 64 * i).collect::<Vec<_>>(),
918 );
919
920 let mut map = CodesBitmap::default();
922 for i in 0..512 {
923 map.set(i);
924 }
925 assert_eq!(
926 map.codes().collect::<Vec<_>>(),
927 (0u16..511u16).collect::<Vec<_>>()
928 );
929 }
930
931 #[test]
932 #[should_panic(expected = "code cannot exceed")]
933 fn test_bitmap_invalid() {
934 let mut map = CodesBitmap::default();
935 map.set(512);
936 }
937}