mdarray/expr/
iter.rs

1use core::fmt::{Debug, Formatter, Result};
2use core::iter::FusedIterator;
3
4use crate::dim::Dims;
5use crate::expr::expression::Expression;
6use crate::shape::Shape;
7
8/// Iterator type for array expressions.
9#[derive(Clone)]
10pub struct Iter<E: Expression> {
11    expr: E,
12    inner_index: usize,
13    inner_limit: usize,
14    outer_index: <E::Shape as Shape>::Dims<usize>,
15    outer_limit: <E::Shape as Shape>::Dims<usize>,
16}
17
18impl<E: Expression> Iter<E> {
19    pub(crate) fn new(expr: E) -> Self {
20        let outer_rank = expr.rank().saturating_sub(expr.inner_rank());
21
22        let inner_index = 0;
23        let inner_limit = expr.shape().with_dims(|dims| dims[outer_rank..].iter().product());
24
25        let mut outer_index = Default::default();
26        let mut outer_limit = Default::default();
27
28        if outer_rank > 0 {
29            outer_index = Dims::new(expr.rank());
30            outer_limit =
31                expr.shape().with_dims(|dims| TryFrom::try_from(dims).expect("invalid rank"));
32        }
33
34        Self { expr, inner_index, inner_limit, outer_index, outer_limit }
35    }
36
37    unsafe fn step_outer(&mut self) -> bool {
38        let outer_rank = self.expr.rank().saturating_sub(self.expr.inner_rank());
39
40        unsafe {
41            // If the inner rank is >0, reset the last dimension when stepping outer dimensions.
42            // This is needed in the FromFn implementation.
43            if outer_rank < self.expr.rank() {
44                self.expr.reset_dim(self.expr.rank() - 1, 0);
45            }
46
47            for i in (0..outer_rank).rev() {
48                if self.outer_index.as_ref()[i] + 1 < self.outer_limit.as_ref()[i] {
49                    self.expr.step_dim(i);
50                    self.outer_index.as_mut()[i] += 1;
51
52                    return true;
53                }
54
55                self.expr.reset_dim(i, self.outer_index.as_ref()[i]);
56                self.outer_index.as_mut()[i] = 0;
57            }
58        }
59
60        self.outer_index.as_mut().fill(0); // Ensure that following calls return false.
61
62        false
63    }
64}
65
66impl<E: Expression + Debug> Debug for Iter<E> {
67    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
68        assert!(self.inner_index == 0, "iterator in use");
69
70        f.debug_tuple("Iter").field(&self.expr).finish()
71    }
72}
73
74impl<E: Expression> ExactSizeIterator for Iter<E> {}
75impl<E: Expression> FusedIterator for Iter<E> {}
76
77impl<E: Expression> Iterator for Iter<E> {
78    type Item = E::Item;
79
80    fn fold<T, F: FnMut(T, Self::Item) -> T>(mut self, init: T, mut f: F) -> T {
81        let mut accum = init;
82
83        loop {
84            for i in self.inner_index..self.inner_limit {
85                accum = f(accum, unsafe { self.expr.get_unchecked(i) });
86            }
87
88            if unsafe { !self.step_outer() } {
89                return accum;
90            }
91
92            self.inner_index = 0;
93        }
94    }
95
96    fn next(&mut self) -> Option<Self::Item> {
97        if self.inner_index == self.inner_limit {
98            if unsafe { !self.step_outer() } {
99                return None;
100            }
101
102            self.inner_index = 0;
103        }
104
105        self.inner_index += 1;
106
107        unsafe { Some(self.expr.get_unchecked(self.inner_index - 1)) }
108    }
109
110    fn size_hint(&self) -> (usize, Option<usize>) {
111        let outer_rank = self.expr.rank().saturating_sub(self.expr.inner_rank());
112        let mut len = 1;
113
114        for i in 0..outer_rank {
115            len = len * self.outer_limit.as_ref()[i] - self.outer_index.as_ref()[i];
116        }
117
118        len = len * self.inner_limit - self.inner_index;
119
120        (len, Some(len))
121    }
122}