hashbrown/external_trait_impls/rayon/
set.rs

1//! Rayon extensions for `HashSet`.
2
3use super::map;
4use crate::hash_set::HashSet;
5use crate::raw::{Allocator, Global};
6use core::hash::{BuildHasher, Hash};
7use rayon::iter::plumbing::UnindexedConsumer;
8use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator};
9
10/// Parallel iterator over elements of a consumed set.
11///
12/// This iterator is created by the [`into_par_iter`] method on [`HashSet`]
13/// (provided by the [`IntoParallelIterator`] trait).
14/// See its documentation for more.
15///
16/// [`into_par_iter`]: /hashbrown/struct.HashSet.html#method.into_par_iter
17/// [`HashSet`]: /hashbrown/struct.HashSet.html
18/// [`IntoParallelIterator`]: https://docs.rs/rayon/1.0/rayon/iter/trait.IntoParallelIterator.html
19pub struct IntoParIter<T, A: Allocator = Global> {
20    inner: map::IntoParIter<T, (), A>,
21}
22
23impl<T: Send, A: Allocator + Send> ParallelIterator for IntoParIter<T, A> {
24    type Item = T;
25
26    fn drive_unindexed<C>(self, consumer: C) -> C::Result
27    where
28        C: UnindexedConsumer<Self::Item>,
29    {
30        self.inner.map(|(k, _)| k).drive_unindexed(consumer)
31    }
32}
33
34/// Parallel draining iterator over entries of a set.
35///
36/// This iterator is created by the [`par_drain`] method on [`HashSet`].
37/// See its documentation for more.
38///
39/// [`par_drain`]: /hashbrown/struct.HashSet.html#method.par_drain
40/// [`HashSet`]: /hashbrown/struct.HashSet.html
41pub struct ParDrain<'a, T, A: Allocator = Global> {
42    inner: map::ParDrain<'a, T, (), A>,
43}
44
45impl<T: Send, A: Allocator + Send + Sync> ParallelIterator for ParDrain<'_, T, A> {
46    type Item = T;
47
48    fn drive_unindexed<C>(self, consumer: C) -> C::Result
49    where
50        C: UnindexedConsumer<Self::Item>,
51    {
52        self.inner.map(|(k, _)| k).drive_unindexed(consumer)
53    }
54}
55
56/// Parallel iterator over shared references to elements in a set.
57///
58/// This iterator is created by the [`par_iter`] method on [`HashSet`]
59/// (provided by the [`IntoParallelRefIterator`] trait).
60/// See its documentation for more.
61///
62/// [`par_iter`]: /hashbrown/struct.HashSet.html#method.par_iter
63/// [`HashSet`]: /hashbrown/struct.HashSet.html
64/// [`IntoParallelRefIterator`]: https://docs.rs/rayon/1.0/rayon/iter/trait.IntoParallelRefIterator.html
65pub struct ParIter<'a, T> {
66    inner: map::ParKeys<'a, T, ()>,
67}
68
69impl<'a, T: Sync> ParallelIterator for ParIter<'a, T> {
70    type Item = &'a T;
71
72    fn drive_unindexed<C>(self, consumer: C) -> C::Result
73    where
74        C: UnindexedConsumer<Self::Item>,
75    {
76        self.inner.drive_unindexed(consumer)
77    }
78}
79
80/// Parallel iterator over shared references to elements in the difference of
81/// sets.
82///
83/// This iterator is created by the [`par_difference`] method on [`HashSet`].
84/// See its documentation for more.
85///
86/// [`par_difference`]: /hashbrown/struct.HashSet.html#method.par_difference
87/// [`HashSet`]: /hashbrown/struct.HashSet.html
88pub struct ParDifference<'a, T, S, A: Allocator = Global> {
89    a: &'a HashSet<T, S, A>,
90    b: &'a HashSet<T, S, A>,
91}
92
93impl<'a, T, S, A> ParallelIterator for ParDifference<'a, T, S, A>
94where
95    T: Eq + Hash + Sync,
96    S: BuildHasher + Sync,
97    A: Allocator + Sync,
98{
99    type Item = &'a T;
100
101    fn drive_unindexed<C>(self, consumer: C) -> C::Result
102    where
103        C: UnindexedConsumer<Self::Item>,
104    {
105        self.a
106            .into_par_iter()
107            .filter(|&x| !self.b.contains(x))
108            .drive_unindexed(consumer)
109    }
110}
111
112/// Parallel iterator over shared references to elements in the symmetric
113/// difference of sets.
114///
115/// This iterator is created by the [`par_symmetric_difference`] method on
116/// [`HashSet`].
117/// See its documentation for more.
118///
119/// [`par_symmetric_difference`]: /hashbrown/struct.HashSet.html#method.par_symmetric_difference
120/// [`HashSet`]: /hashbrown/struct.HashSet.html
121pub struct ParSymmetricDifference<'a, T, S, A: Allocator = Global> {
122    a: &'a HashSet<T, S, A>,
123    b: &'a HashSet<T, S, A>,
124}
125
126impl<'a, T, S, A> ParallelIterator for ParSymmetricDifference<'a, T, S, A>
127where
128    T: Eq + Hash + Sync,
129    S: BuildHasher + Sync,
130    A: Allocator + Sync,
131{
132    type Item = &'a T;
133
134    fn drive_unindexed<C>(self, consumer: C) -> C::Result
135    where
136        C: UnindexedConsumer<Self::Item>,
137    {
138        self.a
139            .par_difference(self.b)
140            .chain(self.b.par_difference(self.a))
141            .drive_unindexed(consumer)
142    }
143}
144
145/// Parallel iterator over shared references to elements in the intersection of
146/// sets.
147///
148/// This iterator is created by the [`par_intersection`] method on [`HashSet`].
149/// See its documentation for more.
150///
151/// [`par_intersection`]: /hashbrown/struct.HashSet.html#method.par_intersection
152/// [`HashSet`]: /hashbrown/struct.HashSet.html
153pub struct ParIntersection<'a, T, S, A: Allocator = Global> {
154    a: &'a HashSet<T, S, A>,
155    b: &'a HashSet<T, S, A>,
156}
157
158impl<'a, T, S, A> ParallelIterator for ParIntersection<'a, T, S, A>
159where
160    T: Eq + Hash + Sync,
161    S: BuildHasher + Sync,
162    A: Allocator + Sync,
163{
164    type Item = &'a T;
165
166    fn drive_unindexed<C>(self, consumer: C) -> C::Result
167    where
168        C: UnindexedConsumer<Self::Item>,
169    {
170        self.a
171            .into_par_iter()
172            .filter(|&x| self.b.contains(x))
173            .drive_unindexed(consumer)
174    }
175}
176
177/// Parallel iterator over shared references to elements in the union of sets.
178///
179/// This iterator is created by the [`par_union`] method on [`HashSet`].
180/// See its documentation for more.
181///
182/// [`par_union`]: /hashbrown/struct.HashSet.html#method.par_union
183/// [`HashSet`]: /hashbrown/struct.HashSet.html
184pub struct ParUnion<'a, T, S, A: Allocator = Global> {
185    a: &'a HashSet<T, S, A>,
186    b: &'a HashSet<T, S, A>,
187}
188
189impl<'a, T, S, A> ParallelIterator for ParUnion<'a, T, S, A>
190where
191    T: Eq + Hash + Sync,
192    S: BuildHasher + Sync,
193    A: Allocator + Sync,
194{
195    type Item = &'a T;
196
197    fn drive_unindexed<C>(self, consumer: C) -> C::Result
198    where
199        C: UnindexedConsumer<Self::Item>,
200    {
201        // We'll iterate one set in full, and only the remaining difference from the other.
202        // Use the smaller set for the difference in order to reduce hash lookups.
203        let (smaller, larger) = if self.a.len() <= self.b.len() {
204            (self.a, self.b)
205        } else {
206            (self.b, self.a)
207        };
208        larger
209            .into_par_iter()
210            .chain(smaller.par_difference(larger))
211            .drive_unindexed(consumer)
212    }
213}
214
215impl<T, S, A> HashSet<T, S, A>
216where
217    T: Eq + Hash + Sync,
218    S: BuildHasher + Sync,
219    A: Allocator + Sync,
220{
221    /// Visits (potentially in parallel) the values representing the union,
222    /// i.e. all the values in `self` or `other`, without duplicates.
223    #[cfg_attr(feature = "inline-more", inline)]
224    pub fn par_union<'a>(&'a self, other: &'a Self) -> ParUnion<'a, T, S, A> {
225        ParUnion { a: self, b: other }
226    }
227
228    /// Visits (potentially in parallel) the values representing the difference,
229    /// i.e. the values that are in `self` but not in `other`.
230    #[cfg_attr(feature = "inline-more", inline)]
231    pub fn par_difference<'a>(&'a self, other: &'a Self) -> ParDifference<'a, T, S, A> {
232        ParDifference { a: self, b: other }
233    }
234
235    /// Visits (potentially in parallel) the values representing the symmetric
236    /// difference, i.e. the values that are in `self` or in `other` but not in both.
237    #[cfg_attr(feature = "inline-more", inline)]
238    pub fn par_symmetric_difference<'a>(
239        &'a self,
240        other: &'a Self,
241    ) -> ParSymmetricDifference<'a, T, S, A> {
242        ParSymmetricDifference { a: self, b: other }
243    }
244
245    /// Visits (potentially in parallel) the values representing the
246    /// intersection, i.e. the values that are both in `self` and `other`.
247    #[cfg_attr(feature = "inline-more", inline)]
248    pub fn par_intersection<'a>(&'a self, other: &'a Self) -> ParIntersection<'a, T, S, A> {
249        ParIntersection { a: self, b: other }
250    }
251
252    /// Returns `true` if `self` has no elements in common with `other`.
253    /// This is equivalent to checking for an empty intersection.
254    ///
255    /// This method runs in a potentially parallel fashion.
256    pub fn par_is_disjoint(&self, other: &Self) -> bool {
257        self.into_par_iter().all(|x| !other.contains(x))
258    }
259
260    /// Returns `true` if the set is a subset of another,
261    /// i.e. `other` contains at least all the values in `self`.
262    ///
263    /// This method runs in a potentially parallel fashion.
264    pub fn par_is_subset(&self, other: &Self) -> bool {
265        if self.len() <= other.len() {
266            self.into_par_iter().all(|x| other.contains(x))
267        } else {
268            false
269        }
270    }
271
272    /// Returns `true` if the set is a superset of another,
273    /// i.e. `self` contains at least all the values in `other`.
274    ///
275    /// This method runs in a potentially parallel fashion.
276    pub fn par_is_superset(&self, other: &Self) -> bool {
277        other.par_is_subset(self)
278    }
279
280    /// Returns `true` if the set is equal to another,
281    /// i.e. both sets contain the same values.
282    ///
283    /// This method runs in a potentially parallel fashion.
284    pub fn par_eq(&self, other: &Self) -> bool {
285        self.len() == other.len() && self.par_is_subset(other)
286    }
287}
288
289impl<T, S, A> HashSet<T, S, A>
290where
291    T: Eq + Hash + Send,
292    A: Allocator + Send,
293{
294    /// Consumes (potentially in parallel) all values in an arbitrary order,
295    /// while preserving the set's allocated memory for reuse.
296    #[cfg_attr(feature = "inline-more", inline)]
297    pub fn par_drain(&mut self) -> ParDrain<'_, T, A> {
298        ParDrain {
299            inner: self.map.par_drain(),
300        }
301    }
302}
303
304impl<T: Send, S, A: Allocator + Send> IntoParallelIterator for HashSet<T, S, A> {
305    type Item = T;
306    type Iter = IntoParIter<T, A>;
307
308    #[cfg_attr(feature = "inline-more", inline)]
309    fn into_par_iter(self) -> Self::Iter {
310        IntoParIter {
311            inner: self.map.into_par_iter(),
312        }
313    }
314}
315
316impl<'a, T: Sync, S, A: Allocator> IntoParallelIterator for &'a HashSet<T, S, A> {
317    type Item = &'a T;
318    type Iter = ParIter<'a, T>;
319
320    #[cfg_attr(feature = "inline-more", inline)]
321    fn into_par_iter(self) -> Self::Iter {
322        ParIter {
323            inner: self.map.par_keys(),
324        }
325    }
326}
327
328/// Collect values from a parallel iterator into a hashset.
329impl<T, S> FromParallelIterator<T> for HashSet<T, S, Global>
330where
331    T: Eq + Hash + Send,
332    S: BuildHasher + Default,
333{
334    fn from_par_iter<P>(par_iter: P) -> Self
335    where
336        P: IntoParallelIterator<Item = T>,
337    {
338        let mut set = HashSet::default();
339        set.par_extend(par_iter);
340        set
341    }
342}
343
344/// Extend a hash set with items from a parallel iterator.
345impl<T, S> ParallelExtend<T> for HashSet<T, S, Global>
346where
347    T: Eq + Hash + Send,
348    S: BuildHasher,
349{
350    fn par_extend<I>(&mut self, par_iter: I)
351    where
352        I: IntoParallelIterator<Item = T>,
353    {
354        extend(self, par_iter);
355    }
356}
357
358/// Extend a hash set with copied items from a parallel iterator.
359impl<'a, T, S> ParallelExtend<&'a T> for HashSet<T, S, Global>
360where
361    T: 'a + Copy + Eq + Hash + Sync,
362    S: BuildHasher,
363{
364    fn par_extend<I>(&mut self, par_iter: I)
365    where
366        I: IntoParallelIterator<Item = &'a T>,
367    {
368        extend(self, par_iter);
369    }
370}
371
372// This is equal to the normal `HashSet` -- no custom advantage.
373fn extend<T, S, I, A>(set: &mut HashSet<T, S, A>, par_iter: I)
374where
375    T: Eq + Hash,
376    S: BuildHasher,
377    A: Allocator,
378    I: IntoParallelIterator,
379    HashSet<T, S, A>: Extend<I::Item>,
380{
381    let (list, len) = super::helpers::collect(par_iter);
382
383    // Values may be already present or show multiple times in the iterator.
384    // Reserve the entire length if the set is empty.
385    // Otherwise reserve half the length (rounded up), so the set
386    // will only resize twice in the worst case.
387    let reserve = if set.is_empty() { len } else { (len + 1) / 2 };
388    set.reserve(reserve);
389    for vec in list {
390        set.extend(vec);
391    }
392}
393
394#[cfg(test)]
395mod test_par_set {
396    use alloc::vec::Vec;
397    use core::sync::atomic::{AtomicUsize, Ordering};
398
399    use rayon::prelude::*;
400
401    use crate::hash_set::HashSet;
402
403    #[test]
404    fn test_disjoint() {
405        let mut xs = HashSet::new();
406        let mut ys = HashSet::new();
407        assert!(xs.par_is_disjoint(&ys));
408        assert!(ys.par_is_disjoint(&xs));
409        assert!(xs.insert(5));
410        assert!(ys.insert(11));
411        assert!(xs.par_is_disjoint(&ys));
412        assert!(ys.par_is_disjoint(&xs));
413        assert!(xs.insert(7));
414        assert!(xs.insert(19));
415        assert!(xs.insert(4));
416        assert!(ys.insert(2));
417        assert!(ys.insert(-11));
418        assert!(xs.par_is_disjoint(&ys));
419        assert!(ys.par_is_disjoint(&xs));
420        assert!(ys.insert(7));
421        assert!(!xs.par_is_disjoint(&ys));
422        assert!(!ys.par_is_disjoint(&xs));
423    }
424
425    #[test]
426    fn test_subset_and_superset() {
427        let mut a = HashSet::new();
428        assert!(a.insert(0));
429        assert!(a.insert(5));
430        assert!(a.insert(11));
431        assert!(a.insert(7));
432
433        let mut b = HashSet::new();
434        assert!(b.insert(0));
435        assert!(b.insert(7));
436        assert!(b.insert(19));
437        assert!(b.insert(250));
438        assert!(b.insert(11));
439        assert!(b.insert(200));
440
441        assert!(!a.par_is_subset(&b));
442        assert!(!a.par_is_superset(&b));
443        assert!(!b.par_is_subset(&a));
444        assert!(!b.par_is_superset(&a));
445
446        assert!(b.insert(5));
447
448        assert!(a.par_is_subset(&b));
449        assert!(!a.par_is_superset(&b));
450        assert!(!b.par_is_subset(&a));
451        assert!(b.par_is_superset(&a));
452    }
453
454    #[test]
455    fn test_iterate() {
456        let mut a = HashSet::new();
457        for i in 0..32 {
458            assert!(a.insert(i));
459        }
460        let observed = AtomicUsize::new(0);
461        a.par_iter().for_each(|k| {
462            observed.fetch_or(1 << *k, Ordering::Relaxed);
463        });
464        assert_eq!(observed.into_inner(), 0xFFFF_FFFF);
465    }
466
467    #[test]
468    fn test_intersection() {
469        let mut a = HashSet::new();
470        let mut b = HashSet::new();
471
472        assert!(a.insert(11));
473        assert!(a.insert(1));
474        assert!(a.insert(3));
475        assert!(a.insert(77));
476        assert!(a.insert(103));
477        assert!(a.insert(5));
478        assert!(a.insert(-5));
479
480        assert!(b.insert(2));
481        assert!(b.insert(11));
482        assert!(b.insert(77));
483        assert!(b.insert(-9));
484        assert!(b.insert(-42));
485        assert!(b.insert(5));
486        assert!(b.insert(3));
487
488        let expected = [3, 5, 11, 77];
489        let i = a
490            .par_intersection(&b)
491            .map(|x| {
492                assert!(expected.contains(x));
493                1
494            })
495            .sum::<usize>();
496        assert_eq!(i, expected.len());
497    }
498
499    #[test]
500    fn test_difference() {
501        let mut a = HashSet::new();
502        let mut b = HashSet::new();
503
504        assert!(a.insert(1));
505        assert!(a.insert(3));
506        assert!(a.insert(5));
507        assert!(a.insert(9));
508        assert!(a.insert(11));
509
510        assert!(b.insert(3));
511        assert!(b.insert(9));
512
513        let expected = [1, 5, 11];
514        let i = a
515            .par_difference(&b)
516            .map(|x| {
517                assert!(expected.contains(x));
518                1
519            })
520            .sum::<usize>();
521        assert_eq!(i, expected.len());
522    }
523
524    #[test]
525    fn test_symmetric_difference() {
526        let mut a = HashSet::new();
527        let mut b = HashSet::new();
528
529        assert!(a.insert(1));
530        assert!(a.insert(3));
531        assert!(a.insert(5));
532        assert!(a.insert(9));
533        assert!(a.insert(11));
534
535        assert!(b.insert(-2));
536        assert!(b.insert(3));
537        assert!(b.insert(9));
538        assert!(b.insert(14));
539        assert!(b.insert(22));
540
541        let expected = [-2, 1, 5, 11, 14, 22];
542        let i = a
543            .par_symmetric_difference(&b)
544            .map(|x| {
545                assert!(expected.contains(x));
546                1
547            })
548            .sum::<usize>();
549        assert_eq!(i, expected.len());
550    }
551
552    #[test]
553    fn test_union() {
554        let mut a = HashSet::new();
555        let mut b = HashSet::new();
556
557        assert!(a.insert(1));
558        assert!(a.insert(3));
559        assert!(a.insert(5));
560        assert!(a.insert(9));
561        assert!(a.insert(11));
562        assert!(a.insert(16));
563        assert!(a.insert(19));
564        assert!(a.insert(24));
565
566        assert!(b.insert(-2));
567        assert!(b.insert(1));
568        assert!(b.insert(5));
569        assert!(b.insert(9));
570        assert!(b.insert(13));
571        assert!(b.insert(19));
572
573        let expected = [-2, 1, 3, 5, 9, 11, 13, 16, 19, 24];
574        let i = a
575            .par_union(&b)
576            .map(|x| {
577                assert!(expected.contains(x));
578                1
579            })
580            .sum::<usize>();
581        assert_eq!(i, expected.len());
582    }
583
584    #[test]
585    fn test_from_iter() {
586        let xs = [1, 2, 3, 4, 5, 6, 7, 8, 9];
587
588        let set: HashSet<_> = xs.par_iter().cloned().collect();
589
590        for x in &xs {
591            assert!(set.contains(x));
592        }
593    }
594
595    #[test]
596    fn test_move_iter() {
597        let hs = {
598            let mut hs = HashSet::new();
599
600            hs.insert('a');
601            hs.insert('b');
602
603            hs
604        };
605
606        let v = hs.into_par_iter().collect::<Vec<char>>();
607        assert!(v == ['a', 'b'] || v == ['b', 'a']);
608    }
609
610    #[test]
611    fn test_eq() {
612        // These constants once happened to expose a bug in insert().
613        // I'm keeping them around to prevent a regression.
614        let mut s1 = HashSet::new();
615
616        s1.insert(1);
617        s1.insert(2);
618        s1.insert(3);
619
620        let mut s2 = HashSet::new();
621
622        s2.insert(1);
623        s2.insert(2);
624
625        assert!(!s1.par_eq(&s2));
626
627        s2.insert(3);
628
629        assert!(s1.par_eq(&s2));
630    }
631
632    #[test]
633    fn test_extend_ref() {
634        let mut a = HashSet::new();
635        a.insert(1);
636
637        a.par_extend(&[2, 3, 4][..]);
638
639        assert_eq!(a.len(), 4);
640        assert!(a.contains(&1));
641        assert!(a.contains(&2));
642        assert!(a.contains(&3));
643        assert!(a.contains(&4));
644
645        let mut b = HashSet::new();
646        b.insert(5);
647        b.insert(6);
648
649        a.par_extend(&b);
650
651        assert_eq!(a.len(), 6);
652        assert!(a.contains(&1));
653        assert!(a.contains(&2));
654        assert!(a.contains(&3));
655        assert!(a.contains(&4));
656        assert!(a.contains(&5));
657        assert!(a.contains(&6));
658    }
659}