use std::borrow::Borrow;
use std::collections::HashSet;
use std::hash::{Hash, Hasher};
use std::mem::size_of_val;
use num::{PrimInt, Unsigned};
use wyhash::WyHash;
use crate::mphf::{Mphf, MphfError, DEFAULT_GAMMA};
#[derive(Default)]
#[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))]
#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))]
pub struct Set<K, const B: usize = 32, const S: usize = 8, ST = u8, H = WyHash>
where
ST: PrimInt + Unsigned,
H: Hasher + Default,
{
mphf: Mphf<B, S, ST, H>,
keys: Box<[K]>,
}
impl<K, const B: usize, const S: usize, ST, H> Set<K, B, S, ST, H>
where
K: Eq + Hash,
ST: PrimInt + Unsigned,
H: Hasher + Default,
{
pub fn from_iter_with_params<I>(iter: I, gamma: f32) -> Result<Self, MphfError>
where
I: IntoIterator<Item = K>,
{
let mut keys: Vec<K> = iter.into_iter().collect();
let mphf = Mphf::from_slice(&keys, gamma)?;
for i in 0..keys.len() {
loop {
let idx: usize = mphf.get(&keys[i]).unwrap();
if idx == i {
break;
}
keys.swap(i, idx);
}
}
Ok(Set { mphf, keys: keys.into_boxed_slice() })
}
#[inline]
pub fn contains<Q>(&self, key: &Q) -> bool
where
K: Borrow<Q> + PartialEq<Q>,
Q: Hash + Eq + ?Sized,
{
self.mphf
.get(key)
.map(|idx| unsafe { self.keys.get_unchecked(idx) == key })
.unwrap_or_default()
}
#[inline]
pub fn len(&self) -> usize {
self.keys.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.keys.is_empty()
}
#[inline]
pub fn iter(&self) -> impl Iterator<Item = &K> {
self.keys.iter()
}
#[inline]
pub fn size(&self) -> usize {
size_of_val(self) + self.mphf.size() + size_of_val(self.keys.as_ref())
}
}
impl<K> TryFrom<HashSet<K>> for Set<K>
where
K: Eq + Hash,
{
type Error = MphfError;
#[inline]
fn try_from(value: HashSet<K>) -> Result<Self, Self::Error> {
Set::from_iter_with_params(value, DEFAULT_GAMMA)
}
}
#[cfg(feature = "rkyv_derive")]
impl<K, const B: usize, const S: usize, ST, H> ArchivedSet<K, B, S, ST, H>
where
K: Eq + Hash + rkyv::Archive,
K::Archived: PartialEq<K>,
ST: PrimInt + Unsigned + rkyv::Archive<Archived = ST>,
H: Hasher + Default,
{
#[inline]
pub fn contains<Q: ?Sized>(&self, key: &Q) -> bool
where
K: Borrow<Q>,
<K as rkyv::Archive>::Archived: PartialEq<Q>,
Q: Hash + Eq,
{
self.mphf
.get(key)
.map(|idx| unsafe { self.keys.get_unchecked(idx) == key })
.unwrap_or_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
use paste::paste;
use proptest::prelude::*;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
fn gen_set(items_num: usize) -> HashSet<u64> {
let mut rng = ChaCha8Rng::seed_from_u64(123);
(0..items_num).map(|_| rng.gen::<u64>()).collect()
}
#[test]
fn test_set_with_hashset() {
let original_set = gen_set(1000);
let set = Set::try_from(original_set.clone()).unwrap();
assert_eq!(set.len(), original_set.len());
assert_eq!(set.is_empty(), original_set.is_empty());
for key in &original_set {
assert!(set.contains(key));
}
for &k in set.iter() {
assert!(original_set.contains(&k));
}
assert_eq!(set.size(), 8540);
}
#[test]
fn test_contains_borrow() {
let set = Set::try_from(HashSet::from(["a".to_string(), "b".to_string()])).unwrap();
assert!(set.contains("a"));
assert!(set.contains("b"));
assert!(!set.contains("c"));
}
#[cfg(feature = "rkyv_derive")]
#[test]
fn test_rkyv() {
let original_set = gen_set(1000);
let set = Set::try_from(original_set.clone()).unwrap();
let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&set).unwrap();
assert_eq!(rkyv_bytes.len(), 8408);
let rkyv_set = rkyv::check_archived_root::<Set<u64>>(&rkyv_bytes).unwrap();
for k in original_set.iter() {
assert!(rkyv_set.contains(k));
}
}
#[cfg(feature = "rkyv_derive")]
#[test]
fn test_rkyv_contains_borrow() {
let set = Set::try_from(HashSet::from(["a".to_string(), "b".to_string()])).unwrap();
let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&set).unwrap();
let rkyv_set = rkyv::check_archived_root::<Set<String>>(&rkyv_bytes).unwrap();
assert!(rkyv_set.contains("a"));
assert!(rkyv_set.contains("b"));
assert!(!rkyv_set.contains("c"));
}
macro_rules! proptest_set_model {
($(($b:expr, $s:expr, $gamma:expr)),* $(,)?) => {
$(
paste! {
proptest! {
#[test]
fn [<proptest_set_model_ $b _ $s _ $gamma>](model: HashSet<u64>, arbitrary: HashSet<u64>) {
let entropy_set: Set<u64, $b, $s> = Set::from_iter_with_params(
model.clone(),
$gamma as f32 / 100.0
).unwrap();
assert_eq!(entropy_set.len(), model.len());
assert_eq!(entropy_set.is_empty(), model.is_empty());
for elm in &model {
assert!(entropy_set.contains(&elm));
}
for elm in arbitrary {
assert_eq!(
model.contains(&elm),
entropy_set.contains(&elm),
);
}
}
}
}
)*
};
}
proptest_set_model!(
(2, 8, 100),
(4, 8, 100),
(7, 8, 100),
(8, 8, 100),
(15, 8, 100),
(16, 8, 100),
(23, 8, 100),
(24, 8, 100),
(31, 8, 100),
(32, 8, 100),
(33, 8, 100),
(48, 8, 100),
(53, 8, 100),
(61, 8, 100),
(63, 8, 100),
(64, 8, 100),
(32, 7, 100),
(32, 5, 100),
(32, 4, 100),
(32, 3, 100),
(32, 1, 100),
(32, 0, 100),
(32, 8, 200),
(32, 6, 200),
);
proptest! {
#[test]
fn test_set_contains(model: HashSet<u64>, arbitrary: HashSet<u64>) {
let entropy_set = Set::try_from(model.clone()).unwrap();
for elm in &model {
assert!(entropy_set.contains(&elm));
}
for elm in arbitrary {
assert_eq!(
model.contains(&elm),
entropy_set.contains(&elm),
);
}
}
}
}