primitives/utils/
iter.rs

1#[macro_export]
2/// Create an iterator running multiple iterators in lockstep, panicking if they
3/// are not of the same length.
4///
5/// Length checks are done lazily as the resulting iterator is built.
6/// I.e. the `izip_eq_lazy!` iterator yields elements until any subiterator
7/// returns `None`.
8///
9/// **Note:** This macro is slower than `izip_eq!` however it allows to zip iterators of
10/// unknown or unbounded size.
11macro_rules! izip_eq_lazy {
12    // @closure creates a tuple-flattening closure for .map() call. usage:
13    // @closure partial_pattern => partial_tuple , rest , of , iterators
14    // eg. izip_eq_lazy!( @closure ((a, b), c) => (a, b, c) , dd , ee )
15    ( @closure $p:pat => $tup:expr ) => {
16        |$p| $tup
17    };
18
19    // The "b" identifier is a different identifier on each recursion level thanks to hygiene.
20    ( @closure $p:pat => ( $($tup:tt)* ) , $_iter:expr $( , $tail:expr )* ) => {
21        $crate::izip_eq_lazy!(@closure ($p, b) => ( $($tup)*, b ) $( , $tail )*)
22    };
23
24    // unary
25    ($first:expr $(,)*) => {
26        std::iter::IntoIterator::into_iter($first)
27    };
28
29    // binary
30    ($first:expr, $second:expr $(,)*) => {
31        itertools::zip_eq(
32            std::iter::IntoIterator::into_iter($first),
33            $second,
34        )
35    };
36
37    // n-ary where n > 2
38    ( $first:expr $( , $rest:expr )* $(,)* ) => {
39        {
40            let iter = std::iter::IntoIterator::into_iter($first);
41            $(
42                let iter = itertools::zip_eq(iter, $rest);
43            )*
44            std::iter::Iterator::map(
45                iter,
46                $crate::izip_eq_lazy!(@closure a => (a) $( , $rest )*)
47            )
48        }
49    };
50}
51
52#[macro_export]
53/// Create an iterator running multiple iterators in lockstep, panicking if they
54/// are not the same length.
55///
56/// Length checks are eagerly performed before the resulting iterator is built.
57///
58/// **Note:** For performance reasons, this macro is preferable to `izip_eq_lazy!` when the lengths
59/// of the iterators are known.
60macro_rules! izip_eq {
61    (@assert_eq_len $first:expr, $second:expr) => {
62        assert!($first.len() == $second.len(), "iterator length mismatch: {} vs {}", $first.len(), $second.len());
63    };
64
65    // unary
66    ($first:expr $(,)*) => {
67        std::iter::IntoIterator::into_iter($first)
68    };
69
70    // binary
71    ($first:expr, $second:expr $(,)*) => {
72        {
73            let iter = std::iter::IntoIterator::into_iter($first);
74            let second = std::iter::IntoIterator::into_iter($second);
75            $crate::izip_eq!(@assert_eq_len iter, second);
76            let iter = std::iter::Iterator::zip(iter, second);
77            iter
78        }
79    };
80
81    // n-ary where n > 2
82    ( $first:expr $( , $rest:expr )* $(,)* ) => {
83        {
84            let iter = std::iter::IntoIterator::into_iter($first);
85            $(
86                let rest = std::iter::IntoIterator::into_iter($rest);
87                $crate::izip_eq!(@assert_eq_len iter, rest);
88                let iter = std::iter::Iterator::zip(iter, rest);
89            )*
90            std::iter::Iterator::map(
91                iter,
92                $crate::izip_eq_lazy!(@closure a => (a) $( , $rest )*)
93            )
94        }
95    };
96}
97
98#[macro_export]
99/// Create a parallel iterator running multiple iterators in lockstep, panicking if they
100/// are not the same length.
101///
102/// This is the rayon parallel version of `izip_eq!`. It uses `into_par_iter()` and performs
103/// length checks eagerly before building the parallel iterator.
104///
105/// **Note:** All input iterators must implement `IntoParallelIterator` and the resulting
106/// parallel iterator will process elements in parallel using rayon.
107///
108/// # Examples
109///
110/// ```
111/// use primitives::par_izip;
112/// use rayon::prelude::*;
113///
114/// let a = vec![1, 2, 3];
115/// let b = vec![4, 5, 6];
116/// let c = vec![7, 8, 9];
117///
118/// let results: Vec<_> = par_izip!(&a, &b, &c).map(|(a, b, c)| a + b + c).collect();
119///
120/// assert_eq!(results, vec![12, 15, 18]);
121/// ```
122macro_rules! par_izip {
123    // Helper to create closure for tuple flattening, reuses izip_eq_lazy's closure builder
124    ( @closure $p:pat => $tup:expr ) => {
125        |$p| $tup
126    };
127
128    ( @closure $p:pat => ( $($tup:tt)* ) , $_iter:expr $( , $tail:expr )* ) => {
129        $crate::par_izip!(@closure ($p, b) => ( $($tup)*, b ) $( , $tail )*)
130    };
131
132    // Length assertion helper
133    (@assert_eq_len $first:expr, $second:expr) => {
134        assert!($first.len() == $second.len(), "parallel iterator length mismatch: {} vs {}", $first.len(), $second.len());
135    };
136
137    // unary
138    ($first:expr $(,)*) => {
139        rayon::iter::IntoParallelIterator::into_par_iter($first)
140    };
141
142    // binary
143    ($first:expr, $second:expr $(,)*) => {
144        {
145            let iter = rayon::iter::IntoParallelIterator::into_par_iter($first);
146            let second = rayon::iter::IntoParallelIterator::into_par_iter($second);
147            $crate::par_izip!(@assert_eq_len iter, second);
148            let iter = rayon::iter::IndexedParallelIterator::zip(iter, second);
149            iter
150        }
151    };
152
153    // n-ary where n > 2
154    ( $first:expr $( , $rest:expr )* $(,)* ) => {
155        {
156            let iter = rayon::iter::IntoParallelIterator::into_par_iter($first);
157            $(
158                let rest = rayon::iter::IntoParallelIterator::into_par_iter($rest);
159                $crate::par_izip!(@assert_eq_len iter, rest);
160                let iter = rayon::iter::IndexedParallelIterator::zip(iter, rest);
161            )*
162            rayon::iter::ParallelIterator::map(
163                iter,
164                $crate::par_izip!(@closure a => (a) $( , $rest )*)
165            )
166        }
167    };
168}
169
170#[macro_export]
171/// An adaptor for `itertools::chain!` macro which creates an ExactSizeIterator.
172///
173/// **Note:** This macro collects the output of `itertools::chain!` into a vector. This may prevent
174/// some compiler optimizations implying iterators.
175macro_rules! chain_eq {
176    ($first:expr $( , $rest:expr )* $(,)*) => {{
177        let it = itertools::chain!($first, $($rest, )*);
178        let vec = <std::vec::Vec<_> as std::iter::FromIterator<_>>::from_iter(it);
179        <std::vec::Vec<_> as std::iter::IntoIterator>::into_iter(vec)
180    }
181    };
182}
183
184#[macro_export]
185/// Create an iterator over hashmaps with the same keys. Accepts a list of keys, and the maps
186/// to be iterated over. Panics if any of the keys are not found in the maps.
187macro_rules! zip_maps {
188    ($keys:expr, $($map:expr),+ $(,)?) => {
189        $keys.into_iter().map(move |key| {
190            (
191                key,
192                $(
193                    $map.remove(&key).expect(&format!(
194                        "Key `{:?}` not found in map `{}`",
195                        key,
196                        stringify!($map),
197                    )),
198                )+
199            )
200        })
201    };
202}
203
204pub trait IntoExactSizeIterator:
205    IntoIterator<IntoIter: ExactSizeIterator<Item = <Self as IntoIterator>::Item>>
206{
207}
208
209impl<T: IntoIterator<IntoIter = S>, S: ExactSizeIterator<Item = <T as IntoIterator>::Item>>
210    IntoExactSizeIterator for T
211{
212}
213
214/// A trait for iterators that take exactly N elements, panicking if the iterator
215/// is shorter than N.
216pub trait TakeExact: Iterator {
217    /// Takes exactly `n` elements from the iterator, panicking if the iterator
218    /// is shorter than `n`. Returns an iterator that implements `ExactSizeIterator`.
219    fn take_exact(self, n: usize) -> TakeExactIter<Self>
220    where
221        Self: Sized,
222    {
223        TakeExactIter {
224            iter: self,
225            remaining: n,
226        }
227    }
228}
229
230impl<I: Iterator> TakeExact for I {}
231
232/// An iterator that takes exactly N elements from the underlying iterator.
233#[derive(Clone, Debug)]
234pub struct TakeExactIter<I> {
235    iter: I,
236    remaining: usize,
237}
238
239impl<I: Iterator> Iterator for TakeExactIter<I> {
240    type Item = I::Item;
241
242    fn next(&mut self) -> Option<Self::Item> {
243        if self.remaining == 0 {
244            return None;
245        }
246        self.remaining -= 1;
247        match self.iter.next() {
248            Some(item) => Some(item),
249            None => panic!("iterator shorter than expected length"),
250        }
251    }
252
253    fn size_hint(&self) -> (usize, Option<usize>) {
254        (self.remaining, Some(self.remaining))
255    }
256}
257
258impl<I: Iterator> ExactSizeIterator for TakeExactIter<I> {}
259
260#[cfg(test)]
261mod tests {
262    use rayon::prelude::*;
263
264    use crate::{izip_eq, izip_eq_lazy, par_izip, utils::TakeExact};
265
266    #[test]
267    fn test_izip_eq() {
268        let a = [1, 2, 3];
269        let b = [4, 5, 6];
270        let c = [7, 8, 9];
271
272        {
273            let mut results = [0, 0, 0];
274            for (r, aa, bb, cc) in izip_eq_lazy!(&mut results, &a, &b, &c) {
275                *r = aa + bb + cc;
276            }
277
278            assert_eq!(results, [1 + 4 + 7, 2 + 5 + 8, 3 + 6 + 9]);
279        }
280        {
281            let mut results = [0, 0, 0];
282            for (r, aa, bb) in izip_eq_lazy!(&mut results, &a, &b) {
283                *r = aa + bb;
284            }
285
286            assert_eq!(results, [1 + 4, 2 + 5, 3 + 6]);
287        }
288        {
289            let mut results = [0, 0, 0];
290            for (r, aa) in izip_eq_lazy!(&mut results, &a) {
291                *r = *aa;
292            }
293
294            assert_eq!(results, [1, 2, 3]);
295        }
296        {
297            let mut result = 0;
298            for aa in izip_eq_lazy!(&a) {
299                result += *aa;
300            }
301
302            assert_eq!(result, 1 + 2 + 3);
303        }
304        {
305            let mut results = [0, 0, 0];
306            for (r, aa, bb, cc) in izip_eq!(&mut results, &a, &b, &c) {
307                *r = aa + bb + cc;
308            }
309
310            assert_eq!(results, [1 + 4 + 7, 2 + 5 + 8, 3 + 6 + 9]);
311        }
312        {
313            let mut results = [0, 0, 0];
314            for (r, aa, bb) in izip_eq!(&mut results, &a, &b) {
315                *r = aa + bb;
316            }
317
318            assert_eq!(results, [1 + 4, 2 + 5, 3 + 6]);
319        }
320        {
321            let mut results = [0, 0, 0];
322            for (r, aa) in izip_eq!(&mut results, &a) {
323                *r = *aa;
324            }
325
326            assert_eq!(results, [1, 2, 3]);
327        }
328        {
329            let mut result = 0;
330            for aa in izip_eq!(&a) {
331                result += *aa;
332            }
333
334            assert_eq!(result, 1 + 2 + 3);
335        }
336    }
337
338    #[test]
339    #[should_panic(expected = "itertools: .zip_eq() reached end of one iterator before the other")]
340    fn test_izip_eq_lazy_panic() {
341        let a = [1, 2, 3];
342        let b = [4, 5];
343        let c = [7, 8, 9];
344
345        let mut results = [0, 0, 0];
346        for (r, aa, bb, cc) in izip_eq_lazy!(&mut results, &a, &b, &c) {
347            *r = aa + bb + cc;
348        }
349        unreachable!()
350    }
351
352    #[test]
353    #[should_panic(expected = "itertools: .zip_eq() reached end of one iterator before the other")]
354    fn test_izip_eq_lazy_panic_2() {
355        let a = [1, 2, 3];
356        let b = [4, 5, 6];
357        let c = [7, 8, 9];
358
359        let mut results = [0, 0];
360        for (r, aa, bb, cc) in izip_eq_lazy!(&mut results, &a, &b, &c) {
361            *r = aa + bb + cc;
362        }
363        unreachable!()
364    }
365
366    #[test]
367    #[should_panic(expected = "iterator length mismatch: 3 vs 2")]
368    fn test_izip_eq_eager_panic() {
369        let a = [1, 2, 3];
370        let b = [4, 5];
371        let c = [7, 8, 9];
372
373        let mut results = [0, 0, 0];
374        for (r, aa, bb, cc) in izip_eq!(&mut results, &a, &b, &c) {
375            *r = aa + bb + cc;
376        }
377        unreachable!()
378    }
379
380    #[test]
381    #[should_panic(expected = "iterator length mismatch: 2 vs 3")]
382    fn test_izip_eq_eager_panic_2() {
383        let a = [1, 2, 3];
384        let b = [4, 5, 6];
385        let c = [7, 8, 9];
386
387        let mut results = [0, 0];
388        for (r, aa, bb, cc) in izip_eq!(&mut results, &a, &b, &c) {
389            *r = aa + bb + cc;
390        }
391        unreachable!()
392    }
393
394    #[test]
395    #[should_panic(expected = "iterator length mismatch: 2 vs 3")]
396    fn test_izip_eq_eager_panic_3() {
397        let a = [1, 2, 3];
398
399        let mut results = [0, 0];
400        for (r, aa) in izip_eq!(&mut results, &a) {
401            *r = *aa;
402        }
403        unreachable!()
404    }
405
406    #[test]
407    #[should_panic(expected = "iterator length mismatch: 3 vs 2")]
408    fn test_izip_eq_eager_panic_4() {
409        let a = [1, 2];
410
411        let mut results = [0, 0, 0];
412        for (r, aa) in izip_eq!(&mut results, &a) {
413            *r = *aa;
414        }
415        unreachable!()
416    }
417
418    #[test]
419    fn test_zip_maps() {
420        let keys = vec![1, 2, 3];
421        let mut map1: std::collections::HashMap<_, _> = [(1, "a"), (2, "b"), (3, "c")].into();
422        let mut map2: std::collections::HashMap<_, _> = [(1, "x"), (2, "y"), (3, "z")].into();
423
424        let result: Vec<_> = zip_maps!(keys, map1, map2).collect();
425
426        assert_eq!(result, vec![(1, "a", "x"), (2, "b", "y"), (3, "c", "z"),]);
427    }
428
429    #[test]
430    #[should_panic(expected = "Key `4` not found in map `map1`")]
431    fn test_zip_maps_missing_key_panic() {
432        let keys = vec![1, 2, 3, 4];
433        let mut map1: std::collections::HashMap<_, _> = [(1, "a"), (2, "b"), (3, "c")].into();
434        let mut map2: std::collections::HashMap<_, _> = [(1, "x"), (2, "y"), (3, "z")].into();
435
436        let _result: Vec<_> = zip_maps!(keys, map1, map2).collect();
437    }
438
439    #[test]
440    fn test_take_exact() {
441        let v = vec![1, 2, 3, 4, 5];
442        let mut iter = v.into_iter().take_exact(3);
443
444        assert_eq!(iter.len(), 3);
445        assert_eq!(iter.next(), Some(1));
446        assert_eq!(iter.len(), 2);
447        assert_eq!(iter.next(), Some(2));
448        assert_eq!(iter.len(), 1);
449        assert_eq!(iter.next(), Some(3));
450        assert_eq!(iter.len(), 0);
451        assert_eq!(iter.next(), None);
452    }
453
454    #[test]
455    #[should_panic(expected = "iterator shorter than expected length")]
456    fn test_take_exact_panic() {
457        let v = vec![1, 2];
458        let mut iter = v.into_iter().take_exact(3);
459
460        assert_eq!(iter.next(), Some(1));
461        assert_eq!(iter.next(), Some(2));
462        iter.next(); // This should panic
463    }
464
465    #[test]
466    fn test_take_exact_collect() {
467        let v = vec![1, 2, 3, 4, 5];
468        let result: Vec<_> = v.into_iter().take_exact(3).collect();
469        assert_eq!(result, vec![1, 2, 3]);
470    }
471
472    #[test]
473    fn test_take_exact_size_hint() {
474        let v = vec![1, 2, 3, 4, 5];
475        let iter = v.into_iter().take_exact(3);
476        assert_eq!(iter.size_hint(), (3, Some(3)));
477    }
478
479    #[test]
480    fn test_par_izip_unary() {
481        let a = vec![1, 2, 3];
482        let result: Vec<_> = par_izip!(&a).map(|x| x * 2).collect();
483        assert_eq!(result, vec![2, 4, 6]);
484    }
485
486    #[test]
487    fn test_par_izip_binary() {
488        let a = vec![1, 2, 3];
489        let b = vec![4, 5, 6];
490        let result: Vec<_> = par_izip!(&a, &b).map(|(x, y)| x + y).collect();
491        assert_eq!(result, vec![5, 7, 9]);
492    }
493
494    #[test]
495    fn test_par_izip_ternary() {
496        let a = vec![1, 2, 3];
497        let b = vec![4, 5, 6];
498        let c = vec![7, 8, 9];
499        let result: Vec<_> = par_izip!(&a, &b, &c).map(|(x, y, z)| x + y + z).collect();
500        assert_eq!(result, vec![12, 15, 18]);
501    }
502
503    #[test]
504    fn test_par_izip_quaternary() {
505        let a = vec![1, 2, 3];
506        let b = vec![4, 5, 6];
507        let c = vec![7, 8, 9];
508        let d = vec![10, 11, 12];
509        let result: Vec<_> = par_izip!(&a, &b, &c, &d)
510            .map(|(w, x, y, z)| w + x + y + z)
511            .collect();
512        assert_eq!(result, vec![22, 26, 30]);
513    }
514
515    #[test]
516    fn test_par_izip_with_mutation() {
517        let a = vec![1, 2, 3];
518        let b = vec![4, 5, 6];
519        let mut results = vec![0, 0, 0];
520
521        par_izip!(&mut results, &a, &b).for_each(|(r, a, b)| {
522            *r = a + b;
523        });
524
525        assert_eq!(results, vec![5, 7, 9]);
526    }
527
528    #[test]
529    #[should_panic(expected = "parallel iterator length mismatch: 3 vs 2")]
530    fn test_par_izip_length_mismatch_binary() {
531        let a = vec![1, 2, 3];
532        let b = vec![4, 5];
533        let _: Vec<_> = par_izip!(&a, &b).collect();
534    }
535
536    #[test]
537    #[should_panic(expected = "parallel iterator length mismatch: 3 vs 2")]
538    fn test_par_izip_length_mismatch_ternary_first() {
539        let a = vec![1, 2, 3];
540        let b = vec![4, 5];
541        let c = vec![7, 8, 9];
542        let _: Vec<_> = par_izip!(&a, &b, &c).collect();
543    }
544
545    #[test]
546    #[should_panic(expected = "parallel iterator length mismatch: 3 vs 2")]
547    fn test_par_izip_length_mismatch_ternary_second() {
548        let a = vec![1, 2, 3];
549        let b = vec![4, 5, 6];
550        let c = vec![7, 8];
551        let _: Vec<_> = par_izip!(&a, &b, &c).collect();
552    }
553
554    #[test]
555    fn test_par_izip_empty() {
556        let a: Vec<i32> = vec![];
557        let b: Vec<i32> = vec![];
558        let result: Vec<_> = par_izip!(&a, &b).collect();
559        assert_eq!(result, vec![]);
560    }
561
562    #[test]
563    fn test_par_izip_large() {
564        // Test with larger dataset to ensure parallel execution works
565        let a: Vec<_> = (0..1000).collect();
566        let b: Vec<_> = (1000..2000).collect();
567        let c: Vec<_> = (2000..3000).collect();
568
569        let result: Vec<_> = par_izip!(&a, &b, &c).map(|(x, y, z)| x + y + z).collect();
570
571        assert_eq!(result.len(), 1000);
572        assert_eq!(result[0], 1000 + 2000);
573        assert_eq!(result[999], 999 + 1999 + 2999);
574    }
575}