Skip to main content

fsst/
builder.rs

1//! Functions and types used for building a [`Compressor`] from a corpus of text.
2//!
3//! This module implements the logic from Algorithm 3 of the [FSST Paper].
4//!
5//! [FSST Paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf
6
7use crate::{
8    Code, Compressor, FSST_CODE_BASE, FSST_CODE_MASK, Symbol, advance_8byte_word, compare_masked,
9    lossy_pht::LossyPHT,
10};
11use rustc_hash::{FxBuildHasher, FxHashMap};
12use std::cmp::Ordering;
13use std::collections::BinaryHeap;
14
15/// Bitmap that only works for values up to 512
16#[derive(Clone, Copy, Debug, Default)]
17struct CodesBitmap {
18    codes: [u64; 8],
19}
20
21assert_sizeof!(CodesBitmap => 64);
22
23impl CodesBitmap {
24    /// Set the indicated bit. Must be between 0 and [`FSST_CODE_MASK`][crate::FSST_CODE_MASK].
25    pub(crate) fn set(&mut self, index: usize) {
26        debug_assert!(
27            index <= FSST_CODE_MASK as usize,
28            "code cannot exceed {FSST_CODE_MASK}"
29        );
30
31        let map = index >> 6;
32        self.codes[map] |= 1 << (index % 64);
33    }
34
35    /// Check if `index` is present in the bitmap
36    pub(crate) fn is_set(&self, index: usize) -> bool {
37        debug_assert!(
38            index <= FSST_CODE_MASK as usize,
39            "code cannot exceed {FSST_CODE_MASK}"
40        );
41
42        let map = index >> 6;
43        self.codes[map] & (1 << (index % 64)) != 0
44    }
45
46    /// Get all codes set in this bitmap
47    pub(crate) fn codes(&self) -> CodesIterator<'_> {
48        CodesIterator {
49            inner: self,
50            index: 0,
51            block: self.codes[0],
52            reference: 0,
53        }
54    }
55
56    /// Clear the bitmap of all entries.
57    pub(crate) fn clear(&mut self) {
58        self.codes[0] = 0;
59        self.codes[1] = 0;
60        self.codes[2] = 0;
61        self.codes[3] = 0;
62        self.codes[4] = 0;
63        self.codes[5] = 0;
64        self.codes[6] = 0;
65        self.codes[7] = 0;
66    }
67}
68
69struct CodesIterator<'a> {
70    inner: &'a CodesBitmap,
71    index: usize,
72    block: u64,
73    reference: usize,
74}
75
76impl Iterator for CodesIterator<'_> {
77    type Item = u16;
78
79    fn next(&mut self) -> Option<Self::Item> {
80        // If current is zero, advance to next non-zero block
81        while self.block == 0 {
82            self.index += 1;
83            if self.index >= 8 {
84                return None;
85            }
86            self.block = self.inner.codes[self.index];
87            self.reference = self.index * 64;
88        }
89
90        // Find the next set bit in the current block.
91        let position = self.block.trailing_zeros() as usize;
92        let code = self.reference + position;
93
94        if code >= 511 {
95            return None;
96        }
97
98        // The next iteration will calculate with reference to the returned code + 1
99        self.reference = code + 1;
100        self.block = if position == 63 {
101            0
102        } else {
103            self.block >> (1 + position)
104        };
105
106        Some(code as u16)
107    }
108}
109
110#[derive(Debug, Clone)]
111struct Counter {
112    /// Frequency count for each code.
113    counts1: Vec<usize>,
114
115    /// Frequency count for each code-pair.
116    counts2: Vec<usize>,
117
118    /// Bitmap index for codes that appear in counts1
119    code1_index: CodesBitmap,
120
121    /// Bitmap index of pairs that have been set.
122    ///
123    /// `pair_index[code1].codes()` yields an iterator that can
124    /// be used to find all possible codes that follow `codes1`.
125    pair_index: Vec<CodesBitmap>,
126}
127
128const COUNTS1_SIZE: usize = (FSST_CODE_MASK + 1) as usize;
129
130// NOTE: in Rust, creating a 1D vector of length N^2 is ~4x faster than creating a 2-D vector,
131//  because `vec!` has a specialization for zero.
132//
133// We also include +1 extra row at the end so that we can do writes into the counters without a branch
134// for the first iteration.
135const COUNTS2_SIZE: usize = COUNTS1_SIZE * COUNTS1_SIZE;
136
137impl Counter {
138    fn new() -> Self {
139        let mut counts1 = Vec::with_capacity(COUNTS1_SIZE);
140        let mut counts2 = Vec::with_capacity(COUNTS2_SIZE);
141        // SAFETY: all accesses to the vector go through the bitmap to ensure no uninitialized
142        //  data is ever read from these vectors.
143        unsafe {
144            counts1.set_len(COUNTS1_SIZE);
145            counts2.set_len(COUNTS2_SIZE);
146        }
147
148        Self {
149            counts1,
150            counts2,
151            code1_index: CodesBitmap::default(),
152            pair_index: vec![CodesBitmap::default(); COUNTS1_SIZE],
153        }
154    }
155
156    #[inline]
157    fn record_count1(&mut self, code1: u16) {
158        // If not set, we want to start at one.
159        let base = if self.code1_index.is_set(code1 as usize) {
160            self.counts1[code1 as usize]
161        } else {
162            0
163        };
164
165        self.counts1[code1 as usize] = base + 1;
166        self.code1_index.set(code1 as usize);
167    }
168
169    #[inline]
170    fn record_count2(&mut self, code1: u16, code2: u16) {
171        debug_assert!(code1 == FSST_CODE_MASK || self.code1_index.is_set(code1 as usize));
172        debug_assert!(self.code1_index.is_set(code2 as usize));
173
174        let idx = (code1 as usize) * COUNTS1_SIZE + (code2 as usize);
175        if self.pair_index[code1 as usize].is_set(code2 as usize) {
176            self.counts2[idx] += 1;
177        } else {
178            self.counts2[idx] = 1;
179        }
180        self.pair_index[code1 as usize].set(code2 as usize);
181    }
182
183    #[inline]
184    fn count1(&self, code1: u16) -> usize {
185        debug_assert!(self.code1_index.is_set(code1 as usize));
186
187        self.counts1[code1 as usize]
188    }
189
190    #[inline]
191    fn count2(&self, code1: u16, code2: u16) -> usize {
192        debug_assert!(self.code1_index.is_set(code1 as usize));
193        debug_assert!(self.code1_index.is_set(code2 as usize));
194        debug_assert!(self.pair_index[code1 as usize].is_set(code2 as usize));
195
196        let idx = (code1 as usize) * 512 + (code2 as usize);
197        self.counts2[idx]
198    }
199
200    /// Returns an ordered iterator over the codes that were observed
201    /// in a call to [`Self::count1`].
202    fn first_codes(&self) -> CodesIterator<'_> {
203        self.code1_index.codes()
204    }
205
206    /// Returns an iterator over the codes that have been observed
207    /// to follow `code1`.
208    ///
209    /// This is the set of all values `code2` where there was
210    /// previously a call to `self.record_count2(code1, code2)`.
211    fn second_codes(&self, code1: u16) -> CodesIterator<'_> {
212        self.pair_index[code1 as usize].codes()
213    }
214
215    /// Clear the counters.
216    /// Note that this just touches the bitmaps and sets them all to invalid.
217    fn clear(&mut self) {
218        self.code1_index.clear();
219        for index in &mut self.pair_index {
220            index.clear();
221        }
222    }
223}
224
225/// Entrypoint for building a new `Compressor`.
226pub struct CompressorBuilder {
227    /// Table mapping codes to symbols.
228    ///
229    /// The entries 0-255 are setup in some other way here
230    symbols: Vec<Symbol>,
231
232    /// The number of entries in the symbol table that have been populated, not counting
233    /// the escape values.
234    n_symbols: u8,
235
236    /// Counts for number of symbols of each length.
237    ///
238    /// `len_histogram[len-1]` = count of the symbols of length `len`.
239    len_histogram: [u8; 8],
240
241    /// Inverted index mapping 1-byte symbols to codes.
242    ///
243    /// This is only used for building, not used by the final `Compressor`.
244    codes_one_byte: Vec<Code>,
245
246    /// Inverted index mapping 2-byte symbols to codes
247    codes_two_byte: Vec<Code>,
248
249    /// Lossy perfect hash table for looking up codes to symbols that are 3 bytes or more
250    lossy_pht: LossyPHT,
251}
252
253impl CompressorBuilder {
254    /// Create a new builder.
255    pub fn new() -> Self {
256        // NOTE: `vec!` has a specialization for building a new vector of `0u64`. Because Symbol and u64
257        //  have the same bit pattern, we can allocate as u64 and transmute. If we do `vec![Symbol::EMPTY; N]`,
258        // that will create a new Vec and call `Symbol::EMPTY.clone()` `N` times which is considerably slower.
259        let symbols = vec![0u64; 511];
260
261        // SAFETY: transmute safety assured by the compiler.
262        let symbols: Vec<Symbol> = unsafe { std::mem::transmute(symbols) };
263
264        let mut table = Self {
265            symbols,
266            n_symbols: 0,
267            len_histogram: [0; 8],
268            codes_two_byte: Vec::with_capacity(65_536),
269            codes_one_byte: Vec::with_capacity(512),
270            lossy_pht: LossyPHT::new(),
271        };
272
273        // Populate the escape byte entries.
274        for byte in 0..=255 {
275            let symbol = Symbol::from_u8(byte);
276            table.symbols[byte as usize] = symbol;
277        }
278
279        // Fill codes_one_byte with pseudocodes for each byte.
280        for byte in 0..=255 {
281            // Push pseudocode for single-byte escape.
282            table.codes_one_byte.push(Code::new_escape(byte));
283        }
284
285        // Fill codes_two_byte with pseudocode of first byte
286        for idx in 0..=65_535 {
287            table.codes_two_byte.push(Code::new_escape(idx as u8));
288        }
289
290        table
291    }
292}
293
294impl Default for CompressorBuilder {
295    fn default() -> Self {
296        Self::new()
297    }
298}
299
300impl CompressorBuilder {
301    /// Attempt to insert a new symbol at the end of the table.
302    ///
303    /// # Panics
304    ///
305    /// Panics if the table is already full.
306    ///
307    /// # Returns
308    ///
309    /// Returns true if the symbol was inserted successfully, or false if it conflicted
310    /// with an existing symbol.
311    pub fn insert(&mut self, symbol: Symbol, len: usize) -> bool {
312        assert!(self.n_symbols < 255, "cannot insert into full symbol table");
313        assert_eq!(len, symbol.len(), "provided len must equal symbol.len()");
314
315        if len == 2 {
316            // shortCodes
317            self.codes_two_byte[symbol.first2() as usize] =
318                Code::new_symbol_building(self.n_symbols, 2);
319        } else if len == 1 {
320            // byteCodes
321            self.codes_one_byte[symbol.first_byte() as usize] =
322                Code::new_symbol_building(self.n_symbols, 1);
323        } else {
324            // Symbols of 3 or more bytes go into the hash table
325            if !self.lossy_pht.insert(symbol, len, self.n_symbols) {
326                return false;
327            }
328        }
329
330        // Increment length histogram.
331        self.len_histogram[len - 1] += 1;
332
333        // Insert successfully stored symbol at end of the symbol table
334        // Note the rescaling from range [0-254] -> [256, 510].
335        self.symbols[256 + (self.n_symbols as usize)] = symbol;
336        self.n_symbols += 1;
337        true
338    }
339
340    /// Clear all set items from the compressor.
341    ///
342    /// This is considerably faster than building a new Compressor from scratch for each
343    /// iteration of the `train` loop.
344    fn clear(&mut self) {
345        // Eliminate every observed code from the table.
346        for code in 0..(256 + self.n_symbols as usize) {
347            let symbol = self.symbols[code];
348            if symbol.len() == 1 {
349                // Reset the entry from the codes_one_byte array.
350                self.codes_one_byte[symbol.first_byte() as usize] =
351                    Code::new_escape(symbol.first_byte());
352            } else if symbol.len() == 2 {
353                // Reset the entry from the codes_two_byte array.
354                self.codes_two_byte[symbol.first2() as usize] =
355                    Code::new_escape(symbol.first_byte());
356            } else {
357                // Clear the hashtable entry
358                self.lossy_pht.remove(symbol);
359            }
360        }
361
362        // Reset len histogram
363        for i in 0..=7 {
364            self.len_histogram[i] = 0;
365        }
366
367        self.n_symbols = 0;
368    }
369
370    /// Finalizing the table is done once building is complete to prepare for efficient
371    /// compression.
372    ///
373    /// When we finalize the table, the following modifications are made in-place:
374    ///
375    /// 1. The codes are renumbered so that all symbols are ordered by length (order 23456781).
376    ///    During this process, the two byte symbols are separated into a byte_lim and a suffix_lim,
377    ///    so we know that we don't need to check the suffix limitations instead.
378    /// 2. The 1-byte symbols index is merged into the 2-byte symbols index to allow for use of only
379    ///    a single index in front of the hash table.
380    ///
381    /// # Returns
382    ///
383    /// Returns the `suffix_lim`, which is the index of the two-byte code before where we know
384    /// there are no longer suffixies in the symbol table.
385    ///
386    /// Also returns the lengths vector, which is of length `n_symbols` and contains the
387    /// length for each of the values.
388    fn finalize(&mut self) -> (u8, Vec<u8>) {
389        // Create a cumulative sum of each of the elements of the input line numbers.
390        // Do a map that includes the previously seen value as well.
391        // Regroup symbols based on their lengths.
392        // Space at the end of the symbol table reserved for the one-byte codes.
393        let byte_lim = self.n_symbols - self.len_histogram[0];
394
395        // Start code for each length.
396        // Length 1: at the end of symbol table.
397        // Length 2: starts at 0. Split into before/after suffixLim.
398        let mut codes_by_length = [0u8; 8];
399        codes_by_length[0] = byte_lim;
400        codes_by_length[1] = 0;
401
402        // codes for lengths 3..=8 start where the previous ones end.
403        for i in 1..7 {
404            codes_by_length[i + 1] = codes_by_length[i] + self.len_histogram[i];
405        }
406
407        // no_suffix_code is the lowest code for a symbol that does not have a longer 3+ byte
408        // suffix in the table.
409        // This value starts at 0 and extends up.
410        let mut no_suffix_code = 0;
411
412        // The codes that do not have a suffix begin just before the range of the 3-byte codes.
413        let mut has_suffix_code = codes_by_length[2];
414
415        // Assign each symbol a new code ordered by lengths, in the order
416        // 2(no suffix) | 2 (suffix) | 3 | 4 | 5 | 6 | 7 | 8 | 1
417        let mut new_codes = [0u8; FSST_CODE_BASE as usize];
418
419        let mut symbol_lens = [0u8; FSST_CODE_BASE as usize];
420
421        for i in 0..(self.n_symbols as usize) {
422            let symbol = self.symbols[256 + i];
423            let len = symbol.len();
424            if len == 2 {
425                let has_suffix = self
426                    .symbols
427                    .iter()
428                    .skip(FSST_CODE_BASE as usize)
429                    .enumerate()
430                    .any(|(k, other)| i != k && symbol.first2() == other.first2());
431
432                if has_suffix {
433                    // Symbols that have a longer suffix are inserted at the end of the 2-byte range
434                    has_suffix_code -= 1;
435                    new_codes[i] = has_suffix_code;
436                } else {
437                    // Symbols that do not have a longer suffix are inserted at the start of
438                    // the 2-byte range.
439                    new_codes[i] = no_suffix_code;
440                    no_suffix_code += 1;
441                }
442            } else {
443                // Assign new code based on the next code available for the given length symbol
444                new_codes[i] = codes_by_length[len - 1];
445                codes_by_length[len - 1] += 1;
446            }
447
448            // Write the symbol into the front half of the symbol table.
449            // We are reusing the space that was previously occupied by escapes.
450            self.symbols[new_codes[i] as usize] = symbol;
451            symbol_lens[new_codes[i] as usize] = len as u8;
452        }
453
454        // Truncate the symbol table to only include the "true" symbols.
455        self.symbols.truncate(self.n_symbols as usize);
456
457        // Rewrite the codes_one_byte table to point at the new code values.
458        // Replace pseudocodes with escapes.
459        for byte in 0..=255 {
460            let one_byte = self.codes_one_byte[byte];
461            if one_byte.extended_code() >= FSST_CODE_BASE {
462                let new_code = new_codes[one_byte.code() as usize];
463                self.codes_one_byte[byte] = Code::new_symbol(new_code, 1);
464            } else {
465                // After finalize: codes_one_byte contains the unused value
466                self.codes_one_byte[byte] = Code::UNUSED;
467            }
468        }
469
470        // Rewrite the codes_two_byte table to point at the new code values.
471        // Replace pseudocodes with escapes.
472        for two_bytes in 0..=65_535 {
473            let two_byte = self.codes_two_byte[two_bytes];
474            if two_byte.extended_code() >= FSST_CODE_BASE {
475                let new_code = new_codes[two_byte.code() as usize];
476                self.codes_two_byte[two_bytes] = Code::new_symbol(new_code, 2);
477            } else {
478                // The one-byte code for the given code number here...
479                self.codes_two_byte[two_bytes] = self.codes_one_byte[two_bytes & 0xFF];
480            }
481        }
482
483        // Reset values in the hash table as well.
484        self.lossy_pht.renumber(&new_codes);
485
486        // Pre-compute the lengths
487        let mut lengths = Vec::with_capacity(self.n_symbols as usize);
488        for symbol in &self.symbols {
489            lengths.push(symbol.len() as u8);
490        }
491
492        (has_suffix_code, lengths)
493    }
494
495    /// Build into the final hash table.
496    pub fn build(mut self) -> Compressor {
497        // finalize the symbol table by inserting the codes_twobyte values into
498        // the relevant parts of the `codes_onebyte` set.
499
500        let (has_suffix_code, lengths) = self.finalize();
501
502        Compressor {
503            symbols: self.symbols,
504            lengths,
505            n_symbols: self.n_symbols,
506            has_suffix_code,
507            codes_two_byte: self.codes_two_byte,
508            lossy_pht: self.lossy_pht,
509        }
510    }
511}
512
513/// The number of generations used for training. This is taken from the [FSST paper].
514///
515/// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf
516#[cfg(not(miri))]
517const GENERATIONS: [usize; 5] = [8usize, 38, 68, 98, 128];
518#[cfg(miri)]
519const GENERATIONS: [usize; 3] = [8usize, 38, 128];
520
521const FSST_SAMPLETARGET: usize = 1 << 14;
522const FSST_SAMPLEMAX: usize = 1 << 15;
523const FSST_SAMPLELINE: usize = 512;
524
525/// Create a sample from a set of strings in the input.
526///
527/// Sample is constructing by copying "chunks" from the `str_in`s into the `sample_buf`, the
528/// returned slices are pointers into the `sample_buf`.
529///
530/// SAFETY: sample_buf must be >= FSST_SAMPLEMAX bytes long. Providing something less may cause unexpected failures.
531#[allow(clippy::ptr_arg)]
532fn make_sample<'a, 'b: 'a>(sample_buf: &'a mut Vec<u8>, str_in: &Vec<&'b [u8]>) -> Vec<&'a [u8]> {
533    assert!(
534        sample_buf.capacity() >= FSST_SAMPLEMAX,
535        "sample_buf.len() < FSST_SAMPLEMAX"
536    );
537
538    let mut sample: Vec<&[u8]> = Vec::new();
539
540    let tot_size: usize = str_in.iter().map(|s| s.len()).sum();
541    if tot_size < FSST_SAMPLETARGET {
542        return str_in.clone();
543    }
544
545    let mut sample_rnd = fsst_hash(4637947);
546    let sample_lim = FSST_SAMPLETARGET;
547    let mut sample_buf_offset: usize = 0;
548
549    while sample_buf_offset < sample_lim {
550        sample_rnd = fsst_hash(sample_rnd);
551        let line_nr = (sample_rnd as usize) % str_in.len();
552
553        // Find the first non-empty chunk starting at line_nr, wrapping around if
554        // necessary.
555        let Some(line) = (line_nr..str_in.len())
556            .chain(0..line_nr)
557            .map(|line_nr| str_in[line_nr])
558            .find(|line| !line.is_empty())
559        else {
560            return sample;
561        };
562
563        let chunks = 1 + ((line.len() - 1) / FSST_SAMPLELINE);
564        sample_rnd = fsst_hash(sample_rnd);
565        let chunk = FSST_SAMPLELINE * ((sample_rnd as usize) % chunks);
566
567        let len = FSST_SAMPLELINE.min(line.len() - chunk);
568
569        sample_buf.extend_from_slice(&line[chunk..chunk + len]);
570
571        // SAFETY: this is the data we just placed into `sample_buf` in the line above.
572        let slice =
573            unsafe { std::slice::from_raw_parts(sample_buf.as_ptr().add(sample_buf_offset), len) };
574
575        sample.push(slice);
576
577        sample_buf_offset += len;
578    }
579
580    sample
581}
582
583/// Hash function used in various components of the library.
584///
585/// This is equivalent to the FSST_HASH macro from the C++ implementation.
586#[inline]
587pub(crate) fn fsst_hash(value: u64) -> u64 {
588    value.wrapping_mul(2971215073) ^ value.wrapping_shr(15)
589}
590
591impl Compressor {
592    /// Build and train a `Compressor` from a sample corpus of text.
593    ///
594    /// This function implements the generational algorithm described in the [FSST paper] Section
595    /// 4.3. Starting with an empty symbol table, it iteratively compresses the corpus, then attempts
596    /// to merge symbols when doing so would yield better compression than leaving them unmerged. The
597    /// resulting table will have at most 255 symbols (the 256th symbol is reserved for the escape
598    /// code).
599    ///
600    /// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf
601    pub fn train(values: &Vec<&[u8]>) -> Self {
602        let mut builder = CompressorBuilder::new();
603
604        if values.is_empty() {
605            return builder.build();
606        }
607
608        let mut counters = Counter::new();
609        let mut sample_memory = Vec::with_capacity(FSST_SAMPLEMAX);
610        let mut pqueue = BinaryHeap::with_capacity(65_536);
611
612        let sample = make_sample(&mut sample_memory, values);
613        for sample_frac in GENERATIONS {
614            for (i, line) in sample.iter().enumerate() {
615                if sample_frac < 128 && ((fsst_hash(i as u64) & 127) as usize) > sample_frac {
616                    continue;
617                }
618
619                builder.compress_count(line, &mut counters);
620            }
621
622            // Clear the heap before we use it again
623            pqueue.clear();
624            builder.optimize(&counters, sample_frac, &mut pqueue);
625            counters.clear();
626        }
627
628        builder.build()
629    }
630}
631
632impl CompressorBuilder {
633    /// Find the longest symbol using the hash table and the codes_one_byte and codes_two_byte indexes.
634    fn find_longest_symbol(&self, word: u64) -> Code {
635        // Probe the hash table first to see if we have a long match
636        let entry = self.lossy_pht.lookup(word);
637        let ignored_bits = entry.ignored_bits;
638
639        // If the entry is valid, return the code
640        if !entry.is_unused() && compare_masked(word, entry.symbol.to_u64(), ignored_bits) {
641            return entry.code;
642        }
643
644        // Try and match first two bytes
645        let twobyte = self.codes_two_byte[word as u16 as usize];
646        if twobyte.extended_code() >= FSST_CODE_BASE {
647            return twobyte;
648        }
649
650        // Fall back to single-byte match
651        self.codes_one_byte[word as u8 as usize]
652    }
653
654    /// Compress the text using the current symbol table. Count the code occurrences
655    /// and code-pair occurrences, calculating total gain using the current compressor.
656    ///
657    /// NOTE: this is largely an unfortunate amount of copy-paste from `compress`, just to make sure
658    /// we can do all the counting in a single pass.
659    fn compress_count(&self, sample: &[u8], counter: &mut Counter) -> usize {
660        let mut gain = 0;
661        if sample.is_empty() {
662            return gain;
663        }
664
665        let mut in_ptr = sample.as_ptr();
666
667        // SAFETY: `end` will point just after the end of the `plaintext` slice.
668        let in_end = unsafe { in_ptr.byte_add(sample.len()) };
669        let in_end_sub8 = in_end as usize - 8;
670
671        let mut prev_code: u16 = FSST_CODE_MASK;
672
673        while (in_ptr as usize) < (in_end_sub8) {
674            // SAFETY: ensured in-bounds by loop condition.
675            let word: u64 = unsafe { std::ptr::read_unaligned(in_ptr as *const u64) };
676            let code = self.find_longest_symbol(word);
677            let code_u16 = code.extended_code();
678
679            // Gain increases by the symbol length if a symbol matches, or 0
680            // if an escape is emitted.
681            gain += (code.len() as usize) - ((code_u16 < 256) as usize);
682
683            // Record the single and pair counts
684            counter.record_count1(code_u16);
685            counter.record_count2(prev_code, code_u16);
686
687            // Also record the count for just extending by a single byte, but only if
688            // the symbol is not itself a single byte.
689            if code.len() > 1 {
690                let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16;
691                counter.record_count1(code_first_byte);
692                counter.record_count2(prev_code, code_first_byte);
693            }
694
695            // SAFETY: pointer bound is checked in loop condition before any access is made.
696            in_ptr = unsafe { in_ptr.byte_add(code.len() as usize) };
697
698            prev_code = code_u16;
699        }
700
701        let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
702        assert!(
703            remaining_bytes.is_positive(),
704            "in_ptr exceeded in_end, should not be possible"
705        );
706        let remaining_bytes = remaining_bytes as usize;
707
708        // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
709        // but shift data out of this word rather than advancing an input pointer and potentially reading
710        // unowned memory
711        let mut bytes = [0u8; 8];
712        unsafe {
713            // SAFETY: it is safe to read up to remaining_bytes from in_ptr, and remaining_bytes
714            //  will be <= 8 bytes.
715            std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes);
716        }
717        let mut last_word = u64::from_le_bytes(bytes);
718
719        let mut remaining_bytes = remaining_bytes;
720
721        while remaining_bytes > 0 {
722            // SAFETY: ensured in-bounds by loop condition.
723            let code = self.find_longest_symbol(last_word);
724            let code_u16 = code.extended_code();
725
726            // Gain increases by the symbol length if a symbol matches, or 0
727            // if an escape is emitted.
728            gain += (code.len() as usize) - ((code_u16 < 256) as usize);
729
730            // Record the single and pair counts
731            counter.record_count1(code_u16);
732            counter.record_count2(prev_code, code_u16);
733
734            // Also record the count for just extending by a single byte, but only if
735            // the symbol is not itself a single byte.
736            if code.len() > 1 {
737                let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16;
738                counter.record_count1(code_first_byte);
739                counter.record_count2(prev_code, code_first_byte);
740            }
741
742            // Advance our last_word "input pointer" by shifting off the covered values.
743            let advance = code.len() as usize;
744            remaining_bytes -= advance;
745            last_word = advance_8byte_word(last_word, advance);
746
747            prev_code = code_u16;
748        }
749
750        gain
751    }
752
753    /// Using a set of counters and the existing set of symbols, build a new
754    /// set of symbols/codes that optimizes the gain over the distribution in `counter`.
755    fn optimize(
756        &mut self,
757        counters: &Counter,
758        sample_frac: usize,
759        pqueue: &mut BinaryHeap<Candidate>,
760    ) {
761        // Use a HashMap to deduplicate candidates by symbol content, combining gains
762        // when the same symbol is encountered via different codes.
763        // This matches the C++ implementation's use of unordered_set<QSymbol> with addOrInc.
764        // NOTE: we use fxhash since that is the best Rust hasher for 64-bit ints.
765        let mut candidates = FxHashMap::with_capacity_and_hasher(256, FxBuildHasher);
766
767        for code1 in counters.first_codes() {
768            let symbol1 = self.symbols[code1 as usize];
769            let symbol1_len = symbol1.len();
770            let count = counters.count1(code1);
771
772            // From the c++ impl:
773            // "improves both compression speed (less candidates), but also quality!!"
774            if count < (5 * sample_frac / 128) {
775                continue;
776            }
777
778            let mut gain = count * symbol1_len;
779            // NOTE: use heuristic from C++ implementation to boost the gain of single-byte symbols.
780            // This helps to reduce exception counts.
781            if symbol1_len == 1 {
782                gain *= 8;
783            }
784
785            // Add or combine gain for this symbol
786            *candidates.entry(symbol1).or_insert(0) += gain;
787
788            // Skip merges on last round, or when symbol cannot be extended.
789            if sample_frac >= 128 || symbol1_len == 8 {
790                continue;
791            }
792
793            for code2 in counters.second_codes(code1) {
794                let symbol2 = self.symbols[code2 as usize];
795
796                // If merging would yield a symbol of length greater than 8, skip.
797                if symbol1_len + symbol2.len() > 8 {
798                    continue;
799                }
800                let new_symbol = symbol1.concat(symbol2);
801                let gain = counters.count2(code1, code2) * new_symbol.len();
802
803                // Add or combine gain for this merged symbol
804                *candidates.entry(new_symbol).or_insert(0) += gain;
805            }
806        }
807
808        // Transfer deduplicated candidates to the priority queue
809        for (symbol, gain) in candidates {
810            pqueue.push(Candidate { symbol, gain });
811        }
812
813        // clear self in advance of inserting the symbols.
814        self.clear();
815
816        // Pop the 255 best symbols.
817        let mut n_symbols = 0;
818        while !pqueue.is_empty() && n_symbols < 255 {
819            let candidate = pqueue.pop().unwrap();
820            if self.insert(candidate.symbol, candidate.symbol.len()) {
821                n_symbols += 1;
822            }
823        }
824    }
825}
826
827/// A candidate for inclusion in a symbol table.
828///
829/// This is really only useful for the `optimize` step of training.
830#[derive(Copy, Clone, Debug)]
831struct Candidate {
832    gain: usize,
833    symbol: Symbol,
834}
835
836impl Candidate {
837    fn comparable_form(&self) -> (usize, usize) {
838        (self.gain, self.symbol.len())
839    }
840}
841
842impl Eq for Candidate {}
843
844impl PartialEq<Self> for Candidate {
845    fn eq(&self, other: &Self) -> bool {
846        self.comparable_form().eq(&other.comparable_form())
847    }
848}
849
850impl PartialOrd<Self> for Candidate {
851    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
852        Some(self.cmp(other))
853    }
854}
855
856impl Ord for Candidate {
857    fn cmp(&self, other: &Self) -> Ordering {
858        let self_ord = (self.gain, self.symbol.len());
859        let other_ord = (other.gain, other.symbol.len());
860
861        self_ord.cmp(&other_ord)
862    }
863}
864
865#[cfg(test)]
866mod test {
867    use crate::{Compressor, ESCAPE_CODE, builder::CodesBitmap};
868
869    #[test]
870    fn test_builder() {
871        // Train a Compressor on the toy string
872        let text = b"hello hello hello hello hello";
873
874        // count of 5 is the cutoff for including a symbol in the table.
875        let table = Compressor::train(&vec![text, text, text, text, text]);
876
877        // Use the table to compress a string, see the values
878        let compressed = table.compress(text);
879
880        // Ensure that the compressed string has no escape bytes
881        assert!(compressed.iter().all(|b| *b != ESCAPE_CODE));
882
883        // Ensure that we can compress a string with no values seen at training time, with escape bytes
884        let compressed = table.compress("xyz123".as_bytes());
885        let decompressed = table.decompressor().decompress(&compressed);
886        assert_eq!(&decompressed, b"xyz123");
887        assert_eq!(
888            compressed,
889            vec![
890                ESCAPE_CODE,
891                b'x',
892                ESCAPE_CODE,
893                b'y',
894                ESCAPE_CODE,
895                b'z',
896                ESCAPE_CODE,
897                b'1',
898                ESCAPE_CODE,
899                b'2',
900                ESCAPE_CODE,
901                b'3',
902            ]
903        );
904    }
905
906    #[test]
907    fn test_bitmap() {
908        let mut map = CodesBitmap::default();
909        map.set(10);
910        map.set(100);
911        map.set(500);
912
913        let codes: Vec<u16> = map.codes().collect();
914        assert_eq!(codes, vec![10u16, 100, 500]);
915
916        // empty case
917        let map = CodesBitmap::default();
918        assert!(map.codes().collect::<Vec<_>>().is_empty());
919
920        // edge case: first bit in each block is set
921        let mut map = CodesBitmap::default();
922        (0..8).for_each(|i| map.set(64 * i));
923        assert_eq!(
924            map.codes().collect::<Vec<_>>(),
925            (0u16..8).map(|i| 64 * i).collect::<Vec<_>>(),
926        );
927
928        // Full bitmap case. There are only 512 values, so test them all
929        let mut map = CodesBitmap::default();
930        for i in 0..512 {
931            map.set(i);
932        }
933        assert_eq!(
934            map.codes().collect::<Vec<_>>(),
935            (0u16..511u16).collect::<Vec<_>>()
936        );
937    }
938
939    #[test]
940    #[should_panic(expected = "code cannot exceed")]
941    fn test_bitmap_invalid() {
942        let mut map = CodesBitmap::default();
943        map.set(512);
944    }
945
946    #[test]
947    fn test_no_duplicate_symbols() {
948        // Train on data that is likely to produce duplicate 1-byte and 2-byte candidates.
949        let text = b"aababcabcdabcde";
950        let corpus: Vec<&[u8]> = std::iter::repeat_n(text.as_slice(), 100).collect();
951        let compressor = Compressor::train(&corpus);
952
953        let symbols = compressor.symbol_table();
954        let lengths = compressor.symbol_lengths();
955
956        // Collect all 1-byte symbols and check for duplicates.
957        let one_byte: Vec<u8> = symbols
958            .iter()
959            .zip(lengths.iter())
960            .filter(|&(_, &len)| len == 1)
961            .map(|(sym, _)| sym.first_byte())
962            .collect();
963        let mut one_byte_sorted = one_byte.clone();
964        one_byte_sorted.sort();
965        one_byte_sorted.dedup();
966        assert_eq!(
967            one_byte.len(),
968            one_byte_sorted.len(),
969            "duplicate 1-byte symbols found"
970        );
971
972        // Collect all 2-byte symbols and check for duplicates.
973        let two_byte: Vec<u16> = symbols
974            .iter()
975            .zip(lengths.iter())
976            .filter(|&(_, &len)| len == 2)
977            .map(|(sym, _)| sym.first2())
978            .collect();
979        let mut two_byte_sorted = two_byte.clone();
980        two_byte_sorted.sort();
981        two_byte_sorted.dedup();
982        assert_eq!(
983            two_byte.len(),
984            two_byte_sorted.len(),
985            "duplicate 2-byte symbols found"
986        );
987    }
988}