rayon_hash/par/
set.rs

1use rayon::iter::plumbing::UnindexedConsumer;
2/// Rayon extensions for `HashSet`
3use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator};
4use std::hash::{BuildHasher, Hash};
5
6use super::map;
7use crate::HashSet;
8
9pub struct ParIntoIter<T: Send> {
10    inner: map::ParIntoIter<T, ()>,
11}
12
13pub struct ParIter<'a, T: Sync + 'a> {
14    inner: map::ParKeys<'a, T, ()>,
15}
16
17pub struct ParDifference<'a, T: Sync + 'a, S: Sync + 'a> {
18    a: &'a HashSet<T, S>,
19    b: &'a HashSet<T, S>,
20}
21
22pub struct ParSymmetricDifference<'a, T: Sync + 'a, S: Sync + 'a> {
23    a: &'a HashSet<T, S>,
24    b: &'a HashSet<T, S>,
25}
26
27pub struct ParIntersection<'a, T: Sync + 'a, S: Sync + 'a> {
28    a: &'a HashSet<T, S>,
29    b: &'a HashSet<T, S>,
30}
31
32pub struct ParUnion<'a, T: Sync + 'a, S: Sync + 'a> {
33    a: &'a HashSet<T, S>,
34    b: &'a HashSet<T, S>,
35}
36
37impl<T, S> HashSet<T, S>
38where
39    T: Eq + Hash + Sync,
40    S: BuildHasher + Sync,
41{
42    pub fn par_difference<'a>(&'a self, other: &'a Self) -> ParDifference<'a, T, S> {
43        ParDifference { a: self, b: other }
44    }
45
46    pub fn par_symmetric_difference<'a>(
47        &'a self,
48        other: &'a Self,
49    ) -> ParSymmetricDifference<'a, T, S> {
50        ParSymmetricDifference { a: self, b: other }
51    }
52
53    pub fn par_intersection<'a>(&'a self, other: &'a Self) -> ParIntersection<'a, T, S> {
54        ParIntersection { a: self, b: other }
55    }
56
57    pub fn par_union<'a>(&'a self, other: &'a Self) -> ParUnion<'a, T, S> {
58        ParUnion { a: self, b: other }
59    }
60
61    pub fn par_is_disjoint(&self, other: &Self) -> bool {
62        self.into_par_iter().all(|x| !other.contains(x))
63    }
64
65    pub fn par_is_subset(&self, other: &Self) -> bool {
66        self.into_par_iter().all(|x| other.contains(x))
67    }
68
69    pub fn par_is_superset(&self, other: &Self) -> bool {
70        other.is_subset(self)
71    }
72
73    pub fn par_eq(&self, other: &Self) -> bool {
74        self.len() == other.len() && self.par_is_subset(other)
75    }
76}
77
78impl<T: Send, S> IntoParallelIterator for HashSet<T, S> {
79    type Item = T;
80    type Iter = ParIntoIter<T>;
81
82    fn into_par_iter(self) -> Self::Iter {
83        ParIntoIter {
84            inner: self.map.into_par_iter(),
85        }
86    }
87}
88
89impl<'a, T: Sync, S> IntoParallelIterator for &'a HashSet<T, S> {
90    type Item = &'a T;
91    type Iter = ParIter<'a, T>;
92
93    fn into_par_iter(self) -> Self::Iter {
94        ParIter {
95            inner: self.map.par_keys(),
96        }
97    }
98}
99
100/// Collect values from a parallel iterator into a hashset.
101impl<T, S> FromParallelIterator<T> for HashSet<T, S>
102where
103    T: Eq + Hash + Send,
104    S: BuildHasher + Default + Send,
105{
106    fn from_par_iter<P>(par_iter: P) -> Self
107    where
108        P: IntoParallelIterator<Item = T>,
109    {
110        let mut set = HashSet::default();
111        set.par_extend(par_iter);
112        set
113    }
114}
115
116/// Extend a hash set with items from a parallel iterator.
117impl<T, S> ParallelExtend<T> for HashSet<T, S>
118where
119    T: Eq + Hash + Send,
120    S: BuildHasher + Send,
121{
122    fn par_extend<I>(&mut self, par_iter: I)
123    where
124        I: IntoParallelIterator<Item = T>,
125    {
126        extend(self, par_iter);
127    }
128}
129
130/// Extend a hash set with copied items from a parallel iterator.
131impl<'a, T, S> ParallelExtend<&'a T> for HashSet<T, S>
132where
133    T: 'a + Copy + Eq + Hash + Send + Sync,
134    S: BuildHasher + Send,
135{
136    fn par_extend<I>(&mut self, par_iter: I)
137    where
138        I: IntoParallelIterator<Item = &'a T>,
139    {
140        extend(self, par_iter);
141    }
142}
143
144// This is equal to the normal `HashSet` -- no custom advantage.
145fn extend<T, S, I>(set: &mut HashSet<T, S>, par_iter: I)
146where
147    T: Eq + Hash,
148    S: BuildHasher,
149    I: IntoParallelIterator,
150    HashSet<T, S>: Extend<I::Item>,
151{
152    let (list, len) = super::collect(par_iter);
153
154    // Values may be already present or show multiple times in the iterator.
155    // Reserve the entire length if the set is empty.
156    // Otherwise reserve half the length (rounded up), so the set
157    // will only resize twice in the worst case.
158    let reserve = if set.is_empty() { len } else { (len + 1) / 2 };
159    set.reserve(reserve);
160    for vec in list {
161        set.extend(vec);
162    }
163}
164
165impl<T: Send> ParallelIterator for ParIntoIter<T> {
166    type Item = T;
167
168    fn drive_unindexed<C>(self, consumer: C) -> C::Result
169    where
170        C: UnindexedConsumer<Self::Item>,
171    {
172        self.inner.map(|(k, _)| k).drive_unindexed(consumer)
173    }
174}
175
176impl<'a, T: Sync> ParallelIterator for ParIter<'a, T> {
177    type Item = &'a T;
178
179    fn drive_unindexed<C>(self, consumer: C) -> C::Result
180    where
181        C: UnindexedConsumer<Self::Item>,
182    {
183        self.inner.drive_unindexed(consumer)
184    }
185}
186
187impl<'a, T, S> ParallelIterator for ParDifference<'a, T, S>
188where
189    T: Eq + Hash + Sync,
190    S: BuildHasher + Sync,
191{
192    type Item = &'a T;
193
194    fn drive_unindexed<C>(self, consumer: C) -> C::Result
195    where
196        C: UnindexedConsumer<Self::Item>,
197    {
198        self.a
199            .into_par_iter()
200            .filter(|&x| !self.b.contains(x))
201            .drive_unindexed(consumer)
202    }
203}
204
205impl<'a, T, S> ParallelIterator for ParSymmetricDifference<'a, T, S>
206where
207    T: Eq + Hash + Sync,
208    S: BuildHasher + Sync,
209{
210    type Item = &'a T;
211
212    fn drive_unindexed<C>(self, consumer: C) -> C::Result
213    where
214        C: UnindexedConsumer<Self::Item>,
215    {
216        self.a
217            .par_difference(self.b)
218            .chain(self.b.par_difference(self.a))
219            .drive_unindexed(consumer)
220    }
221}
222
223impl<'a, T, S> ParallelIterator for ParIntersection<'a, T, S>
224where
225    T: Eq + Hash + Sync,
226    S: BuildHasher + Sync,
227{
228    type Item = &'a T;
229
230    fn drive_unindexed<C>(self, consumer: C) -> C::Result
231    where
232        C: UnindexedConsumer<Self::Item>,
233    {
234        self.a
235            .into_par_iter()
236            .filter(|&x| self.b.contains(x))
237            .drive_unindexed(consumer)
238    }
239}
240
241impl<'a, T, S> ParallelIterator for ParUnion<'a, T, S>
242where
243    T: Eq + Hash + Sync,
244    S: BuildHasher + Sync,
245{
246    type Item = &'a T;
247
248    fn drive_unindexed<C>(self, consumer: C) -> C::Result
249    where
250        C: UnindexedConsumer<Self::Item>,
251    {
252        self.a
253            .into_par_iter()
254            .chain(self.b.par_difference(self.a))
255            .drive_unindexed(consumer)
256    }
257}
258
259#[cfg(test)]
260mod test_par_set {
261    use super::HashSet;
262    use rayon::prelude::*;
263    use std::sync::atomic::{AtomicUsize, Ordering};
264
265    #[test]
266    fn test_disjoint() {
267        let mut xs = HashSet::new();
268        let mut ys = HashSet::new();
269        assert!(xs.par_is_disjoint(&ys));
270        assert!(ys.par_is_disjoint(&xs));
271        assert!(xs.insert(5));
272        assert!(ys.insert(11));
273        assert!(xs.par_is_disjoint(&ys));
274        assert!(ys.par_is_disjoint(&xs));
275        assert!(xs.insert(7));
276        assert!(xs.insert(19));
277        assert!(xs.insert(4));
278        assert!(ys.insert(2));
279        assert!(ys.insert(-11));
280        assert!(xs.par_is_disjoint(&ys));
281        assert!(ys.par_is_disjoint(&xs));
282        assert!(ys.insert(7));
283        assert!(!xs.par_is_disjoint(&ys));
284        assert!(!ys.par_is_disjoint(&xs));
285    }
286
287    #[test]
288    fn test_subset_and_superset() {
289        let mut a = HashSet::new();
290        assert!(a.insert(0));
291        assert!(a.insert(5));
292        assert!(a.insert(11));
293        assert!(a.insert(7));
294
295        let mut b = HashSet::new();
296        assert!(b.insert(0));
297        assert!(b.insert(7));
298        assert!(b.insert(19));
299        assert!(b.insert(250));
300        assert!(b.insert(11));
301        assert!(b.insert(200));
302
303        assert!(!a.par_is_subset(&b));
304        assert!(!a.par_is_superset(&b));
305        assert!(!b.par_is_subset(&a));
306        assert!(!b.par_is_superset(&a));
307
308        assert!(b.insert(5));
309
310        assert!(a.par_is_subset(&b));
311        assert!(!a.par_is_superset(&b));
312        assert!(!b.par_is_subset(&a));
313        assert!(b.par_is_superset(&a));
314    }
315
316    #[test]
317    fn test_iterate() {
318        let mut a = HashSet::new();
319        for i in 0..32 {
320            assert!(a.insert(i));
321        }
322        let observed = AtomicUsize::new(0);
323        a.par_iter().for_each(|k| {
324            observed.fetch_or(1 << *k, Ordering::Relaxed);
325        });
326        assert_eq!(observed.into_inner(), 0xFFFF_FFFF);
327    }
328
329    #[test]
330    fn test_intersection() {
331        let mut a = HashSet::new();
332        let mut b = HashSet::new();
333
334        assert!(a.insert(11));
335        assert!(a.insert(1));
336        assert!(a.insert(3));
337        assert!(a.insert(77));
338        assert!(a.insert(103));
339        assert!(a.insert(5));
340        assert!(a.insert(-5));
341
342        assert!(b.insert(2));
343        assert!(b.insert(11));
344        assert!(b.insert(77));
345        assert!(b.insert(-9));
346        assert!(b.insert(-42));
347        assert!(b.insert(5));
348        assert!(b.insert(3));
349
350        let expected = [3, 5, 11, 77];
351        let i = a
352            .par_intersection(&b)
353            .map(|x| {
354                assert!(expected.contains(x));
355                1
356            }).sum::<usize>();
357        assert_eq!(i, expected.len());
358    }
359
360    #[test]
361    fn test_difference() {
362        let mut a = HashSet::new();
363        let mut b = HashSet::new();
364
365        assert!(a.insert(1));
366        assert!(a.insert(3));
367        assert!(a.insert(5));
368        assert!(a.insert(9));
369        assert!(a.insert(11));
370
371        assert!(b.insert(3));
372        assert!(b.insert(9));
373
374        let expected = [1, 5, 11];
375        let i = a
376            .par_difference(&b)
377            .map(|x| {
378                assert!(expected.contains(x));
379                1
380            }).sum::<usize>();
381        assert_eq!(i, expected.len());
382    }
383
384    #[test]
385    fn test_symmetric_difference() {
386        let mut a = HashSet::new();
387        let mut b = HashSet::new();
388
389        assert!(a.insert(1));
390        assert!(a.insert(3));
391        assert!(a.insert(5));
392        assert!(a.insert(9));
393        assert!(a.insert(11));
394
395        assert!(b.insert(-2));
396        assert!(b.insert(3));
397        assert!(b.insert(9));
398        assert!(b.insert(14));
399        assert!(b.insert(22));
400
401        let expected = [-2, 1, 5, 11, 14, 22];
402        let i = a
403            .par_symmetric_difference(&b)
404            .map(|x| {
405                assert!(expected.contains(x));
406                1
407            }).sum::<usize>();
408        assert_eq!(i, expected.len());
409    }
410
411    #[test]
412    fn test_union() {
413        let mut a = HashSet::new();
414        let mut b = HashSet::new();
415
416        assert!(a.insert(1));
417        assert!(a.insert(3));
418        assert!(a.insert(5));
419        assert!(a.insert(9));
420        assert!(a.insert(11));
421        assert!(a.insert(16));
422        assert!(a.insert(19));
423        assert!(a.insert(24));
424
425        assert!(b.insert(-2));
426        assert!(b.insert(1));
427        assert!(b.insert(5));
428        assert!(b.insert(9));
429        assert!(b.insert(13));
430        assert!(b.insert(19));
431
432        let expected = [-2, 1, 3, 5, 9, 11, 13, 16, 19, 24];
433        let i = a
434            .par_union(&b)
435            .map(|x| {
436                assert!(expected.contains(x));
437                1
438            }).sum::<usize>();
439        assert_eq!(i, expected.len());
440    }
441
442    #[test]
443    fn test_from_iter() {
444        let xs = [1, 2, 3, 4, 5, 6, 7, 8, 9];
445
446        let set: HashSet<_> = xs.par_iter().cloned().collect();
447
448        for x in &xs {
449            assert!(set.contains(x));
450        }
451    }
452
453    #[test]
454    fn test_move_iter() {
455        let hs = {
456            let mut hs = HashSet::new();
457
458            hs.insert('a');
459            hs.insert('b');
460
461            hs
462        };
463
464        let v = hs.into_par_iter().collect::<Vec<char>>();
465        assert!(v == ['a', 'b'] || v == ['b', 'a']);
466    }
467
468    #[test]
469    fn test_eq() {
470        // These constants once happened to expose a bug in insert().
471        // I'm keeping them around to prevent a regression.
472        let mut s1 = HashSet::new();
473
474        s1.insert(1);
475        s1.insert(2);
476        s1.insert(3);
477
478        let mut s2 = HashSet::new();
479
480        s2.insert(1);
481        s2.insert(2);
482
483        assert!(!s1.par_eq(&s2));
484
485        s2.insert(3);
486
487        assert!(s1.par_eq(&s2));
488    }
489
490    #[test]
491    fn test_extend_ref() {
492        let mut a = HashSet::new();
493        a.insert(1);
494
495        a.par_extend(&[2, 3, 4][..]);
496
497        assert_eq!(a.len(), 4);
498        assert!(a.contains(&1));
499        assert!(a.contains(&2));
500        assert!(a.contains(&3));
501        assert!(a.contains(&4));
502
503        let mut b = HashSet::new();
504        b.insert(5);
505        b.insert(6);
506
507        a.par_extend(&b);
508
509        assert_eq!(a.len(), 6);
510        assert!(a.contains(&1));
511        assert!(a.contains(&2));
512        assert!(a.contains(&3));
513        assert!(a.contains(&4));
514        assert!(a.contains(&5));
515        assert!(a.contains(&6));
516    }
517}