redstone_ml/ndarray/iterator/
flat_index_generator.rs

1use crate::iterator::collapse_contiguous::collapse_to_uniform_stride;
2use crate::ndarray::MAX_DIMS;
3
4#[non_exhaustive]
5pub struct FlatIndexGenerator
6{
7    ndims: usize,
8    shape: [usize; MAX_DIMS],
9    stride: [usize; MAX_DIMS],
10
11    size: usize,
12    iterator_index: usize,
13
14    indices: [usize; MAX_DIMS], // current index along each dimension
15    flat_index: usize,
16}
17
18impl FlatIndexGenerator {
19    pub(crate) fn from(shape: &[usize], stride: &[usize]) -> Self {
20        let (shape, stride) = collapse_to_uniform_stride(shape, stride);
21        let ndims = shape.len();
22        let size = shape.iter().product();
23
24        let mut new_shape = [0; MAX_DIMS];
25        let mut new_stride = [0; MAX_DIMS];
26
27        new_shape[0..ndims].copy_from_slice(&shape);
28        new_stride[0..ndims].copy_from_slice(&stride);
29
30        Self {
31            ndims,
32            shape: new_shape,
33            stride: new_stride,
34            size,
35            iterator_index: 0,
36            indices: [0; MAX_DIMS],
37            flat_index: 0,
38        }
39    }
40
41    #[inline(always)]
42    pub fn advance_by(&mut self, mut n: usize) {
43        let remaining = self.size - self.iterator_index;
44        n = n.min(remaining);
45
46        if n == 0 {
47            return;
48        }
49        self.iterator_index += n;
50
51        let mut carry = n;
52        for i in (0..self.ndims).rev() {
53            let dim = self.shape[i];
54            let idx = &mut self.indices[i];
55
56            let total = *idx + carry;
57            *idx = total % dim;
58            carry = total / dim;
59
60            self.flat_index += self.stride[i] * (*idx - self.indices[i]);
61        }
62    }
63}
64
65impl Iterator for FlatIndexGenerator {
66    type Item = usize;
67
68    #[inline(always)]
69    fn next(&mut self) -> Option<Self::Item> {
70        if self.iterator_index == self.size {
71            return None;
72        }
73
74        let return_index = self.flat_index;
75
76        let mut i = self.ndims;
77        while i > 0 {
78            i -= 1;
79
80            unsafe {
81                let idx = self.indices.get_unchecked_mut(i);
82                *idx += 1;
83
84                if *idx < *self.shape.get_unchecked(i) {
85                    self.flat_index += *self.stride.get_unchecked(i);
86                    break;
87                }
88
89                *idx = 0; // reset this dimension and carry over to the next
90                self.flat_index -= *self.stride.get_unchecked(i) * (*self.shape.get_unchecked(i) - 1);
91            }
92        }
93
94        self.iterator_index += 1;
95        Some(return_index)
96    }
97}
98
99impl Clone for FlatIndexGenerator {
100    fn clone(&self) -> Self {
101        Self {
102            ndims: self.ndims,
103            shape: self.shape,
104            stride: self.stride,
105
106            size: self.size,
107            iterator_index: self.iterator_index,
108
109            indices: self.indices,
110            flat_index: self.flat_index,
111        }
112    }
113}