1use std::borrow::Borrow;
15use std::collections::HashMap;
16use std::hash::{Hash, Hasher};
17use std::mem::size_of_val;
18
19use bitpacking::{BitPacker, BitPacker1x};
20use num::{PrimInt, Unsigned};
21use wyhash::WyHash;
22
23use crate::mphf::{Mphf, DEFAULT_GAMMA};
24
25#[derive(Default)]
27#[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))]
28#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))]
29pub struct MapWithDictBitpacked<K, const B: usize = 32, const S: usize = 8, ST = u8, H = WyHash>
30where
31 ST: PrimInt + Unsigned,
32 H: Hasher + Default,
33{
34 mphf: Mphf<B, S, ST, H>,
36 keys: Box<[K]>,
38 values_index: Box<[usize]>,
40 values_dict: Box<[u8]>,
42}
43
44#[derive(Debug)]
46pub enum Error {
47 MphfError(crate::mphf::MphfError),
49 NotEqualValuesLengths,
51}
52
53impl<K, const B: usize, const S: usize, ST, H> MapWithDictBitpacked<K, B, S, ST, H>
54where
55 K: Hash + PartialEq + Clone,
56 ST: PrimInt + Unsigned,
57 H: Hasher + Default,
58{
59 pub fn from_iter_with_params<I>(iter: I, gamma: f32) -> Result<Self, Error>
61 where
62 I: IntoIterator<Item = (K, Vec<u32>)>,
63 {
64 let mut keys = vec![];
65 let mut offsets_cache = HashMap::new();
66 let mut values_index = vec![];
67 let mut values_dict = vec![];
68
69 let mut iter = iter.into_iter().peekable();
70 let v_len = iter.peek().map_or(0, |(_, v)| v.len());
71
72 for (k, v) in iter {
73 keys.push(k.clone());
74
75 if v.len() != v_len {
76 return Err(Error::NotEqualValuesLengths);
77 }
78
79 if let Some(&offset) = offsets_cache.get(&v) {
80 values_index.push(offset);
82 } else {
83 let offset = values_dict.len();
85 offsets_cache.insert(v.clone(), offset);
86 values_index.push(offset);
87
88 pack_values(&v, &mut values_dict);
90 }
91 }
92
93 values_dict.resize(values_dict.len() + 4 * VALUES_BLOCK_LEN, 0);
95
96 let mphf = Mphf::from_slice(&keys, gamma).map_err(Error::MphfError)?;
97
98 for i in 0..keys.len() {
100 loop {
101 let idx = mphf.get(&keys[i]).unwrap();
102 if idx == i {
103 break;
104 }
105 keys.swap(i, idx);
106 values_index.swap(i, idx);
107 }
108 }
109
110 Ok(MapWithDictBitpacked {
111 mphf,
112 keys: keys.into_boxed_slice(),
113 values_index: values_index.into_boxed_slice(),
114 values_dict: values_dict.into_boxed_slice(),
115 })
116 }
117
118 #[inline]
132 pub fn get_values<Q>(&self, key: &Q, values: &mut [u32]) -> bool
133 where
134 K: Borrow<Q> + PartialEq<Q>,
135 Q: Hash + Eq + ?Sized,
136 {
137 let idx = match self.mphf.get(key) {
138 Some(idx) => idx,
139 None => return false,
140 };
141
142 unsafe {
144 if self.keys.get_unchecked(idx) != key {
145 return false;
146 }
147
148 let value_idx = *self.values_index.get_unchecked(idx);
150 let dict = self.values_dict.get_unchecked(value_idx..);
151 unpack_values(dict, values);
152 }
153
154 true
155 }
156
157 #[inline]
167 pub fn len(&self) -> usize {
168 self.keys.len()
169 }
170
171 #[inline]
183 pub fn is_empty(&self) -> bool {
184 self.keys.is_empty()
185 }
186
187 #[inline]
198 pub fn contains_key<Q>(&self, key: &Q) -> bool
199 where
200 K: Borrow<Q> + PartialEq<Q>,
201 Q: Hash + Eq + ?Sized,
202 {
203 if let Some(idx) = self.mphf.get(key) {
204 unsafe { self.keys.get_unchecked(idx) == key }
206 } else {
207 false
208 }
209 }
210
211 #[inline]
223 pub fn iter(&self, n: usize) -> impl Iterator<Item = (&K, Vec<u32>)> {
224 self.keys().zip(self.values_index.iter()).map(move |(key, &value_idx)| {
225 let mut values = vec![0; n];
226 let dict = unsafe { self.values_dict.get_unchecked(value_idx..) };
228 unpack_values(dict, &mut values);
229 (key, values)
230 })
231 }
232
233 #[inline]
245 pub fn keys(&self) -> impl Iterator<Item = &K> {
246 self.keys.iter()
247 }
248
249 #[inline]
261 pub fn values(&self, n: usize) -> impl Iterator<Item = Vec<u32>> + '_ {
262 self.values_index.iter().map(move |&value_idx| {
263 let mut values = vec![0; n];
264 let dict = unsafe { self.values_dict.get_unchecked(value_idx..) };
266 unpack_values(dict, &mut values);
267 values
268 })
269 }
270
271 pub fn size(&self) -> usize {
281 size_of_val(self)
282 + self.mphf.size()
283 + size_of_val(self.keys.as_ref())
284 + size_of_val(self.values_index.as_ref())
285 + size_of_val(self.values_dict.as_ref())
286 }
287}
288
289impl<K> TryFrom<HashMap<K, Vec<u32>>> for MapWithDictBitpacked<K>
291where
292 K: PartialEq + Hash + Clone,
293{
294 type Error = Error;
295
296 #[inline]
297 fn try_from(value: HashMap<K, Vec<u32>>) -> Result<Self, Self::Error> {
298 MapWithDictBitpacked::from_iter_with_params(value, DEFAULT_GAMMA)
299 }
300}
301
302const VALUES_BLOCK_LEN: usize = BitPacker1x::BLOCK_LEN;
304
305fn pack_values(values: &[u32], dict: &mut Vec<u8>) {
308 let bitpacker = BitPacker1x::new();
310
311 for block in values.chunks(VALUES_BLOCK_LEN) {
312 let mut values_block = [0u32; VALUES_BLOCK_LEN];
313 let mut values_packed_block = [0u8; 4 * VALUES_BLOCK_LEN];
314
315 values_block[..block.len()].copy_from_slice(block);
316
317 let num_bits = bitpacker.num_bits(&values_block);
319
320 bitpacker.compress(&values_block, &mut values_packed_block, num_bits);
322
323 let size = (block.len() * (num_bits as usize)).div_ceil(8);
325 dict.push(num_bits);
326 dict.extend_from_slice(&values_packed_block[..size]);
327 }
328}
329
330fn unpack_values(dict: &[u8], res: &mut [u32]) {
333 let bitpacker = BitPacker1x::new();
334 let mut dict = dict;
335 for block in res.chunks_mut(VALUES_BLOCK_LEN) {
336 let mut values_block = [0u32; VALUES_BLOCK_LEN];
337
338 let num_bits = dict[0];
340 dict = &dict[1..];
341
342 let size = (block.len() * (num_bits as usize)).div_ceil(8);
344 bitpacker.decompress(dict, &mut values_block, num_bits);
345 dict = &dict[size..];
346
347 block.copy_from_slice(&values_block[..block.len()]);
348 }
349}
350
351#[cfg(feature = "rkyv_derive")]
353impl<K, const B: usize, const S: usize, ST, H> ArchivedMapWithDictBitpacked<K, B, S, ST, H>
354where
355 K: PartialEq + Hash + rkyv::Archive,
356 K::Archived: PartialEq<K>,
357 ST: PrimInt + Unsigned + rkyv::Archive<Archived = ST>,
358 H: Hasher + Default,
359{
360 #[inline]
377 pub fn get_values(&self, key: &K, values: &mut [u32]) -> bool {
378 let idx = match self.mphf.get(key) {
379 Some(idx) => idx,
380 None => return false,
381 };
382
383 unsafe {
385 if self.keys.get_unchecked(idx) != key {
386 return false;
387 }
388
389 let value_idx = *self.values_index.get_unchecked(idx) as usize;
391 let dict = self.values_dict.get_unchecked(value_idx..);
392 unpack_values(dict, values);
393 }
394
395 true
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use paste::paste;
403 use proptest::prelude::*;
404 use rand::{Rng, SeedableRng};
405 use rand_chacha::ChaCha8Rng;
406 use std::collections::{hash_map::RandomState, HashSet};
407 use test_case::test_case;
408
409 #[test_case(
410 &[] => Vec::<u8>::new();
411 "empty values"
412 )]
413 #[test_case(
414 &[0] => vec![0];
415 "single 0-bit value"
416 )]
417 #[test_case(
418 &[0; 10] => vec![0];
419 "10 0-bit value"
420 )]
421 #[test_case(
422 &[0; 77] => vec![0, 0, 0];
423 "77 0-bit values (3 blocks)"
424 )]
425 #[test_case(
426 &[1] => vec![1, 1];
427 "single 1-bit value"
428 )]
429 #[test_case(
430 &[1; 10] => vec![1, 0b11111111, 0b00000011];
431 "10 1-bit value"
432 )]
433 #[test_case(
434 &[1; 32] => vec![1, 0b11111111, 0b11111111, 0b11111111, 0b11111111];
435 "32 1-bit value"
436 )]
437 #[test_case(
438 &[1; 33] => vec![1, 0b11111111, 0b11111111, 0b11111111, 0b11111111, 1, 0b00000001];
439 "33 1-bit value"
440 )]
441 #[test_case(
442 &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] => vec![4, 0b0010_0001, 0b0100_0011, 0b0110_0101, 0b1000_0111, 0b1010_1001];
443 "10 4-bit value"
444 )]
445 fn test_pack_unpack(values: &[u32]) -> Vec<u8> {
446 let mut dict = vec![];
447 pack_values(values, &mut dict);
448
449 let mut padded_dict = dict.clone();
450 padded_dict.resize(dict.len() + 4 * VALUES_BLOCK_LEN, 0);
451
452 let mut unpacked_values = vec![0; values.len()];
453 unpack_values(&padded_dict, &mut unpacked_values);
454
455 assert_eq!(values, unpacked_values);
456
457 dict
458 }
459
460 #[test]
461 fn test_pack_unpack_random() {
462 let max_n = 200;
463 let mut rng = ChaCha8Rng::seed_from_u64(123);
464 let mut dict = vec![];
465 let mut values = vec![];
466 let mut unpacked_values = vec![];
467
468 for n in 1..=max_n {
469 for num_bits in 0..=32 {
470 values.truncate(0);
471 values.extend((0..n).map(|_| rng.gen::<u32>() & ((1u32 << (num_bits % 32)) - 1)));
472 dict.truncate(0);
473
474 pack_values(&values, &mut dict);
475 assert!(!dict.is_empty());
476
477 dict.resize(dict.len() + 4 * VALUES_BLOCK_LEN, 0);
478 unpacked_values.resize(n, 0);
479 unpack_values(&dict, &mut unpacked_values);
480
481 assert_eq!(values, unpacked_values);
482 }
483 }
484 }
485
486 fn gen_map(items_num: usize, values_num: usize) -> HashMap<u64, Vec<u32>> {
487 let mut rng = ChaCha8Rng::seed_from_u64(123);
488
489 (0..items_num)
490 .map(|_| {
491 let key = rng.gen::<u64>();
492 let value = (0..values_num).map(|_| rng.gen_range(1..=10)).collect();
493 (key, value)
494 })
495 .collect()
496 }
497
498 #[test]
499 fn test_map_with_dict_bitpacked() {
500 let items_num = 1000;
501 let values_num = 10;
502 let original_map = gen_map(items_num, values_num);
503
504 let map = MapWithDictBitpacked::try_from(original_map.clone()).unwrap();
505
506 assert_eq!(map.len(), original_map.len());
508
509 assert_eq!(map.is_empty(), original_map.is_empty());
511
512 let mut values_buf = vec![0; values_num];
514 for (key, value) in &original_map {
515 assert!(map.get_values(key, &mut values_buf));
516 assert_eq!(value, &values_buf);
517 assert!(map.contains_key(key));
518 }
519
520 for (&k, v) in map.iter(values_num) {
522 assert_eq!(original_map.get(&k), Some(&v));
523 }
524
525 for k in map.keys() {
527 assert!(original_map.contains_key(k));
528 }
529
530 for v in map.values(values_num) {
532 assert!(original_map.values().any(|val| val == &v));
533 }
534
535 assert_eq!(map.size(), 22664);
537 }
538
539 #[cfg(feature = "rkyv_derive")]
540 #[test]
541 fn test_rkyv() {
542 let items_num = 1000;
544 let values_num = 10;
545 let original_map = gen_map(items_num, values_num);
546 let map = MapWithDictBitpacked::try_from(original_map.clone()).unwrap();
547 let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&map).unwrap();
548
549 assert_eq!(rkyv_bytes.len(), 18516);
550
551 let rkyv_map = rkyv::check_archived_root::<MapWithDictBitpacked<u64>>(&rkyv_bytes).unwrap();
552
553 let mut values_buf = vec![0; values_num];
555 for (k, v) in original_map {
556 rkyv_map.get_values(&k, &mut values_buf);
557 assert_eq!(v, values_buf);
558 }
559 }
560
561 macro_rules! proptest_map_with_dict_bitpacked_model {
562 ($(($b:expr, $s:expr, $gamma:expr, $n:expr)),* $(,)?) => {
563 $(
564 paste! {
565 proptest! {
566 #[test]
567 fn [<proptest_map_with_dict_bitpacked_model_ $b _ $s _ $n _ $gamma>](model: HashMap<u64, [u32; $n]>, arbitrary: HashSet<u64>) {
568 let entropy_map: MapWithDictBitpacked<u64, $b, $s> = MapWithDictBitpacked::from_iter_with_params(
569 model.iter().map(|(&k, v)| (k, Vec::from(v))),
570 $gamma as f32 / 100.0
571 ).unwrap();
572
573 assert_eq!(entropy_map.len(), model.len());
575 assert_eq!(entropy_map.is_empty(), model.is_empty());
576
577 assert_eq!(
579 HashSet::<_, RandomState>::from_iter(entropy_map.keys()),
580 HashSet::from_iter(model.keys())
581 );
582 assert_eq!(
583 HashSet::<_, RandomState>::from_iter(entropy_map.values($n)),
584 HashSet::from_iter(model.values().map(Vec::from))
585 );
586
587 for (k, v) in &model {
589 assert!(entropy_map.contains_key(&k));
590
591 let mut buf = [0u32; $n];
592 assert!(entropy_map.get_values(&k, &mut buf));
593 assert_eq!(&buf, v);
594 }
595
596 for k in arbitrary {
598 assert_eq!(
599 model.contains_key(&k),
600 entropy_map.contains_key(&k),
601 );
602 let mut buf = [0u32; $n];
603 let contains = entropy_map.get_values(&k, &mut buf);
604 assert_eq!(contains, model.contains_key(&k));
605 if contains {
606 assert_eq!(Some(&buf), model.get(&k));
607 }
608 }
609 }
610 }
611 }
612 )*
613 };
614 }
615
616 proptest_map_with_dict_bitpacked_model!(
617 (2, 8, 100, 10),
619 (4, 8, 100, 10),
620 (7, 8, 100, 10),
621 (8, 8, 100, 10),
622 (15, 8, 100, 10),
623 (16, 8, 100, 10),
624 (23, 8, 100, 10),
625 (24, 8, 100, 10),
626 (31, 8, 100, 10),
627 (32, 8, 100, 10),
628 (33, 8, 100, 10),
629 (48, 8, 100, 10),
630 (53, 8, 100, 10),
631 (61, 8, 100, 10),
632 (63, 8, 100, 10),
633 (64, 8, 100, 10),
634 (32, 7, 100, 10),
635 (32, 5, 100, 10),
636 (32, 4, 100, 10),
637 (32, 3, 100, 10),
638 (32, 1, 100, 10),
639 (32, 0, 100, 10),
640 (32, 8, 200, 10),
641 (32, 6, 200, 10),
642 );
643}