fsst/lib.rs
1#![doc = include_str!("../README.md")]
2#![cfg(target_endian = "little")]
3
4/// Throw a compiler error if a type isn't guaranteed to have a specific size in bytes.
5macro_rules! assert_sizeof {
6 ($typ:ty => $size_in_bytes:expr) => {
7 const _: [u8; $size_in_bytes] = [0; std::mem::size_of::<$typ>()];
8 };
9}
10
11use lossy_pht::LossyPHT;
12use std::fmt::{Debug, Formatter};
13use std::mem::MaybeUninit;
14
15mod builder;
16mod lossy_pht;
17
18pub use builder::*;
19
20/// `Symbol`s are small (up to 8-byte) segments of strings, stored in a [`Compressor`][`crate::Compressor`] and
21/// identified by an 8-bit code.
22#[derive(Copy, Clone, PartialEq, Eq, Hash)]
23pub struct Symbol(u64);
24
25assert_sizeof!(Symbol => 8);
26
27impl Symbol {
28 /// Zero value for `Symbol`.
29 pub const ZERO: Self = Self::zero();
30
31 /// Constructor for a `Symbol` from an 8-element byte slice.
32 pub fn from_slice(slice: &[u8; 8]) -> Self {
33 let num: u64 = u64::from_le_bytes(*slice);
34
35 Self(num)
36 }
37
38 /// Return a zero symbol
39 const fn zero() -> Self {
40 Self(0)
41 }
42
43 /// Create a new single-byte symbol
44 pub fn from_u8(value: u8) -> Self {
45 Self(value as u64)
46 }
47}
48
49impl Symbol {
50 /// Calculate the length of the symbol in bytes. Always a value between 1 and 8.
51 ///
52 /// Each symbol has the capacity to hold up to 8 bytes of data, but the symbols
53 /// can contain fewer bytes, padded with 0x00. There is a special case of a symbol
54 /// that holds the byte 0x00. In that case, the symbol contains `0x0000000000000000`
55 /// but we want to interpret that as a one-byte symbol containing `0x00`.
56 #[allow(clippy::len_without_is_empty)]
57 pub fn len(self) -> usize {
58 let numeric = self.0;
59 // For little-endian platforms, this counts the number of *trailing* zeros
60 let null_bytes = (numeric.leading_zeros() >> 3) as usize;
61
62 // Special case handling of a symbol with all-zeros. This is actually
63 // a 1-byte symbol containing 0x00.
64 let len = size_of::<Self>() - null_bytes;
65 if len == 0 { 1 } else { len }
66 }
67
68 /// Returns the Symbol's inner representation.
69 #[inline]
70 pub fn to_u64(self) -> u64 {
71 self.0
72 }
73
74 /// Get the first byte of the symbol as a `u8`.
75 ///
76 /// If the symbol is empty, this will return the zero byte.
77 #[inline]
78 pub fn first_byte(self) -> u8 {
79 self.0 as u8
80 }
81
82 /// Get the first two bytes of the symbol as a `u16`.
83 ///
84 /// If the Symbol is one or zero bytes, this will return `0u16`.
85 #[inline]
86 pub fn first2(self) -> u16 {
87 self.0 as u16
88 }
89
90 /// Get the first three bytes of the symbol as a `u64`.
91 ///
92 /// If the Symbol is one or zero bytes, this will return `0u64`.
93 #[inline]
94 pub fn first3(self) -> u64 {
95 self.0 & 0xFF_FF_FF
96 }
97
98 /// Return a new `Symbol` by logically concatenating ourselves with another `Symbol`.
99 pub fn concat(self, other: Self) -> Self {
100 assert!(
101 self.len() + other.len() <= 8,
102 "cannot build symbol with length > 8"
103 );
104
105 let self_len = self.len();
106
107 Self((other.0 << (8 * self_len)) | self.0)
108 }
109}
110
111impl Debug for Symbol {
112 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
113 write!(f, "[")?;
114
115 let slice = &self.0.to_le_bytes()[0..self.len()];
116 for c in slice.iter().map(|c| *c as char) {
117 if ('!'..='~').contains(&c) {
118 write!(f, "{c}")?;
119 } else if c == '\n' {
120 write!(f, " \\n ")?;
121 } else if c == '\t' {
122 write!(f, " \\t ")?;
123 } else if c == ' ' {
124 write!(f, " SPACE ")?;
125 } else {
126 write!(f, " 0x{:X?} ", c as u8)?
127 }
128 }
129
130 write!(f, "]")
131 }
132}
133
134/// A packed type containing a code value, as well as metadata about the symbol referred to by
135/// the code.
136///
137/// Logically, codes can range from 0-255 inclusive. This type holds both the 8-bit code as well as
138/// other metadata bit-packed into a `u16`.
139///
140/// The bottom 8 bits contain EITHER a code for a symbol stored in the table, OR a raw byte.
141///
142/// The interpretation depends on the 9th bit: when toggled off, the value stores a raw byte, and when
143/// toggled on, it stores a code. Thus if you examine the bottom 9 bits of the `u16`, you have an extended
144/// code range, where the values 0-255 are raw bytes, and the values 256-510 represent codes 0-254. 511 is
145/// a placeholder for the invalid code here.
146///
147/// Bits 12-15 store the length of the symbol (values ranging from 0-8).
148#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
149struct Code(u16);
150
151/// Code used to indicate bytes that are not in the symbol table.
152///
153/// When compressing a string that cannot fully be expressed with the symbol table, the compressed
154/// output will contain an `ESCAPE` byte followed by a raw byte. At decompression time, the presence
155/// of `ESCAPE` indicates that the next byte should be appended directly to the result instead of
156/// being looked up in the symbol table.
157pub const ESCAPE_CODE: u8 = 255;
158
159/// Number of bits in the `ExtendedCode` that are used to dictate a code value.
160pub const FSST_CODE_BITS: usize = 9;
161
162/// First bit of the "length" portion of an extended code.
163pub const FSST_LEN_BITS: usize = 12;
164
165/// Maximum code value in the extended code range.
166pub const FSST_CODE_MAX: u16 = 1 << FSST_CODE_BITS;
167
168/// Maximum value for the extended code range.
169///
170/// When truncated to u8 this is code 255, which is equivalent to [`ESCAPE_CODE`].
171pub const FSST_CODE_MASK: u16 = FSST_CODE_MAX - 1;
172
173/// First code in the symbol table that corresponds to a non-escape symbol.
174pub const FSST_CODE_BASE: u16 = 256;
175
176#[allow(clippy::len_without_is_empty)]
177impl Code {
178 /// Code for an unused slot in a symbol table or index.
179 ///
180 /// This corresponds to the maximum code with a length of 1.
181 pub const UNUSED: Self = Code(FSST_CODE_MASK + (1 << 12));
182
183 /// Create a new code for a symbol of given length.
184 fn new_symbol(code: u8, len: usize) -> Self {
185 Self(code as u16 + ((len as u16) << FSST_LEN_BITS))
186 }
187
188 /// Code for a new symbol during the building phase.
189 ///
190 /// The code is remapped from 0..254 to 256...510.
191 fn new_symbol_building(code: u8, len: usize) -> Self {
192 Self(code as u16 + 256 + ((len as u16) << FSST_LEN_BITS))
193 }
194
195 /// Create a new code corresponding for an escaped byte.
196 fn new_escape(byte: u8) -> Self {
197 Self((byte as u16) + (1 << FSST_LEN_BITS))
198 }
199
200 #[inline]
201 fn code(self) -> u8 {
202 self.0 as u8
203 }
204
205 #[inline]
206 fn extended_code(self) -> u16 {
207 self.0 & 0b111_111_111
208 }
209
210 #[inline]
211 fn len(self) -> u16 {
212 self.0 >> FSST_LEN_BITS
213 }
214}
215
216impl Debug for Code {
217 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
218 f.debug_struct("TrainingCode")
219 .field("code", &(self.0 as u8))
220 .field("is_escape", &(self.0 < 256))
221 .field("len", &(self.0 >> 12))
222 .finish()
223 }
224}
225
226/// Decompressor uses a symbol table to take a stream of 8-bit codes into a string.
227#[derive(Clone)]
228pub struct Decompressor<'a> {
229 /// Slice mapping codes to symbols.
230 pub(crate) symbols: &'a [Symbol],
231
232 /// Slice containing the length of each symbol in the `symbols` slice.
233 pub(crate) lengths: &'a [u8],
234}
235
236impl<'a> Decompressor<'a> {
237 /// Returns a new decompressor that uses the provided symbol table.
238 ///
239 /// # Panics
240 ///
241 /// If the provided symbol table has length greater than or equal to [`FSST_CODE_BASE`]
242 pub fn new(symbols: &'a [Symbol], lengths: &'a [u8]) -> Self {
243 assert!(
244 symbols.len() < FSST_CODE_BASE as usize,
245 "symbol table cannot have size exceeding 255"
246 );
247
248 Self { symbols, lengths }
249 }
250
251 /// Returns an upper bound on the size of the decompressed data.
252 pub fn max_decompression_capacity(&self, compressed: &[u8]) -> usize {
253 size_of::<Symbol>() * (compressed.len() + 1)
254 }
255
256 /// Decompress a slice of codes into a provided buffer.
257 ///
258 /// The provided `decoded` buffer must be at least the size of the decoded data.
259 ///
260 /// ## Panics
261 ///
262 /// If the caller fails to provide sufficient capacity in the decoded buffer.
263 /// An upper bound on the required capacity can be obtained by calling [`Self::max_decompression_capacity`].
264 ///
265 /// ## Example
266 ///
267 /// ```
268 /// use fsst::{Symbol, Compressor, CompressorBuilder};
269 /// let compressor = {
270 /// let mut builder = CompressorBuilder::new();
271 /// builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', b'o', b'o', b'o']), 8);
272 /// builder.build()
273 /// };
274 ///
275 /// let decompressor = compressor.decompressor();
276 ///
277 /// let mut decompressed = Vec::with_capacity(8);
278 ///
279 /// let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
280 /// assert_eq!(len, 8);
281 /// unsafe { decompressed.set_len(len) };
282 /// assert_eq!(&decompressed, "helloooo".as_bytes());
283 /// ```
284 pub fn decompress_into(&self, compressed: &[u8], decoded: &mut [MaybeUninit<u8>]) -> usize {
285 // Ensure the target buffer is at least half the size of the input buffer.
286 // This is the theortical smallest a valid target can be, and occurs when
287 // every input code is an escape.
288 assert!(
289 decoded.len() >= compressed.len() / 2,
290 "decoded is smaller than lower-bound decompressed size"
291 );
292
293 unsafe {
294 let mut in_ptr = compressed.as_ptr();
295 let _in_begin = in_ptr;
296 let in_end = in_ptr.add(compressed.len());
297
298 let mut out_ptr: *mut u8 = decoded.as_mut_ptr().cast();
299 let out_begin = out_ptr.cast_const();
300 let out_end = decoded.as_ptr().add(decoded.len()).cast::<u8>();
301
302 macro_rules! store_next_symbol {
303 ($code:expr) => {{
304 out_ptr
305 .cast::<u64>()
306 .write_unaligned(self.symbols.get_unchecked($code as usize).to_u64());
307 out_ptr = out_ptr.add(*self.lengths.get_unchecked($code as usize) as usize);
308 }};
309 }
310
311 // First we try loading 8 bytes at a time.
312 if decoded.len() >= 8 * size_of::<Symbol>() && compressed.len() >= 8 {
313 // Extract the loop condition since the compiler fails to do so
314 let block_out_end = out_end.sub(8 * size_of::<Symbol>());
315 let block_in_end = in_end.sub(8);
316
317 while out_ptr.cast_const() <= block_out_end && in_ptr < block_in_end {
318 // Note that we load a little-endian u64 here.
319 let next_block = in_ptr.cast::<u64>().read_unaligned();
320 let escape_mask = (next_block & 0x8080808080808080)
321 & ((((!next_block) & 0x7F7F7F7F7F7F7F7F) + 0x7F7F7F7F7F7F7F7F)
322 ^ 0x8080808080808080);
323
324 // If there are no escape codes, we write each symbol one by one.
325 if escape_mask == 0 {
326 let code = (next_block & 0xFF) as u8;
327 store_next_symbol!(code);
328 let code = ((next_block >> 8) & 0xFF) as u8;
329 store_next_symbol!(code);
330 let code = ((next_block >> 16) & 0xFF) as u8;
331 store_next_symbol!(code);
332 let code = ((next_block >> 24) & 0xFF) as u8;
333 store_next_symbol!(code);
334 let code = ((next_block >> 32) & 0xFF) as u8;
335 store_next_symbol!(code);
336 let code = ((next_block >> 40) & 0xFF) as u8;
337 store_next_symbol!(code);
338 let code = ((next_block >> 48) & 0xFF) as u8;
339 store_next_symbol!(code);
340 let code = ((next_block >> 56) & 0xFF) as u8;
341 store_next_symbol!(code);
342 in_ptr = in_ptr.add(8);
343 } else if (next_block & 0x00FF00FF00FF00FF) == 0x00FF00FF00FF00FF {
344 // All 4 even-positioned bytes are ESCAPE_CODE.
345 // Batch-extract the 4 raw bytes at odd positions.
346 out_ptr.write(((next_block >> 8) & 0xFF) as u8);
347 out_ptr.add(1).write(((next_block >> 24) & 0xFF) as u8);
348 out_ptr.add(2).write(((next_block >> 40) & 0xFF) as u8);
349 out_ptr.add(3).write(((next_block >> 56) & 0xFF) as u8);
350 out_ptr = out_ptr.add(4);
351 in_ptr = in_ptr.add(8);
352 } else {
353 // Otherwise, find the first escape code and write the symbols up to that point.
354 let first_escape_pos = escape_mask.trailing_zeros() >> 3; // Divide bits to bytes
355 debug_assert!(first_escape_pos < 8);
356 match first_escape_pos {
357 7 => {
358 let code = (next_block & 0xFF) as u8;
359 store_next_symbol!(code);
360 let code = ((next_block >> 8) & 0xFF) as u8;
361 store_next_symbol!(code);
362 let code = ((next_block >> 16) & 0xFF) as u8;
363 store_next_symbol!(code);
364 let code = ((next_block >> 24) & 0xFF) as u8;
365 store_next_symbol!(code);
366 let code = ((next_block >> 32) & 0xFF) as u8;
367 store_next_symbol!(code);
368 let code = ((next_block >> 40) & 0xFF) as u8;
369 store_next_symbol!(code);
370 let code = ((next_block >> 48) & 0xFF) as u8;
371 store_next_symbol!(code);
372
373 in_ptr = in_ptr.add(7);
374 }
375 6 => {
376 let code = (next_block & 0xFF) as u8;
377 store_next_symbol!(code);
378 let code = ((next_block >> 8) & 0xFF) as u8;
379 store_next_symbol!(code);
380 let code = ((next_block >> 16) & 0xFF) as u8;
381 store_next_symbol!(code);
382 let code = ((next_block >> 24) & 0xFF) as u8;
383 store_next_symbol!(code);
384 let code = ((next_block >> 32) & 0xFF) as u8;
385 store_next_symbol!(code);
386 let code = ((next_block >> 40) & 0xFF) as u8;
387 store_next_symbol!(code);
388
389 let escaped = ((next_block >> 56) & 0xFF) as u8;
390 out_ptr.write(escaped);
391 out_ptr = out_ptr.add(1);
392
393 in_ptr = in_ptr.add(8);
394 }
395 5 => {
396 let code = (next_block & 0xFF) as u8;
397 store_next_symbol!(code);
398 let code = ((next_block >> 8) & 0xFF) as u8;
399 store_next_symbol!(code);
400 let code = ((next_block >> 16) & 0xFF) as u8;
401 store_next_symbol!(code);
402 let code = ((next_block >> 24) & 0xFF) as u8;
403 store_next_symbol!(code);
404 let code = ((next_block >> 32) & 0xFF) as u8;
405 store_next_symbol!(code);
406
407 let escaped = ((next_block >> 48) & 0xFF) as u8;
408 out_ptr.write(escaped);
409 out_ptr = out_ptr.add(1);
410
411 in_ptr = in_ptr.add(7);
412 }
413 4 => {
414 let code = (next_block & 0xFF) as u8;
415 store_next_symbol!(code);
416 let code = ((next_block >> 8) & 0xFF) as u8;
417 store_next_symbol!(code);
418 let code = ((next_block >> 16) & 0xFF) as u8;
419 store_next_symbol!(code);
420 let code = ((next_block >> 24) & 0xFF) as u8;
421 store_next_symbol!(code);
422
423 let escaped = ((next_block >> 40) & 0xFF) as u8;
424 out_ptr.write(escaped);
425 out_ptr = out_ptr.add(1);
426
427 in_ptr = in_ptr.add(6);
428 }
429 3 => {
430 let code = (next_block & 0xFF) as u8;
431 store_next_symbol!(code);
432 let code = ((next_block >> 8) & 0xFF) as u8;
433 store_next_symbol!(code);
434 let code = ((next_block >> 16) & 0xFF) as u8;
435 store_next_symbol!(code);
436
437 let escaped = ((next_block >> 32) & 0xFF) as u8;
438 out_ptr.write(escaped);
439 out_ptr = out_ptr.add(1);
440
441 in_ptr = in_ptr.add(5);
442 }
443 2 => {
444 let code = (next_block & 0xFF) as u8;
445 store_next_symbol!(code);
446 let code = ((next_block >> 8) & 0xFF) as u8;
447 store_next_symbol!(code);
448
449 let escaped = ((next_block >> 24) & 0xFF) as u8;
450 out_ptr.write(escaped);
451 out_ptr = out_ptr.add(1);
452
453 in_ptr = in_ptr.add(4);
454 }
455 1 => {
456 let code = (next_block & 0xFF) as u8;
457 store_next_symbol!(code);
458
459 let escaped = ((next_block >> 16) & 0xFF) as u8;
460 out_ptr.write(escaped);
461 out_ptr = out_ptr.add(1);
462
463 in_ptr = in_ptr.add(3);
464 }
465 0 => {
466 // Otherwise, we actually need to decompress the next byte
467 // Extract the second byte from the u32
468 let escaped = ((next_block >> 8) & 0xFF) as u8;
469 in_ptr = in_ptr.add(2);
470 out_ptr.write(escaped);
471 out_ptr = out_ptr.add(1);
472 }
473 _ => unreachable!(),
474 }
475 }
476 }
477 }
478
479 // Otherwise, fall back to 1-byte reads using 8-byte writes where safe.
480 while out_end.offset_from(out_ptr) >= size_of::<Symbol>() as isize && in_ptr < in_end {
481 let code = in_ptr.read();
482 in_ptr = in_ptr.add(1);
483
484 if code == ESCAPE_CODE {
485 assert!(
486 in_ptr < in_end,
487 "truncated compressed string: escape code at end of input"
488 );
489 out_ptr.write(in_ptr.read());
490 in_ptr = in_ptr.add(1);
491 out_ptr = out_ptr.add(1);
492 } else {
493 store_next_symbol!(code);
494 }
495 }
496
497 // For the last few bytes (if any) where we can't do an 8-byte unaligned write.
498 while in_ptr < in_end {
499 let code = in_ptr.read();
500 in_ptr = in_ptr.add(1);
501
502 if code == ESCAPE_CODE {
503 assert!(
504 in_ptr < in_end,
505 "truncated compressed string: escape code at end of input"
506 );
507 assert!(
508 out_ptr.cast_const() < out_end,
509 "output buffer sized too small"
510 );
511 out_ptr.write(in_ptr.read());
512 in_ptr = in_ptr.add(1);
513 out_ptr = out_ptr.add(1);
514 } else {
515 let len = *self.lengths.get_unchecked(code as usize) as usize;
516 assert!(
517 out_end.offset_from(out_ptr) >= len as isize,
518 "output buffer sized too small"
519 );
520 let sym = self.symbols.get_unchecked(code as usize).to_u64();
521 let sym_bytes = sym.to_le_bytes();
522 std::ptr::copy_nonoverlapping(sym_bytes.as_ptr(), out_ptr, len);
523 out_ptr = out_ptr.add(len);
524 }
525 }
526
527 assert_eq!(
528 in_ptr, in_end,
529 "decompression should exhaust input before output"
530 );
531
532 out_ptr.offset_from(out_begin) as usize
533 }
534 }
535
536 /// Decompress a byte slice that was previously returned by a compressor using the same symbol
537 /// table into a new vector of bytes.
538 pub fn decompress(&self, compressed: &[u8]) -> Vec<u8> {
539 let mut decoded = Vec::with_capacity(self.max_decompression_capacity(compressed) + 7);
540
541 let len = self.decompress_into(compressed, decoded.spare_capacity_mut());
542 // SAFETY: len bytes have now been initialized by the decompressor.
543 unsafe { decoded.set_len(len) };
544 decoded
545 }
546}
547
548/// A compressor that uses a symbol table to greedily compress strings.
549///
550/// The `Compressor` is the central component of FSST. You can create a compressor either by
551/// default (i.e. an empty compressor), or by [training][`Self::train`] it on an input corpus of text.
552///
553/// Example usage:
554///
555/// ```
556/// use fsst::{Symbol, Compressor, CompressorBuilder};
557/// let compressor = {
558/// let mut builder = CompressorBuilder::new();
559/// builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', 0, 0, 0]), 5);
560/// builder.build()
561/// };
562///
563/// let compressed = compressor.compress("hello".as_bytes());
564/// assert_eq!(compressed, vec![0u8]);
565/// ```
566#[derive(Clone)]
567pub struct Compressor {
568 /// Table mapping codes to symbols.
569 pub(crate) symbols: Vec<Symbol>,
570
571 /// Length of each symbol, values range from 1-8.
572 pub(crate) lengths: Vec<u8>,
573
574 /// The number of entries in the symbol table that have been populated, not counting
575 /// the escape values.
576 pub(crate) n_symbols: u8,
577
578 /// Inverted index mapping 2-byte symbols to codes
579 codes_two_byte: Vec<Code>,
580
581 /// Limit of no suffixes.
582 has_suffix_code: u8,
583
584 /// Lossy perfect hash table for looking up codes to symbols that are 3 bytes or more
585 lossy_pht: LossyPHT,
586}
587
588/// The core structure of the FSST codec, holding a mapping between `Symbol`s and `Code`s.
589///
590/// The symbol table is trained on a corpus of data in the form of a single byte array, building up
591/// a mapping of 1-byte "codes" to sequences of up to 8 plaintext bytes, or "symbols".
592impl Compressor {
593 /// Using the symbol table, runs a single cycle of compression on an input word, writing
594 /// the output into `out_ptr`.
595 ///
596 /// # Returns
597 ///
598 /// This function returns a tuple of (advance_in, advance_out) with the number of bytes
599 /// for the caller to advance the input and output pointers.
600 ///
601 /// `advance_in` is the number of bytes to advance the input pointer before the next call.
602 ///
603 /// `advance_out` is the number of bytes to advance `out_ptr` before the next call.
604 ///
605 /// # Safety
606 ///
607 /// `out_ptr` must never be NULL or otherwise point to invalid memory.
608 pub unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) {
609 // Speculatively write the first byte of `word` at offset 1. This is necessary if it is an escape, and
610 // if it isn't, it will be overwritten anyway.
611 //
612 // SAFETY: caller ensures out_ptr is not null
613 let first_byte = word as u8;
614 // SAFETY: out_ptr is not null
615 unsafe { out_ptr.byte_add(1).write_unaligned(first_byte) };
616
617 // First, check the two_bytes table
618 // SAFETY: codes_two_byte has exactly 65536 entries and `word as u16` is always in [0, 65535].
619 let code_twobyte = unsafe { *self.codes_two_byte.get_unchecked(word as u16 as usize) };
620
621 if code_twobyte.code() < self.has_suffix_code {
622 // 2 byte code without having to worry about longer matches.
623 // SAFETY: out_ptr is not null.
624 unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
625
626 // Advance input by symbol length (2) and output by a single code byte
627 (2, 1)
628 } else {
629 // Probe the hash table
630 let entry = self.lossy_pht.lookup(word);
631
632 // Now, downshift the `word` and the `entry` to see if they align.
633 let ignored_bits = entry.ignored_bits;
634 if entry.code != Code::UNUSED
635 && compare_masked(word, entry.symbol.to_u64(), ignored_bits)
636 {
637 // Advance the input by the symbol length (variable) and the output by one code byte
638 // SAFETY: out_ptr is not null.
639 unsafe { std::ptr::write(out_ptr, entry.code.code()) };
640 (entry.code.len() as usize, 1)
641 } else {
642 // SAFETY: out_ptr is not null
643 unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
644
645 // Advance the input by the symbol length (variable) and the output by either 1
646 // byte (if was one-byte code) or two bytes (escape).
647 (
648 code_twobyte.len() as usize,
649 // Predicated version of:
650 //
651 // if entry.code >= 256 {
652 // 2
653 // } else {
654 // 1
655 // }
656 1 + (code_twobyte.extended_code() >> 8) as usize,
657 )
658 }
659 }
660 }
661
662 /// Compress many lines in bulk.
663 pub fn compress_bulk(&self, lines: &Vec<&[u8]>) -> Vec<Vec<u8>> {
664 let mut res = Vec::new();
665
666 for line in lines {
667 res.push(self.compress(line));
668 }
669
670 res
671 }
672
673 /// Compress a string, writing its result into a target buffer.
674 ///
675 /// The target buffer is a byte vector that must have capacity large enough
676 /// to hold the encoded data.
677 ///
678 /// When this call returns, `values` will hold the compressed bytes and have
679 /// its length set to the length of the compressed text.
680 ///
681 /// ```
682 /// use fsst::{Compressor, CompressorBuilder, Symbol};
683 ///
684 /// let mut compressor = CompressorBuilder::new();
685 /// assert!(compressor.insert(Symbol::from_slice(b"aaaaaaaa"), 8));
686 ///
687 /// let compressor = compressor.build();
688 ///
689 /// let mut compressed_values = Vec::with_capacity(1_024);
690 ///
691 /// // SAFETY: we have over-sized compressed_values.
692 /// unsafe {
693 /// compressor.compress_into(b"aaaaaaaa", &mut compressed_values);
694 /// }
695 ///
696 /// assert_eq!(compressed_values, vec![0u8]);
697 /// ```
698 ///
699 /// # Safety
700 ///
701 /// It is up to the caller to ensure the provided buffer is large enough to hold
702 /// all encoded data.
703 pub unsafe fn compress_into(&self, plaintext: &[u8], values: &mut Vec<u8>) {
704 let mut in_ptr = plaintext.as_ptr();
705 let mut out_ptr = values.as_mut_ptr();
706
707 // SAFETY: `end` will point just after the end of the `plaintext` slice.
708 let in_end = unsafe { in_ptr.byte_add(plaintext.len()) };
709 let in_end_sub8 = in_end as usize - 8;
710 // SAFETY: `end` will point just after the end of the `values` allocation.
711 let out_end = unsafe { out_ptr.byte_add(values.capacity()) };
712
713 while (in_ptr as usize) <= in_end_sub8 && unsafe { out_end.offset_from(out_ptr) } >= 2 {
714 // SAFETY: pointer ranges are checked in the loop condition
715 unsafe {
716 // Load a full 8-byte word of data from in_ptr.
717 // SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
718 let word: u64 = std::ptr::read_unaligned(in_ptr as *const u64);
719 let (advance_in, advance_out) = self.compress_word(word, out_ptr);
720 in_ptr = in_ptr.byte_add(advance_in);
721 out_ptr = out_ptr.byte_add(advance_out);
722 };
723 }
724
725 let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
726 assert!(
727 out_ptr < out_end || remaining_bytes == 0,
728 "output buffer sized too small"
729 );
730
731 let remaining_bytes = remaining_bytes as usize;
732
733 // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
734 // but shift data out of this word rather than advancing an input pointer and potentially reading
735 // unowned memory.
736 let mut bytes = [0u8; 8];
737 // SAFETY: remaining_bytes <= 8
738 unsafe { std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes) };
739 let mut last_word = u64::from_le_bytes(bytes);
740
741 while in_ptr < in_end && unsafe { out_end.offset_from(out_ptr) } >= 2 {
742 // Load a full 8-byte word of data from in_ptr.
743 // SAFETY: caller asserts in_ptr is not null
744 let (advance_in, advance_out) = unsafe { self.compress_word(last_word, out_ptr) };
745 // SAFETY: pointer ranges are checked in the loop condition
746 unsafe {
747 in_ptr = in_ptr.add(advance_in);
748 out_ptr = out_ptr.add(advance_out);
749 }
750
751 last_word = advance_8byte_word(last_word, advance_in);
752 }
753
754 // in_ptr should have exceeded in_end
755 assert!(
756 in_ptr >= in_end,
757 "exhausted output buffer before exhausting input, there is a bug in SymbolTable::compress()"
758 );
759
760 assert!(out_ptr <= out_end, "output buffer sized too small");
761
762 // SAFETY: out_ptr is derived from the `values` allocation.
763 let bytes_written = unsafe { out_ptr.offset_from(values.as_ptr()) };
764 assert!(
765 bytes_written >= 0,
766 "out_ptr ended before it started, not possible"
767 );
768
769 // SAFETY: we have initialized `bytes_written` values in the output buffer.
770 unsafe { values.set_len(bytes_written as usize) };
771 }
772
773 /// Use the symbol table to compress the plaintext into a sequence of codes and escapes.
774 pub fn compress(&self, plaintext: &[u8]) -> Vec<u8> {
775 if plaintext.is_empty() {
776 return Vec::new();
777 }
778
779 let mut buffer = Vec::with_capacity(plaintext.len() * 2);
780
781 // SAFETY: the largest compressed size would be all escapes == 2*plaintext_len
782 unsafe { self.compress_into(plaintext, &mut buffer) };
783
784 buffer
785 }
786
787 /// Access the decompressor that can be used to decompress strings emitted from this
788 /// `Compressor` instance.
789 pub fn decompressor(&self) -> Decompressor<'_> {
790 Decompressor::new(self.symbol_table(), self.symbol_lengths())
791 }
792
793 /// Returns a readonly slice of the current symbol table.
794 ///
795 /// The returned slice will have length of `n_symbols`.
796 pub fn symbol_table(&self) -> &[Symbol] {
797 &self.symbols[0..self.n_symbols as usize]
798 }
799
800 /// Returns a readonly slice where index `i` contains the
801 /// length of the symbol represented by code `i`.
802 ///
803 /// Values range from 1-8.
804 pub fn symbol_lengths(&self) -> &[u8] {
805 &self.lengths[0..self.n_symbols as usize]
806 }
807
808 /// Rebuild a compressor from an existing symbol table.
809 ///
810 /// This will not attempt to optimize or re-order the codes.
811 pub fn rebuild_from(symbols: impl AsRef<[Symbol]>, symbol_lens: impl AsRef<[u8]>) -> Self {
812 let symbols = symbols.as_ref();
813 let symbol_lens = symbol_lens.as_ref();
814
815 assert_eq!(
816 symbols.len(),
817 symbol_lens.len(),
818 "symbols and lengths differ"
819 );
820 assert!(
821 symbols.len() <= 255,
822 "symbol table len must be <= 255, was {}",
823 symbols.len()
824 );
825 validate_symbol_order(symbol_lens);
826
827 // Insert the symbols in their given order into the FSST lookup structures.
828 let symbols = symbols.to_vec();
829 let lengths = symbol_lens.to_vec();
830 let mut lossy_pht = LossyPHT::new();
831
832 let mut codes_one_byte = vec![Code::UNUSED; 256];
833
834 // Insert all of the one byte symbols first.
835 for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
836 if len == 1 {
837 codes_one_byte[symbol.first_byte() as usize] = Code::new_symbol(code as u8, 1);
838 }
839 }
840
841 // Initialize the codes_two_byte table to be all escapes
842 let mut codes_two_byte = vec![Code::UNUSED; 65_536];
843
844 // Insert the two byte symbols, possibly overwriting slots for one-byte symbols and escapes.
845 for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
846 match len {
847 2 => {
848 codes_two_byte[symbol.first2() as usize] = Code::new_symbol(code as u8, 2);
849 }
850 3.. => {
851 assert!(
852 lossy_pht.insert(symbol, len as usize, code as u8),
853 "rebuild symbol insertion into PHT must succeed"
854 );
855 }
856 _ => { /* Covered by the 1-byte loop above. */ }
857 }
858 }
859
860 // Build the finished codes_two_byte table, subbing in unused positions with the
861 // codes_one_byte value similar to what we do in CompressBuilder::finalize.
862 for (symbol, code) in codes_two_byte.iter_mut().enumerate() {
863 if *code == Code::UNUSED {
864 *code = codes_one_byte[symbol & 0xFF];
865 }
866 }
867
868 // Find the position of the first 2-byte code that has a suffix later in the table
869 let mut has_suffix_code = 0u8;
870 for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
871 if len != 2 {
872 break;
873 }
874 let rest = &symbols[code..];
875 if rest
876 .iter()
877 .any(|&other| other.len() > 2 && symbol.first2() == other.first2())
878 {
879 has_suffix_code = code as u8;
880 break;
881 }
882 }
883
884 Compressor {
885 n_symbols: symbols.len() as u8,
886 symbols,
887 lengths,
888 codes_two_byte,
889 lossy_pht,
890 has_suffix_code,
891 }
892 }
893}
894
895#[inline]
896pub(crate) fn advance_8byte_word(word: u64, bytes: usize) -> u64 {
897 // shift the word off the low-end, because little endian means the first
898 // char is stored in the LSB.
899 //
900 // Note that even though this looks like it branches, Rust compiles this to a
901 // conditional move instruction. See `<https://godbolt.org/z/Pbvre65Pq>`
902 if bytes == 8 { 0 } else { word >> (8 * bytes) }
903}
904
905fn validate_symbol_order(symbol_lens: &[u8]) {
906 // Ensure that the symbol table is ordered by length, 23456781
907 let mut expected = 2;
908 for (idx, &len) in symbol_lens.iter().enumerate() {
909 if expected == 1 {
910 assert_eq!(
911 len, 1,
912 "symbol code={idx} should be one byte, was {len} bytes"
913 );
914 } else {
915 if len == 1 {
916 expected = 1;
917 }
918
919 // we're in the non-zero portion.
920 assert!(
921 len >= expected,
922 "symbol code={idx} breaks violates FSST symbol table ordering"
923 );
924 expected = len;
925 }
926 }
927}
928
929#[inline]
930pub(crate) fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool {
931 let mask = u64::MAX >> ignored_bits;
932 (left & mask) == right
933}
934
935#[cfg(test)]
936mod test {
937 use super::*;
938 use std::{iter, mem};
939 #[test]
940 fn test_stuff() {
941 let compressor = {
942 let mut builder = CompressorBuilder::new();
943 builder.insert(Symbol::from_slice(b"helloooo"), 8);
944 builder.build()
945 };
946
947 let decompressor = compressor.decompressor();
948
949 let mut decompressed = Vec::with_capacity(8 + 7);
950
951 let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
952 assert_eq!(len, 8);
953 unsafe { decompressed.set_len(len) };
954 assert_eq!(&decompressed, "helloooo".as_bytes());
955 }
956
957 #[test]
958 fn test_symbols_good() {
959 let symbols_u64: &[u64] = &[
960 24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
961 24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
962 6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
963 6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
964 6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
965 6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
966 6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
967 6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
968 6447715, 97, 98, 100, 99, 97, 98, 99, 100,
969 ];
970 let symbols: &[Symbol] = unsafe { mem::transmute(symbols_u64) };
971 let lens: Vec<u8> = iter::repeat_n(2u8, 16)
972 .chain(iter::repeat_n(3u8, 61))
973 .chain(iter::repeat_n(1u8, 8))
974 .collect();
975
976 let compressor = Compressor::rebuild_from(symbols, lens);
977 let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
978 assert_eq!(built_symbols, symbols_u64);
979 }
980
981 #[should_panic(expected = "assertion `left == right` failed")]
982 #[test]
983 fn test_symbols_bad() {
984 let symbols: &[u64] = &[
985 24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
986 24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
987 6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
988 6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
989 6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
990 6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
991 6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
992 6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
993 6447715, 97, 98, 100, 99, 97, 98, 99, 100,
994 ];
995 let lens: Vec<u8> = iter::repeat_n(2u8, 16)
996 .chain(iter::repeat_n(3u8, 61))
997 .chain(iter::repeat_n(1u8, 8))
998 .collect();
999
1000 let mut builder = CompressorBuilder::new();
1001 for (symbol, len) in symbols.iter().zip(lens.iter()) {
1002 let symbol = Symbol::from_slice(&symbol.to_le_bytes());
1003 builder.insert(symbol, *len as usize);
1004 }
1005 let compressor = builder.build();
1006 let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
1007 assert_eq!(built_symbols, symbols);
1008 }
1009}