rten_tensor/
index_iterator.rs

1use std::iter::FusedIterator;
2use std::ops::Range;
3
4use smallvec::{SmallVec, smallvec};
5
6pub trait IndexArray: AsMut<[usize]> + AsRef<[usize]> + Clone {}
7impl<const N: usize> IndexArray for SmallVec<[usize; N]> {}
8impl<const N: usize> IndexArray for [usize; N] {}
9
10/// The index type used for dynamic-rank tensors.
11pub type DynIndex = SmallVec<[usize; 5]>;
12
13/// Iterator over a range of N-dimensional indices, where N may be known at
14/// compile time (see [`NdIndices`]) or only at runtime ([`DynIndices`]).
15///
16/// The number of dimensions may be zero, in which case the iterator will yield
17/// a single empty index. This is consistent with eg. `ndindex` in NumPy.
18pub struct Indices<Index: IndexArray>
19where
20    Index: IndexArray,
21{
22    /// Start index along each dimension.
23    start: Index,
24
25    /// End index (exclusive) along each dimension.
26    end: Index,
27
28    next: Option<Index>,
29
30    /// Remaining iteration steps.
31    steps: usize,
32}
33
34/// Return the number of steps for an index iterator over the range of indices
35/// from `from` to `to`.
36///
37/// If any index in `from` is greater than the corresponding index in `to`,
38/// this returns zero.
39fn steps(from: &[usize], to: &[usize]) -> usize {
40    assert!(from.len() == to.len());
41    let mut product = 1;
42    for (&from, &to) in from.iter().zip(to.iter()).rev() {
43        let size = to.saturating_sub(from);
44        product *= size;
45    }
46    product
47}
48
49impl<Index: IndexArray> Indices<Index> {
50    fn from_start_and_end(start: Index, end: Index) -> Indices<Index> {
51        let steps = steps(start.as_ref(), end.as_ref());
52        Indices {
53            // Note that if the index is empty, `start == end` but the iterator
54            // should yield a single empty element in that case.
55            next: if steps > 0 || start.as_ref().is_empty() {
56                Some(start.clone())
57            } else {
58                None
59            },
60            start,
61            end,
62            steps,
63        }
64    }
65}
66
67impl<const N: usize> Indices<SmallVec<[usize; N]>> {
68    /// Return an iterator over all the indices where each dimension lies
69    /// within the corresponding range in `ranges`.
70    pub fn from_ranges(ranges: &[Range<usize>]) -> Indices<SmallVec<[usize; N]>> {
71        let start: SmallVec<[usize; N]> = ranges.iter().map(|r| r.start).collect();
72        let end = ranges.iter().map(|r| r.end).collect();
73        Self::from_start_and_end(start, end)
74    }
75
76    /// Return an iterator over all the indices where each dimension is between
77    /// `0` and `shape[dim]`.
78    pub fn from_shape(shape: &[usize]) -> Indices<SmallVec<[usize; N]>> {
79        let start = smallvec![0; shape.len()];
80        let end = shape.iter().copied().collect();
81        Self::from_start_and_end(start, end)
82    }
83}
84
85impl<const N: usize> Indices<[usize; N]> {
86    /// Return an iterator over all the indices where each dimension lies
87    /// within the corresponding range in `ranges`.
88    pub fn from_ranges(ranges: [Range<usize>; N]) -> Indices<[usize; N]> {
89        let start = ranges.clone().map(|r| r.start);
90        let end = ranges.map(|r| r.end);
91        Self::from_start_and_end(start, end)
92    }
93
94    /// Return an iterator over all the indices where each dimension is between
95    /// `0` and `shape[dim]`.
96    pub fn from_shape(shape: [usize; N]) -> Indices<[usize; N]> {
97        Self::from_ranges(shape.map(|size| 0..size))
98    }
99}
100
101impl<Index: IndexArray> Iterator for Indices<Index> {
102    type Item = Index;
103
104    /// Return the next index in the sequence, or `None` after all indices
105    /// have been returned.
106    fn next(&mut self) -> Option<Self::Item> {
107        let current = self.next.clone()?;
108
109        let mut next = current.clone();
110        let mut has_next = false;
111        for ((&dim_end, &dim_start), index) in self
112            .end
113            .as_ref()
114            .iter()
115            .zip(self.start.as_ref())
116            .zip(next.as_mut().iter_mut())
117            .rev()
118        {
119            *index += 1;
120            if *index == dim_end {
121                *index = dim_start;
122            } else {
123                has_next = true;
124                break;
125            }
126        }
127
128        self.next = has_next.then_some(next);
129
130        Some(current)
131    }
132
133    #[inline]
134    fn size_hint(&self) -> (usize, Option<usize>) {
135        (self.steps, Some(self.steps))
136    }
137}
138
139impl<Index: IndexArray> ExactSizeIterator for Indices<Index> {}
140
141impl<Index: IndexArray> FusedIterator for Indices<Index> {}
142
143/// Iterator over a range of N-dimensional indices, where N is known at compile
144/// time.
145pub struct NdIndices<const N: usize> {
146    inner: Indices<[usize; N]>,
147}
148
149impl<const N: usize> NdIndices<N> {
150    pub fn from_ranges(ranges: [Range<usize>; N]) -> NdIndices<N> {
151        NdIndices {
152            inner: Indices::<[usize; N]>::from_ranges(ranges),
153        }
154    }
155
156    pub fn from_shape(shape: [usize; N]) -> NdIndices<N> {
157        NdIndices {
158            inner: Indices::<[usize; N]>::from_shape(shape),
159        }
160    }
161}
162
163impl<const N: usize> Iterator for NdIndices<N> {
164    type Item = [usize; N];
165
166    fn next(&mut self) -> Option<Self::Item> {
167        self.inner.next()
168    }
169
170    fn size_hint(&self) -> (usize, Option<usize>) {
171        self.inner.size_hint()
172    }
173}
174
175impl<const N: usize> ExactSizeIterator for NdIndices<N> {}
176impl<const N: usize> FusedIterator for NdIndices<N> {}
177
178/// Max tensor rank supported by the variant of [`DynIndices`] that is optimized
179/// for small-rank tensors.
180const DYN_SMALL_LEN: usize = 4;
181
182enum DynIndicesInner {
183    Small {
184        iter: NdIndices<DYN_SMALL_LEN>,
185        pad: usize,
186    },
187    Large(Indices<DynIndex>),
188}
189
190/// Iterator over a range of N-dimensional indices, where N is not known at
191/// compile time.
192pub struct DynIndices {
193    inner: DynIndicesInner,
194}
195
196/// Left-pad a shape with 1s to size N (eg. [32, 32] => [1, 1, 32, 32]).
197fn left_pad_shape<const N: usize>(shape: &[usize]) -> (usize, [usize; N]) {
198    assert!(shape.len() <= N);
199    let mut padded_shape = [0; N];
200    let pad = N - shape.len();
201    for i in 0..pad {
202        padded_shape[i] = 1;
203    }
204    for i in pad..N {
205        padded_shape[i] = shape[i - pad];
206    }
207    (N - shape.len(), padded_shape)
208}
209
210/// Left-pad ranges with `[0..1]` to size N.
211fn left_pad_ranges<const N: usize>(ranges: &[Range<usize>]) -> (usize, [Range<usize>; N]) {
212    assert!(ranges.len() <= N);
213
214    // We use a `SmallVec` here because sadly `[elem; N]` doesn't work with
215    // Range, which is a non-Copy type :(
216    let mut padded_ranges = SmallVec::<[Range<usize>; N]>::from_elem(0..1, N);
217    let pad = N - ranges.len();
218    for i in 0..pad {
219        padded_ranges[i] = 0..1;
220    }
221    for i in pad..N {
222        padded_ranges[i] = ranges[i - pad].clone();
223    }
224    (N - ranges.len(), padded_ranges.into_inner().unwrap())
225}
226
227impl DynIndices {
228    pub fn from_shape(shape: &[usize]) -> DynIndices {
229        let inner = if shape.len() <= DYN_SMALL_LEN {
230            let (pad, padded) = left_pad_shape(shape);
231            DynIndicesInner::Small {
232                iter: NdIndices::from_shape(padded),
233                pad,
234            }
235        } else {
236            DynIndicesInner::Large(Indices::<DynIndex>::from_shape(shape))
237        };
238        DynIndices { inner }
239    }
240
241    pub fn from_ranges(ranges: &[Range<usize>]) -> DynIndices {
242        let inner = if ranges.len() <= DYN_SMALL_LEN {
243            let (pad, padded) = left_pad_ranges(ranges);
244            DynIndicesInner::Small {
245                iter: NdIndices::from_ranges(padded),
246                pad,
247            }
248        } else {
249            DynIndicesInner::Large(Indices::<DynIndex>::from_ranges(ranges))
250        };
251        DynIndices { inner }
252    }
253}
254
255impl Iterator for DynIndices {
256    type Item = DynIndex;
257
258    #[inline]
259    fn next(&mut self) -> Option<Self::Item> {
260        match self.inner {
261            DynIndicesInner::Small { ref mut iter, pad } => {
262                iter.next().map(|idx| SmallVec::from_slice(&idx[pad..]))
263            }
264            DynIndicesInner::Large(ref mut inner) => inner.next(),
265        }
266    }
267
268    fn size_hint(&self) -> (usize, Option<usize>) {
269        match self.inner {
270            DynIndicesInner::Small { ref iter, .. } => iter.size_hint(),
271            DynIndicesInner::Large(ref inner) => inner.size_hint(),
272        }
273    }
274}
275
276impl ExactSizeIterator for DynIndices {}
277impl FusedIterator for DynIndices {}
278
279#[cfg(test)]
280mod tests {
281    use super::{DynIndices, NdIndices};
282
283    #[test]
284    fn test_nd_indices() {
285        // Empty iterator
286        let mut iter = NdIndices::from_ranges([0..0]);
287        assert_eq!(iter.next(), None);
288        assert_eq!(iter.next(), None);
289
290        // Scalar index iterator
291        let mut iter = NdIndices::from_ranges([]);
292        assert_eq!(iter.next(), Some([]));
293        assert_eq!(iter.next(), None);
294
295        // 1D index iterator
296        let iter = NdIndices::from_ranges([0..5]);
297        let visited: Vec<_> = iter.collect();
298        assert_eq!(visited, &[[0], [1], [2], [3], [4]]);
299
300        // 2D index iterator
301        let iter = NdIndices::from_ranges([2..4, 2..4]);
302        let visited: Vec<_> = iter.collect();
303        assert_eq!(visited, &[[2, 2], [2, 3], [3, 2], [3, 3]]);
304    }
305
306    #[test]
307    fn test_dyn_indices() {
308        type Index = <DynIndices as Iterator>::Item;
309
310        // Empty iterator
311        let mut iter = DynIndices::from_ranges(&[0..0]);
312        assert_eq!(iter.next(), None);
313        assert_eq!(iter.next(), None);
314
315        // Scalar index iterator
316        let mut iter = DynIndices::from_ranges(&[]);
317        assert_eq!(iter.next(), Some(Index::new()));
318        assert_eq!(iter.next(), None);
319
320        // 1D index iterator
321        let iter = DynIndices::from_ranges(&[0..5]);
322        let visited: Vec<Vec<usize>> = iter.map(|ix| ix.into_iter().collect()).collect();
323        assert_eq!(visited, vec![vec![0], vec![1], vec![2], vec![3], vec![4]]);
324
325        // 2D index iterator
326        let iter = DynIndices::from_ranges(&[2..4, 2..4]);
327        let visited: Vec<Vec<usize>> = iter.map(|ix| ix.into_iter().collect()).collect();
328        assert_eq!(
329            visited,
330            vec![vec![2, 2], vec![2, 3], vec![3, 2], vec![3, 3],]
331        );
332
333        // 5D index iterator. This exercises the path for tensors with more
334        // than 4 dims.
335        let iter = DynIndices::from_shape(&[2, 1, 1, 2, 2]);
336        let visited: Vec<Vec<usize>> = iter.map(|ix| ix.into_iter().collect()).collect();
337        assert_eq!(
338            visited,
339            vec![
340                vec![0, 0, 0, 0, 0],
341                vec![0, 0, 0, 0, 1],
342                vec![0, 0, 0, 1, 0],
343                vec![0, 0, 0, 1, 1],
344                //
345                vec![1, 0, 0, 0, 0],
346                vec![1, 0, 0, 0, 1],
347                vec![1, 0, 0, 1, 0],
348                vec![1, 0, 0, 1, 1],
349            ]
350        );
351    }
352
353    #[test]
354    #[ignore]
355    fn bench_indices() {
356        use std::time::Instant;
357
358        // Shape taken from GatherElements usage in
359        // https://huggingface.co/microsoft/deberta-v3-large.
360        //
361        // `black_box` is not necessary for the current implementations, but in
362        // an experiment with some less branch-y implementations of NdIndices,
363        // Rust was able to precompute the iteration count (!).
364        let shape = std::hint::black_box([16, 128, 128]);
365
366        // Dynamic rank
367        let start = Instant::now();
368        let mut count = 0;
369        for _ in 0..100 {
370            let indices = DynIndices::from_shape(&shape);
371            for _ in indices {
372                count += 1;
373            }
374        }
375        let elapsed = start.elapsed().as_millis();
376        println!("DynIndices stepped {} times in {} ms", count, elapsed);
377
378        // Same shape, static rank
379        let start = Instant::now();
380        let mut count = 0;
381        for _ in 0..100 {
382            let indices = NdIndices::from_shape(shape);
383            for _ in indices {
384                count += 1;
385            }
386        }
387        let elapsed = start.elapsed().as_millis();
388        println!("NdIndices stepped {} times in {} ms", count, elapsed);
389    }
390}