par_iter/iter/
chain.rs

1use std::iter;
2
3use par_core::join;
4
5use super::{plumbing::*, *};
6
7/// `Chain` is an iterator that joins `b` after `a` in one continuous iterator.
8/// This struct is created by the [`chain()`] method on [`ParallelIterator`]
9///
10/// [`chain()`]: trait.ParallelIterator.html#method.chain
11/// [`ParallelIterator`]: trait.ParallelIterator.html
12#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
13#[derive(Debug, Clone)]
14pub struct Chain<A, B>
15where
16    A: ParallelIterator,
17    B: ParallelIterator<Item = A::Item>,
18{
19    a: A,
20    b: B,
21}
22
23impl<A, B> Chain<A, B>
24where
25    A: ParallelIterator,
26    B: ParallelIterator<Item = A::Item>,
27{
28    /// Creates a new `Chain` iterator.
29    pub(super) fn new(a: A, b: B) -> Self {
30        Chain { a, b }
31    }
32}
33
34impl<A, B> ParallelIterator for Chain<A, B>
35where
36    A: ParallelIterator,
37    B: ParallelIterator<Item = A::Item>,
38{
39    type Item = A::Item;
40
41    fn drive_unindexed<C>(self, consumer: C) -> C::Result
42    where
43        C: UnindexedConsumer<Self::Item>,
44    {
45        let Chain { a, b } = self;
46
47        // If we returned a value from our own `opt_len`, then the collect consumer in
48        // particular will balk at being treated like an actual
49        // `UnindexedConsumer`.  But when we do know the length, we can use
50        // `Consumer::split_at` instead, and this is still harmless for other
51        // truly-unindexed consumers too.
52        let (left, right, reducer) = if let Some(len) = a.opt_len() {
53            consumer.split_at(len)
54        } else {
55            let reducer = consumer.to_reducer();
56            (consumer.split_off_left(), consumer, reducer)
57        };
58
59        let (a, b) = join(|| a.drive_unindexed(left), || b.drive_unindexed(right));
60        reducer.reduce(a, b)
61    }
62
63    fn opt_len(&self) -> Option<usize> {
64        self.a.opt_len()?.checked_add(self.b.opt_len()?)
65    }
66}
67
68impl<A, B> IndexedParallelIterator for Chain<A, B>
69where
70    A: IndexedParallelIterator,
71    B: IndexedParallelIterator<Item = A::Item>,
72{
73    fn drive<C>(self, consumer: C) -> C::Result
74    where
75        C: Consumer<Self::Item>,
76    {
77        let Chain { a, b } = self;
78        let (left, right, reducer) = consumer.split_at(a.len());
79        let (a, b) = join(|| a.drive(left), || b.drive(right));
80        reducer.reduce(a, b)
81    }
82
83    fn len(&self) -> usize {
84        self.a.len().checked_add(self.b.len()).expect("overflow")
85    }
86
87    fn with_producer<CB>(self, callback: CB) -> CB::Output
88    where
89        CB: ProducerCallback<Self::Item>,
90    {
91        let a_len = self.a.len();
92        return self.a.with_producer(CallbackA {
93            callback,
94            a_len,
95            b: self.b,
96        });
97
98        struct CallbackA<CB, B> {
99            callback: CB,
100            a_len: usize,
101            b: B,
102        }
103
104        impl<CB, B> ProducerCallback<B::Item> for CallbackA<CB, B>
105        where
106            B: IndexedParallelIterator,
107            CB: ProducerCallback<B::Item>,
108        {
109            type Output = CB::Output;
110
111            fn callback<A>(self, a_producer: A) -> Self::Output
112            where
113                A: Producer<Item = B::Item>,
114            {
115                self.b.with_producer(CallbackB {
116                    callback: self.callback,
117                    a_len: self.a_len,
118                    a_producer,
119                })
120            }
121        }
122
123        struct CallbackB<CB, A> {
124            callback: CB,
125            a_len: usize,
126            a_producer: A,
127        }
128
129        impl<CB, A> ProducerCallback<A::Item> for CallbackB<CB, A>
130        where
131            A: Producer,
132            CB: ProducerCallback<A::Item>,
133        {
134            type Output = CB::Output;
135
136            fn callback<B>(self, b_producer: B) -> Self::Output
137            where
138                B: Producer<Item = A::Item>,
139            {
140                let producer = ChainProducer::new(self.a_len, self.a_producer, b_producer);
141                self.callback.callback(producer)
142            }
143        }
144    }
145}
146
147/// ////////////////////////////////////////////////////////////////////////
148
149struct ChainProducer<A, B>
150where
151    A: Producer,
152    B: Producer<Item = A::Item>,
153{
154    a_len: usize,
155    a: A,
156    b: B,
157}
158
159impl<A, B> ChainProducer<A, B>
160where
161    A: Producer,
162    B: Producer<Item = A::Item>,
163{
164    fn new(a_len: usize, a: A, b: B) -> Self {
165        ChainProducer { a_len, a, b }
166    }
167}
168
169impl<A, B> Producer for ChainProducer<A, B>
170where
171    A: Producer,
172    B: Producer<Item = A::Item>,
173{
174    type IntoIter = ChainSeq<A::IntoIter, B::IntoIter>;
175    type Item = A::Item;
176
177    fn into_iter(self) -> Self::IntoIter {
178        ChainSeq::new(self.a.into_iter(), self.b.into_iter())
179    }
180
181    fn min_len(&self) -> usize {
182        Ord::max(self.a.min_len(), self.b.min_len())
183    }
184
185    fn max_len(&self) -> usize {
186        Ord::min(self.a.max_len(), self.b.max_len())
187    }
188
189    fn split_at(self, index: usize) -> (Self, Self) {
190        if index <= self.a_len {
191            let a_rem = self.a_len - index;
192            let (a_left, a_right) = self.a.split_at(index);
193            let (b_left, b_right) = self.b.split_at(0);
194            (
195                ChainProducer::new(index, a_left, b_left),
196                ChainProducer::new(a_rem, a_right, b_right),
197            )
198        } else {
199            let (a_left, a_right) = self.a.split_at(self.a_len);
200            let (b_left, b_right) = self.b.split_at(index - self.a_len);
201            (
202                ChainProducer::new(self.a_len, a_left, b_left),
203                ChainProducer::new(0, a_right, b_right),
204            )
205        }
206    }
207
208    fn fold_with<F>(self, mut folder: F) -> F
209    where
210        F: Folder<A::Item>,
211    {
212        folder = self.a.fold_with(folder);
213        if folder.full() {
214            folder
215        } else {
216            self.b.fold_with(folder)
217        }
218    }
219}
220
221/// ////////////////////////////////////////////////////////////////////////
222/// Wrapper for Chain to implement ExactSizeIterator
223
224struct ChainSeq<A, B> {
225    chain: iter::Chain<A, B>,
226}
227
228impl<A, B> ChainSeq<A, B> {
229    fn new(a: A, b: B) -> ChainSeq<A, B>
230    where
231        A: ExactSizeIterator,
232        B: ExactSizeIterator<Item = A::Item>,
233    {
234        ChainSeq { chain: a.chain(b) }
235    }
236}
237
238impl<A, B> Iterator for ChainSeq<A, B>
239where
240    A: Iterator,
241    B: Iterator<Item = A::Item>,
242{
243    type Item = A::Item;
244
245    fn next(&mut self) -> Option<Self::Item> {
246        self.chain.next()
247    }
248
249    fn size_hint(&self) -> (usize, Option<usize>) {
250        self.chain.size_hint()
251    }
252}
253
254impl<A, B> ExactSizeIterator for ChainSeq<A, B>
255where
256    A: ExactSizeIterator,
257    B: ExactSizeIterator<Item = A::Item>,
258{
259}
260
261impl<A, B> DoubleEndedIterator for ChainSeq<A, B>
262where
263    A: DoubleEndedIterator,
264    B: DoubleEndedIterator<Item = A::Item>,
265{
266    fn next_back(&mut self) -> Option<Self::Item> {
267        self.chain.next_back()
268    }
269}