1use crate::{
8 Code, Compressor, FSST_CODE_BASE, FSST_CODE_MASK, Symbol, advance_8byte_word, compare_masked,
9 lossy_pht::LossyPHT,
10};
11use std::cmp::Ordering;
12use std::collections::BinaryHeap;
13
14#[derive(Clone, Copy, Debug, Default)]
16struct CodesBitmap {
17 codes: [u64; 8],
18}
19
20assert_sizeof!(CodesBitmap => 64);
21
22impl CodesBitmap {
23 pub(crate) fn set(&mut self, index: usize) {
25 debug_assert!(
26 index <= FSST_CODE_MASK as usize,
27 "code cannot exceed {FSST_CODE_MASK}"
28 );
29
30 let map = index >> 6;
31 self.codes[map] |= 1 << (index % 64);
32 }
33
34 pub(crate) fn is_set(&self, index: usize) -> bool {
36 debug_assert!(
37 index <= FSST_CODE_MASK as usize,
38 "code cannot exceed {FSST_CODE_MASK}"
39 );
40
41 let map = index >> 6;
42 self.codes[map] & 1 << (index % 64) != 0
43 }
44
45 pub(crate) fn codes(&self) -> CodesIterator {
47 CodesIterator {
48 inner: self,
49 index: 0,
50 block: self.codes[0],
51 reference: 0,
52 }
53 }
54
55 pub(crate) fn clear(&mut self) {
57 self.codes[0] = 0;
58 self.codes[1] = 0;
59 self.codes[2] = 0;
60 self.codes[3] = 0;
61 self.codes[4] = 0;
62 self.codes[5] = 0;
63 self.codes[6] = 0;
64 self.codes[7] = 0;
65 }
66}
67
68struct CodesIterator<'a> {
69 inner: &'a CodesBitmap,
70 index: usize,
71 block: u64,
72 reference: usize,
73}
74
75impl Iterator for CodesIterator<'_> {
76 type Item = u16;
77
78 fn next(&mut self) -> Option<Self::Item> {
79 while self.block == 0 {
81 self.index += 1;
82 if self.index >= 8 {
83 return None;
84 }
85 self.block = self.inner.codes[self.index];
86 self.reference = self.index * 64;
87 }
88
89 let position = self.block.trailing_zeros() as usize;
91 let code = self.reference + position;
92
93 if code >= 511 {
94 return None;
95 }
96
97 self.reference = code + 1;
99 self.block = if position == 63 {
100 0
101 } else {
102 self.block >> (1 + position)
103 };
104
105 Some(code as u16)
106 }
107}
108
109#[derive(Debug, Clone)]
110struct Counter {
111 counts1: Vec<usize>,
113
114 counts2: Vec<usize>,
116
117 code1_index: CodesBitmap,
119
120 pair_index: Vec<CodesBitmap>,
125}
126
127const COUNTS1_SIZE: usize = (FSST_CODE_MASK + 1) as usize;
128
129const COUNTS2_SIZE: usize = COUNTS1_SIZE * COUNTS1_SIZE;
135
136impl Counter {
137 fn new() -> Self {
138 let mut counts1 = Vec::with_capacity(COUNTS1_SIZE);
139 let mut counts2 = Vec::with_capacity(COUNTS2_SIZE);
140 unsafe {
143 counts1.set_len(COUNTS1_SIZE);
144 counts2.set_len(COUNTS2_SIZE);
145 }
146
147 Self {
148 counts1,
149 counts2,
150 code1_index: CodesBitmap::default(),
151 pair_index: vec![CodesBitmap::default(); COUNTS1_SIZE],
152 }
153 }
154
155 #[inline]
156 fn record_count1(&mut self, code1: u16) {
157 let base = if self.code1_index.is_set(code1 as usize) {
159 self.counts1[code1 as usize]
160 } else {
161 0
162 };
163
164 self.counts1[code1 as usize] = base + 1;
165 self.code1_index.set(code1 as usize);
166 }
167
168 #[inline]
169 fn record_count2(&mut self, code1: u16, code2: u16) {
170 debug_assert!(code1 == FSST_CODE_MASK || self.code1_index.is_set(code1 as usize));
171 debug_assert!(self.code1_index.is_set(code2 as usize));
172
173 let idx = (code1 as usize) * COUNTS1_SIZE + (code2 as usize);
174 if self.pair_index[code1 as usize].is_set(code2 as usize) {
175 self.counts2[idx] += 1;
176 } else {
177 self.counts2[idx] = 1;
178 }
179 self.pair_index[code1 as usize].set(code2 as usize);
180 }
181
182 #[inline]
183 fn count1(&self, code1: u16) -> usize {
184 debug_assert!(self.code1_index.is_set(code1 as usize));
185
186 self.counts1[code1 as usize]
187 }
188
189 #[inline]
190 fn count2(&self, code1: u16, code2: u16) -> usize {
191 debug_assert!(self.code1_index.is_set(code1 as usize));
192 debug_assert!(self.code1_index.is_set(code2 as usize));
193 debug_assert!(self.pair_index[code1 as usize].is_set(code2 as usize));
194
195 let idx = (code1 as usize) * 512 + (code2 as usize);
196 self.counts2[idx]
197 }
198
199 fn first_codes(&self) -> CodesIterator {
202 self.code1_index.codes()
203 }
204
205 fn second_codes(&self, code1: u16) -> CodesIterator {
211 self.pair_index[code1 as usize].codes()
212 }
213
214 fn clear(&mut self) {
217 self.code1_index.clear();
218 for index in &mut self.pair_index {
219 index.clear();
220 }
221 }
222}
223
224pub struct CompressorBuilder {
226 symbols: Vec<Symbol>,
230
231 n_symbols: u8,
234
235 len_histogram: [u8; 8],
239
240 codes_one_byte: Vec<Code>,
244
245 codes_two_byte: Vec<Code>,
247
248 lossy_pht: LossyPHT,
250}
251
252impl CompressorBuilder {
253 pub fn new() -> Self {
255 let symbols = vec![0u64; 511];
259
260 let symbols: Vec<Symbol> = unsafe { std::mem::transmute(symbols) };
262
263 let mut table = Self {
264 symbols,
265 n_symbols: 0,
266 len_histogram: [0; 8],
267 codes_two_byte: Vec::with_capacity(65_536),
268 codes_one_byte: Vec::with_capacity(512),
269 lossy_pht: LossyPHT::new(),
270 };
271
272 for byte in 0..=255 {
274 let symbol = Symbol::from_u8(byte);
275 table.symbols[byte as usize] = symbol;
276 }
277
278 for byte in 0..=255 {
280 table.codes_one_byte.push(Code::new_escape(byte));
282 }
283
284 for idx in 0..=65_535 {
286 table.codes_two_byte.push(Code::new_escape(idx as u8));
287 }
288
289 table
290 }
291}
292
293impl Default for CompressorBuilder {
294 fn default() -> Self {
295 Self::new()
296 }
297}
298
299impl CompressorBuilder {
300 pub fn insert(&mut self, symbol: Symbol, len: usize) -> bool {
311 assert!(self.n_symbols < 255, "cannot insert into full symbol table");
312 assert_eq!(len, symbol.len(), "provided len must equal symbol.len()");
313
314 if len == 2 {
315 self.codes_two_byte[symbol.first2() as usize] =
317 Code::new_symbol_building(self.n_symbols, 2);
318 } else if len == 1 {
319 self.codes_one_byte[symbol.first_byte() as usize] =
321 Code::new_symbol_building(self.n_symbols, 1);
322 } else {
323 if !self.lossy_pht.insert(symbol, len, self.n_symbols) {
325 return false;
326 }
327 }
328
329 self.len_histogram[len - 1] += 1;
331
332 self.symbols[256 + (self.n_symbols as usize)] = symbol;
335 self.n_symbols += 1;
336 true
337 }
338
339 fn clear(&mut self) {
344 for code in 0..(256 + self.n_symbols as usize) {
346 let symbol = self.symbols[code];
347 if symbol.len() == 1 {
348 self.codes_one_byte[symbol.first_byte() as usize] =
350 Code::new_escape(symbol.first_byte());
351 } else if symbol.len() == 2 {
352 self.codes_two_byte[symbol.first2() as usize] =
354 Code::new_escape(symbol.first_byte());
355 } else {
356 self.lossy_pht.remove(symbol);
358 }
359 }
360
361 for i in 0..=7 {
363 self.len_histogram[i] = 0;
364 }
365
366 self.n_symbols = 0;
367 }
368
369 fn finalize(&mut self) -> (u8, Vec<u8>) {
388 let byte_lim = self.n_symbols - self.len_histogram[0];
393
394 let mut codes_by_length = [0u8; 8];
398 codes_by_length[0] = byte_lim;
399 codes_by_length[1] = 0;
400
401 for i in 1..7 {
403 codes_by_length[i + 1] = codes_by_length[i] + self.len_histogram[i];
404 }
405
406 let mut no_suffix_code = 0;
410
411 let mut has_suffix_code = codes_by_length[2];
413
414 let mut new_codes = [0u8; FSST_CODE_BASE as usize];
417
418 let mut symbol_lens = [0u8; FSST_CODE_BASE as usize];
419
420 for i in 0..(self.n_symbols as usize) {
421 let symbol = self.symbols[256 + i];
422 let len = symbol.len();
423 if len == 2 {
424 let has_suffix = self
425 .symbols
426 .iter()
427 .skip(FSST_CODE_BASE as usize)
428 .enumerate()
429 .any(|(k, other)| i != k && symbol.first2() == other.first2());
430
431 if has_suffix {
432 has_suffix_code -= 1;
434 new_codes[i] = has_suffix_code;
435 } else {
436 new_codes[i] = no_suffix_code;
439 no_suffix_code += 1;
440 }
441 } else {
442 new_codes[i] = codes_by_length[len - 1];
444 codes_by_length[len - 1] += 1;
445 }
446
447 self.symbols[new_codes[i] as usize] = symbol;
450 symbol_lens[new_codes[i] as usize] = len as u8;
451 }
452
453 self.symbols.truncate(self.n_symbols as usize);
455
456 for byte in 0..=255 {
459 let one_byte = self.codes_one_byte[byte];
460 if one_byte.extended_code() >= FSST_CODE_BASE {
461 let new_code = new_codes[one_byte.code() as usize];
462 self.codes_one_byte[byte] = Code::new_symbol(new_code, 1);
463 } else {
464 self.codes_one_byte[byte] = Code::UNUSED;
466 }
467 }
468
469 for two_bytes in 0..=65_535 {
472 let two_byte = self.codes_two_byte[two_bytes];
473 if two_byte.extended_code() >= FSST_CODE_BASE {
474 let new_code = new_codes[two_byte.code() as usize];
475 self.codes_two_byte[two_bytes] = Code::new_symbol(new_code, 2);
476 } else {
477 self.codes_two_byte[two_bytes] = self.codes_one_byte[two_bytes & 0xFF];
479 }
480 }
481
482 self.lossy_pht.renumber(&new_codes);
484
485 let mut lengths = Vec::with_capacity(self.n_symbols as usize);
487 for symbol in &self.symbols {
488 lengths.push(symbol.len() as u8);
489 }
490
491 (has_suffix_code, lengths)
492 }
493
494 pub fn build(mut self) -> Compressor {
496 let (has_suffix_code, lengths) = self.finalize();
500
501 Compressor {
502 symbols: self.symbols,
503 lengths,
504 n_symbols: self.n_symbols,
505 has_suffix_code,
506 codes_two_byte: self.codes_two_byte,
507 lossy_pht: self.lossy_pht,
508 }
509 }
510}
511
512#[cfg(not(miri))]
516const GENERATIONS: [usize; 5] = [8usize, 38, 68, 98, 128];
517#[cfg(miri)]
518const GENERATIONS: [usize; 3] = [8usize, 38, 128];
519
520const FSST_SAMPLETARGET: usize = 1 << 14;
521const FSST_SAMPLEMAX: usize = 1 << 15;
522const FSST_SAMPLELINE: usize = 512;
523
524#[allow(clippy::ptr_arg)]
531fn make_sample<'a, 'b: 'a>(sample_buf: &'a mut Vec<u8>, str_in: &Vec<&'b [u8]>) -> Vec<&'a [u8]> {
532 assert!(
533 sample_buf.capacity() >= FSST_SAMPLEMAX,
534 "sample_buf.len() < FSST_SAMPLEMAX"
535 );
536
537 let mut sample: Vec<&[u8]> = Vec::new();
538
539 let tot_size: usize = str_in.iter().map(|s| s.len()).sum();
540 if tot_size < FSST_SAMPLETARGET {
541 return str_in.clone();
542 }
543
544 let mut sample_rnd = fsst_hash(4637947);
545 let sample_lim = FSST_SAMPLETARGET;
546 let mut sample_buf_offset: usize = 0;
547
548 while sample_buf_offset < sample_lim {
549 sample_rnd = fsst_hash(sample_rnd);
550 let line_nr = (sample_rnd as usize) % str_in.len();
551
552 let Some(line) = (line_nr..str_in.len())
555 .chain(0..line_nr)
556 .map(|line_nr| str_in[line_nr])
557 .find(|line| !line.is_empty())
558 else {
559 return sample;
560 };
561
562 let chunks = 1 + ((line.len() - 1) / FSST_SAMPLELINE);
563 sample_rnd = fsst_hash(sample_rnd);
564 let chunk = FSST_SAMPLELINE * ((sample_rnd as usize) % chunks);
565
566 let len = FSST_SAMPLELINE.min(line.len() - chunk);
567
568 sample_buf.extend_from_slice(&line[chunk..chunk + len]);
569
570 let slice =
572 unsafe { std::slice::from_raw_parts(sample_buf.as_ptr().add(sample_buf_offset), len) };
573
574 sample.push(slice);
575
576 sample_buf_offset += len;
577 }
578
579 sample
580}
581
582#[inline]
586pub(crate) fn fsst_hash(value: u64) -> u64 {
587 value.wrapping_mul(2971215073) ^ value.wrapping_shr(15)
588}
589
590impl Compressor {
591 pub fn train(values: &Vec<&[u8]>) -> Self {
601 let mut builder = CompressorBuilder::new();
602
603 if values.is_empty() {
604 return builder.build();
605 }
606
607 let mut counters = Counter::new();
608 let mut sample_memory = Vec::with_capacity(FSST_SAMPLEMAX);
609 let sample = make_sample(&mut sample_memory, values);
610 for sample_frac in GENERATIONS {
611 for (i, line) in sample.iter().enumerate() {
612 if sample_frac < 128 && ((fsst_hash(i as u64) & 127) as usize) > sample_frac {
613 continue;
614 }
615
616 builder.compress_count(line, &mut counters);
617 }
618
619 builder.optimize(&counters, sample_frac);
620 counters.clear();
621 }
622
623 builder.build()
624 }
625}
626
627impl CompressorBuilder {
628 fn find_longest_symbol(&self, word: u64) -> Code {
630 let entry = self.lossy_pht.lookup(word);
632 let ignored_bits = entry.ignored_bits;
633
634 if !entry.is_unused() && compare_masked(word, entry.symbol.as_u64(), ignored_bits) {
636 return entry.code;
637 }
638
639 let twobyte = self.codes_two_byte[word as u16 as usize];
641 if twobyte.extended_code() >= FSST_CODE_BASE {
642 return twobyte;
643 }
644
645 self.codes_one_byte[word as u8 as usize]
647 }
648
649 fn compress_count(&self, sample: &[u8], counter: &mut Counter) -> usize {
655 let mut gain = 0;
656 if sample.is_empty() {
657 return gain;
658 }
659
660 let mut in_ptr = sample.as_ptr();
661
662 let in_end = unsafe { in_ptr.byte_add(sample.len()) };
664 let in_end_sub8 = in_end as usize - 8;
665
666 let mut prev_code: u16 = FSST_CODE_MASK;
667
668 while (in_ptr as usize) < (in_end_sub8) {
669 let word: u64 = unsafe { std::ptr::read_unaligned(in_ptr as *const u64) };
671 let code = self.find_longest_symbol(word);
672 let code_u16 = code.extended_code();
673
674 gain += (code.len() as usize) - ((code_u16 < 256) as usize);
677
678 counter.record_count1(code_u16);
680 counter.record_count2(prev_code, code_u16);
681
682 if code.len() > 1 {
685 let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16;
686 counter.record_count1(code_first_byte);
687 counter.record_count2(prev_code, code_first_byte);
688 }
689
690 in_ptr = unsafe { in_ptr.byte_add(code.len() as usize) };
692
693 prev_code = code_u16;
694 }
695
696 let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
697 assert!(
698 remaining_bytes.is_positive(),
699 "in_ptr exceeded in_end, should not be possible"
700 );
701 let remaining_bytes = remaining_bytes as usize;
702
703 let mut bytes = [0u8; 8];
707 unsafe {
708 std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes);
711 }
712 let mut last_word = u64::from_le_bytes(bytes);
713
714 let mut remaining_bytes = remaining_bytes;
715
716 while remaining_bytes > 0 {
717 let code = self.find_longest_symbol(last_word);
719 let code_u16 = code.extended_code();
720
721 gain += (code.len() as usize) - ((code_u16 < 256) as usize);
724
725 counter.record_count1(code_u16);
727 counter.record_count2(prev_code, code_u16);
728
729 if code.len() > 1 {
732 let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16;
733 counter.record_count1(code_first_byte);
734 counter.record_count2(prev_code, code_first_byte);
735 }
736
737 let advance = code.len() as usize;
739 remaining_bytes -= advance;
740 last_word = advance_8byte_word(last_word, advance);
741
742 prev_code = code_u16;
743 }
744
745 gain
746 }
747
748 fn optimize(&mut self, counters: &Counter, sample_frac: usize) {
751 let mut pqueue = BinaryHeap::with_capacity(65_536);
752
753 for code1 in counters.first_codes() {
754 let symbol1 = self.symbols[code1 as usize];
755 let symbol1_len = symbol1.len();
756 let count = counters.count1(code1);
757
758 if count < (5 * sample_frac / 128) {
761 continue;
762 }
763
764 let mut gain = count * symbol1_len;
765 if code1 < 256 {
768 gain *= 8;
769 }
770
771 pqueue.push(Candidate {
772 symbol: symbol1,
773 gain,
774 });
775
776 if sample_frac >= 128 || symbol1_len == 8 {
778 continue;
779 }
780
781 for code2 in counters.second_codes(code1) {
782 let symbol2 = self.symbols[code2 as usize];
783
784 if symbol1_len + symbol2.len() > 8 {
786 continue;
787 }
788 let new_symbol = symbol1.concat(symbol2);
789 let gain = counters.count2(code1, code2) * new_symbol.len();
790
791 pqueue.push(Candidate {
792 symbol: new_symbol,
793 gain,
794 })
795 }
796 }
797
798 self.clear();
800
801 let mut n_symbols = 0;
803 while !pqueue.is_empty() && n_symbols < 255 {
804 let candidate = pqueue.pop().unwrap();
805 if self.insert(candidate.symbol, candidate.symbol.len()) {
806 n_symbols += 1;
807 }
808 }
809 }
810}
811
812#[derive(Copy, Clone, Debug)]
816struct Candidate {
817 gain: usize,
818 symbol: Symbol,
819}
820
821impl Candidate {
822 fn comparable_form(&self) -> (usize, usize) {
823 (self.gain, self.symbol.len())
824 }
825}
826
827impl Eq for Candidate {}
828
829impl PartialEq<Self> for Candidate {
830 fn eq(&self, other: &Self) -> bool {
831 self.comparable_form().eq(&other.comparable_form())
832 }
833}
834
835impl PartialOrd<Self> for Candidate {
836 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
837 Some(self.cmp(other))
838 }
839}
840
841impl Ord for Candidate {
842 fn cmp(&self, other: &Self) -> Ordering {
843 let self_ord = (self.gain, self.symbol.len());
844 let other_ord = (other.gain, other.symbol.len());
845
846 self_ord.cmp(&other_ord)
847 }
848}
849
850#[cfg(test)]
851mod test {
852 use crate::{Compressor, ESCAPE_CODE, builder::CodesBitmap};
853
854 #[test]
855 fn test_builder() {
856 let text = b"hello hello hello hello hello";
858
859 let table = Compressor::train(&vec![text, text, text, text, text]);
861
862 let compressed = table.compress(text);
864
865 assert!(compressed.iter().all(|b| *b != ESCAPE_CODE));
867
868 let compressed = table.compress("xyz123".as_bytes());
870 let decompressed = table.decompressor().decompress(&compressed);
871 assert_eq!(&decompressed, b"xyz123");
872 assert_eq!(
873 compressed,
874 vec![
875 ESCAPE_CODE,
876 b'x',
877 ESCAPE_CODE,
878 b'y',
879 ESCAPE_CODE,
880 b'z',
881 ESCAPE_CODE,
882 b'1',
883 ESCAPE_CODE,
884 b'2',
885 ESCAPE_CODE,
886 b'3',
887 ]
888 );
889 }
890
891 #[test]
892 fn test_bitmap() {
893 let mut map = CodesBitmap::default();
894 map.set(10);
895 map.set(100);
896 map.set(500);
897
898 let codes: Vec<u16> = map.codes().collect();
899 assert_eq!(codes, vec![10u16, 100, 500]);
900
901 let map = CodesBitmap::default();
903 assert!(map.codes().collect::<Vec<_>>().is_empty());
904
905 let mut map = CodesBitmap::default();
907 (0..8).for_each(|i| map.set(64 * i));
908 assert_eq!(
909 map.codes().collect::<Vec<_>>(),
910 (0u16..8).map(|i| 64 * i).collect::<Vec<_>>(),
911 );
912
913 let mut map = CodesBitmap::default();
915 for i in 0..512 {
916 map.set(i);
917 }
918 assert_eq!(
919 map.codes().collect::<Vec<_>>(),
920 (0u16..511u16).collect::<Vec<_>>()
921 );
922 }
923
924 #[test]
925 #[should_panic(expected = "code cannot exceed")]
926 fn test_bitmap_invalid() {
927 let mut map = CodesBitmap::default();
928 map.set(512);
929 }
930}