1use super::map;
4use crate::hash_set::HashSet;
5use crate::raw::{Allocator, Global};
6use core::hash::{BuildHasher, Hash};
7use rayon::iter::plumbing::UnindexedConsumer;
8use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator};
9
10pub struct IntoParIter<T, A: Allocator = Global> {
20 inner: map::IntoParIter<T, (), A>,
21}
22
23impl<T: Send, A: Allocator + Send> ParallelIterator for IntoParIter<T, A> {
24 type Item = T;
25
26 fn drive_unindexed<C>(self, consumer: C) -> C::Result
27 where
28 C: UnindexedConsumer<Self::Item>,
29 {
30 self.inner.map(|(k, _)| k).drive_unindexed(consumer)
31 }
32}
33
34pub struct ParDrain<'a, T, A: Allocator = Global> {
42 inner: map::ParDrain<'a, T, (), A>,
43}
44
45impl<T: Send, A: Allocator + Send + Sync> ParallelIterator for ParDrain<'_, T, A> {
46 type Item = T;
47
48 fn drive_unindexed<C>(self, consumer: C) -> C::Result
49 where
50 C: UnindexedConsumer<Self::Item>,
51 {
52 self.inner.map(|(k, _)| k).drive_unindexed(consumer)
53 }
54}
55
56pub struct ParIter<'a, T> {
66 inner: map::ParKeys<'a, T, ()>,
67}
68
69impl<'a, T: Sync> ParallelIterator for ParIter<'a, T> {
70 type Item = &'a T;
71
72 fn drive_unindexed<C>(self, consumer: C) -> C::Result
73 where
74 C: UnindexedConsumer<Self::Item>,
75 {
76 self.inner.drive_unindexed(consumer)
77 }
78}
79
80pub struct ParDifference<'a, T, S, A: Allocator = Global> {
89 a: &'a HashSet<T, S, A>,
90 b: &'a HashSet<T, S, A>,
91}
92
93impl<'a, T, S, A> ParallelIterator for ParDifference<'a, T, S, A>
94where
95 T: Eq + Hash + Sync,
96 S: BuildHasher + Sync,
97 A: Allocator + Sync,
98{
99 type Item = &'a T;
100
101 fn drive_unindexed<C>(self, consumer: C) -> C::Result
102 where
103 C: UnindexedConsumer<Self::Item>,
104 {
105 self.a
106 .into_par_iter()
107 .filter(|&x| !self.b.contains(x))
108 .drive_unindexed(consumer)
109 }
110}
111
112pub struct ParSymmetricDifference<'a, T, S, A: Allocator = Global> {
122 a: &'a HashSet<T, S, A>,
123 b: &'a HashSet<T, S, A>,
124}
125
126impl<'a, T, S, A> ParallelIterator for ParSymmetricDifference<'a, T, S, A>
127where
128 T: Eq + Hash + Sync,
129 S: BuildHasher + Sync,
130 A: Allocator + Sync,
131{
132 type Item = &'a T;
133
134 fn drive_unindexed<C>(self, consumer: C) -> C::Result
135 where
136 C: UnindexedConsumer<Self::Item>,
137 {
138 self.a
139 .par_difference(self.b)
140 .chain(self.b.par_difference(self.a))
141 .drive_unindexed(consumer)
142 }
143}
144
145pub struct ParIntersection<'a, T, S, A: Allocator = Global> {
154 a: &'a HashSet<T, S, A>,
155 b: &'a HashSet<T, S, A>,
156}
157
158impl<'a, T, S, A> ParallelIterator for ParIntersection<'a, T, S, A>
159where
160 T: Eq + Hash + Sync,
161 S: BuildHasher + Sync,
162 A: Allocator + Sync,
163{
164 type Item = &'a T;
165
166 fn drive_unindexed<C>(self, consumer: C) -> C::Result
167 where
168 C: UnindexedConsumer<Self::Item>,
169 {
170 self.a
171 .into_par_iter()
172 .filter(|&x| self.b.contains(x))
173 .drive_unindexed(consumer)
174 }
175}
176
177pub struct ParUnion<'a, T, S, A: Allocator = Global> {
185 a: &'a HashSet<T, S, A>,
186 b: &'a HashSet<T, S, A>,
187}
188
189impl<'a, T, S, A> ParallelIterator for ParUnion<'a, T, S, A>
190where
191 T: Eq + Hash + Sync,
192 S: BuildHasher + Sync,
193 A: Allocator + Sync,
194{
195 type Item = &'a T;
196
197 fn drive_unindexed<C>(self, consumer: C) -> C::Result
198 where
199 C: UnindexedConsumer<Self::Item>,
200 {
201 let (smaller, larger) = if self.a.len() <= self.b.len() {
204 (self.a, self.b)
205 } else {
206 (self.b, self.a)
207 };
208 larger
209 .into_par_iter()
210 .chain(smaller.par_difference(larger))
211 .drive_unindexed(consumer)
212 }
213}
214
215impl<T, S, A> HashSet<T, S, A>
216where
217 T: Eq + Hash + Sync,
218 S: BuildHasher + Sync,
219 A: Allocator + Sync,
220{
221 #[cfg_attr(feature = "inline-more", inline)]
224 pub fn par_union<'a>(&'a self, other: &'a Self) -> ParUnion<'a, T, S, A> {
225 ParUnion { a: self, b: other }
226 }
227
228 #[cfg_attr(feature = "inline-more", inline)]
231 pub fn par_difference<'a>(&'a self, other: &'a Self) -> ParDifference<'a, T, S, A> {
232 ParDifference { a: self, b: other }
233 }
234
235 #[cfg_attr(feature = "inline-more", inline)]
238 pub fn par_symmetric_difference<'a>(
239 &'a self,
240 other: &'a Self,
241 ) -> ParSymmetricDifference<'a, T, S, A> {
242 ParSymmetricDifference { a: self, b: other }
243 }
244
245 #[cfg_attr(feature = "inline-more", inline)]
248 pub fn par_intersection<'a>(&'a self, other: &'a Self) -> ParIntersection<'a, T, S, A> {
249 ParIntersection { a: self, b: other }
250 }
251
252 pub fn par_is_disjoint(&self, other: &Self) -> bool {
257 self.into_par_iter().all(|x| !other.contains(x))
258 }
259
260 pub fn par_is_subset(&self, other: &Self) -> bool {
265 if self.len() <= other.len() {
266 self.into_par_iter().all(|x| other.contains(x))
267 } else {
268 false
269 }
270 }
271
272 pub fn par_is_superset(&self, other: &Self) -> bool {
277 other.par_is_subset(self)
278 }
279
280 pub fn par_eq(&self, other: &Self) -> bool {
285 self.len() == other.len() && self.par_is_subset(other)
286 }
287}
288
289impl<T, S, A> HashSet<T, S, A>
290where
291 T: Eq + Hash + Send,
292 A: Allocator + Send,
293{
294 #[cfg_attr(feature = "inline-more", inline)]
297 pub fn par_drain(&mut self) -> ParDrain<'_, T, A> {
298 ParDrain {
299 inner: self.map.par_drain(),
300 }
301 }
302}
303
304impl<T: Send, S, A: Allocator + Send> IntoParallelIterator for HashSet<T, S, A> {
305 type Item = T;
306 type Iter = IntoParIter<T, A>;
307
308 #[cfg_attr(feature = "inline-more", inline)]
309 fn into_par_iter(self) -> Self::Iter {
310 IntoParIter {
311 inner: self.map.into_par_iter(),
312 }
313 }
314}
315
316impl<'a, T: Sync, S, A: Allocator> IntoParallelIterator for &'a HashSet<T, S, A> {
317 type Item = &'a T;
318 type Iter = ParIter<'a, T>;
319
320 #[cfg_attr(feature = "inline-more", inline)]
321 fn into_par_iter(self) -> Self::Iter {
322 ParIter {
323 inner: self.map.par_keys(),
324 }
325 }
326}
327
328impl<T, S> FromParallelIterator<T> for HashSet<T, S, Global>
330where
331 T: Eq + Hash + Send,
332 S: BuildHasher + Default,
333{
334 fn from_par_iter<P>(par_iter: P) -> Self
335 where
336 P: IntoParallelIterator<Item = T>,
337 {
338 let mut set = HashSet::default();
339 set.par_extend(par_iter);
340 set
341 }
342}
343
344impl<T, S> ParallelExtend<T> for HashSet<T, S, Global>
346where
347 T: Eq + Hash + Send,
348 S: BuildHasher,
349{
350 fn par_extend<I>(&mut self, par_iter: I)
351 where
352 I: IntoParallelIterator<Item = T>,
353 {
354 extend(self, par_iter);
355 }
356}
357
358impl<'a, T, S> ParallelExtend<&'a T> for HashSet<T, S, Global>
360where
361 T: 'a + Copy + Eq + Hash + Sync,
362 S: BuildHasher,
363{
364 fn par_extend<I>(&mut self, par_iter: I)
365 where
366 I: IntoParallelIterator<Item = &'a T>,
367 {
368 extend(self, par_iter);
369 }
370}
371
372fn extend<T, S, I, A>(set: &mut HashSet<T, S, A>, par_iter: I)
374where
375 T: Eq + Hash,
376 S: BuildHasher,
377 A: Allocator,
378 I: IntoParallelIterator,
379 HashSet<T, S, A>: Extend<I::Item>,
380{
381 let (list, len) = super::helpers::collect(par_iter);
382
383 let reserve = if set.is_empty() { len } else { (len + 1) / 2 };
388 set.reserve(reserve);
389 for vec in list {
390 set.extend(vec);
391 }
392}
393
394#[cfg(test)]
395mod test_par_set {
396 use alloc::vec::Vec;
397 use core::sync::atomic::{AtomicUsize, Ordering};
398
399 use rayon::prelude::*;
400
401 use crate::hash_set::HashSet;
402
403 #[test]
404 fn test_disjoint() {
405 let mut xs = HashSet::new();
406 let mut ys = HashSet::new();
407 assert!(xs.par_is_disjoint(&ys));
408 assert!(ys.par_is_disjoint(&xs));
409 assert!(xs.insert(5));
410 assert!(ys.insert(11));
411 assert!(xs.par_is_disjoint(&ys));
412 assert!(ys.par_is_disjoint(&xs));
413 assert!(xs.insert(7));
414 assert!(xs.insert(19));
415 assert!(xs.insert(4));
416 assert!(ys.insert(2));
417 assert!(ys.insert(-11));
418 assert!(xs.par_is_disjoint(&ys));
419 assert!(ys.par_is_disjoint(&xs));
420 assert!(ys.insert(7));
421 assert!(!xs.par_is_disjoint(&ys));
422 assert!(!ys.par_is_disjoint(&xs));
423 }
424
425 #[test]
426 fn test_subset_and_superset() {
427 let mut a = HashSet::new();
428 assert!(a.insert(0));
429 assert!(a.insert(5));
430 assert!(a.insert(11));
431 assert!(a.insert(7));
432
433 let mut b = HashSet::new();
434 assert!(b.insert(0));
435 assert!(b.insert(7));
436 assert!(b.insert(19));
437 assert!(b.insert(250));
438 assert!(b.insert(11));
439 assert!(b.insert(200));
440
441 assert!(!a.par_is_subset(&b));
442 assert!(!a.par_is_superset(&b));
443 assert!(!b.par_is_subset(&a));
444 assert!(!b.par_is_superset(&a));
445
446 assert!(b.insert(5));
447
448 assert!(a.par_is_subset(&b));
449 assert!(!a.par_is_superset(&b));
450 assert!(!b.par_is_subset(&a));
451 assert!(b.par_is_superset(&a));
452 }
453
454 #[test]
455 fn test_iterate() {
456 let mut a = HashSet::new();
457 for i in 0..32 {
458 assert!(a.insert(i));
459 }
460 let observed = AtomicUsize::new(0);
461 a.par_iter().for_each(|k| {
462 observed.fetch_or(1 << *k, Ordering::Relaxed);
463 });
464 assert_eq!(observed.into_inner(), 0xFFFF_FFFF);
465 }
466
467 #[test]
468 fn test_intersection() {
469 let mut a = HashSet::new();
470 let mut b = HashSet::new();
471
472 assert!(a.insert(11));
473 assert!(a.insert(1));
474 assert!(a.insert(3));
475 assert!(a.insert(77));
476 assert!(a.insert(103));
477 assert!(a.insert(5));
478 assert!(a.insert(-5));
479
480 assert!(b.insert(2));
481 assert!(b.insert(11));
482 assert!(b.insert(77));
483 assert!(b.insert(-9));
484 assert!(b.insert(-42));
485 assert!(b.insert(5));
486 assert!(b.insert(3));
487
488 let expected = [3, 5, 11, 77];
489 let i = a
490 .par_intersection(&b)
491 .map(|x| {
492 assert!(expected.contains(x));
493 1
494 })
495 .sum::<usize>();
496 assert_eq!(i, expected.len());
497 }
498
499 #[test]
500 fn test_difference() {
501 let mut a = HashSet::new();
502 let mut b = HashSet::new();
503
504 assert!(a.insert(1));
505 assert!(a.insert(3));
506 assert!(a.insert(5));
507 assert!(a.insert(9));
508 assert!(a.insert(11));
509
510 assert!(b.insert(3));
511 assert!(b.insert(9));
512
513 let expected = [1, 5, 11];
514 let i = a
515 .par_difference(&b)
516 .map(|x| {
517 assert!(expected.contains(x));
518 1
519 })
520 .sum::<usize>();
521 assert_eq!(i, expected.len());
522 }
523
524 #[test]
525 fn test_symmetric_difference() {
526 let mut a = HashSet::new();
527 let mut b = HashSet::new();
528
529 assert!(a.insert(1));
530 assert!(a.insert(3));
531 assert!(a.insert(5));
532 assert!(a.insert(9));
533 assert!(a.insert(11));
534
535 assert!(b.insert(-2));
536 assert!(b.insert(3));
537 assert!(b.insert(9));
538 assert!(b.insert(14));
539 assert!(b.insert(22));
540
541 let expected = [-2, 1, 5, 11, 14, 22];
542 let i = a
543 .par_symmetric_difference(&b)
544 .map(|x| {
545 assert!(expected.contains(x));
546 1
547 })
548 .sum::<usize>();
549 assert_eq!(i, expected.len());
550 }
551
552 #[test]
553 fn test_union() {
554 let mut a = HashSet::new();
555 let mut b = HashSet::new();
556
557 assert!(a.insert(1));
558 assert!(a.insert(3));
559 assert!(a.insert(5));
560 assert!(a.insert(9));
561 assert!(a.insert(11));
562 assert!(a.insert(16));
563 assert!(a.insert(19));
564 assert!(a.insert(24));
565
566 assert!(b.insert(-2));
567 assert!(b.insert(1));
568 assert!(b.insert(5));
569 assert!(b.insert(9));
570 assert!(b.insert(13));
571 assert!(b.insert(19));
572
573 let expected = [-2, 1, 3, 5, 9, 11, 13, 16, 19, 24];
574 let i = a
575 .par_union(&b)
576 .map(|x| {
577 assert!(expected.contains(x));
578 1
579 })
580 .sum::<usize>();
581 assert_eq!(i, expected.len());
582 }
583
584 #[test]
585 fn test_from_iter() {
586 let xs = [1, 2, 3, 4, 5, 6, 7, 8, 9];
587
588 let set: HashSet<_> = xs.par_iter().cloned().collect();
589
590 for x in &xs {
591 assert!(set.contains(x));
592 }
593 }
594
595 #[test]
596 fn test_move_iter() {
597 let hs = {
598 let mut hs = HashSet::new();
599
600 hs.insert('a');
601 hs.insert('b');
602
603 hs
604 };
605
606 let v = hs.into_par_iter().collect::<Vec<char>>();
607 assert!(v == ['a', 'b'] || v == ['b', 'a']);
608 }
609
610 #[test]
611 fn test_eq() {
612 let mut s1 = HashSet::new();
615
616 s1.insert(1);
617 s1.insert(2);
618 s1.insert(3);
619
620 let mut s2 = HashSet::new();
621
622 s2.insert(1);
623 s2.insert(2);
624
625 assert!(!s1.par_eq(&s2));
626
627 s2.insert(3);
628
629 assert!(s1.par_eq(&s2));
630 }
631
632 #[test]
633 fn test_extend_ref() {
634 let mut a = HashSet::new();
635 a.insert(1);
636
637 a.par_extend(&[2, 3, 4][..]);
638
639 assert_eq!(a.len(), 4);
640 assert!(a.contains(&1));
641 assert!(a.contains(&2));
642 assert!(a.contains(&3));
643 assert!(a.contains(&4));
644
645 let mut b = HashSet::new();
646 b.insert(5);
647 b.insert(6);
648
649 a.par_extend(&b);
650
651 assert_eq!(a.len(), 6);
652 assert!(a.contains(&1));
653 assert!(a.contains(&2));
654 assert!(a.contains(&3));
655 assert!(a.contains(&4));
656 assert!(a.contains(&5));
657 assert!(a.contains(&6));
658 }
659}