entropy_map/
mphf.rs

1//! # Minimal Perfect Hash Function (MPHF) Module
2//!
3//! This module implements a Minimal Perfect Hash Function (MPHF) based on fingerprinting techniques,
4//! as detailed in [Fingerprinting-based minimal perfect hashing revisited](https://doi.org/10.1145/3596453).
5//!
6//! This implementation is inspired by existing Rust crate [ph](https://github.com/beling/bsuccinct-rs/tree/main/ph),
7//! but prioritizes code simplicity and portability, with a special focus on optimizing the rank
8//! storage mechanism and reducing the construction time and querying latency of MPHF.
9
10use std::hash::{Hash, Hasher};
11use std::marker::PhantomData;
12use std::mem::size_of_val;
13
14use num::{Integer, PrimInt, Unsigned};
15use wyhash::WyHash;
16
17use crate::mphf::MphfError::*;
18use crate::rank::{RankedBits, RankedBitsAccess};
19
20/// A Minimal Perfect Hash Function (MPHF).
21///
22/// Template parameters:
23/// - `B`: group size in bits in [1..64] range, default 32 bits.
24/// - `S`: defines maximum seed value to try (2^S) in [0..16] range, default 8.
25/// - `ST`: seed type (unsigned integer), default `u8`.
26/// - `H`: hasher used to hash keys, default `WyHash`.
27#[derive(Default)]
28#[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))]
29#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))]
30pub struct Mphf<const B: usize = 32, const S: usize = 8, ST: PrimInt + Unsigned = u8, H: Hasher + Default = WyHash> {
31    /// Ranked bits for efficient rank queries
32    ranked_bits: RankedBits,
33    /// Group sizes at each level
34    level_groups: Box<[u32]>,
35    /// Combined group seeds from all levels
36    group_seeds: Box<[ST]>,
37    /// Phantom field for the hasher
38    _phantom_hasher: PhantomData<H>,
39}
40
41/// Maximum number of levels to build for MPHF.
42const MAX_LEVELS: usize = 64;
43
44/// Errors that can occur when initializing `Mphf`.
45#[derive(Debug)]
46pub enum MphfError {
47    /// Error when the maximum number of levels is exceeded during initialization.
48    MaxLevelsExceeded,
49    /// Error when the seed type `ST` is too small to store `S` bits
50    InvalidSeedType,
51    /// Error when the `gamma` parameter is less than 1.0.
52    InvalidGammaParameter,
53}
54
55/// Default `gamma` parameter for MPHF.
56pub const DEFAULT_GAMMA: f32 = 2.0;
57
58impl<const B: usize, const S: usize, ST: PrimInt + Unsigned, H: Hasher + Default> Mphf<B, S, ST, H> {
59    /// Ensure that `B` is in [1..64] range
60    const B: usize = {
61        assert!(B >= 1 && B <= 64);
62        B
63    };
64    /// Ensure that `S` is in [0..16] range
65    const S: usize = {
66        assert!(S <= 16);
67        S
68    };
69
70    /// Initializes `Mphf` using slice of `keys` and parameter `gamma`.
71    pub fn from_slice<K: Hash>(keys: &[K], gamma: f32) -> Result<Self, MphfError> {
72        if gamma < 1.0 {
73            return Err(InvalidGammaParameter);
74        }
75
76        if ST::from((1 << Self::S) - 1).is_none() {
77            return Err(InvalidSeedType);
78        }
79
80        let mut hashes: Vec<u64> = keys.iter().map(|key| hash_key::<H, _>(key)).collect();
81        let mut group_bits = vec![];
82        let mut group_seeds = vec![];
83        let mut level_groups = vec![];
84
85        while !hashes.is_empty() {
86            let level = level_groups.len() as u32;
87            let (level_group_bits, level_group_seeds) = Self::build_level(level, &mut hashes, gamma);
88
89            group_bits.extend_from_slice(&level_group_bits);
90            group_seeds.extend_from_slice(&level_group_seeds);
91            level_groups.push(level_group_seeds.len() as u32);
92
93            if level_groups.len() == MAX_LEVELS && !hashes.is_empty() {
94                return Err(MaxLevelsExceeded);
95            }
96        }
97
98        Ok(Mphf {
99            ranked_bits: RankedBits::new(group_bits.into_boxed_slice()),
100            level_groups: level_groups.into_boxed_slice(),
101            group_seeds: group_seeds.into_boxed_slice(),
102            _phantom_hasher: PhantomData,
103        })
104    }
105
106    /// Builds specified `level` using provided `hashes` and returns level group bits and seeds.
107    fn build_level(level: u32, hashes: &mut Vec<u64>, gamma: f32) -> (Vec<u64>, Vec<ST>) {
108        // compute level size (#bits storing non-collided hashes), number of groups and segments
109        let level_size = ((hashes.len() as f32) * gamma).ceil() as usize;
110        let (groups, segments) = Self::level_size_groups_segments(level_size);
111        let max_group_seed = 1 << S;
112
113        // Reserve x3 bits for all segments to reduce cache misses when updating/fetching group bits.
114        // Every 3 consecutive elements represent:
115        // - 0: hashes bits set for current seed
116        // - 1: hashes collision bits set for current seed
117        // - 2: hashes bits set for best seed
118        let mut group_bits = vec![0u64; 3 * segments + 3];
119        let mut best_group_seeds = vec![ST::zero(); groups];
120
121        // For each seed compute `group_bits` and then update those groups where seed produced less collisions
122        for group_seed in 0..max_group_seed {
123            Self::update_group_bits_with_seed(
124                level,
125                groups,
126                group_seed,
127                hashes,
128                &mut group_bits,
129                &mut best_group_seeds,
130            );
131        }
132
133        // finalize best group bits to be returned
134        let best_group_bits: Vec<u64> = group_bits[..group_bits.len() - 3]
135            .chunks_exact(3)
136            .map(|group_bits| group_bits[2])
137            .collect();
138
139        // filter out hashes which are already stored in `best_group_bits`
140        hashes.retain(|&hash| {
141            let level_hash = hash_with_seed(hash, level);
142            let group_idx = fastmod32(level_hash as u32, groups as u32);
143            let group_seed = best_group_seeds[group_idx].to_u32().unwrap();
144            let bit_idx = bit_index_for_seed::<B>(level_hash, group_seed, group_idx);
145            // SAFETY: `bit_idx` is always within bounds (ensured during calculation)
146            *unsafe { best_group_bits.get_unchecked(bit_idx / 64) } & (1 << (bit_idx % 64)) == 0
147        });
148
149        (best_group_bits, best_group_seeds)
150    }
151
152    /// Returns number of groups and 64-bit segments for given `size`.
153    #[inline]
154    fn level_size_groups_segments(size: usize) -> (usize, usize) {
155        // Calculate the least common multiple of 64 and B
156        let lcm_value = Self::B.lcm(&64);
157
158        // Adjust size to the nearest value that is a multiple of the LCM
159        let adjusted_size = size.div_ceil(lcm_value) * lcm_value;
160
161        (adjusted_size / Self::B, adjusted_size / 64)
162    }
163
164    /// Computes group bits for given seed and then updates those groups where seed produced least collisions.
165    #[inline]
166    fn update_group_bits_with_seed(
167        level: u32,
168        groups: usize,
169        group_seed: u32,
170        hashes: &[u64],
171        group_bits: &mut [u64],
172        best_group_seeds: &mut [ST],
173    ) {
174        // Reset all group bits except best group bits
175        let group_bits_len = group_bits.len();
176        for bits in group_bits[..group_bits_len - 3].chunks_exact_mut(3) {
177            bits[0] = 0;
178            bits[1] = 0;
179        }
180
181        // For each hash compute group bits and collision bits
182        for &hash in hashes {
183            let level_hash = hash_with_seed(hash, level);
184            let group_idx = fastmod32(level_hash as u32, groups as u32);
185            let bit_idx = bit_index_for_seed::<B>(level_hash, group_seed, group_idx);
186            let mask = 1 << (bit_idx % 64);
187            let idx = (bit_idx / 64) * 3;
188
189            // SAFETY: `idx` is always within bounds (ensured during calculation)
190            let bits = unsafe { group_bits.get_unchecked_mut(idx..idx + 2) };
191
192            bits[1] |= bits[0] & mask;
193            bits[0] |= mask;
194        }
195
196        // Filter out collided bits from group bits
197        for bits in group_bits.chunks_exact_mut(3) {
198            bits[0] &= !bits[1];
199        }
200
201        // Update best group bits and seeds
202        for (group_idx, best_group_seed) in best_group_seeds.iter_mut().enumerate() {
203            let bit_idx = group_idx * Self::B;
204            let bit_pos = bit_idx % 64;
205            let idx = (bit_idx / 64) * 3;
206
207            // SAFETY: `idx` is always within bounds (ensured during calculation)
208            let bits = unsafe { group_bits.get_unchecked_mut(idx..idx + 6) };
209
210            let bits_1 = Self::B.min(64 - bit_pos);
211            let bits_2 = Self::B - bits_1;
212            let mask_1 = u64::MAX >> (64 - bits_1);
213            let mask_2 = (1 << bits_2) - 1;
214
215            let new_bits_1 = (bits[0] >> bit_pos) & mask_1;
216            let new_bits_2 = bits[3] & mask_2;
217            let new_ones = new_bits_1.count_ones() + new_bits_2.count_ones();
218
219            let best_bits_1 = (bits[2] >> bit_pos) & mask_1;
220            let best_bits_2 = bits[5] & mask_2;
221            let best_ones = best_bits_1.count_ones() + best_bits_2.count_ones();
222
223            if new_ones > best_ones {
224                bits[2] &= !(mask_1 << bit_pos);
225                bits[2] |= new_bits_1 << bit_pos;
226
227                bits[5] &= !mask_2;
228                bits[5] |= new_bits_2;
229
230                *best_group_seed = ST::from(group_seed).unwrap();
231            }
232        }
233    }
234
235    /// Returns the index associated with `key`, within 0 to the key collection size (exclusive).
236    /// If `key` was not in the initial collection, returns `None` or an arbitrary value from the range.
237    #[inline]
238    pub fn get<K: Hash + ?Sized>(&self, key: &K) -> Option<usize> {
239        Self::get_impl(key, &self.level_groups, &self.group_seeds, &self.ranked_bits)
240    }
241
242    /// Inner implementation of `get` with `level_groups`, `group_seeds` and `ranked_bits` passed
243    /// from standard and `Archived` version of `Mphf`.
244    #[inline]
245    fn get_impl<K: Hash + ?Sized>(
246        key: &K,
247        level_groups: &[u32],
248        group_seeds: &[ST],
249        ranked_bits: &impl RankedBitsAccess,
250    ) -> Option<usize> {
251        let mut groups_before = 0;
252        for (level, &groups) in level_groups.iter().enumerate() {
253            let level_hash = hash_with_seed(hash_key::<H, _>(key), level as u32);
254            let group_idx = groups_before + fastmod32(level_hash as u32, groups);
255            // SAFETY: `group_idx` is always within bounds (ensured during calculation)
256            let group_seed = unsafe { group_seeds.get_unchecked(group_idx).to_u32().unwrap() };
257            let bit_idx = bit_index_for_seed::<B>(level_hash, group_seed, group_idx);
258            if let Some(rank) = ranked_bits.rank(bit_idx) {
259                return Some(rank);
260            }
261            groups_before += groups as usize;
262        }
263
264        None
265    }
266
267    /// Returns the total number of bytes occupied by `Mphf`
268    pub fn size(&self) -> usize {
269        size_of_val(self)
270            + size_of_val(self.level_groups.as_ref())
271            + size_of_val(self.group_seeds.as_ref())
272            + self.ranked_bits.size()
273    }
274}
275
276/// Computes a 64-bit hash for the given key using the default hasher `H`.
277#[inline]
278fn hash_key<H: Hasher + Default, T: Hash + ?Sized>(key: &T) -> u64 {
279    let mut hasher = H::default();
280    key.hash(&mut hasher);
281    hasher.finish()
282}
283
284/// Computes bit index based on `hash`, `group_seed`, `groups_before` and const `B`.
285#[inline]
286fn bit_index_for_seed<const B: usize>(hash: u64, group_seed: u32, groups_before: usize) -> usize {
287    // Take the lower 32 bits of the hash and XOR with the group_seed
288    let mut x = (hash as u32) ^ group_seed;
289
290    // MurmurHash3's finalizer step to avalanche the bits
291    x = (x ^ (x >> 16)).wrapping_mul(0x85ebca6b);
292    x = (x ^ (x >> 13)).wrapping_mul(0xc2b2ae35);
293    x ^= x >> 16;
294
295    groups_before * B + fastmod32(x, B as u32)
296}
297
298/// Combines a 64-bit hash with a 32-bit seed, then multiplies by a prime constant to enhance hash uniformity and reduces the result back to 64 bits.
299#[inline]
300fn hash_with_seed(hash: u64, seed: u32) -> u64 {
301    let x = ((hash as u128) ^ (seed as u128)).wrapping_mul(0x5851f42d4c957f2d);
302    ((x & 0xFFFFFFFFFFFFFFFF) as u64) ^ ((x >> 64) as u64)
303}
304
305/// A fast alternative to the modulo reduction
306/// More details: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
307#[inline]
308fn fastmod32(x: u32, n: u32) -> usize {
309    (((x as u64) * (n as u64)) >> 32) as usize
310}
311
312/// Implement `get` for `Archived` version of `Mphf` if feature is enabled
313#[cfg(feature = "rkyv_derive")]
314impl<const B: usize, const S: usize, ST, H> ArchivedMphf<B, S, ST, H>
315where
316    ST: PrimInt + Unsigned + rkyv::Archive<Archived = ST>,
317    H: Hasher + Default,
318{
319    #[inline]
320    pub fn get<K: Hash + ?Sized>(&self, key: &K) -> Option<usize> {
321        Mphf::<B, S, ST, H>::get_impl(key, &self.level_groups, &self.group_seeds, &self.ranked_bits)
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use paste::paste;
329    use std::collections::HashSet;
330    use test_case::test_case;
331
332    /// Helper function that contains the test logic
333    fn test_mphfs_impl<const B: usize, const S: usize>(n: usize, gamma: f32) -> String {
334        let keys = (0..n as u64).collect::<Vec<u64>>();
335        let mphf = Mphf::<B, S>::from_slice(&keys, gamma).expect("failed to create mphf");
336
337        // Ensure that all keys are assigned unique index which is less than `n`
338        let mut set = HashSet::with_capacity(n);
339        for key in &keys {
340            let idx = mphf.get(key).unwrap();
341            assert!(idx < n, "idx = {} n = {}", idx, n);
342            if !set.insert(idx) {
343                panic!("duplicate idx = {} for key {}", idx, key);
344            }
345        }
346        assert_eq!(set.len(), n);
347
348        // Compute average number of levels which needed to be accessed during `get`
349        let mut avg_levels = 0f32;
350        let total_groups: u32 = mphf.level_groups.iter().sum();
351        for (i, &groups) in mphf.level_groups.iter().enumerate() {
352            avg_levels += ((i + 1) as f32 * groups as f32) / (total_groups as f32);
353        }
354        let bits = mphf.size() as f32 * (8.0 / n as f32);
355
356        format!(
357            "bits: {:.2} total_levels: {} avg_levels: {:.2}",
358            bits,
359            mphf.level_groups.len(),
360            avg_levels
361        )
362    }
363
364    /// Macro to generate test functions for various B and S constants
365    macro_rules! generate_tests {
366        ($(($b:expr, $s:expr, $n: expr, $gamma:expr, $expected:expr)),* $(,)?) => {
367            $(
368                paste! {
369                    #[test_case($n, $gamma => $expected)]
370                    fn [<test_mphfs_ $b _ $s _ $n _ $gamma>](n: usize, gamma_scaled: usize) -> String {
371                        let gamma = (gamma_scaled as f32) / 100.0;
372                        test_mphfs_impl::<$b, $s>(n, gamma)
373                    }
374                }
375            )*
376        };
377    }
378
379    // Generate test functions for different combinations of B and S
380    generate_tests!(
381        (1, 8, 10000, 100, "bits: 26.64 total_levels: 42 avg_levels: 4.34"),
382        (2, 8, 10000, 100, "bits: 9.00 total_levels: 8 avg_levels: 1.76"),
383        (4, 8, 10000, 100, "bits: 4.39 total_levels: 6 avg_levels: 1.42"),
384        (7, 8, 10000, 100, "bits: 3.12 total_levels: 4 avg_levels: 1.39"),
385        (8, 8, 10000, 100, "bits: 2.80 total_levels: 6 avg_levels: 1.34"),
386        (15, 8, 10000, 100, "bits: 2.50 total_levels: 4 avg_levels: 1.50"),
387        (16, 8, 10000, 100, "bits: 2.30 total_levels: 6 avg_levels: 1.43"),
388        (23, 8, 10000, 100, "bits: 2.53 total_levels: 4 avg_levels: 1.67"),
389        (24, 8, 10000, 100, "bits: 2.25 total_levels: 6 avg_levels: 1.57"),
390        (31, 8, 10000, 100, "bits: 2.40 total_levels: 3 avg_levels: 1.44"),
391        (32, 8, 10000, 100, "bits: 2.20 total_levels: 7 avg_levels: 1.63"),
392        (33, 8, 10000, 100, "bits: 2.52 total_levels: 4 avg_levels: 1.78"),
393        (48, 8, 10000, 100, "bits: 2.25 total_levels: 7 avg_levels: 1.78"),
394        (53, 8, 10000, 100, "bits: 2.90 total_levels: 4 avg_levels: 2.00"),
395        (61, 8, 10000, 100, "bits: 2.82 total_levels: 4 avg_levels: 2.00"),
396        (63, 8, 10000, 100, "bits: 2.89 total_levels: 4 avg_levels: 2.00"),
397        (64, 8, 10000, 100, "bits: 2.25 total_levels: 8 avg_levels: 1.84"),
398        (32, 7, 10000, 100, "bits: 2.29 total_levels: 7 avg_levels: 1.70"),
399        (32, 5, 10000, 100, "bits: 2.47 total_levels: 8 avg_levels: 1.84"),
400        (32, 4, 10000, 100, "bits: 2.58 total_levels: 9 avg_levels: 1.92"),
401        (32, 3, 10000, 100, "bits: 2.75 total_levels: 10 avg_levels: 2.05"),
402        (32, 1, 10000, 100, "bits: 3.22 total_levels: 11 avg_levels: 2.39"),
403        (32, 0, 10000, 100, "bits: 3.65 total_levels: 14 avg_levels: 2.73"),
404        (32, 8, 100000, 100, "bits: 2.11 total_levels: 10 avg_levels: 1.64"),
405        (32, 8, 100000, 200, "bits: 2.73 total_levels: 4 avg_levels: 1.06"),
406        (32, 6, 100000, 200, "bits: 2.84 total_levels: 5 avg_levels: 1.11"),
407    );
408
409    #[cfg(feature = "rkyv_derive")]
410    #[test]
411    fn test_rkyv() {
412        let n = 10000;
413        let keys = (0..n as u64).collect::<Vec<u64>>();
414        let mphf = Mphf::<32, 4>::from_slice(&keys, DEFAULT_GAMMA).expect("failed to create mphf");
415        let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&mphf).unwrap();
416
417        assert_eq!(rkyv_bytes.len(), 3804);
418
419        let rkyv_mphf = rkyv::check_archived_root::<Mphf<32, 4>>(&rkyv_bytes).unwrap();
420
421        // Ensure that all keys are assigned unique index which is less than `n`
422        let mut set = HashSet::with_capacity(n);
423        for key in &keys {
424            let idx = mphf.get(key).unwrap();
425            let rkyv_idx = rkyv_mphf.get(key).unwrap();
426
427            assert_eq!(idx, rkyv_idx);
428            assert!(idx < n, "idx = {} n = {}", idx, n);
429            if !set.insert(idx) {
430                panic!("duplicate idx = {} for key {}", idx, key);
431            }
432        }
433        assert_eq!(set.len(), n);
434    }
435}