1use std::iter::FusedIterator;
2
3use crate::Layout;
4
5pub const MAX_DIMS: usize = 8;
7
8pub struct NdIter<const N: usize> {
31 pub inner_size: usize,
32 pub inner_strides: [usize; N],
34
35 outer_dims: [usize; MAX_DIMS],
36 outer_strides: [[usize; MAX_DIMS]; N],
37 outer_len: usize,
38
39 offsets: [usize; N],
40 coords: [usize; MAX_DIMS],
41 remaining: usize,
42}
43
44impl<const N: usize> NdIter<N> {
45 pub fn new(layouts: [&Layout; N]) -> NdIter<N> {
46 let dims = layouts[0].dims();
47 debug_assert!(
48 dims.len() <= MAX_DIMS,
49 "rank {} exceeds MAX_DIMS={}",
50 dims.len(),
51 MAX_DIMS
52 );
53 #[cfg(debug_assertions)]
54 for l in &layouts {
55 debug_assert_eq!(l.dims(), dims);
56 }
57
58 let rank = dims.len();
59 let mut out_dims = [0usize; MAX_DIMS];
60 let mut out_strides = [[0usize; MAX_DIMS]; N];
61 let mut out_len;
62
63 if rank == 0 {
64 out_dims[0] = 1;
65 out_len = 1;
66 } else {
67 out_dims[0] = dims[0];
68 for n in 0..N {
69 out_strides[n][0] = layouts[n].stride()[0];
70 }
71 out_len = 1;
72
73 for (i, d) in dims.iter().enumerate().take(rank).skip(1) {
74 let top = out_len - 1;
75 let last_d = out_dims[top];
76
77 let (can_merge, use_inner) = if last_d == 1 {
78 (true, true)
79 } else if *d == 1 {
80 (true, false)
81 } else {
82 let can_merge =
83 (0..N).all(|n| out_strides[n][top] == layouts[n].stride()[i] * d);
84 (can_merge, true)
85 };
86 if can_merge {
87 out_dims[top] = last_d * d;
88 if use_inner {
89 for n in 0..N {
90 out_strides[n][top] = layouts[n].stride()[i];
91 }
92 }
93 } else {
94 out_dims[out_len] = *d;
95 for n in 0..N {
96 out_strides[n][out_len] = layouts[n].stride()[i];
97 }
98 out_len += 1;
99 }
100 }
101 }
102
103 let inner_idx = out_len - 1;
105 let inner_size = out_dims[inner_idx];
106 let mut inner_strides = [0usize; N];
107 for n in 0..N {
108 inner_strides[n] = out_strides[n][inner_idx];
109 }
110
111 let outer_len = inner_idx;
113 let mut outer_dims = [0usize; MAX_DIMS];
114 outer_dims[..outer_len].copy_from_slice(&out_dims[..outer_len]);
115
116 let mut outer_strides = [[0usize; MAX_DIMS]; N];
118 for n in 0..N {
119 outer_strides[n][..outer_len].copy_from_slice(&out_strides[n][..outer_len]);
120 }
121
122 let mut offsets = [0usize; N];
124 for n in 0..N {
125 offsets[n] = layouts[n].start_offset();
126 }
127
128 let remaining = out_dims[..outer_len].iter().product();
130
131 NdIter {
132 inner_size,
133 inner_strides,
134 outer_dims,
135 outer_strides,
136 outer_len,
137 offsets,
138 coords: [0; MAX_DIMS],
139 remaining,
140 }
141 }
142}
143
144impl<const N: usize> Iterator for NdIter<N> {
145 type Item = [usize; N];
146
147 #[inline]
148 fn next(&mut self) -> Option<Self::Item> {
149 if self.remaining == 0 {
150 return None;
151 }
152 let item = self.offsets;
153 self.remaining -= 1;
154
155 for k in (0..self.outer_len).rev() {
156 self.coords[k] += 1;
157 for n in 0..N {
158 self.offsets[n] += self.outer_strides[n][k];
159 }
160 if self.coords[k] < self.outer_dims[k] {
161 break;
162 }
163 self.coords[k] = 0;
164 for n in 0..N {
165 self.offsets[n] -= self.outer_dims[k] * self.outer_strides[n][k];
166 }
167 }
168
169 Some(item)
170 }
171
172 #[inline]
173 fn size_hint(&self) -> (usize, Option<usize>) {
174 (self.remaining, Some(self.remaining))
175 }
176}
177
178impl<const N: usize> ExactSizeIterator for NdIter<N> {}
179impl<const N: usize> FusedIterator for NdIter<N> {}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use crate::shape::Shape;
185
186 fn layout(dims: &[usize], strides: &[usize]) -> Layout {
187 Layout::new(Shape::from(dims.to_vec()), strides.to_vec(), 0)
188 }
189
190 #[test]
191 fn rank0_scalar() {
192 let l = Layout::contiguous(());
193 let mut it = NdIter::new([&l, &l]);
194 assert_eq!(it.inner_size, 1);
195 assert_eq!(it.inner_strides, [0, 0]);
196 assert_eq!(it.len(), 1);
197 assert_eq!(it.next(), Some([0, 0]));
198 assert_eq!(it.next(), None);
199 }
200
201 #[test]
202 fn rank1_contiguous_single_block() {
203 let l = Layout::contiguous(&[5]);
204 let mut it = NdIter::new([&l, &l]);
205 assert_eq!(it.inner_size, 5);
206 assert_eq!(it.inner_strides, [1, 1]);
207 assert_eq!(it.len(), 1);
208 assert_eq!(it.next(), Some([0, 0]));
209 assert_eq!(it.next(), None);
210 }
211
212 #[test]
213 fn rank2_contiguous_merges_to_one_block() {
214 let l = Layout::contiguous(&[3, 4]);
216 let it = NdIter::new([&l, &l]);
217 assert_eq!(it.inner_size, 12);
218 assert_eq!(it.inner_strides, [1, 1]);
219 assert_eq!(it.len(), 1);
220 }
221
222 #[test]
223 fn rank3_contiguous_fully_merged() {
224 let l = Layout::contiguous(&[2, 3, 4]);
225 let it = NdIter::new([&l]);
226 assert_eq!(it.inner_size, 24);
227 assert_eq!(it.inner_strides, [1]);
228 assert_eq!(it.len(), 1);
229 }
230
231 #[test]
232 fn rank3_outer_gap_partial_merge() {
233 let l = layout(&[2, 3, 4], &[24, 4, 1]);
237 let it = NdIter::new([&l]);
238 assert_eq!(it.inner_size, 12); assert_eq!(it.inner_strides, [1]);
240 assert_eq!(it.len(), 2); let offsets: Vec<_> = it.collect();
242 assert_eq!(offsets, vec![[0], [24]]);
243 }
244
245 #[test]
246 fn rank2_no_merge() {
247 let l = layout(&[3, 4], &[1, 3]);
251 let it = NdIter::new([&l, &l]);
252 assert_eq!(it.inner_size, 4);
253 assert_eq!(it.inner_strides, [3, 3]);
254 assert_eq!(it.len(), 3);
255 let offsets: Vec<_> = it.collect();
256 assert_eq!(offsets, vec![[0, 0], [1, 1], [2, 2]]);
257 }
258
259 #[test]
260 fn broadcast_zeros_merge() {
261 let l = layout(&[3, 4], &[0, 0]);
265 let it = NdIter::new([&l, &l]);
266 assert_eq!(it.inner_size, 12);
267 assert_eq!(it.inner_strides, [0, 0]);
268 assert_eq!(it.len(), 1);
269 }
270
271 #[test]
272 fn mixed_contiguous_and_broadcast_merge() {
273 let lhs = Layout::contiguous(&[3, 4]);
277 let rhs = layout(&[3, 4], &[0, 0]);
278 let it = NdIter::new([&lhs, &rhs]);
279 assert_eq!(it.inner_size, 12);
280 assert_eq!(it.inner_strides, [1, 0]);
281 assert_eq!(it.len(), 1);
282 }
283
284 #[test]
285 fn offsets_lhs_contiguous_rhs_strided() {
286 let lhs = Layout::contiguous(&[2, 3]);
291 let rhs = layout(&[2, 3], &[1, 2]);
292 let it = NdIter::new([&lhs, &rhs]);
293 assert_eq!(it.inner_size, 3);
294 assert_eq!(it.inner_strides, [1, 2]);
295 assert_eq!(it.len(), 2);
296 let offsets: Vec<_> = it.collect();
297 assert_eq!(offsets, vec![[0, 0], [3, 1]]);
298 }
299
300 #[test]
301 fn start_offset_reflected_in_first_iter() {
302 let l = Layout::contiguous_with_offset(4, 7);
303 let mut it = NdIter::new([&l]);
304 assert_eq!(it.next(), Some([7]));
305 assert_eq!(it.next(), None);
306 }
307
308 #[test]
309 fn start_offset_advances_with_outer_dims() {
310 let l = Layout::new(Shape::from(vec![2, 3]), vec![4, 1], 10);
316 let offsets: Vec<_> = NdIter::new([&l]).collect();
317 assert_eq!(offsets, vec![[10], [14]]);
318 }
319}