1use std::iter;
2
3use par_core::join;
4
5use super::{plumbing::*, *};
6
7#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
13#[derive(Debug, Clone)]
14pub struct Chain<A, B>
15where
16 A: ParallelIterator,
17 B: ParallelIterator<Item = A::Item>,
18{
19 a: A,
20 b: B,
21}
22
23impl<A, B> Chain<A, B>
24where
25 A: ParallelIterator,
26 B: ParallelIterator<Item = A::Item>,
27{
28 pub(super) fn new(a: A, b: B) -> Self {
30 Chain { a, b }
31 }
32}
33
34impl<A, B> ParallelIterator for Chain<A, B>
35where
36 A: ParallelIterator,
37 B: ParallelIterator<Item = A::Item>,
38{
39 type Item = A::Item;
40
41 fn drive_unindexed<C>(self, consumer: C) -> C::Result
42 where
43 C: UnindexedConsumer<Self::Item>,
44 {
45 let Chain { a, b } = self;
46
47 let (left, right, reducer) = if let Some(len) = a.opt_len() {
53 consumer.split_at(len)
54 } else {
55 let reducer = consumer.to_reducer();
56 (consumer.split_off_left(), consumer, reducer)
57 };
58
59 let (a, b) = join(|| a.drive_unindexed(left), || b.drive_unindexed(right));
60 reducer.reduce(a, b)
61 }
62
63 fn opt_len(&self) -> Option<usize> {
64 self.a.opt_len()?.checked_add(self.b.opt_len()?)
65 }
66}
67
68impl<A, B> IndexedParallelIterator for Chain<A, B>
69where
70 A: IndexedParallelIterator,
71 B: IndexedParallelIterator<Item = A::Item>,
72{
73 fn drive<C>(self, consumer: C) -> C::Result
74 where
75 C: Consumer<Self::Item>,
76 {
77 let Chain { a, b } = self;
78 let (left, right, reducer) = consumer.split_at(a.len());
79 let (a, b) = join(|| a.drive(left), || b.drive(right));
80 reducer.reduce(a, b)
81 }
82
83 fn len(&self) -> usize {
84 self.a.len().checked_add(self.b.len()).expect("overflow")
85 }
86
87 fn with_producer<CB>(self, callback: CB) -> CB::Output
88 where
89 CB: ProducerCallback<Self::Item>,
90 {
91 let a_len = self.a.len();
92 return self.a.with_producer(CallbackA {
93 callback,
94 a_len,
95 b: self.b,
96 });
97
98 struct CallbackA<CB, B> {
99 callback: CB,
100 a_len: usize,
101 b: B,
102 }
103
104 impl<CB, B> ProducerCallback<B::Item> for CallbackA<CB, B>
105 where
106 B: IndexedParallelIterator,
107 CB: ProducerCallback<B::Item>,
108 {
109 type Output = CB::Output;
110
111 fn callback<A>(self, a_producer: A) -> Self::Output
112 where
113 A: Producer<Item = B::Item>,
114 {
115 self.b.with_producer(CallbackB {
116 callback: self.callback,
117 a_len: self.a_len,
118 a_producer,
119 })
120 }
121 }
122
123 struct CallbackB<CB, A> {
124 callback: CB,
125 a_len: usize,
126 a_producer: A,
127 }
128
129 impl<CB, A> ProducerCallback<A::Item> for CallbackB<CB, A>
130 where
131 A: Producer,
132 CB: ProducerCallback<A::Item>,
133 {
134 type Output = CB::Output;
135
136 fn callback<B>(self, b_producer: B) -> Self::Output
137 where
138 B: Producer<Item = A::Item>,
139 {
140 let producer = ChainProducer::new(self.a_len, self.a_producer, b_producer);
141 self.callback.callback(producer)
142 }
143 }
144 }
145}
146
147struct ChainProducer<A, B>
150where
151 A: Producer,
152 B: Producer<Item = A::Item>,
153{
154 a_len: usize,
155 a: A,
156 b: B,
157}
158
159impl<A, B> ChainProducer<A, B>
160where
161 A: Producer,
162 B: Producer<Item = A::Item>,
163{
164 fn new(a_len: usize, a: A, b: B) -> Self {
165 ChainProducer { a_len, a, b }
166 }
167}
168
169impl<A, B> Producer for ChainProducer<A, B>
170where
171 A: Producer,
172 B: Producer<Item = A::Item>,
173{
174 type IntoIter = ChainSeq<A::IntoIter, B::IntoIter>;
175 type Item = A::Item;
176
177 fn into_iter(self) -> Self::IntoIter {
178 ChainSeq::new(self.a.into_iter(), self.b.into_iter())
179 }
180
181 fn min_len(&self) -> usize {
182 Ord::max(self.a.min_len(), self.b.min_len())
183 }
184
185 fn max_len(&self) -> usize {
186 Ord::min(self.a.max_len(), self.b.max_len())
187 }
188
189 fn split_at(self, index: usize) -> (Self, Self) {
190 if index <= self.a_len {
191 let a_rem = self.a_len - index;
192 let (a_left, a_right) = self.a.split_at(index);
193 let (b_left, b_right) = self.b.split_at(0);
194 (
195 ChainProducer::new(index, a_left, b_left),
196 ChainProducer::new(a_rem, a_right, b_right),
197 )
198 } else {
199 let (a_left, a_right) = self.a.split_at(self.a_len);
200 let (b_left, b_right) = self.b.split_at(index - self.a_len);
201 (
202 ChainProducer::new(self.a_len, a_left, b_left),
203 ChainProducer::new(0, a_right, b_right),
204 )
205 }
206 }
207
208 fn fold_with<F>(self, mut folder: F) -> F
209 where
210 F: Folder<A::Item>,
211 {
212 folder = self.a.fold_with(folder);
213 if folder.full() {
214 folder
215 } else {
216 self.b.fold_with(folder)
217 }
218 }
219}
220
221struct ChainSeq<A, B> {
225 chain: iter::Chain<A, B>,
226}
227
228impl<A, B> ChainSeq<A, B> {
229 fn new(a: A, b: B) -> ChainSeq<A, B>
230 where
231 A: ExactSizeIterator,
232 B: ExactSizeIterator<Item = A::Item>,
233 {
234 ChainSeq { chain: a.chain(b) }
235 }
236}
237
238impl<A, B> Iterator for ChainSeq<A, B>
239where
240 A: Iterator,
241 B: Iterator<Item = A::Item>,
242{
243 type Item = A::Item;
244
245 fn next(&mut self) -> Option<Self::Item> {
246 self.chain.next()
247 }
248
249 fn size_hint(&self) -> (usize, Option<usize>) {
250 self.chain.size_hint()
251 }
252}
253
254impl<A, B> ExactSizeIterator for ChainSeq<A, B>
255where
256 A: ExactSizeIterator,
257 B: ExactSizeIterator<Item = A::Item>,
258{
259}
260
261impl<A, B> DoubleEndedIterator for ChainSeq<A, B>
262where
263 A: DoubleEndedIterator,
264 B: DoubleEndedIterator<Item = A::Item>,
265{
266 fn next_back(&mut self) -> Option<Self::Item> {
267 self.chain.next_back()
268 }
269}