Skip to main content

fsst/
builder.rs

1//! Functions and types used for building a [`Compressor`] from a corpus of text.
2//!
3//! This module implements the logic from Algorithm 3 of the [FSST Paper].
4//!
5//! [FSST Paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf
6
7use crate::{
8    Code, Compressor, FSST_CODE_BASE, FSST_CODE_MASK, Symbol, advance_8byte_word, compare_masked,
9    lossy_pht::LossyPHT,
10};
11use rustc_hash::{FxBuildHasher, FxHashMap};
12use std::cmp::Ordering;
13use std::collections::BinaryHeap;
14
15/// Bitmap that only works for values up to 512
16#[derive(Clone, Copy, Debug, Default)]
17struct CodesBitmap {
18    codes: [u64; 8],
19}
20
21assert_sizeof!(CodesBitmap => 64);
22
23impl CodesBitmap {
24    /// Set the indicated bit. Must be between 0 and [`FSST_CODE_MASK`][crate::FSST_CODE_MASK].
25    pub(crate) fn set(&mut self, index: usize) {
26        debug_assert!(
27            index <= FSST_CODE_MASK as usize,
28            "code cannot exceed {FSST_CODE_MASK}"
29        );
30
31        let map = index >> 6;
32        self.codes[map] |= 1 << (index % 64);
33    }
34
35    /// Check if `index` is present in the bitmap
36    pub(crate) fn is_set(&self, index: usize) -> bool {
37        debug_assert!(
38            index <= FSST_CODE_MASK as usize,
39            "code cannot exceed {FSST_CODE_MASK}"
40        );
41
42        let map = index >> 6;
43        self.codes[map] & (1 << (index % 64)) != 0
44    }
45
46    /// Get all codes set in this bitmap
47    pub(crate) fn codes(&self) -> CodesIterator<'_> {
48        CodesIterator {
49            inner: self,
50            index: 0,
51            block: self.codes[0],
52            reference: 0,
53        }
54    }
55
56    /// Clear the bitmap of all entries.
57    pub(crate) fn clear(&mut self) {
58        self.codes[0] = 0;
59        self.codes[1] = 0;
60        self.codes[2] = 0;
61        self.codes[3] = 0;
62        self.codes[4] = 0;
63        self.codes[5] = 0;
64        self.codes[6] = 0;
65        self.codes[7] = 0;
66    }
67}
68
69struct CodesIterator<'a> {
70    inner: &'a CodesBitmap,
71    index: usize,
72    block: u64,
73    reference: usize,
74}
75
76impl Iterator for CodesIterator<'_> {
77    type Item = u16;
78
79    fn next(&mut self) -> Option<Self::Item> {
80        // If current is zero, advance to next non-zero block
81        while self.block == 0 {
82            self.index += 1;
83            if self.index >= 8 {
84                return None;
85            }
86            self.block = self.inner.codes[self.index];
87            self.reference = self.index * 64;
88        }
89
90        // Find the next set bit in the current block.
91        let position = self.block.trailing_zeros() as usize;
92        let code = self.reference + position;
93
94        if code >= 511 {
95            return None;
96        }
97
98        // The next iteration will calculate with reference to the returned code + 1
99        self.reference = code + 1;
100        self.block = if position == 63 {
101            0
102        } else {
103            self.block >> (1 + position)
104        };
105
106        Some(code as u16)
107    }
108}
109
110#[derive(Debug, Clone)]
111struct Counter {
112    /// Frequency count for each code.
113    counts1: Vec<usize>,
114
115    /// Frequency count for each code-pair.
116    counts2: Vec<usize>,
117
118    /// Bitmap index for codes that appear in counts1
119    code1_index: CodesBitmap,
120
121    /// Bitmap index of pairs that have been set.
122    ///
123    /// `pair_index[code1].codes()` yields an iterator that can
124    /// be used to find all possible codes that follow `codes1`.
125    pair_index: Vec<CodesBitmap>,
126}
127
128const COUNTS1_SIZE: usize = (FSST_CODE_MASK + 1) as usize;
129
130// NOTE: in Rust, creating a 1D vector of length N^2 is ~4x faster than creating a 2-D vector,
131//  because `vec!` has a specialization for zero.
132//
133// We also include +1 extra row at the end so that we can do writes into the counters without a branch
134// for the first iteration.
135const COUNTS2_SIZE: usize = COUNTS1_SIZE * COUNTS1_SIZE;
136
137impl Counter {
138    fn new() -> Self {
139        let mut counts1 = Vec::with_capacity(COUNTS1_SIZE);
140        let mut counts2 = Vec::with_capacity(COUNTS2_SIZE);
141        // SAFETY: all accesses to the vector go through the bitmap to ensure no uninitialized
142        //  data is ever read from these vectors.
143        unsafe {
144            counts1.set_len(COUNTS1_SIZE);
145            counts2.set_len(COUNTS2_SIZE);
146        }
147
148        Self {
149            counts1,
150            counts2,
151            code1_index: CodesBitmap::default(),
152            pair_index: vec![CodesBitmap::default(); COUNTS1_SIZE],
153        }
154    }
155
156    #[inline]
157    fn record_count1(&mut self, code1: u16) {
158        // If not set, we want to start at one.
159        let base = if self.code1_index.is_set(code1 as usize) {
160            self.counts1[code1 as usize]
161        } else {
162            0
163        };
164
165        self.counts1[code1 as usize] = base + 1;
166        self.code1_index.set(code1 as usize);
167    }
168
169    #[inline]
170    fn record_count2(&mut self, code1: u16, code2: u16) {
171        debug_assert!(code1 == FSST_CODE_MASK || self.code1_index.is_set(code1 as usize));
172        debug_assert!(self.code1_index.is_set(code2 as usize));
173
174        let idx = (code1 as usize) * COUNTS1_SIZE + (code2 as usize);
175        if self.pair_index[code1 as usize].is_set(code2 as usize) {
176            self.counts2[idx] += 1;
177        } else {
178            self.counts2[idx] = 1;
179        }
180        self.pair_index[code1 as usize].set(code2 as usize);
181    }
182
183    #[inline]
184    fn count1(&self, code1: u16) -> usize {
185        debug_assert!(self.code1_index.is_set(code1 as usize));
186
187        self.counts1[code1 as usize]
188    }
189
190    #[inline]
191    fn count2(&self, code1: u16, code2: u16) -> usize {
192        debug_assert!(self.code1_index.is_set(code1 as usize));
193        debug_assert!(self.code1_index.is_set(code2 as usize));
194        debug_assert!(self.pair_index[code1 as usize].is_set(code2 as usize));
195
196        let idx = (code1 as usize) * 512 + (code2 as usize);
197        self.counts2[idx]
198    }
199
200    /// Returns an ordered iterator over the codes that were observed
201    /// in a call to [`Self::count1`].
202    fn first_codes(&self) -> CodesIterator<'_> {
203        self.code1_index.codes()
204    }
205
206    /// Returns an iterator over the codes that have been observed
207    /// to follow `code1`.
208    ///
209    /// This is the set of all values `code2` where there was
210    /// previously a call to `self.record_count2(code1, code2)`.
211    fn second_codes(&self, code1: u16) -> CodesIterator<'_> {
212        self.pair_index[code1 as usize].codes()
213    }
214
215    /// Clear the counters.
216    /// Note that this just touches the bitmaps and sets them all to invalid.
217    fn clear(&mut self) {
218        self.code1_index.clear();
219        for index in &mut self.pair_index {
220            index.clear();
221        }
222    }
223}
224
225/// Entrypoint for building a new `Compressor`.
226pub struct CompressorBuilder {
227    /// Table mapping codes to symbols.
228    ///
229    /// The entries 0-255 are setup in some other way here
230    symbols: Vec<Symbol>,
231
232    /// The number of entries in the symbol table that have been populated, not counting
233    /// the escape values.
234    n_symbols: u8,
235
236    /// Counts for number of symbols of each length.
237    ///
238    /// `len_histogram[len-1]` = count of the symbols of length `len`.
239    len_histogram: [u8; 8],
240
241    /// Inverted index mapping 1-byte symbols to codes.
242    ///
243    /// This is only used for building, not used by the final `Compressor`.
244    codes_one_byte: Vec<Code>,
245
246    /// Inverted index mapping 2-byte symbols to codes
247    codes_two_byte: Vec<Code>,
248
249    /// Lossy perfect hash table for looking up codes to symbols that are 3 bytes or more
250    lossy_pht: LossyPHT,
251}
252
253impl CompressorBuilder {
254    /// Create a new builder.
255    pub fn new() -> Self {
256        // NOTE: `vec!` has a specialization for building a new vector of `0u64`. Because Symbol and u64
257        //  have the same bit pattern, we can allocate as u64 and transmute. If we do `vec![Symbol::EMPTY; N]`,
258        // that will create a new Vec and call `Symbol::EMPTY.clone()` `N` times which is considerably slower.
259        let symbols = vec![0u64; 511];
260
261        // SAFETY: transmute safety assured by the compiler.
262        let symbols: Vec<Symbol> = unsafe { std::mem::transmute(symbols) };
263
264        let mut table = Self {
265            symbols,
266            n_symbols: 0,
267            len_histogram: [0; 8],
268            codes_two_byte: Vec::with_capacity(65_536),
269            codes_one_byte: Vec::with_capacity(512),
270            lossy_pht: LossyPHT::new(),
271        };
272
273        // Populate the escape byte entries.
274        for byte in 0..=255 {
275            let symbol = Symbol::from_u8(byte);
276            table.symbols[byte as usize] = symbol;
277        }
278
279        // Fill codes_one_byte with pseudocodes for each byte.
280        for byte in 0..=255 {
281            // Push pseudocode for single-byte escape.
282            table.codes_one_byte.push(Code::new_escape(byte));
283        }
284
285        // Fill codes_two_byte with pseudocode of first byte
286        for idx in 0..=65_535 {
287            table.codes_two_byte.push(Code::new_escape(idx as u8));
288        }
289
290        table
291    }
292}
293
294impl Default for CompressorBuilder {
295    fn default() -> Self {
296        Self::new()
297    }
298}
299
300impl CompressorBuilder {
301    /// Attempt to insert a new symbol at the end of the table.
302    ///
303    /// # Panics
304    ///
305    /// Panics if the table is already full.
306    ///
307    /// # Returns
308    ///
309    /// Returns true if the symbol was inserted successfully, or false if it conflicted
310    /// with an existing symbol.
311    pub fn insert(&mut self, symbol: Symbol, len: usize) -> bool {
312        assert!(self.n_symbols < 255, "cannot insert into full symbol table");
313        assert_eq!(len, symbol.len(), "provided len must equal symbol.len()");
314
315        if len == 2 {
316            // shortCodes
317            self.codes_two_byte[symbol.first2() as usize] =
318                Code::new_symbol_building(self.n_symbols, 2);
319        } else if len == 1 {
320            // byteCodes
321            self.codes_one_byte[symbol.first_byte() as usize] =
322                Code::new_symbol_building(self.n_symbols, 1);
323        } else {
324            // Symbols of 3 or more bytes go into the hash table
325            if !self.lossy_pht.insert(symbol, len, self.n_symbols) {
326                return false;
327            }
328        }
329
330        // Increment length histogram.
331        self.len_histogram[len - 1] += 1;
332
333        // Insert successfully stored symbol at end of the symbol table
334        // Note the rescaling from range [0-254] -> [256, 510].
335        self.symbols[256 + (self.n_symbols as usize)] = symbol;
336        self.n_symbols += 1;
337        true
338    }
339
340    /// Clear all set items from the compressor.
341    ///
342    /// This is considerably faster than building a new Compressor from scratch for each
343    /// iteration of the `train` loop.
344    fn clear(&mut self) {
345        // Eliminate every observed code from the table.
346        for code in 0..(256 + self.n_symbols as usize) {
347            let symbol = self.symbols[code];
348            if symbol.len() == 1 {
349                // Reset the entry from the codes_one_byte array.
350                self.codes_one_byte[symbol.first_byte() as usize] =
351                    Code::new_escape(symbol.first_byte());
352            } else if symbol.len() == 2 {
353                // Reset the entry from the codes_two_byte array.
354                self.codes_two_byte[symbol.first2() as usize] =
355                    Code::new_escape(symbol.first_byte());
356            } else {
357                // Clear the hashtable entry
358                self.lossy_pht.remove(symbol);
359            }
360        }
361
362        // Reset len histogram
363        for i in 0..=7 {
364            self.len_histogram[i] = 0;
365        }
366
367        self.n_symbols = 0;
368    }
369
370    /// Finalizing the table is done once building is complete to prepare for efficient
371    /// compression.
372    ///
373    /// When we finalize the table, the following modifications are made in-place:
374    ///
375    /// 1. The codes are renumbered so that all symbols are ordered by length (order 23456781).
376    ///    During this process, the two byte symbols are separated into a byte_lim and a suffix_lim,
377    ///    so we know that we don't need to check the suffix limitations instead.
378    /// 2. The 1-byte symbols index is merged into the 2-byte symbols index to allow for use of only
379    ///    a single index in front of the hash table.
380    ///
381    /// # Returns
382    ///
383    /// Returns the `suffix_lim`, which is the index of the two-byte code before where we know
384    /// there are no longer suffixies in the symbol table.
385    ///
386    /// Also returns the lengths vector, which is of length `n_symbols` and contains the
387    /// length for each of the values.
388    fn finalize(&mut self) -> (u8, Vec<u8>) {
389        // Create a cumulative sum of each of the elements of the input line numbers.
390        // Do a map that includes the previously seen value as well.
391        // Regroup symbols based on their lengths.
392        // Space at the end of the symbol table reserved for the one-byte codes.
393        let byte_lim = self.n_symbols - self.len_histogram[0];
394
395        // Start code for each length.
396        // Length 1: at the end of symbol table.
397        // Length 2: starts at 0. Split into before/after suffixLim.
398        let mut codes_by_length = [0u8; 8];
399        codes_by_length[0] = byte_lim;
400        codes_by_length[1] = 0;
401
402        // codes for lengths 3..=8 start where the previous ones end.
403        for i in 1..7 {
404            codes_by_length[i + 1] = codes_by_length[i] + self.len_histogram[i];
405        }
406
407        // no_suffix_code is the lowest code for a symbol that does not have a longer 3+ byte
408        // suffix in the table.
409        // This value starts at 0 and extends up.
410        let mut no_suffix_code = 0;
411
412        // The codes that do not have a suffix begin just before the range of the 3-byte codes.
413        let mut has_suffix_code = codes_by_length[2];
414
415        // Assign each symbol a new code ordered by lengths, in the order
416        // 2(no suffix) | 2 (suffix) | 3 | 4 | 5 | 6 | 7 | 8 | 1
417        let mut new_codes = [0u8; FSST_CODE_BASE as usize];
418
419        let mut symbol_lens = [0u8; FSST_CODE_BASE as usize];
420
421        for i in 0..(self.n_symbols as usize) {
422            let symbol = self.symbols[256 + i];
423            let len = symbol.len();
424            if len == 2 {
425                let has_suffix = self
426                    .symbols
427                    .iter()
428                    .skip(FSST_CODE_BASE as usize)
429                    .enumerate()
430                    .any(|(k, other)| i != k && symbol.first2() == other.first2());
431
432                if has_suffix {
433                    // Symbols that have a longer suffix are inserted at the end of the 2-byte range
434                    has_suffix_code -= 1;
435                    new_codes[i] = has_suffix_code;
436                } else {
437                    // Symbols that do not have a longer suffix are inserted at the start of
438                    // the 2-byte range.
439                    new_codes[i] = no_suffix_code;
440                    no_suffix_code += 1;
441                }
442            } else {
443                // Assign new code based on the next code available for the given length symbol
444                new_codes[i] = codes_by_length[len - 1];
445                codes_by_length[len - 1] += 1;
446            }
447
448            // Write the symbol into the front half of the symbol table.
449            // We are reusing the space that was previously occupied by escapes.
450            self.symbols[new_codes[i] as usize] = symbol;
451            symbol_lens[new_codes[i] as usize] = len as u8;
452        }
453
454        // Truncate the symbol table to only include the "true" symbols.
455        self.symbols.truncate(self.n_symbols as usize);
456
457        // Rewrite the codes_one_byte table to point at the new code values.
458        // Replace pseudocodes with escapes.
459        for byte in 0..=255 {
460            let one_byte = self.codes_one_byte[byte];
461            if one_byte.extended_code() >= FSST_CODE_BASE {
462                let new_code = new_codes[one_byte.code() as usize];
463                self.codes_one_byte[byte] = Code::new_symbol(new_code, 1);
464            } else {
465                // After finalize: codes_one_byte contains the unused value
466                self.codes_one_byte[byte] = Code::UNUSED;
467            }
468        }
469
470        // Rewrite the codes_two_byte table to point at the new code values.
471        // Replace pseudocodes with escapes.
472        for two_bytes in 0..=65_535 {
473            let two_byte = self.codes_two_byte[two_bytes];
474            if two_byte.extended_code() >= FSST_CODE_BASE {
475                let new_code = new_codes[two_byte.code() as usize];
476                self.codes_two_byte[two_bytes] = Code::new_symbol(new_code, 2);
477            } else {
478                // The one-byte code for the given code number here...
479                self.codes_two_byte[two_bytes] = self.codes_one_byte[two_bytes & 0xFF];
480            }
481        }
482
483        // Reset values in the hash table as well.
484        self.lossy_pht.renumber(&new_codes);
485
486        // Pre-compute the lengths
487        let mut lengths = Vec::with_capacity(self.n_symbols as usize);
488        for symbol in &self.symbols {
489            lengths.push(symbol.len() as u8);
490        }
491
492        (has_suffix_code, lengths)
493    }
494
495    /// Build into the final hash table.
496    pub fn build(mut self) -> Compressor {
497        // finalize the symbol table by inserting the codes_twobyte values into
498        // the relevant parts of the `codes_onebyte` set.
499
500        let (has_suffix_code, lengths) = self.finalize();
501
502        Compressor {
503            symbols: self.symbols,
504            lengths,
505            n_symbols: self.n_symbols,
506            has_suffix_code,
507            codes_two_byte: self.codes_two_byte,
508            lossy_pht: self.lossy_pht,
509        }
510    }
511}
512
513/// The number of generations used for training. This is taken from the [FSST paper].
514///
515/// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf
516#[cfg(not(miri))]
517const GENERATIONS: [usize; 5] = [8usize, 38, 68, 98, 128];
518#[cfg(miri)]
519const GENERATIONS: [usize; 3] = [8usize, 38, 128];
520
521const FSST_SAMPLETARGET: usize = 1 << 14;
522const FSST_SAMPLEMAX: usize = 1 << 15;
523const FSST_SAMPLELINE: usize = 512;
524
525/// Create a sample from a set of strings in the input.
526///
527/// Sample is constructing by copying "chunks" from the `str_in`s into the `sample_buf`, the
528/// returned slices are pointers into the `sample_buf`.
529///
530/// SAFETY: sample_buf must be >= FSST_SAMPLEMAX bytes long. Providing something less may cause unexpected failures.
531#[allow(clippy::ptr_arg)]
532fn make_sample<'a, 'b: 'a>(
533    sample_buf: &'a mut Vec<u8>,
534    str_in: &Vec<&'b [u8]>,
535    tot_size: usize,
536) -> Vec<&'a [u8]> {
537    assert!(
538        sample_buf.capacity() >= FSST_SAMPLEMAX,
539        "sample_buf.len() < FSST_SAMPLEMAX"
540    );
541
542    let mut sample: Vec<&[u8]> = Vec::new();
543
544    if tot_size < FSST_SAMPLETARGET {
545        return str_in.clone();
546    }
547
548    let mut sample_rnd = fsst_hash(4637947);
549    let sample_lim = FSST_SAMPLETARGET;
550    let mut sample_buf_offset: usize = 0;
551
552    while sample_buf_offset < sample_lim {
553        sample_rnd = fsst_hash(sample_rnd);
554        let line_nr = (sample_rnd as usize) % str_in.len();
555
556        // Find the first non-empty chunk starting at line_nr, wrapping around if
557        // necessary.
558        let Some(line) = (line_nr..str_in.len())
559            .chain(0..line_nr)
560            .map(|line_nr| str_in[line_nr])
561            .find(|line| !line.is_empty())
562        else {
563            return sample;
564        };
565
566        let chunks = 1 + ((line.len() - 1) / FSST_SAMPLELINE);
567        sample_rnd = fsst_hash(sample_rnd);
568        let chunk = FSST_SAMPLELINE * ((sample_rnd as usize) % chunks);
569
570        let len = FSST_SAMPLELINE.min(line.len() - chunk);
571
572        sample_buf.extend_from_slice(&line[chunk..chunk + len]);
573
574        // SAFETY: this is the data we just placed into `sample_buf` in the line above.
575        let slice =
576            unsafe { std::slice::from_raw_parts(sample_buf.as_ptr().add(sample_buf_offset), len) };
577
578        sample.push(slice);
579
580        sample_buf_offset += len;
581    }
582
583    sample
584}
585
586/// Hash function used in various components of the library.
587///
588/// This is equivalent to the FSST_HASH macro from the C++ implementation.
589#[inline]
590pub(crate) fn fsst_hash(value: u64) -> u64 {
591    value.wrapping_mul(2971215073) ^ value.wrapping_shr(15)
592}
593
594impl Compressor {
595    /// Build and train a `Compressor` from a sample corpus of text.
596    ///
597    /// This function implements the generational algorithm described in the [FSST paper] Section
598    /// 4.3. Starting with an empty symbol table, it iteratively compresses the corpus, then attempts
599    /// to merge symbols when doing so would yield better compression than leaving them unmerged. The
600    /// resulting table will have at most 255 symbols (the 256th symbol is reserved for the escape
601    /// code).
602    ///
603    /// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf
604    pub fn train(values: &Vec<&[u8]>) -> Self {
605        let mut builder = CompressorBuilder::new();
606
607        if values.is_empty() {
608            return builder.build();
609        }
610
611        let mut counters = Counter::new();
612        let mut sample_memory = Vec::with_capacity(FSST_SAMPLEMAX);
613        let mut pqueue = BinaryHeap::with_capacity(65_536);
614
615        let tot_size: usize = values.iter().map(|s| s.len()).sum();
616        let sampled = tot_size >= FSST_SAMPLETARGET;
617        let sample = make_sample(&mut sample_memory, values, tot_size);
618        for sample_frac in GENERATIONS {
619            for (i, line) in sample.iter().enumerate() {
620                if sample_frac < 128 && ((fsst_hash(i as u64) & 127) as usize) > sample_frac {
621                    continue;
622                }
623
624                builder.compress_count(line, &mut counters);
625            }
626
627            // Clear the heap before we use it again
628            pqueue.clear();
629            let prune = sample_frac >= 128 && !sampled;
630            builder.optimize(&counters, sample_frac, &mut pqueue, prune);
631            counters.clear();
632        }
633
634        builder.build()
635    }
636}
637
638impl CompressorBuilder {
639    /// Find the longest symbol using the hash table and the codes_one_byte and codes_two_byte indexes.
640    fn find_longest_symbol(&self, word: u64) -> Code {
641        // Probe the hash table first to see if we have a long match
642        let entry = self.lossy_pht.lookup(word);
643        let ignored_bits = entry.ignored_bits;
644
645        // If the entry is valid, return the code
646        if !entry.is_unused() && compare_masked(word, entry.symbol.to_u64(), ignored_bits) {
647            return entry.code;
648        }
649
650        // Try and match first two bytes
651        let twobyte = self.codes_two_byte[word as u16 as usize];
652        if twobyte.extended_code() >= FSST_CODE_BASE {
653            return twobyte;
654        }
655
656        // Fall back to single-byte match
657        self.codes_one_byte[word as u8 as usize]
658    }
659
660    /// Compress the text using the current symbol table. Count the code occurrences
661    /// and code-pair occurrences, calculating total gain using the current compressor.
662    ///
663    /// NOTE: this is largely an unfortunate amount of copy-paste from `compress`, just to make sure
664    /// we can do all the counting in a single pass.
665    fn compress_count(&self, sample: &[u8], counter: &mut Counter) -> usize {
666        let mut gain = 0;
667        if sample.is_empty() {
668            return gain;
669        }
670
671        let mut in_ptr = sample.as_ptr();
672
673        // SAFETY: `end` will point just after the end of the `plaintext` slice.
674        let in_end = unsafe { in_ptr.byte_add(sample.len()) };
675        let in_end_sub8 = in_end as usize - 8;
676
677        let mut prev_code: u16 = FSST_CODE_MASK;
678
679        while (in_ptr as usize) < (in_end_sub8) {
680            // SAFETY: ensured in-bounds by loop condition.
681            let word: u64 = unsafe { std::ptr::read_unaligned(in_ptr as *const u64) };
682            let code = self.find_longest_symbol(word);
683            let code_u16 = code.extended_code();
684
685            // Gain increases by the symbol length if a symbol matches, or 0
686            // if an escape is emitted.
687            gain += (code.len() as usize) - ((code_u16 < 256) as usize);
688
689            // Record the single and pair counts
690            counter.record_count1(code_u16);
691            counter.record_count2(prev_code, code_u16);
692
693            // Also record the count for just extending by a single byte, but only if
694            // the symbol is not itself a single byte.
695            if code.len() > 1 {
696                let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16;
697                counter.record_count1(code_first_byte);
698                counter.record_count2(prev_code, code_first_byte);
699            }
700
701            // SAFETY: pointer bound is checked in loop condition before any access is made.
702            in_ptr = unsafe { in_ptr.byte_add(code.len() as usize) };
703
704            prev_code = code_u16;
705        }
706
707        let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
708        assert!(
709            remaining_bytes.is_positive(),
710            "in_ptr exceeded in_end, should not be possible"
711        );
712        let remaining_bytes = remaining_bytes as usize;
713
714        // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
715        // but shift data out of this word rather than advancing an input pointer and potentially reading
716        // unowned memory
717        let mut bytes = [0u8; 8];
718        unsafe {
719            // SAFETY: it is safe to read up to remaining_bytes from in_ptr, and remaining_bytes
720            //  will be <= 8 bytes.
721            std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes);
722        }
723        let mut last_word = u64::from_le_bytes(bytes);
724
725        let mut remaining_bytes = remaining_bytes;
726
727        while remaining_bytes > 0 {
728            // SAFETY: ensured in-bounds by loop condition.
729            let code = self.find_longest_symbol(last_word);
730            let code_u16 = code.extended_code();
731
732            // Gain increases by the symbol length if a symbol matches, or 0
733            // if an escape is emitted.
734            gain += (code.len() as usize) - ((code_u16 < 256) as usize);
735
736            // Record the single and pair counts
737            counter.record_count1(code_u16);
738            counter.record_count2(prev_code, code_u16);
739
740            // Also record the count for just extending by a single byte, but only if
741            // the symbol is not itself a single byte.
742            if code.len() > 1 {
743                let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16;
744                counter.record_count1(code_first_byte);
745                counter.record_count2(prev_code, code_first_byte);
746            }
747
748            // Advance our last_word "input pointer" by shifting off the covered values.
749            let advance = code.len() as usize;
750            remaining_bytes -= advance;
751            last_word = advance_8byte_word(last_word, advance);
752
753            prev_code = code_u16;
754        }
755
756        gain
757    }
758
759    /// Using a set of counters and the existing set of symbols, build a new
760    /// set of symbols/codes that optimizes the gain over the distribution in `counter`.
761    fn optimize(
762        &mut self,
763        counters: &Counter,
764        sample_frac: usize,
765        pqueue: &mut BinaryHeap<Candidate>,
766        prune: bool,
767    ) {
768        // Use a HashMap to deduplicate candidates by symbol content, combining gains
769        // when the same symbol is encountered via different codes.
770        // This matches the C++ implementation's use of unordered_set<QSymbol> with addOrInc.
771        // NOTE: we use fxhash since that is the best Rust hasher for 64-bit ints.
772        let mut candidates = FxHashMap::with_capacity_and_hasher(256, FxBuildHasher);
773
774        for code1 in counters.first_codes() {
775            let symbol1 = self.symbols[code1 as usize];
776            let symbol1_len = symbol1.len();
777            let count = counters.count1(code1);
778
779            // From the c++ impl:
780            // "improves both compression speed (less candidates), but also quality!!"
781            // When pruning (final pass, exact counts), lower the threshold to 1
782            // so the pruning check can decide based on cost/benefit.
783            let min_count = if prune { 1 } else { 5 * sample_frac / 128 };
784            if count < min_count {
785                continue;
786            }
787
788            let mut gain = count * symbol1_len;
789            // NOTE: use heuristic from C++ implementation to boost the gain of single-byte symbols.
790            // This helps to reduce exception counts.
791            if symbol1_len == 1 {
792                gain *= 8;
793            }
794
795            // Add or combine gain for this symbol
796            *candidates.entry(symbol1).or_insert(0) += gain;
797
798            // Skip merges on last round, or when symbol cannot be extended.
799            if sample_frac >= 128 || symbol1_len == 8 {
800                continue;
801            }
802
803            for code2 in counters.second_codes(code1) {
804                let symbol2 = self.symbols[code2 as usize];
805
806                // If merging would yield a symbol of length greater than 8, skip.
807                if symbol1_len + symbol2.len() > 8 {
808                    continue;
809                }
810                let new_symbol = symbol1.concat(symbol2);
811                let gain = counters.count2(code1, code2) * new_symbol.len();
812
813                // Add or combine gain for this merged symbol
814                *candidates.entry(new_symbol).or_insert(0) += gain;
815            }
816        }
817
818        // Transfer deduplicated candidates to the priority queue
819        for (symbol, gain) in candidates {
820            pqueue.push(Candidate { symbol, gain });
821        }
822
823        // clear self in advance of inserting the symbols.
824        self.clear();
825
826        // Pop the 255 best symbols.
827        let mut n_symbols = 0;
828        while !pqueue.is_empty() && n_symbols < 255 {
829            let candidate = pqueue.pop().unwrap();
830            if prune {
831                let symbol_len = candidate.symbol.len();
832                let saves = if symbol_len == 1 {
833                    candidate.gain / 8 // undo the 8x single-byte boost
834                } else {
835                    candidate.gain
836                };
837                if saves <= symbol_len + 1 {
838                    continue;
839                }
840            }
841            if self.insert(candidate.symbol, candidate.symbol.len()) {
842                n_symbols += 1;
843            }
844        }
845    }
846}
847
848/// A candidate for inclusion in a symbol table.
849///
850/// This is really only useful for the `optimize` step of training.
851#[derive(Copy, Clone, Debug)]
852struct Candidate {
853    gain: usize,
854    symbol: Symbol,
855}
856
857impl Candidate {
858    fn comparable_form(&self) -> (usize, usize) {
859        (self.gain, self.symbol.len())
860    }
861}
862
863impl Eq for Candidate {}
864
865impl PartialEq<Self> for Candidate {
866    fn eq(&self, other: &Self) -> bool {
867        self.comparable_form().eq(&other.comparable_form())
868    }
869}
870
871impl PartialOrd<Self> for Candidate {
872    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
873        Some(self.cmp(other))
874    }
875}
876
877impl Ord for Candidate {
878    fn cmp(&self, other: &Self) -> Ordering {
879        let self_ord = (self.gain, self.symbol.len());
880        let other_ord = (other.gain, other.symbol.len());
881
882        self_ord.cmp(&other_ord)
883    }
884}
885
886#[cfg(test)]
887mod test {
888    use crate::{Compressor, ESCAPE_CODE, builder::CodesBitmap};
889
890    #[test]
891    fn test_builder() {
892        // Train a Compressor on the toy string
893        let text = b"hello hello hello hello hello";
894
895        // count of 5 is the cutoff for including a symbol in the table.
896        let table = Compressor::train(&vec![text, text, text, text, text]);
897
898        // Use the table to compress a string, see the values
899        let compressed = table.compress(text);
900
901        // Ensure that the compressed string has no escape bytes
902        assert!(compressed.iter().all(|b| *b != ESCAPE_CODE));
903
904        // Ensure that we can compress a string with no values seen at training time, with escape bytes
905        let compressed = table.compress("xyz123".as_bytes());
906        let decompressed = table.decompressor().decompress(&compressed);
907        assert_eq!(&decompressed, b"xyz123");
908        assert_eq!(
909            compressed,
910            vec![
911                ESCAPE_CODE,
912                b'x',
913                ESCAPE_CODE,
914                b'y',
915                ESCAPE_CODE,
916                b'z',
917                ESCAPE_CODE,
918                b'1',
919                ESCAPE_CODE,
920                b'2',
921                ESCAPE_CODE,
922                b'3',
923            ]
924        );
925    }
926
927    #[test]
928    fn test_bitmap() {
929        let mut map = CodesBitmap::default();
930        map.set(10);
931        map.set(100);
932        map.set(500);
933
934        let codes: Vec<u16> = map.codes().collect();
935        assert_eq!(codes, vec![10u16, 100, 500]);
936
937        // empty case
938        let map = CodesBitmap::default();
939        assert!(map.codes().collect::<Vec<_>>().is_empty());
940
941        // edge case: first bit in each block is set
942        let mut map = CodesBitmap::default();
943        (0..8).for_each(|i| map.set(64 * i));
944        assert_eq!(
945            map.codes().collect::<Vec<_>>(),
946            (0u16..8).map(|i| 64 * i).collect::<Vec<_>>(),
947        );
948
949        // Full bitmap case. There are only 512 values, so test them all
950        let mut map = CodesBitmap::default();
951        for i in 0..512 {
952            map.set(i);
953        }
954        assert_eq!(
955            map.codes().collect::<Vec<_>>(),
956            (0u16..511u16).collect::<Vec<_>>()
957        );
958    }
959
960    #[test]
961    #[should_panic(expected = "code cannot exceed")]
962    fn test_bitmap_invalid() {
963        let mut map = CodesBitmap::default();
964        map.set(512);
965    }
966
967    #[test]
968    fn test_no_duplicate_symbols() {
969        // Train on data that is likely to produce duplicate 1-byte and 2-byte candidates.
970        let text = b"aababcabcdabcde";
971        let corpus: Vec<&[u8]> = std::iter::repeat_n(text.as_slice(), 100).collect();
972        let compressor = Compressor::train(&corpus);
973
974        let symbols = compressor.symbol_table();
975        let lengths = compressor.symbol_lengths();
976
977        // Collect all 1-byte symbols and check for duplicates.
978        let one_byte: Vec<u8> = symbols
979            .iter()
980            .zip(lengths.iter())
981            .filter(|&(_, &len)| len == 1)
982            .map(|(sym, _)| sym.first_byte())
983            .collect();
984        let mut one_byte_sorted = one_byte.clone();
985        one_byte_sorted.sort();
986        one_byte_sorted.dedup();
987        assert_eq!(
988            one_byte.len(),
989            one_byte_sorted.len(),
990            "duplicate 1-byte symbols found"
991        );
992
993        // Collect all 2-byte symbols and check for duplicates.
994        let two_byte: Vec<u16> = symbols
995            .iter()
996            .zip(lengths.iter())
997            .filter(|&(_, &len)| len == 2)
998            .map(|(sym, _)| sym.first2())
999            .collect();
1000        let mut two_byte_sorted = two_byte.clone();
1001        two_byte_sorted.sort();
1002        two_byte_sorted.dedup();
1003        assert_eq!(
1004            two_byte.len(),
1005            two_byte_sorted.len(),
1006            "duplicate 2-byte symbols found"
1007        );
1008    }
1009}