redstone_ml/ndarray/iterator/
flat_index_generator.rs1use crate::iterator::collapse_contiguous::collapse_to_uniform_stride;
2use crate::ndarray::MAX_DIMS;
3
4#[non_exhaustive]
5pub struct FlatIndexGenerator
6{
7 ndims: usize,
8 shape: [usize; MAX_DIMS],
9 stride: [usize; MAX_DIMS],
10
11 size: usize,
12 iterator_index: usize,
13
14 indices: [usize; MAX_DIMS], flat_index: usize,
16}
17
18impl FlatIndexGenerator {
19 pub(crate) fn from(shape: &[usize], stride: &[usize]) -> Self {
20 let (shape, stride) = collapse_to_uniform_stride(shape, stride);
21 let ndims = shape.len();
22 let size = shape.iter().product();
23
24 let mut new_shape = [0; MAX_DIMS];
25 let mut new_stride = [0; MAX_DIMS];
26
27 new_shape[0..ndims].copy_from_slice(&shape);
28 new_stride[0..ndims].copy_from_slice(&stride);
29
30 Self {
31 ndims,
32 shape: new_shape,
33 stride: new_stride,
34 size,
35 iterator_index: 0,
36 indices: [0; MAX_DIMS],
37 flat_index: 0,
38 }
39 }
40
41 #[inline(always)]
42 pub fn advance_by(&mut self, mut n: usize) {
43 let remaining = self.size - self.iterator_index;
44 n = n.min(remaining);
45
46 if n == 0 {
47 return;
48 }
49 self.iterator_index += n;
50
51 let mut carry = n;
52 for i in (0..self.ndims).rev() {
53 let dim = self.shape[i];
54 let idx = &mut self.indices[i];
55
56 let total = *idx + carry;
57 *idx = total % dim;
58 carry = total / dim;
59
60 self.flat_index += self.stride[i] * (*idx - self.indices[i]);
61 }
62 }
63}
64
65impl Iterator for FlatIndexGenerator {
66 type Item = usize;
67
68 #[inline(always)]
69 fn next(&mut self) -> Option<Self::Item> {
70 if self.iterator_index == self.size {
71 return None;
72 }
73
74 let return_index = self.flat_index;
75
76 let mut i = self.ndims;
77 while i > 0 {
78 i -= 1;
79
80 unsafe {
81 let idx = self.indices.get_unchecked_mut(i);
82 *idx += 1;
83
84 if *idx < *self.shape.get_unchecked(i) {
85 self.flat_index += *self.stride.get_unchecked(i);
86 break;
87 }
88
89 *idx = 0; self.flat_index -= *self.stride.get_unchecked(i) * (*self.shape.get_unchecked(i) - 1);
91 }
92 }
93
94 self.iterator_index += 1;
95 Some(return_index)
96 }
97}
98
99impl Clone for FlatIndexGenerator {
100 fn clone(&self) -> Self {
101 Self {
102 ndims: self.ndims,
103 shape: self.shape,
104 stride: self.stride,
105
106 size: self.size,
107 iterator_index: self.iterator_index,
108
109 indices: self.indices,
110 flat_index: self.flat_index,
111 }
112 }
113}