1use crate::{
8 Code, Compressor, FSST_CODE_BASE, FSST_CODE_MASK, Symbol, advance_8byte_word, compare_masked,
9 lossy_pht::LossyPHT,
10};
11use rustc_hash::{FxBuildHasher, FxHashMap};
12use std::cmp::Ordering;
13use std::collections::BinaryHeap;
14
15#[derive(Clone, Copy, Debug, Default)]
17struct CodesBitmap {
18 codes: [u64; 8],
19}
20
21assert_sizeof!(CodesBitmap => 64);
22
23impl CodesBitmap {
24 pub(crate) fn set(&mut self, index: usize) {
26 debug_assert!(
27 index <= FSST_CODE_MASK as usize,
28 "code cannot exceed {FSST_CODE_MASK}"
29 );
30
31 let map = index >> 6;
32 self.codes[map] |= 1 << (index % 64);
33 }
34
35 pub(crate) fn is_set(&self, index: usize) -> bool {
37 debug_assert!(
38 index <= FSST_CODE_MASK as usize,
39 "code cannot exceed {FSST_CODE_MASK}"
40 );
41
42 let map = index >> 6;
43 self.codes[map] & (1 << (index % 64)) != 0
44 }
45
46 pub(crate) fn codes(&self) -> CodesIterator<'_> {
48 CodesIterator {
49 inner: self,
50 index: 0,
51 block: self.codes[0],
52 reference: 0,
53 }
54 }
55
56 pub(crate) fn clear(&mut self) {
58 self.codes[0] = 0;
59 self.codes[1] = 0;
60 self.codes[2] = 0;
61 self.codes[3] = 0;
62 self.codes[4] = 0;
63 self.codes[5] = 0;
64 self.codes[6] = 0;
65 self.codes[7] = 0;
66 }
67}
68
69struct CodesIterator<'a> {
70 inner: &'a CodesBitmap,
71 index: usize,
72 block: u64,
73 reference: usize,
74}
75
76impl Iterator for CodesIterator<'_> {
77 type Item = u16;
78
79 fn next(&mut self) -> Option<Self::Item> {
80 while self.block == 0 {
82 self.index += 1;
83 if self.index >= 8 {
84 return None;
85 }
86 self.block = self.inner.codes[self.index];
87 self.reference = self.index * 64;
88 }
89
90 let position = self.block.trailing_zeros() as usize;
92 let code = self.reference + position;
93
94 if code >= 511 {
95 return None;
96 }
97
98 self.reference = code + 1;
100 self.block = if position == 63 {
101 0
102 } else {
103 self.block >> (1 + position)
104 };
105
106 Some(code as u16)
107 }
108}
109
110#[derive(Debug, Clone)]
111struct Counter {
112 counts1: Vec<usize>,
114
115 counts2: Vec<usize>,
117
118 code1_index: CodesBitmap,
120
121 pair_index: Vec<CodesBitmap>,
126}
127
128const COUNTS1_SIZE: usize = (FSST_CODE_MASK + 1) as usize;
129
130const COUNTS2_SIZE: usize = COUNTS1_SIZE * COUNTS1_SIZE;
136
137impl Counter {
138 fn new() -> Self {
139 let mut counts1 = Vec::with_capacity(COUNTS1_SIZE);
140 let mut counts2 = Vec::with_capacity(COUNTS2_SIZE);
141 unsafe {
144 counts1.set_len(COUNTS1_SIZE);
145 counts2.set_len(COUNTS2_SIZE);
146 }
147
148 Self {
149 counts1,
150 counts2,
151 code1_index: CodesBitmap::default(),
152 pair_index: vec![CodesBitmap::default(); COUNTS1_SIZE],
153 }
154 }
155
156 #[inline]
157 fn record_count1(&mut self, code1: u16) {
158 let base = if self.code1_index.is_set(code1 as usize) {
160 self.counts1[code1 as usize]
161 } else {
162 0
163 };
164
165 self.counts1[code1 as usize] = base + 1;
166 self.code1_index.set(code1 as usize);
167 }
168
169 #[inline]
170 fn record_count2(&mut self, code1: u16, code2: u16) {
171 debug_assert!(code1 == FSST_CODE_MASK || self.code1_index.is_set(code1 as usize));
172 debug_assert!(self.code1_index.is_set(code2 as usize));
173
174 let idx = (code1 as usize) * COUNTS1_SIZE + (code2 as usize);
175 if self.pair_index[code1 as usize].is_set(code2 as usize) {
176 self.counts2[idx] += 1;
177 } else {
178 self.counts2[idx] = 1;
179 }
180 self.pair_index[code1 as usize].set(code2 as usize);
181 }
182
183 #[inline]
184 fn count1(&self, code1: u16) -> usize {
185 debug_assert!(self.code1_index.is_set(code1 as usize));
186
187 self.counts1[code1 as usize]
188 }
189
190 #[inline]
191 fn count2(&self, code1: u16, code2: u16) -> usize {
192 debug_assert!(self.code1_index.is_set(code1 as usize));
193 debug_assert!(self.code1_index.is_set(code2 as usize));
194 debug_assert!(self.pair_index[code1 as usize].is_set(code2 as usize));
195
196 let idx = (code1 as usize) * 512 + (code2 as usize);
197 self.counts2[idx]
198 }
199
200 fn first_codes(&self) -> CodesIterator<'_> {
203 self.code1_index.codes()
204 }
205
206 fn second_codes(&self, code1: u16) -> CodesIterator<'_> {
212 self.pair_index[code1 as usize].codes()
213 }
214
215 fn clear(&mut self) {
218 self.code1_index.clear();
219 for index in &mut self.pair_index {
220 index.clear();
221 }
222 }
223}
224
225pub struct CompressorBuilder {
227 symbols: Vec<Symbol>,
231
232 n_symbols: u8,
235
236 len_histogram: [u8; 8],
240
241 codes_one_byte: Vec<Code>,
245
246 codes_two_byte: Vec<Code>,
248
249 lossy_pht: LossyPHT,
251}
252
253impl CompressorBuilder {
254 pub fn new() -> Self {
256 let symbols = vec![0u64; 511];
260
261 let symbols: Vec<Symbol> = unsafe { std::mem::transmute(symbols) };
263
264 let mut table = Self {
265 symbols,
266 n_symbols: 0,
267 len_histogram: [0; 8],
268 codes_two_byte: Vec::with_capacity(65_536),
269 codes_one_byte: Vec::with_capacity(512),
270 lossy_pht: LossyPHT::new(),
271 };
272
273 for byte in 0..=255 {
275 let symbol = Symbol::from_u8(byte);
276 table.symbols[byte as usize] = symbol;
277 }
278
279 for byte in 0..=255 {
281 table.codes_one_byte.push(Code::new_escape(byte));
283 }
284
285 for idx in 0..=65_535 {
287 table.codes_two_byte.push(Code::new_escape(idx as u8));
288 }
289
290 table
291 }
292}
293
294impl Default for CompressorBuilder {
295 fn default() -> Self {
296 Self::new()
297 }
298}
299
300impl CompressorBuilder {
301 pub fn insert(&mut self, symbol: Symbol, len: usize) -> bool {
312 assert!(self.n_symbols < 255, "cannot insert into full symbol table");
313 assert_eq!(len, symbol.len(), "provided len must equal symbol.len()");
314
315 if len == 2 {
316 self.codes_two_byte[symbol.first2() as usize] =
318 Code::new_symbol_building(self.n_symbols, 2);
319 } else if len == 1 {
320 self.codes_one_byte[symbol.first_byte() as usize] =
322 Code::new_symbol_building(self.n_symbols, 1);
323 } else {
324 if !self.lossy_pht.insert(symbol, len, self.n_symbols) {
326 return false;
327 }
328 }
329
330 self.len_histogram[len - 1] += 1;
332
333 self.symbols[256 + (self.n_symbols as usize)] = symbol;
336 self.n_symbols += 1;
337 true
338 }
339
340 fn clear(&mut self) {
345 for code in 0..(256 + self.n_symbols as usize) {
347 let symbol = self.symbols[code];
348 if symbol.len() == 1 {
349 self.codes_one_byte[symbol.first_byte() as usize] =
351 Code::new_escape(symbol.first_byte());
352 } else if symbol.len() == 2 {
353 self.codes_two_byte[symbol.first2() as usize] =
355 Code::new_escape(symbol.first_byte());
356 } else {
357 self.lossy_pht.remove(symbol);
359 }
360 }
361
362 for i in 0..=7 {
364 self.len_histogram[i] = 0;
365 }
366
367 self.n_symbols = 0;
368 }
369
370 fn finalize(&mut self) -> (u8, Vec<u8>) {
389 let byte_lim = self.n_symbols - self.len_histogram[0];
394
395 let mut codes_by_length = [0u8; 8];
399 codes_by_length[0] = byte_lim;
400 codes_by_length[1] = 0;
401
402 for i in 1..7 {
404 codes_by_length[i + 1] = codes_by_length[i] + self.len_histogram[i];
405 }
406
407 let mut no_suffix_code = 0;
411
412 let mut has_suffix_code = codes_by_length[2];
414
415 let mut new_codes = [0u8; FSST_CODE_BASE as usize];
418
419 let mut symbol_lens = [0u8; FSST_CODE_BASE as usize];
420
421 for i in 0..(self.n_symbols as usize) {
422 let symbol = self.symbols[256 + i];
423 let len = symbol.len();
424 if len == 2 {
425 let has_suffix = self
426 .symbols
427 .iter()
428 .skip(FSST_CODE_BASE as usize)
429 .enumerate()
430 .any(|(k, other)| i != k && symbol.first2() == other.first2());
431
432 if has_suffix {
433 has_suffix_code -= 1;
435 new_codes[i] = has_suffix_code;
436 } else {
437 new_codes[i] = no_suffix_code;
440 no_suffix_code += 1;
441 }
442 } else {
443 new_codes[i] = codes_by_length[len - 1];
445 codes_by_length[len - 1] += 1;
446 }
447
448 self.symbols[new_codes[i] as usize] = symbol;
451 symbol_lens[new_codes[i] as usize] = len as u8;
452 }
453
454 self.symbols.truncate(self.n_symbols as usize);
456
457 for byte in 0..=255 {
460 let one_byte = self.codes_one_byte[byte];
461 if one_byte.extended_code() >= FSST_CODE_BASE {
462 let new_code = new_codes[one_byte.code() as usize];
463 self.codes_one_byte[byte] = Code::new_symbol(new_code, 1);
464 } else {
465 self.codes_one_byte[byte] = Code::UNUSED;
467 }
468 }
469
470 for two_bytes in 0..=65_535 {
473 let two_byte = self.codes_two_byte[two_bytes];
474 if two_byte.extended_code() >= FSST_CODE_BASE {
475 let new_code = new_codes[two_byte.code() as usize];
476 self.codes_two_byte[two_bytes] = Code::new_symbol(new_code, 2);
477 } else {
478 self.codes_two_byte[two_bytes] = self.codes_one_byte[two_bytes & 0xFF];
480 }
481 }
482
483 self.lossy_pht.renumber(&new_codes);
485
486 let mut lengths = Vec::with_capacity(self.n_symbols as usize);
488 for symbol in &self.symbols {
489 lengths.push(symbol.len() as u8);
490 }
491
492 (has_suffix_code, lengths)
493 }
494
495 pub fn build(mut self) -> Compressor {
497 let (has_suffix_code, lengths) = self.finalize();
501
502 Compressor {
503 symbols: self.symbols,
504 lengths,
505 n_symbols: self.n_symbols,
506 has_suffix_code,
507 codes_two_byte: self.codes_two_byte,
508 lossy_pht: self.lossy_pht,
509 }
510 }
511}
512
513#[cfg(not(miri))]
517const GENERATIONS: [usize; 5] = [8usize, 38, 68, 98, 128];
518#[cfg(miri)]
519const GENERATIONS: [usize; 3] = [8usize, 38, 128];
520
521const FSST_SAMPLETARGET: usize = 1 << 14;
522const FSST_SAMPLEMAX: usize = 1 << 15;
523const FSST_SAMPLELINE: usize = 512;
524
525#[allow(clippy::ptr_arg)]
532fn make_sample<'a, 'b: 'a>(
533 sample_buf: &'a mut Vec<u8>,
534 str_in: &Vec<&'b [u8]>,
535 tot_size: usize,
536) -> Vec<&'a [u8]> {
537 assert!(
538 sample_buf.capacity() >= FSST_SAMPLEMAX,
539 "sample_buf.len() < FSST_SAMPLEMAX"
540 );
541
542 let mut sample: Vec<&[u8]> = Vec::new();
543
544 if tot_size < FSST_SAMPLETARGET {
545 return str_in.clone();
546 }
547
548 let mut sample_rnd = fsst_hash(4637947);
549 let sample_lim = FSST_SAMPLETARGET;
550 let mut sample_buf_offset: usize = 0;
551
552 while sample_buf_offset < sample_lim {
553 sample_rnd = fsst_hash(sample_rnd);
554 let line_nr = (sample_rnd as usize) % str_in.len();
555
556 let Some(line) = (line_nr..str_in.len())
559 .chain(0..line_nr)
560 .map(|line_nr| str_in[line_nr])
561 .find(|line| !line.is_empty())
562 else {
563 return sample;
564 };
565
566 let chunks = 1 + ((line.len() - 1) / FSST_SAMPLELINE);
567 sample_rnd = fsst_hash(sample_rnd);
568 let chunk = FSST_SAMPLELINE * ((sample_rnd as usize) % chunks);
569
570 let len = FSST_SAMPLELINE.min(line.len() - chunk);
571
572 sample_buf.extend_from_slice(&line[chunk..chunk + len]);
573
574 let slice =
576 unsafe { std::slice::from_raw_parts(sample_buf.as_ptr().add(sample_buf_offset), len) };
577
578 sample.push(slice);
579
580 sample_buf_offset += len;
581 }
582
583 sample
584}
585
586#[inline]
590pub(crate) fn fsst_hash(value: u64) -> u64 {
591 value.wrapping_mul(2971215073) ^ value.wrapping_shr(15)
592}
593
594impl Compressor {
595 pub fn train(values: &Vec<&[u8]>) -> Self {
605 let mut builder = CompressorBuilder::new();
606
607 if values.is_empty() {
608 return builder.build();
609 }
610
611 let mut counters = Counter::new();
612 let mut sample_memory = Vec::with_capacity(FSST_SAMPLEMAX);
613 let mut pqueue = BinaryHeap::with_capacity(65_536);
614
615 let tot_size: usize = values.iter().map(|s| s.len()).sum();
616 let sampled = tot_size >= FSST_SAMPLETARGET;
617 let sample = make_sample(&mut sample_memory, values, tot_size);
618 for sample_frac in GENERATIONS {
619 for (i, line) in sample.iter().enumerate() {
620 if sample_frac < 128 && ((fsst_hash(i as u64) & 127) as usize) > sample_frac {
621 continue;
622 }
623
624 builder.compress_count(line, &mut counters);
625 }
626
627 pqueue.clear();
629 let prune = sample_frac >= 128 && !sampled;
630 builder.optimize(&counters, sample_frac, &mut pqueue, prune);
631 counters.clear();
632 }
633
634 builder.build()
635 }
636}
637
638impl CompressorBuilder {
639 fn find_longest_symbol(&self, word: u64) -> Code {
641 let entry = self.lossy_pht.lookup(word);
643 let ignored_bits = entry.ignored_bits;
644
645 if !entry.is_unused() && compare_masked(word, entry.symbol.to_u64(), ignored_bits) {
647 return entry.code;
648 }
649
650 let twobyte = self.codes_two_byte[word as u16 as usize];
652 if twobyte.extended_code() >= FSST_CODE_BASE {
653 return twobyte;
654 }
655
656 self.codes_one_byte[word as u8 as usize]
658 }
659
660 fn compress_count(&self, sample: &[u8], counter: &mut Counter) -> usize {
666 let mut gain = 0;
667 if sample.is_empty() {
668 return gain;
669 }
670
671 let mut in_ptr = sample.as_ptr();
672
673 let in_end = unsafe { in_ptr.byte_add(sample.len()) };
675 let in_end_sub8 = in_end as usize - 8;
676
677 let mut prev_code: u16 = FSST_CODE_MASK;
678
679 while (in_ptr as usize) < (in_end_sub8) {
680 let word: u64 = unsafe { std::ptr::read_unaligned(in_ptr as *const u64) };
682 let code = self.find_longest_symbol(word);
683 let code_u16 = code.extended_code();
684
685 gain += (code.len() as usize) - ((code_u16 < 256) as usize);
688
689 counter.record_count1(code_u16);
691 counter.record_count2(prev_code, code_u16);
692
693 if code.len() > 1 {
696 let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16;
697 counter.record_count1(code_first_byte);
698 counter.record_count2(prev_code, code_first_byte);
699 }
700
701 in_ptr = unsafe { in_ptr.byte_add(code.len() as usize) };
703
704 prev_code = code_u16;
705 }
706
707 let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
708 assert!(
709 remaining_bytes.is_positive(),
710 "in_ptr exceeded in_end, should not be possible"
711 );
712 let remaining_bytes = remaining_bytes as usize;
713
714 let mut bytes = [0u8; 8];
718 unsafe {
719 std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes);
722 }
723 let mut last_word = u64::from_le_bytes(bytes);
724
725 let mut remaining_bytes = remaining_bytes;
726
727 while remaining_bytes > 0 {
728 let code = self.find_longest_symbol(last_word);
730 let code_u16 = code.extended_code();
731
732 gain += (code.len() as usize) - ((code_u16 < 256) as usize);
735
736 counter.record_count1(code_u16);
738 counter.record_count2(prev_code, code_u16);
739
740 if code.len() > 1 {
743 let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16;
744 counter.record_count1(code_first_byte);
745 counter.record_count2(prev_code, code_first_byte);
746 }
747
748 let advance = code.len() as usize;
750 remaining_bytes -= advance;
751 last_word = advance_8byte_word(last_word, advance);
752
753 prev_code = code_u16;
754 }
755
756 gain
757 }
758
759 fn optimize(
762 &mut self,
763 counters: &Counter,
764 sample_frac: usize,
765 pqueue: &mut BinaryHeap<Candidate>,
766 prune: bool,
767 ) {
768 let mut candidates = FxHashMap::with_capacity_and_hasher(256, FxBuildHasher);
773
774 for code1 in counters.first_codes() {
775 let symbol1 = self.symbols[code1 as usize];
776 let symbol1_len = symbol1.len();
777 let count = counters.count1(code1);
778
779 let min_count = if prune { 1 } else { 5 * sample_frac / 128 };
784 if count < min_count {
785 continue;
786 }
787
788 let mut gain = count * symbol1_len;
789 if symbol1_len == 1 {
792 gain *= 8;
793 }
794
795 *candidates.entry(symbol1).or_insert(0) += gain;
797
798 if sample_frac >= 128 || symbol1_len == 8 {
800 continue;
801 }
802
803 for code2 in counters.second_codes(code1) {
804 let symbol2 = self.symbols[code2 as usize];
805
806 if symbol1_len + symbol2.len() > 8 {
808 continue;
809 }
810 let new_symbol = symbol1.concat(symbol2);
811 let gain = counters.count2(code1, code2) * new_symbol.len();
812
813 *candidates.entry(new_symbol).or_insert(0) += gain;
815 }
816 }
817
818 for (symbol, gain) in candidates {
820 pqueue.push(Candidate { symbol, gain });
821 }
822
823 self.clear();
825
826 let mut n_symbols = 0;
828 while !pqueue.is_empty() && n_symbols < 255 {
829 let candidate = pqueue.pop().unwrap();
830 if prune {
831 let symbol_len = candidate.symbol.len();
832 let saves = if symbol_len == 1 {
833 candidate.gain / 8 } else {
835 candidate.gain
836 };
837 if saves <= symbol_len + 1 {
838 continue;
839 }
840 }
841 if self.insert(candidate.symbol, candidate.symbol.len()) {
842 n_symbols += 1;
843 }
844 }
845 }
846}
847
848#[derive(Copy, Clone, Debug)]
852struct Candidate {
853 gain: usize,
854 symbol: Symbol,
855}
856
857impl Candidate {
858 fn comparable_form(&self) -> (usize, usize) {
859 (self.gain, self.symbol.len())
860 }
861}
862
863impl Eq for Candidate {}
864
865impl PartialEq<Self> for Candidate {
866 fn eq(&self, other: &Self) -> bool {
867 self.comparable_form().eq(&other.comparable_form())
868 }
869}
870
871impl PartialOrd<Self> for Candidate {
872 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
873 Some(self.cmp(other))
874 }
875}
876
877impl Ord for Candidate {
878 fn cmp(&self, other: &Self) -> Ordering {
879 let self_ord = (self.gain, self.symbol.len());
880 let other_ord = (other.gain, other.symbol.len());
881
882 self_ord.cmp(&other_ord)
883 }
884}
885
886#[cfg(test)]
887mod test {
888 use crate::{Compressor, ESCAPE_CODE, builder::CodesBitmap};
889
890 #[test]
891 fn test_builder() {
892 let text = b"hello hello hello hello hello";
894
895 let table = Compressor::train(&vec![text, text, text, text, text]);
897
898 let compressed = table.compress(text);
900
901 assert!(compressed.iter().all(|b| *b != ESCAPE_CODE));
903
904 let compressed = table.compress("xyz123".as_bytes());
906 let decompressed = table.decompressor().decompress(&compressed);
907 assert_eq!(&decompressed, b"xyz123");
908 assert_eq!(
909 compressed,
910 vec![
911 ESCAPE_CODE,
912 b'x',
913 ESCAPE_CODE,
914 b'y',
915 ESCAPE_CODE,
916 b'z',
917 ESCAPE_CODE,
918 b'1',
919 ESCAPE_CODE,
920 b'2',
921 ESCAPE_CODE,
922 b'3',
923 ]
924 );
925 }
926
927 #[test]
928 fn test_bitmap() {
929 let mut map = CodesBitmap::default();
930 map.set(10);
931 map.set(100);
932 map.set(500);
933
934 let codes: Vec<u16> = map.codes().collect();
935 assert_eq!(codes, vec![10u16, 100, 500]);
936
937 let map = CodesBitmap::default();
939 assert!(map.codes().collect::<Vec<_>>().is_empty());
940
941 let mut map = CodesBitmap::default();
943 (0..8).for_each(|i| map.set(64 * i));
944 assert_eq!(
945 map.codes().collect::<Vec<_>>(),
946 (0u16..8).map(|i| 64 * i).collect::<Vec<_>>(),
947 );
948
949 let mut map = CodesBitmap::default();
951 for i in 0..512 {
952 map.set(i);
953 }
954 assert_eq!(
955 map.codes().collect::<Vec<_>>(),
956 (0u16..511u16).collect::<Vec<_>>()
957 );
958 }
959
960 #[test]
961 #[should_panic(expected = "code cannot exceed")]
962 fn test_bitmap_invalid() {
963 let mut map = CodesBitmap::default();
964 map.set(512);
965 }
966
967 #[test]
968 fn test_no_duplicate_symbols() {
969 let text = b"aababcabcdabcde";
971 let corpus: Vec<&[u8]> = std::iter::repeat_n(text.as_slice(), 100).collect();
972 let compressor = Compressor::train(&corpus);
973
974 let symbols = compressor.symbol_table();
975 let lengths = compressor.symbol_lengths();
976
977 let one_byte: Vec<u8> = symbols
979 .iter()
980 .zip(lengths.iter())
981 .filter(|&(_, &len)| len == 1)
982 .map(|(sym, _)| sym.first_byte())
983 .collect();
984 let mut one_byte_sorted = one_byte.clone();
985 one_byte_sorted.sort();
986 one_byte_sorted.dedup();
987 assert_eq!(
988 one_byte.len(),
989 one_byte_sorted.len(),
990 "duplicate 1-byte symbols found"
991 );
992
993 let two_byte: Vec<u16> = symbols
995 .iter()
996 .zip(lengths.iter())
997 .filter(|&(_, &len)| len == 2)
998 .map(|(sym, _)| sym.first2())
999 .collect();
1000 let mut two_byte_sorted = two_byte.clone();
1001 two_byte_sorted.sort();
1002 two_byte_sorted.dedup();
1003 assert_eq!(
1004 two_byte.len(),
1005 two_byte_sorted.len(),
1006 "duplicate 2-byte symbols found"
1007 );
1008 }
1009}