use std::hash::{Hash, Hasher};
use std::marker::PhantomData;
use std::mem::size_of_val;
use num::{Integer, PrimInt, Unsigned};
use wyhash::WyHash;
use crate::mphf::MphfError::*;
use crate::rank::{RankedBits, RankedBitsAccess};
#[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))]
#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))]
pub struct Mphf<const B: usize = 32, const S: usize = 8, ST: PrimInt + Unsigned = u8, H: Hasher + Default = WyHash> {
ranked_bits: RankedBits,
level_groups: Box<[u32]>,
group_seeds: Box<[ST]>,
_phantom_hasher: PhantomData<H>,
}
const MAX_LEVELS: usize = 64;
#[derive(Debug)]
pub enum MphfError {
MaxLevelsExceeded,
InvalidSeedType,
InvalidGammaParameter,
}
pub const DEFAULT_GAMMA: f32 = 2.0;
impl<const B: usize, const S: usize, ST: PrimInt + Unsigned, H: Hasher + Default> Mphf<B, S, ST, H> {
const B: usize = {
assert!(B >= 1 && B <= 64);
B
};
const S: usize = {
assert!(S <= 16);
S
};
pub fn from_slice<K: Hash>(keys: &[K], gamma: f32) -> Result<Self, MphfError> {
if gamma < 1.0 {
return Err(InvalidGammaParameter);
}
if ST::from((1 << Self::S) - 1).is_none() {
return Err(InvalidSeedType);
}
let mut hashes: Vec<u64> = keys.iter().map(|key| hash_key::<H, _>(key)).collect();
let mut group_bits = vec![];
let mut group_seeds = vec![];
let mut level_groups = vec![];
while !hashes.is_empty() {
let level = level_groups.len() as u32;
let (level_group_bits, level_group_seeds) = Self::build_level(level, &mut hashes, gamma);
group_bits.extend_from_slice(&level_group_bits);
group_seeds.extend_from_slice(&level_group_seeds);
level_groups.push(level_group_seeds.len() as u32);
if level_groups.len() == MAX_LEVELS && !hashes.is_empty() {
return Err(MaxLevelsExceeded);
}
}
Ok(Mphf {
ranked_bits: RankedBits::new(group_bits.into_boxed_slice()),
level_groups: level_groups.into_boxed_slice(),
group_seeds: group_seeds.into_boxed_slice(),
_phantom_hasher: PhantomData,
})
}
fn build_level(level: u32, hashes: &mut Vec<u64>, gamma: f32) -> (Vec<u64>, Vec<ST>) {
let level_size = ((hashes.len() as f32) * gamma).ceil() as usize;
let (groups, segments) = Self::level_size_groups_segments(level_size);
let max_group_seed = 1 << S;
let mut group_bits = vec![0u64; 3 * segments + 3];
let mut best_group_seeds = vec![ST::zero(); groups];
for group_seed in 0..max_group_seed {
Self::update_group_bits_with_seed(
level,
groups,
group_seed,
hashes,
&mut group_bits,
&mut best_group_seeds,
);
}
let best_group_bits: Vec<u64> = group_bits[..group_bits.len() - 3]
.chunks_exact(3)
.map(|group_bits| group_bits[2])
.collect();
hashes.retain(|&hash| {
let level_hash = hash_with_seed(hash, level);
let group_idx = fastmod32(level_hash as u32, groups as u32);
let group_seed = best_group_seeds[group_idx].to_u32().unwrap();
let bit_idx = bit_index_for_seed::<B>(level_hash, group_seed, group_idx);
*unsafe { best_group_bits.get_unchecked(bit_idx / 64) } & (1 << (bit_idx % 64)) == 0
});
(best_group_bits, best_group_seeds)
}
#[inline]
fn level_size_groups_segments(size: usize) -> (usize, usize) {
let lcm_value = Self::B.lcm(&64);
let adjusted_size = size.div_ceil(lcm_value) * lcm_value;
(adjusted_size / Self::B, adjusted_size / 64)
}
#[inline]
fn update_group_bits_with_seed(
level: u32,
groups: usize,
group_seed: u32,
hashes: &[u64],
group_bits: &mut [u64],
best_group_seeds: &mut [ST],
) {
let group_bits_len = group_bits.len();
for bits in group_bits[..group_bits_len - 3].chunks_exact_mut(3) {
bits[0] = 0;
bits[1] = 0;
}
for &hash in hashes {
let level_hash = hash_with_seed(hash, level);
let group_idx = fastmod32(level_hash as u32, groups as u32);
let bit_idx = bit_index_for_seed::<B>(level_hash, group_seed, group_idx);
let mask = 1 << (bit_idx % 64);
let idx = (bit_idx / 64) * 3;
let bits = unsafe { group_bits.get_unchecked_mut(idx..idx + 2) };
bits[1] |= bits[0] & mask;
bits[0] |= mask;
}
for bits in group_bits.chunks_exact_mut(3) {
bits[0] &= !bits[1];
}
for (group_idx, best_group_seed) in best_group_seeds.iter_mut().enumerate() {
let bit_idx = group_idx * Self::B;
let bit_pos = bit_idx % 64;
let idx = (bit_idx / 64) * 3;
let bits = unsafe { group_bits.get_unchecked_mut(idx..idx + 6) };
let bits_1 = Self::B.min(64 - bit_pos);
let bits_2 = Self::B - bits_1;
let mask_1 = u64::MAX >> (64 - bits_1);
let mask_2 = (1 << bits_2) - 1;
let new_bits_1 = (bits[0] >> bit_pos) & mask_1;
let new_bits_2 = bits[3] & mask_2;
let new_ones = new_bits_1.count_ones() + new_bits_2.count_ones();
let best_bits_1 = (bits[2] >> bit_pos) & mask_1;
let best_bits_2 = bits[5] & mask_2;
let best_ones = best_bits_1.count_ones() + best_bits_2.count_ones();
if new_ones > best_ones {
bits[2] &= !(mask_1 << bit_pos);
bits[2] |= new_bits_1 << bit_pos;
bits[5] &= !mask_2;
bits[5] |= new_bits_2;
*best_group_seed = ST::from(group_seed).unwrap();
}
}
}
#[inline]
pub fn get<K: Hash + ?Sized>(&self, key: &K) -> Option<usize> {
Self::get_impl(key, &self.level_groups, &self.group_seeds, &self.ranked_bits)
}
#[inline]
fn get_impl<K: Hash + ?Sized>(
key: &K,
level_groups: &[u32],
group_seeds: &[ST],
ranked_bits: &impl RankedBitsAccess,
) -> Option<usize> {
let mut groups_before = 0;
for (level, &groups) in level_groups.iter().enumerate() {
let level_hash = hash_with_seed(hash_key::<H, _>(key), level as u32);
let group_idx = groups_before + fastmod32(level_hash as u32, groups);
let group_seed = unsafe { group_seeds.get_unchecked(group_idx).to_u32().unwrap() };
let bit_idx = bit_index_for_seed::<B>(level_hash, group_seed, group_idx);
if let Some(rank) = ranked_bits.rank(bit_idx) {
return Some(rank);
}
groups_before += groups as usize;
}
None
}
pub fn size(&self) -> usize {
size_of_val(self)
+ size_of_val(self.level_groups.as_ref())
+ size_of_val(self.group_seeds.as_ref())
+ self.ranked_bits.size()
}
}
#[inline]
fn hash_key<H: Hasher + Default, T: Hash + ?Sized>(key: &T) -> u64 {
let mut hasher = H::default();
key.hash(&mut hasher);
hasher.finish()
}
#[inline]
fn bit_index_for_seed<const B: usize>(hash: u64, group_seed: u32, groups_before: usize) -> usize {
let mut x = (hash as u32) ^ group_seed;
x = (x ^ (x >> 16)).wrapping_mul(0x85ebca6b);
x = (x ^ (x >> 13)).wrapping_mul(0xc2b2ae35);
x ^= x >> 16;
groups_before * B + fastmod32(x, B as u32)
}
#[inline]
fn hash_with_seed(hash: u64, seed: u32) -> u64 {
let x = ((hash as u128) ^ (seed as u128)).wrapping_mul(0x5851f42d4c957f2d);
((x & 0xFFFFFFFFFFFFFFFF) as u64) ^ ((x >> 64) as u64)
}
#[inline]
fn fastmod32(x: u32, n: u32) -> usize {
(((x as u64) * (n as u64)) >> 32) as usize
}
#[cfg(feature = "rkyv_derive")]
impl<const B: usize, const S: usize, ST, H> ArchivedMphf<B, S, ST, H>
where
ST: PrimInt + Unsigned + rkyv::Archive<Archived = ST>,
H: Hasher + Default,
{
#[inline]
pub fn get<K: Hash + ?Sized>(&self, key: &K) -> Option<usize> {
Mphf::<B, S, ST, H>::get_impl(key, &self.level_groups, &self.group_seeds, &self.ranked_bits)
}
}
#[cfg(test)]
mod tests {
use super::*;
use paste::paste;
use std::collections::HashSet;
use test_case::test_case;
fn test_mphfs_impl<const B: usize, const S: usize>(n: usize, gamma: f32) -> String {
let keys = (0..n as u64).collect::<Vec<u64>>();
let mphf = Mphf::<B, S>::from_slice(&keys, gamma).expect("failed to create mphf");
let mut set = HashSet::with_capacity(n);
for key in &keys {
let idx = mphf.get(key).unwrap();
assert!(idx < n, "idx = {} n = {}", idx, n);
if !set.insert(idx) {
panic!("duplicate idx = {} for key {}", idx, key);
}
}
assert_eq!(set.len(), n);
let mut avg_levels = 0f32;
let total_groups: u32 = mphf.level_groups.iter().sum();
for (i, &groups) in mphf.level_groups.iter().enumerate() {
avg_levels += ((i + 1) as f32 * groups as f32) / (total_groups as f32);
}
let bits = mphf.size() as f32 * (8.0 / n as f32);
format!(
"bits: {:.2} total_levels: {} avg_levels: {:.2}",
bits,
mphf.level_groups.len(),
avg_levels
)
}
macro_rules! generate_tests {
($(($b:expr, $s:expr, $n: expr, $gamma:expr, $expected:expr)),* $(,)?) => {
$(
paste! {
#[test_case($n, $gamma => $expected)]
fn [<test_mphfs_ $b _ $s _ $n _ $gamma>](n: usize, gamma_scaled: usize) -> String {
let gamma = (gamma_scaled as f32) / 100.0;
test_mphfs_impl::<$b, $s>(n, gamma)
}
}
)*
};
}
generate_tests!(
(1, 8, 10000, 100, "bits: 26.64 total_levels: 42 avg_levels: 4.34"),
(2, 8, 10000, 100, "bits: 9.00 total_levels: 8 avg_levels: 1.76"),
(4, 8, 10000, 100, "bits: 4.39 total_levels: 6 avg_levels: 1.42"),
(7, 8, 10000, 100, "bits: 3.12 total_levels: 4 avg_levels: 1.39"),
(8, 8, 10000, 100, "bits: 2.80 total_levels: 6 avg_levels: 1.34"),
(15, 8, 10000, 100, "bits: 2.50 total_levels: 4 avg_levels: 1.50"),
(16, 8, 10000, 100, "bits: 2.30 total_levels: 6 avg_levels: 1.43"),
(23, 8, 10000, 100, "bits: 2.53 total_levels: 4 avg_levels: 1.67"),
(24, 8, 10000, 100, "bits: 2.25 total_levels: 6 avg_levels: 1.57"),
(31, 8, 10000, 100, "bits: 2.40 total_levels: 3 avg_levels: 1.44"),
(32, 8, 10000, 100, "bits: 2.20 total_levels: 7 avg_levels: 1.63"),
(33, 8, 10000, 100, "bits: 2.52 total_levels: 4 avg_levels: 1.78"),
(48, 8, 10000, 100, "bits: 2.25 total_levels: 7 avg_levels: 1.78"),
(53, 8, 10000, 100, "bits: 2.90 total_levels: 4 avg_levels: 2.00"),
(61, 8, 10000, 100, "bits: 2.82 total_levels: 4 avg_levels: 2.00"),
(63, 8, 10000, 100, "bits: 2.89 total_levels: 4 avg_levels: 2.00"),
(64, 8, 10000, 100, "bits: 2.25 total_levels: 8 avg_levels: 1.84"),
(32, 7, 10000, 100, "bits: 2.29 total_levels: 7 avg_levels: 1.70"),
(32, 5, 10000, 100, "bits: 2.47 total_levels: 8 avg_levels: 1.84"),
(32, 4, 10000, 100, "bits: 2.58 total_levels: 9 avg_levels: 1.92"),
(32, 3, 10000, 100, "bits: 2.75 total_levels: 10 avg_levels: 2.05"),
(32, 1, 10000, 100, "bits: 3.22 total_levels: 11 avg_levels: 2.39"),
(32, 0, 10000, 100, "bits: 3.65 total_levels: 14 avg_levels: 2.73"),
(32, 8, 100000, 100, "bits: 2.11 total_levels: 10 avg_levels: 1.64"),
(32, 8, 100000, 200, "bits: 2.73 total_levels: 4 avg_levels: 1.06"),
(32, 6, 100000, 200, "bits: 2.84 total_levels: 5 avg_levels: 1.11"),
);
#[cfg(feature = "rkyv_derive")]
#[test]
fn test_rkyv() {
let n = 10000;
let keys = (0..n as u64).collect::<Vec<u64>>();
let mphf = Mphf::<32, 4>::from_slice(&keys, DEFAULT_GAMMA).expect("failed to create mphf");
let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&mphf).unwrap();
assert_eq!(rkyv_bytes.len(), 3804);
let rkyv_mphf = rkyv::check_archived_root::<Mphf<32, 4>>(&rkyv_bytes).unwrap();
let mut set = HashSet::with_capacity(n);
for key in &keys {
let idx = mphf.get(key).unwrap();
let rkyv_idx = rkyv_mphf.get(key).unwrap();
assert_eq!(idx, rkyv_idx);
assert!(idx < n, "idx = {} n = {}", idx, n);
if !set.insert(idx) {
panic!("duplicate idx = {} for key {}", idx, key);
}
}
assert_eq!(set.len(), n);
}
}