1use core::fmt::{Debug, Formatter, Result};
2use core::iter::FusedIterator;
3
4use crate::dim::Dims;
5use crate::expr::expression::Expression;
6use crate::shape::Shape;
7
8#[derive(Clone)]
10pub struct Iter<E: Expression> {
11 expr: E,
12 inner_index: usize,
13 inner_limit: usize,
14 outer_index: <E::Shape as Shape>::Dims<usize>,
15 outer_limit: <E::Shape as Shape>::Dims<usize>,
16}
17
18impl<E: Expression> Iter<E> {
19 pub(crate) fn new(expr: E) -> Self {
20 let outer_rank = expr.rank().saturating_sub(expr.inner_rank());
21
22 let inner_index = 0;
23 let inner_limit = expr.shape().with_dims(|dims| dims[outer_rank..].iter().product());
24
25 let mut outer_index = Default::default();
26 let mut outer_limit = Default::default();
27
28 if outer_rank > 0 {
29 outer_index = Dims::new(expr.rank());
30 outer_limit =
31 expr.shape().with_dims(|dims| TryFrom::try_from(dims).expect("invalid rank"));
32 }
33
34 Self { expr, inner_index, inner_limit, outer_index, outer_limit }
35 }
36
37 unsafe fn step_outer(&mut self) -> bool {
38 let outer_rank = self.expr.rank().saturating_sub(self.expr.inner_rank());
39
40 unsafe {
41 if outer_rank < self.expr.rank() {
44 self.expr.reset_dim(self.expr.rank() - 1, 0);
45 }
46
47 for i in (0..outer_rank).rev() {
48 if self.outer_index.as_ref()[i] + 1 < self.outer_limit.as_ref()[i] {
49 self.expr.step_dim(i);
50 self.outer_index.as_mut()[i] += 1;
51
52 return true;
53 }
54
55 self.expr.reset_dim(i, self.outer_index.as_ref()[i]);
56 self.outer_index.as_mut()[i] = 0;
57 }
58 }
59
60 self.outer_index.as_mut().fill(0); false
63 }
64}
65
66impl<E: Expression + Debug> Debug for Iter<E> {
67 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
68 assert!(self.inner_index == 0, "iterator in use");
69
70 f.debug_tuple("Iter").field(&self.expr).finish()
71 }
72}
73
74impl<E: Expression> ExactSizeIterator for Iter<E> {}
75impl<E: Expression> FusedIterator for Iter<E> {}
76
77impl<E: Expression> Iterator for Iter<E> {
78 type Item = E::Item;
79
80 fn fold<T, F: FnMut(T, Self::Item) -> T>(mut self, init: T, mut f: F) -> T {
81 let mut accum = init;
82
83 loop {
84 for i in self.inner_index..self.inner_limit {
85 accum = f(accum, unsafe { self.expr.get_unchecked(i) });
86 }
87
88 if unsafe { !self.step_outer() } {
89 return accum;
90 }
91
92 self.inner_index = 0;
93 }
94 }
95
96 fn next(&mut self) -> Option<Self::Item> {
97 if self.inner_index == self.inner_limit {
98 if unsafe { !self.step_outer() } {
99 return None;
100 }
101
102 self.inner_index = 0;
103 }
104
105 self.inner_index += 1;
106
107 unsafe { Some(self.expr.get_unchecked(self.inner_index - 1)) }
108 }
109
110 fn size_hint(&self) -> (usize, Option<usize>) {
111 let outer_rank = self.expr.rank().saturating_sub(self.expr.inner_rank());
112 let mut len = 1;
113
114 for i in 0..outer_rank {
115 len = len * self.outer_limit.as_ref()[i] - self.outer_index.as_ref()[i];
116 }
117
118 len = len * self.inner_limit - self.inner_index;
119
120 (len, Some(len))
121 }
122}