Skip to main content

fsst/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg(target_endian = "little")]
3
4/// Throw a compiler error if a type isn't guaranteed to have a specific size in bytes.
5macro_rules! assert_sizeof {
6    ($typ:ty => $size_in_bytes:expr) => {
7        const _: [u8; $size_in_bytes] = [0; std::mem::size_of::<$typ>()];
8    };
9}
10
11use lossy_pht::LossyPHT;
12use std::fmt::{Debug, Formatter};
13use std::mem::MaybeUninit;
14
15mod builder;
16mod lossy_pht;
17
18pub use builder::*;
19
20/// `Symbol`s are small (up to 8-byte) segments of strings, stored in a [`Compressor`][`crate::Compressor`] and
21/// identified by an 8-bit code.
22#[derive(Copy, Clone, PartialEq, Eq, Hash)]
23pub struct Symbol(u64);
24
25assert_sizeof!(Symbol => 8);
26
27impl Symbol {
28    /// Zero value for `Symbol`.
29    pub const ZERO: Self = Self::zero();
30
31    /// Constructor for a `Symbol` from an 8-element byte slice.
32    pub fn from_slice(slice: &[u8; 8]) -> Self {
33        let num: u64 = u64::from_le_bytes(*slice);
34
35        Self(num)
36    }
37
38    /// Return a zero symbol
39    const fn zero() -> Self {
40        Self(0)
41    }
42
43    /// Create a new single-byte symbol
44    pub fn from_u8(value: u8) -> Self {
45        Self(value as u64)
46    }
47}
48
49impl Symbol {
50    /// Calculate the length of the symbol in bytes. Always a value between 1 and 8.
51    ///
52    /// Each symbol has the capacity to hold up to 8 bytes of data, but the symbols
53    /// can contain fewer bytes, padded with 0x00. There is a special case of a symbol
54    /// that holds the byte 0x00. In that case, the symbol contains `0x0000000000000000`
55    /// but we want to interpret that as a one-byte symbol containing `0x00`.
56    #[allow(clippy::len_without_is_empty)]
57    pub fn len(self) -> usize {
58        let numeric = self.0;
59        // For little-endian platforms, this counts the number of *trailing* zeros
60        let null_bytes = (numeric.leading_zeros() >> 3) as usize;
61
62        // Special case handling of a symbol with all-zeros. This is actually
63        // a 1-byte symbol containing 0x00.
64        let len = size_of::<Self>() - null_bytes;
65        if len == 0 { 1 } else { len }
66    }
67
68    /// Returns the Symbol's inner representation.
69    #[inline]
70    pub fn to_u64(self) -> u64 {
71        self.0
72    }
73
74    /// Get the first byte of the symbol as a `u8`.
75    ///
76    /// If the symbol is empty, this will return the zero byte.
77    #[inline]
78    pub fn first_byte(self) -> u8 {
79        self.0 as u8
80    }
81
82    /// Get the first two bytes of the symbol as a `u16`.
83    ///
84    /// If the Symbol is one or zero bytes, this will return `0u16`.
85    #[inline]
86    pub fn first2(self) -> u16 {
87        self.0 as u16
88    }
89
90    /// Get the first three bytes of the symbol as a `u64`.
91    ///
92    /// If the Symbol is one or zero bytes, this will return `0u64`.
93    #[inline]
94    pub fn first3(self) -> u64 {
95        self.0 & 0xFF_FF_FF
96    }
97
98    /// Return a new `Symbol` by logically concatenating ourselves with another `Symbol`.
99    pub fn concat(self, other: Self) -> Self {
100        assert!(
101            self.len() + other.len() <= 8,
102            "cannot build symbol with length > 8"
103        );
104
105        let self_len = self.len();
106
107        Self((other.0 << (8 * self_len)) | self.0)
108    }
109}
110
111impl Debug for Symbol {
112    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
113        write!(f, "[")?;
114
115        let slice = &self.0.to_le_bytes()[0..self.len()];
116        for c in slice.iter().map(|c| *c as char) {
117            if ('!'..='~').contains(&c) {
118                write!(f, "{c}")?;
119            } else if c == '\n' {
120                write!(f, " \\n ")?;
121            } else if c == '\t' {
122                write!(f, " \\t ")?;
123            } else if c == ' ' {
124                write!(f, " SPACE ")?;
125            } else {
126                write!(f, " 0x{:X?} ", c as u8)?
127            }
128        }
129
130        write!(f, "]")
131    }
132}
133
134/// A packed type containing a code value, as well as metadata about the symbol referred to by
135/// the code.
136///
137/// Logically, codes can range from 0-255 inclusive. This type holds both the 8-bit code as well as
138/// other metadata bit-packed into a `u16`.
139///
140/// The bottom 8 bits contain EITHER a code for a symbol stored in the table, OR a raw byte.
141///
142/// The interpretation depends on the 9th bit: when toggled off, the value stores a raw byte, and when
143/// toggled on, it stores a code. Thus if you examine the bottom 9 bits of the `u16`, you have an extended
144/// code range, where the values 0-255 are raw bytes, and the values 256-510 represent codes 0-254. 511 is
145/// a placeholder for the invalid code here.
146///
147/// Bits 12-15 store the length of the symbol (values ranging from 0-8).
148#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
149struct Code(u16);
150
151/// Code used to indicate bytes that are not in the symbol table.
152///
153/// When compressing a string that cannot fully be expressed with the symbol table, the compressed
154/// output will contain an `ESCAPE` byte followed by a raw byte. At decompression time, the presence
155/// of `ESCAPE` indicates that the next byte should be appended directly to the result instead of
156/// being looked up in the symbol table.
157pub const ESCAPE_CODE: u8 = 255;
158
159/// Number of bits in the `ExtendedCode` that are used to dictate a code value.
160pub const FSST_CODE_BITS: usize = 9;
161
162/// First bit of the "length" portion of an extended code.
163pub const FSST_LEN_BITS: usize = 12;
164
165/// Maximum code value in the extended code range.
166pub const FSST_CODE_MAX: u16 = 1 << FSST_CODE_BITS;
167
168/// Maximum value for the extended code range.
169///
170/// When truncated to u8 this is code 255, which is equivalent to [`ESCAPE_CODE`].
171pub const FSST_CODE_MASK: u16 = FSST_CODE_MAX - 1;
172
173/// First code in the symbol table that corresponds to a non-escape symbol.
174pub const FSST_CODE_BASE: u16 = 256;
175
176#[allow(clippy::len_without_is_empty)]
177impl Code {
178    /// Code for an unused slot in a symbol table or index.
179    ///
180    /// This corresponds to the maximum code with a length of 1.
181    pub const UNUSED: Self = Code(FSST_CODE_MASK + (1 << 12));
182
183    /// Create a new code for a symbol of given length.
184    fn new_symbol(code: u8, len: usize) -> Self {
185        Self(code as u16 + ((len as u16) << FSST_LEN_BITS))
186    }
187
188    /// Code for a new symbol during the building phase.
189    ///
190    /// The code is remapped from 0..254 to 256...510.
191    fn new_symbol_building(code: u8, len: usize) -> Self {
192        Self(code as u16 + 256 + ((len as u16) << FSST_LEN_BITS))
193    }
194
195    /// Create a new code corresponding for an escaped byte.
196    fn new_escape(byte: u8) -> Self {
197        Self((byte as u16) + (1 << FSST_LEN_BITS))
198    }
199
200    #[inline]
201    fn code(self) -> u8 {
202        self.0 as u8
203    }
204
205    #[inline]
206    fn extended_code(self) -> u16 {
207        self.0 & 0b111_111_111
208    }
209
210    #[inline]
211    fn len(self) -> u16 {
212        self.0 >> FSST_LEN_BITS
213    }
214}
215
216impl Debug for Code {
217    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
218        f.debug_struct("TrainingCode")
219            .field("code", &(self.0 as u8))
220            .field("is_escape", &(self.0 < 256))
221            .field("len", &(self.0 >> 12))
222            .finish()
223    }
224}
225
226/// Decompressor uses a symbol table to take a stream of 8-bit codes into a string.
227#[derive(Clone)]
228pub struct Decompressor<'a> {
229    /// Slice mapping codes to symbols.
230    pub(crate) symbols: &'a [Symbol],
231
232    /// Slice containing the length of each symbol in the `symbols` slice.
233    pub(crate) lengths: &'a [u8],
234}
235
236impl<'a> Decompressor<'a> {
237    /// Returns a new decompressor that uses the provided symbol table.
238    ///
239    /// # Panics
240    ///
241    /// If the provided symbol table has length greater than or equal to [`FSST_CODE_BASE`]
242    pub fn new(symbols: &'a [Symbol], lengths: &'a [u8]) -> Self {
243        assert!(
244            symbols.len() < FSST_CODE_BASE as usize,
245            "symbol table cannot have size exceeding 255"
246        );
247
248        Self { symbols, lengths }
249    }
250
251    /// Returns an upper bound on the size of the decompressed data.
252    pub fn max_decompression_capacity(&self, compressed: &[u8]) -> usize {
253        size_of::<Symbol>() * (compressed.len() + 1)
254    }
255
256    /// Decompress a slice of codes into a provided buffer.
257    ///
258    /// The provided `decoded` buffer must be at least the size of the decoded data.
259    ///
260    /// ## Panics
261    ///
262    /// If the caller fails to provide sufficient capacity in the decoded buffer.
263    /// An upper bound on the required capacity can be obtained by calling [`Self::max_decompression_capacity`].
264    ///
265    /// ## Example
266    ///
267    /// ```
268    /// use fsst::{Symbol, Compressor, CompressorBuilder};
269    /// let compressor = {
270    ///     let mut builder = CompressorBuilder::new();
271    ///     builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', b'o', b'o', b'o']), 8);
272    ///     builder.build()
273    /// };
274    ///
275    /// let decompressor = compressor.decompressor();
276    ///
277    /// let mut decompressed = Vec::with_capacity(8);
278    ///
279    /// let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
280    /// assert_eq!(len, 8);
281    /// unsafe { decompressed.set_len(len) };
282    /// assert_eq!(&decompressed, "helloooo".as_bytes());
283    /// ```
284    pub fn decompress_into(&self, compressed: &[u8], decoded: &mut [MaybeUninit<u8>]) -> usize {
285        // Ensure the target buffer is at least half the size of the input buffer.
286        // This is the theortical smallest a valid target can be, and occurs when
287        // every input code is an escape.
288        assert!(
289            decoded.len() >= compressed.len() / 2,
290            "decoded is smaller than lower-bound decompressed size"
291        );
292
293        unsafe {
294            let mut in_ptr = compressed.as_ptr();
295            let _in_begin = in_ptr;
296            let in_end = in_ptr.add(compressed.len());
297
298            let mut out_ptr: *mut u8 = decoded.as_mut_ptr().cast();
299            let out_begin = out_ptr.cast_const();
300            let out_end = decoded.as_ptr().add(decoded.len()).cast::<u8>();
301
302            macro_rules! store_next_symbol {
303                ($code:expr) => {{
304                    out_ptr
305                        .cast::<u64>()
306                        .write_unaligned(self.symbols.get_unchecked($code as usize).to_u64());
307                    out_ptr = out_ptr.add(*self.lengths.get_unchecked($code as usize) as usize);
308                }};
309            }
310
311            // First we try loading 8 bytes at a time.
312            if decoded.len() >= 8 * size_of::<Symbol>() && compressed.len() >= 8 {
313                // Extract the loop condition since the compiler fails to do so
314                let block_out_end = out_end.sub(8 * size_of::<Symbol>());
315                let block_in_end = in_end.sub(8);
316
317                while out_ptr.cast_const() <= block_out_end && in_ptr < block_in_end {
318                    // Note that we load a little-endian u64 here.
319                    let next_block = in_ptr.cast::<u64>().read_unaligned();
320                    let escape_mask = (next_block & 0x8080808080808080)
321                        & ((((!next_block) & 0x7F7F7F7F7F7F7F7F) + 0x7F7F7F7F7F7F7F7F)
322                            ^ 0x8080808080808080);
323
324                    // If there are no escape codes, we write each symbol one by one.
325                    if escape_mask == 0 {
326                        let code = (next_block & 0xFF) as u8;
327                        store_next_symbol!(code);
328                        let code = ((next_block >> 8) & 0xFF) as u8;
329                        store_next_symbol!(code);
330                        let code = ((next_block >> 16) & 0xFF) as u8;
331                        store_next_symbol!(code);
332                        let code = ((next_block >> 24) & 0xFF) as u8;
333                        store_next_symbol!(code);
334                        let code = ((next_block >> 32) & 0xFF) as u8;
335                        store_next_symbol!(code);
336                        let code = ((next_block >> 40) & 0xFF) as u8;
337                        store_next_symbol!(code);
338                        let code = ((next_block >> 48) & 0xFF) as u8;
339                        store_next_symbol!(code);
340                        let code = ((next_block >> 56) & 0xFF) as u8;
341                        store_next_symbol!(code);
342                        in_ptr = in_ptr.add(8);
343                    } else if (next_block & 0x00FF00FF00FF00FF) == 0x00FF00FF00FF00FF {
344                        // All 4 even-positioned bytes are ESCAPE_CODE.
345                        // Batch-extract the 4 raw bytes at odd positions.
346                        out_ptr.write(((next_block >> 8) & 0xFF) as u8);
347                        out_ptr.add(1).write(((next_block >> 24) & 0xFF) as u8);
348                        out_ptr.add(2).write(((next_block >> 40) & 0xFF) as u8);
349                        out_ptr.add(3).write(((next_block >> 56) & 0xFF) as u8);
350                        out_ptr = out_ptr.add(4);
351                        in_ptr = in_ptr.add(8);
352                    } else {
353                        // Otherwise, find the first escape code and write the symbols up to that point.
354                        let first_escape_pos = escape_mask.trailing_zeros() >> 3; // Divide bits to bytes
355                        debug_assert!(first_escape_pos < 8);
356                        match first_escape_pos {
357                            7 => {
358                                let code = (next_block & 0xFF) as u8;
359                                store_next_symbol!(code);
360                                let code = ((next_block >> 8) & 0xFF) as u8;
361                                store_next_symbol!(code);
362                                let code = ((next_block >> 16) & 0xFF) as u8;
363                                store_next_symbol!(code);
364                                let code = ((next_block >> 24) & 0xFF) as u8;
365                                store_next_symbol!(code);
366                                let code = ((next_block >> 32) & 0xFF) as u8;
367                                store_next_symbol!(code);
368                                let code = ((next_block >> 40) & 0xFF) as u8;
369                                store_next_symbol!(code);
370                                let code = ((next_block >> 48) & 0xFF) as u8;
371                                store_next_symbol!(code);
372
373                                in_ptr = in_ptr.add(7);
374                            }
375                            6 => {
376                                let code = (next_block & 0xFF) as u8;
377                                store_next_symbol!(code);
378                                let code = ((next_block >> 8) & 0xFF) as u8;
379                                store_next_symbol!(code);
380                                let code = ((next_block >> 16) & 0xFF) as u8;
381                                store_next_symbol!(code);
382                                let code = ((next_block >> 24) & 0xFF) as u8;
383                                store_next_symbol!(code);
384                                let code = ((next_block >> 32) & 0xFF) as u8;
385                                store_next_symbol!(code);
386                                let code = ((next_block >> 40) & 0xFF) as u8;
387                                store_next_symbol!(code);
388
389                                let escaped = ((next_block >> 56) & 0xFF) as u8;
390                                out_ptr.write(escaped);
391                                out_ptr = out_ptr.add(1);
392
393                                in_ptr = in_ptr.add(8);
394                            }
395                            5 => {
396                                let code = (next_block & 0xFF) as u8;
397                                store_next_symbol!(code);
398                                let code = ((next_block >> 8) & 0xFF) as u8;
399                                store_next_symbol!(code);
400                                let code = ((next_block >> 16) & 0xFF) as u8;
401                                store_next_symbol!(code);
402                                let code = ((next_block >> 24) & 0xFF) as u8;
403                                store_next_symbol!(code);
404                                let code = ((next_block >> 32) & 0xFF) as u8;
405                                store_next_symbol!(code);
406
407                                let escaped = ((next_block >> 48) & 0xFF) as u8;
408                                out_ptr.write(escaped);
409                                out_ptr = out_ptr.add(1);
410
411                                in_ptr = in_ptr.add(7);
412                            }
413                            4 => {
414                                let code = (next_block & 0xFF) as u8;
415                                store_next_symbol!(code);
416                                let code = ((next_block >> 8) & 0xFF) as u8;
417                                store_next_symbol!(code);
418                                let code = ((next_block >> 16) & 0xFF) as u8;
419                                store_next_symbol!(code);
420                                let code = ((next_block >> 24) & 0xFF) as u8;
421                                store_next_symbol!(code);
422
423                                let escaped = ((next_block >> 40) & 0xFF) as u8;
424                                out_ptr.write(escaped);
425                                out_ptr = out_ptr.add(1);
426
427                                in_ptr = in_ptr.add(6);
428                            }
429                            3 => {
430                                let code = (next_block & 0xFF) as u8;
431                                store_next_symbol!(code);
432                                let code = ((next_block >> 8) & 0xFF) as u8;
433                                store_next_symbol!(code);
434                                let code = ((next_block >> 16) & 0xFF) as u8;
435                                store_next_symbol!(code);
436
437                                let escaped = ((next_block >> 32) & 0xFF) as u8;
438                                out_ptr.write(escaped);
439                                out_ptr = out_ptr.add(1);
440
441                                in_ptr = in_ptr.add(5);
442                            }
443                            2 => {
444                                let code = (next_block & 0xFF) as u8;
445                                store_next_symbol!(code);
446                                let code = ((next_block >> 8) & 0xFF) as u8;
447                                store_next_symbol!(code);
448
449                                let escaped = ((next_block >> 24) & 0xFF) as u8;
450                                out_ptr.write(escaped);
451                                out_ptr = out_ptr.add(1);
452
453                                in_ptr = in_ptr.add(4);
454                            }
455                            1 => {
456                                let code = (next_block & 0xFF) as u8;
457                                store_next_symbol!(code);
458
459                                let escaped = ((next_block >> 16) & 0xFF) as u8;
460                                out_ptr.write(escaped);
461                                out_ptr = out_ptr.add(1);
462
463                                in_ptr = in_ptr.add(3);
464                            }
465                            0 => {
466                                // Otherwise, we actually need to decompress the next byte
467                                // Extract the second byte from the u32
468                                let escaped = ((next_block >> 8) & 0xFF) as u8;
469                                in_ptr = in_ptr.add(2);
470                                out_ptr.write(escaped);
471                                out_ptr = out_ptr.add(1);
472                            }
473                            _ => unreachable!(),
474                        }
475                    }
476                }
477            }
478
479            // Otherwise, fall back to 1-byte reads using 8-byte writes where safe.
480            while out_end.offset_from(out_ptr) >= size_of::<Symbol>() as isize && in_ptr < in_end {
481                let code = in_ptr.read();
482                in_ptr = in_ptr.add(1);
483
484                if code == ESCAPE_CODE {
485                    assert!(
486                        in_ptr < in_end,
487                        "truncated compressed string: escape code at end of input"
488                    );
489                    out_ptr.write(in_ptr.read());
490                    in_ptr = in_ptr.add(1);
491                    out_ptr = out_ptr.add(1);
492                } else {
493                    store_next_symbol!(code);
494                }
495            }
496
497            // For the last few bytes (if any) where we can't do an 8-byte unaligned write.
498            while in_ptr < in_end {
499                let code = in_ptr.read();
500                in_ptr = in_ptr.add(1);
501
502                if code == ESCAPE_CODE {
503                    assert!(
504                        in_ptr < in_end,
505                        "truncated compressed string: escape code at end of input"
506                    );
507                    assert!(
508                        out_ptr.cast_const() < out_end,
509                        "output buffer sized too small"
510                    );
511                    out_ptr.write(in_ptr.read());
512                    in_ptr = in_ptr.add(1);
513                    out_ptr = out_ptr.add(1);
514                } else {
515                    let len = *self.lengths.get_unchecked(code as usize) as usize;
516                    assert!(
517                        out_end.offset_from(out_ptr) >= len as isize,
518                        "output buffer sized too small"
519                    );
520                    let sym = self.symbols.get_unchecked(code as usize).to_u64();
521                    let sym_bytes = sym.to_le_bytes();
522                    std::ptr::copy_nonoverlapping(sym_bytes.as_ptr(), out_ptr, len);
523                    out_ptr = out_ptr.add(len);
524                }
525            }
526
527            assert_eq!(
528                in_ptr, in_end,
529                "decompression should exhaust input before output"
530            );
531
532            out_ptr.offset_from(out_begin) as usize
533        }
534    }
535
536    /// Decompress a byte slice that was previously returned by a compressor using the same symbol
537    /// table into a new vector of bytes.
538    pub fn decompress(&self, compressed: &[u8]) -> Vec<u8> {
539        let mut decoded = Vec::with_capacity(self.max_decompression_capacity(compressed) + 7);
540
541        let len = self.decompress_into(compressed, decoded.spare_capacity_mut());
542        // SAFETY: len bytes have now been initialized by the decompressor.
543        unsafe { decoded.set_len(len) };
544        decoded
545    }
546}
547
548/// A compressor that uses a symbol table to greedily compress strings.
549///
550/// The `Compressor` is the central component of FSST. You can create a compressor either by
551/// default (i.e. an empty compressor), or by [training][`Self::train`] it on an input corpus of text.
552///
553/// Example usage:
554///
555/// ```
556/// use fsst::{Symbol, Compressor, CompressorBuilder};
557/// let compressor = {
558///     let mut builder = CompressorBuilder::new();
559///     builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', 0, 0, 0]), 5);
560///     builder.build()
561/// };
562///
563/// let compressed = compressor.compress("hello".as_bytes());
564/// assert_eq!(compressed, vec![0u8]);
565/// ```
566#[derive(Clone)]
567pub struct Compressor {
568    /// Table mapping codes to symbols.
569    pub(crate) symbols: Vec<Symbol>,
570
571    /// Length of each symbol, values range from 1-8.
572    pub(crate) lengths: Vec<u8>,
573
574    /// The number of entries in the symbol table that have been populated, not counting
575    /// the escape values.
576    pub(crate) n_symbols: u8,
577
578    /// Inverted index mapping 2-byte symbols to codes
579    codes_two_byte: Vec<Code>,
580
581    /// Limit of no suffixes.
582    has_suffix_code: u8,
583
584    /// Lossy perfect hash table for looking up codes to symbols that are 3 bytes or more
585    lossy_pht: LossyPHT,
586}
587
588/// The core structure of the FSST codec, holding a mapping between `Symbol`s and `Code`s.
589///
590/// The symbol table is trained on a corpus of data in the form of a single byte array, building up
591/// a mapping of 1-byte "codes" to sequences of up to 8 plaintext bytes, or "symbols".
592impl Compressor {
593    /// Using the symbol table, runs a single cycle of compression on an input word, writing
594    /// the output into `out_ptr`.
595    ///
596    /// # Returns
597    ///
598    /// This function returns a tuple of (advance_in, advance_out) with the number of bytes
599    /// for the caller to advance the input and output pointers.
600    ///
601    /// `advance_in` is the number of bytes to advance the input pointer before the next call.
602    ///
603    /// `advance_out` is the number of bytes to advance `out_ptr` before the next call.
604    ///
605    /// # Safety
606    ///
607    /// `out_ptr` must never be NULL or otherwise point to invalid memory.
608    pub unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) {
609        // Speculatively write the first byte of `word` at offset 1. This is necessary if it is an escape, and
610        // if it isn't, it will be overwritten anyway.
611        //
612        // SAFETY: caller ensures out_ptr is not null
613        let first_byte = word as u8;
614        // SAFETY: out_ptr is not null
615        unsafe { out_ptr.byte_add(1).write_unaligned(first_byte) };
616
617        // First, check the two_bytes table
618        // SAFETY: codes_two_byte has exactly 65536 entries and `word as u16` is always in [0, 65535].
619        let code_twobyte = unsafe { *self.codes_two_byte.get_unchecked(word as u16 as usize) };
620
621        if code_twobyte.code() < self.has_suffix_code {
622            // 2 byte code without having to worry about longer matches.
623            // SAFETY: out_ptr is not null.
624            unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
625
626            // Advance input by symbol length (2) and output by a single code byte
627            (2, 1)
628        } else {
629            // Probe the hash table
630            let entry = self.lossy_pht.lookup(word);
631
632            // Now, downshift the `word` and the `entry` to see if they align.
633            let ignored_bits = entry.ignored_bits;
634            if entry.code != Code::UNUSED
635                && compare_masked(word, entry.symbol.to_u64(), ignored_bits)
636            {
637                // Advance the input by the symbol length (variable) and the output by one code byte
638                // SAFETY: out_ptr is not null.
639                unsafe { std::ptr::write(out_ptr, entry.code.code()) };
640                (entry.code.len() as usize, 1)
641            } else {
642                // SAFETY: out_ptr is not null
643                unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
644
645                // Advance the input by the symbol length (variable) and the output by either 1
646                // byte (if was one-byte code) or two bytes (escape).
647                (
648                    code_twobyte.len() as usize,
649                    // Predicated version of:
650                    //
651                    // if entry.code >= 256 {
652                    //      2
653                    // } else {
654                    //      1
655                    // }
656                    1 + (code_twobyte.extended_code() >> 8) as usize,
657                )
658            }
659        }
660    }
661
662    /// Compress many lines in bulk.
663    pub fn compress_bulk(&self, lines: &Vec<&[u8]>) -> Vec<Vec<u8>> {
664        let mut res = Vec::new();
665
666        for line in lines {
667            res.push(self.compress(line));
668        }
669
670        res
671    }
672
673    /// Compress a string, writing its result into a target buffer.
674    ///
675    /// The target buffer is a byte vector that must have capacity large enough
676    /// to hold the encoded data.
677    ///
678    /// When this call returns, `values` will hold the compressed bytes and have
679    /// its length set to the length of the compressed text.
680    ///
681    /// ```
682    /// use fsst::{Compressor, CompressorBuilder, Symbol};
683    ///
684    /// let mut compressor = CompressorBuilder::new();
685    /// assert!(compressor.insert(Symbol::from_slice(b"aaaaaaaa"), 8));
686    ///
687    /// let compressor = compressor.build();
688    ///
689    /// let mut compressed_values = Vec::with_capacity(1_024);
690    ///
691    /// // SAFETY: we have over-sized compressed_values.
692    /// unsafe {
693    ///     compressor.compress_into(b"aaaaaaaa", &mut compressed_values);
694    /// }
695    ///
696    /// assert_eq!(compressed_values, vec![0u8]);
697    /// ```
698    ///
699    /// # Safety
700    ///
701    /// It is up to the caller to ensure the provided buffer is large enough to hold
702    /// all encoded data.
703    pub unsafe fn compress_into(&self, plaintext: &[u8], values: &mut Vec<u8>) {
704        let mut in_ptr = plaintext.as_ptr();
705        let mut out_ptr = values.as_mut_ptr();
706
707        // SAFETY: `end` will point just after the end of the `plaintext` slice.
708        let in_end = unsafe { in_ptr.byte_add(plaintext.len()) };
709        let in_end_sub8 = in_end as usize - 8;
710        // SAFETY: `end` will point just after the end of the `values` allocation.
711        let out_end = unsafe { out_ptr.byte_add(values.capacity()) };
712
713        while (in_ptr as usize) <= in_end_sub8 && unsafe { out_end.offset_from(out_ptr) } >= 2 {
714            // SAFETY: pointer ranges are checked in the loop condition
715            unsafe {
716                // Load a full 8-byte word of data from in_ptr.
717                // SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
718                let word: u64 = std::ptr::read_unaligned(in_ptr as *const u64);
719                let (advance_in, advance_out) = self.compress_word(word, out_ptr);
720                in_ptr = in_ptr.byte_add(advance_in);
721                out_ptr = out_ptr.byte_add(advance_out);
722            };
723        }
724
725        let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
726        assert!(
727            out_ptr < out_end || remaining_bytes == 0,
728            "output buffer sized too small"
729        );
730
731        let remaining_bytes = remaining_bytes as usize;
732
733        // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
734        // but shift data out of this word rather than advancing an input pointer and potentially reading
735        // unowned memory.
736        let mut bytes = [0u8; 8];
737        // SAFETY: remaining_bytes <= 8
738        unsafe { std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes) };
739        let mut last_word = u64::from_le_bytes(bytes);
740
741        while in_ptr < in_end && unsafe { out_end.offset_from(out_ptr) } >= 2 {
742            // Load a full 8-byte word of data from in_ptr.
743            // SAFETY: caller asserts in_ptr is not null
744            let (advance_in, advance_out) = unsafe { self.compress_word(last_word, out_ptr) };
745            // SAFETY: pointer ranges are checked in the loop condition
746            unsafe {
747                in_ptr = in_ptr.add(advance_in);
748                out_ptr = out_ptr.add(advance_out);
749            }
750
751            last_word = advance_8byte_word(last_word, advance_in);
752        }
753
754        // in_ptr should have exceeded in_end
755        assert!(
756            in_ptr >= in_end,
757            "exhausted output buffer before exhausting input, there is a bug in SymbolTable::compress()"
758        );
759
760        assert!(out_ptr <= out_end, "output buffer sized too small");
761
762        // SAFETY: out_ptr is derived from the `values` allocation.
763        let bytes_written = unsafe { out_ptr.offset_from(values.as_ptr()) };
764        assert!(
765            bytes_written >= 0,
766            "out_ptr ended before it started, not possible"
767        );
768
769        // SAFETY: we have initialized `bytes_written` values in the output buffer.
770        unsafe { values.set_len(bytes_written as usize) };
771    }
772
773    /// Use the symbol table to compress the plaintext into a sequence of codes and escapes.
774    pub fn compress(&self, plaintext: &[u8]) -> Vec<u8> {
775        if plaintext.is_empty() {
776            return Vec::new();
777        }
778
779        let mut buffer = Vec::with_capacity(plaintext.len() * 2);
780
781        // SAFETY: the largest compressed size would be all escapes == 2*plaintext_len
782        unsafe { self.compress_into(plaintext, &mut buffer) };
783
784        buffer
785    }
786
787    /// Access the decompressor that can be used to decompress strings emitted from this
788    /// `Compressor` instance.
789    pub fn decompressor(&self) -> Decompressor<'_> {
790        Decompressor::new(self.symbol_table(), self.symbol_lengths())
791    }
792
793    /// Returns a readonly slice of the current symbol table.
794    ///
795    /// The returned slice will have length of `n_symbols`.
796    pub fn symbol_table(&self) -> &[Symbol] {
797        &self.symbols[0..self.n_symbols as usize]
798    }
799
800    /// Returns a readonly slice where index `i` contains the
801    /// length of the symbol represented by code `i`.
802    ///
803    /// Values range from 1-8.
804    pub fn symbol_lengths(&self) -> &[u8] {
805        &self.lengths[0..self.n_symbols as usize]
806    }
807
808    /// Rebuild a compressor from an existing symbol table.
809    ///
810    /// This will not attempt to optimize or re-order the codes.
811    pub fn rebuild_from(symbols: impl AsRef<[Symbol]>, symbol_lens: impl AsRef<[u8]>) -> Self {
812        let symbols = symbols.as_ref();
813        let symbol_lens = symbol_lens.as_ref();
814
815        assert_eq!(
816            symbols.len(),
817            symbol_lens.len(),
818            "symbols and lengths differ"
819        );
820        assert!(
821            symbols.len() <= 255,
822            "symbol table len must be <= 255, was {}",
823            symbols.len()
824        );
825        validate_symbol_order(symbol_lens);
826
827        // Insert the symbols in their given order into the FSST lookup structures.
828        let symbols = symbols.to_vec();
829        let lengths = symbol_lens.to_vec();
830        let mut lossy_pht = LossyPHT::new();
831
832        let mut codes_one_byte = vec![Code::UNUSED; 256];
833
834        // Insert all of the one byte symbols first.
835        for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
836            if len == 1 {
837                codes_one_byte[symbol.first_byte() as usize] = Code::new_symbol(code as u8, 1);
838            }
839        }
840
841        // Initialize the codes_two_byte table to be all escapes
842        let mut codes_two_byte = vec![Code::UNUSED; 65_536];
843
844        // Insert the two byte symbols, possibly overwriting slots for one-byte symbols and escapes.
845        for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
846            match len {
847                2 => {
848                    codes_two_byte[symbol.first2() as usize] = Code::new_symbol(code as u8, 2);
849                }
850                3.. => {
851                    assert!(
852                        lossy_pht.insert(symbol, len as usize, code as u8),
853                        "rebuild symbol insertion into PHT must succeed"
854                    );
855                }
856                _ => { /* Covered by the 1-byte loop above. */ }
857            }
858        }
859
860        // Build the finished codes_two_byte table, subbing in unused positions with the
861        // codes_one_byte value similar to what we do in CompressBuilder::finalize.
862        for (symbol, code) in codes_two_byte.iter_mut().enumerate() {
863            if *code == Code::UNUSED {
864                *code = codes_one_byte[symbol & 0xFF];
865            }
866        }
867
868        // Find the position of the first 2-byte code that has a suffix later in the table
869        let mut has_suffix_code = 0u8;
870        for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
871            if len != 2 {
872                break;
873            }
874            let rest = &symbols[code..];
875            if rest
876                .iter()
877                .any(|&other| other.len() > 2 && symbol.first2() == other.first2())
878            {
879                has_suffix_code = code as u8;
880                break;
881            }
882        }
883
884        Compressor {
885            n_symbols: symbols.len() as u8,
886            symbols,
887            lengths,
888            codes_two_byte,
889            lossy_pht,
890            has_suffix_code,
891        }
892    }
893}
894
895#[inline]
896pub(crate) fn advance_8byte_word(word: u64, bytes: usize) -> u64 {
897    // shift the word off the low-end, because little endian means the first
898    // char is stored in the LSB.
899    //
900    // Note that even though this looks like it branches, Rust compiles this to a
901    // conditional move instruction. See `<https://godbolt.org/z/Pbvre65Pq>`
902    if bytes == 8 { 0 } else { word >> (8 * bytes) }
903}
904
905fn validate_symbol_order(symbol_lens: &[u8]) {
906    // Ensure that the symbol table is ordered by length, 23456781
907    let mut expected = 2;
908    for (idx, &len) in symbol_lens.iter().enumerate() {
909        if expected == 1 {
910            assert_eq!(
911                len, 1,
912                "symbol code={idx} should be one byte, was {len} bytes"
913            );
914        } else {
915            if len == 1 {
916                expected = 1;
917            }
918
919            // we're in the non-zero portion.
920            assert!(
921                len >= expected,
922                "symbol code={idx} breaks violates FSST symbol table ordering"
923            );
924            expected = len;
925        }
926    }
927}
928
929#[inline]
930pub(crate) fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool {
931    let mask = u64::MAX >> ignored_bits;
932    (left & mask) == right
933}
934
935#[cfg(test)]
936mod test {
937    use super::*;
938    use std::{iter, mem};
939    #[test]
940    fn test_stuff() {
941        let compressor = {
942            let mut builder = CompressorBuilder::new();
943            builder.insert(Symbol::from_slice(b"helloooo"), 8);
944            builder.build()
945        };
946
947        let decompressor = compressor.decompressor();
948
949        let mut decompressed = Vec::with_capacity(8 + 7);
950
951        let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
952        assert_eq!(len, 8);
953        unsafe { decompressed.set_len(len) };
954        assert_eq!(&decompressed, "helloooo".as_bytes());
955    }
956
957    #[test]
958    fn test_symbols_good() {
959        let symbols_u64: &[u64] = &[
960            24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
961            24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
962            6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
963            6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
964            6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
965            6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
966            6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
967            6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
968            6447715, 97, 98, 100, 99, 97, 98, 99, 100,
969        ];
970        let symbols: &[Symbol] = unsafe { mem::transmute(symbols_u64) };
971        let lens: Vec<u8> = iter::repeat_n(2u8, 16)
972            .chain(iter::repeat_n(3u8, 61))
973            .chain(iter::repeat_n(1u8, 8))
974            .collect();
975
976        let compressor = Compressor::rebuild_from(symbols, lens);
977        let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
978        assert_eq!(built_symbols, symbols_u64);
979    }
980
981    #[should_panic(expected = "assertion `left == right` failed")]
982    #[test]
983    fn test_symbols_bad() {
984        let symbols: &[u64] = &[
985            24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
986            24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
987            6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
988            6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
989            6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
990            6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
991            6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
992            6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
993            6447715, 97, 98, 100, 99, 97, 98, 99, 100,
994        ];
995        let lens: Vec<u8> = iter::repeat_n(2u8, 16)
996            .chain(iter::repeat_n(3u8, 61))
997            .chain(iter::repeat_n(1u8, 8))
998            .collect();
999
1000        let mut builder = CompressorBuilder::new();
1001        for (symbol, len) in symbols.iter().zip(lens.iter()) {
1002            let symbol = Symbol::from_slice(&symbol.to_le_bytes());
1003            builder.insert(symbol, *len as usize);
1004        }
1005        let compressor = builder.build();
1006        let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
1007        assert_eq!(built_symbols, symbols);
1008    }
1009}