1use rayon::iter::plumbing::UnindexedConsumer;
2use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator};
4use std::hash::{BuildHasher, Hash};
5
6use super::map;
7use crate::HashSet;
8
9pub struct ParIntoIter<T: Send> {
10 inner: map::ParIntoIter<T, ()>,
11}
12
13pub struct ParIter<'a, T: Sync + 'a> {
14 inner: map::ParKeys<'a, T, ()>,
15}
16
17pub struct ParDifference<'a, T: Sync + 'a, S: Sync + 'a> {
18 a: &'a HashSet<T, S>,
19 b: &'a HashSet<T, S>,
20}
21
22pub struct ParSymmetricDifference<'a, T: Sync + 'a, S: Sync + 'a> {
23 a: &'a HashSet<T, S>,
24 b: &'a HashSet<T, S>,
25}
26
27pub struct ParIntersection<'a, T: Sync + 'a, S: Sync + 'a> {
28 a: &'a HashSet<T, S>,
29 b: &'a HashSet<T, S>,
30}
31
32pub struct ParUnion<'a, T: Sync + 'a, S: Sync + 'a> {
33 a: &'a HashSet<T, S>,
34 b: &'a HashSet<T, S>,
35}
36
37impl<T, S> HashSet<T, S>
38where
39 T: Eq + Hash + Sync,
40 S: BuildHasher + Sync,
41{
42 pub fn par_difference<'a>(&'a self, other: &'a Self) -> ParDifference<'a, T, S> {
43 ParDifference { a: self, b: other }
44 }
45
46 pub fn par_symmetric_difference<'a>(
47 &'a self,
48 other: &'a Self,
49 ) -> ParSymmetricDifference<'a, T, S> {
50 ParSymmetricDifference { a: self, b: other }
51 }
52
53 pub fn par_intersection<'a>(&'a self, other: &'a Self) -> ParIntersection<'a, T, S> {
54 ParIntersection { a: self, b: other }
55 }
56
57 pub fn par_union<'a>(&'a self, other: &'a Self) -> ParUnion<'a, T, S> {
58 ParUnion { a: self, b: other }
59 }
60
61 pub fn par_is_disjoint(&self, other: &Self) -> bool {
62 self.into_par_iter().all(|x| !other.contains(x))
63 }
64
65 pub fn par_is_subset(&self, other: &Self) -> bool {
66 self.into_par_iter().all(|x| other.contains(x))
67 }
68
69 pub fn par_is_superset(&self, other: &Self) -> bool {
70 other.is_subset(self)
71 }
72
73 pub fn par_eq(&self, other: &Self) -> bool {
74 self.len() == other.len() && self.par_is_subset(other)
75 }
76}
77
78impl<T: Send, S> IntoParallelIterator for HashSet<T, S> {
79 type Item = T;
80 type Iter = ParIntoIter<T>;
81
82 fn into_par_iter(self) -> Self::Iter {
83 ParIntoIter {
84 inner: self.map.into_par_iter(),
85 }
86 }
87}
88
89impl<'a, T: Sync, S> IntoParallelIterator for &'a HashSet<T, S> {
90 type Item = &'a T;
91 type Iter = ParIter<'a, T>;
92
93 fn into_par_iter(self) -> Self::Iter {
94 ParIter {
95 inner: self.map.par_keys(),
96 }
97 }
98}
99
100impl<T, S> FromParallelIterator<T> for HashSet<T, S>
102where
103 T: Eq + Hash + Send,
104 S: BuildHasher + Default + Send,
105{
106 fn from_par_iter<P>(par_iter: P) -> Self
107 where
108 P: IntoParallelIterator<Item = T>,
109 {
110 let mut set = HashSet::default();
111 set.par_extend(par_iter);
112 set
113 }
114}
115
116impl<T, S> ParallelExtend<T> for HashSet<T, S>
118where
119 T: Eq + Hash + Send,
120 S: BuildHasher + Send,
121{
122 fn par_extend<I>(&mut self, par_iter: I)
123 where
124 I: IntoParallelIterator<Item = T>,
125 {
126 extend(self, par_iter);
127 }
128}
129
130impl<'a, T, S> ParallelExtend<&'a T> for HashSet<T, S>
132where
133 T: 'a + Copy + Eq + Hash + Send + Sync,
134 S: BuildHasher + Send,
135{
136 fn par_extend<I>(&mut self, par_iter: I)
137 where
138 I: IntoParallelIterator<Item = &'a T>,
139 {
140 extend(self, par_iter);
141 }
142}
143
144fn extend<T, S, I>(set: &mut HashSet<T, S>, par_iter: I)
146where
147 T: Eq + Hash,
148 S: BuildHasher,
149 I: IntoParallelIterator,
150 HashSet<T, S>: Extend<I::Item>,
151{
152 let (list, len) = super::collect(par_iter);
153
154 let reserve = if set.is_empty() { len } else { (len + 1) / 2 };
159 set.reserve(reserve);
160 for vec in list {
161 set.extend(vec);
162 }
163}
164
165impl<T: Send> ParallelIterator for ParIntoIter<T> {
166 type Item = T;
167
168 fn drive_unindexed<C>(self, consumer: C) -> C::Result
169 where
170 C: UnindexedConsumer<Self::Item>,
171 {
172 self.inner.map(|(k, _)| k).drive_unindexed(consumer)
173 }
174}
175
176impl<'a, T: Sync> ParallelIterator for ParIter<'a, T> {
177 type Item = &'a T;
178
179 fn drive_unindexed<C>(self, consumer: C) -> C::Result
180 where
181 C: UnindexedConsumer<Self::Item>,
182 {
183 self.inner.drive_unindexed(consumer)
184 }
185}
186
187impl<'a, T, S> ParallelIterator for ParDifference<'a, T, S>
188where
189 T: Eq + Hash + Sync,
190 S: BuildHasher + Sync,
191{
192 type Item = &'a T;
193
194 fn drive_unindexed<C>(self, consumer: C) -> C::Result
195 where
196 C: UnindexedConsumer<Self::Item>,
197 {
198 self.a
199 .into_par_iter()
200 .filter(|&x| !self.b.contains(x))
201 .drive_unindexed(consumer)
202 }
203}
204
205impl<'a, T, S> ParallelIterator for ParSymmetricDifference<'a, T, S>
206where
207 T: Eq + Hash + Sync,
208 S: BuildHasher + Sync,
209{
210 type Item = &'a T;
211
212 fn drive_unindexed<C>(self, consumer: C) -> C::Result
213 where
214 C: UnindexedConsumer<Self::Item>,
215 {
216 self.a
217 .par_difference(self.b)
218 .chain(self.b.par_difference(self.a))
219 .drive_unindexed(consumer)
220 }
221}
222
223impl<'a, T, S> ParallelIterator for ParIntersection<'a, T, S>
224where
225 T: Eq + Hash + Sync,
226 S: BuildHasher + Sync,
227{
228 type Item = &'a T;
229
230 fn drive_unindexed<C>(self, consumer: C) -> C::Result
231 where
232 C: UnindexedConsumer<Self::Item>,
233 {
234 self.a
235 .into_par_iter()
236 .filter(|&x| self.b.contains(x))
237 .drive_unindexed(consumer)
238 }
239}
240
241impl<'a, T, S> ParallelIterator for ParUnion<'a, T, S>
242where
243 T: Eq + Hash + Sync,
244 S: BuildHasher + Sync,
245{
246 type Item = &'a T;
247
248 fn drive_unindexed<C>(self, consumer: C) -> C::Result
249 where
250 C: UnindexedConsumer<Self::Item>,
251 {
252 self.a
253 .into_par_iter()
254 .chain(self.b.par_difference(self.a))
255 .drive_unindexed(consumer)
256 }
257}
258
259#[cfg(test)]
260mod test_par_set {
261 use super::HashSet;
262 use rayon::prelude::*;
263 use std::sync::atomic::{AtomicUsize, Ordering};
264
265 #[test]
266 fn test_disjoint() {
267 let mut xs = HashSet::new();
268 let mut ys = HashSet::new();
269 assert!(xs.par_is_disjoint(&ys));
270 assert!(ys.par_is_disjoint(&xs));
271 assert!(xs.insert(5));
272 assert!(ys.insert(11));
273 assert!(xs.par_is_disjoint(&ys));
274 assert!(ys.par_is_disjoint(&xs));
275 assert!(xs.insert(7));
276 assert!(xs.insert(19));
277 assert!(xs.insert(4));
278 assert!(ys.insert(2));
279 assert!(ys.insert(-11));
280 assert!(xs.par_is_disjoint(&ys));
281 assert!(ys.par_is_disjoint(&xs));
282 assert!(ys.insert(7));
283 assert!(!xs.par_is_disjoint(&ys));
284 assert!(!ys.par_is_disjoint(&xs));
285 }
286
287 #[test]
288 fn test_subset_and_superset() {
289 let mut a = HashSet::new();
290 assert!(a.insert(0));
291 assert!(a.insert(5));
292 assert!(a.insert(11));
293 assert!(a.insert(7));
294
295 let mut b = HashSet::new();
296 assert!(b.insert(0));
297 assert!(b.insert(7));
298 assert!(b.insert(19));
299 assert!(b.insert(250));
300 assert!(b.insert(11));
301 assert!(b.insert(200));
302
303 assert!(!a.par_is_subset(&b));
304 assert!(!a.par_is_superset(&b));
305 assert!(!b.par_is_subset(&a));
306 assert!(!b.par_is_superset(&a));
307
308 assert!(b.insert(5));
309
310 assert!(a.par_is_subset(&b));
311 assert!(!a.par_is_superset(&b));
312 assert!(!b.par_is_subset(&a));
313 assert!(b.par_is_superset(&a));
314 }
315
316 #[test]
317 fn test_iterate() {
318 let mut a = HashSet::new();
319 for i in 0..32 {
320 assert!(a.insert(i));
321 }
322 let observed = AtomicUsize::new(0);
323 a.par_iter().for_each(|k| {
324 observed.fetch_or(1 << *k, Ordering::Relaxed);
325 });
326 assert_eq!(observed.into_inner(), 0xFFFF_FFFF);
327 }
328
329 #[test]
330 fn test_intersection() {
331 let mut a = HashSet::new();
332 let mut b = HashSet::new();
333
334 assert!(a.insert(11));
335 assert!(a.insert(1));
336 assert!(a.insert(3));
337 assert!(a.insert(77));
338 assert!(a.insert(103));
339 assert!(a.insert(5));
340 assert!(a.insert(-5));
341
342 assert!(b.insert(2));
343 assert!(b.insert(11));
344 assert!(b.insert(77));
345 assert!(b.insert(-9));
346 assert!(b.insert(-42));
347 assert!(b.insert(5));
348 assert!(b.insert(3));
349
350 let expected = [3, 5, 11, 77];
351 let i = a
352 .par_intersection(&b)
353 .map(|x| {
354 assert!(expected.contains(x));
355 1
356 }).sum::<usize>();
357 assert_eq!(i, expected.len());
358 }
359
360 #[test]
361 fn test_difference() {
362 let mut a = HashSet::new();
363 let mut b = HashSet::new();
364
365 assert!(a.insert(1));
366 assert!(a.insert(3));
367 assert!(a.insert(5));
368 assert!(a.insert(9));
369 assert!(a.insert(11));
370
371 assert!(b.insert(3));
372 assert!(b.insert(9));
373
374 let expected = [1, 5, 11];
375 let i = a
376 .par_difference(&b)
377 .map(|x| {
378 assert!(expected.contains(x));
379 1
380 }).sum::<usize>();
381 assert_eq!(i, expected.len());
382 }
383
384 #[test]
385 fn test_symmetric_difference() {
386 let mut a = HashSet::new();
387 let mut b = HashSet::new();
388
389 assert!(a.insert(1));
390 assert!(a.insert(3));
391 assert!(a.insert(5));
392 assert!(a.insert(9));
393 assert!(a.insert(11));
394
395 assert!(b.insert(-2));
396 assert!(b.insert(3));
397 assert!(b.insert(9));
398 assert!(b.insert(14));
399 assert!(b.insert(22));
400
401 let expected = [-2, 1, 5, 11, 14, 22];
402 let i = a
403 .par_symmetric_difference(&b)
404 .map(|x| {
405 assert!(expected.contains(x));
406 1
407 }).sum::<usize>();
408 assert_eq!(i, expected.len());
409 }
410
411 #[test]
412 fn test_union() {
413 let mut a = HashSet::new();
414 let mut b = HashSet::new();
415
416 assert!(a.insert(1));
417 assert!(a.insert(3));
418 assert!(a.insert(5));
419 assert!(a.insert(9));
420 assert!(a.insert(11));
421 assert!(a.insert(16));
422 assert!(a.insert(19));
423 assert!(a.insert(24));
424
425 assert!(b.insert(-2));
426 assert!(b.insert(1));
427 assert!(b.insert(5));
428 assert!(b.insert(9));
429 assert!(b.insert(13));
430 assert!(b.insert(19));
431
432 let expected = [-2, 1, 3, 5, 9, 11, 13, 16, 19, 24];
433 let i = a
434 .par_union(&b)
435 .map(|x| {
436 assert!(expected.contains(x));
437 1
438 }).sum::<usize>();
439 assert_eq!(i, expected.len());
440 }
441
442 #[test]
443 fn test_from_iter() {
444 let xs = [1, 2, 3, 4, 5, 6, 7, 8, 9];
445
446 let set: HashSet<_> = xs.par_iter().cloned().collect();
447
448 for x in &xs {
449 assert!(set.contains(x));
450 }
451 }
452
453 #[test]
454 fn test_move_iter() {
455 let hs = {
456 let mut hs = HashSet::new();
457
458 hs.insert('a');
459 hs.insert('b');
460
461 hs
462 };
463
464 let v = hs.into_par_iter().collect::<Vec<char>>();
465 assert!(v == ['a', 'b'] || v == ['b', 'a']);
466 }
467
468 #[test]
469 fn test_eq() {
470 let mut s1 = HashSet::new();
473
474 s1.insert(1);
475 s1.insert(2);
476 s1.insert(3);
477
478 let mut s2 = HashSet::new();
479
480 s2.insert(1);
481 s2.insert(2);
482
483 assert!(!s1.par_eq(&s2));
484
485 s2.insert(3);
486
487 assert!(s1.par_eq(&s2));
488 }
489
490 #[test]
491 fn test_extend_ref() {
492 let mut a = HashSet::new();
493 a.insert(1);
494
495 a.par_extend(&[2, 3, 4][..]);
496
497 assert_eq!(a.len(), 4);
498 assert!(a.contains(&1));
499 assert!(a.contains(&2));
500 assert!(a.contains(&3));
501 assert!(a.contains(&4));
502
503 let mut b = HashSet::new();
504 b.insert(5);
505 b.insert(6);
506
507 a.par_extend(&b);
508
509 assert_eq!(a.len(), 6);
510 assert!(a.contains(&1));
511 assert!(a.contains(&2));
512 assert!(a.contains(&3));
513 assert!(a.contains(&4));
514 assert!(a.contains(&5));
515 assert!(a.contains(&6));
516 }
517}