1use std::iter::FusedIterator;
2use std::ops::Range;
3
4use smallvec::{SmallVec, smallvec};
5
6pub trait IndexArray: AsMut<[usize]> + AsRef<[usize]> + Clone {}
7impl<const N: usize> IndexArray for SmallVec<[usize; N]> {}
8impl<const N: usize> IndexArray for [usize; N] {}
9
10pub type DynIndex = SmallVec<[usize; 5]>;
12
13pub struct Indices<Index: IndexArray>
19where
20 Index: IndexArray,
21{
22 start: Index,
24
25 end: Index,
27
28 next: Option<Index>,
29
30 steps: usize,
32}
33
34fn steps(from: &[usize], to: &[usize]) -> usize {
40 assert!(from.len() == to.len());
41 let mut product = 1;
42 for (&from, &to) in from.iter().zip(to.iter()).rev() {
43 let size = to.saturating_sub(from);
44 product *= size;
45 }
46 product
47}
48
49impl<Index: IndexArray> Indices<Index> {
50 fn from_start_and_end(start: Index, end: Index) -> Indices<Index> {
51 let steps = steps(start.as_ref(), end.as_ref());
52 Indices {
53 next: if steps > 0 || start.as_ref().is_empty() {
56 Some(start.clone())
57 } else {
58 None
59 },
60 start,
61 end,
62 steps,
63 }
64 }
65}
66
67impl<const N: usize> Indices<SmallVec<[usize; N]>> {
68 pub fn from_ranges(ranges: &[Range<usize>]) -> Indices<SmallVec<[usize; N]>> {
71 let start: SmallVec<[usize; N]> = ranges.iter().map(|r| r.start).collect();
72 let end = ranges.iter().map(|r| r.end).collect();
73 Self::from_start_and_end(start, end)
74 }
75
76 pub fn from_shape(shape: &[usize]) -> Indices<SmallVec<[usize; N]>> {
79 let start = smallvec![0; shape.len()];
80 let end = shape.iter().copied().collect();
81 Self::from_start_and_end(start, end)
82 }
83}
84
85impl<const N: usize> Indices<[usize; N]> {
86 pub fn from_ranges(ranges: [Range<usize>; N]) -> Indices<[usize; N]> {
89 let start = ranges.clone().map(|r| r.start);
90 let end = ranges.map(|r| r.end);
91 Self::from_start_and_end(start, end)
92 }
93
94 pub fn from_shape(shape: [usize; N]) -> Indices<[usize; N]> {
97 Self::from_ranges(shape.map(|size| 0..size))
98 }
99}
100
101impl<Index: IndexArray> Iterator for Indices<Index> {
102 type Item = Index;
103
104 fn next(&mut self) -> Option<Self::Item> {
107 let current = self.next.clone()?;
108
109 let mut next = current.clone();
110 let mut has_next = false;
111 for ((&dim_end, &dim_start), index) in self
112 .end
113 .as_ref()
114 .iter()
115 .zip(self.start.as_ref())
116 .zip(next.as_mut().iter_mut())
117 .rev()
118 {
119 *index += 1;
120 if *index == dim_end {
121 *index = dim_start;
122 } else {
123 has_next = true;
124 break;
125 }
126 }
127
128 self.next = has_next.then_some(next);
129
130 Some(current)
131 }
132
133 #[inline]
134 fn size_hint(&self) -> (usize, Option<usize>) {
135 (self.steps, Some(self.steps))
136 }
137}
138
139impl<Index: IndexArray> ExactSizeIterator for Indices<Index> {}
140
141impl<Index: IndexArray> FusedIterator for Indices<Index> {}
142
143pub struct NdIndices<const N: usize> {
146 inner: Indices<[usize; N]>,
147}
148
149impl<const N: usize> NdIndices<N> {
150 pub fn from_ranges(ranges: [Range<usize>; N]) -> NdIndices<N> {
151 NdIndices {
152 inner: Indices::<[usize; N]>::from_ranges(ranges),
153 }
154 }
155
156 pub fn from_shape(shape: [usize; N]) -> NdIndices<N> {
157 NdIndices {
158 inner: Indices::<[usize; N]>::from_shape(shape),
159 }
160 }
161}
162
163impl<const N: usize> Iterator for NdIndices<N> {
164 type Item = [usize; N];
165
166 fn next(&mut self) -> Option<Self::Item> {
167 self.inner.next()
168 }
169
170 fn size_hint(&self) -> (usize, Option<usize>) {
171 self.inner.size_hint()
172 }
173}
174
175impl<const N: usize> ExactSizeIterator for NdIndices<N> {}
176impl<const N: usize> FusedIterator for NdIndices<N> {}
177
178const DYN_SMALL_LEN: usize = 4;
181
182enum DynIndicesInner {
183 Small {
184 iter: NdIndices<DYN_SMALL_LEN>,
185 pad: usize,
186 },
187 Large(Indices<DynIndex>),
188}
189
190pub struct DynIndices {
193 inner: DynIndicesInner,
194}
195
196fn left_pad_shape<const N: usize>(shape: &[usize]) -> (usize, [usize; N]) {
198 assert!(shape.len() <= N);
199 let mut padded_shape = [0; N];
200 let pad = N - shape.len();
201 for i in 0..pad {
202 padded_shape[i] = 1;
203 }
204 for i in pad..N {
205 padded_shape[i] = shape[i - pad];
206 }
207 (N - shape.len(), padded_shape)
208}
209
210fn left_pad_ranges<const N: usize>(ranges: &[Range<usize>]) -> (usize, [Range<usize>; N]) {
212 assert!(ranges.len() <= N);
213
214 let mut padded_ranges = SmallVec::<[Range<usize>; N]>::from_elem(0..1, N);
217 let pad = N - ranges.len();
218 for i in 0..pad {
219 padded_ranges[i] = 0..1;
220 }
221 for i in pad..N {
222 padded_ranges[i] = ranges[i - pad].clone();
223 }
224 (N - ranges.len(), padded_ranges.into_inner().unwrap())
225}
226
227impl DynIndices {
228 pub fn from_shape(shape: &[usize]) -> DynIndices {
229 let inner = if shape.len() <= DYN_SMALL_LEN {
230 let (pad, padded) = left_pad_shape(shape);
231 DynIndicesInner::Small {
232 iter: NdIndices::from_shape(padded),
233 pad,
234 }
235 } else {
236 DynIndicesInner::Large(Indices::<DynIndex>::from_shape(shape))
237 };
238 DynIndices { inner }
239 }
240
241 pub fn from_ranges(ranges: &[Range<usize>]) -> DynIndices {
242 let inner = if ranges.len() <= DYN_SMALL_LEN {
243 let (pad, padded) = left_pad_ranges(ranges);
244 DynIndicesInner::Small {
245 iter: NdIndices::from_ranges(padded),
246 pad,
247 }
248 } else {
249 DynIndicesInner::Large(Indices::<DynIndex>::from_ranges(ranges))
250 };
251 DynIndices { inner }
252 }
253}
254
255impl Iterator for DynIndices {
256 type Item = DynIndex;
257
258 #[inline]
259 fn next(&mut self) -> Option<Self::Item> {
260 match self.inner {
261 DynIndicesInner::Small { ref mut iter, pad } => {
262 iter.next().map(|idx| SmallVec::from_slice(&idx[pad..]))
263 }
264 DynIndicesInner::Large(ref mut inner) => inner.next(),
265 }
266 }
267
268 fn size_hint(&self) -> (usize, Option<usize>) {
269 match self.inner {
270 DynIndicesInner::Small { ref iter, .. } => iter.size_hint(),
271 DynIndicesInner::Large(ref inner) => inner.size_hint(),
272 }
273 }
274}
275
276impl ExactSizeIterator for DynIndices {}
277impl FusedIterator for DynIndices {}
278
279#[cfg(test)]
280mod tests {
281 use super::{DynIndices, NdIndices};
282
283 #[test]
284 fn test_nd_indices() {
285 let mut iter = NdIndices::from_ranges([0..0]);
287 assert_eq!(iter.next(), None);
288 assert_eq!(iter.next(), None);
289
290 let mut iter = NdIndices::from_ranges([]);
292 assert_eq!(iter.next(), Some([]));
293 assert_eq!(iter.next(), None);
294
295 let iter = NdIndices::from_ranges([0..5]);
297 let visited: Vec<_> = iter.collect();
298 assert_eq!(visited, &[[0], [1], [2], [3], [4]]);
299
300 let iter = NdIndices::from_ranges([2..4, 2..4]);
302 let visited: Vec<_> = iter.collect();
303 assert_eq!(visited, &[[2, 2], [2, 3], [3, 2], [3, 3]]);
304 }
305
306 #[test]
307 fn test_dyn_indices() {
308 type Index = <DynIndices as Iterator>::Item;
309
310 let mut iter = DynIndices::from_ranges(&[0..0]);
312 assert_eq!(iter.next(), None);
313 assert_eq!(iter.next(), None);
314
315 let mut iter = DynIndices::from_ranges(&[]);
317 assert_eq!(iter.next(), Some(Index::new()));
318 assert_eq!(iter.next(), None);
319
320 let iter = DynIndices::from_ranges(&[0..5]);
322 let visited: Vec<Vec<usize>> = iter.map(|ix| ix.into_iter().collect()).collect();
323 assert_eq!(visited, vec![vec![0], vec![1], vec![2], vec![3], vec![4]]);
324
325 let iter = DynIndices::from_ranges(&[2..4, 2..4]);
327 let visited: Vec<Vec<usize>> = iter.map(|ix| ix.into_iter().collect()).collect();
328 assert_eq!(
329 visited,
330 vec![vec![2, 2], vec![2, 3], vec![3, 2], vec![3, 3],]
331 );
332
333 let iter = DynIndices::from_shape(&[2, 1, 1, 2, 2]);
336 let visited: Vec<Vec<usize>> = iter.map(|ix| ix.into_iter().collect()).collect();
337 assert_eq!(
338 visited,
339 vec![
340 vec![0, 0, 0, 0, 0],
341 vec![0, 0, 0, 0, 1],
342 vec![0, 0, 0, 1, 0],
343 vec![0, 0, 0, 1, 1],
344 vec![1, 0, 0, 0, 0],
346 vec![1, 0, 0, 0, 1],
347 vec![1, 0, 0, 1, 0],
348 vec![1, 0, 0, 1, 1],
349 ]
350 );
351 }
352
353 #[test]
354 #[ignore]
355 fn bench_indices() {
356 use std::time::Instant;
357
358 let shape = std::hint::black_box([16, 128, 128]);
365
366 let start = Instant::now();
368 let mut count = 0;
369 for _ in 0..100 {
370 let indices = DynIndices::from_shape(&shape);
371 for _ in indices {
372 count += 1;
373 }
374 }
375 let elapsed = start.elapsed().as_millis();
376 println!("DynIndices stepped {} times in {} ms", count, elapsed);
377
378 let start = Instant::now();
380 let mut count = 0;
381 for _ in 0..100 {
382 let indices = NdIndices::from_shape(shape);
383 for _ in indices {
384 count += 1;
385 }
386 }
387 let elapsed = start.elapsed().as_millis();
388 println!("NdIndices stepped {} times in {} ms", count, elapsed);
389 }
390}