use alloc::vec;
use alloc::vec::Vec;
const ZSTD_HASHLOG_MAX: u32 = 30;
fn validate_params(hash_log: u32, mls: u32) {
assert!(
(1..=ZSTD_HASHLOG_MAX).contains(&hash_log),
"hash_log must be in 1..={ZSTD_HASHLOG_MAX} for upstream zstd-compatible Fast hashing (got {hash_log}); \
the lower bound prevents a full-word-width shift in hash_ptr, the upper bound is upstream zstd's ZSTD_HASHLOG_MAX",
);
assert!(
(4..=8).contains(&mls),
"ZSTD Fast strategy only supports mls 4..=8 (got {mls})",
);
}
const PRIME_4_BYTES: u32 = 0x9E3779B1;
const PRIME_5_BYTES: u64 = 889_523_592_379;
const PRIME_6_BYTES: u64 = 227_718_039_650_203;
const PRIME_7_BYTES: u64 = 58_295_818_150_454_627;
const PRIME_8_BYTES: u64 = 0xCF1BBCDCB7A56463;
pub(crate) struct FastHashTable {
table: Vec<u32>,
hash_log: u32,
mls: u32,
bias: u32,
}
impl Clone for FastHashTable {
fn clone(&self) -> Self {
Self {
table: self.table.clone(),
hash_log: self.hash_log,
mls: self.mls,
bias: self.bias,
}
}
fn clone_from(&mut self, source: &Self) {
self.table.clone_from(&source.table);
self.hash_log = source.hash_log;
self.mls = source.mls;
self.bias = source.bias;
}
}
impl FastHashTable {
pub(crate) fn new(hash_log: u32, mls: u32) -> Self {
validate_params(hash_log, mls);
let entries = 1usize.checked_shl(hash_log).unwrap_or_else(|| {
panic!(
"FastHashTable cannot allocate 2^{hash_log} u32 entries on this target: \
`1usize << {hash_log}` overflows {0}-bit usize",
usize::BITS,
)
});
let bytes = entries
.checked_mul(core::mem::size_of::<u32>())
.unwrap_or_else(|| {
panic!(
"FastHashTable cannot allocate {entries} u32 entries on this target: \
byte size overflows {0}-bit usize",
usize::BITS,
)
});
let _ = bytes;
Self {
table: vec![0u32; entries],
hash_log,
mls,
bias: 0,
}
}
#[inline(always)]
pub(crate) fn hash_log(&self) -> u32 {
self.hash_log
}
#[inline(always)]
pub(crate) fn mls(&self) -> u32 {
self.mls
}
pub(crate) fn heap_size(&self) -> usize {
self.table.capacity() * core::mem::size_of::<u32>()
}
pub(crate) fn clear(&mut self) {
self.table.fill(0);
self.bias = 0;
}
pub(crate) fn advance_epoch(&mut self, span: u32) {
const POSITION_CEILING: u32 = 1 << 31;
match self.bias.checked_add(span) {
Some(new_bias) if new_bias <= u32::MAX - POSITION_CEILING => self.bias = new_bias,
_ => self.clear(),
}
}
#[inline(always)]
pub(crate) unsafe fn hash_ptr<const MLS: u32>(&self, ptr: *const u8) -> u32 {
debug_assert_eq!(MLS, self.mls, "monomorphised MLS must match table mls");
unsafe { hash_ptr_raw::<MLS>(ptr, self.hash_log) }
}
#[inline(always)]
pub(crate) fn hot_state(&mut self) -> (&mut [u32], u32) {
debug_assert_eq!(self.bias, 0, "hot_state requires an unbiased table");
(self.table.as_mut_slice(), self.hash_log)
}
#[inline(always)]
pub(crate) unsafe fn get(&self, hash: u32) -> u32 {
debug_assert!((hash as usize) < self.table.len());
let raw = unsafe { *self.table.get_unchecked(hash as usize) };
raw.saturating_sub(self.bias)
}
#[inline(always)]
pub(crate) unsafe fn put(&mut self, hash: u32, pos: u32) {
debug_assert!((hash as usize) < self.table.len());
let biased = pos + self.bias;
unsafe {
*self.table.get_unchecked_mut(hash as usize) = biased;
}
}
}
#[inline(always)]
pub(crate) unsafe fn hash_ptr_raw<const MLS: u32>(ptr: *const u8, hash_log: u32) -> u32 {
match MLS {
4 => {
let u = unsafe { core::ptr::read_unaligned(ptr.cast::<u32>()) }.to_le();
u.wrapping_mul(PRIME_4_BYTES) >> (32 - hash_log)
}
5 => {
let u = unsafe { core::ptr::read_unaligned(ptr.cast::<u64>()) }.to_le();
((u << (64 - 40)).wrapping_mul(PRIME_5_BYTES) >> (64 - hash_log)) as u32
}
6 => {
let u = unsafe { core::ptr::read_unaligned(ptr.cast::<u64>()) }.to_le();
((u << (64 - 48)).wrapping_mul(PRIME_6_BYTES) >> (64 - hash_log)) as u32
}
7 => {
let u = unsafe { core::ptr::read_unaligned(ptr.cast::<u64>()) }.to_le();
((u << (64 - 56)).wrapping_mul(PRIME_7_BYTES) >> (64 - hash_log)) as u32
}
8 => {
let u = unsafe { core::ptr::read_unaligned(ptr.cast::<u64>()) }.to_le();
(u.wrapping_mul(PRIME_8_BYTES) >> (64 - hash_log)) as u32
}
_ => {
debug_assert!(false, "unsupported MLS {MLS}");
0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hash4_matches_expected_value_on_known_input() {
let table = FastHashTable::new(12, 4);
let data = [0x01u8, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
let h = unsafe { table.hash_ptr::<4>(data.as_ptr()) };
let expected = 0x04030201u32.wrapping_mul(0x9E3779B1) >> 20;
assert_eq!(
h, expected,
"hash4 must match upstream zstd multiply-shift formula"
);
}
#[test]
fn hash5_matches_expected_value_on_known_input() {
let table = FastHashTable::new(13, 5);
let data = [0x01u8, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
let h = unsafe { table.hash_ptr::<5>(data.as_ptr()) };
let u = u64::from_le_bytes(data);
let expected = (((u << (64 - 40)).wrapping_mul(889_523_592_379u64)) >> (64 - 13)) as u32;
assert_eq!(h, expected);
}
#[test]
fn get_put_round_trip_under_known_hash() {
let mut table = FastHashTable::new(8, 4);
let data = [0xAAu8, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11];
let h = unsafe { table.hash_ptr::<4>(data.as_ptr()) };
unsafe {
assert_eq!(table.get(h), 0, "fresh table reads sentinel");
table.put(h, 0xCAFE_BABE);
assert_eq!(table.get(h), 0xCAFE_BABE);
}
}
#[test]
fn clear_resets_all_entries_to_sentinel() {
let mut table = FastHashTable::new(6, 4);
let data = [1u8, 2, 3, 4, 5, 6, 7, 8];
let h = unsafe { table.hash_ptr::<4>(data.as_ptr()) };
unsafe {
table.put(h, 42);
}
table.clear();
let read_back = unsafe { table.get(h) };
assert_eq!(read_back, 0, "clear must zero every entry");
}
#[test]
fn hash6_matches_expected_value_on_known_input() {
let table = FastHashTable::new(14, 6);
let data = [0x01u8, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
let h = unsafe { table.hash_ptr::<6>(data.as_ptr()) };
let u = u64::from_le_bytes(data);
let expected =
(((u << (64 - 48)).wrapping_mul(227_718_039_650_203u64)) >> (64 - 14)) as u32;
assert_eq!(
h, expected,
"hash6 must match upstream zstd multiply-shift formula"
);
}
#[test]
fn hash7_matches_expected_value_on_known_input() {
let table = FastHashTable::new(15, 7);
let data = [0x01u8, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
let h = unsafe { table.hash_ptr::<7>(data.as_ptr()) };
let u = u64::from_le_bytes(data);
let expected =
(((u << (64 - 56)).wrapping_mul(58_295_818_150_454_627u64)) >> (64 - 15)) as u32;
assert_eq!(
h, expected,
"hash7 must match upstream zstd multiply-shift formula"
);
}
#[test]
fn hash8_matches_expected_value_on_known_input() {
let table = FastHashTable::new(16, 8);
let data = [0x01u8, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
let h = unsafe { table.hash_ptr::<8>(data.as_ptr()) };
let u = u64::from_le_bytes(data);
let expected = (u.wrapping_mul(0xCF1BBCDCB7A56463u64) >> (64 - 16)) as u32;
assert_eq!(
h, expected,
"hash8 must match upstream zstd multiply-shift formula"
);
}
#[test]
fn hash_log_minimum_one_constructs_two_entry_table() {
let table = FastHashTable::new(1, 4);
let data = [0u8, 0, 0, 0];
let h = unsafe { table.hash_ptr::<4>(data.as_ptr()) };
assert!(h < 2, "hash_log=1 must produce values ∈ {{0, 1}} (got {h})");
}
#[test]
fn hash_log_maximum_thirty_is_accepted_by_validation() {
validate_params(ZSTD_HASHLOG_MAX, 4);
validate_params(ZSTD_HASHLOG_MAX, 8);
}
#[test]
#[should_panic(expected = "hash_log must be in 1..=")]
fn panics_on_zero_hash_log() {
let _ = FastHashTable::new(0, 4);
}
#[test]
#[should_panic(expected = "hash_log must be in 1..=")]
fn panics_on_hash_log_above_zstd_hashlog_max() {
let _ = FastHashTable::new(31, 4);
}
#[test]
#[should_panic(expected = "ZSTD Fast strategy only supports mls 4..=8")]
fn panics_on_mls_below_four() {
let _ = FastHashTable::new(12, 3);
}
#[test]
#[should_panic(expected = "ZSTD Fast strategy only supports mls 4..=8")]
fn panics_on_mls_above_eight() {
let _ = FastHashTable::new(12, 9);
}
}