contrie/
set.rs

1//! The [`ConSet`] and other related structures.
2
3use std::borrow::Borrow;
4use std::collections::hash_map::RandomState;
5use std::fmt::{Debug, Formatter, Result as FmtResult};
6use std::hash::{BuildHasher, Hash};
7use std::iter::FromIterator;
8
9use crossbeam_epoch;
10
11#[cfg(feature = "rayon")]
12use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator};
13
14use crate::raw::config::Trivial as TrivialConfig;
15use crate::raw::{self, Raw};
16
17/// A concurrent lock-free set.
18///
19/// Note that due to the limitations described in the crate level docs, values returned by looking
20/// up (or misplacing or removing) are always copied using the `Clone` trait. Therefore, the set is
21/// more suitable for types that are cheap to copy (eg. `u64` or `IpAddr`).
22///
23/// If you intend to store types that are more expensive to make copies of or are not `Clone`, you
24/// can wrap them in an `Arc` (eg. `Arc<str>`).
25///
26/// ```rust
27/// use contrie::ConSet;
28/// use crossbeam_utils::thread;
29///
30/// let set = ConSet::new();
31///
32/// thread::scope(|s| {
33///     s.spawn(|_| {
34///         set.insert("hello");
35///     });
36///     s.spawn(|_| {
37///         set.insert("world");
38///     });
39/// }).unwrap();
40///
41/// assert_eq!(Some("hello"), set.get("hello"));
42/// assert_eq!(Some("world"), set.get("world"));
43/// assert_eq!(None, set.get("universe"));
44/// set.remove("world");
45/// assert_eq!(None, set.get("world"));
46/// ```
47///
48/// ```rust
49/// use contrie::set::{ConSet};
50/// let set: ConSet<usize> = ConSet::new();
51///
52/// set.insert(0);
53/// set.insert(1);
54///
55/// assert!(set.contains(&1));
56///
57/// set.remove(&1);
58/// assert!(!set.contains(&1));
59///
60/// set.remove(&0);
61/// assert!(set.is_empty());
62/// ```
63pub struct ConSet<T, S = RandomState>
64where
65    T: Clone + Hash + Eq + 'static,
66{
67    raw: Raw<TrivialConfig<T>, S>,
68}
69
70impl<T> ConSet<T, RandomState>
71where
72    T: Clone + Hash + Eq + 'static,
73{
74    /// Creates a new empty set.
75    pub fn new() -> Self {
76        Self::with_hasher(RandomState::default())
77    }
78}
79
80impl<T, S> ConSet<T, S>
81where
82    T: Clone + Hash + Eq + 'static,
83    S: BuildHasher,
84{
85    /// Creates a new empty set with the given hasher.
86    pub fn with_hasher(hasher: S) -> Self {
87        Self {
88            raw: Raw::with_hasher(hasher),
89        }
90    }
91
92    /// Inserts a new value into the set.
93    ///
94    /// It returns the previous value, if any was present.
95    pub fn insert(&self, value: T) -> Option<T> {
96        let pin = crossbeam_epoch::pin();
97        self.raw.insert(value, &pin).cloned()
98    }
99
100    /// Looks up a value in the set.
101    ///
102    /// This creates a copy of the original value.
103    pub fn get<Q>(&self, key: &Q) -> Option<T>
104    where
105        Q: ?Sized + Eq + Hash,
106        T: Borrow<Q>,
107    {
108        let pin = crossbeam_epoch::pin();
109        self.raw.get(key, &pin).cloned()
110    }
111
112    /// Checks if a value identified by the given key is present in the set.
113    ///
114    /// Note that by the time you can act on it, the presence of the value can change (eg. other
115    /// thread can add or remove it in the meantime).
116    pub fn contains<Q>(&self, key: &Q) -> bool
117    where
118        Q: ?Sized + Eq + Hash,
119        T: Borrow<Q>,
120    {
121        let pin = crossbeam_epoch::pin();
122        self.raw.get(key, &pin).is_some()
123    }
124
125    /// Removes an element identified by the given key, returning it.
126    pub fn remove<Q>(&self, key: &Q) -> Option<T>
127    where
128        Q: ?Sized + Eq + Hash,
129        T: Borrow<Q>,
130    {
131        let pin = crossbeam_epoch::pin();
132        self.raw.remove(key, &pin).cloned()
133    }
134
135    /// Checks if the set is currently empty.
136    ///
137    /// Note that due to being concurrent, the use-case of this method is mostly for debugging
138    /// purposes, because the state can change between reading the value and acting on it.
139    pub fn is_empty(&self) -> bool {
140        self.raw.is_empty()
141    }
142}
143
144impl<T> Default for ConSet<T, RandomState>
145where
146    T: Clone + Hash + Eq + 'static,
147{
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153impl<T, S> Debug for ConSet<T, S>
154where
155    T: Debug + Clone + Hash + Eq + 'static,
156{
157    fn fmt(&self, fmt: &mut Formatter) -> FmtResult {
158        fmt.debug_set().entries(self.iter()).finish()
159    }
160}
161
162impl<T, S> ConSet<T, S>
163where
164    T: Clone + Hash + Eq + 'static,
165{
166    /// Returns an iterator through the elements of the set.
167    pub fn iter(&self) -> Iter<T, S> {
168        Iter {
169            inner: raw::iterator::Iter::new(&self.raw),
170        }
171    }
172}
173
174/// The iterator of the [`ConSet`].
175///
176/// See the [`iter`][ConSet::iter] method for details.
177pub struct Iter<'a, T, S>
178where
179    T: Clone + Hash + Eq + 'static,
180{
181    inner: raw::iterator::Iter<'a, TrivialConfig<T>, S>,
182}
183
184impl<'a, T, S> Iterator for Iter<'a, T, S>
185where
186    T: Clone + Hash + Eq + 'static,
187{
188    type Item = T;
189
190    fn next(&mut self) -> Option<T> {
191        self.inner.next().cloned()
192    }
193}
194
195impl<'a, T, S> IntoIterator for &'a ConSet<T, S>
196where
197    T: Clone + Hash + Eq + 'static,
198{
199    type Item = T;
200    type IntoIter = Iter<'a, T, S>;
201
202    fn into_iter(self) -> Self::IntoIter {
203        self.iter()
204    }
205}
206
207impl<'a, T, S> Extend<T> for &'a ConSet<T, S>
208where
209    T: Clone + Hash + Eq + 'static,
210    S: BuildHasher,
211{
212    fn extend<I>(&mut self, iter: I)
213    where
214        I: IntoIterator<Item = T>,
215    {
216        for n in iter {
217            self.insert(n);
218        }
219    }
220}
221
222impl<T, S> Extend<T> for ConSet<T, S>
223where
224    T: Clone + Hash + Eq + 'static,
225    S: BuildHasher,
226{
227    fn extend<I>(&mut self, iter: I)
228    where
229        I: IntoIterator<Item = T>,
230    {
231        let mut me: &ConSet<_, _> = self;
232        me.extend(iter);
233    }
234}
235
236impl<T> FromIterator<T> for ConSet<T>
237where
238    T: Clone + Hash + Eq + 'static,
239{
240    fn from_iter<I>(iter: I) -> Self
241    where
242        I: IntoIterator<Item = T>,
243    {
244        let mut me = ConSet::new();
245        me.extend(iter);
246        me
247    }
248}
249
250#[cfg(feature = "rayon")]
251impl<'a, T, S> ParallelExtend<T> for &'a ConSet<T, S>
252where
253    T: Clone + Hash + Eq + Send + Sync,
254    S: BuildHasher + Sync,
255{
256    fn par_extend<I>(&mut self, par_iter: I)
257    where
258        I: IntoParallelIterator<Item = T>,
259    {
260        par_iter.into_par_iter().for_each(|n| {
261            self.insert(n);
262        });
263    }
264}
265
266#[cfg(feature = "rayon")]
267impl<T, S> ParallelExtend<T> for ConSet<T, S>
268where
269    T: Clone + Hash + Eq + Send + Sync,
270    S: BuildHasher + Sync,
271{
272    fn par_extend<I>(&mut self, par_iter: I)
273    where
274        I: IntoParallelIterator<Item = T>,
275    {
276        let mut me: &ConSet<_, _> = self;
277        me.par_extend(par_iter);
278    }
279}
280
281#[cfg(feature = "rayon")]
282impl<T> FromParallelIterator<T> for ConSet<T>
283where
284    T: Clone + Hash + Eq + Send + Sync,
285{
286    fn from_par_iter<I>(iter: I) -> Self
287    where
288        I: IntoParallelIterator<Item = T>,
289    {
290        let mut me = ConSet::new();
291        me.par_extend(iter);
292        me
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use crossbeam_utils::thread;
299    #[cfg(feature = "rayon")]
300    use rayon::prelude::*;
301
302    use super::*;
303    use crate::raw::tests::NoHasher;
304    use crate::raw::LEVEL_CELLS;
305
306    const TEST_THREADS: usize = 4;
307    const TEST_BATCH: usize = 10000;
308    const TEST_BATCH_SMALL: usize = 100;
309    const TEST_REP: usize = 20;
310
311    #[test]
312    fn debug_when_empty() {
313        let set: ConSet<String> = ConSet::new();
314        assert_eq!("{}", &format!("{:?}", set));
315    }
316
317    #[test]
318    fn debug_when_has_elements() {
319        let set: ConSet<&str> = ConSet::new();
320        assert!(set.insert("hello").is_none());
321        assert!(set.insert("world").is_none());
322        let expected = "{\"hello\", \"world\"}";
323        let actual = &format!("{:?}", set);
324
325        let mut expected_chars: Vec<char> = expected.chars().collect();
326        expected_chars.sort();
327        let mut actual_chars: Vec<char> = actual.chars().collect();
328        actual_chars.sort();
329        assert_eq!(expected_chars, actual_chars);
330    }
331
332    #[test]
333    fn debug_when_elements_are_added_and_removed() {
334        let set: ConSet<&str> = ConSet::new();
335        assert_eq!("{}", &format!("{:?}", set));
336        assert!(set.insert("hello").is_none());
337        assert!(set.insert("hello").is_some());
338        assert!(set.insert("hello").is_some());
339        assert_eq!("{\"hello\"}", &format!("{:?}", set));
340        assert!(set.remove("hello").is_some());
341        assert_eq!("{}", &format!("{:?}", set));
342    }
343
344    #[test]
345    fn create_destroy() {
346        let set: ConSet<String> = ConSet::new();
347        drop(set);
348    }
349
350    #[test]
351    fn lookup_empty() {
352        let set: ConSet<String> = ConSet::new();
353        assert!(set.get("hello").is_none());
354    }
355
356    #[test]
357    fn insert_lookup() {
358        let set = ConSet::new();
359        assert!(set.insert("hello").is_none());
360        assert!(set.get("world").is_none());
361        let found = set.get("hello").unwrap();
362        assert_eq!("hello", found);
363    }
364
365    // Insert a lot of things, to make sure we have multiple levels.
366    #[test]
367    fn insert_many() {
368        let set = ConSet::new();
369        for i in 0..TEST_BATCH * LEVEL_CELLS {
370            assert!(set.insert(i).is_none());
371        }
372
373        for i in 0..TEST_BATCH * LEVEL_CELLS {
374            assert_eq!(i, set.get(&i).unwrap());
375        }
376    }
377
378    #[test]
379    fn par_insert_many() {
380        for _ in 0..TEST_REP {
381            let set: ConSet<usize> = ConSet::new();
382            thread::scope(|s| {
383                for t in 0..TEST_THREADS {
384                    let set = &set;
385                    s.spawn(move |_| {
386                        for i in 0..TEST_BATCH {
387                            let num = t * TEST_BATCH + i;
388                            assert!(set.insert(num).is_none());
389                        }
390                    });
391                }
392            })
393            .unwrap();
394
395            for i in 0..TEST_BATCH * TEST_THREADS {
396                assert_eq!(set.get(&i).unwrap(), i);
397            }
398        }
399    }
400
401    #[test]
402    fn par_get_many() {
403        for _ in 0..TEST_REP {
404            let set = ConSet::new();
405            for i in 0..TEST_BATCH * TEST_THREADS {
406                assert!(set.insert(i).is_none());
407            }
408            thread::scope(|s| {
409                for t in 0..TEST_THREADS {
410                    let set = &set;
411                    s.spawn(move |_| {
412                        for i in 0..TEST_BATCH {
413                            let num = t * TEST_BATCH + i;
414                            assert_eq!(set.get(&num).unwrap(), num);
415                        }
416                    });
417                }
418            })
419            .unwrap();
420        }
421    }
422
423    #[test]
424    fn no_collisions() {
425        let set = ConSet::with_hasher(NoHasher);
426        // While their hash is the same under the hasher, they don't kick each other out.
427        for i in 0..TEST_BATCH_SMALL {
428            assert!(set.insert(i).is_none());
429        }
430        // And all are present.
431        for i in 0..TEST_BATCH_SMALL {
432            assert_eq!(i, set.get(&i).unwrap());
433        }
434        // No key kicks another one out.
435        for i in 0..TEST_BATCH_SMALL {
436            assert_eq!(i, set.insert(i).unwrap());
437        }
438    }
439
440    #[test]
441    fn simple_remove() {
442        let set = ConSet::new();
443        assert!(set.remove(&42).is_none());
444        assert!(set.insert(42).is_none());
445        assert_eq!(42, set.get(&42).unwrap());
446        assert_eq!(42, set.remove(&42).unwrap());
447        assert!(set.get(&42).is_none());
448        assert!(set.is_empty());
449        assert!(set.remove(&42).is_none());
450        assert!(set.is_empty());
451    }
452
453    fn remove_many_inner<H: BuildHasher>(mut set: ConSet<usize, H>, len: usize) {
454        for i in 0..len {
455            assert!(set.insert(i).is_none());
456        }
457        for i in 0..len {
458            assert_eq!(i, set.get(&i).unwrap());
459            assert_eq!(i, set.remove(&i).unwrap());
460            assert!(set.get(&i).is_none());
461            set.raw.assert_pruned();
462        }
463
464        assert!(set.is_empty());
465    }
466
467    #[test]
468    fn remove_many() {
469        remove_many_inner(ConSet::new(), TEST_BATCH);
470    }
471
472    #[test]
473    fn remove_many_collision() {
474        remove_many_inner(ConSet::with_hasher(NoHasher), TEST_BATCH_SMALL);
475    }
476
477    #[test]
478    fn collision_remove_one_left() {
479        let mut set = ConSet::with_hasher(NoHasher);
480        set.insert(1);
481        set.insert(2);
482
483        set.raw.assert_pruned();
484
485        assert!(set.remove(&2).is_some());
486        set.raw.assert_pruned();
487
488        assert!(set.remove(&1).is_some());
489
490        set.raw.assert_pruned();
491        assert!(set.is_empty());
492    }
493
494    #[test]
495    fn collision_remove_one_left_with_str() {
496        let mut set = ConSet::with_hasher(NoHasher);
497        set.insert("hello");
498        set.insert("world");
499
500        set.raw.assert_pruned();
501
502        assert!(set.remove("world").is_some());
503        set.raw.assert_pruned();
504
505        assert!(set.remove("hello").is_some());
506
507        set.raw.assert_pruned();
508        assert!(set.is_empty());
509    }
510
511    #[test]
512    fn remove_par() {
513        let mut set = ConSet::new();
514        for i in 0..TEST_THREADS * TEST_BATCH {
515            set.insert(i);
516        }
517
518        thread::scope(|s| {
519            for t in 0..TEST_THREADS {
520                let set = &set;
521                s.spawn(move |_| {
522                    for i in 0..TEST_BATCH {
523                        let num = t * TEST_BATCH + i;
524                        let val = set.remove(&num).unwrap();
525                        assert_eq!(num, val);
526                        assert_eq!(num, val);
527                    }
528                });
529            }
530        })
531        .unwrap();
532
533        set.raw.assert_pruned();
534        assert!(set.is_empty());
535    }
536
537    fn iter_test_inner<S: BuildHasher>(set: ConSet<usize, S>) {
538        for i in 0..TEST_BATCH_SMALL {
539            assert!(set.insert(i).is_none());
540        }
541
542        let mut extracted = set.iter().collect::<Vec<_>>();
543
544        extracted.sort();
545        let expected = (0..TEST_BATCH_SMALL).collect::<Vec<_>>();
546        assert_eq!(expected, extracted);
547    }
548
549    #[test]
550    fn iter() {
551        let set = ConSet::new();
552        iter_test_inner(set);
553    }
554
555    #[test]
556    fn iter_collision() {
557        let set = ConSet::with_hasher(NoHasher);
558        iter_test_inner(set);
559    }
560
561    #[test]
562    fn collect() {
563        let set = (0..TEST_BATCH_SMALL).collect::<ConSet<_>>();
564
565        let mut extracted = set.iter().collect::<Vec<_>>();
566        extracted.sort();
567        let expected = (0..TEST_BATCH_SMALL).collect::<Vec<_>>();
568        assert_eq!(expected, extracted);
569    }
570
571    #[test]
572    fn par_extend() {
573        let set = ConSet::new();
574
575        thread::scope(|s| {
576            for t in 0..TEST_THREADS {
577                let mut set = &set;
578                s.spawn(move |_| {
579                    let start = t * TEST_BATCH_SMALL;
580                    let iter = start..start + TEST_BATCH_SMALL;
581                    set.extend(iter);
582                });
583            }
584        })
585        .unwrap();
586
587        let mut extracted = set.iter().collect::<Vec<_>>();
588
589        extracted.sort();
590        let expected = (0..TEST_THREADS * TEST_BATCH_SMALL).collect::<Vec<_>>();
591
592        assert_eq!(expected, extracted);
593    }
594
595    #[cfg(feature = "rayon")]
596    #[test]
597    fn rayon_extend() {
598        let mut map = ConSet::new();
599        map.par_extend((0..TEST_BATCH_SMALL).into_par_iter());
600
601        let mut extracted = map.iter().collect::<Vec<_>>();
602        extracted.par_sort();
603
604        let expected = (0..TEST_BATCH_SMALL).collect::<Vec<_>>();
605        assert_eq!(expected, extracted);
606    }
607
608    #[cfg(feature = "rayon")]
609    #[test]
610    fn rayon_from_par_iter() {
611        let map = ConSet::from_par_iter((0..TEST_BATCH_SMALL).into_par_iter());
612
613        let mut extracted = map.iter().collect::<Vec<_>>();
614        extracted.sort();
615
616        let expected = (0..TEST_BATCH_SMALL).collect::<Vec<_>>();
617        assert_eq!(expected, extracted);
618    }
619}