Skip to main content

fsst/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg(target_endian = "little")]
3
4/// Throw a compiler error if a type isn't guaranteed to have a specific size in bytes.
5macro_rules! assert_sizeof {
6    ($typ:ty => $size_in_bytes:expr) => {
7        const _: [u8; $size_in_bytes] = [0; std::mem::size_of::<$typ>()];
8    };
9}
10
11use lossy_pht::LossyPHT;
12use std::fmt::{Debug, Formatter};
13use std::mem::MaybeUninit;
14
15mod builder;
16mod lossy_pht;
17
18pub use builder::*;
19
20/// `Symbol`s are small (up to 8-byte) segments of strings, stored in a [`Compressor`][`crate::Compressor`] and
21/// identified by an 8-bit code.
22#[derive(Copy, Clone, PartialEq, Eq, Hash)]
23pub struct Symbol(u64);
24
25assert_sizeof!(Symbol => 8);
26
27impl Symbol {
28    /// Zero value for `Symbol`.
29    pub const ZERO: Self = Self::zero();
30
31    /// Constructor for a `Symbol` from an 8-element byte slice.
32    pub fn from_slice(slice: &[u8; 8]) -> Self {
33        let num: u64 = u64::from_le_bytes(*slice);
34
35        Self(num)
36    }
37
38    /// Return a zero symbol
39    const fn zero() -> Self {
40        Self(0)
41    }
42
43    /// Create a new single-byte symbol
44    pub fn from_u8(value: u8) -> Self {
45        Self(value as u64)
46    }
47}
48
49impl Symbol {
50    /// Calculate the length of the symbol in bytes. Always a value between 1 and 8.
51    ///
52    /// Each symbol has the capacity to hold up to 8 bytes of data, but the symbols
53    /// can contain fewer bytes, padded with 0x00. There is a special case of a symbol
54    /// that holds the byte 0x00. In that case, the symbol contains `0x0000000000000000`
55    /// but we want to interpret that as a one-byte symbol containing `0x00`.
56    #[allow(clippy::len_without_is_empty)]
57    pub fn len(self) -> usize {
58        let numeric = self.0;
59        // For little-endian platforms, this counts the number of *trailing* zeros
60        let null_bytes = (numeric.leading_zeros() >> 3) as usize;
61
62        // Special case handling of a symbol with all-zeros. This is actually
63        // a 1-byte symbol containing 0x00.
64        let len = size_of::<Self>() - null_bytes;
65        if len == 0 { 1 } else { len }
66    }
67
68    /// Returns the Symbol's inner representation.
69    #[inline]
70    pub fn to_u64(self) -> u64 {
71        self.0
72    }
73
74    /// Get the first byte of the symbol as a `u8`.
75    ///
76    /// If the symbol is empty, this will return the zero byte.
77    #[inline]
78    pub fn first_byte(self) -> u8 {
79        self.0 as u8
80    }
81
82    /// Get the first two bytes of the symbol as a `u16`.
83    ///
84    /// If the Symbol is one or zero bytes, this will return `0u16`.
85    #[inline]
86    pub fn first2(self) -> u16 {
87        self.0 as u16
88    }
89
90    /// Get the first three bytes of the symbol as a `u64`.
91    ///
92    /// If the Symbol is one or zero bytes, this will return `0u64`.
93    #[inline]
94    pub fn first3(self) -> u64 {
95        self.0 & 0xFF_FF_FF
96    }
97
98    /// Return a new `Symbol` by logically concatenating ourselves with another `Symbol`.
99    pub fn concat(self, other: Self) -> Self {
100        assert!(
101            self.len() + other.len() <= 8,
102            "cannot build symbol with length > 8"
103        );
104
105        let self_len = self.len();
106
107        Self((other.0 << (8 * self_len)) | self.0)
108    }
109}
110
111impl Debug for Symbol {
112    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
113        write!(f, "[")?;
114
115        let slice = &self.0.to_le_bytes()[0..self.len()];
116        for c in slice.iter().map(|c| *c as char) {
117            if ('!'..='~').contains(&c) {
118                write!(f, "{c}")?;
119            } else if c == '\n' {
120                write!(f, " \\n ")?;
121            } else if c == '\t' {
122                write!(f, " \\t ")?;
123            } else if c == ' ' {
124                write!(f, " SPACE ")?;
125            } else {
126                write!(f, " 0x{:X?} ", c as u8)?
127            }
128        }
129
130        write!(f, "]")
131    }
132}
133
134/// A packed type containing a code value, as well as metadata about the symbol referred to by
135/// the code.
136///
137/// Logically, codes can range from 0-255 inclusive. This type holds both the 8-bit code as well as
138/// other metadata bit-packed into a `u16`.
139///
140/// The bottom 8 bits contain EITHER a code for a symbol stored in the table, OR a raw byte.
141///
142/// The interpretation depends on the 9th bit: when toggled off, the value stores a raw byte, and when
143/// toggled on, it stores a code. Thus if you examine the bottom 9 bits of the `u16`, you have an extended
144/// code range, where the values 0-255 are raw bytes, and the values 256-510 represent codes 0-254. 511 is
145/// a placeholder for the invalid code here.
146///
147/// Bits 12-15 store the length of the symbol (values ranging from 0-8).
148#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
149struct Code(u16);
150
151/// Code used to indicate bytes that are not in the symbol table.
152///
153/// When compressing a string that cannot fully be expressed with the symbol table, the compressed
154/// output will contain an `ESCAPE` byte followed by a raw byte. At decompression time, the presence
155/// of `ESCAPE` indicates that the next byte should be appended directly to the result instead of
156/// being looked up in the symbol table.
157pub const ESCAPE_CODE: u8 = 255;
158
159/// Number of bits in the `ExtendedCode` that are used to dictate a code value.
160pub const FSST_CODE_BITS: usize = 9;
161
162/// First bit of the "length" portion of an extended code.
163pub const FSST_LEN_BITS: usize = 12;
164
165/// Maximum code value in the extended code range.
166pub const FSST_CODE_MAX: u16 = 1 << FSST_CODE_BITS;
167
168/// Maximum value for the extended code range.
169///
170/// When truncated to u8 this is code 255, which is equivalent to [`ESCAPE_CODE`].
171pub const FSST_CODE_MASK: u16 = FSST_CODE_MAX - 1;
172
173/// First code in the symbol table that corresponds to a non-escape symbol.
174pub const FSST_CODE_BASE: u16 = 256;
175
176#[allow(clippy::len_without_is_empty)]
177impl Code {
178    /// Code for an unused slot in a symbol table or index.
179    ///
180    /// This corresponds to the maximum code with a length of 1.
181    pub const UNUSED: Self = Code(FSST_CODE_MASK + (1 << 12));
182
183    /// Create a new code for a symbol of given length.
184    fn new_symbol(code: u8, len: usize) -> Self {
185        Self(code as u16 + ((len as u16) << FSST_LEN_BITS))
186    }
187
188    /// Code for a new symbol during the building phase.
189    ///
190    /// The code is remapped from 0..254 to 256...510.
191    fn new_symbol_building(code: u8, len: usize) -> Self {
192        Self(code as u16 + 256 + ((len as u16) << FSST_LEN_BITS))
193    }
194
195    /// Create a new code corresponding for an escaped byte.
196    fn new_escape(byte: u8) -> Self {
197        Self((byte as u16) + (1 << FSST_LEN_BITS))
198    }
199
200    #[inline]
201    fn code(self) -> u8 {
202        self.0 as u8
203    }
204
205    #[inline]
206    fn extended_code(self) -> u16 {
207        self.0 & 0b111_111_111
208    }
209
210    #[inline]
211    fn len(self) -> u16 {
212        self.0 >> FSST_LEN_BITS
213    }
214}
215
216impl Debug for Code {
217    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
218        f.debug_struct("TrainingCode")
219            .field("code", &(self.0 as u8))
220            .field("is_escape", &(self.0 < 256))
221            .field("len", &(self.0 >> 12))
222            .finish()
223    }
224}
225
226/// Decompressor uses a symbol table to take a stream of 8-bit codes into a string.
227#[derive(Clone)]
228pub struct Decompressor<'a> {
229    /// Slice mapping codes to symbols.
230    pub(crate) symbols: &'a [Symbol],
231
232    /// Slice containing the length of each symbol in the `symbols` slice.
233    pub(crate) lengths: &'a [u8],
234}
235
236impl<'a> Decompressor<'a> {
237    /// Returns a new decompressor that uses the provided symbol table.
238    ///
239    /// # Panics
240    ///
241    /// If the provided symbol table has length greater than 256
242    pub fn new(symbols: &'a [Symbol], lengths: &'a [u8]) -> Self {
243        assert!(
244            symbols.len() < FSST_CODE_BASE as usize,
245            "symbol table cannot have size exceeding 255"
246        );
247
248        Self { symbols, lengths }
249    }
250
251    /// Returns an upper bound on the size of the decompressed data.
252    pub fn max_decompression_capacity(&self, compressed: &[u8]) -> usize {
253        size_of::<Symbol>() * (compressed.len() + 1)
254    }
255
256    /// Decompress a slice of codes into a provided buffer.
257    ///
258    /// The provided `decoded` buffer must be at least the size of the decoded data, plus
259    /// an additional 7 bytes.
260    ///
261    /// ## Panics
262    ///
263    /// If the caller fails to provide sufficient capacity in the decoded buffer. An upper bound
264    /// on the required capacity can be obtained by calling [`Self::max_decompression_capacity`].
265    ///
266    /// ## Example
267    ///
268    /// ```
269    /// use fsst::{Symbol, Compressor, CompressorBuilder};
270    /// let compressor = {
271    ///     let mut builder = CompressorBuilder::new();
272    ///     builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', b'o', b'o', b'o']), 8);
273    ///     builder.build()
274    /// };
275    ///
276    /// let decompressor = compressor.decompressor();
277    ///
278    /// let mut decompressed = Vec::with_capacity(8 + 7);
279    ///
280    /// let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
281    /// assert_eq!(len, 8);
282    /// unsafe { decompressed.set_len(len) };
283    /// assert_eq!(&decompressed, "helloooo".as_bytes());
284    /// ```
285    pub fn decompress_into(&self, compressed: &[u8], decoded: &mut [MaybeUninit<u8>]) -> usize {
286        // Ensure the target buffer is at least half the size of the input buffer.
287        // This is the theortical smallest a valid target can be, and occurs when
288        // every input code is an escape.
289        assert!(
290            decoded.len() >= compressed.len() / 2,
291            "decoded is smaller than lower-bound decompressed size"
292        );
293
294        unsafe {
295            let mut in_ptr = compressed.as_ptr();
296            let _in_begin = in_ptr;
297            let in_end = in_ptr.add(compressed.len());
298
299            let mut out_ptr: *mut u8 = decoded.as_mut_ptr().cast();
300            let out_begin = out_ptr.cast_const();
301            let out_end = decoded.as_ptr().add(decoded.len()).cast::<u8>();
302
303            macro_rules! store_next_symbol {
304                ($code:expr) => {{
305                    out_ptr
306                        .cast::<u64>()
307                        .write_unaligned(self.symbols.get_unchecked($code as usize).to_u64());
308                    out_ptr = out_ptr.add(*self.lengths.get_unchecked($code as usize) as usize);
309                }};
310            }
311
312            // First we try loading 8 bytes at a time.
313            if decoded.len() >= 8 * size_of::<Symbol>() && compressed.len() >= 8 {
314                // Extract the loop condition since the compiler fails to do so
315                let block_out_end = out_end.sub(8 * size_of::<Symbol>());
316                let block_in_end = in_end.sub(8);
317
318                while out_ptr.cast_const() <= block_out_end && in_ptr < block_in_end {
319                    // Note that we load a little-endian u64 here.
320                    let next_block = in_ptr.cast::<u64>().read_unaligned();
321                    let escape_mask = (next_block & 0x8080808080808080)
322                        & ((((!next_block) & 0x7F7F7F7F7F7F7F7F) + 0x7F7F7F7F7F7F7F7F)
323                            ^ 0x8080808080808080);
324
325                    // If there are no escape codes, we write each symbol one by one.
326                    if escape_mask == 0 {
327                        let code = (next_block & 0xFF) as u8;
328                        store_next_symbol!(code);
329                        let code = ((next_block >> 8) & 0xFF) as u8;
330                        store_next_symbol!(code);
331                        let code = ((next_block >> 16) & 0xFF) as u8;
332                        store_next_symbol!(code);
333                        let code = ((next_block >> 24) & 0xFF) as u8;
334                        store_next_symbol!(code);
335                        let code = ((next_block >> 32) & 0xFF) as u8;
336                        store_next_symbol!(code);
337                        let code = ((next_block >> 40) & 0xFF) as u8;
338                        store_next_symbol!(code);
339                        let code = ((next_block >> 48) & 0xFF) as u8;
340                        store_next_symbol!(code);
341                        let code = ((next_block >> 56) & 0xFF) as u8;
342                        store_next_symbol!(code);
343                        in_ptr = in_ptr.add(8);
344                    } else if (next_block & 0x00FF00FF00FF00FF) == 0x00FF00FF00FF00FF {
345                        // All 4 even-positioned bytes are ESCAPE_CODE.
346                        // Batch-extract the 4 raw bytes at odd positions.
347                        out_ptr.write(((next_block >> 8) & 0xFF) as u8);
348                        out_ptr.add(1).write(((next_block >> 24) & 0xFF) as u8);
349                        out_ptr.add(2).write(((next_block >> 40) & 0xFF) as u8);
350                        out_ptr.add(3).write(((next_block >> 56) & 0xFF) as u8);
351                        out_ptr = out_ptr.add(4);
352                        in_ptr = in_ptr.add(8);
353                    } else {
354                        // Otherwise, find the first escape code and write the symbols up to that point.
355                        let first_escape_pos = escape_mask.trailing_zeros() >> 3; // Divide bits to bytes
356                        debug_assert!(first_escape_pos < 8);
357                        match first_escape_pos {
358                            7 => {
359                                let code = (next_block & 0xFF) as u8;
360                                store_next_symbol!(code);
361                                let code = ((next_block >> 8) & 0xFF) as u8;
362                                store_next_symbol!(code);
363                                let code = ((next_block >> 16) & 0xFF) as u8;
364                                store_next_symbol!(code);
365                                let code = ((next_block >> 24) & 0xFF) as u8;
366                                store_next_symbol!(code);
367                                let code = ((next_block >> 32) & 0xFF) as u8;
368                                store_next_symbol!(code);
369                                let code = ((next_block >> 40) & 0xFF) as u8;
370                                store_next_symbol!(code);
371                                let code = ((next_block >> 48) & 0xFF) as u8;
372                                store_next_symbol!(code);
373
374                                in_ptr = in_ptr.add(7);
375                            }
376                            6 => {
377                                let code = (next_block & 0xFF) as u8;
378                                store_next_symbol!(code);
379                                let code = ((next_block >> 8) & 0xFF) as u8;
380                                store_next_symbol!(code);
381                                let code = ((next_block >> 16) & 0xFF) as u8;
382                                store_next_symbol!(code);
383                                let code = ((next_block >> 24) & 0xFF) as u8;
384                                store_next_symbol!(code);
385                                let code = ((next_block >> 32) & 0xFF) as u8;
386                                store_next_symbol!(code);
387                                let code = ((next_block >> 40) & 0xFF) as u8;
388                                store_next_symbol!(code);
389
390                                let escaped = ((next_block >> 56) & 0xFF) as u8;
391                                out_ptr.write(escaped);
392                                out_ptr = out_ptr.add(1);
393
394                                in_ptr = in_ptr.add(8);
395                            }
396                            5 => {
397                                let code = (next_block & 0xFF) as u8;
398                                store_next_symbol!(code);
399                                let code = ((next_block >> 8) & 0xFF) as u8;
400                                store_next_symbol!(code);
401                                let code = ((next_block >> 16) & 0xFF) as u8;
402                                store_next_symbol!(code);
403                                let code = ((next_block >> 24) & 0xFF) as u8;
404                                store_next_symbol!(code);
405                                let code = ((next_block >> 32) & 0xFF) as u8;
406                                store_next_symbol!(code);
407
408                                let escaped = ((next_block >> 48) & 0xFF) as u8;
409                                out_ptr.write(escaped);
410                                out_ptr = out_ptr.add(1);
411
412                                in_ptr = in_ptr.add(7);
413                            }
414                            4 => {
415                                let code = (next_block & 0xFF) as u8;
416                                store_next_symbol!(code);
417                                let code = ((next_block >> 8) & 0xFF) as u8;
418                                store_next_symbol!(code);
419                                let code = ((next_block >> 16) & 0xFF) as u8;
420                                store_next_symbol!(code);
421                                let code = ((next_block >> 24) & 0xFF) as u8;
422                                store_next_symbol!(code);
423
424                                let escaped = ((next_block >> 40) & 0xFF) as u8;
425                                out_ptr.write(escaped);
426                                out_ptr = out_ptr.add(1);
427
428                                in_ptr = in_ptr.add(6);
429                            }
430                            3 => {
431                                let code = (next_block & 0xFF) as u8;
432                                store_next_symbol!(code);
433                                let code = ((next_block >> 8) & 0xFF) as u8;
434                                store_next_symbol!(code);
435                                let code = ((next_block >> 16) & 0xFF) as u8;
436                                store_next_symbol!(code);
437
438                                let escaped = ((next_block >> 32) & 0xFF) as u8;
439                                out_ptr.write(escaped);
440                                out_ptr = out_ptr.add(1);
441
442                                in_ptr = in_ptr.add(5);
443                            }
444                            2 => {
445                                let code = (next_block & 0xFF) as u8;
446                                store_next_symbol!(code);
447                                let code = ((next_block >> 8) & 0xFF) as u8;
448                                store_next_symbol!(code);
449
450                                let escaped = ((next_block >> 24) & 0xFF) as u8;
451                                out_ptr.write(escaped);
452                                out_ptr = out_ptr.add(1);
453
454                                in_ptr = in_ptr.add(4);
455                            }
456                            1 => {
457                                let code = (next_block & 0xFF) as u8;
458                                store_next_symbol!(code);
459
460                                let escaped = ((next_block >> 16) & 0xFF) as u8;
461                                out_ptr.write(escaped);
462                                out_ptr = out_ptr.add(1);
463
464                                in_ptr = in_ptr.add(3);
465                            }
466                            0 => {
467                                // Otherwise, we actually need to decompress the next byte
468                                // Extract the second byte from the u32
469                                let escaped = ((next_block >> 8) & 0xFF) as u8;
470                                in_ptr = in_ptr.add(2);
471                                out_ptr.write(escaped);
472                                out_ptr = out_ptr.add(1);
473                            }
474                            _ => unreachable!(),
475                        }
476                    }
477                }
478            }
479
480            // Otherwise, fall back to 1-byte reads using 8-byte writes where safe.
481            while out_end.offset_from(out_ptr) >= size_of::<Symbol>() as isize && in_ptr < in_end {
482                let code = in_ptr.read();
483                in_ptr = in_ptr.add(1);
484
485                if code == ESCAPE_CODE {
486                    assert!(
487                        in_ptr < in_end,
488                        "truncated compressed string: escape code at end of input"
489                    );
490                    out_ptr.write(in_ptr.read());
491                    in_ptr = in_ptr.add(1);
492                    out_ptr = out_ptr.add(1);
493                } else {
494                    store_next_symbol!(code);
495                }
496            }
497
498            // For the last few bytes (if any) where we can't do an 8-byte unaligned write.
499            while in_ptr < in_end {
500                let code = in_ptr.read();
501                in_ptr = in_ptr.add(1);
502
503                if code == ESCAPE_CODE {
504                    assert!(
505                        in_ptr < in_end,
506                        "truncated compressed string: escape code at end of input"
507                    );
508                    assert!(
509                        out_ptr.cast_const() < out_end,
510                        "output buffer sized too small"
511                    );
512                    out_ptr.write(in_ptr.read());
513                    in_ptr = in_ptr.add(1);
514                    out_ptr = out_ptr.add(1);
515                } else {
516                    let len = *self.lengths.get_unchecked(code as usize) as usize;
517                    assert!(
518                        out_end.offset_from(out_ptr) >= len as isize,
519                        "output buffer sized too small"
520                    );
521                    let sym = self.symbols.get_unchecked(code as usize).to_u64();
522                    let sym_bytes = sym.to_le_bytes();
523                    std::ptr::copy_nonoverlapping(sym_bytes.as_ptr(), out_ptr, len);
524                    out_ptr = out_ptr.add(len);
525                }
526            }
527
528            assert_eq!(
529                in_ptr, in_end,
530                "decompression should exhaust input before output"
531            );
532
533            out_ptr.offset_from(out_begin) as usize
534        }
535    }
536
537    /// Decompress a byte slice that was previously returned by a compressor using the same symbol
538    /// table into a new vector of bytes.
539    pub fn decompress(&self, compressed: &[u8]) -> Vec<u8> {
540        let mut decoded = Vec::with_capacity(self.max_decompression_capacity(compressed) + 7);
541
542        let len = self.decompress_into(compressed, decoded.spare_capacity_mut());
543        // SAFETY: len bytes have now been initialized by the decompressor.
544        unsafe { decoded.set_len(len) };
545        decoded
546    }
547}
548
549/// A compressor that uses a symbol table to greedily compress strings.
550///
551/// The `Compressor` is the central component of FSST. You can create a compressor either by
552/// default (i.e. an empty compressor), or by [training][`Self::train`] it on an input corpus of text.
553///
554/// Example usage:
555///
556/// ```
557/// use fsst::{Symbol, Compressor, CompressorBuilder};
558/// let compressor = {
559///     let mut builder = CompressorBuilder::new();
560///     builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', 0, 0, 0]), 5);
561///     builder.build()
562/// };
563///
564/// let compressed = compressor.compress("hello".as_bytes());
565/// assert_eq!(compressed, vec![0u8]);
566/// ```
567#[derive(Clone)]
568pub struct Compressor {
569    /// Table mapping codes to symbols.
570    pub(crate) symbols: Vec<Symbol>,
571
572    /// Length of each symbol, values range from 1-8.
573    pub(crate) lengths: Vec<u8>,
574
575    /// The number of entries in the symbol table that have been populated, not counting
576    /// the escape values.
577    pub(crate) n_symbols: u8,
578
579    /// Inverted index mapping 2-byte symbols to codes
580    codes_two_byte: Vec<Code>,
581
582    /// Limit of no suffixes.
583    has_suffix_code: u8,
584
585    /// Lossy perfect hash table for looking up codes to symbols that are 3 bytes or more
586    lossy_pht: LossyPHT,
587}
588
589/// The core structure of the FSST codec, holding a mapping between `Symbol`s and `Code`s.
590///
591/// The symbol table is trained on a corpus of data in the form of a single byte array, building up
592/// a mapping of 1-byte "codes" to sequences of up to 8 plaintext bytes, or "symbols".
593impl Compressor {
594    /// Using the symbol table, runs a single cycle of compression on an input word, writing
595    /// the output into `out_ptr`.
596    ///
597    /// # Returns
598    ///
599    /// This function returns a tuple of (advance_in, advance_out) with the number of bytes
600    /// for the caller to advance the input and output pointers.
601    ///
602    /// `advance_in` is the number of bytes to advance the input pointer before the next call.
603    ///
604    /// `advance_out` is the number of bytes to advance `out_ptr` before the next call.
605    ///
606    /// # Safety
607    ///
608    /// `out_ptr` must never be NULL or otherwise point to invalid memory.
609    pub unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) {
610        // Speculatively write the first byte of `word` at offset 1. This is necessary if it is an escape, and
611        // if it isn't, it will be overwritten anyway.
612        //
613        // SAFETY: caller ensures out_ptr is not null
614        let first_byte = word as u8;
615        // SAFETY: out_ptr is not null
616        unsafe { out_ptr.byte_add(1).write_unaligned(first_byte) };
617
618        // First, check the two_bytes table
619        // SAFETY: codes_two_byte has exactly 65536 entries and `word as u16` is always in [0, 65535].
620        let code_twobyte = unsafe { *self.codes_two_byte.get_unchecked(word as u16 as usize) };
621
622        if code_twobyte.code() < self.has_suffix_code {
623            // 2 byte code without having to worry about longer matches.
624            // SAFETY: out_ptr is not null.
625            unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
626
627            // Advance input by symbol length (2) and output by a single code byte
628            (2, 1)
629        } else {
630            // Probe the hash table
631            let entry = self.lossy_pht.lookup(word);
632
633            // Now, downshift the `word` and the `entry` to see if they align.
634            let ignored_bits = entry.ignored_bits;
635            if entry.code != Code::UNUSED
636                && compare_masked(word, entry.symbol.to_u64(), ignored_bits)
637            {
638                // Advance the input by the symbol length (variable) and the output by one code byte
639                // SAFETY: out_ptr is not null.
640                unsafe { std::ptr::write(out_ptr, entry.code.code()) };
641                (entry.code.len() as usize, 1)
642            } else {
643                // SAFETY: out_ptr is not null
644                unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
645
646                // Advance the input by the symbol length (variable) and the output by either 1
647                // byte (if was one-byte code) or two bytes (escape).
648                (
649                    code_twobyte.len() as usize,
650                    // Predicated version of:
651                    //
652                    // if entry.code >= 256 {
653                    //      2
654                    // } else {
655                    //      1
656                    // }
657                    1 + (code_twobyte.extended_code() >> 8) as usize,
658                )
659            }
660        }
661    }
662
663    /// Compress many lines in bulk.
664    pub fn compress_bulk(&self, lines: &Vec<&[u8]>) -> Vec<Vec<u8>> {
665        let mut res = Vec::new();
666
667        for line in lines {
668            res.push(self.compress(line));
669        }
670
671        res
672    }
673
674    /// Compress a string, writing its result into a target buffer.
675    ///
676    /// The target buffer is a byte vector that must have capacity large enough
677    /// to hold the encoded data.
678    ///
679    /// When this call returns, `values` will hold the compressed bytes and have
680    /// its length set to the length of the compressed text.
681    ///
682    /// ```
683    /// use fsst::{Compressor, CompressorBuilder, Symbol};
684    ///
685    /// let mut compressor = CompressorBuilder::new();
686    /// assert!(compressor.insert(Symbol::from_slice(b"aaaaaaaa"), 8));
687    ///
688    /// let compressor = compressor.build();
689    ///
690    /// let mut compressed_values = Vec::with_capacity(1_024);
691    ///
692    /// // SAFETY: we have over-sized compressed_values.
693    /// unsafe {
694    ///     compressor.compress_into(b"aaaaaaaa", &mut compressed_values);
695    /// }
696    ///
697    /// assert_eq!(compressed_values, vec![0u8]);
698    /// ```
699    ///
700    /// # Safety
701    ///
702    /// It is up to the caller to ensure the provided buffer is large enough to hold
703    /// all encoded data.
704    pub unsafe fn compress_into(&self, plaintext: &[u8], values: &mut Vec<u8>) {
705        let mut in_ptr = plaintext.as_ptr();
706        let mut out_ptr = values.as_mut_ptr();
707
708        // SAFETY: `end` will point just after the end of the `plaintext` slice.
709        let in_end = unsafe { in_ptr.byte_add(plaintext.len()) };
710        let in_end_sub8 = in_end as usize - 8;
711        // SAFETY: `end` will point just after the end of the `values` allocation.
712        let out_end = unsafe { out_ptr.byte_add(values.capacity()) };
713
714        while (in_ptr as usize) <= in_end_sub8 && unsafe { out_end.offset_from(out_ptr) } >= 2 {
715            // SAFETY: pointer ranges are checked in the loop condition
716            unsafe {
717                // Load a full 8-byte word of data from in_ptr.
718                // SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
719                let word: u64 = std::ptr::read_unaligned(in_ptr as *const u64);
720                let (advance_in, advance_out) = self.compress_word(word, out_ptr);
721                in_ptr = in_ptr.byte_add(advance_in);
722                out_ptr = out_ptr.byte_add(advance_out);
723            };
724        }
725
726        let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
727        assert!(
728            out_ptr < out_end || remaining_bytes == 0,
729            "output buffer sized too small"
730        );
731
732        let remaining_bytes = remaining_bytes as usize;
733
734        // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
735        // but shift data out of this word rather than advancing an input pointer and potentially reading
736        // unowned memory.
737        let mut bytes = [0u8; 8];
738        // SAFETY: remaining_bytes <= 8
739        unsafe { std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes) };
740        let mut last_word = u64::from_le_bytes(bytes);
741
742        while in_ptr < in_end && unsafe { out_end.offset_from(out_ptr) } >= 2 {
743            // Load a full 8-byte word of data from in_ptr.
744            // SAFETY: caller asserts in_ptr is not null
745            let (advance_in, advance_out) = unsafe { self.compress_word(last_word, out_ptr) };
746            // SAFETY: pointer ranges are checked in the loop condition
747            unsafe {
748                in_ptr = in_ptr.add(advance_in);
749                out_ptr = out_ptr.add(advance_out);
750            }
751
752            last_word = advance_8byte_word(last_word, advance_in);
753        }
754
755        // in_ptr should have exceeded in_end
756        assert!(
757            in_ptr >= in_end,
758            "exhausted output buffer before exhausting input, there is a bug in SymbolTable::compress()"
759        );
760
761        assert!(out_ptr <= out_end, "output buffer sized too small");
762
763        // SAFETY: out_ptr is derived from the `values` allocation.
764        let bytes_written = unsafe { out_ptr.offset_from(values.as_ptr()) };
765        assert!(
766            bytes_written >= 0,
767            "out_ptr ended before it started, not possible"
768        );
769
770        // SAFETY: we have initialized `bytes_written` values in the output buffer.
771        unsafe { values.set_len(bytes_written as usize) };
772    }
773
774    /// Use the symbol table to compress the plaintext into a sequence of codes and escapes.
775    pub fn compress(&self, plaintext: &[u8]) -> Vec<u8> {
776        if plaintext.is_empty() {
777            return Vec::new();
778        }
779
780        let mut buffer = Vec::with_capacity(plaintext.len() * 2);
781
782        // SAFETY: the largest compressed size would be all escapes == 2*plaintext_len
783        unsafe { self.compress_into(plaintext, &mut buffer) };
784
785        buffer
786    }
787
788    /// Access the decompressor that can be used to decompress strings emitted from this
789    /// `Compressor` instance.
790    pub fn decompressor(&self) -> Decompressor<'_> {
791        Decompressor::new(self.symbol_table(), self.symbol_lengths())
792    }
793
794    /// Returns a readonly slice of the current symbol table.
795    ///
796    /// The returned slice will have length of `n_symbols`.
797    pub fn symbol_table(&self) -> &[Symbol] {
798        &self.symbols[0..self.n_symbols as usize]
799    }
800
801    /// Returns a readonly slice where index `i` contains the
802    /// length of the symbol represented by code `i`.
803    ///
804    /// Values range from 1-8.
805    pub fn symbol_lengths(&self) -> &[u8] {
806        &self.lengths[0..self.n_symbols as usize]
807    }
808
809    /// Rebuild a compressor from an existing symbol table.
810    ///
811    /// This will not attempt to optimize or re-order the codes.
812    pub fn rebuild_from(symbols: impl AsRef<[Symbol]>, symbol_lens: impl AsRef<[u8]>) -> Self {
813        let symbols = symbols.as_ref();
814        let symbol_lens = symbol_lens.as_ref();
815
816        assert_eq!(
817            symbols.len(),
818            symbol_lens.len(),
819            "symbols and lengths differ"
820        );
821        assert!(
822            symbols.len() <= 255,
823            "symbol table len must be <= 255, was {}",
824            symbols.len()
825        );
826        validate_symbol_order(symbol_lens);
827
828        // Insert the symbols in their given order into the FSST lookup structures.
829        let symbols = symbols.to_vec();
830        let lengths = symbol_lens.to_vec();
831        let mut lossy_pht = LossyPHT::new();
832
833        let mut codes_one_byte = vec![Code::UNUSED; 256];
834
835        // Insert all of the one byte symbols first.
836        for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
837            if len == 1 {
838                codes_one_byte[symbol.first_byte() as usize] = Code::new_symbol(code as u8, 1);
839            }
840        }
841
842        // Initialize the codes_two_byte table to be all escapes
843        let mut codes_two_byte = vec![Code::UNUSED; 65_536];
844
845        // Insert the two byte symbols, possibly overwriting slots for one-byte symbols and escapes.
846        for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
847            match len {
848                2 => {
849                    codes_two_byte[symbol.first2() as usize] = Code::new_symbol(code as u8, 2);
850                }
851                3.. => {
852                    assert!(
853                        lossy_pht.insert(symbol, len as usize, code as u8),
854                        "rebuild symbol insertion into PHT must succeed"
855                    );
856                }
857                _ => { /* Covered by the 1-byte loop above. */ }
858            }
859        }
860
861        // Build the finished codes_two_byte table, subbing in unused positions with the
862        // codes_one_byte value similar to what we do in CompressBuilder::finalize.
863        for (symbol, code) in codes_two_byte.iter_mut().enumerate() {
864            if *code == Code::UNUSED {
865                *code = codes_one_byte[symbol & 0xFF];
866            }
867        }
868
869        // Find the position of the first 2-byte code that has a suffix later in the table
870        let mut has_suffix_code = 0u8;
871        for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
872            if len != 2 {
873                break;
874            }
875            let rest = &symbols[code..];
876            if rest
877                .iter()
878                .any(|&other| other.len() > 2 && symbol.first2() == other.first2())
879            {
880                has_suffix_code = code as u8;
881                break;
882            }
883        }
884
885        Compressor {
886            n_symbols: symbols.len() as u8,
887            symbols,
888            lengths,
889            codes_two_byte,
890            lossy_pht,
891            has_suffix_code,
892        }
893    }
894}
895
896#[inline]
897pub(crate) fn advance_8byte_word(word: u64, bytes: usize) -> u64 {
898    // shift the word off the low-end, because little endian means the first
899    // char is stored in the LSB.
900    //
901    // Note that even though this looks like it branches, Rust compiles this to a
902    // conditional move instruction. See `<https://godbolt.org/z/Pbvre65Pq>`
903    if bytes == 8 { 0 } else { word >> (8 * bytes) }
904}
905
906fn validate_symbol_order(symbol_lens: &[u8]) {
907    // Ensure that the symbol table is ordered by length, 23456781
908    let mut expected = 2;
909    for (idx, &len) in symbol_lens.iter().enumerate() {
910        if expected == 1 {
911            assert_eq!(
912                len, 1,
913                "symbol code={idx} should be one byte, was {len} bytes"
914            );
915        } else {
916            if len == 1 {
917                expected = 1;
918            }
919
920            // we're in the non-zero portion.
921            assert!(
922                len >= expected,
923                "symbol code={idx} breaks violates FSST symbol table ordering"
924            );
925            expected = len;
926        }
927    }
928}
929
930#[inline]
931pub(crate) fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool {
932    let mask = u64::MAX >> ignored_bits;
933    (left & mask) == right
934}
935
936#[cfg(test)]
937mod test {
938    use super::*;
939    use std::{iter, mem};
940    #[test]
941    fn test_stuff() {
942        let compressor = {
943            let mut builder = CompressorBuilder::new();
944            builder.insert(Symbol::from_slice(b"helloooo"), 8);
945            builder.build()
946        };
947
948        let decompressor = compressor.decompressor();
949
950        let mut decompressed = Vec::with_capacity(8 + 7);
951
952        let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
953        assert_eq!(len, 8);
954        unsafe { decompressed.set_len(len) };
955        assert_eq!(&decompressed, "helloooo".as_bytes());
956    }
957
958    #[test]
959    fn test_symbols_good() {
960        let symbols_u64: &[u64] = &[
961            24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
962            24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
963            6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
964            6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
965            6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
966            6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
967            6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
968            6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
969            6447715, 97, 98, 100, 99, 97, 98, 99, 100,
970        ];
971        let symbols: &[Symbol] = unsafe { mem::transmute(symbols_u64) };
972        let lens: Vec<u8> = iter::repeat_n(2u8, 16)
973            .chain(iter::repeat_n(3u8, 61))
974            .chain(iter::repeat_n(1u8, 8))
975            .collect();
976
977        let compressor = Compressor::rebuild_from(symbols, lens);
978        let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
979        assert_eq!(built_symbols, symbols_u64);
980    }
981
982    #[should_panic(expected = "assertion `left == right` failed")]
983    #[test]
984    fn test_symbols_bad() {
985        let symbols: &[u64] = &[
986            24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
987            24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
988            6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
989            6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
990            6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
991            6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
992            6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
993            6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
994            6447715, 97, 98, 100, 99, 97, 98, 99, 100,
995        ];
996        let lens: Vec<u8> = iter::repeat_n(2u8, 16)
997            .chain(iter::repeat_n(3u8, 61))
998            .chain(iter::repeat_n(1u8, 8))
999            .collect();
1000
1001        let mut builder = CompressorBuilder::new();
1002        for (symbol, len) in symbols.iter().zip(lens.iter()) {
1003            let symbol = Symbol::from_slice(&symbol.to_le_bytes());
1004            builder.insert(symbol, *len as usize);
1005        }
1006        let compressor = builder.build();
1007        let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
1008        assert_eq!(built_symbols, symbols);
1009    }
1010}