1use crate::hash_set::HashSet;
4use core::hash::{BuildHasher, Hash};
5use rayon_::iter::plumbing::UnindexedConsumer;
6use rayon_::iter::{FromParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator};
7
8pub struct ParIter<'a, T, S> {
18 set: &'a HashSet<T, S>,
19}
20
21impl<'a, T: Sync, S: Sync> ParallelIterator for ParIter<'a, T, S> {
22 type Item = &'a T;
23
24 fn drive_unindexed<C>(self, consumer: C) -> C::Result
25 where
26 C: UnindexedConsumer<Self::Item>,
27 {
28 self.set.map.par_keys().drive_unindexed(consumer)
29 }
30}
31
32pub struct ParDifference<'a, T, S> {
41 a: &'a HashSet<T, S>,
42 b: &'a HashSet<T, S>,
43}
44
45impl<'a, T, S> ParallelIterator for ParDifference<'a, T, S>
46where
47 T: Eq + Hash + Sync,
48 S: BuildHasher + Sync,
49{
50 type Item = &'a T;
51
52 fn drive_unindexed<C>(self, consumer: C) -> C::Result
53 where
54 C: UnindexedConsumer<Self::Item>,
55 {
56 self.a
57 .into_par_iter()
58 .filter(|&x| !self.b.contains(x))
59 .drive_unindexed(consumer)
60 }
61}
62
63pub struct ParSymmetricDifference<'a, T, S> {
73 a: &'a HashSet<T, S>,
74 b: &'a HashSet<T, S>,
75}
76
77impl<'a, T, S> ParallelIterator for ParSymmetricDifference<'a, T, S>
78where
79 T: Eq + Hash + Sync,
80 S: BuildHasher + Sync,
81{
82 type Item = &'a T;
83
84 fn drive_unindexed<C>(self, consumer: C) -> C::Result
85 where
86 C: UnindexedConsumer<Self::Item>,
87 {
88 self.a
89 .par_difference(self.b)
90 .chain(self.b.par_difference(self.a))
91 .drive_unindexed(consumer)
92 }
93}
94
95pub struct ParIntersection<'a, T, S> {
104 a: &'a HashSet<T, S>,
105 b: &'a HashSet<T, S>,
106}
107
108impl<'a, T, S> ParallelIterator for ParIntersection<'a, T, S>
109where
110 T: Eq + Hash + Sync,
111 S: BuildHasher + Sync,
112{
113 type Item = &'a T;
114
115 fn drive_unindexed<C>(self, consumer: C) -> C::Result
116 where
117 C: UnindexedConsumer<Self::Item>,
118 {
119 self.a
120 .into_par_iter()
121 .filter(|&x| self.b.contains(x))
122 .drive_unindexed(consumer)
123 }
124}
125
126pub struct ParUnion<'a, T, S> {
134 a: &'a HashSet<T, S>,
135 b: &'a HashSet<T, S>,
136}
137
138impl<'a, T, S> ParallelIterator for ParUnion<'a, T, S>
139where
140 T: Eq + Hash + Sync,
141 S: BuildHasher + Sync,
142{
143 type Item = &'a T;
144
145 fn drive_unindexed<C>(self, consumer: C) -> C::Result
146 where
147 C: UnindexedConsumer<Self::Item>,
148 {
149 self.a
150 .into_par_iter()
151 .chain(self.b.par_difference(self.a))
152 .drive_unindexed(consumer)
153 }
154}
155
156impl<T, S> HashSet<T, S>
157where
158 T: Eq + Hash + Sync,
159 S: BuildHasher + Sync,
160{
161 #[cfg_attr(feature = "inline-more", inline)]
164 pub fn par_difference<'a>(&'a self, other: &'a Self) -> ParDifference<'a, T, S> {
165 ParDifference { a: self, b: other }
166 }
167
168 #[cfg_attr(feature = "inline-more", inline)]
171 pub fn par_symmetric_difference<'a>(
172 &'a self,
173 other: &'a Self,
174 ) -> ParSymmetricDifference<'a, T, S> {
175 ParSymmetricDifference { a: self, b: other }
176 }
177
178 #[cfg_attr(feature = "inline-more", inline)]
181 pub fn par_intersection<'a>(&'a self, other: &'a Self) -> ParIntersection<'a, T, S> {
182 ParIntersection { a: self, b: other }
183 }
184
185 #[cfg_attr(feature = "inline-more", inline)]
188 pub fn par_union<'a>(&'a self, other: &'a Self) -> ParUnion<'a, T, S> {
189 ParUnion { a: self, b: other }
190 }
191
192 pub fn par_is_disjoint(&self, other: &Self) -> bool {
197 self.into_par_iter().all(|x| !other.contains(x))
198 }
199
200 pub fn par_is_subset(&self, other: &Self) -> bool {
205 if self.len() <= other.len() {
206 self.into_par_iter().all(|x| other.contains(x))
207 } else {
208 false
209 }
210 }
211
212 pub fn par_is_superset(&self, other: &Self) -> bool {
217 other.par_is_subset(self)
218 }
219
220 pub fn par_eq(&self, other: &Self) -> bool {
225 self.len() == other.len() && self.par_is_subset(other)
226 }
227}
228
229impl<'a, T: Sync, S: Sync> IntoParallelIterator for &'a HashSet<T, S> {
230 type Item = &'a T;
231 type Iter = ParIter<'a, T, S>;
232
233 #[cfg_attr(feature = "inline-more", inline)]
234 fn into_par_iter(self) -> Self::Iter {
235 ParIter { set: self }
236 }
237}
238
239impl<T, S> FromParallelIterator<T> for HashSet<T, S>
241where
242 T: Eq + Hash + Send,
243 S: BuildHasher + Default,
244{
245 fn from_par_iter<P>(par_iter: P) -> Self
246 where
247 P: IntoParallelIterator<Item = T>,
248 {
249 let mut set = HashSet::default();
250 set.par_extend(par_iter);
251 set
252 }
253}
254
255impl<T, S> ParallelExtend<T> for HashSet<T, S>
257where
258 T: Eq + Hash + Send,
259 S: BuildHasher,
260{
261 fn par_extend<I>(&mut self, par_iter: I)
262 where
263 I: IntoParallelIterator<Item = T>,
264 {
265 extend(self, par_iter);
266 }
267}
268
269impl<'a, T, S> ParallelExtend<&'a T> for HashSet<T, S>
271where
272 T: 'a + Copy + Eq + Hash + Sync,
273 S: BuildHasher,
274{
275 fn par_extend<I>(&mut self, par_iter: I)
276 where
277 I: IntoParallelIterator<Item = &'a T>,
278 {
279 extend(self, par_iter);
280 }
281}
282
283fn extend<T, S, I>(set: &mut HashSet<T, S>, par_iter: I)
285where
286 T: Eq + Hash,
287 S: BuildHasher,
288 I: IntoParallelIterator,
289 HashSet<T, S>: Extend<I::Item>,
290{
291 let (list, len) = super::helpers::collect(par_iter);
292
293 let reserve = if set.is_empty() { len } else { (len + 1) / 2 };
298 set.reserve(reserve);
299 for vec in list {
300 set.extend(vec);
301 }
302}
303
304#[cfg(test)]
305mod test_par_set {
306 use alloc::vec::Vec;
307 use core::sync::atomic::{AtomicUsize, Ordering};
308
309 use rayon_::prelude::*;
310
311 use crate::hash_set::HashSet;
312
313 #[test]
314 fn test_disjoint() {
315 let mut xs = HashSet::new();
316 let mut ys = HashSet::new();
317 assert!(xs.par_is_disjoint(&ys));
318 assert!(ys.par_is_disjoint(&xs));
319 assert!(xs.insert(5));
320 assert!(ys.insert(11));
321 assert!(xs.par_is_disjoint(&ys));
322 assert!(ys.par_is_disjoint(&xs));
323 assert!(xs.insert(7));
324 assert!(xs.insert(19));
325 assert!(xs.insert(4));
326 assert!(ys.insert(2));
327 assert!(ys.insert(-11));
328 assert!(xs.par_is_disjoint(&ys));
329 assert!(ys.par_is_disjoint(&xs));
330 assert!(ys.insert(7));
331 assert!(!xs.par_is_disjoint(&ys));
332 assert!(!ys.par_is_disjoint(&xs));
333 }
334
335 #[test]
336 fn test_subset_and_superset() {
337 let mut a = HashSet::new();
338 assert!(a.insert(0));
339 assert!(a.insert(5));
340 assert!(a.insert(11));
341 assert!(a.insert(7));
342
343 let mut b = HashSet::new();
344 assert!(b.insert(0));
345 assert!(b.insert(7));
346 assert!(b.insert(19));
347 assert!(b.insert(250));
348 assert!(b.insert(11));
349 assert!(b.insert(200));
350
351 assert!(!a.par_is_subset(&b));
352 assert!(!a.par_is_superset(&b));
353 assert!(!b.par_is_subset(&a));
354 assert!(!b.par_is_superset(&a));
355
356 assert!(b.insert(5));
357
358 assert!(a.par_is_subset(&b));
359 assert!(!a.par_is_superset(&b));
360 assert!(!b.par_is_subset(&a));
361 assert!(b.par_is_superset(&a));
362 }
363
364 #[test]
365 fn test_iterate() {
366 let mut a = HashSet::new();
367 for i in 0..32 {
368 assert!(a.insert(i));
369 }
370 let observed = AtomicUsize::new(0);
371 a.par_iter().for_each(|k| {
372 observed.fetch_or(1 << *k, Ordering::Relaxed);
373 });
374 assert_eq!(observed.into_inner(), 0xFFFF_FFFF);
375 }
376
377 #[test]
378 fn test_intersection() {
379 let mut a = HashSet::new();
380 let mut b = HashSet::new();
381
382 assert!(a.insert(11));
383 assert!(a.insert(1));
384 assert!(a.insert(3));
385 assert!(a.insert(77));
386 assert!(a.insert(103));
387 assert!(a.insert(5));
388 assert!(a.insert(-5));
389
390 assert!(b.insert(2));
391 assert!(b.insert(11));
392 assert!(b.insert(77));
393 assert!(b.insert(-9));
394 assert!(b.insert(-42));
395 assert!(b.insert(5));
396 assert!(b.insert(3));
397
398 let expected = [3, 5, 11, 77];
399 let i = a
400 .par_intersection(&b)
401 .map(|x| {
402 assert!(expected.contains(x));
403 1
404 })
405 .sum::<usize>();
406 assert_eq!(i, expected.len());
407 }
408
409 #[test]
410 fn test_difference() {
411 let mut a = HashSet::new();
412 let mut b = HashSet::new();
413
414 assert!(a.insert(1));
415 assert!(a.insert(3));
416 assert!(a.insert(5));
417 assert!(a.insert(9));
418 assert!(a.insert(11));
419
420 assert!(b.insert(3));
421 assert!(b.insert(9));
422
423 let expected = [1, 5, 11];
424 let i = a
425 .par_difference(&b)
426 .map(|x| {
427 assert!(expected.contains(x));
428 1
429 })
430 .sum::<usize>();
431 assert_eq!(i, expected.len());
432 }
433
434 #[test]
435 fn test_symmetric_difference() {
436 let mut a = HashSet::new();
437 let mut b = HashSet::new();
438
439 assert!(a.insert(1));
440 assert!(a.insert(3));
441 assert!(a.insert(5));
442 assert!(a.insert(9));
443 assert!(a.insert(11));
444
445 assert!(b.insert(-2));
446 assert!(b.insert(3));
447 assert!(b.insert(9));
448 assert!(b.insert(14));
449 assert!(b.insert(22));
450
451 let expected = [-2, 1, 5, 11, 14, 22];
452 let i = a
453 .par_symmetric_difference(&b)
454 .map(|x| {
455 assert!(expected.contains(x));
456 1
457 })
458 .sum::<usize>();
459 assert_eq!(i, expected.len());
460 }
461
462 #[test]
463 fn test_union() {
464 let mut a = HashSet::new();
465 let mut b = HashSet::new();
466
467 assert!(a.insert(1));
468 assert!(a.insert(3));
469 assert!(a.insert(5));
470 assert!(a.insert(9));
471 assert!(a.insert(11));
472 assert!(a.insert(16));
473 assert!(a.insert(19));
474 assert!(a.insert(24));
475
476 assert!(b.insert(-2));
477 assert!(b.insert(1));
478 assert!(b.insert(5));
479 assert!(b.insert(9));
480 assert!(b.insert(13));
481 assert!(b.insert(19));
482
483 let expected = [-2, 1, 3, 5, 9, 11, 13, 16, 19, 24];
484 let i = a
485 .par_union(&b)
486 .map(|x| {
487 assert!(expected.contains(x));
488 1
489 })
490 .sum::<usize>();
491 assert_eq!(i, expected.len());
492 }
493
494 #[test]
495 fn test_from_iter() {
496 let xs = [1, 2, 3, 4, 5, 6, 7, 8, 9];
497
498 let set: HashSet<_> = xs.par_iter().cloned().collect();
499
500 for x in &xs {
501 assert!(set.contains(x));
502 }
503 }
504
505 #[test]
506 fn test_move_iter() {
507 let hs = {
508 let mut hs = HashSet::new();
509
510 hs.insert('a');
511 hs.insert('b');
512
513 hs
514 };
515
516 let v = (&hs).into_par_iter().copied().collect::<Vec<char>>();
517 assert!(v == ['a', 'b'] || v == ['b', 'a']);
518 }
519
520 #[test]
521 fn test_eq() {
522 let mut s1 = HashSet::new();
525
526 s1.insert(1);
527 s1.insert(2);
528 s1.insert(3);
529
530 let mut s2 = HashSet::new();
531
532 s2.insert(1);
533 s2.insert(2);
534
535 assert!(!s1.par_eq(&s2));
536
537 s2.insert(3);
538
539 assert!(s1.par_eq(&s2));
540 }
541
542 #[test]
543 fn test_extend_ref() {
544 let mut a = HashSet::new();
545 a.insert(1);
546
547 a.par_extend(&[2, 3, 4][..]);
548
549 assert_eq!(a.len(), 4);
550 assert!(a.contains(&1));
551 assert!(a.contains(&2));
552 assert!(a.contains(&3));
553 assert!(a.contains(&4));
554
555 let mut b = HashSet::new();
556 b.insert(5);
557 b.insert(6);
558
559 a.par_extend(&b);
560
561 assert_eq!(a.len(), 6);
562 assert!(a.contains(&1));
563 assert!(a.contains(&2));
564 assert!(a.contains(&3));
565 assert!(a.contains(&4));
566 assert!(a.contains(&5));
567 assert!(a.contains(&6));
568 }
569}