use std::borrow::Borrow;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::mem::size_of_val;
use num::{PrimInt, Unsigned};
use wyhash::WyHash;
use crate::mphf::{Mphf, MphfError, DEFAULT_GAMMA};
#[derive(Default)]
#[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))]
#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))]
pub struct MapWithDict<K, V, const B: usize = 32, const S: usize = 8, ST = u8, H = WyHash>
where
ST: PrimInt + Unsigned,
H: Hasher + Default,
{
mphf: Mphf<B, S, ST, H>,
keys: Box<[K]>,
values_index: Box<[usize]>,
values_dict: Box<[V]>,
}
impl<K, V, const B: usize, const S: usize, ST, H> MapWithDict<K, V, B, S, ST, H>
where
K: Eq + Hash + Clone,
V: Eq + Clone + Hash,
ST: PrimInt + Unsigned,
H: Hasher + Default,
{
pub fn from_iter_with_params<I>(iter: I, gamma: f32) -> Result<Self, MphfError>
where
I: IntoIterator<Item = (K, V)>,
{
let mut keys = vec![];
let mut values_index = vec![];
let mut values_dict = vec![];
let mut offsets_cache = HashMap::new();
for (k, v) in iter {
keys.push(k.clone());
if let Some(&offset) = offsets_cache.get(&v) {
values_index.push(offset);
} else {
let offset = values_dict.len();
offsets_cache.insert(v.clone(), offset);
values_index.push(offset);
values_dict.push(v.clone());
}
}
let mphf = Mphf::from_slice(&keys, gamma)?;
for i in 0..keys.len() {
loop {
let idx = mphf.get(&keys[i]).unwrap();
if idx == i {
break;
}
keys.swap(i, idx);
values_index.swap(i, idx);
}
}
Ok(MapWithDict {
mphf,
keys: keys.into_boxed_slice(),
values_index: values_index.into_boxed_slice(),
values_dict: values_dict.into_boxed_slice(),
})
}
#[inline]
pub fn get<Q>(&self, key: &Q) -> Option<&V>
where
K: Borrow<Q> + PartialEq<Q>,
Q: Hash + Eq + ?Sized,
{
let idx = self.mphf.get(key)?;
unsafe {
if self.keys.get_unchecked(idx) == key {
let value_idx = *self.values_index.get_unchecked(idx);
Some(self.values_dict.get_unchecked(value_idx))
} else {
None
}
}
}
#[inline]
pub fn len(&self) -> usize {
self.keys.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.keys.is_empty()
}
#[inline]
pub fn contains_key<Q>(&self, key: &Q) -> bool
where
K: Borrow<Q> + PartialEq<Q>,
Q: Hash + Eq + ?Sized,
{
if let Some(idx) = self.mphf.get(key) {
unsafe { self.keys.get_unchecked(idx) == key }
} else {
false
}
}
#[inline]
pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> {
self.keys
.iter()
.zip(self.values_index.iter())
.map(move |(key, &value_idx)| {
let value = unsafe { self.values_dict.get_unchecked(value_idx) };
(key, value)
})
}
#[inline]
pub fn keys(&self) -> impl Iterator<Item = &K> {
self.keys.iter()
}
#[inline]
pub fn values(&self) -> impl Iterator<Item = &V> {
self.values_index.iter().map(move |&value_idx| {
unsafe { self.values_dict.get_unchecked(value_idx) }
})
}
#[inline]
pub fn size(&self) -> usize {
size_of_val(self)
+ self.mphf.size()
+ size_of_val(self.keys.as_ref())
+ size_of_val(self.values_index.as_ref())
+ size_of_val(self.values_dict.as_ref())
}
}
impl<K, V> TryFrom<HashMap<K, V>> for MapWithDict<K, V>
where
K: Eq + Hash + Clone,
V: Eq + Clone + Hash,
{
type Error = MphfError;
#[inline]
fn try_from(value: HashMap<K, V>) -> Result<Self, Self::Error> {
MapWithDict::<K, V>::from_iter_with_params(value, DEFAULT_GAMMA)
}
}
#[cfg(feature = "rkyv_derive")]
impl<K, V, const B: usize, const S: usize, ST, H> ArchivedMapWithDict<K, V, B, S, ST, H>
where
K: PartialEq + Hash + rkyv::Archive,
K::Archived: PartialEq<K>,
V: rkyv::Archive,
ST: PrimInt + Unsigned + rkyv::Archive<Archived = ST>,
H: Hasher + Default,
{
#[inline]
pub fn contains_key<Q: ?Sized>(&self, key: &Q) -> bool
where
K: Borrow<Q>,
<K as rkyv::Archive>::Archived: PartialEq<Q>,
Q: Hash + Eq,
{
if let Some(idx) = self.mphf.get(key) {
unsafe { self.keys.get_unchecked(idx) == key }
} else {
false
}
}
#[inline]
pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<&V::Archived>
where
K: Borrow<Q>,
<K as rkyv::Archive>::Archived: PartialEq<Q>,
Q: Hash + Eq,
{
let idx = self.mphf.get(key)?;
unsafe {
if self.keys.get_unchecked(idx) == key {
let value_idx = *self.values_index.get_unchecked(idx) as usize;
Some(self.values_dict.get_unchecked(value_idx))
} else {
None
}
}
}
#[inline]
pub fn iter(&self) -> impl Iterator<Item = (&K::Archived, &V::Archived)> {
self.keys
.iter()
.zip(self.values_index.iter())
.map(move |(key, &value_idx)| {
let value = unsafe { self.values_dict.get_unchecked(value_idx as usize) };
(key, value)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use paste::paste;
use proptest::prelude::*;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use std::collections::{hash_map::RandomState, HashSet};
fn gen_map(items_num: usize) -> HashMap<u64, u32> {
let mut rng = ChaCha8Rng::seed_from_u64(123);
(0..items_num)
.map(|_| {
let key = rng.gen::<u64>();
let value = rng.gen_range(1..=10);
(key, value)
})
.collect()
}
#[test]
fn test_map_with_dict() {
let original_map = gen_map(1000);
let map = MapWithDict::try_from(original_map.clone()).unwrap();
assert_eq!(map.len(), original_map.len());
assert_eq!(map.is_empty(), original_map.is_empty());
for (key, value) in &original_map {
assert_eq!(map.get(key), Some(value));
assert!(map.contains_key(key));
}
for (&k, &v) in map.iter() {
assert_eq!(original_map.get(&k), Some(&v));
}
for k in map.keys() {
assert!(original_map.contains_key(k));
}
for &v in map.values() {
assert!(original_map.values().any(|&val| val == v));
}
assert_eq!(map.size(), 16626);
}
#[test]
fn test_get_borrow() {
let original_map = HashMap::from_iter([("a".to_string(), ()), ("b".to_string(), ())]);
let map = MapWithDict::try_from(original_map).unwrap();
assert_eq!(map.get("a"), Some(&()));
assert!(map.contains_key("a"));
assert_eq!(map.get("b"), Some(&()));
assert!(map.contains_key("b"));
assert_eq!(map.get("c"), None);
assert!(!map.contains_key("c"));
}
#[cfg(feature = "rkyv_derive")]
#[test]
fn test_rkyv() {
let original_map = gen_map(1000);
let map = MapWithDict::try_from(original_map.clone()).unwrap();
let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&map).unwrap();
assert_eq!(rkyv_bytes.len(), 12464);
let rkyv_map = rkyv::check_archived_root::<MapWithDict<u64, u32>>(&rkyv_bytes).unwrap();
for (k, v) in original_map.iter() {
assert_eq!(v, rkyv_map.get(k).unwrap());
}
for (&k, &v) in rkyv_map.iter() {
assert_eq!(original_map.get(&k), Some(&v));
}
}
#[cfg(feature = "rkyv_derive")]
#[test]
fn test_rkyv_get_borrow() {
let original_map = HashMap::from_iter([("a".to_string(), ()), ("b".to_string(), ())]);
let map = MapWithDict::try_from(original_map).unwrap();
let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&map).unwrap();
let rkyv_map = rkyv::check_archived_root::<MapWithDict<String, ()>>(&rkyv_bytes).unwrap();
assert_eq!(map.get("a"), Some(&()));
assert!(rkyv_map.contains_key("a"));
assert_eq!(map.get("b"), Some(&()));
assert!(rkyv_map.contains_key("b"));
assert_eq!(map.get("c"), None);
assert!(!rkyv_map.contains_key("c"));
}
macro_rules! proptest_map_with_dict_model {
($(($b:expr, $s:expr, $gamma:expr)),* $(,)?) => {
$(
paste! {
proptest! {
#[test]
fn [<proptest_map_with_dict_model_ $b _ $s _ $gamma>](model: HashMap<u64, u64>, arbitrary: HashSet<u64>) {
let entropy_map: MapWithDict<u64, u64, $b, $s> = MapWithDict::from_iter_with_params(
model.clone(),
$gamma as f32 / 100.0
).unwrap();
assert_eq!(entropy_map.len(), model.len());
assert_eq!(entropy_map.is_empty(), model.is_empty());
assert_eq!(
HashSet::<_, RandomState>::from_iter(entropy_map.keys()),
HashSet::from_iter(model.keys())
);
assert_eq!(
HashSet::<_, RandomState>::from_iter(entropy_map.values()),
HashSet::from_iter(model.values())
);
for (k, v) in &model {
assert!(entropy_map.contains_key(&k));
assert_eq!(entropy_map.get(&k), Some(v));
}
for k in arbitrary {
assert_eq!(
model.contains_key(&k),
entropy_map.contains_key(&k),
);
assert_eq!(entropy_map.get(&k), model.get(&k));
}
}
}
}
)*
};
}
proptest_map_with_dict_model!(
(2, 8, 100),
(4, 8, 100),
(7, 8, 100),
(8, 8, 100),
(15, 8, 100),
(16, 8, 100),
(23, 8, 100),
(24, 8, 100),
(31, 8, 100),
(32, 8, 100),
(33, 8, 100),
(48, 8, 100),
(53, 8, 100),
(61, 8, 100),
(63, 8, 100),
(64, 8, 100),
(32, 7, 100),
(32, 5, 100),
(32, 4, 100),
(32, 3, 100),
(32, 1, 100),
(32, 0, 100),
(32, 8, 200),
(32, 6, 200),
);
}