Skip to main content

primitives/utils/
iter.rs

1#[macro_export]
2/// Create an iterator running multiple iterators in lockstep, panicking if they
3/// are not of the same length.
4///
5/// Length checks are done lazily as the resulting iterator is built.
6/// I.e. the `izip_eq_lazy!` iterator yields elements until any subiterator
7/// returns `None`.
8///
9/// **Note:** This macro is slower than `izip_eq!` however it allows to zip iterators of
10/// unknown or unbounded size.
11macro_rules! izip_eq_lazy {
12    // @closure creates a tuple-flattening closure for .map() call. usage:
13    // @closure partial_pattern => partial_tuple , rest , of , iterators
14    // eg. izip_eq_lazy!( @closure ((a, b), c) => (a, b, c) , dd , ee )
15    ( @closure $p:pat => $tup:expr ) => {
16        |$p| $tup
17    };
18
19    // The "b" identifier is a different identifier on each recursion level thanks to hygiene.
20    ( @closure $p:pat => ( $($tup:tt)* ) , $_iter:expr $( , $tail:expr )* ) => {
21        $crate::izip_eq_lazy!(@closure ($p, b) => ( $($tup)*, b ) $( , $tail )*)
22    };
23
24    // unary
25    ($first:expr $(,)*) => {
26        std::iter::IntoIterator::into_iter($first)
27    };
28
29    // binary
30    ($first:expr, $second:expr $(,)*) => {
31        itertools::zip_eq(
32            std::iter::IntoIterator::into_iter($first),
33            $second,
34        )
35    };
36
37    // n-ary where n > 2
38    ( $first:expr $( , $rest:expr )* $(,)* ) => {
39        {
40            let iter = std::iter::IntoIterator::into_iter($first);
41            $(
42                let iter = itertools::zip_eq(iter, $rest);
43            )*
44            std::iter::Iterator::map(
45                iter,
46                $crate::izip_eq_lazy!(@closure a => (a) $( , $rest )*)
47            )
48        }
49    };
50}
51
52#[macro_export]
53/// Create an iterator running multiple iterators in lockstep, panicking if they
54/// are not the same length.
55///
56/// Length checks are eagerly performed before the resulting iterator is built.
57///
58/// **Note:** For performance reasons, this macro is preferable to `izip_eq_lazy!` when the lengths
59/// of the iterators are known.
60macro_rules! izip_eq {
61    (@assert_eq_len $first:expr, $second:expr) => {
62        assert!($first.len() == $second.len(), "iterator length mismatch: {} vs {}", $first.len(), $second.len());
63    };
64
65    // unary
66    ($first:expr $(,)*) => {
67        std::iter::IntoIterator::into_iter($first)
68    };
69
70    // binary
71    ($first:expr, $second:expr $(,)*) => {
72        {
73            let iter = std::iter::IntoIterator::into_iter($first);
74            let second = std::iter::IntoIterator::into_iter($second);
75            $crate::izip_eq!(@assert_eq_len iter, second);
76            let iter = std::iter::Iterator::zip(iter, second);
77            iter
78        }
79    };
80
81    // n-ary where n > 2
82    ( $first:expr $( , $rest:expr )* $(,)* ) => {
83        {
84            let iter = std::iter::IntoIterator::into_iter($first);
85            $(
86                let rest = std::iter::IntoIterator::into_iter($rest);
87                $crate::izip_eq!(@assert_eq_len iter, rest);
88                let iter = std::iter::Iterator::zip(iter, rest);
89            )*
90            std::iter::Iterator::map(
91                iter,
92                $crate::izip_eq_lazy!(@closure a => (a) $( , $rest )*)
93            )
94        }
95    };
96}
97
98#[macro_export]
99/// Create a parallel iterator running multiple iterators in lockstep, panicking if they
100/// are not the same length.
101///
102/// This is the rayon parallel version of `izip_eq!`. It uses `into_par_iter()` and performs
103/// length checks eagerly before building the parallel iterator.
104///
105/// **Note:** All input iterators must implement `IntoParallelIterator` and the resulting
106/// parallel iterator will process elements in parallel using rayon.
107///
108/// # Examples
109///
110/// ```
111/// use primitives::par_izip;
112/// use rayon::prelude::*;
113///
114/// let a = vec![1, 2, 3];
115/// let b = vec![4, 5, 6];
116/// let c = vec![7, 8, 9];
117///
118/// let results: Vec<_> = par_izip!(&a, &b, &c).map(|(a, b, c)| a + b + c).collect();
119///
120/// assert_eq!(results, vec![12, 15, 18]);
121/// ```
122macro_rules! par_izip {
123    // Helper to create closure for tuple flattening, reuses izip_eq_lazy's closure builder
124    ( @closure $p:pat => $tup:expr ) => {
125        |$p| $tup
126    };
127
128    ( @closure $p:pat => ( $($tup:tt)* ) , $_iter:expr $( , $tail:expr )* ) => {
129        $crate::par_izip!(@closure ($p, b) => ( $($tup)*, b ) $( , $tail )*)
130    };
131
132    // Length assertion helper
133    (@assert_eq_len $first:expr, $second:expr) => {
134        assert!($first.len() == $second.len(), "parallel iterator length mismatch: {} vs {}", $first.len(), $second.len());
135    };
136
137    // unary
138    ($first:expr $(,)*) => {
139        rayon::iter::IntoParallelIterator::into_par_iter($first)
140    };
141
142    // binary
143    ($first:expr, $second:expr $(,)*) => {
144        {
145            let iter = rayon::iter::IntoParallelIterator::into_par_iter($first);
146            let second = rayon::iter::IntoParallelIterator::into_par_iter($second);
147            $crate::par_izip!(@assert_eq_len iter, second);
148            let iter = rayon::iter::IndexedParallelIterator::zip(iter, second);
149            iter
150        }
151    };
152
153    // n-ary where n > 2
154    ( $first:expr $( , $rest:expr )* $(,)* ) => {
155        {
156            let iter = rayon::iter::IntoParallelIterator::into_par_iter($first);
157            $(
158                let rest = rayon::iter::IntoParallelIterator::into_par_iter($rest);
159                $crate::par_izip!(@assert_eq_len iter, rest);
160                let iter = rayon::iter::IndexedParallelIterator::zip(iter, rest);
161            )*
162            rayon::iter::ParallelIterator::map(
163                iter,
164                $crate::par_izip!(@closure a => (a) $( , $rest )*)
165            )
166        }
167    };
168}
169
170#[macro_export]
171/// An adaptor for `itertools::chain!` macro which creates an ExactSizeIterator.
172///
173/// **Note:** This macro collects the output of `itertools::chain!` into a vector. This may prevent
174/// some compiler optimizations implying iterators.
175macro_rules! chain_eq {
176    ($first:expr $( , $rest:expr )* $(,)*) => {{
177        let it = itertools::chain!($first, $($rest, )*);
178        let vec = <std::vec::Vec<_> as std::iter::FromIterator<_>>::from_iter(it);
179        <std::vec::Vec<_> as std::iter::IntoIterator>::into_iter(vec)
180    }
181    };
182}
183
184#[macro_export]
185/// Create an iterator over hashmaps with the same keys. Accepts a list of keys, and the maps
186/// to be iterated over. Panics if any of the keys are not found in the maps.
187macro_rules! zip_maps {
188    ($keys:expr, $($map:expr),+ $(,)?) => {
189        $keys.into_iter().map(move |key| {
190            (
191                key,
192                $(
193                    $map.remove(&key).expect(&format!(
194                        "Key `{:?}` not found in map `{}`",
195                        key,
196                        stringify!($map),
197                    )),
198                )+
199            )
200        })
201    };
202}
203
204/// A trait for types that can be converted into an `ExactSizeIterator` via `into_iter()`.
205pub trait IntoExactSizeIterator:
206    IntoIterator<IntoIter: ExactSizeIterator<Item = <Self as IntoIterator>::Item>>
207{
208}
209
210impl<T: IntoIterator<IntoIter = S>, S: ExactSizeIterator<Item = <T as IntoIterator>::Item>>
211    IntoExactSizeIterator for T
212{
213}
214
215/// A trait for iterators that take exactly N elements, panicking if the iterator
216/// is shorter than N.
217pub trait TakeExact: Iterator {
218    /// Takes exactly `n` elements from the iterator, panicking if the iterator
219    /// is shorter than `n`. Returns an iterator that implements `ExactSizeIterator`.
220    fn take_exact(self, n: usize) -> TakeExactIter<Self>
221    where
222        Self: Sized,
223    {
224        TakeExactIter {
225            iter: self,
226            remaining: n,
227        }
228    }
229}
230
231impl<I: Iterator> TakeExact for I {}
232
233/// An iterator that takes exactly N elements from the underlying iterator.
234#[derive(Clone, Debug)]
235pub struct TakeExactIter<I> {
236    iter: I,
237    remaining: usize,
238}
239
240impl<I: Iterator> Iterator for TakeExactIter<I> {
241    type Item = I::Item;
242
243    fn next(&mut self) -> Option<Self::Item> {
244        if self.remaining == 0 {
245            return None;
246        }
247        self.remaining -= 1;
248        match self.iter.next() {
249            Some(item) => Some(item),
250            None => panic!("iterator shorter than expected length"),
251        }
252    }
253
254    fn size_hint(&self) -> (usize, Option<usize>) {
255        (self.remaining, Some(self.remaining))
256    }
257}
258
259impl<I: Iterator> ExactSizeIterator for TakeExactIter<I> {}
260
261#[cfg(test)]
262mod tests {
263    use rayon::prelude::*;
264
265    use crate::{izip_eq, izip_eq_lazy, par_izip, utils::TakeExact};
266
267    #[test]
268    fn test_izip_eq() {
269        let a = [1, 2, 3];
270        let b = [4, 5, 6];
271        let c = [7, 8, 9];
272
273        {
274            let mut results = [0, 0, 0];
275            for (r, aa, bb, cc) in izip_eq_lazy!(&mut results, &a, &b, &c) {
276                *r = aa + bb + cc;
277            }
278
279            assert_eq!(results, [1 + 4 + 7, 2 + 5 + 8, 3 + 6 + 9]);
280        }
281        {
282            let mut results = [0, 0, 0];
283            for (r, aa, bb) in izip_eq_lazy!(&mut results, &a, &b) {
284                *r = aa + bb;
285            }
286
287            assert_eq!(results, [1 + 4, 2 + 5, 3 + 6]);
288        }
289        {
290            let mut results = [0, 0, 0];
291            for (r, aa) in izip_eq_lazy!(&mut results, &a) {
292                *r = *aa;
293            }
294
295            assert_eq!(results, [1, 2, 3]);
296        }
297        {
298            let mut result = 0;
299            for aa in izip_eq_lazy!(&a) {
300                result += *aa;
301            }
302
303            assert_eq!(result, 1 + 2 + 3);
304        }
305        {
306            let mut results = [0, 0, 0];
307            for (r, aa, bb, cc) in izip_eq!(&mut results, &a, &b, &c) {
308                *r = aa + bb + cc;
309            }
310
311            assert_eq!(results, [1 + 4 + 7, 2 + 5 + 8, 3 + 6 + 9]);
312        }
313        {
314            let mut results = [0, 0, 0];
315            for (r, aa, bb) in izip_eq!(&mut results, &a, &b) {
316                *r = aa + bb;
317            }
318
319            assert_eq!(results, [1 + 4, 2 + 5, 3 + 6]);
320        }
321        {
322            let mut results = [0, 0, 0];
323            for (r, aa) in izip_eq!(&mut results, &a) {
324                *r = *aa;
325            }
326
327            assert_eq!(results, [1, 2, 3]);
328        }
329        {
330            let mut result = 0;
331            for aa in izip_eq!(&a) {
332                result += *aa;
333            }
334
335            assert_eq!(result, 1 + 2 + 3);
336        }
337    }
338
339    #[test]
340    #[should_panic(expected = "itertools: .zip_eq() reached end of one iterator before the other")]
341    fn test_izip_eq_lazy_panic() {
342        let a = [1, 2, 3];
343        let b = [4, 5];
344        let c = [7, 8, 9];
345
346        let mut results = [0, 0, 0];
347        for (r, aa, bb, cc) in izip_eq_lazy!(&mut results, &a, &b, &c) {
348            *r = aa + bb + cc;
349        }
350        unreachable!()
351    }
352
353    #[test]
354    #[should_panic(expected = "itertools: .zip_eq() reached end of one iterator before the other")]
355    fn test_izip_eq_lazy_panic_2() {
356        let a = [1, 2, 3];
357        let b = [4, 5, 6];
358        let c = [7, 8, 9];
359
360        let mut results = [0, 0];
361        for (r, aa, bb, cc) in izip_eq_lazy!(&mut results, &a, &b, &c) {
362            *r = aa + bb + cc;
363        }
364        unreachable!()
365    }
366
367    #[test]
368    #[should_panic(expected = "iterator length mismatch: 3 vs 2")]
369    fn test_izip_eq_eager_panic() {
370        let a = [1, 2, 3];
371        let b = [4, 5];
372        let c = [7, 8, 9];
373
374        let mut results = [0, 0, 0];
375        for (r, aa, bb, cc) in izip_eq!(&mut results, &a, &b, &c) {
376            *r = aa + bb + cc;
377        }
378        unreachable!()
379    }
380
381    #[test]
382    #[should_panic(expected = "iterator length mismatch: 2 vs 3")]
383    fn test_izip_eq_eager_panic_2() {
384        let a = [1, 2, 3];
385        let b = [4, 5, 6];
386        let c = [7, 8, 9];
387
388        let mut results = [0, 0];
389        for (r, aa, bb, cc) in izip_eq!(&mut results, &a, &b, &c) {
390            *r = aa + bb + cc;
391        }
392        unreachable!()
393    }
394
395    #[test]
396    #[should_panic(expected = "iterator length mismatch: 2 vs 3")]
397    fn test_izip_eq_eager_panic_3() {
398        let a = [1, 2, 3];
399
400        let mut results = [0, 0];
401        for (r, aa) in izip_eq!(&mut results, &a) {
402            *r = *aa;
403        }
404        unreachable!()
405    }
406
407    #[test]
408    #[should_panic(expected = "iterator length mismatch: 3 vs 2")]
409    fn test_izip_eq_eager_panic_4() {
410        let a = [1, 2];
411
412        let mut results = [0, 0, 0];
413        for (r, aa) in izip_eq!(&mut results, &a) {
414            *r = *aa;
415        }
416        unreachable!()
417    }
418
419    #[test]
420    fn test_zip_maps() {
421        let keys = vec![1, 2, 3];
422        let mut map1: std::collections::HashMap<_, _> = [(1, "a"), (2, "b"), (3, "c")].into();
423        let mut map2: std::collections::HashMap<_, _> = [(1, "x"), (2, "y"), (3, "z")].into();
424
425        let result: Vec<_> = zip_maps!(keys, map1, map2).collect();
426
427        assert_eq!(result, vec![(1, "a", "x"), (2, "b", "y"), (3, "c", "z"),]);
428    }
429
430    #[test]
431    #[should_panic(expected = "Key `4` not found in map `map1`")]
432    fn test_zip_maps_missing_key_panic() {
433        let keys = vec![1, 2, 3, 4];
434        let mut map1: std::collections::HashMap<_, _> = [(1, "a"), (2, "b"), (3, "c")].into();
435        let mut map2: std::collections::HashMap<_, _> = [(1, "x"), (2, "y"), (3, "z")].into();
436
437        let _result: Vec<_> = zip_maps!(keys, map1, map2).collect();
438    }
439
440    #[test]
441    fn test_take_exact() {
442        let v = vec![1, 2, 3, 4, 5];
443        let mut iter = v.into_iter().take_exact(3);
444
445        assert_eq!(iter.len(), 3);
446        assert_eq!(iter.next(), Some(1));
447        assert_eq!(iter.len(), 2);
448        assert_eq!(iter.next(), Some(2));
449        assert_eq!(iter.len(), 1);
450        assert_eq!(iter.next(), Some(3));
451        assert_eq!(iter.len(), 0);
452        assert_eq!(iter.next(), None);
453    }
454
455    #[test]
456    #[should_panic(expected = "iterator shorter than expected length")]
457    fn test_take_exact_panic() {
458        let v = vec![1, 2];
459        let mut iter = v.into_iter().take_exact(3);
460
461        assert_eq!(iter.next(), Some(1));
462        assert_eq!(iter.next(), Some(2));
463        iter.next(); // This should panic
464    }
465
466    #[test]
467    fn test_take_exact_collect() {
468        let v = vec![1, 2, 3, 4, 5];
469        let result: Vec<_> = v.into_iter().take_exact(3).collect();
470        assert_eq!(result, vec![1, 2, 3]);
471    }
472
473    #[test]
474    fn test_take_exact_size_hint() {
475        let v = vec![1, 2, 3, 4, 5];
476        let iter = v.into_iter().take_exact(3);
477        assert_eq!(iter.size_hint(), (3, Some(3)));
478    }
479
480    #[test]
481    fn test_par_izip_unary() {
482        let a = vec![1, 2, 3];
483        let result: Vec<_> = par_izip!(&a).map(|x| x * 2).collect();
484        assert_eq!(result, vec![2, 4, 6]);
485    }
486
487    #[test]
488    fn test_par_izip_binary() {
489        let a = vec![1, 2, 3];
490        let b = vec![4, 5, 6];
491        let result: Vec<_> = par_izip!(&a, &b).map(|(x, y)| x + y).collect();
492        assert_eq!(result, vec![5, 7, 9]);
493    }
494
495    #[test]
496    fn test_par_izip_ternary() {
497        let a = vec![1, 2, 3];
498        let b = vec![4, 5, 6];
499        let c = vec![7, 8, 9];
500        let result: Vec<_> = par_izip!(&a, &b, &c).map(|(x, y, z)| x + y + z).collect();
501        assert_eq!(result, vec![12, 15, 18]);
502    }
503
504    #[test]
505    fn test_par_izip_quaternary() {
506        let a = vec![1, 2, 3];
507        let b = vec![4, 5, 6];
508        let c = vec![7, 8, 9];
509        let d = vec![10, 11, 12];
510        let result: Vec<_> = par_izip!(&a, &b, &c, &d)
511            .map(|(w, x, y, z)| w + x + y + z)
512            .collect();
513        assert_eq!(result, vec![22, 26, 30]);
514    }
515
516    #[test]
517    fn test_par_izip_with_mutation() {
518        let a = vec![1, 2, 3];
519        let b = vec![4, 5, 6];
520        let mut results = vec![0, 0, 0];
521
522        par_izip!(&mut results, &a, &b).for_each(|(r, a, b)| {
523            *r = a + b;
524        });
525
526        assert_eq!(results, vec![5, 7, 9]);
527    }
528
529    #[test]
530    #[should_panic(expected = "parallel iterator length mismatch: 3 vs 2")]
531    fn test_par_izip_length_mismatch_binary() {
532        let a = vec![1, 2, 3];
533        let b = vec![4, 5];
534        let _: Vec<_> = par_izip!(&a, &b).collect();
535    }
536
537    #[test]
538    #[should_panic(expected = "parallel iterator length mismatch: 3 vs 2")]
539    fn test_par_izip_length_mismatch_ternary_first() {
540        let a = vec![1, 2, 3];
541        let b = vec![4, 5];
542        let c = vec![7, 8, 9];
543        let _: Vec<_> = par_izip!(&a, &b, &c).collect();
544    }
545
546    #[test]
547    #[should_panic(expected = "parallel iterator length mismatch: 3 vs 2")]
548    fn test_par_izip_length_mismatch_ternary_second() {
549        let a = vec![1, 2, 3];
550        let b = vec![4, 5, 6];
551        let c = vec![7, 8];
552        let _: Vec<_> = par_izip!(&a, &b, &c).collect();
553    }
554
555    #[test]
556    fn test_par_izip_empty() {
557        let a: Vec<i32> = vec![];
558        let b: Vec<i32> = vec![];
559        let result: Vec<_> = par_izip!(&a, &b).collect();
560        assert_eq!(result, vec![]);
561    }
562
563    #[test]
564    fn test_par_izip_large() {
565        // Test with larger dataset to ensure parallel execution works
566        let a: Vec<_> = (0..1000).collect();
567        let b: Vec<_> = (1000..2000).collect();
568        let c: Vec<_> = (2000..3000).collect();
569
570        let result: Vec<_> = par_izip!(&a, &b, &c).map(|(x, y, z)| x + y + z).collect();
571
572        assert_eq!(result.len(), 1000);
573        assert_eq!(result[0], 1000 + 2000);
574        assert_eq!(result[999], 999 + 1999 + 2999);
575    }
576}