fsst/
builder.rs

1//! Functions and types used for building a [`Compressor`] from a corpus of text.
2//!
3//! This module implements the logic from Algorithm 3 of the [FSST Paper].
4//!
5//! [FSST Paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf
6
7use crate::{
8    Code, Compressor, FSST_CODE_BASE, FSST_CODE_MASK, Symbol, advance_8byte_word, compare_masked,
9    lossy_pht::LossyPHT,
10};
11use std::cmp::Ordering;
12use std::collections::BinaryHeap;
13
14/// Bitmap that only works for values up to 512
15#[derive(Clone, Copy, Debug, Default)]
16struct CodesBitmap {
17    codes: [u64; 8],
18}
19
20assert_sizeof!(CodesBitmap => 64);
21
22impl CodesBitmap {
23    /// Set the indicated bit. Must be between 0 and [`FSST_CODE_MASK`][crate::FSST_CODE_MASK].
24    pub(crate) fn set(&mut self, index: usize) {
25        debug_assert!(
26            index <= FSST_CODE_MASK as usize,
27            "code cannot exceed {FSST_CODE_MASK}"
28        );
29
30        let map = index >> 6;
31        self.codes[map] |= 1 << (index % 64);
32    }
33
34    /// Check if `index` is present in the bitmap
35    pub(crate) fn is_set(&self, index: usize) -> bool {
36        debug_assert!(
37            index <= FSST_CODE_MASK as usize,
38            "code cannot exceed {FSST_CODE_MASK}"
39        );
40
41        let map = index >> 6;
42        self.codes[map] & (1 << (index % 64)) != 0
43    }
44
45    /// Get all codes set in this bitmap
46    pub(crate) fn codes(&self) -> CodesIterator<'_> {
47        CodesIterator {
48            inner: self,
49            index: 0,
50            block: self.codes[0],
51            reference: 0,
52        }
53    }
54
55    /// Clear the bitmap of all entries.
56    pub(crate) fn clear(&mut self) {
57        self.codes[0] = 0;
58        self.codes[1] = 0;
59        self.codes[2] = 0;
60        self.codes[3] = 0;
61        self.codes[4] = 0;
62        self.codes[5] = 0;
63        self.codes[6] = 0;
64        self.codes[7] = 0;
65    }
66}
67
68struct CodesIterator<'a> {
69    inner: &'a CodesBitmap,
70    index: usize,
71    block: u64,
72    reference: usize,
73}
74
75impl Iterator for CodesIterator<'_> {
76    type Item = u16;
77
78    fn next(&mut self) -> Option<Self::Item> {
79        // If current is zero, advance to next non-zero block
80        while self.block == 0 {
81            self.index += 1;
82            if self.index >= 8 {
83                return None;
84            }
85            self.block = self.inner.codes[self.index];
86            self.reference = self.index * 64;
87        }
88
89        // Find the next set bit in the current block.
90        let position = self.block.trailing_zeros() as usize;
91        let code = self.reference + position;
92
93        if code >= 511 {
94            return None;
95        }
96
97        // The next iteration will calculate with reference to the returned code + 1
98        self.reference = code + 1;
99        self.block = if position == 63 {
100            0
101        } else {
102            self.block >> (1 + position)
103        };
104
105        Some(code as u16)
106    }
107}
108
109#[derive(Debug, Clone)]
110struct Counter {
111    /// Frequency count for each code.
112    counts1: Vec<usize>,
113
114    /// Frequency count for each code-pair.
115    counts2: Vec<usize>,
116
117    /// Bitmap index for codes that appear in counts1
118    code1_index: CodesBitmap,
119
120    /// Bitmap index of pairs that have been set.
121    ///
122    /// `pair_index[code1].codes()` yields an iterator that can
123    /// be used to find all possible codes that follow `codes1`.
124    pair_index: Vec<CodesBitmap>,
125}
126
127const COUNTS1_SIZE: usize = (FSST_CODE_MASK + 1) as usize;
128
129// NOTE: in Rust, creating a 1D vector of length N^2 is ~4x faster than creating a 2-D vector,
130//  because `vec!` has a specialization for zero.
131//
132// We also include +1 extra row at the end so that we can do writes into the counters without a branch
133// for the first iteration.
134const COUNTS2_SIZE: usize = COUNTS1_SIZE * COUNTS1_SIZE;
135
136impl Counter {
137    fn new() -> Self {
138        let mut counts1 = Vec::with_capacity(COUNTS1_SIZE);
139        let mut counts2 = Vec::with_capacity(COUNTS2_SIZE);
140        // SAFETY: all accesses to the vector go through the bitmap to ensure no uninitialized
141        //  data is ever read from these vectors.
142        unsafe {
143            counts1.set_len(COUNTS1_SIZE);
144            counts2.set_len(COUNTS2_SIZE);
145        }
146
147        Self {
148            counts1,
149            counts2,
150            code1_index: CodesBitmap::default(),
151            pair_index: vec![CodesBitmap::default(); COUNTS1_SIZE],
152        }
153    }
154
155    #[inline]
156    fn record_count1(&mut self, code1: u16) {
157        // If not set, we want to start at one.
158        let base = if self.code1_index.is_set(code1 as usize) {
159            self.counts1[code1 as usize]
160        } else {
161            0
162        };
163
164        self.counts1[code1 as usize] = base + 1;
165        self.code1_index.set(code1 as usize);
166    }
167
168    #[inline]
169    fn record_count2(&mut self, code1: u16, code2: u16) {
170        debug_assert!(code1 == FSST_CODE_MASK || self.code1_index.is_set(code1 as usize));
171        debug_assert!(self.code1_index.is_set(code2 as usize));
172
173        let idx = (code1 as usize) * COUNTS1_SIZE + (code2 as usize);
174        if self.pair_index[code1 as usize].is_set(code2 as usize) {
175            self.counts2[idx] += 1;
176        } else {
177            self.counts2[idx] = 1;
178        }
179        self.pair_index[code1 as usize].set(code2 as usize);
180    }
181
182    #[inline]
183    fn count1(&self, code1: u16) -> usize {
184        debug_assert!(self.code1_index.is_set(code1 as usize));
185
186        self.counts1[code1 as usize]
187    }
188
189    #[inline]
190    fn count2(&self, code1: u16, code2: u16) -> usize {
191        debug_assert!(self.code1_index.is_set(code1 as usize));
192        debug_assert!(self.code1_index.is_set(code2 as usize));
193        debug_assert!(self.pair_index[code1 as usize].is_set(code2 as usize));
194
195        let idx = (code1 as usize) * 512 + (code2 as usize);
196        self.counts2[idx]
197    }
198
199    /// Returns an ordered iterator over the codes that were observed
200    /// in a call to [`Self::count1`].
201    fn first_codes(&self) -> CodesIterator<'_> {
202        self.code1_index.codes()
203    }
204
205    /// Returns an iterator over the codes that have been observed
206    /// to follow `code1`.
207    ///
208    /// This is the set of all values `code2` where there was
209    /// previously a call to `self.record_count2(code1, code2)`.
210    fn second_codes(&self, code1: u16) -> CodesIterator<'_> {
211        self.pair_index[code1 as usize].codes()
212    }
213
214    /// Clear the counters.
215    /// Note that this just touches the bitmaps and sets them all to invalid.
216    fn clear(&mut self) {
217        self.code1_index.clear();
218        for index in &mut self.pair_index {
219            index.clear();
220        }
221    }
222}
223
224/// Entrypoint for building a new `Compressor`.
225pub struct CompressorBuilder {
226    /// Table mapping codes to symbols.
227    ///
228    /// The entries 0-255 are setup in some other way here
229    symbols: Vec<Symbol>,
230
231    /// The number of entries in the symbol table that have been populated, not counting
232    /// the escape values.
233    n_symbols: u8,
234
235    /// Counts for number of symbols of each length.
236    ///
237    /// `len_histogram[len-1]` = count of the symbols of length `len`.
238    len_histogram: [u8; 8],
239
240    /// Inverted index mapping 1-byte symbols to codes.
241    ///
242    /// This is only used for building, not used by the final `Compressor`.
243    codes_one_byte: Vec<Code>,
244
245    /// Inverted index mapping 2-byte symbols to codes
246    codes_two_byte: Vec<Code>,
247
248    /// Lossy perfect hash table for looking up codes to symbols that are 3 bytes or more
249    lossy_pht: LossyPHT,
250}
251
252impl CompressorBuilder {
253    /// Create a new builder.
254    pub fn new() -> Self {
255        // NOTE: `vec!` has a specialization for building a new vector of `0u64`. Because Symbol and u64
256        //  have the same bit pattern, we can allocate as u64 and transmute. If we do `vec![Symbol::EMPTY; N]`,
257        // that will create a new Vec and call `Symbol::EMPTY.clone()` `N` times which is considerably slower.
258        let symbols = vec![0u64; 511];
259
260        // SAFETY: transmute safety assured by the compiler.
261        let symbols: Vec<Symbol> = unsafe { std::mem::transmute(symbols) };
262
263        let mut table = Self {
264            symbols,
265            n_symbols: 0,
266            len_histogram: [0; 8],
267            codes_two_byte: Vec::with_capacity(65_536),
268            codes_one_byte: Vec::with_capacity(512),
269            lossy_pht: LossyPHT::new(),
270        };
271
272        // Populate the escape byte entries.
273        for byte in 0..=255 {
274            let symbol = Symbol::from_u8(byte);
275            table.symbols[byte as usize] = symbol;
276        }
277
278        // Fill codes_one_byte with pseudocodes for each byte.
279        for byte in 0..=255 {
280            // Push pseudocode for single-byte escape.
281            table.codes_one_byte.push(Code::new_escape(byte));
282        }
283
284        // Fill codes_two_byte with pseudocode of first byte
285        for idx in 0..=65_535 {
286            table.codes_two_byte.push(Code::new_escape(idx as u8));
287        }
288
289        table
290    }
291}
292
293impl Default for CompressorBuilder {
294    fn default() -> Self {
295        Self::new()
296    }
297}
298
299impl CompressorBuilder {
300    /// Attempt to insert a new symbol at the end of the table.
301    ///
302    /// # Panics
303    ///
304    /// Panics if the table is already full.
305    ///
306    /// # Returns
307    ///
308    /// Returns true if the symbol was inserted successfully, or false if it conflicted
309    /// with an existing symbol.
310    pub fn insert(&mut self, symbol: Symbol, len: usize) -> bool {
311        assert!(self.n_symbols < 255, "cannot insert into full symbol table");
312        assert_eq!(len, symbol.len(), "provided len must equal symbol.len()");
313
314        if len == 2 {
315            // shortCodes
316            self.codes_two_byte[symbol.first2() as usize] =
317                Code::new_symbol_building(self.n_symbols, 2);
318        } else if len == 1 {
319            // byteCodes
320            self.codes_one_byte[symbol.first_byte() as usize] =
321                Code::new_symbol_building(self.n_symbols, 1);
322        } else {
323            // Symbols of 3 or more bytes go into the hash table
324            if !self.lossy_pht.insert(symbol, len, self.n_symbols) {
325                return false;
326            }
327        }
328
329        // Increment length histogram.
330        self.len_histogram[len - 1] += 1;
331
332        // Insert successfully stored symbol at end of the symbol table
333        // Note the rescaling from range [0-254] -> [256, 510].
334        self.symbols[256 + (self.n_symbols as usize)] = symbol;
335        self.n_symbols += 1;
336        true
337    }
338
339    /// Clear all set items from the compressor.
340    ///
341    /// This is considerably faster than building a new Compressor from scratch for each
342    /// iteration of the `train` loop.
343    fn clear(&mut self) {
344        // Eliminate every observed code from the table.
345        for code in 0..(256 + self.n_symbols as usize) {
346            let symbol = self.symbols[code];
347            if symbol.len() == 1 {
348                // Reset the entry from the codes_one_byte array.
349                self.codes_one_byte[symbol.first_byte() as usize] =
350                    Code::new_escape(symbol.first_byte());
351            } else if symbol.len() == 2 {
352                // Reset the entry from the codes_two_byte array.
353                self.codes_two_byte[symbol.first2() as usize] =
354                    Code::new_escape(symbol.first_byte());
355            } else {
356                // Clear the hashtable entry
357                self.lossy_pht.remove(symbol);
358            }
359        }
360
361        // Reset len histogram
362        for i in 0..=7 {
363            self.len_histogram[i] = 0;
364        }
365
366        self.n_symbols = 0;
367    }
368
369    /// Finalizing the table is done once building is complete to prepare for efficient
370    /// compression.
371    ///
372    /// When we finalize the table, the following modifications are made in-place:
373    ///
374    /// 1. The codes are renumbered so that all symbols are ordered by length (order 23456781).
375    ///    During this process, the two byte symbols are separated into a byte_lim and a suffix_lim,
376    ///    so we know that we don't need to check the suffix limitations instead.
377    /// 2. The 1-byte symbols index is merged into the 2-byte symbols index to allow for use of only
378    ///    a single index in front of the hash table.
379    ///
380    /// # Returns
381    ///
382    /// Returns the `suffix_lim`, which is the index of the two-byte code before where we know
383    /// there are no longer suffixies in the symbol table.
384    ///
385    /// Also returns the lengths vector, which is of length `n_symbols` and contains the
386    /// length for each of the values.
387    fn finalize(&mut self) -> (u8, Vec<u8>) {
388        // Create a cumulative sum of each of the elements of the input line numbers.
389        // Do a map that includes the previously seen value as well.
390        // Regroup symbols based on their lengths.
391        // Space at the end of the symbol table reserved for the one-byte codes.
392        let byte_lim = self.n_symbols - self.len_histogram[0];
393
394        // Start code for each length.
395        // Length 1: at the end of symbol table.
396        // Length 2: starts at 0. Split into before/after suffixLim.
397        let mut codes_by_length = [0u8; 8];
398        codes_by_length[0] = byte_lim;
399        codes_by_length[1] = 0;
400
401        // codes for lengths 3..=8 start where the previous ones end.
402        for i in 1..7 {
403            codes_by_length[i + 1] = codes_by_length[i] + self.len_histogram[i];
404        }
405
406        // no_suffix_code is the lowest code for a symbol that does not have a longer 3+ byte
407        // suffix in the table.
408        // This value starts at 0 and extends up.
409        let mut no_suffix_code = 0;
410
411        // The codes that do not have a suffix begin just before the range of the 3-byte codes.
412        let mut has_suffix_code = codes_by_length[2];
413
414        // Assign each symbol a new code ordered by lengths, in the order
415        // 2(no suffix) | 2 (suffix) | 3 | 4 | 5 | 6 | 7 | 8 | 1
416        let mut new_codes = [0u8; FSST_CODE_BASE as usize];
417
418        let mut symbol_lens = [0u8; FSST_CODE_BASE as usize];
419
420        for i in 0..(self.n_symbols as usize) {
421            let symbol = self.symbols[256 + i];
422            let len = symbol.len();
423            if len == 2 {
424                let has_suffix = self
425                    .symbols
426                    .iter()
427                    .skip(FSST_CODE_BASE as usize)
428                    .enumerate()
429                    .any(|(k, other)| i != k && symbol.first2() == other.first2());
430
431                if has_suffix {
432                    // Symbols that have a longer suffix are inserted at the end of the 2-byte range
433                    has_suffix_code -= 1;
434                    new_codes[i] = has_suffix_code;
435                } else {
436                    // Symbols that do not have a longer suffix are inserted at the start of
437                    // the 2-byte range.
438                    new_codes[i] = no_suffix_code;
439                    no_suffix_code += 1;
440                }
441            } else {
442                // Assign new code based on the next code available for the given length symbol
443                new_codes[i] = codes_by_length[len - 1];
444                codes_by_length[len - 1] += 1;
445            }
446
447            // Write the symbol into the front half of the symbol table.
448            // We are reusing the space that was previously occupied by escapes.
449            self.symbols[new_codes[i] as usize] = symbol;
450            symbol_lens[new_codes[i] as usize] = len as u8;
451        }
452
453        // Truncate the symbol table to only include the "true" symbols.
454        self.symbols.truncate(self.n_symbols as usize);
455
456        // Rewrite the codes_one_byte table to point at the new code values.
457        // Replace pseudocodes with escapes.
458        for byte in 0..=255 {
459            let one_byte = self.codes_one_byte[byte];
460            if one_byte.extended_code() >= FSST_CODE_BASE {
461                let new_code = new_codes[one_byte.code() as usize];
462                self.codes_one_byte[byte] = Code::new_symbol(new_code, 1);
463            } else {
464                // After finalize: codes_one_byte contains the unused value
465                self.codes_one_byte[byte] = Code::UNUSED;
466            }
467        }
468
469        // Rewrite the codes_two_byte table to point at the new code values.
470        // Replace pseudocodes with escapes.
471        for two_bytes in 0..=65_535 {
472            let two_byte = self.codes_two_byte[two_bytes];
473            if two_byte.extended_code() >= FSST_CODE_BASE {
474                let new_code = new_codes[two_byte.code() as usize];
475                self.codes_two_byte[two_bytes] = Code::new_symbol(new_code, 2);
476            } else {
477                // The one-byte code for the given code number here...
478                self.codes_two_byte[two_bytes] = self.codes_one_byte[two_bytes & 0xFF];
479            }
480        }
481
482        // Reset values in the hash table as well.
483        self.lossy_pht.renumber(&new_codes);
484
485        // Pre-compute the lengths
486        let mut lengths = Vec::with_capacity(self.n_symbols as usize);
487        for symbol in &self.symbols {
488            lengths.push(symbol.len() as u8);
489        }
490
491        (has_suffix_code, lengths)
492    }
493
494    /// Build into the final hash table.
495    pub fn build(mut self) -> Compressor {
496        // finalize the symbol table by inserting the codes_twobyte values into
497        // the relevant parts of the `codes_onebyte` set.
498
499        let (has_suffix_code, lengths) = self.finalize();
500
501        Compressor {
502            symbols: self.symbols,
503            lengths,
504            n_symbols: self.n_symbols,
505            has_suffix_code,
506            codes_two_byte: self.codes_two_byte,
507            lossy_pht: self.lossy_pht,
508        }
509    }
510}
511
512/// The number of generations used for training. This is taken from the [FSST paper].
513///
514/// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf
515#[cfg(not(miri))]
516const GENERATIONS: [usize; 5] = [8usize, 38, 68, 98, 128];
517#[cfg(miri)]
518const GENERATIONS: [usize; 3] = [8usize, 38, 128];
519
520const FSST_SAMPLETARGET: usize = 1 << 14;
521const FSST_SAMPLEMAX: usize = 1 << 15;
522const FSST_SAMPLELINE: usize = 512;
523
524/// Create a sample from a set of strings in the input.
525///
526/// Sample is constructing by copying "chunks" from the `str_in`s into the `sample_buf`, the
527/// returned slices are pointers into the `sample_buf`.
528///
529/// SAFETY: sample_buf must be >= FSST_SAMPLEMAX bytes long. Providing something less may cause unexpected failures.
530#[allow(clippy::ptr_arg)]
531fn make_sample<'a, 'b: 'a>(sample_buf: &'a mut Vec<u8>, str_in: &Vec<&'b [u8]>) -> Vec<&'a [u8]> {
532    assert!(
533        sample_buf.capacity() >= FSST_SAMPLEMAX,
534        "sample_buf.len() < FSST_SAMPLEMAX"
535    );
536
537    let mut sample: Vec<&[u8]> = Vec::new();
538
539    let tot_size: usize = str_in.iter().map(|s| s.len()).sum();
540    if tot_size < FSST_SAMPLETARGET {
541        return str_in.clone();
542    }
543
544    let mut sample_rnd = fsst_hash(4637947);
545    let sample_lim = FSST_SAMPLETARGET;
546    let mut sample_buf_offset: usize = 0;
547
548    while sample_buf_offset < sample_lim {
549        sample_rnd = fsst_hash(sample_rnd);
550        let line_nr = (sample_rnd as usize) % str_in.len();
551
552        // Find the first non-empty chunk starting at line_nr, wrapping around if
553        // necessary.
554        let Some(line) = (line_nr..str_in.len())
555            .chain(0..line_nr)
556            .map(|line_nr| str_in[line_nr])
557            .find(|line| !line.is_empty())
558        else {
559            return sample;
560        };
561
562        let chunks = 1 + ((line.len() - 1) / FSST_SAMPLELINE);
563        sample_rnd = fsst_hash(sample_rnd);
564        let chunk = FSST_SAMPLELINE * ((sample_rnd as usize) % chunks);
565
566        let len = FSST_SAMPLELINE.min(line.len() - chunk);
567
568        sample_buf.extend_from_slice(&line[chunk..chunk + len]);
569
570        // SAFETY: this is the data we just placed into `sample_buf` in the line above.
571        let slice =
572            unsafe { std::slice::from_raw_parts(sample_buf.as_ptr().add(sample_buf_offset), len) };
573
574        sample.push(slice);
575
576        sample_buf_offset += len;
577    }
578
579    sample
580}
581
582/// Hash function used in various components of the library.
583///
584/// This is equivalent to the FSST_HASH macro from the C++ implementation.
585#[inline]
586pub(crate) fn fsst_hash(value: u64) -> u64 {
587    value.wrapping_mul(2971215073) ^ value.wrapping_shr(15)
588}
589
590impl Compressor {
591    /// Build and train a `Compressor` from a sample corpus of text.
592    ///
593    /// This function implements the generational algorithm described in the [FSST paper] Section
594    /// 4.3. Starting with an empty symbol table, it iteratively compresses the corpus, then attempts
595    /// to merge symbols when doing so would yield better compression than leaving them unmerged. The
596    /// resulting table will have at most 255 symbols (the 256th symbol is reserved for the escape
597    /// code).
598    ///
599    /// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf
600    pub fn train(values: &Vec<&[u8]>) -> Self {
601        let mut builder = CompressorBuilder::new();
602
603        if values.is_empty() {
604            return builder.build();
605        }
606
607        let mut counters = Counter::new();
608        let mut sample_memory = Vec::with_capacity(FSST_SAMPLEMAX);
609        let mut pqueue = BinaryHeap::with_capacity(65_536);
610
611        let sample = make_sample(&mut sample_memory, values);
612        for sample_frac in GENERATIONS {
613            for (i, line) in sample.iter().enumerate() {
614                if sample_frac < 128 && ((fsst_hash(i as u64) & 127) as usize) > sample_frac {
615                    continue;
616                }
617
618                builder.compress_count(line, &mut counters);
619            }
620
621            // Clear the heap before we use it again
622            pqueue.clear();
623            builder.optimize(&counters, sample_frac, &mut pqueue);
624            counters.clear();
625        }
626
627        builder.build()
628    }
629}
630
631impl CompressorBuilder {
632    /// Find the longest symbol using the hash table and the codes_one_byte and codes_two_byte indexes.
633    fn find_longest_symbol(&self, word: u64) -> Code {
634        // Probe the hash table first to see if we have a long match
635        let entry = self.lossy_pht.lookup(word);
636        let ignored_bits = entry.ignored_bits;
637
638        // If the entry is valid, return the code
639        if !entry.is_unused() && compare_masked(word, entry.symbol.to_u64(), ignored_bits) {
640            return entry.code;
641        }
642
643        // Try and match first two bytes
644        let twobyte = self.codes_two_byte[word as u16 as usize];
645        if twobyte.extended_code() >= FSST_CODE_BASE {
646            return twobyte;
647        }
648
649        // Fall back to single-byte match
650        self.codes_one_byte[word as u8 as usize]
651    }
652
653    /// Compress the text using the current symbol table. Count the code occurrences
654    /// and code-pair occurrences, calculating total gain using the current compressor.
655    ///
656    /// NOTE: this is largely an unfortunate amount of copy-paste from `compress`, just to make sure
657    /// we can do all the counting in a single pass.
658    fn compress_count(&self, sample: &[u8], counter: &mut Counter) -> usize {
659        let mut gain = 0;
660        if sample.is_empty() {
661            return gain;
662        }
663
664        let mut in_ptr = sample.as_ptr();
665
666        // SAFETY: `end` will point just after the end of the `plaintext` slice.
667        let in_end = unsafe { in_ptr.byte_add(sample.len()) };
668        let in_end_sub8 = in_end as usize - 8;
669
670        let mut prev_code: u16 = FSST_CODE_MASK;
671
672        while (in_ptr as usize) < (in_end_sub8) {
673            // SAFETY: ensured in-bounds by loop condition.
674            let word: u64 = unsafe { std::ptr::read_unaligned(in_ptr as *const u64) };
675            let code = self.find_longest_symbol(word);
676            let code_u16 = code.extended_code();
677
678            // Gain increases by the symbol length if a symbol matches, or 0
679            // if an escape is emitted.
680            gain += (code.len() as usize) - ((code_u16 < 256) as usize);
681
682            // Record the single and pair counts
683            counter.record_count1(code_u16);
684            counter.record_count2(prev_code, code_u16);
685
686            // Also record the count for just extending by a single byte, but only if
687            // the symbol is not itself a single byte.
688            if code.len() > 1 {
689                let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16;
690                counter.record_count1(code_first_byte);
691                counter.record_count2(prev_code, code_first_byte);
692            }
693
694            // SAFETY: pointer bound is checked in loop condition before any access is made.
695            in_ptr = unsafe { in_ptr.byte_add(code.len() as usize) };
696
697            prev_code = code_u16;
698        }
699
700        let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
701        assert!(
702            remaining_bytes.is_positive(),
703            "in_ptr exceeded in_end, should not be possible"
704        );
705        let remaining_bytes = remaining_bytes as usize;
706
707        // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
708        // but shift data out of this word rather than advancing an input pointer and potentially reading
709        // unowned memory
710        let mut bytes = [0u8; 8];
711        unsafe {
712            // SAFETY: it is safe to read up to remaining_bytes from in_ptr, and remaining_bytes
713            //  will be <= 8 bytes.
714            std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes);
715        }
716        let mut last_word = u64::from_le_bytes(bytes);
717
718        let mut remaining_bytes = remaining_bytes;
719
720        while remaining_bytes > 0 {
721            // SAFETY: ensured in-bounds by loop condition.
722            let code = self.find_longest_symbol(last_word);
723            let code_u16 = code.extended_code();
724
725            // Gain increases by the symbol length if a symbol matches, or 0
726            // if an escape is emitted.
727            gain += (code.len() as usize) - ((code_u16 < 256) as usize);
728
729            // Record the single and pair counts
730            counter.record_count1(code_u16);
731            counter.record_count2(prev_code, code_u16);
732
733            // Also record the count for just extending by a single byte, but only if
734            // the symbol is not itself a single byte.
735            if code.len() > 1 {
736                let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16;
737                counter.record_count1(code_first_byte);
738                counter.record_count2(prev_code, code_first_byte);
739            }
740
741            // Advance our last_word "input pointer" by shifting off the covered values.
742            let advance = code.len() as usize;
743            remaining_bytes -= advance;
744            last_word = advance_8byte_word(last_word, advance);
745
746            prev_code = code_u16;
747        }
748
749        gain
750    }
751
752    /// Using a set of counters and the existing set of symbols, build a new
753    /// set of symbols/codes that optimizes the gain over the distribution in `counter`.
754    fn optimize(
755        &mut self,
756        counters: &Counter,
757        sample_frac: usize,
758        pqueue: &mut BinaryHeap<Candidate>,
759    ) {
760        for code1 in counters.first_codes() {
761            let symbol1 = self.symbols[code1 as usize];
762            let symbol1_len = symbol1.len();
763            let count = counters.count1(code1);
764
765            // From the c++ impl:
766            // "improves both compression speed (less candidates), but also quality!!"
767            if count < (5 * sample_frac / 128) {
768                continue;
769            }
770
771            let mut gain = count * symbol1_len;
772            // NOTE: use heuristic from C++ implementation to boost the gain of single-byte symbols.
773            // This helps to reduce exception counts.
774            if code1 < 256 {
775                gain *= 8;
776            }
777
778            pqueue.push(Candidate {
779                symbol: symbol1,
780                gain,
781            });
782
783            // Skip merges on last round, or when symbol cannot be extended.
784            if sample_frac >= 128 || symbol1_len == 8 {
785                continue;
786            }
787
788            for code2 in counters.second_codes(code1) {
789                let symbol2 = self.symbols[code2 as usize];
790
791                // If merging would yield a symbol of length greater than 8, skip.
792                if symbol1_len + symbol2.len() > 8 {
793                    continue;
794                }
795                let new_symbol = symbol1.concat(symbol2);
796                let gain = counters.count2(code1, code2) * new_symbol.len();
797
798                pqueue.push(Candidate {
799                    symbol: new_symbol,
800                    gain,
801                })
802            }
803        }
804
805        // clear self in advance of inserting the symbols.
806        self.clear();
807
808        // Pop the 255 best symbols.
809        let mut n_symbols = 0;
810        while !pqueue.is_empty() && n_symbols < 255 {
811            let candidate = pqueue.pop().unwrap();
812            if self.insert(candidate.symbol, candidate.symbol.len()) {
813                n_symbols += 1;
814            }
815        }
816    }
817}
818
819/// A candidate for inclusion in a symbol table.
820///
821/// This is really only useful for the `optimize` step of training.
822#[derive(Copy, Clone, Debug)]
823struct Candidate {
824    gain: usize,
825    symbol: Symbol,
826}
827
828impl Candidate {
829    fn comparable_form(&self) -> (usize, usize) {
830        (self.gain, self.symbol.len())
831    }
832}
833
834impl Eq for Candidate {}
835
836impl PartialEq<Self> for Candidate {
837    fn eq(&self, other: &Self) -> bool {
838        self.comparable_form().eq(&other.comparable_form())
839    }
840}
841
842impl PartialOrd<Self> for Candidate {
843    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
844        Some(self.cmp(other))
845    }
846}
847
848impl Ord for Candidate {
849    fn cmp(&self, other: &Self) -> Ordering {
850        let self_ord = (self.gain, self.symbol.len());
851        let other_ord = (other.gain, other.symbol.len());
852
853        self_ord.cmp(&other_ord)
854    }
855}
856
857#[cfg(test)]
858mod test {
859    use crate::{Compressor, ESCAPE_CODE, builder::CodesBitmap};
860
861    #[test]
862    fn test_builder() {
863        // Train a Compressor on the toy string
864        let text = b"hello hello hello hello hello";
865
866        // count of 5 is the cutoff for including a symbol in the table.
867        let table = Compressor::train(&vec![text, text, text, text, text]);
868
869        // Use the table to compress a string, see the values
870        let compressed = table.compress(text);
871
872        // Ensure that the compressed string has no escape bytes
873        assert!(compressed.iter().all(|b| *b != ESCAPE_CODE));
874
875        // Ensure that we can compress a string with no values seen at training time, with escape bytes
876        let compressed = table.compress("xyz123".as_bytes());
877        let decompressed = table.decompressor().decompress(&compressed);
878        assert_eq!(&decompressed, b"xyz123");
879        assert_eq!(
880            compressed,
881            vec![
882                ESCAPE_CODE,
883                b'x',
884                ESCAPE_CODE,
885                b'y',
886                ESCAPE_CODE,
887                b'z',
888                ESCAPE_CODE,
889                b'1',
890                ESCAPE_CODE,
891                b'2',
892                ESCAPE_CODE,
893                b'3',
894            ]
895        );
896    }
897
898    #[test]
899    fn test_bitmap() {
900        let mut map = CodesBitmap::default();
901        map.set(10);
902        map.set(100);
903        map.set(500);
904
905        let codes: Vec<u16> = map.codes().collect();
906        assert_eq!(codes, vec![10u16, 100, 500]);
907
908        // empty case
909        let map = CodesBitmap::default();
910        assert!(map.codes().collect::<Vec<_>>().is_empty());
911
912        // edge case: first bit in each block is set
913        let mut map = CodesBitmap::default();
914        (0..8).for_each(|i| map.set(64 * i));
915        assert_eq!(
916            map.codes().collect::<Vec<_>>(),
917            (0u16..8).map(|i| 64 * i).collect::<Vec<_>>(),
918        );
919
920        // Full bitmap case. There are only 512 values, so test them all
921        let mut map = CodesBitmap::default();
922        for i in 0..512 {
923            map.set(i);
924        }
925        assert_eq!(
926            map.codes().collect::<Vec<_>>(),
927            (0u16..511u16).collect::<Vec<_>>()
928        );
929    }
930
931    #[test]
932    #[should_panic(expected = "code cannot exceed")]
933    fn test_bitmap_invalid() {
934        let mut map = CodesBitmap::default();
935        map.set(512);
936    }
937}