fsst/lib.rs
1#![doc = include_str!("../README.md")]
2#![cfg(target_endian = "little")]
3
4/// Throw a compiler error if a type isn't guaranteed to have a specific size in bytes.
5macro_rules! assert_sizeof {
6 ($typ:ty => $size_in_bytes:expr) => {
7 const _: [u8; $size_in_bytes] = [0; std::mem::size_of::<$typ>()];
8 };
9}
10
11use lossy_pht::LossyPHT;
12use std::fmt::{Debug, Formatter};
13use std::mem::MaybeUninit;
14
15mod builder;
16mod lossy_pht;
17
18pub use builder::*;
19
20/// `Symbol`s are small (up to 8-byte) segments of strings, stored in a [`Compressor`][`crate::Compressor`] and
21/// identified by an 8-bit code.
22#[derive(Copy, Clone, PartialEq, Eq, Hash)]
23pub struct Symbol(u64);
24
25assert_sizeof!(Symbol => 8);
26
27impl Symbol {
28 /// Zero value for `Symbol`.
29 pub const ZERO: Self = Self::zero();
30
31 /// Constructor for a `Symbol` from an 8-element byte slice.
32 pub fn from_slice(slice: &[u8; 8]) -> Self {
33 let num: u64 = u64::from_le_bytes(*slice);
34
35 Self(num)
36 }
37
38 /// Return a zero symbol
39 const fn zero() -> Self {
40 Self(0)
41 }
42
43 /// Create a new single-byte symbol
44 pub fn from_u8(value: u8) -> Self {
45 Self(value as u64)
46 }
47}
48
49impl Symbol {
50 /// Calculate the length of the symbol in bytes. Always a value between 1 and 8.
51 ///
52 /// Each symbol has the capacity to hold up to 8 bytes of data, but the symbols
53 /// can contain fewer bytes, padded with 0x00. There is a special case of a symbol
54 /// that holds the byte 0x00. In that case, the symbol contains `0x0000000000000000`
55 /// but we want to interpret that as a one-byte symbol containing `0x00`.
56 #[allow(clippy::len_without_is_empty)]
57 pub fn len(self) -> usize {
58 let numeric = self.0;
59 // For little-endian platforms, this counts the number of *trailing* zeros
60 let null_bytes = (numeric.leading_zeros() >> 3) as usize;
61
62 // Special case handling of a symbol with all-zeros. This is actually
63 // a 1-byte symbol containing 0x00.
64 let len = size_of::<Self>() - null_bytes;
65 if len == 0 { 1 } else { len }
66 }
67
68 /// Returns the Symbol's inner representation.
69 #[inline]
70 pub fn to_u64(self) -> u64 {
71 self.0
72 }
73
74 /// Get the first byte of the symbol as a `u8`.
75 ///
76 /// If the symbol is empty, this will return the zero byte.
77 #[inline]
78 pub fn first_byte(self) -> u8 {
79 self.0 as u8
80 }
81
82 /// Get the first two bytes of the symbol as a `u16`.
83 ///
84 /// If the Symbol is one or zero bytes, this will return `0u16`.
85 #[inline]
86 pub fn first2(self) -> u16 {
87 self.0 as u16
88 }
89
90 /// Get the first three bytes of the symbol as a `u64`.
91 ///
92 /// If the Symbol is one or zero bytes, this will return `0u64`.
93 #[inline]
94 pub fn first3(self) -> u64 {
95 self.0 & 0xFF_FF_FF
96 }
97
98 /// Return a new `Symbol` by logically concatenating ourselves with another `Symbol`.
99 pub fn concat(self, other: Self) -> Self {
100 assert!(
101 self.len() + other.len() <= 8,
102 "cannot build symbol with length > 8"
103 );
104
105 let self_len = self.len();
106
107 Self((other.0 << (8 * self_len)) | self.0)
108 }
109}
110
111impl Debug for Symbol {
112 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
113 write!(f, "[")?;
114
115 let slice = &self.0.to_le_bytes()[0..self.len()];
116 for c in slice.iter().map(|c| *c as char) {
117 if ('!'..='~').contains(&c) {
118 write!(f, "{c}")?;
119 } else if c == '\n' {
120 write!(f, " \\n ")?;
121 } else if c == '\t' {
122 write!(f, " \\t ")?;
123 } else if c == ' ' {
124 write!(f, " SPACE ")?;
125 } else {
126 write!(f, " 0x{:X?} ", c as u8)?
127 }
128 }
129
130 write!(f, "]")
131 }
132}
133
134/// A packed type containing a code value, as well as metadata about the symbol referred to by
135/// the code.
136///
137/// Logically, codes can range from 0-255 inclusive. This type holds both the 8-bit code as well as
138/// other metadata bit-packed into a `u16`.
139///
140/// The bottom 8 bits contain EITHER a code for a symbol stored in the table, OR a raw byte.
141///
142/// The interpretation depends on the 9th bit: when toggled off, the value stores a raw byte, and when
143/// toggled on, it stores a code. Thus if you examine the bottom 9 bits of the `u16`, you have an extended
144/// code range, where the values 0-255 are raw bytes, and the values 256-510 represent codes 0-254. 511 is
145/// a placeholder for the invalid code here.
146///
147/// Bits 12-15 store the length of the symbol (values ranging from 0-8).
148#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
149struct Code(u16);
150
151/// Code used to indicate bytes that are not in the symbol table.
152///
153/// When compressing a string that cannot fully be expressed with the symbol table, the compressed
154/// output will contain an `ESCAPE` byte followed by a raw byte. At decompression time, the presence
155/// of `ESCAPE` indicates that the next byte should be appended directly to the result instead of
156/// being looked up in the symbol table.
157pub const ESCAPE_CODE: u8 = 255;
158
159/// Number of bits in the `ExtendedCode` that are used to dictate a code value.
160pub const FSST_CODE_BITS: usize = 9;
161
162/// First bit of the "length" portion of an extended code.
163pub const FSST_LEN_BITS: usize = 12;
164
165/// Maximum code value in the extended code range.
166pub const FSST_CODE_MAX: u16 = 1 << FSST_CODE_BITS;
167
168/// Maximum value for the extended code range.
169///
170/// When truncated to u8 this is code 255, which is equivalent to [`ESCAPE_CODE`].
171pub const FSST_CODE_MASK: u16 = FSST_CODE_MAX - 1;
172
173/// First code in the symbol table that corresponds to a non-escape symbol.
174pub const FSST_CODE_BASE: u16 = 256;
175
176#[allow(clippy::len_without_is_empty)]
177impl Code {
178 /// Code for an unused slot in a symbol table or index.
179 ///
180 /// This corresponds to the maximum code with a length of 1.
181 pub const UNUSED: Self = Code(FSST_CODE_MASK + (1 << 12));
182
183 /// Create a new code for a symbol of given length.
184 fn new_symbol(code: u8, len: usize) -> Self {
185 Self(code as u16 + ((len as u16) << FSST_LEN_BITS))
186 }
187
188 /// Code for a new symbol during the building phase.
189 ///
190 /// The code is remapped from 0..254 to 256...510.
191 fn new_symbol_building(code: u8, len: usize) -> Self {
192 Self(code as u16 + 256 + ((len as u16) << FSST_LEN_BITS))
193 }
194
195 /// Create a new code corresponding for an escaped byte.
196 fn new_escape(byte: u8) -> Self {
197 Self((byte as u16) + (1 << FSST_LEN_BITS))
198 }
199
200 #[inline]
201 fn code(self) -> u8 {
202 self.0 as u8
203 }
204
205 #[inline]
206 fn extended_code(self) -> u16 {
207 self.0 & 0b111_111_111
208 }
209
210 #[inline]
211 fn len(self) -> u16 {
212 self.0 >> FSST_LEN_BITS
213 }
214}
215
216impl Debug for Code {
217 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
218 f.debug_struct("TrainingCode")
219 .field("code", &(self.0 as u8))
220 .field("is_escape", &(self.0 < 256))
221 .field("len", &(self.0 >> 12))
222 .finish()
223 }
224}
225
226/// Decompressor uses a symbol table to take a stream of 8-bit codes into a string.
227#[derive(Clone)]
228pub struct Decompressor<'a> {
229 /// Slice mapping codes to symbols.
230 pub(crate) symbols: &'a [Symbol],
231
232 /// Slice containing the length of each symbol in the `symbols` slice.
233 pub(crate) lengths: &'a [u8],
234}
235
236impl<'a> Decompressor<'a> {
237 /// Returns a new decompressor that uses the provided symbol table.
238 ///
239 /// # Panics
240 ///
241 /// If the provided symbol table has length greater than 256
242 pub fn new(symbols: &'a [Symbol], lengths: &'a [u8]) -> Self {
243 assert!(
244 symbols.len() < FSST_CODE_BASE as usize,
245 "symbol table cannot have size exceeding 255"
246 );
247
248 Self { symbols, lengths }
249 }
250
251 /// Returns an upper bound on the size of the decompressed data.
252 pub fn max_decompression_capacity(&self, compressed: &[u8]) -> usize {
253 size_of::<Symbol>() * (compressed.len() + 1)
254 }
255
256 /// Decompress a slice of codes into a provided buffer.
257 ///
258 /// The provided `decoded` buffer must be at least the size of the decoded data, plus
259 /// an additional 7 bytes.
260 ///
261 /// ## Panics
262 ///
263 /// If the caller fails to provide sufficient capacity in the decoded buffer. An upper bound
264 /// on the required capacity can be obtained by calling [`Self::max_decompression_capacity`].
265 ///
266 /// ## Example
267 ///
268 /// ```
269 /// use fsst::{Symbol, Compressor, CompressorBuilder};
270 /// let compressor = {
271 /// let mut builder = CompressorBuilder::new();
272 /// builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', b'o', b'o', b'o']), 8);
273 /// builder.build()
274 /// };
275 ///
276 /// let decompressor = compressor.decompressor();
277 ///
278 /// let mut decompressed = Vec::with_capacity(8 + 7);
279 ///
280 /// let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
281 /// assert_eq!(len, 8);
282 /// unsafe { decompressed.set_len(len) };
283 /// assert_eq!(&decompressed, "helloooo".as_bytes());
284 /// ```
285 pub fn decompress_into(&self, compressed: &[u8], decoded: &mut [MaybeUninit<u8>]) -> usize {
286 // Ensure the target buffer is at least half the size of the input buffer.
287 // This is the theortical smallest a valid target can be, and occurs when
288 // every input code is an escape.
289 assert!(
290 decoded.len() >= compressed.len() / 2,
291 "decoded is smaller than lower-bound decompressed size"
292 );
293
294 unsafe {
295 let mut in_ptr = compressed.as_ptr();
296 let _in_begin = in_ptr;
297 let in_end = in_ptr.add(compressed.len());
298
299 let mut out_ptr: *mut u8 = decoded.as_mut_ptr().cast();
300 let out_begin = out_ptr.cast_const();
301 let out_end = decoded.as_ptr().add(decoded.len()).cast::<u8>();
302
303 macro_rules! store_next_symbol {
304 ($code:expr) => {{
305 out_ptr
306 .cast::<u64>()
307 .write_unaligned(self.symbols.get_unchecked($code as usize).to_u64());
308 out_ptr = out_ptr.add(*self.lengths.get_unchecked($code as usize) as usize);
309 }};
310 }
311
312 // First we try loading 8 bytes at a time.
313 if decoded.len() >= 8 * size_of::<Symbol>() && compressed.len() >= 8 {
314 // Extract the loop condition since the compiler fails to do so
315 let block_out_end = out_end.sub(8 * size_of::<Symbol>());
316 let block_in_end = in_end.sub(8);
317
318 while out_ptr.cast_const() <= block_out_end && in_ptr < block_in_end {
319 // Note that we load a little-endian u64 here.
320 let next_block = in_ptr.cast::<u64>().read_unaligned();
321 let escape_mask = (next_block & 0x8080808080808080)
322 & ((((!next_block) & 0x7F7F7F7F7F7F7F7F) + 0x7F7F7F7F7F7F7F7F)
323 ^ 0x8080808080808080);
324
325 // If there are no escape codes, we write each symbol one by one.
326 if escape_mask == 0 {
327 let code = (next_block & 0xFF) as u8;
328 store_next_symbol!(code);
329 let code = ((next_block >> 8) & 0xFF) as u8;
330 store_next_symbol!(code);
331 let code = ((next_block >> 16) & 0xFF) as u8;
332 store_next_symbol!(code);
333 let code = ((next_block >> 24) & 0xFF) as u8;
334 store_next_symbol!(code);
335 let code = ((next_block >> 32) & 0xFF) as u8;
336 store_next_symbol!(code);
337 let code = ((next_block >> 40) & 0xFF) as u8;
338 store_next_symbol!(code);
339 let code = ((next_block >> 48) & 0xFF) as u8;
340 store_next_symbol!(code);
341 let code = ((next_block >> 56) & 0xFF) as u8;
342 store_next_symbol!(code);
343 in_ptr = in_ptr.add(8);
344 } else {
345 // Otherwise, find the first escape code and write the symbols up to that point.
346 let first_escape_pos = escape_mask.trailing_zeros() >> 3; // Divide bits to bytes
347 debug_assert!(first_escape_pos < 8);
348 match first_escape_pos {
349 7 => {
350 let code = (next_block & 0xFF) as u8;
351 store_next_symbol!(code);
352 let code = ((next_block >> 8) & 0xFF) as u8;
353 store_next_symbol!(code);
354 let code = ((next_block >> 16) & 0xFF) as u8;
355 store_next_symbol!(code);
356 let code = ((next_block >> 24) & 0xFF) as u8;
357 store_next_symbol!(code);
358 let code = ((next_block >> 32) & 0xFF) as u8;
359 store_next_symbol!(code);
360 let code = ((next_block >> 40) & 0xFF) as u8;
361 store_next_symbol!(code);
362 let code = ((next_block >> 48) & 0xFF) as u8;
363 store_next_symbol!(code);
364
365 in_ptr = in_ptr.add(7);
366 }
367 6 => {
368 let code = (next_block & 0xFF) as u8;
369 store_next_symbol!(code);
370 let code = ((next_block >> 8) & 0xFF) as u8;
371 store_next_symbol!(code);
372 let code = ((next_block >> 16) & 0xFF) as u8;
373 store_next_symbol!(code);
374 let code = ((next_block >> 24) & 0xFF) as u8;
375 store_next_symbol!(code);
376 let code = ((next_block >> 32) & 0xFF) as u8;
377 store_next_symbol!(code);
378 let code = ((next_block >> 40) & 0xFF) as u8;
379 store_next_symbol!(code);
380
381 let escaped = ((next_block >> 56) & 0xFF) as u8;
382 out_ptr.write(escaped);
383 out_ptr = out_ptr.add(1);
384
385 in_ptr = in_ptr.add(8);
386 }
387 5 => {
388 let code = (next_block & 0xFF) as u8;
389 store_next_symbol!(code);
390 let code = ((next_block >> 8) & 0xFF) as u8;
391 store_next_symbol!(code);
392 let code = ((next_block >> 16) & 0xFF) as u8;
393 store_next_symbol!(code);
394 let code = ((next_block >> 24) & 0xFF) as u8;
395 store_next_symbol!(code);
396 let code = ((next_block >> 32) & 0xFF) as u8;
397 store_next_symbol!(code);
398
399 let escaped = ((next_block >> 48) & 0xFF) as u8;
400 out_ptr.write(escaped);
401 out_ptr = out_ptr.add(1);
402
403 in_ptr = in_ptr.add(7);
404 }
405 4 => {
406 let code = (next_block & 0xFF) as u8;
407 store_next_symbol!(code);
408 let code = ((next_block >> 8) & 0xFF) as u8;
409 store_next_symbol!(code);
410 let code = ((next_block >> 16) & 0xFF) as u8;
411 store_next_symbol!(code);
412 let code = ((next_block >> 24) & 0xFF) as u8;
413 store_next_symbol!(code);
414
415 let escaped = ((next_block >> 40) & 0xFF) as u8;
416 out_ptr.write(escaped);
417 out_ptr = out_ptr.add(1);
418
419 in_ptr = in_ptr.add(6);
420 }
421 3 => {
422 let code = (next_block & 0xFF) as u8;
423 store_next_symbol!(code);
424 let code = ((next_block >> 8) & 0xFF) as u8;
425 store_next_symbol!(code);
426 let code = ((next_block >> 16) & 0xFF) as u8;
427 store_next_symbol!(code);
428
429 let escaped = ((next_block >> 32) & 0xFF) as u8;
430 out_ptr.write(escaped);
431 out_ptr = out_ptr.add(1);
432
433 in_ptr = in_ptr.add(5);
434 }
435 2 => {
436 let code = (next_block & 0xFF) as u8;
437 store_next_symbol!(code);
438 let code = ((next_block >> 8) & 0xFF) as u8;
439 store_next_symbol!(code);
440
441 let escaped = ((next_block >> 24) & 0xFF) as u8;
442 out_ptr.write(escaped);
443 out_ptr = out_ptr.add(1);
444
445 in_ptr = in_ptr.add(4);
446 }
447 1 => {
448 let code = (next_block & 0xFF) as u8;
449 store_next_symbol!(code);
450
451 let escaped = ((next_block >> 16) & 0xFF) as u8;
452 out_ptr.write(escaped);
453 out_ptr = out_ptr.add(1);
454
455 in_ptr = in_ptr.add(3);
456 }
457 0 => {
458 // Otherwise, we actually need to decompress the next byte
459 // Extract the second byte from the u32
460 let escaped = ((next_block >> 8) & 0xFF) as u8;
461 in_ptr = in_ptr.add(2);
462 out_ptr.write(escaped);
463 out_ptr = out_ptr.add(1);
464 }
465 _ => unreachable!(),
466 }
467 }
468 }
469 }
470
471 // Otherwise, fall back to 1-byte reads using 8-byte writes where safe.
472 while out_end.offset_from(out_ptr) >= size_of::<Symbol>() as isize && in_ptr < in_end {
473 let code = in_ptr.read();
474 in_ptr = in_ptr.add(1);
475
476 if code == ESCAPE_CODE {
477 assert!(
478 in_ptr < in_end,
479 "truncated compressed string: escape code at end of input"
480 );
481 out_ptr.write(in_ptr.read());
482 in_ptr = in_ptr.add(1);
483 out_ptr = out_ptr.add(1);
484 } else {
485 store_next_symbol!(code);
486 }
487 }
488
489 // For the last few bytes (if any) where we can't do an 8-byte unaligned write.
490 while in_ptr < in_end {
491 let code = in_ptr.read();
492 in_ptr = in_ptr.add(1);
493
494 if code == ESCAPE_CODE {
495 assert!(
496 in_ptr < in_end,
497 "truncated compressed string: escape code at end of input"
498 );
499 assert!(
500 out_ptr.cast_const() < out_end,
501 "output buffer sized too small"
502 );
503 out_ptr.write(in_ptr.read());
504 in_ptr = in_ptr.add(1);
505 out_ptr = out_ptr.add(1);
506 } else {
507 let len = *self.lengths.get_unchecked(code as usize) as usize;
508 assert!(
509 out_end.offset_from(out_ptr) >= len as isize,
510 "output buffer sized too small"
511 );
512 let sym = self.symbols.get_unchecked(code as usize).to_u64();
513 let sym_bytes = sym.to_le_bytes();
514 std::ptr::copy_nonoverlapping(sym_bytes.as_ptr(), out_ptr, len);
515 out_ptr = out_ptr.add(len);
516 }
517 }
518
519 assert_eq!(
520 in_ptr, in_end,
521 "decompression should exhaust input before output"
522 );
523
524 out_ptr.offset_from(out_begin) as usize
525 }
526 }
527
528 /// Decompress a byte slice that was previously returned by a compressor using the same symbol
529 /// table into a new vector of bytes.
530 pub fn decompress(&self, compressed: &[u8]) -> Vec<u8> {
531 let mut decoded = Vec::with_capacity(self.max_decompression_capacity(compressed) + 7);
532
533 let len = self.decompress_into(compressed, decoded.spare_capacity_mut());
534 // SAFETY: len bytes have now been initialized by the decompressor.
535 unsafe { decoded.set_len(len) };
536 decoded
537 }
538}
539
540/// A compressor that uses a symbol table to greedily compress strings.
541///
542/// The `Compressor` is the central component of FSST. You can create a compressor either by
543/// default (i.e. an empty compressor), or by [training][`Self::train`] it on an input corpus of text.
544///
545/// Example usage:
546///
547/// ```
548/// use fsst::{Symbol, Compressor, CompressorBuilder};
549/// let compressor = {
550/// let mut builder = CompressorBuilder::new();
551/// builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', 0, 0, 0]), 5);
552/// builder.build()
553/// };
554///
555/// let compressed = compressor.compress("hello".as_bytes());
556/// assert_eq!(compressed, vec![0u8]);
557/// ```
558#[derive(Clone)]
559pub struct Compressor {
560 /// Table mapping codes to symbols.
561 pub(crate) symbols: Vec<Symbol>,
562
563 /// Length of each symbol, values range from 1-8.
564 pub(crate) lengths: Vec<u8>,
565
566 /// The number of entries in the symbol table that have been populated, not counting
567 /// the escape values.
568 pub(crate) n_symbols: u8,
569
570 /// Inverted index mapping 2-byte symbols to codes
571 codes_two_byte: Vec<Code>,
572
573 /// Limit of no suffixes.
574 has_suffix_code: u8,
575
576 /// Lossy perfect hash table for looking up codes to symbols that are 3 bytes or more
577 lossy_pht: LossyPHT,
578}
579
580/// The core structure of the FSST codec, holding a mapping between `Symbol`s and `Code`s.
581///
582/// The symbol table is trained on a corpus of data in the form of a single byte array, building up
583/// a mapping of 1-byte "codes" to sequences of up to 8 plaintext bytes, or "symbols".
584impl Compressor {
585 /// Using the symbol table, runs a single cycle of compression on an input word, writing
586 /// the output into `out_ptr`.
587 ///
588 /// # Returns
589 ///
590 /// This function returns a tuple of (advance_in, advance_out) with the number of bytes
591 /// for the caller to advance the input and output pointers.
592 ///
593 /// `advance_in` is the number of bytes to advance the input pointer before the next call.
594 ///
595 /// `advance_out` is the number of bytes to advance `out_ptr` before the next call.
596 ///
597 /// # Safety
598 ///
599 /// `out_ptr` must never be NULL or otherwise point to invalid memory.
600 pub unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) {
601 // Speculatively write the first byte of `word` at offset 1. This is necessary if it is an escape, and
602 // if it isn't, it will be overwritten anyway.
603 //
604 // SAFETY: caller ensures out_ptr is not null
605 let first_byte = word as u8;
606 // SAFETY: out_ptr is not null
607 unsafe { out_ptr.byte_add(1).write_unaligned(first_byte) };
608
609 // First, check the two_bytes table
610 let code_twobyte = self.codes_two_byte[word as u16 as usize];
611
612 if code_twobyte.code() < self.has_suffix_code {
613 // 2 byte code without having to worry about longer matches.
614 // SAFETY: out_ptr is not null.
615 unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
616
617 // Advance input by symbol length (2) and output by a single code byte
618 (2, 1)
619 } else {
620 // Probe the hash table
621 let entry = self.lossy_pht.lookup(word);
622
623 // Now, downshift the `word` and the `entry` to see if they align.
624 let ignored_bits = entry.ignored_bits;
625 if entry.code != Code::UNUSED
626 && compare_masked(word, entry.symbol.to_u64(), ignored_bits)
627 {
628 // Advance the input by the symbol length (variable) and the output by one code byte
629 // SAFETY: out_ptr is not null.
630 unsafe { std::ptr::write(out_ptr, entry.code.code()) };
631 (entry.code.len() as usize, 1)
632 } else {
633 // SAFETY: out_ptr is not null
634 unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
635
636 // Advance the input by the symbol length (variable) and the output by either 1
637 // byte (if was one-byte code) or two bytes (escape).
638 (
639 code_twobyte.len() as usize,
640 // Predicated version of:
641 //
642 // if entry.code >= 256 {
643 // 2
644 // } else {
645 // 1
646 // }
647 1 + (code_twobyte.extended_code() >> 8) as usize,
648 )
649 }
650 }
651 }
652
653 /// Compress many lines in bulk.
654 pub fn compress_bulk(&self, lines: &Vec<&[u8]>) -> Vec<Vec<u8>> {
655 let mut res = Vec::new();
656
657 for line in lines {
658 res.push(self.compress(line));
659 }
660
661 res
662 }
663
664 /// Compress a string, writing its result into a target buffer.
665 ///
666 /// The target buffer is a byte vector that must have capacity large enough
667 /// to hold the encoded data.
668 ///
669 /// When this call returns, `values` will hold the compressed bytes and have
670 /// its length set to the length of the compressed text.
671 ///
672 /// ```
673 /// use fsst::{Compressor, CompressorBuilder, Symbol};
674 ///
675 /// let mut compressor = CompressorBuilder::new();
676 /// assert!(compressor.insert(Symbol::from_slice(b"aaaaaaaa"), 8));
677 ///
678 /// let compressor = compressor.build();
679 ///
680 /// let mut compressed_values = Vec::with_capacity(1_024);
681 ///
682 /// // SAFETY: we have over-sized compressed_values.
683 /// unsafe {
684 /// compressor.compress_into(b"aaaaaaaa", &mut compressed_values);
685 /// }
686 ///
687 /// assert_eq!(compressed_values, vec![0u8]);
688 /// ```
689 ///
690 /// # Safety
691 ///
692 /// It is up to the caller to ensure the provided buffer is large enough to hold
693 /// all encoded data.
694 pub unsafe fn compress_into(&self, plaintext: &[u8], values: &mut Vec<u8>) {
695 let mut in_ptr = plaintext.as_ptr();
696 let mut out_ptr = values.as_mut_ptr();
697
698 // SAFETY: `end` will point just after the end of the `plaintext` slice.
699 let in_end = unsafe { in_ptr.byte_add(plaintext.len()) };
700 let in_end_sub8 = in_end as usize - 8;
701 // SAFETY: `end` will point just after the end of the `values` allocation.
702 let out_end = unsafe { out_ptr.byte_add(values.capacity()) };
703
704 while (in_ptr as usize) <= in_end_sub8 && unsafe { out_end.offset_from(out_ptr) } >= 2 {
705 // SAFETY: pointer ranges are checked in the loop condition
706 unsafe {
707 // Load a full 8-byte word of data from in_ptr.
708 // SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
709 let word: u64 = std::ptr::read_unaligned(in_ptr as *const u64);
710 let (advance_in, advance_out) = self.compress_word(word, out_ptr);
711 in_ptr = in_ptr.byte_add(advance_in);
712 out_ptr = out_ptr.byte_add(advance_out);
713 };
714 }
715
716 let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
717 assert!(
718 out_ptr < out_end || remaining_bytes == 0,
719 "output buffer sized too small"
720 );
721
722 let remaining_bytes = remaining_bytes as usize;
723
724 // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
725 // but shift data out of this word rather than advancing an input pointer and potentially reading
726 // unowned memory.
727 let mut bytes = [0u8; 8];
728 // SAFETY: remaining_bytes <= 8
729 unsafe { std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes) };
730 let mut last_word = u64::from_le_bytes(bytes);
731
732 while in_ptr < in_end && unsafe { out_end.offset_from(out_ptr) } >= 2 {
733 // Load a full 8-byte word of data from in_ptr.
734 // SAFETY: caller asserts in_ptr is not null
735 let (advance_in, advance_out) = unsafe { self.compress_word(last_word, out_ptr) };
736 // SAFETY: pointer ranges are checked in the loop condition
737 unsafe {
738 in_ptr = in_ptr.add(advance_in);
739 out_ptr = out_ptr.add(advance_out);
740 }
741
742 last_word = advance_8byte_word(last_word, advance_in);
743 }
744
745 // in_ptr should have exceeded in_end
746 assert!(
747 in_ptr >= in_end,
748 "exhausted output buffer before exhausting input, there is a bug in SymbolTable::compress()"
749 );
750
751 assert!(out_ptr <= out_end, "output buffer sized too small");
752
753 // SAFETY: out_ptr is derived from the `values` allocation.
754 let bytes_written = unsafe { out_ptr.offset_from(values.as_ptr()) };
755 assert!(
756 bytes_written >= 0,
757 "out_ptr ended before it started, not possible"
758 );
759
760 // SAFETY: we have initialized `bytes_written` values in the output buffer.
761 unsafe { values.set_len(bytes_written as usize) };
762 }
763
764 /// Use the symbol table to compress the plaintext into a sequence of codes and escapes.
765 pub fn compress(&self, plaintext: &[u8]) -> Vec<u8> {
766 if plaintext.is_empty() {
767 return Vec::new();
768 }
769
770 let mut buffer = Vec::with_capacity(plaintext.len() * 2);
771
772 // SAFETY: the largest compressed size would be all escapes == 2*plaintext_len
773 unsafe { self.compress_into(plaintext, &mut buffer) };
774
775 buffer
776 }
777
778 /// Access the decompressor that can be used to decompress strings emitted from this
779 /// `Compressor` instance.
780 pub fn decompressor(&self) -> Decompressor<'_> {
781 Decompressor::new(self.symbol_table(), self.symbol_lengths())
782 }
783
784 /// Returns a readonly slice of the current symbol table.
785 ///
786 /// The returned slice will have length of `n_symbols`.
787 pub fn symbol_table(&self) -> &[Symbol] {
788 &self.symbols[0..self.n_symbols as usize]
789 }
790
791 /// Returns a readonly slice where index `i` contains the
792 /// length of the symbol represented by code `i`.
793 ///
794 /// Values range from 1-8.
795 pub fn symbol_lengths(&self) -> &[u8] {
796 &self.lengths[0..self.n_symbols as usize]
797 }
798
799 /// Rebuild a compressor from an existing symbol table.
800 ///
801 /// This will not attempt to optimize or re-order the codes.
802 pub fn rebuild_from(symbols: impl AsRef<[Symbol]>, symbol_lens: impl AsRef<[u8]>) -> Self {
803 let symbols = symbols.as_ref();
804 let symbol_lens = symbol_lens.as_ref();
805
806 assert_eq!(
807 symbols.len(),
808 symbol_lens.len(),
809 "symbols and lengths differ"
810 );
811 assert!(
812 symbols.len() <= 255,
813 "symbol table len must be <= 255, was {}",
814 symbols.len()
815 );
816 validate_symbol_order(symbol_lens);
817
818 // Insert the symbols in their given order into the FSST lookup structures.
819 let symbols = symbols.to_vec();
820 let lengths = symbol_lens.to_vec();
821 let mut lossy_pht = LossyPHT::new();
822
823 let mut codes_one_byte = vec![Code::UNUSED; 256];
824
825 // Insert all of the one byte symbols first.
826 for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
827 if len == 1 {
828 codes_one_byte[symbol.first_byte() as usize] = Code::new_symbol(code as u8, 1);
829 }
830 }
831
832 // Initialize the codes_two_byte table to be all escapes
833 let mut codes_two_byte = vec![Code::UNUSED; 65_536];
834
835 // Insert the two byte symbols, possibly overwriting slots for one-byte symbols and escapes.
836 for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
837 match len {
838 2 => {
839 codes_two_byte[symbol.first2() as usize] = Code::new_symbol(code as u8, 2);
840 }
841 3.. => {
842 assert!(
843 lossy_pht.insert(symbol, len as usize, code as u8),
844 "rebuild symbol insertion into PHT must succeed"
845 );
846 }
847 _ => { /* Covered by the 1-byte loop above. */ }
848 }
849 }
850
851 // Build the finished codes_two_byte table, subbing in unused positions with the
852 // codes_one_byte value similar to what we do in CompressBuilder::finalize.
853 for (symbol, code) in codes_two_byte.iter_mut().enumerate() {
854 if *code == Code::UNUSED {
855 *code = codes_one_byte[symbol & 0xFF];
856 }
857 }
858
859 // Find the position of the first 2-byte code that has a suffix later in the table
860 let mut has_suffix_code = 0u8;
861 for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
862 if len != 2 {
863 break;
864 }
865 let rest = &symbols[code..];
866 if rest
867 .iter()
868 .any(|&other| other.len() > 2 && symbol.first2() == other.first2())
869 {
870 has_suffix_code = code as u8;
871 break;
872 }
873 }
874
875 Compressor {
876 n_symbols: symbols.len() as u8,
877 symbols,
878 lengths,
879 codes_two_byte,
880 lossy_pht,
881 has_suffix_code,
882 }
883 }
884}
885
886#[inline]
887pub(crate) fn advance_8byte_word(word: u64, bytes: usize) -> u64 {
888 // shift the word off the low-end, because little endian means the first
889 // char is stored in the LSB.
890 //
891 // Note that even though this looks like it branches, Rust compiles this to a
892 // conditional move instruction. See `<https://godbolt.org/z/Pbvre65Pq>`
893 if bytes == 8 { 0 } else { word >> (8 * bytes) }
894}
895
896fn validate_symbol_order(symbol_lens: &[u8]) {
897 // Ensure that the symbol table is ordered by length, 23456781
898 let mut expected = 2;
899 for (idx, &len) in symbol_lens.iter().enumerate() {
900 if expected == 1 {
901 assert_eq!(
902 len, 1,
903 "symbol code={idx} should be one byte, was {len} bytes"
904 );
905 } else {
906 if len == 1 {
907 expected = 1;
908 }
909
910 // we're in the non-zero portion.
911 assert!(
912 len >= expected,
913 "symbol code={idx} breaks violates FSST symbol table ordering"
914 );
915 expected = len;
916 }
917 }
918}
919
920#[inline]
921pub(crate) fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool {
922 let mask = u64::MAX >> ignored_bits;
923 (left & mask) == right
924}
925
926#[cfg(test)]
927mod test {
928 use super::*;
929 use std::{iter, mem};
930 #[test]
931 fn test_stuff() {
932 let compressor = {
933 let mut builder = CompressorBuilder::new();
934 builder.insert(Symbol::from_slice(b"helloooo"), 8);
935 builder.build()
936 };
937
938 let decompressor = compressor.decompressor();
939
940 let mut decompressed = Vec::with_capacity(8 + 7);
941
942 let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
943 assert_eq!(len, 8);
944 unsafe { decompressed.set_len(len) };
945 assert_eq!(&decompressed, "helloooo".as_bytes());
946 }
947
948 #[test]
949 fn test_symbols_good() {
950 let symbols_u64: &[u64] = &[
951 24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
952 24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
953 6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
954 6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
955 6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
956 6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
957 6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
958 6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
959 6447715, 97, 98, 100, 99, 97, 98, 99, 100,
960 ];
961 let symbols: &[Symbol] = unsafe { mem::transmute(symbols_u64) };
962 let lens: Vec<u8> = iter::repeat_n(2u8, 16)
963 .chain(iter::repeat_n(3u8, 61))
964 .chain(iter::repeat_n(1u8, 8))
965 .collect();
966
967 let compressor = Compressor::rebuild_from(symbols, lens);
968 let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
969 assert_eq!(built_symbols, symbols_u64);
970 }
971
972 #[should_panic(expected = "assertion `left == right` failed")]
973 #[test]
974 fn test_symbols_bad() {
975 let symbols: &[u64] = &[
976 24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
977 24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
978 6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
979 6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
980 6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
981 6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
982 6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
983 6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
984 6447715, 97, 98, 100, 99, 97, 98, 99, 100,
985 ];
986 let lens: Vec<u8> = iter::repeat_n(2u8, 16)
987 .chain(iter::repeat_n(3u8, 61))
988 .chain(iter::repeat_n(1u8, 8))
989 .collect();
990
991 let mut builder = CompressorBuilder::new();
992 for (symbol, len) in symbols.iter().zip(lens.iter()) {
993 let symbol = Symbol::from_slice(&symbol.to_le_bytes());
994 builder.insert(symbol, *len as usize);
995 }
996 let compressor = builder.build();
997 let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
998 assert_eq!(built_symbols, symbols);
999 }
1000}