fsst/lib.rs
1#![doc = include_str!("../README.md")]
2#![cfg(target_endian = "little")]
3
4/// Throw a compiler error if a type isn't guaranteed to have a specific size in bytes.
5macro_rules! assert_sizeof {
6 ($typ:ty => $size_in_bytes:expr) => {
7 const _: [u8; $size_in_bytes] = [0; std::mem::size_of::<$typ>()];
8 };
9}
10
11use lossy_pht::LossyPHT;
12use std::fmt::{Debug, Formatter};
13use std::mem::MaybeUninit;
14
15mod builder;
16mod lossy_pht;
17
18pub use builder::*;
19
20/// `Symbol`s are small (up to 8-byte) segments of strings, stored in a [`Compressor`][`crate::Compressor`] and
21/// identified by an 8-bit code.
22#[derive(Copy, Clone, PartialEq, Eq, Hash)]
23pub struct Symbol(u64);
24
25assert_sizeof!(Symbol => 8);
26
27impl Symbol {
28 /// Zero value for `Symbol`.
29 pub const ZERO: Self = Self::zero();
30
31 /// Constructor for a `Symbol` from an 8-element byte slice.
32 pub fn from_slice(slice: &[u8; 8]) -> Self {
33 let num: u64 = u64::from_le_bytes(*slice);
34
35 Self(num)
36 }
37
38 /// Return a zero symbol
39 const fn zero() -> Self {
40 Self(0)
41 }
42
43 /// Create a new single-byte symbol
44 pub fn from_u8(value: u8) -> Self {
45 Self(value as u64)
46 }
47}
48
49impl Symbol {
50 /// Calculate the length of the symbol in bytes. Always a value between 1 and 8.
51 ///
52 /// Each symbol has the capacity to hold up to 8 bytes of data, but the symbols
53 /// can contain fewer bytes, padded with 0x00. There is a special case of a symbol
54 /// that holds the byte 0x00. In that case, the symbol contains `0x0000000000000000`
55 /// but we want to interpret that as a one-byte symbol containing `0x00`.
56 #[allow(clippy::len_without_is_empty)]
57 pub fn len(self) -> usize {
58 let numeric = self.0;
59 // For little-endian platforms, this counts the number of *trailing* zeros
60 let null_bytes = (numeric.leading_zeros() >> 3) as usize;
61
62 // Special case handling of a symbol with all-zeros. This is actually
63 // a 1-byte symbol containing 0x00.
64 let len = size_of::<Self>() - null_bytes;
65 if len == 0 { 1 } else { len }
66 }
67
68 /// Returns the Symbol's inner representation.
69 #[inline]
70 pub fn to_u64(self) -> u64 {
71 self.0
72 }
73
74 /// Get the first byte of the symbol as a `u8`.
75 ///
76 /// If the symbol is empty, this will return the zero byte.
77 #[inline]
78 pub fn first_byte(self) -> u8 {
79 self.0 as u8
80 }
81
82 /// Get the first two bytes of the symbol as a `u16`.
83 ///
84 /// If the Symbol is one or zero bytes, this will return `0u16`.
85 #[inline]
86 pub fn first2(self) -> u16 {
87 self.0 as u16
88 }
89
90 /// Get the first three bytes of the symbol as a `u64`.
91 ///
92 /// If the Symbol is one or zero bytes, this will return `0u64`.
93 #[inline]
94 pub fn first3(self) -> u64 {
95 self.0 & 0xFF_FF_FF
96 }
97
98 /// Return a new `Symbol` by logically concatenating ourselves with another `Symbol`.
99 pub fn concat(self, other: Self) -> Self {
100 assert!(
101 self.len() + other.len() <= 8,
102 "cannot build symbol with length > 8"
103 );
104
105 let self_len = self.len();
106
107 Self((other.0 << (8 * self_len)) | self.0)
108 }
109}
110
111impl Debug for Symbol {
112 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
113 write!(f, "[")?;
114
115 let slice = &self.0.to_le_bytes()[0..self.len()];
116 for c in slice.iter().map(|c| *c as char) {
117 if ('!'..='~').contains(&c) {
118 write!(f, "{c}")?;
119 } else if c == '\n' {
120 write!(f, " \\n ")?;
121 } else if c == '\t' {
122 write!(f, " \\t ")?;
123 } else if c == ' ' {
124 write!(f, " SPACE ")?;
125 } else {
126 write!(f, " 0x{:X?} ", c as u8)?
127 }
128 }
129
130 write!(f, "]")
131 }
132}
133
134/// A packed type containing a code value, as well as metadata about the symbol referred to by
135/// the code.
136///
137/// Logically, codes can range from 0-255 inclusive. This type holds both the 8-bit code as well as
138/// other metadata bit-packed into a `u16`.
139///
140/// The bottom 8 bits contain EITHER a code for a symbol stored in the table, OR a raw byte.
141///
142/// The interpretation depends on the 9th bit: when toggled off, the value stores a raw byte, and when
143/// toggled on, it stores a code. Thus if you examine the bottom 9 bits of the `u16`, you have an extended
144/// code range, where the values 0-255 are raw bytes, and the values 256-510 represent codes 0-254. 511 is
145/// a placeholder for the invalid code here.
146///
147/// Bits 12-15 store the length of the symbol (values ranging from 0-8).
148#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
149struct Code(u16);
150
151/// Code used to indicate bytes that are not in the symbol table.
152///
153/// When compressing a string that cannot fully be expressed with the symbol table, the compressed
154/// output will contain an `ESCAPE` byte followed by a raw byte. At decompression time, the presence
155/// of `ESCAPE` indicates that the next byte should be appended directly to the result instead of
156/// being looked up in the symbol table.
157pub const ESCAPE_CODE: u8 = 255;
158
159/// Number of bits in the `ExtendedCode` that are used to dictate a code value.
160pub const FSST_CODE_BITS: usize = 9;
161
162/// First bit of the "length" portion of an extended code.
163pub const FSST_LEN_BITS: usize = 12;
164
165/// Maximum code value in the extended code range.
166pub const FSST_CODE_MAX: u16 = 1 << FSST_CODE_BITS;
167
168/// Maximum value for the extended code range.
169///
170/// When truncated to u8 this is code 255, which is equivalent to [`ESCAPE_CODE`].
171pub const FSST_CODE_MASK: u16 = FSST_CODE_MAX - 1;
172
173/// First code in the symbol table that corresponds to a non-escape symbol.
174pub const FSST_CODE_BASE: u16 = 256;
175
176#[allow(clippy::len_without_is_empty)]
177impl Code {
178 /// Code for an unused slot in a symbol table or index.
179 ///
180 /// This corresponds to the maximum code with a length of 1.
181 pub const UNUSED: Self = Code(FSST_CODE_MASK + (1 << 12));
182
183 /// Create a new code for a symbol of given length.
184 fn new_symbol(code: u8, len: usize) -> Self {
185 Self(code as u16 + ((len as u16) << FSST_LEN_BITS))
186 }
187
188 /// Code for a new symbol during the building phase.
189 ///
190 /// The code is remapped from 0..254 to 256...510.
191 fn new_symbol_building(code: u8, len: usize) -> Self {
192 Self(code as u16 + 256 + ((len as u16) << FSST_LEN_BITS))
193 }
194
195 /// Create a new code corresponding for an escaped byte.
196 fn new_escape(byte: u8) -> Self {
197 Self((byte as u16) + (1 << FSST_LEN_BITS))
198 }
199
200 #[inline]
201 fn code(self) -> u8 {
202 self.0 as u8
203 }
204
205 #[inline]
206 fn extended_code(self) -> u16 {
207 self.0 & 0b111_111_111
208 }
209
210 #[inline]
211 fn len(self) -> u16 {
212 self.0 >> FSST_LEN_BITS
213 }
214}
215
216impl Debug for Code {
217 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
218 f.debug_struct("TrainingCode")
219 .field("code", &(self.0 as u8))
220 .field("is_escape", &(self.0 < 256))
221 .field("len", &(self.0 >> 12))
222 .finish()
223 }
224}
225
226/// Decompressor uses a symbol table to take a stream of 8-bit codes into a string.
227#[derive(Clone)]
228pub struct Decompressor<'a> {
229 /// Slice mapping codes to symbols.
230 pub(crate) symbols: &'a [Symbol],
231
232 /// Slice containing the length of each symbol in the `symbols` slice.
233 pub(crate) lengths: &'a [u8],
234}
235
236impl<'a> Decompressor<'a> {
237 /// Returns a new decompressor that uses the provided symbol table.
238 ///
239 /// # Panics
240 ///
241 /// If the provided symbol table has length greater than 256
242 pub fn new(symbols: &'a [Symbol], lengths: &'a [u8]) -> Self {
243 assert!(
244 symbols.len() < FSST_CODE_BASE as usize,
245 "symbol table cannot have size exceeding 255"
246 );
247
248 Self { symbols, lengths }
249 }
250
251 /// Returns an upper bound on the size of the decompressed data.
252 pub fn max_decompression_capacity(&self, compressed: &[u8]) -> usize {
253 size_of::<Symbol>() * (compressed.len() + 1)
254 }
255
256 /// Decompress a slice of codes into a provided buffer.
257 ///
258 /// The provided `decoded` buffer must be at least the size of the decoded data, plus
259 /// an additional 7 bytes.
260 ///
261 /// ## Panics
262 ///
263 /// If the caller fails to provide sufficient capacity in the decoded buffer. An upper bound
264 /// on the required capacity can be obtained by calling [`Self::max_decompression_capacity`].
265 ///
266 /// ## Example
267 ///
268 /// ```
269 /// use fsst::{Symbol, Compressor, CompressorBuilder};
270 /// let compressor = {
271 /// let mut builder = CompressorBuilder::new();
272 /// builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', b'o', b'o', b'o']), 8);
273 /// builder.build()
274 /// };
275 ///
276 /// let decompressor = compressor.decompressor();
277 ///
278 /// let mut decompressed = Vec::with_capacity(8 + 7);
279 ///
280 /// let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
281 /// assert_eq!(len, 8);
282 /// unsafe { decompressed.set_len(len) };
283 /// assert_eq!(&decompressed, "helloooo".as_bytes());
284 /// ```
285 pub fn decompress_into(&self, compressed: &[u8], decoded: &mut [MaybeUninit<u8>]) -> usize {
286 // Ensure the target buffer is at least half the size of the input buffer.
287 // This is the theortical smallest a valid target can be, and occurs when
288 // every input code is an escape.
289 assert!(
290 decoded.len() >= compressed.len() / 2,
291 "decoded is smaller than lower-bound decompressed size"
292 );
293
294 unsafe {
295 let mut in_ptr = compressed.as_ptr();
296 let _in_begin = in_ptr;
297 let in_end = in_ptr.add(compressed.len());
298
299 let mut out_ptr: *mut u8 = decoded.as_mut_ptr().cast();
300 let out_begin = out_ptr.cast_const();
301 let out_end = decoded.as_ptr().add(decoded.len()).cast::<u8>();
302
303 macro_rules! store_next_symbol {
304 ($code:expr) => {{
305 out_ptr
306 .cast::<u64>()
307 .write_unaligned(self.symbols.get_unchecked($code as usize).to_u64());
308 out_ptr = out_ptr.add(*self.lengths.get_unchecked($code as usize) as usize);
309 }};
310 }
311
312 // First we try loading 8 bytes at a time.
313 if decoded.len() >= 8 * size_of::<Symbol>() && compressed.len() >= 8 {
314 // Extract the loop condition since the compiler fails to do so
315 let block_out_end = out_end.sub(8 * size_of::<Symbol>());
316 let block_in_end = in_end.sub(8);
317
318 while out_ptr.cast_const() <= block_out_end && in_ptr < block_in_end {
319 // Note that we load a little-endian u64 here.
320 let next_block = in_ptr.cast::<u64>().read_unaligned();
321 let escape_mask = (next_block & 0x8080808080808080)
322 & ((((!next_block) & 0x7F7F7F7F7F7F7F7F) + 0x7F7F7F7F7F7F7F7F)
323 ^ 0x8080808080808080);
324
325 // If there are no escape codes, we write each symbol one by one.
326 if escape_mask == 0 {
327 let code = (next_block & 0xFF) as u8;
328 store_next_symbol!(code);
329 let code = ((next_block >> 8) & 0xFF) as u8;
330 store_next_symbol!(code);
331 let code = ((next_block >> 16) & 0xFF) as u8;
332 store_next_symbol!(code);
333 let code = ((next_block >> 24) & 0xFF) as u8;
334 store_next_symbol!(code);
335 let code = ((next_block >> 32) & 0xFF) as u8;
336 store_next_symbol!(code);
337 let code = ((next_block >> 40) & 0xFF) as u8;
338 store_next_symbol!(code);
339 let code = ((next_block >> 48) & 0xFF) as u8;
340 store_next_symbol!(code);
341 let code = ((next_block >> 56) & 0xFF) as u8;
342 store_next_symbol!(code);
343 in_ptr = in_ptr.add(8);
344 } else if (next_block & 0x00FF00FF00FF00FF) == 0x00FF00FF00FF00FF {
345 // All 4 even-positioned bytes are ESCAPE_CODE.
346 // Batch-extract the 4 raw bytes at odd positions.
347 out_ptr.write(((next_block >> 8) & 0xFF) as u8);
348 out_ptr.add(1).write(((next_block >> 24) & 0xFF) as u8);
349 out_ptr.add(2).write(((next_block >> 40) & 0xFF) as u8);
350 out_ptr.add(3).write(((next_block >> 56) & 0xFF) as u8);
351 out_ptr = out_ptr.add(4);
352 in_ptr = in_ptr.add(8);
353 } else {
354 // Otherwise, find the first escape code and write the symbols up to that point.
355 let first_escape_pos = escape_mask.trailing_zeros() >> 3; // Divide bits to bytes
356 debug_assert!(first_escape_pos < 8);
357 match first_escape_pos {
358 7 => {
359 let code = (next_block & 0xFF) as u8;
360 store_next_symbol!(code);
361 let code = ((next_block >> 8) & 0xFF) as u8;
362 store_next_symbol!(code);
363 let code = ((next_block >> 16) & 0xFF) as u8;
364 store_next_symbol!(code);
365 let code = ((next_block >> 24) & 0xFF) as u8;
366 store_next_symbol!(code);
367 let code = ((next_block >> 32) & 0xFF) as u8;
368 store_next_symbol!(code);
369 let code = ((next_block >> 40) & 0xFF) as u8;
370 store_next_symbol!(code);
371 let code = ((next_block >> 48) & 0xFF) as u8;
372 store_next_symbol!(code);
373
374 in_ptr = in_ptr.add(7);
375 }
376 6 => {
377 let code = (next_block & 0xFF) as u8;
378 store_next_symbol!(code);
379 let code = ((next_block >> 8) & 0xFF) as u8;
380 store_next_symbol!(code);
381 let code = ((next_block >> 16) & 0xFF) as u8;
382 store_next_symbol!(code);
383 let code = ((next_block >> 24) & 0xFF) as u8;
384 store_next_symbol!(code);
385 let code = ((next_block >> 32) & 0xFF) as u8;
386 store_next_symbol!(code);
387 let code = ((next_block >> 40) & 0xFF) as u8;
388 store_next_symbol!(code);
389
390 let escaped = ((next_block >> 56) & 0xFF) as u8;
391 out_ptr.write(escaped);
392 out_ptr = out_ptr.add(1);
393
394 in_ptr = in_ptr.add(8);
395 }
396 5 => {
397 let code = (next_block & 0xFF) as u8;
398 store_next_symbol!(code);
399 let code = ((next_block >> 8) & 0xFF) as u8;
400 store_next_symbol!(code);
401 let code = ((next_block >> 16) & 0xFF) as u8;
402 store_next_symbol!(code);
403 let code = ((next_block >> 24) & 0xFF) as u8;
404 store_next_symbol!(code);
405 let code = ((next_block >> 32) & 0xFF) as u8;
406 store_next_symbol!(code);
407
408 let escaped = ((next_block >> 48) & 0xFF) as u8;
409 out_ptr.write(escaped);
410 out_ptr = out_ptr.add(1);
411
412 in_ptr = in_ptr.add(7);
413 }
414 4 => {
415 let code = (next_block & 0xFF) as u8;
416 store_next_symbol!(code);
417 let code = ((next_block >> 8) & 0xFF) as u8;
418 store_next_symbol!(code);
419 let code = ((next_block >> 16) & 0xFF) as u8;
420 store_next_symbol!(code);
421 let code = ((next_block >> 24) & 0xFF) as u8;
422 store_next_symbol!(code);
423
424 let escaped = ((next_block >> 40) & 0xFF) as u8;
425 out_ptr.write(escaped);
426 out_ptr = out_ptr.add(1);
427
428 in_ptr = in_ptr.add(6);
429 }
430 3 => {
431 let code = (next_block & 0xFF) as u8;
432 store_next_symbol!(code);
433 let code = ((next_block >> 8) & 0xFF) as u8;
434 store_next_symbol!(code);
435 let code = ((next_block >> 16) & 0xFF) as u8;
436 store_next_symbol!(code);
437
438 let escaped = ((next_block >> 32) & 0xFF) as u8;
439 out_ptr.write(escaped);
440 out_ptr = out_ptr.add(1);
441
442 in_ptr = in_ptr.add(5);
443 }
444 2 => {
445 let code = (next_block & 0xFF) as u8;
446 store_next_symbol!(code);
447 let code = ((next_block >> 8) & 0xFF) as u8;
448 store_next_symbol!(code);
449
450 let escaped = ((next_block >> 24) & 0xFF) as u8;
451 out_ptr.write(escaped);
452 out_ptr = out_ptr.add(1);
453
454 in_ptr = in_ptr.add(4);
455 }
456 1 => {
457 let code = (next_block & 0xFF) as u8;
458 store_next_symbol!(code);
459
460 let escaped = ((next_block >> 16) & 0xFF) as u8;
461 out_ptr.write(escaped);
462 out_ptr = out_ptr.add(1);
463
464 in_ptr = in_ptr.add(3);
465 }
466 0 => {
467 // Otherwise, we actually need to decompress the next byte
468 // Extract the second byte from the u32
469 let escaped = ((next_block >> 8) & 0xFF) as u8;
470 in_ptr = in_ptr.add(2);
471 out_ptr.write(escaped);
472 out_ptr = out_ptr.add(1);
473 }
474 _ => unreachable!(),
475 }
476 }
477 }
478 }
479
480 // Otherwise, fall back to 1-byte reads using 8-byte writes where safe.
481 while out_end.offset_from(out_ptr) >= size_of::<Symbol>() as isize && in_ptr < in_end {
482 let code = in_ptr.read();
483 in_ptr = in_ptr.add(1);
484
485 if code == ESCAPE_CODE {
486 assert!(
487 in_ptr < in_end,
488 "truncated compressed string: escape code at end of input"
489 );
490 out_ptr.write(in_ptr.read());
491 in_ptr = in_ptr.add(1);
492 out_ptr = out_ptr.add(1);
493 } else {
494 store_next_symbol!(code);
495 }
496 }
497
498 // For the last few bytes (if any) where we can't do an 8-byte unaligned write.
499 while in_ptr < in_end {
500 let code = in_ptr.read();
501 in_ptr = in_ptr.add(1);
502
503 if code == ESCAPE_CODE {
504 assert!(
505 in_ptr < in_end,
506 "truncated compressed string: escape code at end of input"
507 );
508 assert!(
509 out_ptr.cast_const() < out_end,
510 "output buffer sized too small"
511 );
512 out_ptr.write(in_ptr.read());
513 in_ptr = in_ptr.add(1);
514 out_ptr = out_ptr.add(1);
515 } else {
516 let len = *self.lengths.get_unchecked(code as usize) as usize;
517 assert!(
518 out_end.offset_from(out_ptr) >= len as isize,
519 "output buffer sized too small"
520 );
521 let sym = self.symbols.get_unchecked(code as usize).to_u64();
522 let sym_bytes = sym.to_le_bytes();
523 std::ptr::copy_nonoverlapping(sym_bytes.as_ptr(), out_ptr, len);
524 out_ptr = out_ptr.add(len);
525 }
526 }
527
528 assert_eq!(
529 in_ptr, in_end,
530 "decompression should exhaust input before output"
531 );
532
533 out_ptr.offset_from(out_begin) as usize
534 }
535 }
536
537 /// Decompress a byte slice that was previously returned by a compressor using the same symbol
538 /// table into a new vector of bytes.
539 pub fn decompress(&self, compressed: &[u8]) -> Vec<u8> {
540 let mut decoded = Vec::with_capacity(self.max_decompression_capacity(compressed) + 7);
541
542 let len = self.decompress_into(compressed, decoded.spare_capacity_mut());
543 // SAFETY: len bytes have now been initialized by the decompressor.
544 unsafe { decoded.set_len(len) };
545 decoded
546 }
547}
548
549/// A compressor that uses a symbol table to greedily compress strings.
550///
551/// The `Compressor` is the central component of FSST. You can create a compressor either by
552/// default (i.e. an empty compressor), or by [training][`Self::train`] it on an input corpus of text.
553///
554/// Example usage:
555///
556/// ```
557/// use fsst::{Symbol, Compressor, CompressorBuilder};
558/// let compressor = {
559/// let mut builder = CompressorBuilder::new();
560/// builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', 0, 0, 0]), 5);
561/// builder.build()
562/// };
563///
564/// let compressed = compressor.compress("hello".as_bytes());
565/// assert_eq!(compressed, vec![0u8]);
566/// ```
567#[derive(Clone)]
568pub struct Compressor {
569 /// Table mapping codes to symbols.
570 pub(crate) symbols: Vec<Symbol>,
571
572 /// Length of each symbol, values range from 1-8.
573 pub(crate) lengths: Vec<u8>,
574
575 /// The number of entries in the symbol table that have been populated, not counting
576 /// the escape values.
577 pub(crate) n_symbols: u8,
578
579 /// Inverted index mapping 2-byte symbols to codes
580 codes_two_byte: Vec<Code>,
581
582 /// Limit of no suffixes.
583 has_suffix_code: u8,
584
585 /// Lossy perfect hash table for looking up codes to symbols that are 3 bytes or more
586 lossy_pht: LossyPHT,
587}
588
589/// The core structure of the FSST codec, holding a mapping between `Symbol`s and `Code`s.
590///
591/// The symbol table is trained on a corpus of data in the form of a single byte array, building up
592/// a mapping of 1-byte "codes" to sequences of up to 8 plaintext bytes, or "symbols".
593impl Compressor {
594 /// Using the symbol table, runs a single cycle of compression on an input word, writing
595 /// the output into `out_ptr`.
596 ///
597 /// # Returns
598 ///
599 /// This function returns a tuple of (advance_in, advance_out) with the number of bytes
600 /// for the caller to advance the input and output pointers.
601 ///
602 /// `advance_in` is the number of bytes to advance the input pointer before the next call.
603 ///
604 /// `advance_out` is the number of bytes to advance `out_ptr` before the next call.
605 ///
606 /// # Safety
607 ///
608 /// `out_ptr` must never be NULL or otherwise point to invalid memory.
609 pub unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) {
610 // Speculatively write the first byte of `word` at offset 1. This is necessary if it is an escape, and
611 // if it isn't, it will be overwritten anyway.
612 //
613 // SAFETY: caller ensures out_ptr is not null
614 let first_byte = word as u8;
615 // SAFETY: out_ptr is not null
616 unsafe { out_ptr.byte_add(1).write_unaligned(first_byte) };
617
618 // First, check the two_bytes table
619 // SAFETY: codes_two_byte has exactly 65536 entries and `word as u16` is always in [0, 65535].
620 let code_twobyte = unsafe { *self.codes_two_byte.get_unchecked(word as u16 as usize) };
621
622 if code_twobyte.code() < self.has_suffix_code {
623 // 2 byte code without having to worry about longer matches.
624 // SAFETY: out_ptr is not null.
625 unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
626
627 // Advance input by symbol length (2) and output by a single code byte
628 (2, 1)
629 } else {
630 // Probe the hash table
631 let entry = self.lossy_pht.lookup(word);
632
633 // Now, downshift the `word` and the `entry` to see if they align.
634 let ignored_bits = entry.ignored_bits;
635 if entry.code != Code::UNUSED
636 && compare_masked(word, entry.symbol.to_u64(), ignored_bits)
637 {
638 // Advance the input by the symbol length (variable) and the output by one code byte
639 // SAFETY: out_ptr is not null.
640 unsafe { std::ptr::write(out_ptr, entry.code.code()) };
641 (entry.code.len() as usize, 1)
642 } else {
643 // SAFETY: out_ptr is not null
644 unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
645
646 // Advance the input by the symbol length (variable) and the output by either 1
647 // byte (if was one-byte code) or two bytes (escape).
648 (
649 code_twobyte.len() as usize,
650 // Predicated version of:
651 //
652 // if entry.code >= 256 {
653 // 2
654 // } else {
655 // 1
656 // }
657 1 + (code_twobyte.extended_code() >> 8) as usize,
658 )
659 }
660 }
661 }
662
663 /// Compress many lines in bulk.
664 pub fn compress_bulk(&self, lines: &Vec<&[u8]>) -> Vec<Vec<u8>> {
665 let mut res = Vec::new();
666
667 for line in lines {
668 res.push(self.compress(line));
669 }
670
671 res
672 }
673
674 /// Compress a string, writing its result into a target buffer.
675 ///
676 /// The target buffer is a byte vector that must have capacity large enough
677 /// to hold the encoded data.
678 ///
679 /// When this call returns, `values` will hold the compressed bytes and have
680 /// its length set to the length of the compressed text.
681 ///
682 /// ```
683 /// use fsst::{Compressor, CompressorBuilder, Symbol};
684 ///
685 /// let mut compressor = CompressorBuilder::new();
686 /// assert!(compressor.insert(Symbol::from_slice(b"aaaaaaaa"), 8));
687 ///
688 /// let compressor = compressor.build();
689 ///
690 /// let mut compressed_values = Vec::with_capacity(1_024);
691 ///
692 /// // SAFETY: we have over-sized compressed_values.
693 /// unsafe {
694 /// compressor.compress_into(b"aaaaaaaa", &mut compressed_values);
695 /// }
696 ///
697 /// assert_eq!(compressed_values, vec![0u8]);
698 /// ```
699 ///
700 /// # Safety
701 ///
702 /// It is up to the caller to ensure the provided buffer is large enough to hold
703 /// all encoded data.
704 pub unsafe fn compress_into(&self, plaintext: &[u8], values: &mut Vec<u8>) {
705 let mut in_ptr = plaintext.as_ptr();
706 let mut out_ptr = values.as_mut_ptr();
707
708 // SAFETY: `end` will point just after the end of the `plaintext` slice.
709 let in_end = unsafe { in_ptr.byte_add(plaintext.len()) };
710 let in_end_sub8 = in_end as usize - 8;
711 // SAFETY: `end` will point just after the end of the `values` allocation.
712 let out_end = unsafe { out_ptr.byte_add(values.capacity()) };
713
714 while (in_ptr as usize) <= in_end_sub8 && unsafe { out_end.offset_from(out_ptr) } >= 2 {
715 // SAFETY: pointer ranges are checked in the loop condition
716 unsafe {
717 // Load a full 8-byte word of data from in_ptr.
718 // SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
719 let word: u64 = std::ptr::read_unaligned(in_ptr as *const u64);
720 let (advance_in, advance_out) = self.compress_word(word, out_ptr);
721 in_ptr = in_ptr.byte_add(advance_in);
722 out_ptr = out_ptr.byte_add(advance_out);
723 };
724 }
725
726 let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
727 assert!(
728 out_ptr < out_end || remaining_bytes == 0,
729 "output buffer sized too small"
730 );
731
732 let remaining_bytes = remaining_bytes as usize;
733
734 // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
735 // but shift data out of this word rather than advancing an input pointer and potentially reading
736 // unowned memory.
737 let mut bytes = [0u8; 8];
738 // SAFETY: remaining_bytes <= 8
739 unsafe { std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes) };
740 let mut last_word = u64::from_le_bytes(bytes);
741
742 while in_ptr < in_end && unsafe { out_end.offset_from(out_ptr) } >= 2 {
743 // Load a full 8-byte word of data from in_ptr.
744 // SAFETY: caller asserts in_ptr is not null
745 let (advance_in, advance_out) = unsafe { self.compress_word(last_word, out_ptr) };
746 // SAFETY: pointer ranges are checked in the loop condition
747 unsafe {
748 in_ptr = in_ptr.add(advance_in);
749 out_ptr = out_ptr.add(advance_out);
750 }
751
752 last_word = advance_8byte_word(last_word, advance_in);
753 }
754
755 // in_ptr should have exceeded in_end
756 assert!(
757 in_ptr >= in_end,
758 "exhausted output buffer before exhausting input, there is a bug in SymbolTable::compress()"
759 );
760
761 assert!(out_ptr <= out_end, "output buffer sized too small");
762
763 // SAFETY: out_ptr is derived from the `values` allocation.
764 let bytes_written = unsafe { out_ptr.offset_from(values.as_ptr()) };
765 assert!(
766 bytes_written >= 0,
767 "out_ptr ended before it started, not possible"
768 );
769
770 // SAFETY: we have initialized `bytes_written` values in the output buffer.
771 unsafe { values.set_len(bytes_written as usize) };
772 }
773
774 /// Use the symbol table to compress the plaintext into a sequence of codes and escapes.
775 pub fn compress(&self, plaintext: &[u8]) -> Vec<u8> {
776 if plaintext.is_empty() {
777 return Vec::new();
778 }
779
780 let mut buffer = Vec::with_capacity(plaintext.len() * 2);
781
782 // SAFETY: the largest compressed size would be all escapes == 2*plaintext_len
783 unsafe { self.compress_into(plaintext, &mut buffer) };
784
785 buffer
786 }
787
788 /// Access the decompressor that can be used to decompress strings emitted from this
789 /// `Compressor` instance.
790 pub fn decompressor(&self) -> Decompressor<'_> {
791 Decompressor::new(self.symbol_table(), self.symbol_lengths())
792 }
793
794 /// Returns a readonly slice of the current symbol table.
795 ///
796 /// The returned slice will have length of `n_symbols`.
797 pub fn symbol_table(&self) -> &[Symbol] {
798 &self.symbols[0..self.n_symbols as usize]
799 }
800
801 /// Returns a readonly slice where index `i` contains the
802 /// length of the symbol represented by code `i`.
803 ///
804 /// Values range from 1-8.
805 pub fn symbol_lengths(&self) -> &[u8] {
806 &self.lengths[0..self.n_symbols as usize]
807 }
808
809 /// Rebuild a compressor from an existing symbol table.
810 ///
811 /// This will not attempt to optimize or re-order the codes.
812 pub fn rebuild_from(symbols: impl AsRef<[Symbol]>, symbol_lens: impl AsRef<[u8]>) -> Self {
813 let symbols = symbols.as_ref();
814 let symbol_lens = symbol_lens.as_ref();
815
816 assert_eq!(
817 symbols.len(),
818 symbol_lens.len(),
819 "symbols and lengths differ"
820 );
821 assert!(
822 symbols.len() <= 255,
823 "symbol table len must be <= 255, was {}",
824 symbols.len()
825 );
826 validate_symbol_order(symbol_lens);
827
828 // Insert the symbols in their given order into the FSST lookup structures.
829 let symbols = symbols.to_vec();
830 let lengths = symbol_lens.to_vec();
831 let mut lossy_pht = LossyPHT::new();
832
833 let mut codes_one_byte = vec![Code::UNUSED; 256];
834
835 // Insert all of the one byte symbols first.
836 for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
837 if len == 1 {
838 codes_one_byte[symbol.first_byte() as usize] = Code::new_symbol(code as u8, 1);
839 }
840 }
841
842 // Initialize the codes_two_byte table to be all escapes
843 let mut codes_two_byte = vec![Code::UNUSED; 65_536];
844
845 // Insert the two byte symbols, possibly overwriting slots for one-byte symbols and escapes.
846 for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
847 match len {
848 2 => {
849 codes_two_byte[symbol.first2() as usize] = Code::new_symbol(code as u8, 2);
850 }
851 3.. => {
852 assert!(
853 lossy_pht.insert(symbol, len as usize, code as u8),
854 "rebuild symbol insertion into PHT must succeed"
855 );
856 }
857 _ => { /* Covered by the 1-byte loop above. */ }
858 }
859 }
860
861 // Build the finished codes_two_byte table, subbing in unused positions with the
862 // codes_one_byte value similar to what we do in CompressBuilder::finalize.
863 for (symbol, code) in codes_two_byte.iter_mut().enumerate() {
864 if *code == Code::UNUSED {
865 *code = codes_one_byte[symbol & 0xFF];
866 }
867 }
868
869 // Find the position of the first 2-byte code that has a suffix later in the table
870 let mut has_suffix_code = 0u8;
871 for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
872 if len != 2 {
873 break;
874 }
875 let rest = &symbols[code..];
876 if rest
877 .iter()
878 .any(|&other| other.len() > 2 && symbol.first2() == other.first2())
879 {
880 has_suffix_code = code as u8;
881 break;
882 }
883 }
884
885 Compressor {
886 n_symbols: symbols.len() as u8,
887 symbols,
888 lengths,
889 codes_two_byte,
890 lossy_pht,
891 has_suffix_code,
892 }
893 }
894}
895
896#[inline]
897pub(crate) fn advance_8byte_word(word: u64, bytes: usize) -> u64 {
898 // shift the word off the low-end, because little endian means the first
899 // char is stored in the LSB.
900 //
901 // Note that even though this looks like it branches, Rust compiles this to a
902 // conditional move instruction. See `<https://godbolt.org/z/Pbvre65Pq>`
903 if bytes == 8 { 0 } else { word >> (8 * bytes) }
904}
905
906fn validate_symbol_order(symbol_lens: &[u8]) {
907 // Ensure that the symbol table is ordered by length, 23456781
908 let mut expected = 2;
909 for (idx, &len) in symbol_lens.iter().enumerate() {
910 if expected == 1 {
911 assert_eq!(
912 len, 1,
913 "symbol code={idx} should be one byte, was {len} bytes"
914 );
915 } else {
916 if len == 1 {
917 expected = 1;
918 }
919
920 // we're in the non-zero portion.
921 assert!(
922 len >= expected,
923 "symbol code={idx} breaks violates FSST symbol table ordering"
924 );
925 expected = len;
926 }
927 }
928}
929
930#[inline]
931pub(crate) fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool {
932 let mask = u64::MAX >> ignored_bits;
933 (left & mask) == right
934}
935
936#[cfg(test)]
937mod test {
938 use super::*;
939 use std::{iter, mem};
940 #[test]
941 fn test_stuff() {
942 let compressor = {
943 let mut builder = CompressorBuilder::new();
944 builder.insert(Symbol::from_slice(b"helloooo"), 8);
945 builder.build()
946 };
947
948 let decompressor = compressor.decompressor();
949
950 let mut decompressed = Vec::with_capacity(8 + 7);
951
952 let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
953 assert_eq!(len, 8);
954 unsafe { decompressed.set_len(len) };
955 assert_eq!(&decompressed, "helloooo".as_bytes());
956 }
957
958 #[test]
959 fn test_symbols_good() {
960 let symbols_u64: &[u64] = &[
961 24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
962 24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
963 6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
964 6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
965 6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
966 6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
967 6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
968 6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
969 6447715, 97, 98, 100, 99, 97, 98, 99, 100,
970 ];
971 let symbols: &[Symbol] = unsafe { mem::transmute(symbols_u64) };
972 let lens: Vec<u8> = iter::repeat_n(2u8, 16)
973 .chain(iter::repeat_n(3u8, 61))
974 .chain(iter::repeat_n(1u8, 8))
975 .collect();
976
977 let compressor = Compressor::rebuild_from(symbols, lens);
978 let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
979 assert_eq!(built_symbols, symbols_u64);
980 }
981
982 #[should_panic(expected = "assertion `left == right` failed")]
983 #[test]
984 fn test_symbols_bad() {
985 let symbols: &[u64] = &[
986 24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
987 24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
988 6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
989 6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
990 6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
991 6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
992 6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
993 6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
994 6447715, 97, 98, 100, 99, 97, 98, 99, 100,
995 ];
996 let lens: Vec<u8> = iter::repeat_n(2u8, 16)
997 .chain(iter::repeat_n(3u8, 61))
998 .chain(iter::repeat_n(1u8, 8))
999 .collect();
1000
1001 let mut builder = CompressorBuilder::new();
1002 for (symbol, len) in symbols.iter().zip(lens.iter()) {
1003 let symbol = Symbol::from_slice(&symbol.to_le_bytes());
1004 builder.insert(symbol, *len as usize);
1005 }
1006 let compressor = builder.build();
1007 let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
1008 assert_eq!(built_symbols, symbols);
1009 }
1010}