use crate::{
Code, Compressor, FSST_CODE_BASE, FSST_CODE_MASK, Symbol, advance_8byte_word, compare_masked,
lossy_pht::LossyPHT,
};
use rustc_hash::{FxBuildHasher, FxHashMap};
use std::cmp::Ordering;
use std::collections::BinaryHeap;
#[derive(Clone, Copy, Debug, Default)]
struct CodesBitmap {
codes: [u64; 8],
}
assert_sizeof!(CodesBitmap => 64);
impl CodesBitmap {
pub(crate) fn set(&mut self, index: usize) {
debug_assert!(
index <= FSST_CODE_MASK as usize,
"code cannot exceed {FSST_CODE_MASK}"
);
let map = index >> 6;
self.codes[map] |= 1 << (index % 64);
}
pub(crate) fn is_set(&self, index: usize) -> bool {
debug_assert!(
index <= FSST_CODE_MASK as usize,
"code cannot exceed {FSST_CODE_MASK}"
);
let map = index >> 6;
self.codes[map] & (1 << (index % 64)) != 0
}
pub(crate) fn codes(&self) -> CodesIterator<'_> {
CodesIterator {
inner: self,
index: 0,
block: self.codes[0],
reference: 0,
}
}
pub(crate) fn clear(&mut self) {
self.codes[0] = 0;
self.codes[1] = 0;
self.codes[2] = 0;
self.codes[3] = 0;
self.codes[4] = 0;
self.codes[5] = 0;
self.codes[6] = 0;
self.codes[7] = 0;
}
}
struct CodesIterator<'a> {
inner: &'a CodesBitmap,
index: usize,
block: u64,
reference: usize,
}
impl Iterator for CodesIterator<'_> {
type Item = u16;
fn next(&mut self) -> Option<Self::Item> {
while self.block == 0 {
self.index += 1;
if self.index >= 8 {
return None;
}
self.block = self.inner.codes[self.index];
self.reference = self.index * 64;
}
let position = self.block.trailing_zeros() as usize;
let code = self.reference + position;
if code >= 511 {
return None;
}
self.reference = code + 1;
self.block = if position == 63 {
0
} else {
self.block >> (1 + position)
};
Some(code as u16)
}
}
#[derive(Debug, Clone)]
struct Counter {
counts1: Vec<usize>,
counts2: Vec<usize>,
code1_index: CodesBitmap,
pair_index: Vec<CodesBitmap>,
}
const COUNTS1_SIZE: usize = (FSST_CODE_MASK + 1) as usize;
const COUNTS2_SIZE: usize = COUNTS1_SIZE * COUNTS1_SIZE;
impl Counter {
fn new() -> Self {
let mut counts1 = Vec::with_capacity(COUNTS1_SIZE);
let mut counts2 = Vec::with_capacity(COUNTS2_SIZE);
unsafe {
counts1.set_len(COUNTS1_SIZE);
counts2.set_len(COUNTS2_SIZE);
}
Self {
counts1,
counts2,
code1_index: CodesBitmap::default(),
pair_index: vec![CodesBitmap::default(); COUNTS1_SIZE],
}
}
#[inline]
fn record_count1(&mut self, code1: u16) {
let base = if self.code1_index.is_set(code1 as usize) {
self.counts1[code1 as usize]
} else {
0
};
self.counts1[code1 as usize] = base + 1;
self.code1_index.set(code1 as usize);
}
#[inline]
fn record_count2(&mut self, code1: u16, code2: u16) {
debug_assert!(code1 == FSST_CODE_MASK || self.code1_index.is_set(code1 as usize));
debug_assert!(self.code1_index.is_set(code2 as usize));
let idx = (code1 as usize) * COUNTS1_SIZE + (code2 as usize);
if self.pair_index[code1 as usize].is_set(code2 as usize) {
self.counts2[idx] += 1;
} else {
self.counts2[idx] = 1;
}
self.pair_index[code1 as usize].set(code2 as usize);
}
#[inline]
fn count1(&self, code1: u16) -> usize {
debug_assert!(self.code1_index.is_set(code1 as usize));
self.counts1[code1 as usize]
}
#[inline]
fn count2(&self, code1: u16, code2: u16) -> usize {
debug_assert!(self.code1_index.is_set(code1 as usize));
debug_assert!(self.code1_index.is_set(code2 as usize));
debug_assert!(self.pair_index[code1 as usize].is_set(code2 as usize));
let idx = (code1 as usize) * 512 + (code2 as usize);
self.counts2[idx]
}
fn first_codes(&self) -> CodesIterator<'_> {
self.code1_index.codes()
}
fn second_codes(&self, code1: u16) -> CodesIterator<'_> {
self.pair_index[code1 as usize].codes()
}
fn clear(&mut self) {
self.code1_index.clear();
for index in &mut self.pair_index {
index.clear();
}
}
}
pub struct CompressorBuilder {
symbols: Vec<Symbol>,
n_symbols: u8,
len_histogram: [u8; 8],
codes_one_byte: Vec<Code>,
codes_two_byte: Vec<Code>,
lossy_pht: LossyPHT,
}
impl CompressorBuilder {
pub fn new() -> Self {
let symbols = vec![0u64; 511];
let symbols: Vec<Symbol> = unsafe { std::mem::transmute(symbols) };
let mut table = Self {
symbols,
n_symbols: 0,
len_histogram: [0; 8],
codes_two_byte: Vec::with_capacity(65_536),
codes_one_byte: Vec::with_capacity(512),
lossy_pht: LossyPHT::new(),
};
for byte in 0..=255 {
let symbol = Symbol::from_u8(byte);
table.symbols[byte as usize] = symbol;
}
for byte in 0..=255 {
table.codes_one_byte.push(Code::new_escape(byte));
}
for idx in 0..=65_535 {
table.codes_two_byte.push(Code::new_escape(idx as u8));
}
table
}
}
impl Default for CompressorBuilder {
fn default() -> Self {
Self::new()
}
}
impl CompressorBuilder {
pub fn insert(&mut self, symbol: Symbol, len: usize) -> bool {
assert!(self.n_symbols < 255, "cannot insert into full symbol table");
assert_eq!(len, symbol.len(), "provided len must equal symbol.len()");
if len == 2 {
self.codes_two_byte[symbol.first2() as usize] =
Code::new_symbol_building(self.n_symbols, 2);
} else if len == 1 {
self.codes_one_byte[symbol.first_byte() as usize] =
Code::new_symbol_building(self.n_symbols, 1);
} else {
if !self.lossy_pht.insert(symbol, len, self.n_symbols) {
return false;
}
}
self.len_histogram[len - 1] += 1;
self.symbols[256 + (self.n_symbols as usize)] = symbol;
self.n_symbols += 1;
true
}
fn clear(&mut self) {
for code in 0..(256 + self.n_symbols as usize) {
let symbol = self.symbols[code];
if symbol.len() == 1 {
self.codes_one_byte[symbol.first_byte() as usize] =
Code::new_escape(symbol.first_byte());
} else if symbol.len() == 2 {
self.codes_two_byte[symbol.first2() as usize] =
Code::new_escape(symbol.first_byte());
} else {
self.lossy_pht.remove(symbol);
}
}
for i in 0..=7 {
self.len_histogram[i] = 0;
}
self.n_symbols = 0;
}
fn finalize(&mut self) -> (u8, Vec<u8>) {
let byte_lim = self.n_symbols - self.len_histogram[0];
let mut codes_by_length = [0u8; 8];
codes_by_length[0] = byte_lim;
codes_by_length[1] = 0;
for i in 1..7 {
codes_by_length[i + 1] = codes_by_length[i] + self.len_histogram[i];
}
let mut no_suffix_code = 0;
let mut has_suffix_code = codes_by_length[2];
let mut new_codes = [0u8; FSST_CODE_BASE as usize];
let mut symbol_lens = [0u8; FSST_CODE_BASE as usize];
for i in 0..(self.n_symbols as usize) {
let symbol = self.symbols[256 + i];
let len = symbol.len();
if len == 2 {
let has_suffix = self
.symbols
.iter()
.skip(FSST_CODE_BASE as usize)
.enumerate()
.any(|(k, other)| i != k && symbol.first2() == other.first2());
if has_suffix {
has_suffix_code -= 1;
new_codes[i] = has_suffix_code;
} else {
new_codes[i] = no_suffix_code;
no_suffix_code += 1;
}
} else {
new_codes[i] = codes_by_length[len - 1];
codes_by_length[len - 1] += 1;
}
self.symbols[new_codes[i] as usize] = symbol;
symbol_lens[new_codes[i] as usize] = len as u8;
}
self.symbols.truncate(self.n_symbols as usize);
for byte in 0..=255 {
let one_byte = self.codes_one_byte[byte];
if one_byte.extended_code() >= FSST_CODE_BASE {
let new_code = new_codes[one_byte.code() as usize];
self.codes_one_byte[byte] = Code::new_symbol(new_code, 1);
} else {
self.codes_one_byte[byte] = Code::UNUSED;
}
}
for two_bytes in 0..=65_535 {
let two_byte = self.codes_two_byte[two_bytes];
if two_byte.extended_code() >= FSST_CODE_BASE {
let new_code = new_codes[two_byte.code() as usize];
self.codes_two_byte[two_bytes] = Code::new_symbol(new_code, 2);
} else {
self.codes_two_byte[two_bytes] = self.codes_one_byte[two_bytes & 0xFF];
}
}
self.lossy_pht.renumber(&new_codes);
let mut lengths = Vec::with_capacity(self.n_symbols as usize);
for symbol in &self.symbols {
lengths.push(symbol.len() as u8);
}
(has_suffix_code, lengths)
}
pub fn build(mut self) -> Compressor {
let (has_suffix_code, lengths) = self.finalize();
Compressor {
symbols: self.symbols,
lengths,
n_symbols: self.n_symbols,
has_suffix_code,
codes_two_byte: self.codes_two_byte,
lossy_pht: self.lossy_pht,
}
}
}
#[cfg(not(miri))]
const GENERATIONS: [usize; 5] = [8usize, 38, 68, 98, 128];
#[cfg(miri)]
const GENERATIONS: [usize; 3] = [8usize, 38, 128];
const FSST_SAMPLETARGET: usize = 1 << 14;
const FSST_SAMPLEMAX: usize = 1 << 15;
const FSST_SAMPLELINE: usize = 512;
#[allow(clippy::ptr_arg)]
fn make_sample<'a, 'b: 'a>(
sample_buf: &'a mut Vec<u8>,
str_in: &Vec<&'b [u8]>,
tot_size: usize,
) -> Vec<&'a [u8]> {
assert!(
sample_buf.capacity() >= FSST_SAMPLEMAX,
"sample_buf.len() < FSST_SAMPLEMAX"
);
let mut sample: Vec<&[u8]> = Vec::new();
if tot_size < FSST_SAMPLETARGET {
return str_in.clone();
}
let mut sample_rnd = fsst_hash(4637947);
let sample_lim = FSST_SAMPLETARGET;
let mut sample_buf_offset: usize = 0;
while sample_buf_offset < sample_lim {
sample_rnd = fsst_hash(sample_rnd);
let line_nr = (sample_rnd as usize) % str_in.len();
let Some(line) = (line_nr..str_in.len())
.chain(0..line_nr)
.map(|line_nr| str_in[line_nr])
.find(|line| !line.is_empty())
else {
return sample;
};
let chunks = 1 + ((line.len() - 1) / FSST_SAMPLELINE);
sample_rnd = fsst_hash(sample_rnd);
let chunk = FSST_SAMPLELINE * ((sample_rnd as usize) % chunks);
let len = FSST_SAMPLELINE.min(line.len() - chunk);
sample_buf.extend_from_slice(&line[chunk..chunk + len]);
let slice =
unsafe { std::slice::from_raw_parts(sample_buf.as_ptr().add(sample_buf_offset), len) };
sample.push(slice);
sample_buf_offset += len;
}
sample
}
#[inline]
pub(crate) fn fsst_hash(value: u64) -> u64 {
value.wrapping_mul(2971215073) ^ value.wrapping_shr(15)
}
impl Compressor {
pub fn train(values: &Vec<&[u8]>) -> Self {
let mut builder = CompressorBuilder::new();
if values.is_empty() {
return builder.build();
}
let mut counters = Counter::new();
let mut sample_memory = Vec::with_capacity(FSST_SAMPLEMAX);
let mut pqueue = BinaryHeap::with_capacity(65_536);
let tot_size: usize = values.iter().map(|s| s.len()).sum();
let sampled = tot_size >= FSST_SAMPLETARGET;
let sample = make_sample(&mut sample_memory, values, tot_size);
for sample_frac in GENERATIONS {
for (i, line) in sample.iter().enumerate() {
if sample_frac < 128 && ((fsst_hash(i as u64) & 127) as usize) > sample_frac {
continue;
}
builder.compress_count(line, &mut counters);
}
pqueue.clear();
let prune = sample_frac >= 128 && !sampled;
builder.optimize(&counters, sample_frac, &mut pqueue, prune);
counters.clear();
}
builder.build()
}
}
impl CompressorBuilder {
fn find_longest_symbol(&self, word: u64) -> Code {
let entry = self.lossy_pht.lookup(word);
let ignored_bits = entry.ignored_bits;
if !entry.is_unused() && compare_masked(word, entry.symbol.to_u64(), ignored_bits) {
return entry.code;
}
let twobyte = self.codes_two_byte[word as u16 as usize];
if twobyte.extended_code() >= FSST_CODE_BASE {
return twobyte;
}
self.codes_one_byte[word as u8 as usize]
}
fn compress_count(&self, sample: &[u8], counter: &mut Counter) -> usize {
let mut gain = 0;
if sample.is_empty() {
return gain;
}
let mut in_ptr = sample.as_ptr();
let in_end = unsafe { in_ptr.byte_add(sample.len()) };
let in_end_sub8 = in_end as usize - 8;
let mut prev_code: u16 = FSST_CODE_MASK;
while (in_ptr as usize) < (in_end_sub8) {
let word: u64 = unsafe { std::ptr::read_unaligned(in_ptr as *const u64) };
let code = self.find_longest_symbol(word);
let code_u16 = code.extended_code();
gain += (code.len() as usize) - ((code_u16 < 256) as usize);
counter.record_count1(code_u16);
counter.record_count2(prev_code, code_u16);
if code.len() > 1 {
let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16;
counter.record_count1(code_first_byte);
counter.record_count2(prev_code, code_first_byte);
}
in_ptr = unsafe { in_ptr.byte_add(code.len() as usize) };
prev_code = code_u16;
}
let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
assert!(
remaining_bytes.is_positive(),
"in_ptr exceeded in_end, should not be possible"
);
let remaining_bytes = remaining_bytes as usize;
let mut bytes = [0u8; 8];
unsafe {
std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes);
}
let mut last_word = u64::from_le_bytes(bytes);
let mut remaining_bytes = remaining_bytes;
while remaining_bytes > 0 {
let code = self.find_longest_symbol(last_word);
let code_u16 = code.extended_code();
gain += (code.len() as usize) - ((code_u16 < 256) as usize);
counter.record_count1(code_u16);
counter.record_count2(prev_code, code_u16);
if code.len() > 1 {
let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16;
counter.record_count1(code_first_byte);
counter.record_count2(prev_code, code_first_byte);
}
let advance = code.len() as usize;
remaining_bytes -= advance;
last_word = advance_8byte_word(last_word, advance);
prev_code = code_u16;
}
gain
}
fn optimize(
&mut self,
counters: &Counter,
sample_frac: usize,
pqueue: &mut BinaryHeap<Candidate>,
prune: bool,
) {
let mut candidates = FxHashMap::with_capacity_and_hasher(256, FxBuildHasher);
for code1 in counters.first_codes() {
let symbol1 = self.symbols[code1 as usize];
let symbol1_len = symbol1.len();
let count = counters.count1(code1);
let min_count = if prune { 1 } else { 5 * sample_frac / 128 };
if count < min_count {
continue;
}
let mut gain = count * symbol1_len;
if symbol1_len == 1 {
gain *= 8;
}
*candidates.entry(symbol1).or_insert(0) += gain;
if sample_frac >= 128 || symbol1_len == 8 {
continue;
}
for code2 in counters.second_codes(code1) {
let symbol2 = self.symbols[code2 as usize];
if symbol1_len + symbol2.len() > 8 {
continue;
}
let new_symbol = symbol1.concat(symbol2);
let gain = counters.count2(code1, code2) * new_symbol.len();
*candidates.entry(new_symbol).or_insert(0) += gain;
}
}
for (symbol, gain) in candidates {
pqueue.push(Candidate { symbol, gain });
}
self.clear();
let mut n_symbols = 0;
while !pqueue.is_empty() && n_symbols < 255 {
let candidate = pqueue.pop().unwrap();
if prune {
let symbol_len = candidate.symbol.len();
let saves = if symbol_len == 1 {
candidate.gain / 8 } else {
candidate.gain
};
if saves <= symbol_len + 1 {
continue;
}
}
if self.insert(candidate.symbol, candidate.symbol.len()) {
n_symbols += 1;
}
}
}
}
#[derive(Copy, Clone, Debug)]
struct Candidate {
gain: usize,
symbol: Symbol,
}
impl Candidate {
fn comparable_form(&self) -> (usize, usize) {
(self.gain, self.symbol.len())
}
}
impl Eq for Candidate {}
impl PartialEq<Self> for Candidate {
fn eq(&self, other: &Self) -> bool {
self.comparable_form().eq(&other.comparable_form())
}
}
impl PartialOrd<Self> for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> Ordering {
let self_ord = (self.gain, self.symbol.len());
let other_ord = (other.gain, other.symbol.len());
self_ord.cmp(&other_ord)
}
}
#[cfg(test)]
mod test {
use crate::{Compressor, ESCAPE_CODE, builder::CodesBitmap};
#[test]
fn test_builder() {
let text = b"hello hello hello hello hello";
let table = Compressor::train(&vec![text, text, text, text, text]);
let compressed = table.compress(text);
assert!(compressed.iter().all(|b| *b != ESCAPE_CODE));
let compressed = table.compress("xyz123".as_bytes());
let decompressed = table.decompressor().decompress(&compressed);
assert_eq!(&decompressed, b"xyz123");
assert_eq!(
compressed,
vec![
ESCAPE_CODE,
b'x',
ESCAPE_CODE,
b'y',
ESCAPE_CODE,
b'z',
ESCAPE_CODE,
b'1',
ESCAPE_CODE,
b'2',
ESCAPE_CODE,
b'3',
]
);
}
#[test]
fn test_bitmap() {
let mut map = CodesBitmap::default();
map.set(10);
map.set(100);
map.set(500);
let codes: Vec<u16> = map.codes().collect();
assert_eq!(codes, vec![10u16, 100, 500]);
let map = CodesBitmap::default();
assert!(map.codes().collect::<Vec<_>>().is_empty());
let mut map = CodesBitmap::default();
(0..8).for_each(|i| map.set(64 * i));
assert_eq!(
map.codes().collect::<Vec<_>>(),
(0u16..8).map(|i| 64 * i).collect::<Vec<_>>(),
);
let mut map = CodesBitmap::default();
for i in 0..512 {
map.set(i);
}
assert_eq!(
map.codes().collect::<Vec<_>>(),
(0u16..511u16).collect::<Vec<_>>()
);
}
#[test]
#[should_panic(expected = "code cannot exceed")]
fn test_bitmap_invalid() {
let mut map = CodesBitmap::default();
map.set(512);
}
#[test]
fn test_no_duplicate_symbols() {
let text = b"aababcabcdabcde";
let corpus: Vec<&[u8]> = std::iter::repeat_n(text.as_slice(), 100).collect();
let compressor = Compressor::train(&corpus);
let symbols = compressor.symbol_table();
let lengths = compressor.symbol_lengths();
let one_byte: Vec<u8> = symbols
.iter()
.zip(lengths.iter())
.filter(|&(_, &len)| len == 1)
.map(|(sym, _)| sym.first_byte())
.collect();
let mut one_byte_sorted = one_byte.clone();
one_byte_sorted.sort();
one_byte_sorted.dedup();
assert_eq!(
one_byte.len(),
one_byte_sorted.len(),
"duplicate 1-byte symbols found"
);
let two_byte: Vec<u16> = symbols
.iter()
.zip(lengths.iter())
.filter(|&(_, &len)| len == 2)
.map(|(sym, _)| sym.first2())
.collect();
let mut two_byte_sorted = two_byte.clone();
two_byte_sorted.sort();
two_byte_sorted.dedup();
assert_eq!(
two_byte.len(),
two_byte_sorted.len(),
"duplicate 2-byte symbols found"
);
}
}