fsst/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg(target_endian = "little")]
3
4/// Throw a compiler error if a type isn't guaranteed to have a specific size in bytes.
5macro_rules! assert_sizeof {
6    ($typ:ty => $size_in_bytes:expr) => {
7        const _: [u8; $size_in_bytes] = [0; std::mem::size_of::<$typ>()];
8    };
9}
10
11use lossy_pht::LossyPHT;
12use std::fmt::{Debug, Formatter};
13use std::mem::MaybeUninit;
14
15mod builder;
16mod lossy_pht;
17
18pub use builder::*;
19
20/// `Symbol`s are small (up to 8-byte) segments of strings, stored in a [`Compressor`][`crate::Compressor`] and
21/// identified by an 8-bit code.
22#[derive(Copy, Clone, PartialEq, Eq, Hash)]
23pub struct Symbol(u64);
24
25assert_sizeof!(Symbol => 8);
26
27impl Symbol {
28    /// Zero value for `Symbol`.
29    pub const ZERO: Self = Self::zero();
30
31    /// Constructor for a `Symbol` from an 8-element byte slice.
32    pub fn from_slice(slice: &[u8; 8]) -> Self {
33        let num: u64 = u64::from_le_bytes(*slice);
34
35        Self(num)
36    }
37
38    /// Return a zero symbol
39    const fn zero() -> Self {
40        Self(0)
41    }
42
43    /// Create a new single-byte symbol
44    pub fn from_u8(value: u8) -> Self {
45        Self(value as u64)
46    }
47}
48
49impl Symbol {
50    /// Calculate the length of the symbol in bytes. Always a value between 1 and 8.
51    ///
52    /// Each symbol has the capacity to hold up to 8 bytes of data, but the symbols
53    /// can contain fewer bytes, padded with 0x00. There is a special case of a symbol
54    /// that holds the byte 0x00. In that case, the symbol contains `0x0000000000000000`
55    /// but we want to interpret that as a one-byte symbol containing `0x00`.
56    #[allow(clippy::len_without_is_empty)]
57    pub fn len(self) -> usize {
58        let numeric = self.0;
59        // For little-endian platforms, this counts the number of *trailing* zeros
60        let null_bytes = (numeric.leading_zeros() >> 3) as usize;
61
62        // Special case handling of a symbol with all-zeros. This is actually
63        // a 1-byte symbol containing 0x00.
64        let len = size_of::<Self>() - null_bytes;
65        if len == 0 { 1 } else { len }
66    }
67
68    /// Returns the Symbol's inner representation.
69    #[inline]
70    pub fn to_u64(self) -> u64 {
71        self.0
72    }
73
74    /// Get the first byte of the symbol as a `u8`.
75    ///
76    /// If the symbol is empty, this will return the zero byte.
77    #[inline]
78    pub fn first_byte(self) -> u8 {
79        self.0 as u8
80    }
81
82    /// Get the first two bytes of the symbol as a `u16`.
83    ///
84    /// If the Symbol is one or zero bytes, this will return `0u16`.
85    #[inline]
86    pub fn first2(self) -> u16 {
87        self.0 as u16
88    }
89
90    /// Get the first three bytes of the symbol as a `u64`.
91    ///
92    /// If the Symbol is one or zero bytes, this will return `0u64`.
93    #[inline]
94    pub fn first3(self) -> u64 {
95        self.0 & 0xFF_FF_FF
96    }
97
98    /// Return a new `Symbol` by logically concatenating ourselves with another `Symbol`.
99    pub fn concat(self, other: Self) -> Self {
100        assert!(
101            self.len() + other.len() <= 8,
102            "cannot build symbol with length > 8"
103        );
104
105        let self_len = self.len();
106
107        Self((other.0 << (8 * self_len)) | self.0)
108    }
109}
110
111impl Debug for Symbol {
112    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
113        write!(f, "[")?;
114
115        let slice = &self.0.to_le_bytes()[0..self.len()];
116        for c in slice.iter().map(|c| *c as char) {
117            if ('!'..='~').contains(&c) {
118                write!(f, "{c}")?;
119            } else if c == '\n' {
120                write!(f, " \\n ")?;
121            } else if c == '\t' {
122                write!(f, " \\t ")?;
123            } else if c == ' ' {
124                write!(f, " SPACE ")?;
125            } else {
126                write!(f, " 0x{:X?} ", c as u8)?
127            }
128        }
129
130        write!(f, "]")
131    }
132}
133
134/// A packed type containing a code value, as well as metadata about the symbol referred to by
135/// the code.
136///
137/// Logically, codes can range from 0-255 inclusive. This type holds both the 8-bit code as well as
138/// other metadata bit-packed into a `u16`.
139///
140/// The bottom 8 bits contain EITHER a code for a symbol stored in the table, OR a raw byte.
141///
142/// The interpretation depends on the 9th bit: when toggled off, the value stores a raw byte, and when
143/// toggled on, it stores a code. Thus if you examine the bottom 9 bits of the `u16`, you have an extended
144/// code range, where the values 0-255 are raw bytes, and the values 256-510 represent codes 0-254. 511 is
145/// a placeholder for the invalid code here.
146///
147/// Bits 12-15 store the length of the symbol (values ranging from 0-8).
148#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
149struct Code(u16);
150
151/// Code used to indicate bytes that are not in the symbol table.
152///
153/// When compressing a string that cannot fully be expressed with the symbol table, the compressed
154/// output will contain an `ESCAPE` byte followed by a raw byte. At decompression time, the presence
155/// of `ESCAPE` indicates that the next byte should be appended directly to the result instead of
156/// being looked up in the symbol table.
157pub const ESCAPE_CODE: u8 = 255;
158
159/// Number of bits in the `ExtendedCode` that are used to dictate a code value.
160pub const FSST_CODE_BITS: usize = 9;
161
162/// First bit of the "length" portion of an extended code.
163pub const FSST_LEN_BITS: usize = 12;
164
165/// Maximum code value in the extended code range.
166pub const FSST_CODE_MAX: u16 = 1 << FSST_CODE_BITS;
167
168/// Maximum value for the extended code range.
169///
170/// When truncated to u8 this is code 255, which is equivalent to [`ESCAPE_CODE`].
171pub const FSST_CODE_MASK: u16 = FSST_CODE_MAX - 1;
172
173/// First code in the symbol table that corresponds to a non-escape symbol.
174pub const FSST_CODE_BASE: u16 = 256;
175
176#[allow(clippy::len_without_is_empty)]
177impl Code {
178    /// Code for an unused slot in a symbol table or index.
179    ///
180    /// This corresponds to the maximum code with a length of 1.
181    pub const UNUSED: Self = Code(FSST_CODE_MASK + (1 << 12));
182
183    /// Create a new code for a symbol of given length.
184    fn new_symbol(code: u8, len: usize) -> Self {
185        Self(code as u16 + ((len as u16) << FSST_LEN_BITS))
186    }
187
188    /// Code for a new symbol during the building phase.
189    ///
190    /// The code is remapped from 0..254 to 256...510.
191    fn new_symbol_building(code: u8, len: usize) -> Self {
192        Self(code as u16 + 256 + ((len as u16) << FSST_LEN_BITS))
193    }
194
195    /// Create a new code corresponding for an escaped byte.
196    fn new_escape(byte: u8) -> Self {
197        Self((byte as u16) + (1 << FSST_LEN_BITS))
198    }
199
200    #[inline]
201    fn code(self) -> u8 {
202        self.0 as u8
203    }
204
205    #[inline]
206    fn extended_code(self) -> u16 {
207        self.0 & 0b111_111_111
208    }
209
210    #[inline]
211    fn len(self) -> u16 {
212        self.0 >> FSST_LEN_BITS
213    }
214}
215
216impl Debug for Code {
217    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
218        f.debug_struct("TrainingCode")
219            .field("code", &(self.0 as u8))
220            .field("is_escape", &(self.0 < 256))
221            .field("len", &(self.0 >> 12))
222            .finish()
223    }
224}
225
226/// Decompressor uses a symbol table to take a stream of 8-bit codes into a string.
227#[derive(Clone)]
228pub struct Decompressor<'a> {
229    /// Slice mapping codes to symbols.
230    pub(crate) symbols: &'a [Symbol],
231
232    /// Slice containing the length of each symbol in the `symbols` slice.
233    pub(crate) lengths: &'a [u8],
234}
235
236impl<'a> Decompressor<'a> {
237    /// Returns a new decompressor that uses the provided symbol table.
238    ///
239    /// # Panics
240    ///
241    /// If the provided symbol table has length greater than 256
242    pub fn new(symbols: &'a [Symbol], lengths: &'a [u8]) -> Self {
243        assert!(
244            symbols.len() < FSST_CODE_BASE as usize,
245            "symbol table cannot have size exceeding 255"
246        );
247
248        Self { symbols, lengths }
249    }
250
251    /// Returns an upper bound on the size of the decompressed data.
252    pub fn max_decompression_capacity(&self, compressed: &[u8]) -> usize {
253        size_of::<Symbol>() * (compressed.len() + 1)
254    }
255
256    /// Decompress a slice of codes into a provided buffer.
257    ///
258    /// The provided `decoded` buffer must be at least the size of the decoded data, plus
259    /// an additional 7 bytes.
260    ///
261    /// ## Panics
262    ///
263    /// If the caller fails to provide sufficient capacity in the decoded buffer. An upper bound
264    /// on the required capacity can be obtained by calling [`Self::max_decompression_capacity`].
265    ///
266    /// ## Example
267    ///
268    /// ```
269    /// use fsst::{Symbol, Compressor, CompressorBuilder};
270    /// let compressor = {
271    ///     let mut builder = CompressorBuilder::new();
272    ///     builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', b'o', b'o', b'o']), 8);
273    ///     builder.build()
274    /// };
275    ///
276    /// let decompressor = compressor.decompressor();
277    ///
278    /// let mut decompressed = Vec::with_capacity(8 + 7);
279    ///
280    /// let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
281    /// assert_eq!(len, 8);
282    /// unsafe { decompressed.set_len(len) };
283    /// assert_eq!(&decompressed, "helloooo".as_bytes());
284    /// ```
285    pub fn decompress_into(&self, compressed: &[u8], decoded: &mut [MaybeUninit<u8>]) -> usize {
286        // Ensure the target buffer is at least half the size of the input buffer.
287        // This is the theortical smallest a valid target can be, and occurs when
288        // every input code is an escape.
289        assert!(
290            decoded.len() >= compressed.len() / 2,
291            "decoded is smaller than lower-bound decompressed size"
292        );
293
294        unsafe {
295            let mut in_ptr = compressed.as_ptr();
296            let _in_begin = in_ptr;
297            let in_end = in_ptr.add(compressed.len());
298
299            let mut out_ptr: *mut u8 = decoded.as_mut_ptr().cast();
300            let out_begin = out_ptr.cast_const();
301            let out_end = decoded.as_ptr().add(decoded.len()).cast::<u8>();
302
303            macro_rules! store_next_symbol {
304                ($code:expr) => {{
305                    out_ptr
306                        .cast::<u64>()
307                        .write_unaligned(self.symbols.get_unchecked($code as usize).to_u64());
308                    out_ptr = out_ptr.add(*self.lengths.get_unchecked($code as usize) as usize);
309                }};
310            }
311
312            // First we try loading 8 bytes at a time.
313            if decoded.len() >= 8 * size_of::<Symbol>() && compressed.len() >= 8 {
314                // Extract the loop condition since the compiler fails to do so
315                let block_out_end = out_end.sub(8 * size_of::<Symbol>());
316                let block_in_end = in_end.sub(8);
317
318                while out_ptr.cast_const() <= block_out_end && in_ptr < block_in_end {
319                    // Note that we load a little-endian u64 here.
320                    let next_block = in_ptr.cast::<u64>().read_unaligned();
321                    let escape_mask = (next_block & 0x8080808080808080)
322                        & ((((!next_block) & 0x7F7F7F7F7F7F7F7F) + 0x7F7F7F7F7F7F7F7F)
323                            ^ 0x8080808080808080);
324
325                    // If there are no escape codes, we write each symbol one by one.
326                    if escape_mask == 0 {
327                        let code = (next_block & 0xFF) as u8;
328                        store_next_symbol!(code);
329                        let code = ((next_block >> 8) & 0xFF) as u8;
330                        store_next_symbol!(code);
331                        let code = ((next_block >> 16) & 0xFF) as u8;
332                        store_next_symbol!(code);
333                        let code = ((next_block >> 24) & 0xFF) as u8;
334                        store_next_symbol!(code);
335                        let code = ((next_block >> 32) & 0xFF) as u8;
336                        store_next_symbol!(code);
337                        let code = ((next_block >> 40) & 0xFF) as u8;
338                        store_next_symbol!(code);
339                        let code = ((next_block >> 48) & 0xFF) as u8;
340                        store_next_symbol!(code);
341                        let code = ((next_block >> 56) & 0xFF) as u8;
342                        store_next_symbol!(code);
343                        in_ptr = in_ptr.add(8);
344                    } else {
345                        // Otherwise, find the first escape code and write the symbols up to that point.
346                        let first_escape_pos = escape_mask.trailing_zeros() >> 3; // Divide bits to bytes
347                        debug_assert!(first_escape_pos < 8);
348                        match first_escape_pos {
349                            7 => {
350                                let code = (next_block & 0xFF) as u8;
351                                store_next_symbol!(code);
352                                let code = ((next_block >> 8) & 0xFF) as u8;
353                                store_next_symbol!(code);
354                                let code = ((next_block >> 16) & 0xFF) as u8;
355                                store_next_symbol!(code);
356                                let code = ((next_block >> 24) & 0xFF) as u8;
357                                store_next_symbol!(code);
358                                let code = ((next_block >> 32) & 0xFF) as u8;
359                                store_next_symbol!(code);
360                                let code = ((next_block >> 40) & 0xFF) as u8;
361                                store_next_symbol!(code);
362                                let code = ((next_block >> 48) & 0xFF) as u8;
363                                store_next_symbol!(code);
364
365                                in_ptr = in_ptr.add(7);
366                            }
367                            6 => {
368                                let code = (next_block & 0xFF) as u8;
369                                store_next_symbol!(code);
370                                let code = ((next_block >> 8) & 0xFF) as u8;
371                                store_next_symbol!(code);
372                                let code = ((next_block >> 16) & 0xFF) as u8;
373                                store_next_symbol!(code);
374                                let code = ((next_block >> 24) & 0xFF) as u8;
375                                store_next_symbol!(code);
376                                let code = ((next_block >> 32) & 0xFF) as u8;
377                                store_next_symbol!(code);
378                                let code = ((next_block >> 40) & 0xFF) as u8;
379                                store_next_symbol!(code);
380
381                                let escaped = ((next_block >> 56) & 0xFF) as u8;
382                                out_ptr.write(escaped);
383                                out_ptr = out_ptr.add(1);
384
385                                in_ptr = in_ptr.add(8);
386                            }
387                            5 => {
388                                let code = (next_block & 0xFF) as u8;
389                                store_next_symbol!(code);
390                                let code = ((next_block >> 8) & 0xFF) as u8;
391                                store_next_symbol!(code);
392                                let code = ((next_block >> 16) & 0xFF) as u8;
393                                store_next_symbol!(code);
394                                let code = ((next_block >> 24) & 0xFF) as u8;
395                                store_next_symbol!(code);
396                                let code = ((next_block >> 32) & 0xFF) as u8;
397                                store_next_symbol!(code);
398
399                                let escaped = ((next_block >> 48) & 0xFF) as u8;
400                                out_ptr.write(escaped);
401                                out_ptr = out_ptr.add(1);
402
403                                in_ptr = in_ptr.add(7);
404                            }
405                            4 => {
406                                let code = (next_block & 0xFF) as u8;
407                                store_next_symbol!(code);
408                                let code = ((next_block >> 8) & 0xFF) as u8;
409                                store_next_symbol!(code);
410                                let code = ((next_block >> 16) & 0xFF) as u8;
411                                store_next_symbol!(code);
412                                let code = ((next_block >> 24) & 0xFF) as u8;
413                                store_next_symbol!(code);
414
415                                let escaped = ((next_block >> 40) & 0xFF) as u8;
416                                out_ptr.write(escaped);
417                                out_ptr = out_ptr.add(1);
418
419                                in_ptr = in_ptr.add(6);
420                            }
421                            3 => {
422                                let code = (next_block & 0xFF) as u8;
423                                store_next_symbol!(code);
424                                let code = ((next_block >> 8) & 0xFF) as u8;
425                                store_next_symbol!(code);
426                                let code = ((next_block >> 16) & 0xFF) as u8;
427                                store_next_symbol!(code);
428
429                                let escaped = ((next_block >> 32) & 0xFF) as u8;
430                                out_ptr.write(escaped);
431                                out_ptr = out_ptr.add(1);
432
433                                in_ptr = in_ptr.add(5);
434                            }
435                            2 => {
436                                let code = (next_block & 0xFF) as u8;
437                                store_next_symbol!(code);
438                                let code = ((next_block >> 8) & 0xFF) as u8;
439                                store_next_symbol!(code);
440
441                                let escaped = ((next_block >> 24) & 0xFF) as u8;
442                                out_ptr.write(escaped);
443                                out_ptr = out_ptr.add(1);
444
445                                in_ptr = in_ptr.add(4);
446                            }
447                            1 => {
448                                let code = (next_block & 0xFF) as u8;
449                                store_next_symbol!(code);
450
451                                let escaped = ((next_block >> 16) & 0xFF) as u8;
452                                out_ptr.write(escaped);
453                                out_ptr = out_ptr.add(1);
454
455                                in_ptr = in_ptr.add(3);
456                            }
457                            0 => {
458                                // Otherwise, we actually need to decompress the next byte
459                                // Extract the second byte from the u32
460                                let escaped = ((next_block >> 8) & 0xFF) as u8;
461                                in_ptr = in_ptr.add(2);
462                                out_ptr.write(escaped);
463                                out_ptr = out_ptr.add(1);
464                            }
465                            _ => unreachable!(),
466                        }
467                    }
468                }
469            }
470
471            // Otherwise, fall back to 1-byte reads.
472            while out_end.offset_from(out_ptr) > size_of::<Symbol>() as isize && in_ptr < in_end {
473                let code = in_ptr.read();
474                in_ptr = in_ptr.add(1);
475
476                if code == ESCAPE_CODE {
477                    out_ptr.write(in_ptr.read());
478                    in_ptr = in_ptr.add(1);
479                    out_ptr = out_ptr.add(1);
480                } else {
481                    store_next_symbol!(code);
482                }
483            }
484
485            assert_eq!(
486                in_ptr, in_end,
487                "decompression should exhaust input before output"
488            );
489
490            out_ptr.offset_from(out_begin) as usize
491        }
492    }
493
494    /// Decompress a byte slice that was previously returned by a compressor using the same symbol
495    /// table into a new vector of bytes.
496    pub fn decompress(&self, compressed: &[u8]) -> Vec<u8> {
497        let mut decoded = Vec::with_capacity(self.max_decompression_capacity(compressed) + 7);
498
499        let len = self.decompress_into(compressed, decoded.spare_capacity_mut());
500        // SAFETY: len bytes have now been initialized by the decompressor.
501        unsafe { decoded.set_len(len) };
502        decoded
503    }
504}
505
506/// A compressor that uses a symbol table to greedily compress strings.
507///
508/// The `Compressor` is the central component of FSST. You can create a compressor either by
509/// default (i.e. an empty compressor), or by [training][`Self::train`] it on an input corpus of text.
510///
511/// Example usage:
512///
513/// ```
514/// use fsst::{Symbol, Compressor, CompressorBuilder};
515/// let compressor = {
516///     let mut builder = CompressorBuilder::new();
517///     builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', 0, 0, 0]), 5);
518///     builder.build()
519/// };
520///
521/// let compressed = compressor.compress("hello".as_bytes());
522/// assert_eq!(compressed, vec![0u8]);
523/// ```
524#[derive(Clone)]
525pub struct Compressor {
526    /// Table mapping codes to symbols.
527    pub(crate) symbols: Vec<Symbol>,
528
529    /// Length of each symbol, values range from 1-8.
530    pub(crate) lengths: Vec<u8>,
531
532    /// The number of entries in the symbol table that have been populated, not counting
533    /// the escape values.
534    pub(crate) n_symbols: u8,
535
536    /// Inverted index mapping 2-byte symbols to codes
537    codes_two_byte: Vec<Code>,
538
539    /// Limit of no suffixes.
540    has_suffix_code: u8,
541
542    /// Lossy perfect hash table for looking up codes to symbols that are 3 bytes or more
543    lossy_pht: LossyPHT,
544}
545
546/// The core structure of the FSST codec, holding a mapping between `Symbol`s and `Code`s.
547///
548/// The symbol table is trained on a corpus of data in the form of a single byte array, building up
549/// a mapping of 1-byte "codes" to sequences of up to 8 plaintext bytes, or "symbols".
550impl Compressor {
551    /// Using the symbol table, runs a single cycle of compression on an input word, writing
552    /// the output into `out_ptr`.
553    ///
554    /// # Returns
555    ///
556    /// This function returns a tuple of (advance_in, advance_out) with the number of bytes
557    /// for the caller to advance the input and output pointers.
558    ///
559    /// `advance_in` is the number of bytes to advance the input pointer before the next call.
560    ///
561    /// `advance_out` is the number of bytes to advance `out_ptr` before the next call.
562    ///
563    /// # Safety
564    ///
565    /// `out_ptr` must never be NULL or otherwise point to invalid memory.
566    pub unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) {
567        // Speculatively write the first byte of `word` at offset 1. This is necessary if it is an escape, and
568        // if it isn't, it will be overwritten anyway.
569        //
570        // SAFETY: caller ensures out_ptr is not null
571        let first_byte = word as u8;
572        // SAFETY: out_ptr is not null
573        unsafe { out_ptr.byte_add(1).write_unaligned(first_byte) };
574
575        // First, check the two_bytes table
576        let code_twobyte = self.codes_two_byte[word as u16 as usize];
577
578        if code_twobyte.code() < self.has_suffix_code {
579            // 2 byte code without having to worry about longer matches.
580            // SAFETY: out_ptr is not null.
581            unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
582
583            // Advance input by symbol length (2) and output by a single code byte
584            (2, 1)
585        } else {
586            // Probe the hash table
587            let entry = self.lossy_pht.lookup(word);
588
589            // Now, downshift the `word` and the `entry` to see if they align.
590            let ignored_bits = entry.ignored_bits;
591            if entry.code != Code::UNUSED
592                && compare_masked(word, entry.symbol.to_u64(), ignored_bits)
593            {
594                // Advance the input by the symbol length (variable) and the output by one code byte
595                // SAFETY: out_ptr is not null.
596                unsafe { std::ptr::write(out_ptr, entry.code.code()) };
597                (entry.code.len() as usize, 1)
598            } else {
599                // SAFETY: out_ptr is not null
600                unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
601
602                // Advance the input by the symbol length (variable) and the output by either 1
603                // byte (if was one-byte code) or two bytes (escape).
604                (
605                    code_twobyte.len() as usize,
606                    // Predicated version of:
607                    //
608                    // if entry.code >= 256 {
609                    //      2
610                    // } else {
611                    //      1
612                    // }
613                    1 + (code_twobyte.extended_code() >> 8) as usize,
614                )
615            }
616        }
617    }
618
619    /// Compress many lines in bulk.
620    pub fn compress_bulk(&self, lines: &Vec<&[u8]>) -> Vec<Vec<u8>> {
621        let mut res = Vec::new();
622
623        for line in lines {
624            res.push(self.compress(line));
625        }
626
627        res
628    }
629
630    /// Compress a string, writing its result into a target buffer.
631    ///
632    /// The target buffer is a byte vector that must have capacity large enough
633    /// to hold the encoded data.
634    ///
635    /// When this call returns, `values` will hold the compressed bytes and have
636    /// its length set to the length of the compressed text.
637    ///
638    /// ```
639    /// use fsst::{Compressor, CompressorBuilder, Symbol};
640    ///
641    /// let mut compressor = CompressorBuilder::new();
642    /// assert!(compressor.insert(Symbol::from_slice(b"aaaaaaaa"), 8));
643    ///
644    /// let compressor = compressor.build();
645    ///
646    /// let mut compressed_values = Vec::with_capacity(1_024);
647    ///
648    /// // SAFETY: we have over-sized compressed_values.
649    /// unsafe {
650    ///     compressor.compress_into(b"aaaaaaaa", &mut compressed_values);
651    /// }
652    ///
653    /// assert_eq!(compressed_values, vec![0u8]);
654    /// ```
655    ///
656    /// # Safety
657    ///
658    /// It is up to the caller to ensure the provided buffer is large enough to hold
659    /// all encoded data.
660    pub unsafe fn compress_into(&self, plaintext: &[u8], values: &mut Vec<u8>) {
661        let mut in_ptr = plaintext.as_ptr();
662        let mut out_ptr = values.as_mut_ptr();
663
664        // SAFETY: `end` will point just after the end of the `plaintext` slice.
665        let in_end = unsafe { in_ptr.byte_add(plaintext.len()) };
666        let in_end_sub8 = in_end as usize - 8;
667        // SAFETY: `end` will point just after the end of the `values` allocation.
668        let out_end = unsafe { out_ptr.byte_add(values.capacity()) };
669
670        while (in_ptr as usize) <= in_end_sub8 && out_ptr < out_end {
671            // SAFETY: pointer ranges are checked in the loop condition
672            unsafe {
673                // Load a full 8-byte word of data from in_ptr.
674                // SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
675                let word: u64 = std::ptr::read_unaligned(in_ptr as *const u64);
676                let (advance_in, advance_out) = self.compress_word(word, out_ptr);
677                in_ptr = in_ptr.byte_add(advance_in);
678                out_ptr = out_ptr.byte_add(advance_out);
679            };
680        }
681
682        let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
683        assert!(
684            out_ptr < out_end || remaining_bytes == 0,
685            "output buffer sized too small"
686        );
687
688        let remaining_bytes = remaining_bytes as usize;
689
690        // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
691        // but shift data out of this word rather than advancing an input pointer and potentially reading
692        // unowned memory.
693        let mut bytes = [0u8; 8];
694        // SAFETY: remaining_bytes <= 8
695        unsafe { std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes) };
696        let mut last_word = u64::from_le_bytes(bytes);
697
698        while in_ptr < in_end && out_ptr < out_end {
699            // Load a full 8-byte word of data from in_ptr.
700            // SAFETY: caller asserts in_ptr is not null
701            let (advance_in, advance_out) = unsafe { self.compress_word(last_word, out_ptr) };
702            // SAFETY: pointer ranges are checked in the loop condition
703            unsafe {
704                in_ptr = in_ptr.add(advance_in);
705                out_ptr = out_ptr.add(advance_out);
706            }
707
708            last_word = advance_8byte_word(last_word, advance_in);
709        }
710
711        // in_ptr should have exceeded in_end
712        assert!(
713            in_ptr >= in_end,
714            "exhausted output buffer before exhausting input, there is a bug in SymbolTable::compress()"
715        );
716
717        assert!(out_ptr <= out_end, "output buffer sized too small");
718
719        // SAFETY: out_ptr is derived from the `values` allocation.
720        let bytes_written = unsafe { out_ptr.offset_from(values.as_ptr()) };
721        assert!(
722            bytes_written >= 0,
723            "out_ptr ended before it started, not possible"
724        );
725
726        // SAFETY: we have initialized `bytes_written` values in the output buffer.
727        unsafe { values.set_len(bytes_written as usize) };
728    }
729
730    /// Use the symbol table to compress the plaintext into a sequence of codes and escapes.
731    pub fn compress(&self, plaintext: &[u8]) -> Vec<u8> {
732        if plaintext.is_empty() {
733            return Vec::new();
734        }
735
736        let mut buffer = Vec::with_capacity(plaintext.len() * 2);
737
738        // SAFETY: the largest compressed size would be all escapes == 2*plaintext_len
739        unsafe { self.compress_into(plaintext, &mut buffer) };
740
741        buffer
742    }
743
744    /// Access the decompressor that can be used to decompress strings emitted from this
745    /// `Compressor` instance.
746    pub fn decompressor(&self) -> Decompressor<'_> {
747        Decompressor::new(self.symbol_table(), self.symbol_lengths())
748    }
749
750    /// Returns a readonly slice of the current symbol table.
751    ///
752    /// The returned slice will have length of `n_symbols`.
753    pub fn symbol_table(&self) -> &[Symbol] {
754        &self.symbols[0..self.n_symbols as usize]
755    }
756
757    /// Returns a readonly slice where index `i` contains the
758    /// length of the symbol represented by code `i`.
759    ///
760    /// Values range from 1-8.
761    pub fn symbol_lengths(&self) -> &[u8] {
762        &self.lengths[0..self.n_symbols as usize]
763    }
764
765    /// Rebuild a compressor from an existing symbol table.
766    ///
767    /// This will not attempt to optimize or re-order the codes.
768    pub fn rebuild_from(symbols: impl AsRef<[Symbol]>, symbol_lens: impl AsRef<[u8]>) -> Self {
769        let symbols = symbols.as_ref();
770        let symbol_lens = symbol_lens.as_ref();
771
772        assert_eq!(
773            symbols.len(),
774            symbol_lens.len(),
775            "symbols and lengths differ"
776        );
777        assert!(
778            symbols.len() <= 255,
779            "symbol table len must be <= 255, was {}",
780            symbols.len()
781        );
782        validate_symbol_order(symbol_lens);
783
784        // Insert the symbols in their given order into the FSST lookup structures.
785        let symbols = symbols.to_vec();
786        let lengths = symbol_lens.to_vec();
787        let mut lossy_pht = LossyPHT::new();
788
789        let mut codes_one_byte = vec![Code::UNUSED; 256];
790
791        // Insert all of the one byte symbols first.
792        for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
793            if len == 1 {
794                codes_one_byte[symbol.first_byte() as usize] = Code::new_symbol(code as u8, 1);
795            }
796        }
797
798        // Initialize the codes_two_byte table to be all escapes
799        let mut codes_two_byte = vec![Code::UNUSED; 65_536];
800
801        // Insert the two byte symbols, possibly overwriting slots for one-byte symbols and escapes.
802        for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
803            match len {
804                2 => {
805                    codes_two_byte[symbol.first2() as usize] = Code::new_symbol(code as u8, 2);
806                }
807                3.. => {
808                    assert!(
809                        lossy_pht.insert(symbol, len as usize, code as u8),
810                        "rebuild symbol insertion into PHT must succeed"
811                    );
812                }
813                _ => { /* Covered by the 1-byte loop above. */ }
814            }
815        }
816
817        // Build the finished codes_two_byte table, subbing in unused positions with the
818        // codes_one_byte value similar to what we do in CompressBuilder::finalize.
819        for (symbol, code) in codes_two_byte.iter_mut().enumerate() {
820            if *code == Code::UNUSED {
821                *code = codes_one_byte[symbol & 0xFF];
822            }
823        }
824
825        // Find the position of the first 2-byte code that has a suffix later in the table
826        let mut has_suffix_code = 0u8;
827        for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
828            if len != 2 {
829                break;
830            }
831            let rest = &symbols[code..];
832            if rest
833                .iter()
834                .any(|&other| other.len() > 2 && symbol.first2() == other.first2())
835            {
836                has_suffix_code = code as u8;
837                break;
838            }
839        }
840
841        Compressor {
842            n_symbols: symbols.len() as u8,
843            symbols,
844            lengths,
845            codes_two_byte,
846            lossy_pht,
847            has_suffix_code,
848        }
849    }
850}
851
852#[inline]
853pub(crate) fn advance_8byte_word(word: u64, bytes: usize) -> u64 {
854    // shift the word off the low-end, because little endian means the first
855    // char is stored in the LSB.
856    //
857    // Note that even though this looks like it branches, Rust compiles this to a
858    // conditional move instruction. See `<https://godbolt.org/z/Pbvre65Pq>`
859    if bytes == 8 { 0 } else { word >> (8 * bytes) }
860}
861
862fn validate_symbol_order(symbol_lens: &[u8]) {
863    // Ensure that the symbol table is ordered by length, 23456781
864    let mut expected = 2;
865    for (idx, &len) in symbol_lens.iter().enumerate() {
866        if expected == 1 {
867            assert_eq!(
868                len, 1,
869                "symbol code={idx} should be one byte, was {len} bytes"
870            );
871        } else {
872            if len == 1 {
873                expected = 1;
874            }
875
876            // we're in the non-zero portion.
877            assert!(
878                len >= expected,
879                "symbol code={idx} breaks violates FSST symbol table ordering"
880            );
881            expected = len;
882        }
883    }
884}
885
886#[inline]
887pub(crate) fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool {
888    let mask = u64::MAX >> ignored_bits;
889    (left & mask) == right
890}
891
892#[cfg(test)]
893mod test {
894    use super::*;
895    use std::{iter, mem};
896    #[test]
897    fn test_stuff() {
898        let compressor = {
899            let mut builder = CompressorBuilder::new();
900            builder.insert(Symbol::from_slice(b"helloooo"), 8);
901            builder.build()
902        };
903
904        let decompressor = compressor.decompressor();
905
906        let mut decompressed = Vec::with_capacity(8 + 7);
907
908        let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
909        assert_eq!(len, 8);
910        unsafe { decompressed.set_len(len) };
911        assert_eq!(&decompressed, "helloooo".as_bytes());
912    }
913
914    #[test]
915    fn test_symbols_good() {
916        let symbols_u64: &[u64] = &[
917            24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
918            24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
919            6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
920            6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
921            6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
922            6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
923            6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
924            6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
925            6447715, 97, 98, 100, 99, 97, 98, 99, 100,
926        ];
927        let symbols: &[Symbol] = unsafe { mem::transmute(symbols_u64) };
928        let lens: Vec<u8> = iter::repeat_n(2u8, 16)
929            .chain(iter::repeat_n(3u8, 61))
930            .chain(iter::repeat_n(1u8, 8))
931            .collect();
932
933        let compressor = Compressor::rebuild_from(symbols, lens);
934        let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
935        assert_eq!(built_symbols, symbols_u64);
936    }
937
938    #[should_panic(expected = "assertion `left == right` failed")]
939    #[test]
940    fn test_symbols_bad() {
941        let symbols: &[u64] = &[
942            24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
943            24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
944            6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
945            6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
946            6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
947            6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
948            6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
949            6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
950            6447715, 97, 98, 100, 99, 97, 98, 99, 100,
951        ];
952        let lens: Vec<u8> = iter::repeat_n(2u8, 16)
953            .chain(iter::repeat_n(3u8, 61))
954            .chain(iter::repeat_n(1u8, 8))
955            .collect();
956
957        let mut builder = CompressorBuilder::new();
958        for (symbol, len) in symbols.iter().zip(lens.iter()) {
959            let symbol = Symbol::from_slice(&symbol.to_le_bytes());
960            builder.insert(symbol, *len as usize);
961        }
962        let compressor = builder.build();
963        let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
964        assert_eq!(built_symbols, symbols);
965    }
966}