fsst/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg(target_endian = "little")]
3
4/// Throw a compiler error if a type isn't guaranteed to have a specific size in bytes.
5macro_rules! assert_sizeof {
6    ($typ:ty => $size_in_bytes:expr) => {
7        const _: [u8; $size_in_bytes] = [0; std::mem::size_of::<$typ>()];
8    };
9}
10
11use lossy_pht::LossyPHT;
12use std::fmt::{Debug, Formatter};
13use std::mem::MaybeUninit;
14
15mod builder;
16mod lossy_pht;
17
18pub use builder::*;
19
20/// `Symbol`s are small (up to 8-byte) segments of strings, stored in a [`Compressor`][`crate::Compressor`] and
21/// identified by an 8-bit code.
22#[derive(Copy, Clone)]
23pub struct Symbol(u64);
24
25assert_sizeof!(Symbol => 8);
26
27impl Symbol {
28    /// Zero value for `Symbol`.
29    pub const ZERO: Self = Self::zero();
30
31    /// Constructor for a `Symbol` from an 8-element byte slice.
32    pub fn from_slice(slice: &[u8; 8]) -> Self {
33        let num: u64 = u64::from_le_bytes(*slice);
34
35        Self(num)
36    }
37
38    /// Return a zero symbol
39    const fn zero() -> Self {
40        Self(0)
41    }
42
43    /// Create a new single-byte symbol
44    pub fn from_u8(value: u8) -> Self {
45        Self(value as u64)
46    }
47}
48
49impl Symbol {
50    /// Calculate the length of the symbol in bytes. Always a value between 1 and 8.
51    ///
52    /// Each symbol has the capacity to hold up to 8 bytes of data, but the symbols
53    /// can contain fewer bytes, padded with 0x00. There is a special case of a symbol
54    /// that holds the byte 0x00. In that case, the symbol contains `0x0000000000000000`
55    /// but we want to interpret that as a one-byte symbol containing `0x00`.
56    #[allow(clippy::len_without_is_empty)]
57    pub fn len(self) -> usize {
58        let numeric = self.0;
59        // For little-endian platforms, this counts the number of *trailing* zeros
60        let null_bytes = (numeric.leading_zeros() >> 3) as usize;
61
62        // Special case handling of a symbol with all-zeros. This is actually
63        // a 1-byte symbol containing 0x00.
64        let len = size_of::<Self>() - null_bytes;
65        if len == 0 { 1 } else { len }
66    }
67
68    #[inline]
69    fn as_u64(self) -> u64 {
70        self.0
71    }
72
73    /// Get the first byte of the symbol as a `u8`.
74    ///
75    /// If the symbol is empty, this will return the zero byte.
76    #[inline]
77    pub fn first_byte(self) -> u8 {
78        self.0 as u8
79    }
80
81    /// Get the first two bytes of the symbol as a `u16`.
82    ///
83    /// If the Symbol is one or zero bytes, this will return `0u16`.
84    #[inline]
85    pub fn first2(self) -> u16 {
86        self.0 as u16
87    }
88
89    /// Get the first three bytes of the symbol as a `u64`.
90    ///
91    /// If the Symbol is one or zero bytes, this will return `0u64`.
92    #[inline]
93    pub fn first3(self) -> u64 {
94        self.0 & 0xFF_FF_FF
95    }
96
97    /// Return a new `Symbol` by logically concatenating ourselves with another `Symbol`.
98    pub fn concat(self, other: Self) -> Self {
99        assert!(
100            self.len() + other.len() <= 8,
101            "cannot build symbol with length > 8"
102        );
103
104        let self_len = self.len();
105
106        Self((other.0 << (8 * self_len)) | self.0)
107    }
108}
109
110impl Debug for Symbol {
111    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
112        write!(f, "[")?;
113
114        let slice = &self.0.to_le_bytes()[0..self.len()];
115        for c in slice.iter().map(|c| *c as char) {
116            if ('!'..='~').contains(&c) {
117                write!(f, "{c}")?;
118            } else if c == '\n' {
119                write!(f, " \\n ")?;
120            } else if c == '\t' {
121                write!(f, " \\t ")?;
122            } else if c == ' ' {
123                write!(f, " SPACE ")?;
124            } else {
125                write!(f, " 0x{:X?} ", c as u8)?
126            }
127        }
128
129        write!(f, "]")
130    }
131}
132
133/// A packed type containing a code value, as well as metadata about the symbol referred to by
134/// the code.
135///
136/// Logically, codes can range from 0-255 inclusive. This type holds both the 8-bit code as well as
137/// other metadata bit-packed into a `u16`.
138///
139/// The bottom 8 bits contain EITHER a code for a symbol stored in the table, OR a raw byte.
140///
141/// The interpretation depends on the 9th bit: when toggled off, the value stores a raw byte, and when
142/// toggled on, it stores a code. Thus if you examine the bottom 9 bits of the `u16`, you have an extended
143/// code range, where the values 0-255 are raw bytes, and the values 256-510 represent codes 0-254. 511 is
144/// a placeholder for the invalid code here.
145///
146/// Bits 12-15 store the length of the symbol (values ranging from 0-8).
147#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
148struct Code(u16);
149
150/// Code used to indicate bytes that are not in the symbol table.
151///
152/// When compressing a string that cannot fully be expressed with the symbol table, the compressed
153/// output will contain an `ESCAPE` byte followed by a raw byte. At decompression time, the presence
154/// of `ESCAPE` indicates that the next byte should be appended directly to the result instead of
155/// being looked up in the symbol table.
156pub const ESCAPE_CODE: u8 = 255;
157
158/// Number of bits in the `ExtendedCode` that are used to dictate a code value.
159pub const FSST_CODE_BITS: usize = 9;
160
161/// First bit of the "length" portion of an extended code.
162pub const FSST_LEN_BITS: usize = 12;
163
164/// Maximum code value in the extended code range.
165pub const FSST_CODE_MAX: u16 = 1 << FSST_CODE_BITS;
166
167/// Maximum value for the extended code range.
168///
169/// When truncated to u8 this is code 255, which is equivalent to [`ESCAPE_CODE`].
170pub const FSST_CODE_MASK: u16 = FSST_CODE_MAX - 1;
171
172/// First code in the symbol table that corresponds to a non-escape symbol.
173pub const FSST_CODE_BASE: u16 = 256;
174
175#[allow(clippy::len_without_is_empty)]
176impl Code {
177    /// Code for an unused slot in a symbol table or index.
178    ///
179    /// This corresponds to the maximum code with a length of 1.
180    pub const UNUSED: Self = Code(FSST_CODE_MASK + (1 << 12));
181
182    /// Create a new code for a symbol of given length.
183    fn new_symbol(code: u8, len: usize) -> Self {
184        Self(code as u16 + ((len as u16) << FSST_LEN_BITS))
185    }
186
187    /// Code for a new symbol during the building phase.
188    ///
189    /// The code is remapped from 0..254 to 256...510.
190    fn new_symbol_building(code: u8, len: usize) -> Self {
191        Self(code as u16 + 256 + ((len as u16) << FSST_LEN_BITS))
192    }
193
194    /// Create a new code corresponding for an escaped byte.
195    fn new_escape(byte: u8) -> Self {
196        Self((byte as u16) + (1 << FSST_LEN_BITS))
197    }
198
199    #[inline]
200    fn code(self) -> u8 {
201        self.0 as u8
202    }
203
204    #[inline]
205    fn extended_code(self) -> u16 {
206        self.0 & 0b111_111_111
207    }
208
209    #[inline]
210    fn len(self) -> u16 {
211        self.0 >> FSST_LEN_BITS
212    }
213}
214
215impl Debug for Code {
216    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
217        f.debug_struct("TrainingCode")
218            .field("code", &(self.0 as u8))
219            .field("is_escape", &(self.0 < 256))
220            .field("len", &(self.0 >> 12))
221            .finish()
222    }
223}
224
225/// Decompressor uses a symbol table to take a stream of 8-bit codes into a string.
226#[derive(Clone)]
227pub struct Decompressor<'a> {
228    /// Slice mapping codes to symbols.
229    pub(crate) symbols: &'a [Symbol],
230
231    /// Slice containing the length of each symbol in the `symbols` slice.
232    pub(crate) lengths: &'a [u8],
233}
234
235impl<'a> Decompressor<'a> {
236    /// Returns a new decompressor that uses the provided symbol table.
237    ///
238    /// # Panics
239    ///
240    /// If the provided symbol table has length greater than 256
241    pub fn new(symbols: &'a [Symbol], lengths: &'a [u8]) -> Self {
242        assert!(
243            symbols.len() < FSST_CODE_BASE as usize,
244            "symbol table cannot have size exceeding 255"
245        );
246
247        Self { symbols, lengths }
248    }
249
250    /// Returns an upper bound on the size of the decompressed data.
251    pub fn max_decompression_capacity(&self, compressed: &[u8]) -> usize {
252        size_of::<Symbol>() * (compressed.len() + 1)
253    }
254
255    /// Decompress a slice of codes into a provided buffer.
256    ///
257    /// The provided `decoded` buffer must be at least the size of the decoded data, plus
258    /// an additional 7 bytes.
259    ///
260    /// ## Panics
261    ///
262    /// If the caller fails to provide sufficient capacity in the decoded buffer. An upper bound
263    /// on the required capacity can be obtained by calling [`Self::max_decompression_capacity`].
264    ///
265    /// ## Example
266    ///
267    /// ```
268    /// use fsst::{Symbol, Compressor, CompressorBuilder};
269    /// let compressor = {
270    ///     let mut builder = CompressorBuilder::new();
271    ///     builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', b'o', b'o', b'o']), 8);
272    ///     builder.build()
273    /// };
274    ///
275    /// let decompressor = compressor.decompressor();
276    ///
277    /// let mut decompressed = Vec::with_capacity(8 + 7);
278    ///
279    /// let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
280    /// assert_eq!(len, 8);
281    /// unsafe { decompressed.set_len(len) };
282    /// assert_eq!(&decompressed, "helloooo".as_bytes());
283    /// ```
284    pub fn decompress_into(&self, compressed: &[u8], decoded: &mut [MaybeUninit<u8>]) -> usize {
285        // Ensure the target buffer is at least half the size of the input buffer.
286        // This is the theortical smallest a valid target can be, and occurs when
287        // every input code is an escape.
288        assert!(
289            decoded.len() >= compressed.len() / 2,
290            "decoded is smaller than lower-bound decompressed size"
291        );
292
293        unsafe {
294            let mut in_ptr = compressed.as_ptr();
295            let _in_begin = in_ptr;
296            let in_end = in_ptr.add(compressed.len());
297
298            let mut out_ptr: *mut u8 = decoded.as_mut_ptr().cast();
299            let out_begin = out_ptr.cast_const();
300            let out_end = decoded.as_ptr().add(decoded.len()).cast::<u8>();
301
302            macro_rules! store_next_symbol {
303                ($code:expr) => {{
304                    out_ptr
305                        .cast::<u64>()
306                        .write_unaligned(self.symbols.get_unchecked($code as usize).as_u64());
307                    out_ptr = out_ptr.add(*self.lengths.get_unchecked($code as usize) as usize);
308                }};
309            }
310
311            // First we try loading 8 bytes at a time.
312            if decoded.len() >= 8 * size_of::<Symbol>() && compressed.len() >= 8 {
313                // Extract the loop condition since the compiler fails to do so
314                let block_out_end = out_end.sub(8 * size_of::<Symbol>());
315                let block_in_end = in_end.sub(8);
316
317                while out_ptr.cast_const() <= block_out_end && in_ptr < block_in_end {
318                    // Note that we load a little-endian u64 here.
319                    let next_block = in_ptr.cast::<u64>().read_unaligned();
320                    let escape_mask = (next_block & 0x8080808080808080)
321                        & ((((!next_block) & 0x7F7F7F7F7F7F7F7F) + 0x7F7F7F7F7F7F7F7F)
322                            ^ 0x8080808080808080);
323
324                    // If there are no escape codes, we write each symbol one by one.
325                    if escape_mask == 0 {
326                        let code = (next_block & 0xFF) as u8;
327                        store_next_symbol!(code);
328                        let code = ((next_block >> 8) & 0xFF) as u8;
329                        store_next_symbol!(code);
330                        let code = ((next_block >> 16) & 0xFF) as u8;
331                        store_next_symbol!(code);
332                        let code = ((next_block >> 24) & 0xFF) as u8;
333                        store_next_symbol!(code);
334                        let code = ((next_block >> 32) & 0xFF) as u8;
335                        store_next_symbol!(code);
336                        let code = ((next_block >> 40) & 0xFF) as u8;
337                        store_next_symbol!(code);
338                        let code = ((next_block >> 48) & 0xFF) as u8;
339                        store_next_symbol!(code);
340                        let code = ((next_block >> 56) & 0xFF) as u8;
341                        store_next_symbol!(code);
342                        in_ptr = in_ptr.add(8);
343                    } else {
344                        // Otherwise, find the first escape code and write the symbols up to that point.
345                        let first_escape_pos = escape_mask.trailing_zeros() >> 3; // Divide bits to bytes
346                        debug_assert!(first_escape_pos < 8);
347                        match first_escape_pos {
348                            7 => {
349                                let code = (next_block & 0xFF) as u8;
350                                store_next_symbol!(code);
351                                let code = ((next_block >> 8) & 0xFF) as u8;
352                                store_next_symbol!(code);
353                                let code = ((next_block >> 16) & 0xFF) as u8;
354                                store_next_symbol!(code);
355                                let code = ((next_block >> 24) & 0xFF) as u8;
356                                store_next_symbol!(code);
357                                let code = ((next_block >> 32) & 0xFF) as u8;
358                                store_next_symbol!(code);
359                                let code = ((next_block >> 40) & 0xFF) as u8;
360                                store_next_symbol!(code);
361                                let code = ((next_block >> 48) & 0xFF) as u8;
362                                store_next_symbol!(code);
363
364                                in_ptr = in_ptr.add(7);
365                            }
366                            6 => {
367                                let code = (next_block & 0xFF) as u8;
368                                store_next_symbol!(code);
369                                let code = ((next_block >> 8) & 0xFF) as u8;
370                                store_next_symbol!(code);
371                                let code = ((next_block >> 16) & 0xFF) as u8;
372                                store_next_symbol!(code);
373                                let code = ((next_block >> 24) & 0xFF) as u8;
374                                store_next_symbol!(code);
375                                let code = ((next_block >> 32) & 0xFF) as u8;
376                                store_next_symbol!(code);
377                                let code = ((next_block >> 40) & 0xFF) as u8;
378                                store_next_symbol!(code);
379
380                                let escaped = ((next_block >> 56) & 0xFF) as u8;
381                                out_ptr.write(escaped);
382                                out_ptr = out_ptr.add(1);
383
384                                in_ptr = in_ptr.add(8);
385                            }
386                            5 => {
387                                let code = (next_block & 0xFF) as u8;
388                                store_next_symbol!(code);
389                                let code = ((next_block >> 8) & 0xFF) as u8;
390                                store_next_symbol!(code);
391                                let code = ((next_block >> 16) & 0xFF) as u8;
392                                store_next_symbol!(code);
393                                let code = ((next_block >> 24) & 0xFF) as u8;
394                                store_next_symbol!(code);
395                                let code = ((next_block >> 32) & 0xFF) as u8;
396                                store_next_symbol!(code);
397
398                                let escaped = ((next_block >> 48) & 0xFF) as u8;
399                                out_ptr.write(escaped);
400                                out_ptr = out_ptr.add(1);
401
402                                in_ptr = in_ptr.add(7);
403                            }
404                            4 => {
405                                let code = (next_block & 0xFF) as u8;
406                                store_next_symbol!(code);
407                                let code = ((next_block >> 8) & 0xFF) as u8;
408                                store_next_symbol!(code);
409                                let code = ((next_block >> 16) & 0xFF) as u8;
410                                store_next_symbol!(code);
411                                let code = ((next_block >> 24) & 0xFF) as u8;
412                                store_next_symbol!(code);
413
414                                let escaped = ((next_block >> 40) & 0xFF) as u8;
415                                out_ptr.write(escaped);
416                                out_ptr = out_ptr.add(1);
417
418                                in_ptr = in_ptr.add(6);
419                            }
420                            3 => {
421                                let code = (next_block & 0xFF) as u8;
422                                store_next_symbol!(code);
423                                let code = ((next_block >> 8) & 0xFF) as u8;
424                                store_next_symbol!(code);
425                                let code = ((next_block >> 16) & 0xFF) as u8;
426                                store_next_symbol!(code);
427
428                                let escaped = ((next_block >> 32) & 0xFF) as u8;
429                                out_ptr.write(escaped);
430                                out_ptr = out_ptr.add(1);
431
432                                in_ptr = in_ptr.add(5);
433                            }
434                            2 => {
435                                let code = (next_block & 0xFF) as u8;
436                                store_next_symbol!(code);
437                                let code = ((next_block >> 8) & 0xFF) as u8;
438                                store_next_symbol!(code);
439
440                                let escaped = ((next_block >> 24) & 0xFF) as u8;
441                                out_ptr.write(escaped);
442                                out_ptr = out_ptr.add(1);
443
444                                in_ptr = in_ptr.add(4);
445                            }
446                            1 => {
447                                let code = (next_block & 0xFF) as u8;
448                                store_next_symbol!(code);
449
450                                let escaped = ((next_block >> 16) & 0xFF) as u8;
451                                out_ptr.write(escaped);
452                                out_ptr = out_ptr.add(1);
453
454                                in_ptr = in_ptr.add(3);
455                            }
456                            0 => {
457                                // Otherwise, we actually need to decompress the next byte
458                                // Extract the second byte from the u32
459                                let escaped = ((next_block >> 8) & 0xFF) as u8;
460                                in_ptr = in_ptr.add(2);
461                                out_ptr.write(escaped);
462                                out_ptr = out_ptr.add(1);
463                            }
464                            _ => unreachable!(),
465                        }
466                    }
467                }
468            }
469
470            // Otherwise, fall back to 1-byte reads.
471            while out_end.offset_from(out_ptr) > size_of::<Symbol>() as isize && in_ptr < in_end {
472                let code = in_ptr.read();
473                in_ptr = in_ptr.add(1);
474
475                if code == ESCAPE_CODE {
476                    out_ptr.write(in_ptr.read());
477                    in_ptr = in_ptr.add(1);
478                    out_ptr = out_ptr.add(1);
479                } else {
480                    store_next_symbol!(code);
481                }
482            }
483
484            assert_eq!(
485                in_ptr, in_end,
486                "decompression should exhaust input before output"
487            );
488
489            out_ptr.offset_from(out_begin) as usize
490        }
491    }
492
493    /// Decompress a byte slice that was previously returned by a compressor using the same symbol
494    /// table into a new vector of bytes.
495    pub fn decompress(&self, compressed: &[u8]) -> Vec<u8> {
496        let mut decoded = Vec::with_capacity(self.max_decompression_capacity(compressed) + 7);
497
498        let len = self.decompress_into(compressed, decoded.spare_capacity_mut());
499        // SAFETY: len bytes have now been initialized by the decompressor.
500        unsafe { decoded.set_len(len) };
501        decoded
502    }
503}
504
505/// A compressor that uses a symbol table to greedily compress strings.
506///
507/// The `Compressor` is the central component of FSST. You can create a compressor either by
508/// default (i.e. an empty compressor), or by [training][`Self::train`] it on an input corpus of text.
509///
510/// Example usage:
511///
512/// ```
513/// use fsst::{Symbol, Compressor, CompressorBuilder};
514/// let compressor = {
515///     let mut builder = CompressorBuilder::new();
516///     builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', 0, 0, 0]), 5);
517///     builder.build()
518/// };
519///
520/// let compressed = compressor.compress("hello".as_bytes());
521/// assert_eq!(compressed, vec![0u8]);
522/// ```
523#[derive(Clone)]
524pub struct Compressor {
525    /// Table mapping codes to symbols.
526    pub(crate) symbols: Vec<Symbol>,
527
528    /// Length of each symbol, values range from 1-8.
529    pub(crate) lengths: Vec<u8>,
530
531    /// The number of entries in the symbol table that have been populated, not counting
532    /// the escape values.
533    pub(crate) n_symbols: u8,
534
535    /// Inverted index mapping 2-byte symbols to codes
536    codes_two_byte: Vec<Code>,
537
538    /// Limit of no suffixes.
539    has_suffix_code: u8,
540
541    /// Lossy perfect hash table for looking up codes to symbols that are 3 bytes or more
542    lossy_pht: LossyPHT,
543}
544
545/// The core structure of the FSST codec, holding a mapping between `Symbol`s and `Code`s.
546///
547/// The symbol table is trained on a corpus of data in the form of a single byte array, building up
548/// a mapping of 1-byte "codes" to sequences of up to 8 plaintext bytes, or "symbols".
549impl Compressor {
550    /// Using the symbol table, runs a single cycle of compression on an input word, writing
551    /// the output into `out_ptr`.
552    ///
553    /// # Returns
554    ///
555    /// This function returns a tuple of (advance_in, advance_out) with the number of bytes
556    /// for the caller to advance the input and output pointers.
557    ///
558    /// `advance_in` is the number of bytes to advance the input pointer before the next call.
559    ///
560    /// `advance_out` is the number of bytes to advance `out_ptr` before the next call.
561    ///
562    /// # Safety
563    ///
564    /// `out_ptr` must never be NULL or otherwise point to invalid memory.
565    pub unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) {
566        // Speculatively write the first byte of `word` at offset 1. This is necessary if it is an escape, and
567        // if it isn't, it will be overwritten anyway.
568        //
569        // SAFETY: caller ensures out_ptr is not null
570        let first_byte = word as u8;
571        // SAFETY: out_ptr is not null
572        unsafe { out_ptr.byte_add(1).write_unaligned(first_byte) };
573
574        // First, check the two_bytes table
575        let code_twobyte = self.codes_two_byte[word as u16 as usize];
576
577        if code_twobyte.code() < self.has_suffix_code {
578            // 2 byte code without having to worry about longer matches.
579            // SAFETY: out_ptr is not null.
580            unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
581
582            // Advance input by symbol length (2) and output by a single code byte
583            (2, 1)
584        } else {
585            // Probe the hash table
586            let entry = self.lossy_pht.lookup(word);
587
588            // Now, downshift the `word` and the `entry` to see if they align.
589            let ignored_bits = entry.ignored_bits;
590            if entry.code != Code::UNUSED
591                && compare_masked(word, entry.symbol.as_u64(), ignored_bits)
592            {
593                // Advance the input by the symbol length (variable) and the output by one code byte
594                // SAFETY: out_ptr is not null.
595                unsafe { std::ptr::write(out_ptr, entry.code.code()) };
596                (entry.code.len() as usize, 1)
597            } else {
598                // SAFETY: out_ptr is not null
599                unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
600
601                // Advance the input by the symbol length (variable) and the output by either 1
602                // byte (if was one-byte code) or two bytes (escape).
603                (
604                    code_twobyte.len() as usize,
605                    // Predicated version of:
606                    //
607                    // if entry.code >= 256 {
608                    //      2
609                    // } else {
610                    //      1
611                    // }
612                    1 + (code_twobyte.extended_code() >> 8) as usize,
613                )
614            }
615        }
616    }
617
618    /// Compress many lines in bulk.
619    pub fn compress_bulk(&self, lines: &Vec<&[u8]>) -> Vec<Vec<u8>> {
620        let mut res = Vec::new();
621
622        for line in lines {
623            res.push(self.compress(line));
624        }
625
626        res
627    }
628
629    /// Compress a string, writing its result into a target buffer.
630    ///
631    /// The target buffer is a byte vector that must have capacity large enough
632    /// to hold the encoded data.
633    ///
634    /// When this call returns, `values` will hold the compressed bytes and have
635    /// its length set to the length of the compressed text.
636    ///
637    /// ```
638    /// use fsst::{Compressor, CompressorBuilder, Symbol};
639    ///
640    /// let mut compressor = CompressorBuilder::new();
641    /// assert!(compressor.insert(Symbol::from_slice(b"aaaaaaaa"), 8));
642    ///
643    /// let compressor = compressor.build();
644    ///
645    /// let mut compressed_values = Vec::with_capacity(1_024);
646    ///
647    /// // SAFETY: we have over-sized compressed_values.
648    /// unsafe {
649    ///     compressor.compress_into(b"aaaaaaaa", &mut compressed_values);
650    /// }
651    ///
652    /// assert_eq!(compressed_values, vec![0u8]);
653    /// ```
654    ///
655    /// # Safety
656    ///
657    /// It is up to the caller to ensure the provided buffer is large enough to hold
658    /// all encoded data.
659    pub unsafe fn compress_into(&self, plaintext: &[u8], values: &mut Vec<u8>) {
660        let mut in_ptr = plaintext.as_ptr();
661        let mut out_ptr = values.as_mut_ptr();
662
663        // SAFETY: `end` will point just after the end of the `plaintext` slice.
664        let in_end = unsafe { in_ptr.byte_add(plaintext.len()) };
665        let in_end_sub8 = in_end as usize - 8;
666        // SAFETY: `end` will point just after the end of the `values` allocation.
667        let out_end = unsafe { out_ptr.byte_add(values.capacity()) };
668
669        while (in_ptr as usize) <= in_end_sub8 && out_ptr < out_end {
670            // SAFETY: pointer ranges are checked in the loop condition
671            unsafe {
672                // Load a full 8-byte word of data from in_ptr.
673                // SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
674                let word: u64 = std::ptr::read_unaligned(in_ptr as *const u64);
675                let (advance_in, advance_out) = self.compress_word(word, out_ptr);
676                in_ptr = in_ptr.byte_add(advance_in);
677                out_ptr = out_ptr.byte_add(advance_out);
678            };
679        }
680
681        let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
682        assert!(
683            out_ptr < out_end || remaining_bytes == 0,
684            "output buffer sized too small"
685        );
686
687        let remaining_bytes = remaining_bytes as usize;
688
689        // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
690        // but shift data out of this word rather than advancing an input pointer and potentially reading
691        // unowned memory.
692        let mut bytes = [0u8; 8];
693        // SAFETY: remaining_bytes <= 8
694        unsafe { std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes) };
695        let mut last_word = u64::from_le_bytes(bytes);
696
697        while in_ptr < in_end && out_ptr < out_end {
698            // Load a full 8-byte word of data from in_ptr.
699            // SAFETY: caller asserts in_ptr is not null
700            let (advance_in, advance_out) = unsafe { self.compress_word(last_word, out_ptr) };
701            // SAFETY: pointer ranges are checked in the loop condition
702            unsafe {
703                in_ptr = in_ptr.add(advance_in);
704                out_ptr = out_ptr.add(advance_out);
705            }
706
707            last_word = advance_8byte_word(last_word, advance_in);
708        }
709
710        // in_ptr should have exceeded in_end
711        assert!(
712            in_ptr >= in_end,
713            "exhausted output buffer before exhausting input, there is a bug in SymbolTable::compress()"
714        );
715
716        assert!(out_ptr <= out_end, "output buffer sized too small");
717
718        // SAFETY: out_ptr is derived from the `values` allocation.
719        let bytes_written = unsafe { out_ptr.offset_from(values.as_ptr()) };
720        assert!(
721            bytes_written >= 0,
722            "out_ptr ended before it started, not possible"
723        );
724
725        // SAFETY: we have initialized `bytes_written` values in the output buffer.
726        unsafe { values.set_len(bytes_written as usize) };
727    }
728
729    /// Use the symbol table to compress the plaintext into a sequence of codes and escapes.
730    pub fn compress(&self, plaintext: &[u8]) -> Vec<u8> {
731        if plaintext.is_empty() {
732            return Vec::new();
733        }
734
735        let mut buffer = Vec::with_capacity(plaintext.len() * 2);
736
737        // SAFETY: the largest compressed size would be all escapes == 2*plaintext_len
738        unsafe { self.compress_into(plaintext, &mut buffer) };
739
740        buffer
741    }
742
743    /// Access the decompressor that can be used to decompress strings emitted from this
744    /// `Compressor` instance.
745    pub fn decompressor(&self) -> Decompressor {
746        Decompressor::new(self.symbol_table(), self.symbol_lengths())
747    }
748
749    /// Returns a readonly slice of the current symbol table.
750    ///
751    /// The returned slice will have length of `n_symbols`.
752    pub fn symbol_table(&self) -> &[Symbol] {
753        &self.symbols[0..self.n_symbols as usize]
754    }
755
756    /// Returns a readonly slice where index `i` contains the
757    /// length of the symbol represented by code `i`.
758    ///
759    /// Values range from 1-8.
760    pub fn symbol_lengths(&self) -> &[u8] {
761        &self.lengths[0..self.n_symbols as usize]
762    }
763
764    /// Rebuild a compressor from an existing symbol table.
765    ///
766    /// This will not attempt to optimize or re-order the codes.
767    pub fn rebuild_from(symbols: impl AsRef<[Symbol]>, symbol_lens: impl AsRef<[u8]>) -> Self {
768        let symbols = symbols.as_ref();
769        let symbol_lens = symbol_lens.as_ref();
770
771        assert_eq!(
772            symbols.len(),
773            symbol_lens.len(),
774            "symbols and lengths differ"
775        );
776        assert!(
777            symbols.len() <= 255,
778            "symbol table len must be <= 255, was {}",
779            symbols.len()
780        );
781        validate_symbol_order(symbol_lens);
782
783        // Insert the symbols in their given order into the FSST lookup structures.
784        let symbols = symbols.to_vec();
785        let lengths = symbol_lens.to_vec();
786        let mut lossy_pht = LossyPHT::new();
787
788        let mut codes_one_byte = vec![Code::UNUSED; 256];
789
790        // Insert all of the one byte symbols first.
791        for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
792            if len == 1 {
793                codes_one_byte[symbol.first_byte() as usize] = Code::new_symbol(code as u8, 1);
794            }
795        }
796
797        // Initialize the codes_two_byte table to be all escapes
798        let mut codes_two_byte = vec![Code::UNUSED; 65_536];
799
800        // Insert the two byte symbols, possibly overwriting slots for one-byte symbols and escapes.
801        for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
802            match len {
803                2 => {
804                    codes_two_byte[symbol.first2() as usize] = Code::new_symbol(code as u8, 2);
805                }
806                3.. => {
807                    assert!(
808                        lossy_pht.insert(symbol, len as usize, code as u8),
809                        "rebuild symbol insertion into PHT must succeed"
810                    );
811                }
812                _ => { /* Covered by the 1-byte loop above. */ }
813            }
814        }
815
816        // Build the finished codes_two_byte table, subbing in unused positions with the
817        // codes_one_byte value similar to what we do in CompressBuilder::finalize.
818        for (symbol, code) in codes_two_byte.iter_mut().enumerate() {
819            if *code == Code::UNUSED {
820                *code = codes_one_byte[symbol & 0xFF];
821            }
822        }
823
824        // Find the position of the first 2-byte code that has a suffix later in the table
825        let mut has_suffix_code = 0u8;
826        for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
827            if len != 2 {
828                break;
829            }
830            let rest = &symbols[code..];
831            if rest
832                .iter()
833                .any(|&other| other.len() > 2 && symbol.first2() == other.first2())
834            {
835                has_suffix_code = code as u8;
836                break;
837            }
838        }
839
840        Compressor {
841            n_symbols: symbols.len() as u8,
842            symbols,
843            lengths,
844            codes_two_byte,
845            lossy_pht,
846            has_suffix_code,
847        }
848    }
849}
850
851#[inline]
852pub(crate) fn advance_8byte_word(word: u64, bytes: usize) -> u64 {
853    // shift the word off the low-end, because little endian means the first
854    // char is stored in the LSB.
855    //
856    // Note that even though this looks like it branches, Rust compiles this to a
857    // conditional move instruction. See `<https://godbolt.org/z/Pbvre65Pq>`
858    if bytes == 8 { 0 } else { word >> (8 * bytes) }
859}
860
861fn validate_symbol_order(symbol_lens: &[u8]) {
862    // Ensure that the symbol table is ordered by length, 23456781
863    let mut expected = 2;
864    for (idx, &len) in symbol_lens.iter().enumerate() {
865        if expected == 1 {
866            assert_eq!(
867                len, 1,
868                "symbol code={idx} should be one byte, was {len} bytes"
869            );
870        } else {
871            if len == 1 {
872                expected = 1;
873            }
874
875            // we're in the non-zero portion.
876            assert!(
877                len >= expected,
878                "symbol code={idx} breaks violates FSST symbol table ordering"
879            );
880            expected = len;
881        }
882    }
883}
884
885#[inline]
886pub(crate) fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool {
887    let mask = u64::MAX >> ignored_bits;
888    (left & mask) == right
889}
890
891#[cfg(test)]
892mod test {
893    use super::*;
894    use std::{iter, mem};
895    #[test]
896    fn test_stuff() {
897        let compressor = {
898            let mut builder = CompressorBuilder::new();
899            builder.insert(Symbol::from_slice(b"helloooo"), 8);
900            builder.build()
901        };
902
903        let decompressor = compressor.decompressor();
904
905        let mut decompressed = Vec::with_capacity(8 + 7);
906
907        let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
908        assert_eq!(len, 8);
909        unsafe { decompressed.set_len(len) };
910        assert_eq!(&decompressed, "helloooo".as_bytes());
911    }
912
913    #[test]
914    fn test_symbols_good() {
915        let symbols_u64: &[u64] = &[
916            24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
917            24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
918            6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
919            6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
920            6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
921            6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
922            6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
923            6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
924            6447715, 97, 98, 100, 99, 97, 98, 99, 100,
925        ];
926        let symbols: &[Symbol] = unsafe { mem::transmute(symbols_u64) };
927        let lens: Vec<u8> = iter::repeat_n(2u8, 16)
928            .chain(iter::repeat_n(3u8, 61))
929            .chain(iter::repeat_n(1u8, 8))
930            .collect();
931
932        let compressor = Compressor::rebuild_from(symbols, lens);
933        let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
934        assert_eq!(built_symbols, symbols_u64);
935    }
936
937    #[should_panic(expected = "assertion `left == right` failed")]
938    #[test]
939    fn test_symbols_bad() {
940        let symbols: &[u64] = &[
941            24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
942            24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
943            6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
944            6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
945            6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
946            6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
947            6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
948            6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
949            6447715, 97, 98, 100, 99, 97, 98, 99, 100,
950        ];
951        let lens: Vec<u8> = iter::repeat_n(2u8, 16)
952            .chain(iter::repeat_n(3u8, 61))
953            .chain(iter::repeat_n(1u8, 8))
954            .collect();
955
956        let mut builder = CompressorBuilder::new();
957        for (symbol, len) in symbols.iter().zip(lens.iter()) {
958            let symbol = Symbol::from_slice(&symbol.to_le_bytes());
959            builder.insert(symbol, *len as usize);
960        }
961        let compressor = builder.build();
962        let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
963        assert_eq!(built_symbols, symbols);
964    }
965}