#![doc = include_str!("../README.md")]
#![cfg(target_endian = "little")]
macro_rules! assert_sizeof {
($typ:ty => $size_in_bytes:expr) => {
const _: [u8; $size_in_bytes] = [0; std::mem::size_of::<$typ>()];
};
}
use lossy_pht::LossyPHT;
use std::fmt::{Debug, Formatter};
use std::mem::MaybeUninit;
mod builder;
mod lossy_pht;
pub use builder::*;
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub struct Symbol(u64);
assert_sizeof!(Symbol => 8);
impl Symbol {
pub const ZERO: Self = Self::zero();
pub fn from_slice(slice: &[u8; 8]) -> Self {
let num: u64 = u64::from_le_bytes(*slice);
Self(num)
}
const fn zero() -> Self {
Self(0)
}
pub fn from_u8(value: u8) -> Self {
Self(value as u64)
}
}
impl Symbol {
#[allow(clippy::len_without_is_empty)]
pub fn len(self) -> usize {
let numeric = self.0;
let null_bytes = (numeric.leading_zeros() >> 3) as usize;
let len = size_of::<Self>() - null_bytes;
if len == 0 { 1 } else { len }
}
#[inline]
pub fn to_u64(self) -> u64 {
self.0
}
#[inline]
pub fn first_byte(self) -> u8 {
self.0 as u8
}
#[inline]
pub fn first2(self) -> u16 {
self.0 as u16
}
#[inline]
pub fn first3(self) -> u64 {
self.0 & 0xFF_FF_FF
}
pub fn concat(self, other: Self) -> Self {
assert!(
self.len() + other.len() <= 8,
"cannot build symbol with length > 8"
);
let self_len = self.len();
Self((other.0 << (8 * self_len)) | self.0)
}
}
impl Debug for Symbol {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "[")?;
let slice = &self.0.to_le_bytes()[0..self.len()];
for c in slice.iter().map(|c| *c as char) {
if ('!'..='~').contains(&c) {
write!(f, "{c}")?;
} else if c == '\n' {
write!(f, " \\n ")?;
} else if c == '\t' {
write!(f, " \\t ")?;
} else if c == ' ' {
write!(f, " SPACE ")?;
} else {
write!(f, " 0x{:X?} ", c as u8)?
}
}
write!(f, "]")
}
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct Code(u16);
pub const ESCAPE_CODE: u8 = 255;
pub const FSST_CODE_BITS: usize = 9;
pub const FSST_LEN_BITS: usize = 12;
pub const FSST_CODE_MAX: u16 = 1 << FSST_CODE_BITS;
pub const FSST_CODE_MASK: u16 = FSST_CODE_MAX - 1;
pub const FSST_CODE_BASE: u16 = 256;
#[allow(clippy::len_without_is_empty)]
impl Code {
pub const UNUSED: Self = Code(FSST_CODE_MASK + (1 << 12));
fn new_symbol(code: u8, len: usize) -> Self {
Self(code as u16 + ((len as u16) << FSST_LEN_BITS))
}
fn new_symbol_building(code: u8, len: usize) -> Self {
Self(code as u16 + 256 + ((len as u16) << FSST_LEN_BITS))
}
fn new_escape(byte: u8) -> Self {
Self((byte as u16) + (1 << FSST_LEN_BITS))
}
#[inline]
fn code(self) -> u8 {
self.0 as u8
}
#[inline]
fn extended_code(self) -> u16 {
self.0 & 0b111_111_111
}
#[inline]
fn len(self) -> u16 {
self.0 >> FSST_LEN_BITS
}
}
impl Debug for Code {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TrainingCode")
.field("code", &(self.0 as u8))
.field("is_escape", &(self.0 < 256))
.field("len", &(self.0 >> 12))
.finish()
}
}
#[derive(Clone)]
pub struct Decompressor<'a> {
pub(crate) symbols: &'a [Symbol],
pub(crate) lengths: &'a [u8],
}
impl<'a> Decompressor<'a> {
pub fn new(symbols: &'a [Symbol], lengths: &'a [u8]) -> Self {
assert!(
symbols.len() < FSST_CODE_BASE as usize,
"symbol table cannot have size exceeding 255"
);
Self { symbols, lengths }
}
pub fn max_decompression_capacity(&self, compressed: &[u8]) -> usize {
size_of::<Symbol>() * (compressed.len() + 1)
}
pub fn decompress_into(&self, compressed: &[u8], decoded: &mut [MaybeUninit<u8>]) -> usize {
assert!(
decoded.len() >= compressed.len() / 2,
"decoded is smaller than lower-bound decompressed size"
);
unsafe {
let mut in_ptr = compressed.as_ptr();
let _in_begin = in_ptr;
let in_end = in_ptr.add(compressed.len());
let mut out_ptr: *mut u8 = decoded.as_mut_ptr().cast();
let out_begin = out_ptr.cast_const();
let out_end = decoded.as_ptr().add(decoded.len()).cast::<u8>();
macro_rules! store_next_symbol {
($code:expr) => {{
out_ptr
.cast::<u64>()
.write_unaligned(self.symbols.get_unchecked($code as usize).to_u64());
out_ptr = out_ptr.add(*self.lengths.get_unchecked($code as usize) as usize);
}};
}
if decoded.len() >= 8 * size_of::<Symbol>() && compressed.len() >= 8 {
let block_out_end = out_end.sub(8 * size_of::<Symbol>());
let block_in_end = in_end.sub(8);
while out_ptr.cast_const() <= block_out_end && in_ptr < block_in_end {
let next_block = in_ptr.cast::<u64>().read_unaligned();
let escape_mask = (next_block & 0x8080808080808080)
& ((((!next_block) & 0x7F7F7F7F7F7F7F7F) + 0x7F7F7F7F7F7F7F7F)
^ 0x8080808080808080);
if escape_mask == 0 {
let code = (next_block & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 8) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 16) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 24) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 32) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 40) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 48) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 56) & 0xFF) as u8;
store_next_symbol!(code);
in_ptr = in_ptr.add(8);
} else if (next_block & 0x00FF00FF00FF00FF) == 0x00FF00FF00FF00FF {
out_ptr.write(((next_block >> 8) & 0xFF) as u8);
out_ptr.add(1).write(((next_block >> 24) & 0xFF) as u8);
out_ptr.add(2).write(((next_block >> 40) & 0xFF) as u8);
out_ptr.add(3).write(((next_block >> 56) & 0xFF) as u8);
out_ptr = out_ptr.add(4);
in_ptr = in_ptr.add(8);
} else {
let first_escape_pos = escape_mask.trailing_zeros() >> 3; debug_assert!(first_escape_pos < 8);
match first_escape_pos {
7 => {
let code = (next_block & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 8) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 16) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 24) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 32) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 40) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 48) & 0xFF) as u8;
store_next_symbol!(code);
in_ptr = in_ptr.add(7);
}
6 => {
let code = (next_block & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 8) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 16) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 24) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 32) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 40) & 0xFF) as u8;
store_next_symbol!(code);
let escaped = ((next_block >> 56) & 0xFF) as u8;
out_ptr.write(escaped);
out_ptr = out_ptr.add(1);
in_ptr = in_ptr.add(8);
}
5 => {
let code = (next_block & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 8) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 16) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 24) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 32) & 0xFF) as u8;
store_next_symbol!(code);
let escaped = ((next_block >> 48) & 0xFF) as u8;
out_ptr.write(escaped);
out_ptr = out_ptr.add(1);
in_ptr = in_ptr.add(7);
}
4 => {
let code = (next_block & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 8) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 16) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 24) & 0xFF) as u8;
store_next_symbol!(code);
let escaped = ((next_block >> 40) & 0xFF) as u8;
out_ptr.write(escaped);
out_ptr = out_ptr.add(1);
in_ptr = in_ptr.add(6);
}
3 => {
let code = (next_block & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 8) & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 16) & 0xFF) as u8;
store_next_symbol!(code);
let escaped = ((next_block >> 32) & 0xFF) as u8;
out_ptr.write(escaped);
out_ptr = out_ptr.add(1);
in_ptr = in_ptr.add(5);
}
2 => {
let code = (next_block & 0xFF) as u8;
store_next_symbol!(code);
let code = ((next_block >> 8) & 0xFF) as u8;
store_next_symbol!(code);
let escaped = ((next_block >> 24) & 0xFF) as u8;
out_ptr.write(escaped);
out_ptr = out_ptr.add(1);
in_ptr = in_ptr.add(4);
}
1 => {
let code = (next_block & 0xFF) as u8;
store_next_symbol!(code);
let escaped = ((next_block >> 16) & 0xFF) as u8;
out_ptr.write(escaped);
out_ptr = out_ptr.add(1);
in_ptr = in_ptr.add(3);
}
0 => {
let escaped = ((next_block >> 8) & 0xFF) as u8;
in_ptr = in_ptr.add(2);
out_ptr.write(escaped);
out_ptr = out_ptr.add(1);
}
_ => unreachable!(),
}
}
}
}
while out_end.offset_from(out_ptr) >= size_of::<Symbol>() as isize && in_ptr < in_end {
let code = in_ptr.read();
in_ptr = in_ptr.add(1);
if code == ESCAPE_CODE {
assert!(
in_ptr < in_end,
"truncated compressed string: escape code at end of input"
);
out_ptr.write(in_ptr.read());
in_ptr = in_ptr.add(1);
out_ptr = out_ptr.add(1);
} else {
store_next_symbol!(code);
}
}
while in_ptr < in_end {
let code = in_ptr.read();
in_ptr = in_ptr.add(1);
if code == ESCAPE_CODE {
assert!(
in_ptr < in_end,
"truncated compressed string: escape code at end of input"
);
assert!(
out_ptr.cast_const() < out_end,
"output buffer sized too small"
);
out_ptr.write(in_ptr.read());
in_ptr = in_ptr.add(1);
out_ptr = out_ptr.add(1);
} else {
let len = *self.lengths.get_unchecked(code as usize) as usize;
assert!(
out_end.offset_from(out_ptr) >= len as isize,
"output buffer sized too small"
);
let sym = self.symbols.get_unchecked(code as usize).to_u64();
let sym_bytes = sym.to_le_bytes();
std::ptr::copy_nonoverlapping(sym_bytes.as_ptr(), out_ptr, len);
out_ptr = out_ptr.add(len);
}
}
assert_eq!(
in_ptr, in_end,
"decompression should exhaust input before output"
);
out_ptr.offset_from(out_begin) as usize
}
}
pub fn decompress(&self, compressed: &[u8]) -> Vec<u8> {
let mut decoded = Vec::with_capacity(self.max_decompression_capacity(compressed) + 7);
let len = self.decompress_into(compressed, decoded.spare_capacity_mut());
unsafe { decoded.set_len(len) };
decoded
}
}
#[derive(Clone)]
pub struct Compressor {
pub(crate) symbols: Vec<Symbol>,
pub(crate) lengths: Vec<u8>,
pub(crate) n_symbols: u8,
codes_two_byte: Vec<Code>,
has_suffix_code: u8,
lossy_pht: LossyPHT,
}
impl Compressor {
pub unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) {
let first_byte = word as u8;
unsafe { out_ptr.byte_add(1).write_unaligned(first_byte) };
let code_twobyte = unsafe { *self.codes_two_byte.get_unchecked(word as u16 as usize) };
if code_twobyte.code() < self.has_suffix_code {
unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
(2, 1)
} else {
let entry = self.lossy_pht.lookup(word);
let ignored_bits = entry.ignored_bits;
if entry.code != Code::UNUSED
&& compare_masked(word, entry.symbol.to_u64(), ignored_bits)
{
unsafe { std::ptr::write(out_ptr, entry.code.code()) };
(entry.code.len() as usize, 1)
} else {
unsafe { std::ptr::write(out_ptr, code_twobyte.code()) };
(
code_twobyte.len() as usize,
1 + (code_twobyte.extended_code() >> 8) as usize,
)
}
}
}
pub fn compress_bulk(&self, lines: &Vec<&[u8]>) -> Vec<Vec<u8>> {
let mut res = Vec::new();
for line in lines {
res.push(self.compress(line));
}
res
}
pub unsafe fn compress_into(&self, plaintext: &[u8], values: &mut Vec<u8>) {
let mut in_ptr = plaintext.as_ptr();
let mut out_ptr = values.as_mut_ptr();
let in_end = unsafe { in_ptr.byte_add(plaintext.len()) };
let in_end_sub8 = in_end as usize - 8;
let out_end = unsafe { out_ptr.byte_add(values.capacity()) };
while (in_ptr as usize) <= in_end_sub8 && unsafe { out_end.offset_from(out_ptr) } >= 2 {
unsafe {
let word: u64 = std::ptr::read_unaligned(in_ptr as *const u64);
let (advance_in, advance_out) = self.compress_word(word, out_ptr);
in_ptr = in_ptr.byte_add(advance_in);
out_ptr = out_ptr.byte_add(advance_out);
};
}
let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
assert!(
out_ptr < out_end || remaining_bytes == 0,
"output buffer sized too small"
);
let remaining_bytes = remaining_bytes as usize;
let mut bytes = [0u8; 8];
unsafe { std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes) };
let mut last_word = u64::from_le_bytes(bytes);
while in_ptr < in_end && unsafe { out_end.offset_from(out_ptr) } >= 2 {
let (advance_in, advance_out) = unsafe { self.compress_word(last_word, out_ptr) };
unsafe {
in_ptr = in_ptr.add(advance_in);
out_ptr = out_ptr.add(advance_out);
}
last_word = advance_8byte_word(last_word, advance_in);
}
assert!(
in_ptr >= in_end,
"exhausted output buffer before exhausting input, there is a bug in SymbolTable::compress()"
);
assert!(out_ptr <= out_end, "output buffer sized too small");
let bytes_written = unsafe { out_ptr.offset_from(values.as_ptr()) };
assert!(
bytes_written >= 0,
"out_ptr ended before it started, not possible"
);
unsafe { values.set_len(bytes_written as usize) };
}
pub fn compress(&self, plaintext: &[u8]) -> Vec<u8> {
if plaintext.is_empty() {
return Vec::new();
}
let mut buffer = Vec::with_capacity(plaintext.len() * 2);
unsafe { self.compress_into(plaintext, &mut buffer) };
buffer
}
pub fn decompressor(&self) -> Decompressor<'_> {
Decompressor::new(self.symbol_table(), self.symbol_lengths())
}
pub fn symbol_table(&self) -> &[Symbol] {
&self.symbols[0..self.n_symbols as usize]
}
pub fn symbol_lengths(&self) -> &[u8] {
&self.lengths[0..self.n_symbols as usize]
}
pub fn rebuild_from(symbols: impl AsRef<[Symbol]>, symbol_lens: impl AsRef<[u8]>) -> Self {
let symbols = symbols.as_ref();
let symbol_lens = symbol_lens.as_ref();
assert_eq!(
symbols.len(),
symbol_lens.len(),
"symbols and lengths differ"
);
assert!(
symbols.len() <= 255,
"symbol table len must be <= 255, was {}",
symbols.len()
);
validate_symbol_order(symbol_lens);
let symbols = symbols.to_vec();
let lengths = symbol_lens.to_vec();
let mut lossy_pht = LossyPHT::new();
let mut codes_one_byte = vec![Code::UNUSED; 256];
for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
if len == 1 {
codes_one_byte[symbol.first_byte() as usize] = Code::new_symbol(code as u8, 1);
}
}
let mut codes_two_byte = vec![Code::UNUSED; 65_536];
for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
match len {
2 => {
codes_two_byte[symbol.first2() as usize] = Code::new_symbol(code as u8, 2);
}
3.. => {
assert!(
lossy_pht.insert(symbol, len as usize, code as u8),
"rebuild symbol insertion into PHT must succeed"
);
}
_ => { }
}
}
for (symbol, code) in codes_two_byte.iter_mut().enumerate() {
if *code == Code::UNUSED {
*code = codes_one_byte[symbol & 0xFF];
}
}
let mut has_suffix_code = 0u8;
for (code, (&symbol, &len)) in symbols.iter().zip(lengths.iter()).enumerate() {
if len != 2 {
break;
}
let rest = &symbols[code..];
if rest
.iter()
.any(|&other| other.len() > 2 && symbol.first2() == other.first2())
{
has_suffix_code = code as u8;
break;
}
}
Compressor {
n_symbols: symbols.len() as u8,
symbols,
lengths,
codes_two_byte,
lossy_pht,
has_suffix_code,
}
}
}
#[inline]
pub(crate) fn advance_8byte_word(word: u64, bytes: usize) -> u64 {
if bytes == 8 { 0 } else { word >> (8 * bytes) }
}
fn validate_symbol_order(symbol_lens: &[u8]) {
let mut expected = 2;
for (idx, &len) in symbol_lens.iter().enumerate() {
if expected == 1 {
assert_eq!(
len, 1,
"symbol code={idx} should be one byte, was {len} bytes"
);
} else {
if len == 1 {
expected = 1;
}
assert!(
len >= expected,
"symbol code={idx} breaks violates FSST symbol table ordering"
);
expected = len;
}
}
}
#[inline]
pub(crate) fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool {
let mask = u64::MAX >> ignored_bits;
(left & mask) == right
}
#[cfg(test)]
mod test {
use super::*;
use std::{iter, mem};
#[test]
fn test_stuff() {
let compressor = {
let mut builder = CompressorBuilder::new();
builder.insert(Symbol::from_slice(b"helloooo"), 8);
builder.build()
};
let decompressor = compressor.decompressor();
let mut decompressed = Vec::with_capacity(8 + 7);
let len = decompressor.decompress_into(&[0], decompressed.spare_capacity_mut());
assert_eq!(len, 8);
unsafe { decompressed.set_len(len) };
assert_eq!(&decompressed, "helloooo".as_bytes());
}
#[test]
fn test_symbols_good() {
let symbols_u64: &[u64] = &[
24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
6447715, 97, 98, 100, 99, 97, 98, 99, 100,
];
let symbols: &[Symbol] = unsafe { mem::transmute(symbols_u64) };
let lens: Vec<u8> = iter::repeat_n(2u8, 16)
.chain(iter::repeat_n(3u8, 61))
.chain(iter::repeat_n(1u8, 8))
.collect();
let compressor = Compressor::rebuild_from(symbols, lens);
let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
assert_eq!(built_symbols, symbols_u64);
}
#[should_panic(expected = "assertion `left == right` failed")]
#[test]
fn test_symbols_bad() {
let symbols: &[u64] = &[
24931, 25698, 25442, 25699, 25186, 25444, 24932, 25188, 25185, 25441, 25697, 25700,
24929, 24930, 25443, 25187, 6513249, 6512995, 6578786, 6513761, 6513507, 6382434,
6579042, 6512994, 6447460, 6447969, 6382178, 6579041, 6512993, 6448226, 6513250,
6579297, 6513506, 6447459, 6513764, 6447458, 6578529, 6382180, 6513762, 6447714,
6579299, 6513508, 6382436, 6513763, 6578532, 6381924, 6448228, 6579300, 6381921,
6382690, 6382179, 6447713, 6447972, 6513505, 6447457, 6382692, 6513252, 6578785,
6578787, 6578531, 6448225, 6382177, 6382433, 6578530, 6448227, 6381922, 6578788,
6579044, 6382691, 6512996, 6579043, 6579298, 6447970, 6447716, 6447971, 6381923,
6447715, 97, 98, 100, 99, 97, 98, 99, 100,
];
let lens: Vec<u8> = iter::repeat_n(2u8, 16)
.chain(iter::repeat_n(3u8, 61))
.chain(iter::repeat_n(1u8, 8))
.collect();
let mut builder = CompressorBuilder::new();
for (symbol, len) in symbols.iter().zip(lens.iter()) {
let symbol = Symbol::from_slice(&symbol.to_le_bytes());
builder.insert(symbol, *len as usize);
}
let compressor = builder.build();
let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) };
assert_eq!(built_symbols, symbols);
}
}