Skip to main content

fsst/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg(target_endian = "little")]
3
4/// Throw a compiler error if a type isn't guaranteed to have a specific size in bytes.
5macro_rules! assert_sizeof {
6    ($typ:ty => $size_in_bytes:expr) => {
7        const _: [u8; $size_in_bytes] = [0; std::mem::size_of::<$typ>()];
8    };
9}
10
11use lossy_pht::LossyPHT;
12use std::fmt::{Debug, Formatter};
13use std::mem::MaybeUninit;
14
15mod builder;
16mod lossy_pht;
17
18pub use builder::*;
19
20/// `Symbol`s are small (up to 8-byte) segments of strings, stored in a [`Compressor`][`crate::Compressor`] and
21/// identified by an 8-bit code.
22#[derive(Copy, Clone, PartialEq, Eq, Hash)]
23pub struct Symbol(u64);
24
25assert_sizeof!(Symbol => 8);
26
27impl Symbol {
28    /// Zero value for `Symbol`.
29    pub const ZERO: Self = Self::zero();
30
31    /// Constructor for a `Symbol` from an 8-element byte slice.
32    pub fn from_slice(slice: &[u8; 8]) -> Self {
33        let num: u64 = u64::from_le_bytes(*slice);
34
35        Self(num)
36    }
37
38    /// Return a zero symbol
39    const fn zero() -> Self {
40        Self(0)
41    }
42
43    /// Create a new single-byte symbol
44    pub fn from_u8(value: u8) -> Self {
45        Self(value as u64)
46    }
47}
48
49impl Symbol {
50    /// Calculate the length of the symbol in bytes. Always a value between 1 and 8.
51    ///
52    /// Each symbol has the capacity to hold up to 8 bytes of data, but the symbols
53    /// can contain fewer bytes, padded with 0x00. There is a special case of a symbol
54    /// that holds the byte 0x00. In that case, the symbol contains `0x0000000000000000`
55    /// but we want to interpret that as a one-byte symbol containing `0x00`.
56    #[allow(clippy::len_without_is_empty)]
57    pub fn len(self) -> usize {
58        let numeric = self.0;
59        // For little-endian platforms, this counts the number of *trailing* zeros
60        let null_bytes = (numeric.leading_zeros() >> 3) as usize;
61
62        // Special case handling of a symbol with all-zeros. This is actually
63        // a 1-byte symbol containing 0x00.
64        let len = size_of::<Self>() - null_bytes;
65        if len == 0 { 1 } else { len }
66    }
67
68    /// Returns the Symbol's inner representation.
69    #[inline]
70    pub fn to_u64(self) -> u64 {
71        self.0
72    }
73
74    /// Get the first byte of the symbol as a `u8`.
75    ///
76    /// If the symbol is empty, this will return the zero byte.
77    #[inline]
78    pub fn first_byte(self) -> u8 {
79        self.0 as u8
80    }
81
82    /// Get the first two bytes of the symbol as a `u16`.
83    ///
84    /// If the Symbol is one or zero bytes, this will return `0u16`.
85    #[inline]
86    pub fn first2(self) -> u16 {
87        self.0 as u16
88    }
89
90    /// Get the first three bytes of the symbol as a `u64`.
91    ///
92    /// If the Symbol is one or zero bytes, this will return `0u64`.
93    #[inline]
94    pub fn first3(self) -> u64 {
95        self.0 & 0xFF_FF_FF
96    }
97
98    /// Return a new `Symbol` by logically concatenating ourselves with another `Symbol`.
99    pub fn concat(self, other: Self) -> Self {
100        assert!(
101            self.len() + other.len() <= 8,
102            "cannot build symbol with length > 8"
103        );
104
105        let self_len = self.len();
106
107        Self((other.0 << (8 * self_len)) | self.0)
108    }
109}
110
111impl Debug for Symbol {
112    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
113        write!(f, "[")?;
114
115        let slice = &self.0.to_le_bytes()[0..self.len()];
116        for c in slice.iter().map(|c| *c as char) {
117            if ('!'..='~').contains(&c) {
118                write!(f, "{c}")?;
119            } else if c == '\n' {
120                write!(f, " \\n ")?;
121            } else if c == '\t' {
122                write!(f, " \\t ")?;
123            } else if c == ' ' {
124                write!(f, " SPACE ")?;
125            } else {
126                write!(f, " 0x{:X?} ", c as u8)?
127            }
128        }
129
130        write!(f, "]")
131    }
132}
133
134/// A packed type containing a code value, as well as metadata about the symbol referred to by
135/// the code.
136///
137/// Logically, codes can range from 0-255 inclusive. This type holds both the 8-bit code as well as
138/// other metadata bit-packed into a `u16`.
139///
140/// The bottom 8 bits contain EITHER a code for a symbol stored in the table, OR a raw byte.
141///
142/// The interpretation depends on the 9th bit: when toggled off, the value stores a raw byte, and when
143/// toggled on, it stores a code. Thus if you examine the bottom 9 bits of the `u16`, you have an extended
144/// code range, where the values 0-255 are raw bytes, and the values 256-510 represent codes 0-254. 511 is
145/// a placeholder for the invalid code here.
146///
147/// Bits 12-15 store the length of the symbol (values ranging from 0-8).
148#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
149struct Code(u16);
150
151/// Code used to indicate bytes that are not in the symbol table.
152///
153/// When compressing a string that cannot fully be expressed with the symbol table, the compressed
154/// output will contain an `ESCAPE` byte followed by a raw byte. At decompression time, the presence
155/// of `ESCAPE` indicates that the next byte should be appended directly to the result instead of
156/// being looked up in the symbol table.
157pub const ESCAPE_CODE: u8 = 255;
158
159/// Number of bits in the `ExtendedCode` that are used to dictate a code value.
160pub const FSST_CODE_BITS: usize = 9;
161
162/// First bit of the "length" portion of an extended code.
163pub const FSST_LEN_BITS: usize = 12;
164
165/// Maximum code value in the extended code range.
166pub const FSST_CODE_MAX: u16 = 1 << FSST_CODE_BITS;
167
168/// Maximum value for the extended code range.
169///
170/// When truncated to u8 this is code 255, which is equivalent to [`ESCAPE_CODE`].
171pub const FSST_CODE_MASK: u16 = FSST_CODE_MAX - 1;
172
173/// First code in the symbol table that corresponds to a non-escape symbol.
174pub const FSST_CODE_BASE: u16 = 256;
175
176#[allow(clippy::len_without_is_empty)]
177impl Code {
178    /// Code for an unused slot in a symbol table or index.
179    ///
180    /// This corresponds to the maximum code with a length of 1.
181    pub const UNUSED: Self = Code(FSST_CODE_MASK + (1 << 12));
182
183    /// Create a new code for a symbol of given length.
184    fn new_symbol(code: u8, len: usize) -> Self {
185        Self(code as u16 + ((len as u16) << FSST_LEN_BITS))
186    }
187
188    /// Code for a new symbol during the building phase.
189    ///
190    /// The code is remapped from 0..254 to 256...510.
191    fn new_symbol_building(code: u8, len: usize) -> Self {
192        Self(code as u16 + 256 + ((len as u16) << FSST_LEN_BITS))
193    }
194
195    /// Create a new code corresponding for an escaped byte.
196    fn new_escape(byte: u8) -> Self {
197        Self((byte as u16) + (1 << FSST_LEN_BITS))
198    }
199
200    #[inline]
201    fn code(self) -> u8 {
202        self.0 as u8
203    }
204
205    #[inline]
206    fn extended_code(self) -> u16 {
207        self.0 & 0b111_111_111
208    }
209
210    #[inline]
211    fn len(self) -> u16 {
212        self.0 >> FSST_LEN_BITS
213    }
214}
215
216impl Debug for Code {
217    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
218        f.debug_struct("TrainingCode")
219            .field("code", &(self.0 as u8))
220            .field("is_escape", &(self.0 < 256))
221            .field("len", &(self.0 >> 12))
222            .finish()
223    }
224}
225
226/// Decompressor uses a symbol table to take a stream of 8-bit codes into a string.
227#[derive(Clone)]
228pub struct Decompressor<'a> {
229    /// Slice mapping codes to symbols.
230    pub(crate) symbols: &'a [Symbol],
231
232    /// Slice containing the length of each symbol in the `symbols` slice.
233    pub(crate) lengths: &'a [u8],
234}
235
236impl<'a> Decompressor<'a> {
237    /// Returns a new decompressor that uses the provided symbol table.
238    ///
239    /// # Panics
240    ///
241    /// If the provided symbol table has length greater than 256
242    pub fn new(symbols: &'a [Symbol], lengths: &'a [u8]) -> Self {
243        assert!(
244            symbols.len() < FSST_CODE_BASE as usize,
245            "symbol table cannot have size exceeding 255"
246        );
247
248        Self { symbols, lengths }
249    }
250
251    /// Returns an upper bound on the size of the decompressed data.
252    pub fn max_decompression_capacity(&self, compressed: &[u8]) -> usize {
253        size_of::<Symbol>() * (compressed.len() + 1)
254    }
255
256    /// Decompress a slice of codes into a provided buffer.
257    ///
258    /// The provided `decoded` buffer must be at least the size of the decoded data, plus
259    /// an additional 7 bytes.
260    ///
261    /// ## Panics
262    ///
263    /// If the caller fails to provide sufficient capacity in the decoded buffer. An upper bound
264    /// on the required capacity can be obtained by calling [`Self::max_decompression_capacity`].
265    ///
266    /// ## Example
267    ///
268    /// ```
269    /// use fsst::{Symbol, Compressor, CompressorBuilder};
270    /// let compressor = {
271    ///     let mut builder = CompressorBuilder::new();
272    ///     builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', b'o', b'o', b'o']), 8);
273    ///     builder.build()
274    /// };
275    ///
276    /// let decompressor = compressor.decompressor();
277    ///
278    /// let mut decompressed = Vec::with_capacity(8 + 7);
279    ///
280    /// let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
281    /// assert_eq!(len, 8);
282    /// unsafe { decompressed.set_len(len) };
283    /// assert_eq!(&decompressed, "helloooo".as_bytes());
284    /// ```
285    pub fn decompress_into(&self, compressed: &[u8], decoded: &mut [MaybeUninit<u8>]) -> usize {
286        // Ensure the target buffer is at least half the size of the input buffer.
287        // This is the theortical smallest a valid target can be, and occurs when
288        // every input code is an escape.
289        assert!(
290            decoded.len() >= compressed.len() / 2,
291            "decoded is smaller than lower-bound decompressed size"
292        );
293
294        unsafe {
295            let mut in_ptr = compressed.as_ptr();
296            let _in_begin = in_ptr;
297            let in_end = in_ptr.add(compressed.len());
298
299            let mut out_ptr: *mut u8 = decoded.as_mut_ptr().cast();
300            let out_begin = out_ptr.cast_const();
301            let out_end = decoded.as_ptr().add(decoded.len()).cast::<u8>();
302
303            macro_rules! store_next_symbol {
304                ($code:expr) => {{
305                    out_ptr
306                        .cast::<u64>()
307                        .write_unaligned(self.symbols.get_unchecked($code as usize).to_u64());
308                    out_ptr = out_ptr.add(*self.lengths.get_unchecked($code as usize) as usize);
309                }};
310            }
311
312            // First we try loading 8 bytes at a time.
313            if decoded.len() >= 8 * size_of::<Symbol>() && compressed.len() >= 8 {
314                // Extract the loop condition since the compiler fails to do so
315                let block_out_end = out_end.sub(8 * size_of::<Symbol>());
316                let block_in_end = in_end.sub(8);
317
318                while out_ptr.cast_const() <= block_out_end && in_ptr < block_in_end {
319                    // Note that we load a little-endian u64 here.
320                    let next_block = in_ptr.cast::<u64>().read_unaligned();
321                    let escape_mask = (next_block & 0x8080808080808080)
322                        & ((((!next_block) & 0x7F7F7F7F7F7F7F7F) + 0x7F7F7F7F7F7F7F7F)
323                            ^ 0x8080808080808080);
324
325                    // If there are no escape codes, we write each symbol one by one.
326                    if escape_mask == 0 {
327                        let code = (next_block & 0xFF) as u8;
328                        store_next_symbol!(code);
329                        let code = ((next_block >> 8) & 0xFF) as u8;
330                        store_next_symbol!(code);
331                        let code = ((next_block >> 16) & 0xFF) as u8;
332                        store_next_symbol!(code);
333                        let code = ((next_block >> 24) & 0xFF) as u8;
334                        store_next_symbol!(code);
335                        let code = ((next_block >> 32) & 0xFF) as u8;
336                        store_next_symbol!(code);
337                        let code = ((next_block >> 40) & 0xFF) as u8;
338                        store_next_symbol!(code);
339                        let code = ((next_block >> 48) & 0xFF) as u8;
340                        store_next_symbol!(code);
341                        let code = ((next_block >> 56) & 0xFF) as u8;
342                        store_next_symbol!(code);
343                        in_ptr = in_ptr.add(8);
344                    } else {
345                        // Otherwise, find the first escape code and write the symbols up to that point.
346                        let first_escape_pos = escape_mask.trailing_zeros() >> 3; // Divide bits to bytes
347                        debug_assert!(first_escape_pos < 8);
348                        match first_escape_pos {
349                            7 => {
350                                let code = (next_block & 0xFF) as u8;
351                                store_next_symbol!(code);
352                                let code = ((next_block >> 8) & 0xFF) as u8;
353                                store_next_symbol!(code);
354                                let code = ((next_block >> 16) & 0xFF) as u8;
355                                store_next_symbol!(code);
356                                let code = ((next_block >> 24) & 0xFF) as u8;
357                                store_next_symbol!(code);
358                                let code = ((next_block >> 32) & 0xFF) as u8;
359                                store_next_symbol!(code);
360                                let code = ((next_block >> 40) & 0xFF) as u8;
361                                store_next_symbol!(code);
362                                let code = ((next_block >> 48) & 0xFF) as u8;
363                                store_next_symbol!(code);
364
365                                in_ptr = in_ptr.add(7);
366                            }
367                            6 => {
368                                let code = (next_block & 0xFF) as u8;
369                                store_next_symbol!(code);
370                                let code = ((next_block >> 8) & 0xFF) as u8;
371                                store_next_symbol!(code);
372                                let code = ((next_block >> 16) & 0xFF) as u8;
373                                store_next_symbol!(code);
374                                let code = ((next_block >> 24) & 0xFF) as u8;
375                                store_next_symbol!(code);
376                                let code = ((next_block >> 32) & 0xFF) as u8;
377                                store_next_symbol!(code);
378                                let code = ((next_block >> 40) & 0xFF) as u8;
379                                store_next_symbol!(code);
380
381                                let escaped = ((next_block >> 56) & 0xFF) as u8;
382                                out_ptr.write(escaped);
383                                out_ptr = out_ptr.add(1);
384
385                                in_ptr = in_ptr.add(8);
386                            }
387                            5 => {
388                                let code = (next_block & 0xFF) as u8;
389                                store_next_symbol!(code);
390                                let code = ((next_block >> 8) & 0xFF) as u8;
391                                store_next_symbol!(code);
392                                let code = ((next_block >> 16) & 0xFF) as u8;
393                                store_next_symbol!(code);
394                                let code = ((next_block >> 24) & 0xFF) as u8;
395                                store_next_symbol!(code);
396                                let code = ((next_block >> 32) & 0xFF) as u8;
397                                store_next_symbol!(code);
398
399                                let escaped = ((next_block >> 48) & 0xFF) as u8;
400                                out_ptr.write(escaped);
401                                out_ptr = out_ptr.add(1);
402
403                                in_ptr = in_ptr.add(7);
404                            }
405                            4 => {
406                                let code = (next_block & 0xFF) as u8;
407                                store_next_symbol!(code);
408                                let code = ((next_block >> 8) & 0xFF) as u8;
409                                store_next_symbol!(code);
410                                let code = ((next_block >> 16) & 0xFF) as u8;
411                                store_next_symbol!(code);
412                                let code = ((next_block >> 24) & 0xFF) as u8;
413                                store_next_symbol!(code);
414
415                                let escaped = ((next_block >> 40) & 0xFF) as u8;
416                                out_ptr.write(escaped);
417                                out_ptr = out_ptr.add(1);
418
419                                in_ptr = in_ptr.add(6);
420                            }
421                            3 => {
422                                let code = (next_block & 0xFF) as u8;
423                                store_next_symbol!(code);
424                                let code = ((next_block >> 8) & 0xFF) as u8;
425                                store_next_symbol!(code);
426                                let code = ((next_block >> 16) & 0xFF) as u8;
427                                store_next_symbol!(code);
428
429                                let escaped = ((next_block >> 32) & 0xFF) as u8;
430                                out_ptr.write(escaped);
431                                out_ptr = out_ptr.add(1);
432
433                                in_ptr = in_ptr.add(5);
434                            }
435                            2 => {
436                                let code = (next_block & 0xFF) as u8;
437                                store_next_symbol!(code);
438                                let code = ((next_block >> 8) & 0xFF) as u8;
439                                store_next_symbol!(code);
440
441                                let escaped = ((next_block >> 24) & 0xFF) as u8;
442                                out_ptr.write(escaped);
443                                out_ptr = out_ptr.add(1);
444
445                                in_ptr = in_ptr.add(4);
446                            }
447                            1 => {
448                                let code = (next_block & 0xFF) as u8;
449                                store_next_symbol!(code);
450
451                                let escaped = ((next_block >> 16) & 0xFF) as u8;
452                                out_ptr.write(escaped);
453                                out_ptr = out_ptr.add(1);
454
455                                in_ptr = in_ptr.add(3);
456                            }
457                            0 => {
458                                // Otherwise, we actually need to decompress the next byte
459                                // Extract the second byte from the u32
460                                let escaped = ((next_block >> 8) & 0xFF) as u8;
461                                in_ptr = in_ptr.add(2);
462                                out_ptr.write(escaped);
463                                out_ptr = out_ptr.add(1);
464                            }
465                            _ => unreachable!(),
466                        }
467                    }
468                }
469            }
470
471            // Otherwise, fall back to 1-byte reads using 8-byte writes where safe.
472            while out_end.offset_from(out_ptr) >= size_of::<Symbol>() as isize && in_ptr < in_end {
473                let code = in_ptr.read();
474                in_ptr = in_ptr.add(1);
475
476                if code == ESCAPE_CODE {
477                    assert!(
478                        in_ptr < in_end,
479                        "truncated compressed string: escape code at end of input"
480                    );
481                    out_ptr.write(in_ptr.read());
482                    in_ptr = in_ptr.add(1);
483                    out_ptr = out_ptr.add(1);
484                } else {
485                    store_next_symbol!(code);
486                }
487            }
488
489            // For the last few bytes (if any) where we can't do an 8-byte unaligned write.
490            while in_ptr < in_end {
491                let code = in_ptr.read();
492                in_ptr = in_ptr.add(1);
493
494                if code == ESCAPE_CODE {
495                    assert!(
496                        in_ptr < in_end,
497                        "truncated compressed string: escape code at end of input"
498                    );
499                    assert!(
500                        out_ptr.cast_const() < out_end,
501                        "output buffer sized too small"
502                    );
503                    out_ptr.write(in_ptr.read());
504                    in_ptr = in_ptr.add(1);
505                    out_ptr = out_ptr.add(1);
506                } else {
507                    let len = *self.lengths.get_unchecked(code as usize) as usize;
508                    assert!(
509                        out_end.offset_from(out_ptr) >= len as isize,
510                        "output buffer sized too small"
511                    );
512                    let sym = self.symbols.get_unchecked(code as usize).to_u64();
513                    let sym_bytes = sym.to_le_bytes();
514                    std::ptr::copy_nonoverlapping(sym_bytes.as_ptr(), out_ptr, len);
515                    out_ptr = out_ptr.add(len);
516                }
517            }
518
519            assert_eq!(
520                in_ptr, in_end,
521                "decompression should exhaust input before output"
522            );
523
524            out_ptr.offset_from(out_begin) as usize
525        }
526    }
527
528    /// Decompress a byte slice that was previously returned by a compressor using the same symbol
529    /// table into a new vector of bytes.
530    pub fn decompress(&self, compressed: &[u8]) -> Vec<u8> {
531        let mut decoded = Vec::with_capacity(self.max_decompression_capacity(compressed) + 7);
532
533        let len = self.decompress_into(compressed, decoded.spare_capacity_mut());
534        // SAFETY: len bytes have now been initialized by the decompressor.
535        unsafe { decoded.set_len(len) };
536        decoded
537    }
538}
539
540/// A compressor that uses a symbol table to greedily compress strings.
541///
542/// The `Compressor` is the central component of FSST. You can create a compressor either by
543/// default (i.e. an empty compressor), or by [training][`Self::train`] it on an input corpus of text.
544///
545/// Example usage:
546///
547/// ```
548/// use fsst::{Symbol, Compressor, CompressorBuilder};
549/// let compressor = {
550///     let mut builder = CompressorBuilder::new();
551///     builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', 0, 0, 0]), 5);
552///     builder.build()
553/// };
554///
555/// let compressed = compressor.compress("hello".as_bytes());
556/// assert_eq!(compressed, vec![0u8]);
557/// ```
558#[derive(Clone)]
559pub struct Compressor {
560    /// Table mapping codes to symbols.
561    pub(crate) symbols: Vec<Symbol>,
562
563    /// Length of each symbol, values range from 1-8.
564    pub(crate) lengths: Vec<u8>,
565
566    /// The number of entries in the symbol table that have been populated, not counting
567    /// the escape values.
568    pub(crate) n_symbols: u8,
569
570    /// Inverted index mapping 2-byte symbols to codes
571    codes_two_byte: Vec<Code>,
572
573    /// Limit of no suffixes.
574    has_suffix_code: u8,
575
576    /// Lossy perfect hash table for looking up codes to symbols that are 3 bytes or more
577    lossy_pht: LossyPHT,
578}
579
580/// The core structure of the FSST codec, holding a mapping between `Symbol`s and `Code`s.
581///
582/// The symbol table is trained on a corpus of data in the form of a single byte array, building up
583/// a mapping of 1-byte "codes" to sequences of up to 8 plaintext bytes, or "symbols".
584impl Compressor {
585    /// Using the symbol table, runs a single cycle of compression on an input word, writing
586    /// the output into `out_ptr`.
587    ///
588    /// # Returns
589    ///
590    /// This function returns a tuple of (advance_in, advance_out) with the number of bytes
591    /// for the caller to advance the input and output pointers.
592    ///
593    /// `advance_in` is the number of bytes to advance the input pointer before the next call.
594    ///
595    /// `advance_out` is the number of bytes to advance `out_ptr` before the next call.
596    ///
597    /// # Safety
598    ///
599    /// `out_ptr` must never be NULL or otherwise point to invalid memory.
600    pub unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) {
601        // Speculatively write the first byte of `word` at offset 1. This is necessary if it is an escape, and
602        // if it isn't, it will be overwritten anyway.
603        //
604        // SAFETY: caller ensures out_ptr is not null
605        let first_byte = word as u8;
606        // SAFETY: out_ptr is not null
607        unsafe { out_ptr.byte_add(1).write_unaligned(first_byte) };
608
609        // First, check the two_bytes table
610        let code_twobyte = self.codes_two_byte[word as u16 as usize];
611
612        if code_twobyte.code() < self.has_suffix_code {
613            // 2 byte code without having to worry about longer matches.
614            // SAFETY: out_ptr is not null.
615            unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
616
617            // Advance input by symbol length (2) and output by a single code byte
618            (2, 1)
619        } else {
620            // Probe the hash table
621            let entry = self.lossy_pht.lookup(word);
622
623            // Now, downshift the `word` and the `entry` to see if they align.
624            let ignored_bits = entry.ignored_bits;
625            if entry.code != Code::UNUSED
626                && compare_masked(word, entry.symbol.to_u64(), ignored_bits)
627            {
628                // Advance the input by the symbol length (variable) and the output by one code byte
629                // SAFETY: out_ptr is not null.
630                unsafe { std::ptr::write(out_ptr, entry.code.code()) };
631                (entry.code.len() as usize, 1)
632            } else {
633                // SAFETY: out_ptr is not null
634                unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
635
636                // Advance the input by the symbol length (variable) and the output by either 1
637                // byte (if was one-byte code) or two bytes (escape).
638                (
639                    code_twobyte.len() as usize,
640                    // Predicated version of:
641                    //
642                    // if entry.code >= 256 {
643                    //      2
644                    // } else {
645                    //      1
646                    // }
647                    1 + (code_twobyte.extended_code() >> 8) as usize,
648                )
649            }
650        }
651    }
652
653    /// Compress many lines in bulk.
654    pub fn compress_bulk(&self, lines: &Vec<&[u8]>) -> Vec<Vec<u8>> {
655        let mut res = Vec::new();
656
657        for line in lines {
658            res.push(self.compress(line));
659        }
660
661        res
662    }
663
664    /// Compress a string, writing its result into a target buffer.
665    ///
666    /// The target buffer is a byte vector that must have capacity large enough
667    /// to hold the encoded data.
668    ///
669    /// When this call returns, `values` will hold the compressed bytes and have
670    /// its length set to the length of the compressed text.
671    ///
672    /// ```
673    /// use fsst::{Compressor, CompressorBuilder, Symbol};
674    ///
675    /// let mut compressor = CompressorBuilder::new();
676    /// assert!(compressor.insert(Symbol::from_slice(b"aaaaaaaa"), 8));
677    ///
678    /// let compressor = compressor.build();
679    ///
680    /// let mut compressed_values = Vec::with_capacity(1_024);
681    ///
682    /// // SAFETY: we have over-sized compressed_values.
683    /// unsafe {
684    ///     compressor.compress_into(b"aaaaaaaa", &mut compressed_values);
685    /// }
686    ///
687    /// assert_eq!(compressed_values, vec![0u8]);
688    /// ```
689    ///
690    /// # Safety
691    ///
692    /// It is up to the caller to ensure the provided buffer is large enough to hold
693    /// all encoded data.
694    pub unsafe fn compress_into(&self, plaintext: &[u8], values: &mut Vec<u8>) {
695        let mut in_ptr = plaintext.as_ptr();
696        let mut out_ptr = values.as_mut_ptr();
697
698        // SAFETY: `end` will point just after the end of the `plaintext` slice.
699        let in_end = unsafe { in_ptr.byte_add(plaintext.len()) };
700        let in_end_sub8 = in_end as usize - 8;
701        // SAFETY: `end` will point just after the end of the `values` allocation.
702        let out_end = unsafe { out_ptr.byte_add(values.capacity()) };
703
704        while (in_ptr as usize) <= in_end_sub8 && unsafe { out_end.offset_from(out_ptr) } >= 2 {
705            // SAFETY: pointer ranges are checked in the loop condition
706            unsafe {
707                // Load a full 8-byte word of data from in_ptr.
708                // SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
709                let word: u64 = std::ptr::read_unaligned(in_ptr as *const u64);
710                let (advance_in, advance_out) = self.compress_word(word, out_ptr);
711                in_ptr = in_ptr.byte_add(advance_in);
712                out_ptr = out_ptr.byte_add(advance_out);
713            };
714        }
715
716        let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
717        assert!(
718            out_ptr < out_end || remaining_bytes == 0,
719            "output buffer sized too small"
720        );
721
722        let remaining_bytes = remaining_bytes as usize;
723
724        // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
725        // but shift data out of this word rather than advancing an input pointer and potentially reading
726        // unowned memory.
727        let mut bytes = [0u8; 8];
728        // SAFETY: remaining_bytes <= 8
729        unsafe { std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes) };
730        let mut last_word = u64::from_le_bytes(bytes);
731
732        while in_ptr < in_end && unsafe { out_end.offset_from(out_ptr) } >= 2 {
733            // Load a full 8-byte word of data from in_ptr.
734            // SAFETY: caller asserts in_ptr is not null
735            let (advance_in, advance_out) = unsafe { self.compress_word(last_word, out_ptr) };
736            // SAFETY: pointer ranges are checked in the loop condition
737            unsafe {
738                in_ptr = in_ptr.add(advance_in);
739                out_ptr = out_ptr.add(advance_out);
740            }
741
742            last_word = advance_8byte_word(last_word, advance_in);
743        }
744
745        // in_ptr should have exceeded in_end
746        assert!(
747            in_ptr >= in_end,
748            "exhausted output buffer before exhausting input, there is a bug in SymbolTable::compress()"
749        );
750
751        assert!(out_ptr <= out_end, "output buffer sized too small");
752
753        // SAFETY: out_ptr is derived from the `values` allocation.
754        let bytes_written = unsafe { out_ptr.offset_from(values.as_ptr()) };
755        assert!(
756            bytes_written >= 0,
757            "out_ptr ended before it started, not possible"
758        );
759
760        // SAFETY: we have initialized `bytes_written` values in the output buffer.
761        unsafe { values.set_len(bytes_written as usize) };
762    }
763
764    /// Use the symbol table to compress the plaintext into a sequence of codes and escapes.
765    pub fn compress(&self, plaintext: &[u8]) -> Vec<u8> {
766        if plaintext.is_empty() {
767            return Vec::new();
768        }
769
770        let mut buffer = Vec::with_capacity(plaintext.len() * 2);
771
772        // SAFETY: the largest compressed size would be all escapes == 2*plaintext_len
773        unsafe { self.compress_into(plaintext, &mut buffer) };
774
775        buffer
776    }
777
778    /// Access the decompressor that can be used to decompress strings emitted from this
779    /// `Compressor` instance.
780    pub fn decompressor(&self) -> Decompressor<'_> {
781        Decompressor::new(self.symbol_table(), self.symbol_lengths())
782    }
783
784    /// Returns a readonly slice of the current symbol table.
785    ///
786    /// The returned slice will have length of `n_symbols`.
787    pub fn symbol_table(&self) -> &[Symbol] {
788        &self.symbols[0..self.n_symbols as usize]
789    }
790
791    /// Returns a readonly slice where index `i` contains the
792    /// length of the symbol represented by code `i`.
793    ///
794    /// Values range from 1-8.
795    pub fn symbol_lengths(&self) -> &[u8] {
796        &self.lengths[0..self.n_symbols as usize]
797    }
798
799    /// Rebuild a compressor from an existing symbol table.
800    ///
801    /// This will not attempt to optimize or re-order the codes.
802    pub fn rebuild_from(symbols: impl AsRef<[Symbol]>, symbol_lens: impl AsRef<[u8]>) -> Self {
803        let symbols = symbols.as_ref();
804        let symbol_lens = symbol_lens.as_ref();
805
806        assert_eq!(
807            symbols.len(),
808            symbol_lens.len(),
809            "symbols and lengths differ"
810        );
811        assert!(
812            symbols.len() <= 255,
813            "symbol table len must be <= 255, was {}",
814            symbols.len()
815        );
816        validate_symbol_order(symbol_lens);
817
818        // Insert the symbols in their given order into the FSST lookup structures.
819        let symbols = symbols.to_vec();
820        let lengths = symbol_lens.to_vec();
821        let mut lossy_pht = LossyPHT::new();
822
823        let mut codes_one_byte = vec![Code::UNUSED; 256];
824
825        // Insert all of the one byte symbols first.
826        for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
827            if len == 1 {
828                codes_one_byte[symbol.first_byte() as usize] = Code::new_symbol(code as u8, 1);
829            }
830        }
831
832        // Initialize the codes_two_byte table to be all escapes
833        let mut codes_two_byte = vec![Code::UNUSED; 65_536];
834
835        // Insert the two byte symbols, possibly overwriting slots for one-byte symbols and escapes.
836        for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
837            match len {
838                2 => {
839                    codes_two_byte[symbol.first2() as usize] = Code::new_symbol(code as u8, 2);
840                }
841                3.. => {
842                    assert!(
843                        lossy_pht.insert(symbol, len as usize, code as u8),
844                        "rebuild symbol insertion into PHT must succeed"
845                    );
846                }
847                _ => { /* Covered by the 1-byte loop above. */ }
848            }
849        }
850
851        // Build the finished codes_two_byte table, subbing in unused positions with the
852        // codes_one_byte value similar to what we do in CompressBuilder::finalize.
853        for (symbol, code) in codes_two_byte.iter_mut().enumerate() {
854            if *code == Code::UNUSED {
855                *code = codes_one_byte[symbol & 0xFF];
856            }
857        }
858
859        // Find the position of the first 2-byte code that has a suffix later in the table
860        let mut has_suffix_code = 0u8;
861        for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
862            if len != 2 {
863                break;
864            }
865            let rest = &symbols[code..];
866            if rest
867                .iter()
868                .any(|&other| other.len() > 2 && symbol.first2() == other.first2())
869            {
870                has_suffix_code = code as u8;
871                break;
872            }
873        }
874
875        Compressor {
876            n_symbols: symbols.len() as u8,
877            symbols,
878            lengths,
879            codes_two_byte,
880            lossy_pht,
881            has_suffix_code,
882        }
883    }
884}
885
886#[inline]
887pub(crate) fn advance_8byte_word(word: u64, bytes: usize) -> u64 {
888    // shift the word off the low-end, because little endian means the first
889    // char is stored in the LSB.
890    //
891    // Note that even though this looks like it branches, Rust compiles this to a
892    // conditional move instruction. See `<https://godbolt.org/z/Pbvre65Pq>`
893    if bytes == 8 { 0 } else { word >> (8 * bytes) }
894}
895
896fn validate_symbol_order(symbol_lens: &[u8]) {
897    // Ensure that the symbol table is ordered by length, 23456781
898    let mut expected = 2;
899    for (idx, &len) in symbol_lens.iter().enumerate() {
900        if expected == 1 {
901            assert_eq!(
902                len, 1,
903                "symbol code={idx} should be one byte, was {len} bytes"
904            );
905        } else {
906            if len == 1 {
907                expected = 1;
908            }
909
910            // we're in the non-zero portion.
911            assert!(
912                len >= expected,
913                "symbol code={idx} breaks violates FSST symbol table ordering"
914            );
915            expected = len;
916        }
917    }
918}
919
920#[inline]
921pub(crate) fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool {
922    let mask = u64::MAX >> ignored_bits;
923    (left & mask) == right
924}
925
926#[cfg(test)]
927mod test {
928    use super::*;
929    use std::{iter, mem};
930    #[test]
931    fn test_stuff() {
932        let compressor = {
933            let mut builder = CompressorBuilder::new();
934            builder.insert(Symbol::from_slice(b"helloooo"), 8);
935            builder.build()
936        };
937
938        let decompressor = compressor.decompressor();
939
940        let mut decompressed = Vec::with_capacity(8 + 7);
941
942        let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
943        assert_eq!(len, 8);
944        unsafe { decompressed.set_len(len) };
945        assert_eq!(&decompressed, "helloooo".as_bytes());
946    }
947
948    #[test]
949    fn test_symbols_good() {
950        let symbols_u64: &[u64] = &[
951            24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
952            24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
953            6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
954            6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
955            6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
956            6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
957            6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
958            6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
959            6447715, 97, 98, 100, 99, 97, 98, 99, 100,
960        ];
961        let symbols: &[Symbol] = unsafe { mem::transmute(symbols_u64) };
962        let lens: Vec<u8> = iter::repeat_n(2u8, 16)
963            .chain(iter::repeat_n(3u8, 61))
964            .chain(iter::repeat_n(1u8, 8))
965            .collect();
966
967        let compressor = Compressor::rebuild_from(symbols, lens);
968        let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
969        assert_eq!(built_symbols, symbols_u64);
970    }
971
972    #[should_panic(expected = "assertion `left == right` failed")]
973    #[test]
974    fn test_symbols_bad() {
975        let symbols: &[u64] = &[
976            24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
977            24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
978            6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
979            6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
980            6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
981            6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
982            6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
983            6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
984            6447715, 97, 98, 100, 99, 97, 98, 99, 100,
985        ];
986        let lens: Vec<u8> = iter::repeat_n(2u8, 16)
987            .chain(iter::repeat_n(3u8, 61))
988            .chain(iter::repeat_n(1u8, 8))
989            .collect();
990
991        let mut builder = CompressorBuilder::new();
992        for (symbol, len) in symbols.iter().zip(lens.iter()) {
993            let symbol = Symbol::from_slice(&symbol.to_le_bytes());
994            builder.insert(symbol, *len as usize);
995        }
996        let compressor = builder.build();
997        let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
998        assert_eq!(built_symbols, symbols);
999    }
1000}