1use std::hash::{Hash, Hasher};
11use std::marker::PhantomData;
12use std::mem::size_of_val;
13
14use num::{Integer, PrimInt, Unsigned};
15use wyhash::WyHash;
16
17use crate::mphf::MphfError::*;
18use crate::rank::{RankedBits, RankedBitsAccess};
19
20#[derive(Default)]
28#[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))]
29#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))]
30pub struct Mphf<const B: usize = 32, const S: usize = 8, ST: PrimInt + Unsigned = u8, H: Hasher + Default = WyHash> {
31 ranked_bits: RankedBits,
33 level_groups: Box<[u32]>,
35 group_seeds: Box<[ST]>,
37 _phantom_hasher: PhantomData<H>,
39}
40
41const MAX_LEVELS: usize = 64;
43
44#[derive(Debug)]
46pub enum MphfError {
47 MaxLevelsExceeded,
49 InvalidSeedType,
51 InvalidGammaParameter,
53}
54
55pub const DEFAULT_GAMMA: f32 = 2.0;
57
58impl<const B: usize, const S: usize, ST: PrimInt + Unsigned, H: Hasher + Default> Mphf<B, S, ST, H> {
59 const B: usize = {
61 assert!(B >= 1 && B <= 64);
62 B
63 };
64 const S: usize = {
66 assert!(S <= 16);
67 S
68 };
69
70 pub fn from_slice<K: Hash>(keys: &[K], gamma: f32) -> Result<Self, MphfError> {
72 if gamma < 1.0 {
73 return Err(InvalidGammaParameter);
74 }
75
76 if ST::from((1 << Self::S) - 1).is_none() {
77 return Err(InvalidSeedType);
78 }
79
80 let mut hashes: Vec<u64> = keys.iter().map(|key| hash_key::<H, _>(key)).collect();
81 let mut group_bits = vec![];
82 let mut group_seeds = vec![];
83 let mut level_groups = vec![];
84
85 while !hashes.is_empty() {
86 let level = level_groups.len() as u32;
87 let (level_group_bits, level_group_seeds) = Self::build_level(level, &mut hashes, gamma);
88
89 group_bits.extend_from_slice(&level_group_bits);
90 group_seeds.extend_from_slice(&level_group_seeds);
91 level_groups.push(level_group_seeds.len() as u32);
92
93 if level_groups.len() == MAX_LEVELS && !hashes.is_empty() {
94 return Err(MaxLevelsExceeded);
95 }
96 }
97
98 Ok(Mphf {
99 ranked_bits: RankedBits::new(group_bits.into_boxed_slice()),
100 level_groups: level_groups.into_boxed_slice(),
101 group_seeds: group_seeds.into_boxed_slice(),
102 _phantom_hasher: PhantomData,
103 })
104 }
105
106 fn build_level(level: u32, hashes: &mut Vec<u64>, gamma: f32) -> (Vec<u64>, Vec<ST>) {
108 let level_size = ((hashes.len() as f32) * gamma).ceil() as usize;
110 let (groups, segments) = Self::level_size_groups_segments(level_size);
111 let max_group_seed = 1 << S;
112
113 let mut group_bits = vec![0u64; 3 * segments + 3];
119 let mut best_group_seeds = vec![ST::zero(); groups];
120
121 for group_seed in 0..max_group_seed {
123 Self::update_group_bits_with_seed(
124 level,
125 groups,
126 group_seed,
127 hashes,
128 &mut group_bits,
129 &mut best_group_seeds,
130 );
131 }
132
133 let best_group_bits: Vec<u64> = group_bits[..group_bits.len() - 3]
135 .chunks_exact(3)
136 .map(|group_bits| group_bits[2])
137 .collect();
138
139 hashes.retain(|&hash| {
141 let level_hash = hash_with_seed(hash, level);
142 let group_idx = fastmod32(level_hash as u32, groups as u32);
143 let group_seed = best_group_seeds[group_idx].to_u32().unwrap();
144 let bit_idx = bit_index_for_seed::<B>(level_hash, group_seed, group_idx);
145 *unsafe { best_group_bits.get_unchecked(bit_idx / 64) } & (1 << (bit_idx % 64)) == 0
147 });
148
149 (best_group_bits, best_group_seeds)
150 }
151
152 #[inline]
154 fn level_size_groups_segments(size: usize) -> (usize, usize) {
155 let lcm_value = Self::B.lcm(&64);
157
158 let adjusted_size = size.div_ceil(lcm_value) * lcm_value;
160
161 (adjusted_size / Self::B, adjusted_size / 64)
162 }
163
164 #[inline]
166 fn update_group_bits_with_seed(
167 level: u32,
168 groups: usize,
169 group_seed: u32,
170 hashes: &[u64],
171 group_bits: &mut [u64],
172 best_group_seeds: &mut [ST],
173 ) {
174 let group_bits_len = group_bits.len();
176 for bits in group_bits[..group_bits_len - 3].chunks_exact_mut(3) {
177 bits[0] = 0;
178 bits[1] = 0;
179 }
180
181 for &hash in hashes {
183 let level_hash = hash_with_seed(hash, level);
184 let group_idx = fastmod32(level_hash as u32, groups as u32);
185 let bit_idx = bit_index_for_seed::<B>(level_hash, group_seed, group_idx);
186 let mask = 1 << (bit_idx % 64);
187 let idx = (bit_idx / 64) * 3;
188
189 let bits = unsafe { group_bits.get_unchecked_mut(idx..idx + 2) };
191
192 bits[1] |= bits[0] & mask;
193 bits[0] |= mask;
194 }
195
196 for bits in group_bits.chunks_exact_mut(3) {
198 bits[0] &= !bits[1];
199 }
200
201 for (group_idx, best_group_seed) in best_group_seeds.iter_mut().enumerate() {
203 let bit_idx = group_idx * Self::B;
204 let bit_pos = bit_idx % 64;
205 let idx = (bit_idx / 64) * 3;
206
207 let bits = unsafe { group_bits.get_unchecked_mut(idx..idx + 6) };
209
210 let bits_1 = Self::B.min(64 - bit_pos);
211 let bits_2 = Self::B - bits_1;
212 let mask_1 = u64::MAX >> (64 - bits_1);
213 let mask_2 = (1 << bits_2) - 1;
214
215 let new_bits_1 = (bits[0] >> bit_pos) & mask_1;
216 let new_bits_2 = bits[3] & mask_2;
217 let new_ones = new_bits_1.count_ones() + new_bits_2.count_ones();
218
219 let best_bits_1 = (bits[2] >> bit_pos) & mask_1;
220 let best_bits_2 = bits[5] & mask_2;
221 let best_ones = best_bits_1.count_ones() + best_bits_2.count_ones();
222
223 if new_ones > best_ones {
224 bits[2] &= !(mask_1 << bit_pos);
225 bits[2] |= new_bits_1 << bit_pos;
226
227 bits[5] &= !mask_2;
228 bits[5] |= new_bits_2;
229
230 *best_group_seed = ST::from(group_seed).unwrap();
231 }
232 }
233 }
234
235 #[inline]
238 pub fn get<K: Hash + ?Sized>(&self, key: &K) -> Option<usize> {
239 Self::get_impl(key, &self.level_groups, &self.group_seeds, &self.ranked_bits)
240 }
241
242 #[inline]
245 fn get_impl<K: Hash + ?Sized>(
246 key: &K,
247 level_groups: &[u32],
248 group_seeds: &[ST],
249 ranked_bits: &impl RankedBitsAccess,
250 ) -> Option<usize> {
251 let mut groups_before = 0;
252 for (level, &groups) in level_groups.iter().enumerate() {
253 let level_hash = hash_with_seed(hash_key::<H, _>(key), level as u32);
254 let group_idx = groups_before + fastmod32(level_hash as u32, groups);
255 let group_seed = unsafe { group_seeds.get_unchecked(group_idx).to_u32().unwrap() };
257 let bit_idx = bit_index_for_seed::<B>(level_hash, group_seed, group_idx);
258 if let Some(rank) = ranked_bits.rank(bit_idx) {
259 return Some(rank);
260 }
261 groups_before += groups as usize;
262 }
263
264 None
265 }
266
267 pub fn size(&self) -> usize {
269 size_of_val(self)
270 + size_of_val(self.level_groups.as_ref())
271 + size_of_val(self.group_seeds.as_ref())
272 + self.ranked_bits.size()
273 }
274}
275
276#[inline]
278fn hash_key<H: Hasher + Default, T: Hash + ?Sized>(key: &T) -> u64 {
279 let mut hasher = H::default();
280 key.hash(&mut hasher);
281 hasher.finish()
282}
283
284#[inline]
286fn bit_index_for_seed<const B: usize>(hash: u64, group_seed: u32, groups_before: usize) -> usize {
287 let mut x = (hash as u32) ^ group_seed;
289
290 x = (x ^ (x >> 16)).wrapping_mul(0x85ebca6b);
292 x = (x ^ (x >> 13)).wrapping_mul(0xc2b2ae35);
293 x ^= x >> 16;
294
295 groups_before * B + fastmod32(x, B as u32)
296}
297
298#[inline]
300fn hash_with_seed(hash: u64, seed: u32) -> u64 {
301 let x = ((hash as u128) ^ (seed as u128)).wrapping_mul(0x5851f42d4c957f2d);
302 ((x & 0xFFFFFFFFFFFFFFFF) as u64) ^ ((x >> 64) as u64)
303}
304
305#[inline]
308fn fastmod32(x: u32, n: u32) -> usize {
309 (((x as u64) * (n as u64)) >> 32) as usize
310}
311
312#[cfg(feature = "rkyv_derive")]
314impl<const B: usize, const S: usize, ST, H> ArchivedMphf<B, S, ST, H>
315where
316 ST: PrimInt + Unsigned + rkyv::Archive<Archived = ST>,
317 H: Hasher + Default,
318{
319 #[inline]
320 pub fn get<K: Hash + ?Sized>(&self, key: &K) -> Option<usize> {
321 Mphf::<B, S, ST, H>::get_impl(key, &self.level_groups, &self.group_seeds, &self.ranked_bits)
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use paste::paste;
329 use std::collections::HashSet;
330 use test_case::test_case;
331
332 fn test_mphfs_impl<const B: usize, const S: usize>(n: usize, gamma: f32) -> String {
334 let keys = (0..n as u64).collect::<Vec<u64>>();
335 let mphf = Mphf::<B, S>::from_slice(&keys, gamma).expect("failed to create mphf");
336
337 let mut set = HashSet::with_capacity(n);
339 for key in &keys {
340 let idx = mphf.get(key).unwrap();
341 assert!(idx < n, "idx = {} n = {}", idx, n);
342 if !set.insert(idx) {
343 panic!("duplicate idx = {} for key {}", idx, key);
344 }
345 }
346 assert_eq!(set.len(), n);
347
348 let mut avg_levels = 0f32;
350 let total_groups: u32 = mphf.level_groups.iter().sum();
351 for (i, &groups) in mphf.level_groups.iter().enumerate() {
352 avg_levels += ((i + 1) as f32 * groups as f32) / (total_groups as f32);
353 }
354 let bits = mphf.size() as f32 * (8.0 / n as f32);
355
356 format!(
357 "bits: {:.2} total_levels: {} avg_levels: {:.2}",
358 bits,
359 mphf.level_groups.len(),
360 avg_levels
361 )
362 }
363
364 macro_rules! generate_tests {
366 ($(($b:expr, $s:expr, $n: expr, $gamma:expr, $expected:expr)),* $(,)?) => {
367 $(
368 paste! {
369 #[test_case($n, $gamma => $expected)]
370 fn [<test_mphfs_ $b _ $s _ $n _ $gamma>](n: usize, gamma_scaled: usize) -> String {
371 let gamma = (gamma_scaled as f32) / 100.0;
372 test_mphfs_impl::<$b, $s>(n, gamma)
373 }
374 }
375 )*
376 };
377 }
378
379 generate_tests!(
381 (1, 8, 10000, 100, "bits: 26.64 total_levels: 42 avg_levels: 4.34"),
382 (2, 8, 10000, 100, "bits: 9.00 total_levels: 8 avg_levels: 1.76"),
383 (4, 8, 10000, 100, "bits: 4.39 total_levels: 6 avg_levels: 1.42"),
384 (7, 8, 10000, 100, "bits: 3.12 total_levels: 4 avg_levels: 1.39"),
385 (8, 8, 10000, 100, "bits: 2.80 total_levels: 6 avg_levels: 1.34"),
386 (15, 8, 10000, 100, "bits: 2.50 total_levels: 4 avg_levels: 1.50"),
387 (16, 8, 10000, 100, "bits: 2.30 total_levels: 6 avg_levels: 1.43"),
388 (23, 8, 10000, 100, "bits: 2.53 total_levels: 4 avg_levels: 1.67"),
389 (24, 8, 10000, 100, "bits: 2.25 total_levels: 6 avg_levels: 1.57"),
390 (31, 8, 10000, 100, "bits: 2.40 total_levels: 3 avg_levels: 1.44"),
391 (32, 8, 10000, 100, "bits: 2.20 total_levels: 7 avg_levels: 1.63"),
392 (33, 8, 10000, 100, "bits: 2.52 total_levels: 4 avg_levels: 1.78"),
393 (48, 8, 10000, 100, "bits: 2.25 total_levels: 7 avg_levels: 1.78"),
394 (53, 8, 10000, 100, "bits: 2.90 total_levels: 4 avg_levels: 2.00"),
395 (61, 8, 10000, 100, "bits: 2.82 total_levels: 4 avg_levels: 2.00"),
396 (63, 8, 10000, 100, "bits: 2.89 total_levels: 4 avg_levels: 2.00"),
397 (64, 8, 10000, 100, "bits: 2.25 total_levels: 8 avg_levels: 1.84"),
398 (32, 7, 10000, 100, "bits: 2.29 total_levels: 7 avg_levels: 1.70"),
399 (32, 5, 10000, 100, "bits: 2.47 total_levels: 8 avg_levels: 1.84"),
400 (32, 4, 10000, 100, "bits: 2.58 total_levels: 9 avg_levels: 1.92"),
401 (32, 3, 10000, 100, "bits: 2.75 total_levels: 10 avg_levels: 2.05"),
402 (32, 1, 10000, 100, "bits: 3.22 total_levels: 11 avg_levels: 2.39"),
403 (32, 0, 10000, 100, "bits: 3.65 total_levels: 14 avg_levels: 2.73"),
404 (32, 8, 100000, 100, "bits: 2.11 total_levels: 10 avg_levels: 1.64"),
405 (32, 8, 100000, 200, "bits: 2.73 total_levels: 4 avg_levels: 1.06"),
406 (32, 6, 100000, 200, "bits: 2.84 total_levels: 5 avg_levels: 1.11"),
407 );
408
409 #[cfg(feature = "rkyv_derive")]
410 #[test]
411 fn test_rkyv() {
412 let n = 10000;
413 let keys = (0..n as u64).collect::<Vec<u64>>();
414 let mphf = Mphf::<32, 4>::from_slice(&keys, DEFAULT_GAMMA).expect("failed to create mphf");
415 let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&mphf).unwrap();
416
417 assert_eq!(rkyv_bytes.len(), 3804);
418
419 let rkyv_mphf = rkyv::check_archived_root::<Mphf<32, 4>>(&rkyv_bytes).unwrap();
420
421 let mut set = HashSet::with_capacity(n);
423 for key in &keys {
424 let idx = mphf.get(key).unwrap();
425 let rkyv_idx = rkyv_mphf.get(key).unwrap();
426
427 assert_eq!(idx, rkyv_idx);
428 assert!(idx < n, "idx = {} n = {}", idx, n);
429 if !set.insert(idx) {
430 panic!("duplicate idx = {} for key {}", idx, key);
431 }
432 }
433 assert_eq!(set.len(), n);
434 }
435}