fsst/lib.rs
1#![doc = include_str!("../README.md")]
2#![cfg(target_endian = "little")]
3
4/// Throw a compiler error if a type isn't guaranteed to have a specific size in bytes.
5macro_rules! assert_sizeof {
6 ($typ:ty => $size_in_bytes:expr) => {
7 const _: [u8; $size_in_bytes] = [0; std::mem::size_of::<$typ>()];
8 };
9}
10
11use lossy_pht::LossyPHT;
12use std::fmt::{Debug, Formatter};
13use std::mem::MaybeUninit;
14
15mod builder;
16mod lossy_pht;
17
18pub use builder::*;
19
20/// `Symbol`s are small (up to 8-byte) segments of strings, stored in a [`Compressor`][`crate::Compressor`] and
21/// identified by an 8-bit code.
22#[derive(Copy, Clone, PartialEq, Eq, Hash)]
23pub struct Symbol(u64);
24
25assert_sizeof!(Symbol => 8);
26
27impl Symbol {
28 /// Zero value for `Symbol`.
29 pub const ZERO: Self = Self::zero();
30
31 /// Constructor for a `Symbol` from an 8-element byte slice.
32 pub fn from_slice(slice: &[u8; 8]) -> Self {
33 let num: u64 = u64::from_le_bytes(*slice);
34
35 Self(num)
36 }
37
38 /// Return a zero symbol
39 const fn zero() -> Self {
40 Self(0)
41 }
42
43 /// Create a new single-byte symbol
44 pub fn from_u8(value: u8) -> Self {
45 Self(value as u64)
46 }
47}
48
49impl Symbol {
50 /// Calculate the length of the symbol in bytes. Always a value between 1 and 8.
51 ///
52 /// Each symbol has the capacity to hold up to 8 bytes of data, but the symbols
53 /// can contain fewer bytes, padded with 0x00. There is a special case of a symbol
54 /// that holds the byte 0x00. In that case, the symbol contains `0x0000000000000000`
55 /// but we want to interpret that as a one-byte symbol containing `0x00`.
56 #[allow(clippy::len_without_is_empty)]
57 pub fn len(self) -> usize {
58 let numeric = self.0;
59 // For little-endian platforms, this counts the number of *trailing* zeros
60 let null_bytes = (numeric.leading_zeros() >> 3) as usize;
61
62 // Special case handling of a symbol with all-zeros. This is actually
63 // a 1-byte symbol containing 0x00.
64 let len = size_of::<Self>() - null_bytes;
65 if len == 0 { 1 } else { len }
66 }
67
68 /// Returns the Symbol's inner representation.
69 #[inline]
70 pub fn to_u64(self) -> u64 {
71 self.0
72 }
73
74 /// Get the first byte of the symbol as a `u8`.
75 ///
76 /// If the symbol is empty, this will return the zero byte.
77 #[inline]
78 pub fn first_byte(self) -> u8 {
79 self.0 as u8
80 }
81
82 /// Get the first two bytes of the symbol as a `u16`.
83 ///
84 /// If the Symbol is one or zero bytes, this will return `0u16`.
85 #[inline]
86 pub fn first2(self) -> u16 {
87 self.0 as u16
88 }
89
90 /// Get the first three bytes of the symbol as a `u64`.
91 ///
92 /// If the Symbol is one or zero bytes, this will return `0u64`.
93 #[inline]
94 pub fn first3(self) -> u64 {
95 self.0 & 0xFF_FF_FF
96 }
97
98 /// Return a new `Symbol` by logically concatenating ourselves with another `Symbol`.
99 pub fn concat(self, other: Self) -> Self {
100 assert!(
101 self.len() + other.len() <= 8,
102 "cannot build symbol with length > 8"
103 );
104
105 let self_len = self.len();
106
107 Self((other.0 << (8 * self_len)) | self.0)
108 }
109}
110
111impl Debug for Symbol {
112 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
113 write!(f, "[")?;
114
115 let slice = &self.0.to_le_bytes()[0..self.len()];
116 for c in slice.iter().map(|c| *c as char) {
117 if ('!'..='~').contains(&c) {
118 write!(f, "{c}")?;
119 } else if c == '\n' {
120 write!(f, " \\n ")?;
121 } else if c == '\t' {
122 write!(f, " \\t ")?;
123 } else if c == ' ' {
124 write!(f, " SPACE ")?;
125 } else {
126 write!(f, " 0x{:X?} ", c as u8)?
127 }
128 }
129
130 write!(f, "]")
131 }
132}
133
134/// A packed type containing a code value, as well as metadata about the symbol referred to by
135/// the code.
136///
137/// Logically, codes can range from 0-255 inclusive. This type holds both the 8-bit code as well as
138/// other metadata bit-packed into a `u16`.
139///
140/// The bottom 8 bits contain EITHER a code for a symbol stored in the table, OR a raw byte.
141///
142/// The interpretation depends on the 9th bit: when toggled off, the value stores a raw byte, and when
143/// toggled on, it stores a code. Thus if you examine the bottom 9 bits of the `u16`, you have an extended
144/// code range, where the values 0-255 are raw bytes, and the values 256-510 represent codes 0-254. 511 is
145/// a placeholder for the invalid code here.
146///
147/// Bits 12-15 store the length of the symbol (values ranging from 0-8).
148#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
149struct Code(u16);
150
151/// Code used to indicate bytes that are not in the symbol table.
152///
153/// When compressing a string that cannot fully be expressed with the symbol table, the compressed
154/// output will contain an `ESCAPE` byte followed by a raw byte. At decompression time, the presence
155/// of `ESCAPE` indicates that the next byte should be appended directly to the result instead of
156/// being looked up in the symbol table.
157pub const ESCAPE_CODE: u8 = 255;
158
159/// Number of bits in the `ExtendedCode` that are used to dictate a code value.
160pub const FSST_CODE_BITS: usize = 9;
161
162/// First bit of the "length" portion of an extended code.
163pub const FSST_LEN_BITS: usize = 12;
164
165/// Maximum code value in the extended code range.
166pub const FSST_CODE_MAX: u16 = 1 << FSST_CODE_BITS;
167
168/// Maximum value for the extended code range.
169///
170/// When truncated to u8 this is code 255, which is equivalent to [`ESCAPE_CODE`].
171pub const FSST_CODE_MASK: u16 = FSST_CODE_MAX - 1;
172
173/// First code in the symbol table that corresponds to a non-escape symbol.
174pub const FSST_CODE_BASE: u16 = 256;
175
176#[allow(clippy::len_without_is_empty)]
177impl Code {
178 /// Code for an unused slot in a symbol table or index.
179 ///
180 /// This corresponds to the maximum code with a length of 1.
181 pub const UNUSED: Self = Code(FSST_CODE_MASK + (1 << 12));
182
183 /// Create a new code for a symbol of given length.
184 fn new_symbol(code: u8, len: usize) -> Self {
185 Self(code as u16 + ((len as u16) << FSST_LEN_BITS))
186 }
187
188 /// Code for a new symbol during the building phase.
189 ///
190 /// The code is remapped from 0..254 to 256...510.
191 fn new_symbol_building(code: u8, len: usize) -> Self {
192 Self(code as u16 + 256 + ((len as u16) << FSST_LEN_BITS))
193 }
194
195 /// Create a new code corresponding for an escaped byte.
196 fn new_escape(byte: u8) -> Self {
197 Self((byte as u16) + (1 << FSST_LEN_BITS))
198 }
199
200 #[inline]
201 fn code(self) -> u8 {
202 self.0 as u8
203 }
204
205 #[inline]
206 fn extended_code(self) -> u16 {
207 self.0 & 0b111_111_111
208 }
209
210 #[inline]
211 fn len(self) -> u16 {
212 self.0 >> FSST_LEN_BITS
213 }
214}
215
216impl Debug for Code {
217 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
218 f.debug_struct("TrainingCode")
219 .field("code", &(self.0 as u8))
220 .field("is_escape", &(self.0 < 256))
221 .field("len", &(self.0 >> 12))
222 .finish()
223 }
224}
225
226/// Decompressor uses a symbol table to take a stream of 8-bit codes into a string.
227#[derive(Clone)]
228pub struct Decompressor<'a> {
229 /// Slice mapping codes to symbols.
230 pub(crate) symbols: &'a [Symbol],
231
232 /// Slice containing the length of each symbol in the `symbols` slice.
233 pub(crate) lengths: &'a [u8],
234}
235
236impl<'a> Decompressor<'a> {
237 /// Returns a new decompressor that uses the provided symbol table.
238 ///
239 /// # Panics
240 ///
241 /// If the provided symbol table has length greater than 256
242 pub fn new(symbols: &'a [Symbol], lengths: &'a [u8]) -> Self {
243 assert!(
244 symbols.len() < FSST_CODE_BASE as usize,
245 "symbol table cannot have size exceeding 255"
246 );
247
248 Self { symbols, lengths }
249 }
250
251 /// Returns an upper bound on the size of the decompressed data.
252 pub fn max_decompression_capacity(&self, compressed: &[u8]) -> usize {
253 size_of::<Symbol>() * (compressed.len() + 1)
254 }
255
256 /// Decompress a slice of codes into a provided buffer.
257 ///
258 /// The provided `decoded` buffer must be at least the size of the decoded data, plus
259 /// an additional 7 bytes.
260 ///
261 /// ## Panics
262 ///
263 /// If the caller fails to provide sufficient capacity in the decoded buffer. An upper bound
264 /// on the required capacity can be obtained by calling [`Self::max_decompression_capacity`].
265 ///
266 /// ## Example
267 ///
268 /// ```
269 /// use fsst::{Symbol, Compressor, CompressorBuilder};
270 /// let compressor = {
271 /// let mut builder = CompressorBuilder::new();
272 /// builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', b'o', b'o', b'o']), 8);
273 /// builder.build()
274 /// };
275 ///
276 /// let decompressor = compressor.decompressor();
277 ///
278 /// let mut decompressed = Vec::with_capacity(8 + 7);
279 ///
280 /// let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
281 /// assert_eq!(len, 8);
282 /// unsafe { decompressed.set_len(len) };
283 /// assert_eq!(&decompressed, "helloooo".as_bytes());
284 /// ```
285 pub fn decompress_into(&self, compressed: &[u8], decoded: &mut [MaybeUninit<u8>]) -> usize {
286 // Ensure the target buffer is at least half the size of the input buffer.
287 // This is the theortical smallest a valid target can be, and occurs when
288 // every input code is an escape.
289 assert!(
290 decoded.len() >= compressed.len() / 2,
291 "decoded is smaller than lower-bound decompressed size"
292 );
293
294 unsafe {
295 let mut in_ptr = compressed.as_ptr();
296 let _in_begin = in_ptr;
297 let in_end = in_ptr.add(compressed.len());
298
299 let mut out_ptr: *mut u8 = decoded.as_mut_ptr().cast();
300 let out_begin = out_ptr.cast_const();
301 let out_end = decoded.as_ptr().add(decoded.len()).cast::<u8>();
302
303 macro_rules! store_next_symbol {
304 ($code:expr) => {{
305 out_ptr
306 .cast::<u64>()
307 .write_unaligned(self.symbols.get_unchecked($code as usize).to_u64());
308 out_ptr = out_ptr.add(*self.lengths.get_unchecked($code as usize) as usize);
309 }};
310 }
311
312 // First we try loading 8 bytes at a time.
313 if decoded.len() >= 8 * size_of::<Symbol>() && compressed.len() >= 8 {
314 // Extract the loop condition since the compiler fails to do so
315 let block_out_end = out_end.sub(8 * size_of::<Symbol>());
316 let block_in_end = in_end.sub(8);
317
318 while out_ptr.cast_const() <= block_out_end && in_ptr < block_in_end {
319 // Note that we load a little-endian u64 here.
320 let next_block = in_ptr.cast::<u64>().read_unaligned();
321 let escape_mask = (next_block & 0x8080808080808080)
322 & ((((!next_block) & 0x7F7F7F7F7F7F7F7F) + 0x7F7F7F7F7F7F7F7F)
323 ^ 0x8080808080808080);
324
325 // If there are no escape codes, we write each symbol one by one.
326 if escape_mask == 0 {
327 let code = (next_block & 0xFF) as u8;
328 store_next_symbol!(code);
329 let code = ((next_block >> 8) & 0xFF) as u8;
330 store_next_symbol!(code);
331 let code = ((next_block >> 16) & 0xFF) as u8;
332 store_next_symbol!(code);
333 let code = ((next_block >> 24) & 0xFF) as u8;
334 store_next_symbol!(code);
335 let code = ((next_block >> 32) & 0xFF) as u8;
336 store_next_symbol!(code);
337 let code = ((next_block >> 40) & 0xFF) as u8;
338 store_next_symbol!(code);
339 let code = ((next_block >> 48) & 0xFF) as u8;
340 store_next_symbol!(code);
341 let code = ((next_block >> 56) & 0xFF) as u8;
342 store_next_symbol!(code);
343 in_ptr = in_ptr.add(8);
344 } else {
345 // Otherwise, find the first escape code and write the symbols up to that point.
346 let first_escape_pos = escape_mask.trailing_zeros() >> 3; // Divide bits to bytes
347 debug_assert!(first_escape_pos < 8);
348 match first_escape_pos {
349 7 => {
350 let code = (next_block & 0xFF) as u8;
351 store_next_symbol!(code);
352 let code = ((next_block >> 8) & 0xFF) as u8;
353 store_next_symbol!(code);
354 let code = ((next_block >> 16) & 0xFF) as u8;
355 store_next_symbol!(code);
356 let code = ((next_block >> 24) & 0xFF) as u8;
357 store_next_symbol!(code);
358 let code = ((next_block >> 32) & 0xFF) as u8;
359 store_next_symbol!(code);
360 let code = ((next_block >> 40) & 0xFF) as u8;
361 store_next_symbol!(code);
362 let code = ((next_block >> 48) & 0xFF) as u8;
363 store_next_symbol!(code);
364
365 in_ptr = in_ptr.add(7);
366 }
367 6 => {
368 let code = (next_block & 0xFF) as u8;
369 store_next_symbol!(code);
370 let code = ((next_block >> 8) & 0xFF) as u8;
371 store_next_symbol!(code);
372 let code = ((next_block >> 16) & 0xFF) as u8;
373 store_next_symbol!(code);
374 let code = ((next_block >> 24) & 0xFF) as u8;
375 store_next_symbol!(code);
376 let code = ((next_block >> 32) & 0xFF) as u8;
377 store_next_symbol!(code);
378 let code = ((next_block >> 40) & 0xFF) as u8;
379 store_next_symbol!(code);
380
381 let escaped = ((next_block >> 56) & 0xFF) as u8;
382 out_ptr.write(escaped);
383 out_ptr = out_ptr.add(1);
384
385 in_ptr = in_ptr.add(8);
386 }
387 5 => {
388 let code = (next_block & 0xFF) as u8;
389 store_next_symbol!(code);
390 let code = ((next_block >> 8) & 0xFF) as u8;
391 store_next_symbol!(code);
392 let code = ((next_block >> 16) & 0xFF) as u8;
393 store_next_symbol!(code);
394 let code = ((next_block >> 24) & 0xFF) as u8;
395 store_next_symbol!(code);
396 let code = ((next_block >> 32) & 0xFF) as u8;
397 store_next_symbol!(code);
398
399 let escaped = ((next_block >> 48) & 0xFF) as u8;
400 out_ptr.write(escaped);
401 out_ptr = out_ptr.add(1);
402
403 in_ptr = in_ptr.add(7);
404 }
405 4 => {
406 let code = (next_block & 0xFF) as u8;
407 store_next_symbol!(code);
408 let code = ((next_block >> 8) & 0xFF) as u8;
409 store_next_symbol!(code);
410 let code = ((next_block >> 16) & 0xFF) as u8;
411 store_next_symbol!(code);
412 let code = ((next_block >> 24) & 0xFF) as u8;
413 store_next_symbol!(code);
414
415 let escaped = ((next_block >> 40) & 0xFF) as u8;
416 out_ptr.write(escaped);
417 out_ptr = out_ptr.add(1);
418
419 in_ptr = in_ptr.add(6);
420 }
421 3 => {
422 let code = (next_block & 0xFF) as u8;
423 store_next_symbol!(code);
424 let code = ((next_block >> 8) & 0xFF) as u8;
425 store_next_symbol!(code);
426 let code = ((next_block >> 16) & 0xFF) as u8;
427 store_next_symbol!(code);
428
429 let escaped = ((next_block >> 32) & 0xFF) as u8;
430 out_ptr.write(escaped);
431 out_ptr = out_ptr.add(1);
432
433 in_ptr = in_ptr.add(5);
434 }
435 2 => {
436 let code = (next_block & 0xFF) as u8;
437 store_next_symbol!(code);
438 let code = ((next_block >> 8) & 0xFF) as u8;
439 store_next_symbol!(code);
440
441 let escaped = ((next_block >> 24) & 0xFF) as u8;
442 out_ptr.write(escaped);
443 out_ptr = out_ptr.add(1);
444
445 in_ptr = in_ptr.add(4);
446 }
447 1 => {
448 let code = (next_block & 0xFF) as u8;
449 store_next_symbol!(code);
450
451 let escaped = ((next_block >> 16) & 0xFF) as u8;
452 out_ptr.write(escaped);
453 out_ptr = out_ptr.add(1);
454
455 in_ptr = in_ptr.add(3);
456 }
457 0 => {
458 // Otherwise, we actually need to decompress the next byte
459 // Extract the second byte from the u32
460 let escaped = ((next_block >> 8) & 0xFF) as u8;
461 in_ptr = in_ptr.add(2);
462 out_ptr.write(escaped);
463 out_ptr = out_ptr.add(1);
464 }
465 _ => unreachable!(),
466 }
467 }
468 }
469 }
470
471 // Otherwise, fall back to 1-byte reads.
472 while out_end.offset_from(out_ptr) > size_of::<Symbol>() as isize && in_ptr < in_end {
473 let code = in_ptr.read();
474 in_ptr = in_ptr.add(1);
475
476 if code == ESCAPE_CODE {
477 out_ptr.write(in_ptr.read());
478 in_ptr = in_ptr.add(1);
479 out_ptr = out_ptr.add(1);
480 } else {
481 store_next_symbol!(code);
482 }
483 }
484
485 assert_eq!(
486 in_ptr, in_end,
487 "decompression should exhaust input before output"
488 );
489
490 out_ptr.offset_from(out_begin) as usize
491 }
492 }
493
494 /// Decompress a byte slice that was previously returned by a compressor using the same symbol
495 /// table into a new vector of bytes.
496 pub fn decompress(&self, compressed: &[u8]) -> Vec<u8> {
497 let mut decoded = Vec::with_capacity(self.max_decompression_capacity(compressed) + 7);
498
499 let len = self.decompress_into(compressed, decoded.spare_capacity_mut());
500 // SAFETY: len bytes have now been initialized by the decompressor.
501 unsafe { decoded.set_len(len) };
502 decoded
503 }
504}
505
506/// A compressor that uses a symbol table to greedily compress strings.
507///
508/// The `Compressor` is the central component of FSST. You can create a compressor either by
509/// default (i.e. an empty compressor), or by [training][`Self::train`] it on an input corpus of text.
510///
511/// Example usage:
512///
513/// ```
514/// use fsst::{Symbol, Compressor, CompressorBuilder};
515/// let compressor = {
516/// let mut builder = CompressorBuilder::new();
517/// builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', 0, 0, 0]), 5);
518/// builder.build()
519/// };
520///
521/// let compressed = compressor.compress("hello".as_bytes());
522/// assert_eq!(compressed, vec![0u8]);
523/// ```
524#[derive(Clone)]
525pub struct Compressor {
526 /// Table mapping codes to symbols.
527 pub(crate) symbols: Vec<Symbol>,
528
529 /// Length of each symbol, values range from 1-8.
530 pub(crate) lengths: Vec<u8>,
531
532 /// The number of entries in the symbol table that have been populated, not counting
533 /// the escape values.
534 pub(crate) n_symbols: u8,
535
536 /// Inverted index mapping 2-byte symbols to codes
537 codes_two_byte: Vec<Code>,
538
539 /// Limit of no suffixes.
540 has_suffix_code: u8,
541
542 /// Lossy perfect hash table for looking up codes to symbols that are 3 bytes or more
543 lossy_pht: LossyPHT,
544}
545
546/// The core structure of the FSST codec, holding a mapping between `Symbol`s and `Code`s.
547///
548/// The symbol table is trained on a corpus of data in the form of a single byte array, building up
549/// a mapping of 1-byte "codes" to sequences of up to 8 plaintext bytes, or "symbols".
550impl Compressor {
551 /// Using the symbol table, runs a single cycle of compression on an input word, writing
552 /// the output into `out_ptr`.
553 ///
554 /// # Returns
555 ///
556 /// This function returns a tuple of (advance_in, advance_out) with the number of bytes
557 /// for the caller to advance the input and output pointers.
558 ///
559 /// `advance_in` is the number of bytes to advance the input pointer before the next call.
560 ///
561 /// `advance_out` is the number of bytes to advance `out_ptr` before the next call.
562 ///
563 /// # Safety
564 ///
565 /// `out_ptr` must never be NULL or otherwise point to invalid memory.
566 pub unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) {
567 // Speculatively write the first byte of `word` at offset 1. This is necessary if it is an escape, and
568 // if it isn't, it will be overwritten anyway.
569 //
570 // SAFETY: caller ensures out_ptr is not null
571 let first_byte = word as u8;
572 // SAFETY: out_ptr is not null
573 unsafe { out_ptr.byte_add(1).write_unaligned(first_byte) };
574
575 // First, check the two_bytes table
576 let code_twobyte = self.codes_two_byte[word as u16 as usize];
577
578 if code_twobyte.code() < self.has_suffix_code {
579 // 2 byte code without having to worry about longer matches.
580 // SAFETY: out_ptr is not null.
581 unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
582
583 // Advance input by symbol length (2) and output by a single code byte
584 (2, 1)
585 } else {
586 // Probe the hash table
587 let entry = self.lossy_pht.lookup(word);
588
589 // Now, downshift the `word` and the `entry` to see if they align.
590 let ignored_bits = entry.ignored_bits;
591 if entry.code != Code::UNUSED
592 && compare_masked(word, entry.symbol.to_u64(), ignored_bits)
593 {
594 // Advance the input by the symbol length (variable) and the output by one code byte
595 // SAFETY: out_ptr is not null.
596 unsafe { std::ptr::write(out_ptr, entry.code.code()) };
597 (entry.code.len() as usize, 1)
598 } else {
599 // SAFETY: out_ptr is not null
600 unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
601
602 // Advance the input by the symbol length (variable) and the output by either 1
603 // byte (if was one-byte code) or two bytes (escape).
604 (
605 code_twobyte.len() as usize,
606 // Predicated version of:
607 //
608 // if entry.code >= 256 {
609 // 2
610 // } else {
611 // 1
612 // }
613 1 + (code_twobyte.extended_code() >> 8) as usize,
614 )
615 }
616 }
617 }
618
619 /// Compress many lines in bulk.
620 pub fn compress_bulk(&self, lines: &Vec<&[u8]>) -> Vec<Vec<u8>> {
621 let mut res = Vec::new();
622
623 for line in lines {
624 res.push(self.compress(line));
625 }
626
627 res
628 }
629
630 /// Compress a string, writing its result into a target buffer.
631 ///
632 /// The target buffer is a byte vector that must have capacity large enough
633 /// to hold the encoded data.
634 ///
635 /// When this call returns, `values` will hold the compressed bytes and have
636 /// its length set to the length of the compressed text.
637 ///
638 /// ```
639 /// use fsst::{Compressor, CompressorBuilder, Symbol};
640 ///
641 /// let mut compressor = CompressorBuilder::new();
642 /// assert!(compressor.insert(Symbol::from_slice(b"aaaaaaaa"), 8));
643 ///
644 /// let compressor = compressor.build();
645 ///
646 /// let mut compressed_values = Vec::with_capacity(1_024);
647 ///
648 /// // SAFETY: we have over-sized compressed_values.
649 /// unsafe {
650 /// compressor.compress_into(b"aaaaaaaa", &mut compressed_values);
651 /// }
652 ///
653 /// assert_eq!(compressed_values, vec![0u8]);
654 /// ```
655 ///
656 /// # Safety
657 ///
658 /// It is up to the caller to ensure the provided buffer is large enough to hold
659 /// all encoded data.
660 pub unsafe fn compress_into(&self, plaintext: &[u8], values: &mut Vec<u8>) {
661 let mut in_ptr = plaintext.as_ptr();
662 let mut out_ptr = values.as_mut_ptr();
663
664 // SAFETY: `end` will point just after the end of the `plaintext` slice.
665 let in_end = unsafe { in_ptr.byte_add(plaintext.len()) };
666 let in_end_sub8 = in_end as usize - 8;
667 // SAFETY: `end` will point just after the end of the `values` allocation.
668 let out_end = unsafe { out_ptr.byte_add(values.capacity()) };
669
670 while (in_ptr as usize) <= in_end_sub8 && out_ptr < out_end {
671 // SAFETY: pointer ranges are checked in the loop condition
672 unsafe {
673 // Load a full 8-byte word of data from in_ptr.
674 // SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
675 let word: u64 = std::ptr::read_unaligned(in_ptr as *const u64);
676 let (advance_in, advance_out) = self.compress_word(word, out_ptr);
677 in_ptr = in_ptr.byte_add(advance_in);
678 out_ptr = out_ptr.byte_add(advance_out);
679 };
680 }
681
682 let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
683 assert!(
684 out_ptr < out_end || remaining_bytes == 0,
685 "output buffer sized too small"
686 );
687
688 let remaining_bytes = remaining_bytes as usize;
689
690 // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
691 // but shift data out of this word rather than advancing an input pointer and potentially reading
692 // unowned memory.
693 let mut bytes = [0u8; 8];
694 // SAFETY: remaining_bytes <= 8
695 unsafe { std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes) };
696 let mut last_word = u64::from_le_bytes(bytes);
697
698 while in_ptr < in_end && out_ptr < out_end {
699 // Load a full 8-byte word of data from in_ptr.
700 // SAFETY: caller asserts in_ptr is not null
701 let (advance_in, advance_out) = unsafe { self.compress_word(last_word, out_ptr) };
702 // SAFETY: pointer ranges are checked in the loop condition
703 unsafe {
704 in_ptr = in_ptr.add(advance_in);
705 out_ptr = out_ptr.add(advance_out);
706 }
707
708 last_word = advance_8byte_word(last_word, advance_in);
709 }
710
711 // in_ptr should have exceeded in_end
712 assert!(
713 in_ptr >= in_end,
714 "exhausted output buffer before exhausting input, there is a bug in SymbolTable::compress()"
715 );
716
717 assert!(out_ptr <= out_end, "output buffer sized too small");
718
719 // SAFETY: out_ptr is derived from the `values` allocation.
720 let bytes_written = unsafe { out_ptr.offset_from(values.as_ptr()) };
721 assert!(
722 bytes_written >= 0,
723 "out_ptr ended before it started, not possible"
724 );
725
726 // SAFETY: we have initialized `bytes_written` values in the output buffer.
727 unsafe { values.set_len(bytes_written as usize) };
728 }
729
730 /// Use the symbol table to compress the plaintext into a sequence of codes and escapes.
731 pub fn compress(&self, plaintext: &[u8]) -> Vec<u8> {
732 if plaintext.is_empty() {
733 return Vec::new();
734 }
735
736 let mut buffer = Vec::with_capacity(plaintext.len() * 2);
737
738 // SAFETY: the largest compressed size would be all escapes == 2*plaintext_len
739 unsafe { self.compress_into(plaintext, &mut buffer) };
740
741 buffer
742 }
743
744 /// Access the decompressor that can be used to decompress strings emitted from this
745 /// `Compressor` instance.
746 pub fn decompressor(&self) -> Decompressor<'_> {
747 Decompressor::new(self.symbol_table(), self.symbol_lengths())
748 }
749
750 /// Returns a readonly slice of the current symbol table.
751 ///
752 /// The returned slice will have length of `n_symbols`.
753 pub fn symbol_table(&self) -> &[Symbol] {
754 &self.symbols[0..self.n_symbols as usize]
755 }
756
757 /// Returns a readonly slice where index `i` contains the
758 /// length of the symbol represented by code `i`.
759 ///
760 /// Values range from 1-8.
761 pub fn symbol_lengths(&self) -> &[u8] {
762 &self.lengths[0..self.n_symbols as usize]
763 }
764
765 /// Rebuild a compressor from an existing symbol table.
766 ///
767 /// This will not attempt to optimize or re-order the codes.
768 pub fn rebuild_from(symbols: impl AsRef<[Symbol]>, symbol_lens: impl AsRef<[u8]>) -> Self {
769 let symbols = symbols.as_ref();
770 let symbol_lens = symbol_lens.as_ref();
771
772 assert_eq!(
773 symbols.len(),
774 symbol_lens.len(),
775 "symbols and lengths differ"
776 );
777 assert!(
778 symbols.len() <= 255,
779 "symbol table len must be <= 255, was {}",
780 symbols.len()
781 );
782 validate_symbol_order(symbol_lens);
783
784 // Insert the symbols in their given order into the FSST lookup structures.
785 let symbols = symbols.to_vec();
786 let lengths = symbol_lens.to_vec();
787 let mut lossy_pht = LossyPHT::new();
788
789 let mut codes_one_byte = vec![Code::UNUSED; 256];
790
791 // Insert all of the one byte symbols first.
792 for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
793 if len == 1 {
794 codes_one_byte[symbol.first_byte() as usize] = Code::new_symbol(code as u8, 1);
795 }
796 }
797
798 // Initialize the codes_two_byte table to be all escapes
799 let mut codes_two_byte = vec![Code::UNUSED; 65_536];
800
801 // Insert the two byte symbols, possibly overwriting slots for one-byte symbols and escapes.
802 for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
803 match len {
804 2 => {
805 codes_two_byte[symbol.first2() as usize] = Code::new_symbol(code as u8, 2);
806 }
807 3.. => {
808 assert!(
809 lossy_pht.insert(symbol, len as usize, code as u8),
810 "rebuild symbol insertion into PHT must succeed"
811 );
812 }
813 _ => { /* Covered by the 1-byte loop above. */ }
814 }
815 }
816
817 // Build the finished codes_two_byte table, subbing in unused positions with the
818 // codes_one_byte value similar to what we do in CompressBuilder::finalize.
819 for (symbol, code) in codes_two_byte.iter_mut().enumerate() {
820 if *code == Code::UNUSED {
821 *code = codes_one_byte[symbol & 0xFF];
822 }
823 }
824
825 // Find the position of the first 2-byte code that has a suffix later in the table
826 let mut has_suffix_code = 0u8;
827 for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
828 if len != 2 {
829 break;
830 }
831 let rest = &symbols[code..];
832 if rest
833 .iter()
834 .any(|&other| other.len() > 2 && symbol.first2() == other.first2())
835 {
836 has_suffix_code = code as u8;
837 break;
838 }
839 }
840
841 Compressor {
842 n_symbols: symbols.len() as u8,
843 symbols,
844 lengths,
845 codes_two_byte,
846 lossy_pht,
847 has_suffix_code,
848 }
849 }
850}
851
852#[inline]
853pub(crate) fn advance_8byte_word(word: u64, bytes: usize) -> u64 {
854 // shift the word off the low-end, because little endian means the first
855 // char is stored in the LSB.
856 //
857 // Note that even though this looks like it branches, Rust compiles this to a
858 // conditional move instruction. See `<https://godbolt.org/z/Pbvre65Pq>`
859 if bytes == 8 { 0 } else { word >> (8 * bytes) }
860}
861
862fn validate_symbol_order(symbol_lens: &[u8]) {
863 // Ensure that the symbol table is ordered by length, 23456781
864 let mut expected = 2;
865 for (idx, &len) in symbol_lens.iter().enumerate() {
866 if expected == 1 {
867 assert_eq!(
868 len, 1,
869 "symbol code={idx} should be one byte, was {len} bytes"
870 );
871 } else {
872 if len == 1 {
873 expected = 1;
874 }
875
876 // we're in the non-zero portion.
877 assert!(
878 len >= expected,
879 "symbol code={idx} breaks violates FSST symbol table ordering"
880 );
881 expected = len;
882 }
883 }
884}
885
886#[inline]
887pub(crate) fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool {
888 let mask = u64::MAX >> ignored_bits;
889 (left & mask) == right
890}
891
892#[cfg(test)]
893mod test {
894 use super::*;
895 use std::{iter, mem};
896 #[test]
897 fn test_stuff() {
898 let compressor = {
899 let mut builder = CompressorBuilder::new();
900 builder.insert(Symbol::from_slice(b"helloooo"), 8);
901 builder.build()
902 };
903
904 let decompressor = compressor.decompressor();
905
906 let mut decompressed = Vec::with_capacity(8 + 7);
907
908 let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
909 assert_eq!(len, 8);
910 unsafe { decompressed.set_len(len) };
911 assert_eq!(&decompressed, "helloooo".as_bytes());
912 }
913
914 #[test]
915 fn test_symbols_good() {
916 let symbols_u64: &[u64] = &[
917 24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
918 24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
919 6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
920 6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
921 6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
922 6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
923 6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
924 6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
925 6447715, 97, 98, 100, 99, 97, 98, 99, 100,
926 ];
927 let symbols: &[Symbol] = unsafe { mem::transmute(symbols_u64) };
928 let lens: Vec<u8> = iter::repeat_n(2u8, 16)
929 .chain(iter::repeat_n(3u8, 61))
930 .chain(iter::repeat_n(1u8, 8))
931 .collect();
932
933 let compressor = Compressor::rebuild_from(symbols, lens);
934 let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
935 assert_eq!(built_symbols, symbols_u64);
936 }
937
938 #[should_panic(expected = "assertion `left == right` failed")]
939 #[test]
940 fn test_symbols_bad() {
941 let symbols: &[u64] = &[
942 24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
943 24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
944 6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
945 6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
946 6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
947 6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
948 6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
949 6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
950 6447715, 97, 98, 100, 99, 97, 98, 99, 100,
951 ];
952 let lens: Vec<u8> = iter::repeat_n(2u8, 16)
953 .chain(iter::repeat_n(3u8, 61))
954 .chain(iter::repeat_n(1u8, 8))
955 .collect();
956
957 let mut builder = CompressorBuilder::new();
958 for (symbol, len) in symbols.iter().zip(lens.iter()) {
959 let symbol = Symbol::from_slice(&symbol.to_le_bytes());
960 builder.insert(symbol, *len as usize);
961 }
962 let compressor = builder.build();
963 let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
964 assert_eq!(built_symbols, symbols);
965 }
966}