fsst/lib.rs
1#![doc = include_str!("../README.md")]
2#![cfg(target_endian = "little")]
3
4/// Throw a compiler error if a type isn't guaranteed to have a specific size in bytes.
5macro_rules! assert_sizeof {
6 ($typ:ty => $size_in_bytes:expr) => {
7 const _: [u8; $size_in_bytes] = [0; std::mem::size_of::<$typ>()];
8 };
9}
10
11use lossy_pht::LossyPHT;
12use std::fmt::{Debug, Formatter};
13use std::mem::MaybeUninit;
14
15mod builder;
16mod lossy_pht;
17
18pub use builder::*;
19
20/// `Symbol`s are small (up to 8-byte) segments of strings, stored in a [`Compressor`][`crate::Compressor`] and
21/// identified by an 8-bit code.
22#[derive(Copy, Clone)]
23pub struct Symbol(u64);
24
25assert_sizeof!(Symbol => 8);
26
27impl Symbol {
28 /// Zero value for `Symbol`.
29 pub const ZERO: Self = Self::zero();
30
31 /// Constructor for a `Symbol` from an 8-element byte slice.
32 pub fn from_slice(slice: &[u8; 8]) -> Self {
33 let num: u64 = u64::from_le_bytes(*slice);
34
35 Self(num)
36 }
37
38 /// Return a zero symbol
39 const fn zero() -> Self {
40 Self(0)
41 }
42
43 /// Create a new single-byte symbol
44 pub fn from_u8(value: u8) -> Self {
45 Self(value as u64)
46 }
47}
48
49impl Symbol {
50 /// Calculate the length of the symbol in bytes. Always a value between 1 and 8.
51 ///
52 /// Each symbol has the capacity to hold up to 8 bytes of data, but the symbols
53 /// can contain fewer bytes, padded with 0x00. There is a special case of a symbol
54 /// that holds the byte 0x00. In that case, the symbol contains `0x0000000000000000`
55 /// but we want to interpret that as a one-byte symbol containing `0x00`.
56 #[allow(clippy::len_without_is_empty)]
57 pub fn len(self) -> usize {
58 let numeric = self.0;
59 // For little-endian platforms, this counts the number of *trailing* zeros
60 let null_bytes = (numeric.leading_zeros() >> 3) as usize;
61
62 // Special case handling of a symbol with all-zeros. This is actually
63 // a 1-byte symbol containing 0x00.
64 let len = size_of::<Self>() - null_bytes;
65 if len == 0 { 1 } else { len }
66 }
67
68 #[inline]
69 fn as_u64(self) -> u64 {
70 self.0
71 }
72
73 /// Get the first byte of the symbol as a `u8`.
74 ///
75 /// If the symbol is empty, this will return the zero byte.
76 #[inline]
77 pub fn first_byte(self) -> u8 {
78 self.0 as u8
79 }
80
81 /// Get the first two bytes of the symbol as a `u16`.
82 ///
83 /// If the Symbol is one or zero bytes, this will return `0u16`.
84 #[inline]
85 pub fn first2(self) -> u16 {
86 self.0 as u16
87 }
88
89 /// Get the first three bytes of the symbol as a `u64`.
90 ///
91 /// If the Symbol is one or zero bytes, this will return `0u64`.
92 #[inline]
93 pub fn first3(self) -> u64 {
94 self.0 & 0xFF_FF_FF
95 }
96
97 /// Return a new `Symbol` by logically concatenating ourselves with another `Symbol`.
98 pub fn concat(self, other: Self) -> Self {
99 assert!(
100 self.len() + other.len() <= 8,
101 "cannot build symbol with length > 8"
102 );
103
104 let self_len = self.len();
105
106 Self((other.0 << (8 * self_len)) | self.0)
107 }
108}
109
110impl Debug for Symbol {
111 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
112 write!(f, "[")?;
113
114 let slice = &self.0.to_le_bytes()[0..self.len()];
115 for c in slice.iter().map(|c| *c as char) {
116 if ('!'..='~').contains(&c) {
117 write!(f, "{c}")?;
118 } else if c == '\n' {
119 write!(f, " \\n ")?;
120 } else if c == '\t' {
121 write!(f, " \\t ")?;
122 } else if c == ' ' {
123 write!(f, " SPACE ")?;
124 } else {
125 write!(f, " 0x{:X?} ", c as u8)?
126 }
127 }
128
129 write!(f, "]")
130 }
131}
132
133/// A packed type containing a code value, as well as metadata about the symbol referred to by
134/// the code.
135///
136/// Logically, codes can range from 0-255 inclusive. This type holds both the 8-bit code as well as
137/// other metadata bit-packed into a `u16`.
138///
139/// The bottom 8 bits contain EITHER a code for a symbol stored in the table, OR a raw byte.
140///
141/// The interpretation depends on the 9th bit: when toggled off, the value stores a raw byte, and when
142/// toggled on, it stores a code. Thus if you examine the bottom 9 bits of the `u16`, you have an extended
143/// code range, where the values 0-255 are raw bytes, and the values 256-510 represent codes 0-254. 511 is
144/// a placeholder for the invalid code here.
145///
146/// Bits 12-15 store the length of the symbol (values ranging from 0-8).
147#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
148struct Code(u16);
149
150/// Code used to indicate bytes that are not in the symbol table.
151///
152/// When compressing a string that cannot fully be expressed with the symbol table, the compressed
153/// output will contain an `ESCAPE` byte followed by a raw byte. At decompression time, the presence
154/// of `ESCAPE` indicates that the next byte should be appended directly to the result instead of
155/// being looked up in the symbol table.
156pub const ESCAPE_CODE: u8 = 255;
157
158/// Number of bits in the `ExtendedCode` that are used to dictate a code value.
159pub const FSST_CODE_BITS: usize = 9;
160
161/// First bit of the "length" portion of an extended code.
162pub const FSST_LEN_BITS: usize = 12;
163
164/// Maximum code value in the extended code range.
165pub const FSST_CODE_MAX: u16 = 1 << FSST_CODE_BITS;
166
167/// Maximum value for the extended code range.
168///
169/// When truncated to u8 this is code 255, which is equivalent to [`ESCAPE_CODE`].
170pub const FSST_CODE_MASK: u16 = FSST_CODE_MAX - 1;
171
172/// First code in the symbol table that corresponds to a non-escape symbol.
173pub const FSST_CODE_BASE: u16 = 256;
174
175#[allow(clippy::len_without_is_empty)]
176impl Code {
177 /// Code for an unused slot in a symbol table or index.
178 ///
179 /// This corresponds to the maximum code with a length of 1.
180 pub const UNUSED: Self = Code(FSST_CODE_MASK + (1 << 12));
181
182 /// Create a new code for a symbol of given length.
183 fn new_symbol(code: u8, len: usize) -> Self {
184 Self(code as u16 + ((len as u16) << FSST_LEN_BITS))
185 }
186
187 /// Code for a new symbol during the building phase.
188 ///
189 /// The code is remapped from 0..254 to 256...510.
190 fn new_symbol_building(code: u8, len: usize) -> Self {
191 Self(code as u16 + 256 + ((len as u16) << FSST_LEN_BITS))
192 }
193
194 /// Create a new code corresponding for an escaped byte.
195 fn new_escape(byte: u8) -> Self {
196 Self((byte as u16) + (1 << FSST_LEN_BITS))
197 }
198
199 #[inline]
200 fn code(self) -> u8 {
201 self.0 as u8
202 }
203
204 #[inline]
205 fn extended_code(self) -> u16 {
206 self.0 & 0b111_111_111
207 }
208
209 #[inline]
210 fn len(self) -> u16 {
211 self.0 >> FSST_LEN_BITS
212 }
213}
214
215impl Debug for Code {
216 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
217 f.debug_struct("TrainingCode")
218 .field("code", &(self.0 as u8))
219 .field("is_escape", &(self.0 < 256))
220 .field("len", &(self.0 >> 12))
221 .finish()
222 }
223}
224
225/// Decompressor uses a symbol table to take a stream of 8-bit codes into a string.
226#[derive(Clone)]
227pub struct Decompressor<'a> {
228 /// Slice mapping codes to symbols.
229 pub(crate) symbols: &'a [Symbol],
230
231 /// Slice containing the length of each symbol in the `symbols` slice.
232 pub(crate) lengths: &'a [u8],
233}
234
235impl<'a> Decompressor<'a> {
236 /// Returns a new decompressor that uses the provided symbol table.
237 ///
238 /// # Panics
239 ///
240 /// If the provided symbol table has length greater than 256
241 pub fn new(symbols: &'a [Symbol], lengths: &'a [u8]) -> Self {
242 assert!(
243 symbols.len() < FSST_CODE_BASE as usize,
244 "symbol table cannot have size exceeding 255"
245 );
246
247 Self { symbols, lengths }
248 }
249
250 /// Returns an upper bound on the size of the decompressed data.
251 pub fn max_decompression_capacity(&self, compressed: &[u8]) -> usize {
252 size_of::<Symbol>() * (compressed.len() + 1)
253 }
254
255 /// Decompress a slice of codes into a provided buffer.
256 ///
257 /// The provided `decoded` buffer must be at least the size of the decoded data, plus
258 /// an additional 7 bytes.
259 ///
260 /// ## Panics
261 ///
262 /// If the caller fails to provide sufficient capacity in the decoded buffer. An upper bound
263 /// on the required capacity can be obtained by calling [`Self::max_decompression_capacity`].
264 ///
265 /// ## Example
266 ///
267 /// ```
268 /// use fsst::{Symbol, Compressor, CompressorBuilder};
269 /// let compressor = {
270 /// let mut builder = CompressorBuilder::new();
271 /// builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', b'o', b'o', b'o']), 8);
272 /// builder.build()
273 /// };
274 ///
275 /// let decompressor = compressor.decompressor();
276 ///
277 /// let mut decompressed = Vec::with_capacity(8 + 7);
278 ///
279 /// let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
280 /// assert_eq!(len, 8);
281 /// unsafe { decompressed.set_len(len) };
282 /// assert_eq!(&decompressed, "helloooo".as_bytes());
283 /// ```
284 pub fn decompress_into(&self, compressed: &[u8], decoded: &mut [MaybeUninit<u8>]) -> usize {
285 // Ensure the target buffer is at least half the size of the input buffer.
286 // This is the theortical smallest a valid target can be, and occurs when
287 // every input code is an escape.
288 assert!(
289 decoded.len() >= compressed.len() / 2,
290 "decoded is smaller than lower-bound decompressed size"
291 );
292
293 unsafe {
294 let mut in_ptr = compressed.as_ptr();
295 let _in_begin = in_ptr;
296 let in_end = in_ptr.add(compressed.len());
297
298 let mut out_ptr: *mut u8 = decoded.as_mut_ptr().cast();
299 let out_begin = out_ptr.cast_const();
300 let out_end = decoded.as_ptr().add(decoded.len()).cast::<u8>();
301
302 macro_rules! store_next_symbol {
303 ($code:expr) => {{
304 out_ptr
305 .cast::<u64>()
306 .write_unaligned(self.symbols.get_unchecked($code as usize).as_u64());
307 out_ptr = out_ptr.add(*self.lengths.get_unchecked($code as usize) as usize);
308 }};
309 }
310
311 // First we try loading 8 bytes at a time.
312 if decoded.len() >= 8 * size_of::<Symbol>() && compressed.len() >= 8 {
313 // Extract the loop condition since the compiler fails to do so
314 let block_out_end = out_end.sub(8 * size_of::<Symbol>());
315 let block_in_end = in_end.sub(8);
316
317 while out_ptr.cast_const() <= block_out_end && in_ptr < block_in_end {
318 // Note that we load a little-endian u64 here.
319 let next_block = in_ptr.cast::<u64>().read_unaligned();
320 let escape_mask = (next_block & 0x8080808080808080)
321 & ((((!next_block) & 0x7F7F7F7F7F7F7F7F) + 0x7F7F7F7F7F7F7F7F)
322 ^ 0x8080808080808080);
323
324 // If there are no escape codes, we write each symbol one by one.
325 if escape_mask == 0 {
326 let code = (next_block & 0xFF) as u8;
327 store_next_symbol!(code);
328 let code = ((next_block >> 8) & 0xFF) as u8;
329 store_next_symbol!(code);
330 let code = ((next_block >> 16) & 0xFF) as u8;
331 store_next_symbol!(code);
332 let code = ((next_block >> 24) & 0xFF) as u8;
333 store_next_symbol!(code);
334 let code = ((next_block >> 32) & 0xFF) as u8;
335 store_next_symbol!(code);
336 let code = ((next_block >> 40) & 0xFF) as u8;
337 store_next_symbol!(code);
338 let code = ((next_block >> 48) & 0xFF) as u8;
339 store_next_symbol!(code);
340 let code = ((next_block >> 56) & 0xFF) as u8;
341 store_next_symbol!(code);
342 in_ptr = in_ptr.add(8);
343 } else {
344 // Otherwise, find the first escape code and write the symbols up to that point.
345 let first_escape_pos = escape_mask.trailing_zeros() >> 3; // Divide bits to bytes
346 debug_assert!(first_escape_pos < 8);
347 match first_escape_pos {
348 7 => {
349 let code = (next_block & 0xFF) as u8;
350 store_next_symbol!(code);
351 let code = ((next_block >> 8) & 0xFF) as u8;
352 store_next_symbol!(code);
353 let code = ((next_block >> 16) & 0xFF) as u8;
354 store_next_symbol!(code);
355 let code = ((next_block >> 24) & 0xFF) as u8;
356 store_next_symbol!(code);
357 let code = ((next_block >> 32) & 0xFF) as u8;
358 store_next_symbol!(code);
359 let code = ((next_block >> 40) & 0xFF) as u8;
360 store_next_symbol!(code);
361 let code = ((next_block >> 48) & 0xFF) as u8;
362 store_next_symbol!(code);
363
364 in_ptr = in_ptr.add(7);
365 }
366 6 => {
367 let code = (next_block & 0xFF) as u8;
368 store_next_symbol!(code);
369 let code = ((next_block >> 8) & 0xFF) as u8;
370 store_next_symbol!(code);
371 let code = ((next_block >> 16) & 0xFF) as u8;
372 store_next_symbol!(code);
373 let code = ((next_block >> 24) & 0xFF) as u8;
374 store_next_symbol!(code);
375 let code = ((next_block >> 32) & 0xFF) as u8;
376 store_next_symbol!(code);
377 let code = ((next_block >> 40) & 0xFF) as u8;
378 store_next_symbol!(code);
379
380 let escaped = ((next_block >> 56) & 0xFF) as u8;
381 out_ptr.write(escaped);
382 out_ptr = out_ptr.add(1);
383
384 in_ptr = in_ptr.add(8);
385 }
386 5 => {
387 let code = (next_block & 0xFF) as u8;
388 store_next_symbol!(code);
389 let code = ((next_block >> 8) & 0xFF) as u8;
390 store_next_symbol!(code);
391 let code = ((next_block >> 16) & 0xFF) as u8;
392 store_next_symbol!(code);
393 let code = ((next_block >> 24) & 0xFF) as u8;
394 store_next_symbol!(code);
395 let code = ((next_block >> 32) & 0xFF) as u8;
396 store_next_symbol!(code);
397
398 let escaped = ((next_block >> 48) & 0xFF) as u8;
399 out_ptr.write(escaped);
400 out_ptr = out_ptr.add(1);
401
402 in_ptr = in_ptr.add(7);
403 }
404 4 => {
405 let code = (next_block & 0xFF) as u8;
406 store_next_symbol!(code);
407 let code = ((next_block >> 8) & 0xFF) as u8;
408 store_next_symbol!(code);
409 let code = ((next_block >> 16) & 0xFF) as u8;
410 store_next_symbol!(code);
411 let code = ((next_block >> 24) & 0xFF) as u8;
412 store_next_symbol!(code);
413
414 let escaped = ((next_block >> 40) & 0xFF) as u8;
415 out_ptr.write(escaped);
416 out_ptr = out_ptr.add(1);
417
418 in_ptr = in_ptr.add(6);
419 }
420 3 => {
421 let code = (next_block & 0xFF) as u8;
422 store_next_symbol!(code);
423 let code = ((next_block >> 8) & 0xFF) as u8;
424 store_next_symbol!(code);
425 let code = ((next_block >> 16) & 0xFF) as u8;
426 store_next_symbol!(code);
427
428 let escaped = ((next_block >> 32) & 0xFF) as u8;
429 out_ptr.write(escaped);
430 out_ptr = out_ptr.add(1);
431
432 in_ptr = in_ptr.add(5);
433 }
434 2 => {
435 let code = (next_block & 0xFF) as u8;
436 store_next_symbol!(code);
437 let code = ((next_block >> 8) & 0xFF) as u8;
438 store_next_symbol!(code);
439
440 let escaped = ((next_block >> 24) & 0xFF) as u8;
441 out_ptr.write(escaped);
442 out_ptr = out_ptr.add(1);
443
444 in_ptr = in_ptr.add(4);
445 }
446 1 => {
447 let code = (next_block & 0xFF) as u8;
448 store_next_symbol!(code);
449
450 let escaped = ((next_block >> 16) & 0xFF) as u8;
451 out_ptr.write(escaped);
452 out_ptr = out_ptr.add(1);
453
454 in_ptr = in_ptr.add(3);
455 }
456 0 => {
457 // Otherwise, we actually need to decompress the next byte
458 // Extract the second byte from the u32
459 let escaped = ((next_block >> 8) & 0xFF) as u8;
460 in_ptr = in_ptr.add(2);
461 out_ptr.write(escaped);
462 out_ptr = out_ptr.add(1);
463 }
464 _ => unreachable!(),
465 }
466 }
467 }
468 }
469
470 // Otherwise, fall back to 1-byte reads.
471 while out_end.offset_from(out_ptr) > size_of::<Symbol>() as isize && in_ptr < in_end {
472 let code = in_ptr.read();
473 in_ptr = in_ptr.add(1);
474
475 if code == ESCAPE_CODE {
476 out_ptr.write(in_ptr.read());
477 in_ptr = in_ptr.add(1);
478 out_ptr = out_ptr.add(1);
479 } else {
480 store_next_symbol!(code);
481 }
482 }
483
484 assert_eq!(
485 in_ptr, in_end,
486 "decompression should exhaust input before output"
487 );
488
489 out_ptr.offset_from(out_begin) as usize
490 }
491 }
492
493 /// Decompress a byte slice that was previously returned by a compressor using the same symbol
494 /// table into a new vector of bytes.
495 pub fn decompress(&self, compressed: &[u8]) -> Vec<u8> {
496 let mut decoded = Vec::with_capacity(self.max_decompression_capacity(compressed) + 7);
497
498 let len = self.decompress_into(compressed, decoded.spare_capacity_mut());
499 // SAFETY: len bytes have now been initialized by the decompressor.
500 unsafe { decoded.set_len(len) };
501 decoded
502 }
503}
504
505/// A compressor that uses a symbol table to greedily compress strings.
506///
507/// The `Compressor` is the central component of FSST. You can create a compressor either by
508/// default (i.e. an empty compressor), or by [training][`Self::train`] it on an input corpus of text.
509///
510/// Example usage:
511///
512/// ```
513/// use fsst::{Symbol, Compressor, CompressorBuilder};
514/// let compressor = {
515/// let mut builder = CompressorBuilder::new();
516/// builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', 0, 0, 0]), 5);
517/// builder.build()
518/// };
519///
520/// let compressed = compressor.compress("hello".as_bytes());
521/// assert_eq!(compressed, vec![0u8]);
522/// ```
523#[derive(Clone)]
524pub struct Compressor {
525 /// Table mapping codes to symbols.
526 pub(crate) symbols: Vec<Symbol>,
527
528 /// Length of each symbol, values range from 1-8.
529 pub(crate) lengths: Vec<u8>,
530
531 /// The number of entries in the symbol table that have been populated, not counting
532 /// the escape values.
533 pub(crate) n_symbols: u8,
534
535 /// Inverted index mapping 2-byte symbols to codes
536 codes_two_byte: Vec<Code>,
537
538 /// Limit of no suffixes.
539 has_suffix_code: u8,
540
541 /// Lossy perfect hash table for looking up codes to symbols that are 3 bytes or more
542 lossy_pht: LossyPHT,
543}
544
545/// The core structure of the FSST codec, holding a mapping between `Symbol`s and `Code`s.
546///
547/// The symbol table is trained on a corpus of data in the form of a single byte array, building up
548/// a mapping of 1-byte "codes" to sequences of up to 8 plaintext bytes, or "symbols".
549impl Compressor {
550 /// Using the symbol table, runs a single cycle of compression on an input word, writing
551 /// the output into `out_ptr`.
552 ///
553 /// # Returns
554 ///
555 /// This function returns a tuple of (advance_in, advance_out) with the number of bytes
556 /// for the caller to advance the input and output pointers.
557 ///
558 /// `advance_in` is the number of bytes to advance the input pointer before the next call.
559 ///
560 /// `advance_out` is the number of bytes to advance `out_ptr` before the next call.
561 ///
562 /// # Safety
563 ///
564 /// `out_ptr` must never be NULL or otherwise point to invalid memory.
565 pub unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) {
566 // Speculatively write the first byte of `word` at offset 1. This is necessary if it is an escape, and
567 // if it isn't, it will be overwritten anyway.
568 //
569 // SAFETY: caller ensures out_ptr is not null
570 let first_byte = word as u8;
571 // SAFETY: out_ptr is not null
572 unsafe { out_ptr.byte_add(1).write_unaligned(first_byte) };
573
574 // First, check the two_bytes table
575 let code_twobyte = self.codes_two_byte[word as u16 as usize];
576
577 if code_twobyte.code() < self.has_suffix_code {
578 // 2 byte code without having to worry about longer matches.
579 // SAFETY: out_ptr is not null.
580 unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
581
582 // Advance input by symbol length (2) and output by a single code byte
583 (2, 1)
584 } else {
585 // Probe the hash table
586 let entry = self.lossy_pht.lookup(word);
587
588 // Now, downshift the `word` and the `entry` to see if they align.
589 let ignored_bits = entry.ignored_bits;
590 if entry.code != Code::UNUSED
591 && compare_masked(word, entry.symbol.as_u64(), ignored_bits)
592 {
593 // Advance the input by the symbol length (variable) and the output by one code byte
594 // SAFETY: out_ptr is not null.
595 unsafe { std::ptr::write(out_ptr, entry.code.code()) };
596 (entry.code.len() as usize, 1)
597 } else {
598 // SAFETY: out_ptr is not null
599 unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
600
601 // Advance the input by the symbol length (variable) and the output by either 1
602 // byte (if was one-byte code) or two bytes (escape).
603 (
604 code_twobyte.len() as usize,
605 // Predicated version of:
606 //
607 // if entry.code >= 256 {
608 // 2
609 // } else {
610 // 1
611 // }
612 1 + (code_twobyte.extended_code() >> 8) as usize,
613 )
614 }
615 }
616 }
617
618 /// Compress many lines in bulk.
619 pub fn compress_bulk(&self, lines: &Vec<&[u8]>) -> Vec<Vec<u8>> {
620 let mut res = Vec::new();
621
622 for line in lines {
623 res.push(self.compress(line));
624 }
625
626 res
627 }
628
629 /// Compress a string, writing its result into a target buffer.
630 ///
631 /// The target buffer is a byte vector that must have capacity large enough
632 /// to hold the encoded data.
633 ///
634 /// When this call returns, `values` will hold the compressed bytes and have
635 /// its length set to the length of the compressed text.
636 ///
637 /// ```
638 /// use fsst::{Compressor, CompressorBuilder, Symbol};
639 ///
640 /// let mut compressor = CompressorBuilder::new();
641 /// assert!(compressor.insert(Symbol::from_slice(b"aaaaaaaa"), 8));
642 ///
643 /// let compressor = compressor.build();
644 ///
645 /// let mut compressed_values = Vec::with_capacity(1_024);
646 ///
647 /// // SAFETY: we have over-sized compressed_values.
648 /// unsafe {
649 /// compressor.compress_into(b"aaaaaaaa", &mut compressed_values);
650 /// }
651 ///
652 /// assert_eq!(compressed_values, vec![0u8]);
653 /// ```
654 ///
655 /// # Safety
656 ///
657 /// It is up to the caller to ensure the provided buffer is large enough to hold
658 /// all encoded data.
659 pub unsafe fn compress_into(&self, plaintext: &[u8], values: &mut Vec<u8>) {
660 let mut in_ptr = plaintext.as_ptr();
661 let mut out_ptr = values.as_mut_ptr();
662
663 // SAFETY: `end` will point just after the end of the `plaintext` slice.
664 let in_end = unsafe { in_ptr.byte_add(plaintext.len()) };
665 let in_end_sub8 = in_end as usize - 8;
666 // SAFETY: `end` will point just after the end of the `values` allocation.
667 let out_end = unsafe { out_ptr.byte_add(values.capacity()) };
668
669 while (in_ptr as usize) <= in_end_sub8 && out_ptr < out_end {
670 // SAFETY: pointer ranges are checked in the loop condition
671 unsafe {
672 // Load a full 8-byte word of data from in_ptr.
673 // SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
674 let word: u64 = std::ptr::read_unaligned(in_ptr as *const u64);
675 let (advance_in, advance_out) = self.compress_word(word, out_ptr);
676 in_ptr = in_ptr.byte_add(advance_in);
677 out_ptr = out_ptr.byte_add(advance_out);
678 };
679 }
680
681 let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
682 assert!(
683 out_ptr < out_end || remaining_bytes == 0,
684 "output buffer sized too small"
685 );
686
687 let remaining_bytes = remaining_bytes as usize;
688
689 // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
690 // but shift data out of this word rather than advancing an input pointer and potentially reading
691 // unowned memory.
692 let mut bytes = [0u8; 8];
693 // SAFETY: remaining_bytes <= 8
694 unsafe { std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes) };
695 let mut last_word = u64::from_le_bytes(bytes);
696
697 while in_ptr < in_end && out_ptr < out_end {
698 // Load a full 8-byte word of data from in_ptr.
699 // SAFETY: caller asserts in_ptr is not null
700 let (advance_in, advance_out) = unsafe { self.compress_word(last_word, out_ptr) };
701 // SAFETY: pointer ranges are checked in the loop condition
702 unsafe {
703 in_ptr = in_ptr.add(advance_in);
704 out_ptr = out_ptr.add(advance_out);
705 }
706
707 last_word = advance_8byte_word(last_word, advance_in);
708 }
709
710 // in_ptr should have exceeded in_end
711 assert!(
712 in_ptr >= in_end,
713 "exhausted output buffer before exhausting input, there is a bug in SymbolTable::compress()"
714 );
715
716 assert!(out_ptr <= out_end, "output buffer sized too small");
717
718 // SAFETY: out_ptr is derived from the `values` allocation.
719 let bytes_written = unsafe { out_ptr.offset_from(values.as_ptr()) };
720 assert!(
721 bytes_written >= 0,
722 "out_ptr ended before it started, not possible"
723 );
724
725 // SAFETY: we have initialized `bytes_written` values in the output buffer.
726 unsafe { values.set_len(bytes_written as usize) };
727 }
728
729 /// Use the symbol table to compress the plaintext into a sequence of codes and escapes.
730 pub fn compress(&self, plaintext: &[u8]) -> Vec<u8> {
731 if plaintext.is_empty() {
732 return Vec::new();
733 }
734
735 let mut buffer = Vec::with_capacity(plaintext.len() * 2);
736
737 // SAFETY: the largest compressed size would be all escapes == 2*plaintext_len
738 unsafe { self.compress_into(plaintext, &mut buffer) };
739
740 buffer
741 }
742
743 /// Access the decompressor that can be used to decompress strings emitted from this
744 /// `Compressor` instance.
745 pub fn decompressor(&self) -> Decompressor {
746 Decompressor::new(self.symbol_table(), self.symbol_lengths())
747 }
748
749 /// Returns a readonly slice of the current symbol table.
750 ///
751 /// The returned slice will have length of `n_symbols`.
752 pub fn symbol_table(&self) -> &[Symbol] {
753 &self.symbols[0..self.n_symbols as usize]
754 }
755
756 /// Returns a readonly slice where index `i` contains the
757 /// length of the symbol represented by code `i`.
758 ///
759 /// Values range from 1-8.
760 pub fn symbol_lengths(&self) -> &[u8] {
761 &self.lengths[0..self.n_symbols as usize]
762 }
763
764 /// Rebuild a compressor from an existing symbol table.
765 ///
766 /// This will not attempt to optimize or re-order the codes.
767 pub fn rebuild_from(symbols: impl AsRef<[Symbol]>, symbol_lens: impl AsRef<[u8]>) -> Self {
768 let symbols = symbols.as_ref();
769 let symbol_lens = symbol_lens.as_ref();
770
771 assert_eq!(
772 symbols.len(),
773 symbol_lens.len(),
774 "symbols and lengths differ"
775 );
776 assert!(
777 symbols.len() <= 255,
778 "symbol table len must be <= 255, was {}",
779 symbols.len()
780 );
781 validate_symbol_order(symbol_lens);
782
783 // Insert the symbols in their given order into the FSST lookup structures.
784 let symbols = symbols.to_vec();
785 let lengths = symbol_lens.to_vec();
786 let mut lossy_pht = LossyPHT::new();
787
788 let mut codes_one_byte = vec![Code::UNUSED; 256];
789
790 // Insert all of the one byte symbols first.
791 for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
792 if len == 1 {
793 codes_one_byte[symbol.first_byte() as usize] = Code::new_symbol(code as u8, 1);
794 }
795 }
796
797 // Initialize the codes_two_byte table to be all escapes
798 let mut codes_two_byte = vec![Code::UNUSED; 65_536];
799
800 // Insert the two byte symbols, possibly overwriting slots for one-byte symbols and escapes.
801 for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
802 match len {
803 2 => {
804 codes_two_byte[symbol.first2() as usize] = Code::new_symbol(code as u8, 2);
805 }
806 3.. => {
807 assert!(
808 lossy_pht.insert(symbol, len as usize, code as u8),
809 "rebuild symbol insertion into PHT must succeed"
810 );
811 }
812 _ => { /* Covered by the 1-byte loop above. */ }
813 }
814 }
815
816 // Build the finished codes_two_byte table, subbing in unused positions with the
817 // codes_one_byte value similar to what we do in CompressBuilder::finalize.
818 for (symbol, code) in codes_two_byte.iter_mut().enumerate() {
819 if *code == Code::UNUSED {
820 *code = codes_one_byte[symbol & 0xFF];
821 }
822 }
823
824 // Find the position of the first 2-byte code that has a suffix later in the table
825 let mut has_suffix_code = 0u8;
826 for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
827 if len != 2 {
828 break;
829 }
830 let rest = &symbols[code..];
831 if rest
832 .iter()
833 .any(|&other| other.len() > 2 && symbol.first2() == other.first2())
834 {
835 has_suffix_code = code as u8;
836 break;
837 }
838 }
839
840 Compressor {
841 n_symbols: symbols.len() as u8,
842 symbols,
843 lengths,
844 codes_two_byte,
845 lossy_pht,
846 has_suffix_code,
847 }
848 }
849}
850
851#[inline]
852pub(crate) fn advance_8byte_word(word: u64, bytes: usize) -> u64 {
853 // shift the word off the low-end, because little endian means the first
854 // char is stored in the LSB.
855 //
856 // Note that even though this looks like it branches, Rust compiles this to a
857 // conditional move instruction. See `<https://godbolt.org/z/Pbvre65Pq>`
858 if bytes == 8 { 0 } else { word >> (8 * bytes) }
859}
860
861fn validate_symbol_order(symbol_lens: &[u8]) {
862 // Ensure that the symbol table is ordered by length, 23456781
863 let mut expected = 2;
864 for (idx, &len) in symbol_lens.iter().enumerate() {
865 if expected == 1 {
866 assert_eq!(
867 len, 1,
868 "symbol code={idx} should be one byte, was {len} bytes"
869 );
870 } else {
871 if len == 1 {
872 expected = 1;
873 }
874
875 // we're in the non-zero portion.
876 assert!(
877 len >= expected,
878 "symbol code={idx} breaks violates FSST symbol table ordering"
879 );
880 expected = len;
881 }
882 }
883}
884
885#[inline]
886pub(crate) fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool {
887 let mask = u64::MAX >> ignored_bits;
888 (left & mask) == right
889}
890
891#[cfg(test)]
892mod test {
893 use super::*;
894 use std::{iter, mem};
895 #[test]
896 fn test_stuff() {
897 let compressor = {
898 let mut builder = CompressorBuilder::new();
899 builder.insert(Symbol::from_slice(b"helloooo"), 8);
900 builder.build()
901 };
902
903 let decompressor = compressor.decompressor();
904
905 let mut decompressed = Vec::with_capacity(8 + 7);
906
907 let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
908 assert_eq!(len, 8);
909 unsafe { decompressed.set_len(len) };
910 assert_eq!(&decompressed, "helloooo".as_bytes());
911 }
912
913 #[test]
914 fn test_symbols_good() {
915 let symbols_u64: &[u64] = &[
916 24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
917 24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
918 6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
919 6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
920 6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
921 6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
922 6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
923 6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
924 6447715, 97, 98, 100, 99, 97, 98, 99, 100,
925 ];
926 let symbols: &[Symbol] = unsafe { mem::transmute(symbols_u64) };
927 let lens: Vec<u8> = iter::repeat_n(2u8, 16)
928 .chain(iter::repeat_n(3u8, 61))
929 .chain(iter::repeat_n(1u8, 8))
930 .collect();
931
932 let compressor = Compressor::rebuild_from(symbols, lens);
933 let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
934 assert_eq!(built_symbols, symbols_u64);
935 }
936
937 #[should_panic(expected = "assertion `left == right` failed")]
938 #[test]
939 fn test_symbols_bad() {
940 let symbols: &[u64] = &[
941 24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
942 24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
943 6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
944 6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
945 6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
946 6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
947 6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
948 6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
949 6447715, 97, 98, 100, 99, 97, 98, 99, 100,
950 ];
951 let lens: Vec<u8> = iter::repeat_n(2u8, 16)
952 .chain(iter::repeat_n(3u8, 61))
953 .chain(iter::repeat_n(1u8, 8))
954 .collect();
955
956 let mut builder = CompressorBuilder::new();
957 for (symbol, len) in symbols.iter().zip(lens.iter()) {
958 let symbol = Symbol::from_slice(&symbol.to_le_bytes());
959 builder.insert(symbol, *len as usize);
960 }
961 let compressor = builder.build();
962 let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
963 assert_eq!(built_symbols, symbols);
964 }
965}