entropy_map/
set.rs

1//! A module providing `Set`, an immutable set implementation backed by a MPHF.
2//!
3//! This implementation is optimized for efficient membership checks by using a MPHF to evaluate
4//! whether an item is in the set. Keys are stored in the map to ensure that queries for an item
5//! not in the set always fail.
6//!
7//! # When to use?
8//! Use this set implementation when you have a pre-defined set of keys and you want to check for
9//! efficient membership in that set. Because this set is immutable, it is not possible to
10//! dynamically update membership. However, when the `rkyv_derive` feature is enabled, you can use
11//! [`rkyv`](https://rkyv.org/) to perform zero-copy deserialization of a new set.
12
13use std::borrow::Borrow;
14use std::collections::HashSet;
15use std::hash::{Hash, Hasher};
16use std::mem::size_of_val;
17
18use num::{PrimInt, Unsigned};
19use wyhash::WyHash;
20
21use crate::mphf::{Mphf, MphfError, DEFAULT_GAMMA};
22
23/// An efficient, immutable set.
24#[derive(Default)]
25#[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))]
26#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))]
27pub struct Set<K, const B: usize = 32, const S: usize = 8, ST = u8, H = WyHash>
28where
29    ST: PrimInt + Unsigned,
30    H: Hasher + Default,
31{
32    /// Minimally Perfect Hash Function for keys indices retrieval
33    mphf: Mphf<B, S, ST, H>,
34    /// Set keys
35    keys: Box<[K]>,
36}
37
38impl<K, const B: usize, const S: usize, ST, H> Set<K, B, S, ST, H>
39where
40    K: Eq + Hash,
41    ST: PrimInt + Unsigned,
42    H: Hasher + Default,
43{
44    /// Constructs a `Set` from an iterator of keys and MPHF function parameters.
45    ///
46    /// # Examples
47    /// ```
48    /// use entropy_map::{Set, DEFAULT_GAMMA};
49    ///
50    /// let set: Set<u32> = Set::from_iter_with_params([1, 2, 3], DEFAULT_GAMMA).unwrap();
51    /// assert!(set.contains(&1));
52    /// ```
53    pub fn from_iter_with_params<I>(iter: I, gamma: f32) -> Result<Self, MphfError>
54    where
55        I: IntoIterator<Item = K>,
56    {
57        let mut keys: Vec<K> = iter.into_iter().collect();
58
59        let mphf = Mphf::from_slice(&keys, gamma)?;
60
61        // Re-order `keys` and according to `mphf`
62        for i in 0..keys.len() {
63            loop {
64                let idx: usize = mphf.get(&keys[i]).unwrap();
65                if idx == i {
66                    break;
67                }
68                keys.swap(i, idx);
69            }
70        }
71
72        Ok(Set { mphf, keys: keys.into_boxed_slice() })
73    }
74
75    /// Returns `true` if the set contains the value.
76    ///
77    /// # Examples
78    /// ```
79    /// # use std::collections::HashSet;
80    /// # use entropy_map::Set;
81    /// let set = Set::try_from(HashSet::from([1, 2, 3])).unwrap();
82    /// assert_eq!(set.contains(&1), true);
83    /// assert_eq!(set.contains(&4), false);
84    /// ```
85    #[inline]
86    pub fn contains<Q>(&self, key: &Q) -> bool
87    where
88        K: Borrow<Q> + PartialEq<Q>,
89        Q: Hash + Eq + ?Sized,
90    {
91        // SAFETY: `idx` is always within array bounds (ensured during construction)
92        self.mphf
93            .get(key)
94            .map(|idx| unsafe { self.keys.get_unchecked(idx) == key })
95            .unwrap_or_default()
96    }
97
98    /// Returns the number of elements in the set.
99    ///
100    /// # Examples
101    /// ```
102    /// # use std::collections::HashSet;
103    /// # use entropy_map::Set;
104    /// let set = Set::try_from(HashSet::from([1, 2, 3])).unwrap();
105    /// assert_eq!(set.len(), 3);
106    /// ```
107    #[inline]
108    pub fn len(&self) -> usize {
109        self.keys.len()
110    }
111
112    /// Returns `true` if the set contains no elements.
113    ///
114    /// # Examples
115    /// ```
116    /// # use std::collections::HashSet;
117    /// # use entropy_map::Set;
118    /// let set = Set::try_from(HashSet::from([0u32; 0])).unwrap();
119    /// assert_eq!(set.is_empty(), true);
120    /// let set = Set::try_from(HashSet::from([1, 2, 3])).unwrap();
121    /// assert_eq!(set.is_empty(), false);
122    /// ```
123    #[inline]
124    pub fn is_empty(&self) -> bool {
125        self.keys.is_empty()
126    }
127
128    /// Returns an iterator visiting set elements in arbitrary order.
129    ///
130    /// # Examples
131    /// ```
132    /// # use std::collections::HashSet;
133    /// # use entropy_map::Set;
134    /// let set = Set::try_from(HashSet::from([1, 2, 3])).unwrap();
135    /// for x in set.iter() {
136    ///     println!("{x}");
137    /// }
138    /// ```
139    #[inline]
140    pub fn iter(&self) -> impl Iterator<Item = &K> {
141        self.keys.iter()
142    }
143
144    /// Returns the total number of bytes occupied by `Set`.
145    ///
146    /// # Examples
147    /// ```
148    /// # use std::collections::HashSet;
149    /// # use entropy_map::Set;
150    /// let set = Set::try_from(HashSet::from([1, 2, 3])).unwrap();
151    /// assert_eq!(set.size(), 218);
152    /// ```
153    #[inline]
154    pub fn size(&self) -> usize {
155        size_of_val(self) + self.mphf.size() + size_of_val(self.keys.as_ref())
156    }
157}
158
159/// Creates a `Set` from a `HashSet`.
160impl<K> TryFrom<HashSet<K>> for Set<K>
161where
162    K: Eq + Hash,
163{
164    type Error = MphfError;
165
166    #[inline]
167    fn try_from(value: HashSet<K>) -> Result<Self, Self::Error> {
168        Set::from_iter_with_params(value, DEFAULT_GAMMA)
169    }
170}
171
172/// Implement `contains` for `Archived` version of `Set` if feature is enabled
173#[cfg(feature = "rkyv_derive")]
174impl<K, const B: usize, const S: usize, ST, H> ArchivedSet<K, B, S, ST, H>
175where
176    K: Eq + Hash + rkyv::Archive,
177    K::Archived: PartialEq<K>,
178    ST: PrimInt + Unsigned + rkyv::Archive<Archived = ST>,
179    H: Hasher + Default,
180{
181    /// Returns `true` if the set contains the value.
182    ///
183    /// # Examples
184    /// ```
185    /// # use std::collections::HashSet;
186    /// # use entropy_map::{ArchivedSet, Set};
187    /// let set: Set<u32> = Set::try_from(HashSet::from([1, 2, 3])).unwrap();
188    /// let archived_set = rkyv::from_bytes::<Set<u32>>(
189    ///     &rkyv::to_bytes::<_, 1024>(&set).unwrap()
190    /// ).unwrap();
191    /// assert_eq!(archived_set.contains(&1), true);
192    /// assert_eq!(archived_set.contains(&4), false);
193    /// ```
194    #[inline]
195    pub fn contains<Q: ?Sized>(&self, key: &Q) -> bool
196    where
197        K: Borrow<Q>,
198        <K as rkyv::Archive>::Archived: PartialEq<Q>,
199        Q: Hash + Eq,
200    {
201        // SAFETY: `idx` is always within bounds (ensured during construction)
202        self.mphf
203            .get(key)
204            .map(|idx| unsafe { self.keys.get_unchecked(idx) == key })
205            .unwrap_or_default()
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use paste::paste;
213    use proptest::prelude::*;
214    use rand::{Rng, SeedableRng};
215    use rand_chacha::ChaCha8Rng;
216
217    fn gen_set(items_num: usize) -> HashSet<u64> {
218        let mut rng = ChaCha8Rng::seed_from_u64(123);
219
220        (0..items_num).map(|_| rng.gen::<u64>()).collect()
221    }
222
223    #[test]
224    fn test_set_with_hashset() {
225        // Collect original key-value pairs directly into a HashSet
226        let original_set = gen_set(1000);
227
228        // Create the set from the iterator
229        let set = Set::try_from(original_set.clone()).unwrap();
230
231        // Test len
232        assert_eq!(set.len(), original_set.len());
233
234        // Test is_empty
235        assert_eq!(set.is_empty(), original_set.is_empty());
236
237        // Test get, contains_key
238        for key in &original_set {
239            assert!(set.contains(key));
240        }
241
242        // Test iter
243        for &k in set.iter() {
244            assert!(original_set.contains(&k));
245        }
246
247        // Test size
248        assert_eq!(set.size(), 8540);
249    }
250
251    /// Assert that we can call `.contains()` with `K::borrow()`.
252    #[test]
253    fn test_contains_borrow() {
254        let set = Set::try_from(HashSet::from(["a".to_string(), "b".to_string()])).unwrap();
255
256        assert!(set.contains("a"));
257        assert!(set.contains("b"));
258        assert!(!set.contains("c"));
259    }
260
261    #[cfg(feature = "rkyv_derive")]
262    #[test]
263    fn test_rkyv() {
264        // create regular `HashSet`, then `Set`, then serialize to `rkyv` bytes.
265        let original_set = gen_set(1000);
266        let set = Set::try_from(original_set.clone()).unwrap();
267        let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&set).unwrap();
268
269        assert_eq!(rkyv_bytes.len(), 8408);
270
271        let rkyv_set = rkyv::check_archived_root::<Set<u64>>(&rkyv_bytes).unwrap();
272
273        // Test get on `Archived` version
274        for k in original_set.iter() {
275            assert!(rkyv_set.contains(k));
276        }
277    }
278
279    #[cfg(feature = "rkyv_derive")]
280    #[test]
281    fn test_rkyv_contains_borrow() {
282        let set = Set::try_from(HashSet::from(["a".to_string(), "b".to_string()])).unwrap();
283        let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&set).unwrap();
284        let rkyv_set = rkyv::check_archived_root::<Set<String>>(&rkyv_bytes).unwrap();
285
286        assert!(rkyv_set.contains("a"));
287        assert!(rkyv_set.contains("b"));
288        assert!(!rkyv_set.contains("c"));
289    }
290
291    macro_rules! proptest_set_model {
292        ($(($b:expr, $s:expr, $gamma:expr)),* $(,)?) => {
293            $(
294                paste! {
295                    proptest! {
296                        #[test]
297                        fn [<proptest_set_model_ $b _ $s _ $gamma>](model: HashSet<u64>, arbitrary: HashSet<u64>) {
298                            let entropy_set: Set<u64, $b, $s> = Set::from_iter_with_params(
299                                model.clone(),
300                                $gamma as f32 / 100.0
301                            ).unwrap();
302
303                            // Assert that length matches model.
304                            assert_eq!(entropy_set.len(), model.len());
305                            assert_eq!(entropy_set.is_empty(), model.is_empty());
306
307                            // Assert that contains operations match model for contained elements.
308                            for elm in &model {
309                                assert!(entropy_set.contains(&elm));
310                            }
311
312                            // Assert that contains operations match model for random elements.
313                            for elm in arbitrary {
314                                assert_eq!(
315                                    model.contains(&elm),
316                                    entropy_set.contains(&elm),
317                                );
318                            }
319                        }
320                    }
321                }
322            )*
323        };
324    }
325
326    proptest_set_model!(
327        // (1, 8, 100),
328        (2, 8, 100),
329        (4, 8, 100),
330        (7, 8, 100),
331        (8, 8, 100),
332        (15, 8, 100),
333        (16, 8, 100),
334        (23, 8, 100),
335        (24, 8, 100),
336        (31, 8, 100),
337        (32, 8, 100),
338        (33, 8, 100),
339        (48, 8, 100),
340        (53, 8, 100),
341        (61, 8, 100),
342        (63, 8, 100),
343        (64, 8, 100),
344        (32, 7, 100),
345        (32, 5, 100),
346        (32, 4, 100),
347        (32, 3, 100),
348        (32, 1, 100),
349        (32, 0, 100),
350        (32, 8, 200),
351        (32, 6, 200),
352    );
353
354    proptest! {
355        #[test]
356        fn test_set_contains(model: HashSet<u64>, arbitrary: HashSet<u64>) {
357            let entropy_set = Set::try_from(model.clone()).unwrap();
358
359            for elm in &model {
360                assert!(entropy_set.contains(&elm));
361            }
362
363            for elm in arbitrary {
364                assert_eq!(
365                    model.contains(&elm),
366                    entropy_set.contains(&elm),
367                );
368            }
369        }
370    }
371}