1use super::collect;
9use rayon::iter::plumbing::{Consumer, ProducerCallback, UnindexedConsumer};
10use rayon::prelude::*;
11
12use crate::vec::Vec;
13use core::cmp::Ordering;
14use core::fmt;
15use core::hash::{BuildHasher, Hash};
16
17use crate::Entries;
18use crate::IndexSet;
19
20type Bucket<T> = crate::Bucket<T, ()>;
21
22impl<T, S> IntoParallelIterator for IndexSet<T, S>
24where
25 T: Send,
26{
27 type Item = T;
28 type Iter = IntoParIter<T>;
29
30 fn into_par_iter(self) -> Self::Iter {
31 IntoParIter {
32 entries: self.into_entries(),
33 }
34 }
35}
36
37pub struct IntoParIter<T> {
45 entries: Vec<Bucket<T>>,
46}
47
48impl<T: fmt::Debug> fmt::Debug for IntoParIter<T> {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 let iter = self.entries.iter().map(Bucket::key_ref);
51 f.debug_list().entries(iter).finish()
52 }
53}
54
55impl<T: Send> ParallelIterator for IntoParIter<T> {
56 type Item = T;
57
58 parallel_iterator_methods!(Bucket::key);
59}
60
61impl<T: Send> IndexedParallelIterator for IntoParIter<T> {
62 indexed_parallel_iterator_methods!(Bucket::key);
63}
64
65impl<'a, T, S> IntoParallelIterator for &'a IndexSet<T, S>
67where
68 T: Sync,
69{
70 type Item = &'a T;
71 type Iter = ParIter<'a, T>;
72
73 fn into_par_iter(self) -> Self::Iter {
74 ParIter {
75 entries: self.as_entries(),
76 }
77 }
78}
79
80pub struct ParIter<'a, T> {
88 entries: &'a [Bucket<T>],
89}
90
91impl<T> Clone for ParIter<'_, T> {
92 fn clone(&self) -> Self {
93 ParIter { ..*self }
94 }
95}
96
97impl<T: fmt::Debug> fmt::Debug for ParIter<'_, T> {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 let iter = self.entries.iter().map(Bucket::key_ref);
100 f.debug_list().entries(iter).finish()
101 }
102}
103
104impl<'a, T: Sync> ParallelIterator for ParIter<'a, T> {
105 type Item = &'a T;
106
107 parallel_iterator_methods!(Bucket::key_ref);
108}
109
110impl<T: Sync> IndexedParallelIterator for ParIter<'_, T> {
111 indexed_parallel_iterator_methods!(Bucket::key_ref);
112}
113
114impl<T, S> IndexSet<T, S>
120where
121 T: Hash + Eq + Sync,
122 S: BuildHasher + Sync,
123{
124 pub fn par_difference<'a, S2>(
129 &'a self,
130 other: &'a IndexSet<T, S2>,
131 ) -> ParDifference<'a, T, S, S2>
132 where
133 S2: BuildHasher + Sync,
134 {
135 ParDifference {
136 set1: self,
137 set2: other,
138 }
139 }
140
141 pub fn par_symmetric_difference<'a, S2>(
149 &'a self,
150 other: &'a IndexSet<T, S2>,
151 ) -> ParSymmetricDifference<'a, T, S, S2>
152 where
153 S2: BuildHasher + Sync,
154 {
155 ParSymmetricDifference {
156 set1: self,
157 set2: other,
158 }
159 }
160
161 pub fn par_intersection<'a, S2>(
166 &'a self,
167 other: &'a IndexSet<T, S2>,
168 ) -> ParIntersection<'a, T, S, S2>
169 where
170 S2: BuildHasher + Sync,
171 {
172 ParIntersection {
173 set1: self,
174 set2: other,
175 }
176 }
177
178 pub fn par_union<'a, S2>(&'a self, other: &'a IndexSet<T, S2>) -> ParUnion<'a, T, S, S2>
185 where
186 S2: BuildHasher + Sync,
187 {
188 ParUnion {
189 set1: self,
190 set2: other,
191 }
192 }
193
194 pub fn par_eq<S2>(&self, other: &IndexSet<T, S2>) -> bool
197 where
198 S2: BuildHasher + Sync,
199 {
200 self.len() == other.len() && self.par_is_subset(other)
201 }
202
203 pub fn par_is_disjoint<S2>(&self, other: &IndexSet<T, S2>) -> bool
206 where
207 S2: BuildHasher + Sync,
208 {
209 if self.len() <= other.len() {
210 self.par_iter().all(move |value| !other.contains(value))
211 } else {
212 other.par_iter().all(move |value| !self.contains(value))
213 }
214 }
215
216 pub fn par_is_superset<S2>(&self, other: &IndexSet<T, S2>) -> bool
219 where
220 S2: BuildHasher + Sync,
221 {
222 other.par_is_subset(self)
223 }
224
225 pub fn par_is_subset<S2>(&self, other: &IndexSet<T, S2>) -> bool
228 where
229 S2: BuildHasher + Sync,
230 {
231 self.len() <= other.len() && self.par_iter().all(move |value| other.contains(value))
232 }
233}
234
235pub struct ParDifference<'a, T, S1, S2> {
243 set1: &'a IndexSet<T, S1>,
244 set2: &'a IndexSet<T, S2>,
245}
246
247impl<T, S1, S2> Clone for ParDifference<'_, T, S1, S2> {
248 fn clone(&self) -> Self {
249 ParDifference { ..*self }
250 }
251}
252
253impl<T, S1, S2> fmt::Debug for ParDifference<'_, T, S1, S2>
254where
255 T: fmt::Debug + Eq + Hash,
256 S1: BuildHasher,
257 S2: BuildHasher,
258{
259 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260 f.debug_list()
261 .entries(self.set1.difference(&self.set2))
262 .finish()
263 }
264}
265
266impl<'a, T, S1, S2> ParallelIterator for ParDifference<'a, T, S1, S2>
267where
268 T: Hash + Eq + Sync,
269 S1: BuildHasher + Sync,
270 S2: BuildHasher + Sync,
271{
272 type Item = &'a T;
273
274 fn drive_unindexed<C>(self, consumer: C) -> C::Result
275 where
276 C: UnindexedConsumer<Self::Item>,
277 {
278 let Self { set1, set2 } = self;
279
280 set1.par_iter()
281 .filter(move |&item| !set2.contains(item))
282 .drive_unindexed(consumer)
283 }
284}
285
286pub struct ParIntersection<'a, T, S1, S2> {
294 set1: &'a IndexSet<T, S1>,
295 set2: &'a IndexSet<T, S2>,
296}
297
298impl<T, S1, S2> Clone for ParIntersection<'_, T, S1, S2> {
299 fn clone(&self) -> Self {
300 ParIntersection { ..*self }
301 }
302}
303
304impl<T, S1, S2> fmt::Debug for ParIntersection<'_, T, S1, S2>
305where
306 T: fmt::Debug + Eq + Hash,
307 S1: BuildHasher,
308 S2: BuildHasher,
309{
310 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
311 f.debug_list()
312 .entries(self.set1.intersection(&self.set2))
313 .finish()
314 }
315}
316
317impl<'a, T, S1, S2> ParallelIterator for ParIntersection<'a, T, S1, S2>
318where
319 T: Hash + Eq + Sync,
320 S1: BuildHasher + Sync,
321 S2: BuildHasher + Sync,
322{
323 type Item = &'a T;
324
325 fn drive_unindexed<C>(self, consumer: C) -> C::Result
326 where
327 C: UnindexedConsumer<Self::Item>,
328 {
329 let Self { set1, set2 } = self;
330
331 set1.par_iter()
332 .filter(move |&item| set2.contains(item))
333 .drive_unindexed(consumer)
334 }
335}
336
337pub struct ParSymmetricDifference<'a, T, S1, S2> {
345 set1: &'a IndexSet<T, S1>,
346 set2: &'a IndexSet<T, S2>,
347}
348
349impl<T, S1, S2> Clone for ParSymmetricDifference<'_, T, S1, S2> {
350 fn clone(&self) -> Self {
351 ParSymmetricDifference { ..*self }
352 }
353}
354
355impl<T, S1, S2> fmt::Debug for ParSymmetricDifference<'_, T, S1, S2>
356where
357 T: fmt::Debug + Eq + Hash,
358 S1: BuildHasher,
359 S2: BuildHasher,
360{
361 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362 f.debug_list()
363 .entries(self.set1.symmetric_difference(&self.set2))
364 .finish()
365 }
366}
367
368impl<'a, T, S1, S2> ParallelIterator for ParSymmetricDifference<'a, T, S1, S2>
369where
370 T: Hash + Eq + Sync,
371 S1: BuildHasher + Sync,
372 S2: BuildHasher + Sync,
373{
374 type Item = &'a T;
375
376 fn drive_unindexed<C>(self, consumer: C) -> C::Result
377 where
378 C: UnindexedConsumer<Self::Item>,
379 {
380 let Self { set1, set2 } = self;
381
382 set1.par_difference(set2)
383 .chain(set2.par_difference(set1))
384 .drive_unindexed(consumer)
385 }
386}
387
388pub struct ParUnion<'a, T, S1, S2> {
396 set1: &'a IndexSet<T, S1>,
397 set2: &'a IndexSet<T, S2>,
398}
399
400impl<T, S1, S2> Clone for ParUnion<'_, T, S1, S2> {
401 fn clone(&self) -> Self {
402 ParUnion { ..*self }
403 }
404}
405
406impl<T, S1, S2> fmt::Debug for ParUnion<'_, T, S1, S2>
407where
408 T: fmt::Debug + Eq + Hash,
409 S1: BuildHasher,
410 S2: BuildHasher,
411{
412 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
413 f.debug_list().entries(self.set1.union(&self.set2)).finish()
414 }
415}
416
417impl<'a, T, S1, S2> ParallelIterator for ParUnion<'a, T, S1, S2>
418where
419 T: Hash + Eq + Sync,
420 S1: BuildHasher + Sync,
421 S2: BuildHasher + Sync,
422{
423 type Item = &'a T;
424
425 fn drive_unindexed<C>(self, consumer: C) -> C::Result
426 where
427 C: UnindexedConsumer<Self::Item>,
428 {
429 let Self { set1, set2 } = self;
430
431 set1.par_iter()
432 .chain(set2.par_difference(set1))
433 .drive_unindexed(consumer)
434 }
435}
436
437impl<T, S> IndexSet<T, S>
441where
442 T: Hash + Eq + Send,
443 S: BuildHasher + Send,
444{
445 pub fn par_sort(&mut self)
447 where
448 T: Ord,
449 {
450 self.with_entries(|entries| {
451 entries.par_sort_by(|a, b| T::cmp(&a.key, &b.key));
452 });
453 }
454
455 pub fn par_sort_by<F>(&mut self, cmp: F)
457 where
458 F: Fn(&T, &T) -> Ordering + Sync,
459 {
460 self.with_entries(|entries| {
461 entries.par_sort_by(move |a, b| cmp(&a.key, &b.key));
462 });
463 }
464
465 pub fn par_sorted_by<F>(self, cmp: F) -> IntoParIter<T>
468 where
469 F: Fn(&T, &T) -> Ordering + Sync,
470 {
471 let mut entries = self.into_entries();
472 entries.par_sort_by(move |a, b| cmp(&a.key, &b.key));
473 IntoParIter { entries }
474 }
475}
476
477impl<T, S> FromParallelIterator<T> for IndexSet<T, S>
479where
480 T: Eq + Hash + Send,
481 S: BuildHasher + Default + Send,
482{
483 fn from_par_iter<I>(iter: I) -> Self
484 where
485 I: IntoParallelIterator<Item = T>,
486 {
487 let list = collect(iter);
488 let len = list.iter().map(Vec::len).sum();
489 let mut set = Self::with_capacity_and_hasher(len, S::default());
490 for vec in list {
491 set.extend(vec);
492 }
493 set
494 }
495}
496
497impl<T, S> ParallelExtend<T> for IndexSet<T, S>
499where
500 T: Eq + Hash + Send,
501 S: BuildHasher + Send,
502{
503 fn par_extend<I>(&mut self, iter: I)
504 where
505 I: IntoParallelIterator<Item = T>,
506 {
507 for vec in collect(iter) {
508 self.extend(vec);
509 }
510 }
511}
512
513impl<'a, T: 'a, S> ParallelExtend<&'a T> for IndexSet<T, S>
515where
516 T: Copy + Eq + Hash + Send + Sync,
517 S: BuildHasher + Send,
518{
519 fn par_extend<I>(&mut self, iter: I)
520 where
521 I: IntoParallelIterator<Item = &'a T>,
522 {
523 for vec in collect(iter) {
524 self.extend(vec);
525 }
526 }
527}
528
529#[cfg(test)]
530mod tests {
531 use super::*;
532
533 #[test]
534 fn insert_order() {
535 let insert = [0, 4, 2, 12, 8, 7, 11, 5, 3, 17, 19, 22, 23];
536 let mut set = IndexSet::new();
537
538 for &elt in &insert {
539 set.insert(elt);
540 }
541
542 assert_eq!(set.par_iter().count(), set.len());
543 assert_eq!(set.par_iter().count(), insert.len());
544 insert.par_iter().zip(&set).for_each(|(a, b)| {
545 assert_eq!(a, b);
546 });
547 (0..insert.len())
548 .into_par_iter()
549 .zip(&set)
550 .for_each(|(i, v)| {
551 assert_eq!(set.get_index(i).unwrap(), v);
552 });
553 }
554
555 #[test]
556 fn partial_eq_and_eq() {
557 let mut set_a = IndexSet::new();
558 set_a.insert(1);
559 set_a.insert(2);
560 let mut set_b = set_a.clone();
561 assert!(set_a.par_eq(&set_b));
562 set_b.swap_remove(&1);
563 assert!(!set_a.par_eq(&set_b));
564 set_b.insert(3);
565 assert!(!set_a.par_eq(&set_b));
566
567 let set_c: IndexSet<_> = set_b.into_par_iter().collect();
568 assert!(!set_a.par_eq(&set_c));
569 assert!(!set_c.par_eq(&set_a));
570 }
571
572 #[test]
573 fn extend() {
574 let mut set = IndexSet::new();
575 set.par_extend(vec![&1, &2, &3, &4]);
576 set.par_extend(vec![5, 6]);
577 assert_eq!(
578 set.into_par_iter().collect::<Vec<_>>(),
579 vec![1, 2, 3, 4, 5, 6]
580 );
581 }
582
583 #[test]
584 fn comparisons() {
585 let set_a: IndexSet<_> = (0..3).collect();
586 let set_b: IndexSet<_> = (3..6).collect();
587 let set_c: IndexSet<_> = (0..6).collect();
588 let set_d: IndexSet<_> = (3..9).collect();
589
590 assert!(!set_a.par_is_disjoint(&set_a));
591 assert!(set_a.par_is_subset(&set_a));
592 assert!(set_a.par_is_superset(&set_a));
593
594 assert!(set_a.par_is_disjoint(&set_b));
595 assert!(set_b.par_is_disjoint(&set_a));
596 assert!(!set_a.par_is_subset(&set_b));
597 assert!(!set_b.par_is_subset(&set_a));
598 assert!(!set_a.par_is_superset(&set_b));
599 assert!(!set_b.par_is_superset(&set_a));
600
601 assert!(!set_a.par_is_disjoint(&set_c));
602 assert!(!set_c.par_is_disjoint(&set_a));
603 assert!(set_a.par_is_subset(&set_c));
604 assert!(!set_c.par_is_subset(&set_a));
605 assert!(!set_a.par_is_superset(&set_c));
606 assert!(set_c.par_is_superset(&set_a));
607
608 assert!(!set_c.par_is_disjoint(&set_d));
609 assert!(!set_d.par_is_disjoint(&set_c));
610 assert!(!set_c.par_is_subset(&set_d));
611 assert!(!set_d.par_is_subset(&set_c));
612 assert!(!set_c.par_is_superset(&set_d));
613 assert!(!set_d.par_is_superset(&set_c));
614 }
615
616 #[test]
617 fn iter_comparisons() {
618 use std::iter::empty;
619
620 fn check<'a, I1, I2>(iter1: I1, iter2: I2)
621 where
622 I1: ParallelIterator<Item = &'a i32>,
623 I2: Iterator<Item = i32>,
624 {
625 let v1: Vec<_> = iter1.cloned().collect();
626 let v2: Vec<_> = iter2.collect();
627 assert_eq!(v1, v2);
628 }
629
630 let set_a: IndexSet<_> = (0..3).collect();
631 let set_b: IndexSet<_> = (3..6).collect();
632 let set_c: IndexSet<_> = (0..6).collect();
633 let set_d: IndexSet<_> = (3..9).rev().collect();
634
635 check(set_a.par_difference(&set_a), empty());
636 check(set_a.par_symmetric_difference(&set_a), empty());
637 check(set_a.par_intersection(&set_a), 0..3);
638 check(set_a.par_union(&set_a), 0..3);
639
640 check(set_a.par_difference(&set_b), 0..3);
641 check(set_b.par_difference(&set_a), 3..6);
642 check(set_a.par_symmetric_difference(&set_b), 0..6);
643 check(set_b.par_symmetric_difference(&set_a), (3..6).chain(0..3));
644 check(set_a.par_intersection(&set_b), empty());
645 check(set_b.par_intersection(&set_a), empty());
646 check(set_a.par_union(&set_b), 0..6);
647 check(set_b.par_union(&set_a), (3..6).chain(0..3));
648
649 check(set_a.par_difference(&set_c), empty());
650 check(set_c.par_difference(&set_a), 3..6);
651 check(set_a.par_symmetric_difference(&set_c), 3..6);
652 check(set_c.par_symmetric_difference(&set_a), 3..6);
653 check(set_a.par_intersection(&set_c), 0..3);
654 check(set_c.par_intersection(&set_a), 0..3);
655 check(set_a.par_union(&set_c), 0..6);
656 check(set_c.par_union(&set_a), 0..6);
657
658 check(set_c.par_difference(&set_d), 0..3);
659 check(set_d.par_difference(&set_c), (6..9).rev());
660 check(
661 set_c.par_symmetric_difference(&set_d),
662 (0..3).chain((6..9).rev()),
663 );
664 check(
665 set_d.par_symmetric_difference(&set_c),
666 (6..9).rev().chain(0..3),
667 );
668 check(set_c.par_intersection(&set_d), 3..6);
669 check(set_d.par_intersection(&set_c), (3..6).rev());
670 check(set_c.par_union(&set_d), (0..6).chain((6..9).rev()));
671 check(set_d.par_union(&set_c), (3..9).rev().chain(0..3));
672 }
673}