griddle/external_trait_impls/rayon/
set.rs

1//! Rayon extensions for `HashSet`.
2
3use crate::hash_set::HashSet;
4use core::hash::{BuildHasher, Hash};
5use rayon_::iter::plumbing::UnindexedConsumer;
6use rayon_::iter::{FromParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator};
7
8/// Parallel iterator over shared references to elements in a set.
9///
10/// This iterator is created by the [`par_iter`] method on [`HashSet`]
11/// (provided by the [`IntoParallelRefIterator`] trait).
12/// See its documentation for more.
13///
14/// [`par_iter`]: /hashbrown/struct.HashSet.html#method.par_iter
15/// [`HashSet`]: /hashbrown/struct.HashSet.html
16/// [`IntoParallelRefIterator`]: https://docs.rs/rayon/1.0/rayon/iter/trait.IntoParallelRefIterator.html
17pub struct ParIter<'a, T, S> {
18    set: &'a HashSet<T, S>,
19}
20
21impl<'a, T: Sync, S: Sync> ParallelIterator for ParIter<'a, T, S> {
22    type Item = &'a T;
23
24    fn drive_unindexed<C>(self, consumer: C) -> C::Result
25    where
26        C: UnindexedConsumer<Self::Item>,
27    {
28        self.set.map.par_keys().drive_unindexed(consumer)
29    }
30}
31
32/// Parallel iterator over shared references to elements in the difference of
33/// sets.
34///
35/// This iterator is created by the [`par_difference`] method on [`HashSet`].
36/// See its documentation for more.
37///
38/// [`par_difference`]: /hashbrown/struct.HashSet.html#method.par_difference
39/// [`HashSet`]: /hashbrown/struct.HashSet.html
40pub struct ParDifference<'a, T, S> {
41    a: &'a HashSet<T, S>,
42    b: &'a HashSet<T, S>,
43}
44
45impl<'a, T, S> ParallelIterator for ParDifference<'a, T, S>
46where
47    T: Eq + Hash + Sync,
48    S: BuildHasher + Sync,
49{
50    type Item = &'a T;
51
52    fn drive_unindexed<C>(self, consumer: C) -> C::Result
53    where
54        C: UnindexedConsumer<Self::Item>,
55    {
56        self.a
57            .into_par_iter()
58            .filter(|&x| !self.b.contains(x))
59            .drive_unindexed(consumer)
60    }
61}
62
63/// Parallel iterator over shared references to elements in the symmetric
64/// difference of sets.
65///
66/// This iterator is created by the [`par_symmetric_difference`] method on
67/// [`HashSet`].
68/// See its documentation for more.
69///
70/// [`par_symmetric_difference`]: /hashbrown/struct.HashSet.html#method.par_symmetric_difference
71/// [`HashSet`]: /hashbrown/struct.HashSet.html
72pub struct ParSymmetricDifference<'a, T, S> {
73    a: &'a HashSet<T, S>,
74    b: &'a HashSet<T, S>,
75}
76
77impl<'a, T, S> ParallelIterator for ParSymmetricDifference<'a, T, S>
78where
79    T: Eq + Hash + Sync,
80    S: BuildHasher + Sync,
81{
82    type Item = &'a T;
83
84    fn drive_unindexed<C>(self, consumer: C) -> C::Result
85    where
86        C: UnindexedConsumer<Self::Item>,
87    {
88        self.a
89            .par_difference(self.b)
90            .chain(self.b.par_difference(self.a))
91            .drive_unindexed(consumer)
92    }
93}
94
95/// Parallel iterator over shared references to elements in the intersection of
96/// sets.
97///
98/// This iterator is created by the [`par_intersection`] method on [`HashSet`].
99/// See its documentation for more.
100///
101/// [`par_intersection`]: /hashbrown/struct.HashSet.html#method.par_intersection
102/// [`HashSet`]: /hashbrown/struct.HashSet.html
103pub struct ParIntersection<'a, T, S> {
104    a: &'a HashSet<T, S>,
105    b: &'a HashSet<T, S>,
106}
107
108impl<'a, T, S> ParallelIterator for ParIntersection<'a, T, S>
109where
110    T: Eq + Hash + Sync,
111    S: BuildHasher + Sync,
112{
113    type Item = &'a T;
114
115    fn drive_unindexed<C>(self, consumer: C) -> C::Result
116    where
117        C: UnindexedConsumer<Self::Item>,
118    {
119        self.a
120            .into_par_iter()
121            .filter(|&x| self.b.contains(x))
122            .drive_unindexed(consumer)
123    }
124}
125
126/// Parallel iterator over shared references to elements in the union of sets.
127///
128/// This iterator is created by the [`par_union`] method on [`HashSet`].
129/// See its documentation for more.
130///
131/// [`par_union`]: /hashbrown/struct.HashSet.html#method.par_union
132/// [`HashSet`]: /hashbrown/struct.HashSet.html
133pub struct ParUnion<'a, T, S> {
134    a: &'a HashSet<T, S>,
135    b: &'a HashSet<T, S>,
136}
137
138impl<'a, T, S> ParallelIterator for ParUnion<'a, T, S>
139where
140    T: Eq + Hash + Sync,
141    S: BuildHasher + Sync,
142{
143    type Item = &'a T;
144
145    fn drive_unindexed<C>(self, consumer: C) -> C::Result
146    where
147        C: UnindexedConsumer<Self::Item>,
148    {
149        self.a
150            .into_par_iter()
151            .chain(self.b.par_difference(self.a))
152            .drive_unindexed(consumer)
153    }
154}
155
156impl<T, S> HashSet<T, S>
157where
158    T: Eq + Hash + Sync,
159    S: BuildHasher + Sync,
160{
161    /// Visits (potentially in parallel) the values representing the difference,
162    /// i.e. the values that are in `self` but not in `other`.
163    #[cfg_attr(feature = "inline-more", inline)]
164    pub fn par_difference<'a>(&'a self, other: &'a Self) -> ParDifference<'a, T, S> {
165        ParDifference { a: self, b: other }
166    }
167
168    /// Visits (potentially in parallel) the values representing the symmetric
169    /// difference, i.e. the values that are in `self` or in `other` but not in both.
170    #[cfg_attr(feature = "inline-more", inline)]
171    pub fn par_symmetric_difference<'a>(
172        &'a self,
173        other: &'a Self,
174    ) -> ParSymmetricDifference<'a, T, S> {
175        ParSymmetricDifference { a: self, b: other }
176    }
177
178    /// Visits (potentially in parallel) the values representing the
179    /// intersection, i.e. the values that are both in `self` and `other`.
180    #[cfg_attr(feature = "inline-more", inline)]
181    pub fn par_intersection<'a>(&'a self, other: &'a Self) -> ParIntersection<'a, T, S> {
182        ParIntersection { a: self, b: other }
183    }
184
185    /// Visits (potentially in parallel) the values representing the union,
186    /// i.e. all the values in `self` or `other`, without duplicates.
187    #[cfg_attr(feature = "inline-more", inline)]
188    pub fn par_union<'a>(&'a self, other: &'a Self) -> ParUnion<'a, T, S> {
189        ParUnion { a: self, b: other }
190    }
191
192    /// Returns `true` if `self` has no elements in common with `other`.
193    /// This is equivalent to checking for an empty intersection.
194    ///
195    /// This method runs in a potentially parallel fashion.
196    pub fn par_is_disjoint(&self, other: &Self) -> bool {
197        self.into_par_iter().all(|x| !other.contains(x))
198    }
199
200    /// Returns `true` if the set is a subset of another,
201    /// i.e. `other` contains at least all the values in `self`.
202    ///
203    /// This method runs in a potentially parallel fashion.
204    pub fn par_is_subset(&self, other: &Self) -> bool {
205        if self.len() <= other.len() {
206            self.into_par_iter().all(|x| other.contains(x))
207        } else {
208            false
209        }
210    }
211
212    /// Returns `true` if the set is a superset of another,
213    /// i.e. `self` contains at least all the values in `other`.
214    ///
215    /// This method runs in a potentially parallel fashion.
216    pub fn par_is_superset(&self, other: &Self) -> bool {
217        other.par_is_subset(self)
218    }
219
220    /// Returns `true` if the set is equal to another,
221    /// i.e. both sets contain the same values.
222    ///
223    /// This method runs in a potentially parallel fashion.
224    pub fn par_eq(&self, other: &Self) -> bool {
225        self.len() == other.len() && self.par_is_subset(other)
226    }
227}
228
229impl<'a, T: Sync, S: Sync> IntoParallelIterator for &'a HashSet<T, S> {
230    type Item = &'a T;
231    type Iter = ParIter<'a, T, S>;
232
233    #[cfg_attr(feature = "inline-more", inline)]
234    fn into_par_iter(self) -> Self::Iter {
235        ParIter { set: self }
236    }
237}
238
239/// Collect values from a parallel iterator into a hashset.
240impl<T, S> FromParallelIterator<T> for HashSet<T, S>
241where
242    T: Eq + Hash + Send,
243    S: BuildHasher + Default,
244{
245    fn from_par_iter<P>(par_iter: P) -> Self
246    where
247        P: IntoParallelIterator<Item = T>,
248    {
249        let mut set = HashSet::default();
250        set.par_extend(par_iter);
251        set
252    }
253}
254
255/// Extend a hash set with items from a parallel iterator.
256impl<T, S> ParallelExtend<T> for HashSet<T, S>
257where
258    T: Eq + Hash + Send,
259    S: BuildHasher,
260{
261    fn par_extend<I>(&mut self, par_iter: I)
262    where
263        I: IntoParallelIterator<Item = T>,
264    {
265        extend(self, par_iter);
266    }
267}
268
269/// Extend a hash set with copied items from a parallel iterator.
270impl<'a, T, S> ParallelExtend<&'a T> for HashSet<T, S>
271where
272    T: 'a + Copy + Eq + Hash + Sync,
273    S: BuildHasher,
274{
275    fn par_extend<I>(&mut self, par_iter: I)
276    where
277        I: IntoParallelIterator<Item = &'a T>,
278    {
279        extend(self, par_iter);
280    }
281}
282
283// This is equal to the normal `HashSet` -- no custom advantage.
284fn extend<T, S, I>(set: &mut HashSet<T, S>, par_iter: I)
285where
286    T: Eq + Hash,
287    S: BuildHasher,
288    I: IntoParallelIterator,
289    HashSet<T, S>: Extend<I::Item>,
290{
291    let (list, len) = super::helpers::collect(par_iter);
292
293    // Values may be already present or show multiple times in the iterator.
294    // Reserve the entire length if the set is empty.
295    // Otherwise reserve half the length (rounded up), so the set
296    // will only resize twice in the worst case.
297    let reserve = if set.is_empty() { len } else { (len + 1) / 2 };
298    set.reserve(reserve);
299    for vec in list {
300        set.extend(vec);
301    }
302}
303
304#[cfg(test)]
305mod test_par_set {
306    use alloc::vec::Vec;
307    use core::sync::atomic::{AtomicUsize, Ordering};
308
309    use rayon_::prelude::*;
310
311    use crate::hash_set::HashSet;
312
313    #[test]
314    fn test_disjoint() {
315        let mut xs = HashSet::new();
316        let mut ys = HashSet::new();
317        assert!(xs.par_is_disjoint(&ys));
318        assert!(ys.par_is_disjoint(&xs));
319        assert!(xs.insert(5));
320        assert!(ys.insert(11));
321        assert!(xs.par_is_disjoint(&ys));
322        assert!(ys.par_is_disjoint(&xs));
323        assert!(xs.insert(7));
324        assert!(xs.insert(19));
325        assert!(xs.insert(4));
326        assert!(ys.insert(2));
327        assert!(ys.insert(-11));
328        assert!(xs.par_is_disjoint(&ys));
329        assert!(ys.par_is_disjoint(&xs));
330        assert!(ys.insert(7));
331        assert!(!xs.par_is_disjoint(&ys));
332        assert!(!ys.par_is_disjoint(&xs));
333    }
334
335    #[test]
336    fn test_subset_and_superset() {
337        let mut a = HashSet::new();
338        assert!(a.insert(0));
339        assert!(a.insert(5));
340        assert!(a.insert(11));
341        assert!(a.insert(7));
342
343        let mut b = HashSet::new();
344        assert!(b.insert(0));
345        assert!(b.insert(7));
346        assert!(b.insert(19));
347        assert!(b.insert(250));
348        assert!(b.insert(11));
349        assert!(b.insert(200));
350
351        assert!(!a.par_is_subset(&b));
352        assert!(!a.par_is_superset(&b));
353        assert!(!b.par_is_subset(&a));
354        assert!(!b.par_is_superset(&a));
355
356        assert!(b.insert(5));
357
358        assert!(a.par_is_subset(&b));
359        assert!(!a.par_is_superset(&b));
360        assert!(!b.par_is_subset(&a));
361        assert!(b.par_is_superset(&a));
362    }
363
364    #[test]
365    fn test_iterate() {
366        let mut a = HashSet::new();
367        for i in 0..32 {
368            assert!(a.insert(i));
369        }
370        let observed = AtomicUsize::new(0);
371        a.par_iter().for_each(|k| {
372            observed.fetch_or(1 << *k, Ordering::Relaxed);
373        });
374        assert_eq!(observed.into_inner(), 0xFFFF_FFFF);
375    }
376
377    #[test]
378    fn test_intersection() {
379        let mut a = HashSet::new();
380        let mut b = HashSet::new();
381
382        assert!(a.insert(11));
383        assert!(a.insert(1));
384        assert!(a.insert(3));
385        assert!(a.insert(77));
386        assert!(a.insert(103));
387        assert!(a.insert(5));
388        assert!(a.insert(-5));
389
390        assert!(b.insert(2));
391        assert!(b.insert(11));
392        assert!(b.insert(77));
393        assert!(b.insert(-9));
394        assert!(b.insert(-42));
395        assert!(b.insert(5));
396        assert!(b.insert(3));
397
398        let expected = [3, 5, 11, 77];
399        let i = a
400            .par_intersection(&b)
401            .map(|x| {
402                assert!(expected.contains(x));
403                1
404            })
405            .sum::<usize>();
406        assert_eq!(i, expected.len());
407    }
408
409    #[test]
410    fn test_difference() {
411        let mut a = HashSet::new();
412        let mut b = HashSet::new();
413
414        assert!(a.insert(1));
415        assert!(a.insert(3));
416        assert!(a.insert(5));
417        assert!(a.insert(9));
418        assert!(a.insert(11));
419
420        assert!(b.insert(3));
421        assert!(b.insert(9));
422
423        let expected = [1, 5, 11];
424        let i = a
425            .par_difference(&b)
426            .map(|x| {
427                assert!(expected.contains(x));
428                1
429            })
430            .sum::<usize>();
431        assert_eq!(i, expected.len());
432    }
433
434    #[test]
435    fn test_symmetric_difference() {
436        let mut a = HashSet::new();
437        let mut b = HashSet::new();
438
439        assert!(a.insert(1));
440        assert!(a.insert(3));
441        assert!(a.insert(5));
442        assert!(a.insert(9));
443        assert!(a.insert(11));
444
445        assert!(b.insert(-2));
446        assert!(b.insert(3));
447        assert!(b.insert(9));
448        assert!(b.insert(14));
449        assert!(b.insert(22));
450
451        let expected = [-2, 1, 5, 11, 14, 22];
452        let i = a
453            .par_symmetric_difference(&b)
454            .map(|x| {
455                assert!(expected.contains(x));
456                1
457            })
458            .sum::<usize>();
459        assert_eq!(i, expected.len());
460    }
461
462    #[test]
463    fn test_union() {
464        let mut a = HashSet::new();
465        let mut b = HashSet::new();
466
467        assert!(a.insert(1));
468        assert!(a.insert(3));
469        assert!(a.insert(5));
470        assert!(a.insert(9));
471        assert!(a.insert(11));
472        assert!(a.insert(16));
473        assert!(a.insert(19));
474        assert!(a.insert(24));
475
476        assert!(b.insert(-2));
477        assert!(b.insert(1));
478        assert!(b.insert(5));
479        assert!(b.insert(9));
480        assert!(b.insert(13));
481        assert!(b.insert(19));
482
483        let expected = [-2, 1, 3, 5, 9, 11, 13, 16, 19, 24];
484        let i = a
485            .par_union(&b)
486            .map(|x| {
487                assert!(expected.contains(x));
488                1
489            })
490            .sum::<usize>();
491        assert_eq!(i, expected.len());
492    }
493
494    #[test]
495    fn test_from_iter() {
496        let xs = [1, 2, 3, 4, 5, 6, 7, 8, 9];
497
498        let set: HashSet<_> = xs.par_iter().cloned().collect();
499
500        for x in &xs {
501            assert!(set.contains(x));
502        }
503    }
504
505    #[test]
506    fn test_move_iter() {
507        let hs = {
508            let mut hs = HashSet::new();
509
510            hs.insert('a');
511            hs.insert('b');
512
513            hs
514        };
515
516        let v = (&hs).into_par_iter().copied().collect::<Vec<char>>();
517        assert!(v == ['a', 'b'] || v == ['b', 'a']);
518    }
519
520    #[test]
521    fn test_eq() {
522        // These constants once happened to expose a bug in insert().
523        // I'm keeping them around to prevent a regression.
524        let mut s1 = HashSet::new();
525
526        s1.insert(1);
527        s1.insert(2);
528        s1.insert(3);
529
530        let mut s2 = HashSet::new();
531
532        s2.insert(1);
533        s2.insert(2);
534
535        assert!(!s1.par_eq(&s2));
536
537        s2.insert(3);
538
539        assert!(s1.par_eq(&s2));
540    }
541
542    #[test]
543    fn test_extend_ref() {
544        let mut a = HashSet::new();
545        a.insert(1);
546
547        a.par_extend(&[2, 3, 4][..]);
548
549        assert_eq!(a.len(), 4);
550        assert!(a.contains(&1));
551        assert!(a.contains(&2));
552        assert!(a.contains(&3));
553        assert!(a.contains(&4));
554
555        let mut b = HashSet::new();
556        b.insert(5);
557        b.insert(6);
558
559        a.par_extend(&b);
560
561        assert_eq!(a.len(), 6);
562        assert!(a.contains(&1));
563        assert!(a.contains(&2));
564        assert!(a.contains(&3));
565        assert!(a.contains(&4));
566        assert!(a.contains(&5));
567        assert!(a.contains(&6));
568    }
569}