1use rayon::iter::plumbing::{Folder, Reducer, UnindexedConsumer};
20use rayon::iter::ParallelIterator;
21use rayon::{current_num_threads, join_context};
22use std::iter::Iterator;
23
24pub trait SplittableIterator: Iterator + Sized {
26 fn split(&mut self) -> Option<Self>;
37}
38
39pub trait IntoParallelIterator: Sized {
41 fn into_par_iter(self) -> ParallelSplittableIterator<Self>;
46}
47
48impl<T> IntoParallelIterator for T
49where
50 T: SplittableIterator + Send,
51 T::Item: Send,
52{
53 fn into_par_iter(self) -> ParallelSplittableIterator<Self> {
54 ParallelSplittableIterator::new(self)
55 }
56}
57
58pub struct ParallelSplittableIterator<Iter> {
60 iter: Iter,
61 splits: usize,
62}
63
64impl<Iter> ParallelSplittableIterator<Iter>
65where
66 Iter: SplittableIterator,
67{
68 pub fn new(iter: Iter) -> Self {
70 Self {
71 iter,
72 splits: current_num_threads(),
73 }
74 }
75
76 fn split(&mut self) -> Option<Self> {
78 if self.splits == 0 {
79 return None;
80 }
81
82 if let Some(split) = self.iter.split() {
83 self.splits /= 2;
84
85 Some(Self {
86 iter: split,
87 splits: self.splits,
88 })
89 } else {
90 None
91 }
92 }
93
94 fn bridge<C>(&mut self, stolen: bool, consumer: C) -> C::Result
98 where
99 Iter: Send,
100 C: UnindexedConsumer<Iter::Item>,
101 {
102 if stolen {
105 self.splits = current_num_threads();
106 }
107
108 let mut folder = consumer.split_off_left().into_folder();
109
110 if self.splits == 0 {
111 return folder.consume_iter(&mut self.iter).complete();
112 }
113
114 while !folder.full() {
115 if let Some(mut split) = self.split() {
117 let (r1, r2) = (consumer.to_reducer(), consumer.to_reducer());
118 let left_consumer = consumer.split_off_left();
119
120 let (left, right) = join_context(
121 |ctx| self.bridge(ctx.migrated(), left_consumer),
122 |ctx| split.bridge(ctx.migrated(), consumer),
123 );
124 return r1.reduce(folder.complete(), r2.reduce(left, right));
125 }
126
127 if let Some(next) = self.iter.next() {
129 folder = folder.consume(next);
130 } else {
131 break;
132 }
133 }
134
135 folder.complete()
136 }
137}
138
139impl<Iter> ParallelIterator for ParallelSplittableIterator<Iter>
140where
141 Iter: SplittableIterator + Send,
142 Iter::Item: Send,
143{
144 type Item = Iter::Item;
145
146 fn drive_unindexed<C>(mut self, consumer: C) -> C::Result
147 where
148 C: UnindexedConsumer<Self::Item>,
149 {
150 self.bridge(false, consumer)
151 }
152}
153
154macro_rules! parallel_iterator {
155 ($iter:ident<$node:ident>) => {
156 impl<N> $crate::sync::par::SplittableIterator for $iter<N>
157 where
158 N: $node,
159 {
160 fn split(&mut self) -> Option<Self> {
161 use $crate::sync::Queue;
162 let len = self.queue.len();
163 if len >= 2 {
164 let split = self.queue.split_off(len / 2);
165 Some(Self {
166 queue: split,
167 max_depth: self.max_depth,
169 })
171 } else {
172 None
173 }
174 }
175 }
176
177 impl<N> rayon::iter::IntoParallelIterator for $iter<N>
178 where
179 N: $node + Sync + Send,
180 N::Error: Send,
181 {
182 type Iter = $crate::sync::par::ParallelSplittableIterator<Self>;
183 type Item = <Self as Iterator>::Item;
184
185 fn into_par_iter(self) -> Self::Iter {
186 $crate::sync::par::ParallelSplittableIterator::new(self)
187 }
188 }
189 };
190}
191pub(crate) use parallel_iterator;