redstone_ml/ndarray/iterator/
tensor_iterator.rs

1use crate::dtype::RawDataType;
2use crate::iterator::util::split_by_indices;
3use crate::util::haslength::HasLength;
4use crate::{NdArray, Reshape};
5
6
7#[non_exhaustive]
8pub struct NdIterator<'a, T: RawDataType> {
9    result: NdArray<'a, T>,
10
11    shape: Vec<usize>,
12    stride: Vec<usize>,
13
14    indices: Vec<usize>, // current index along each dimension
15    iterator_index: usize,
16    size: usize,
17}
18
19impl<T: RawDataType> NdArray<'_, T> {
20    pub(crate) unsafe fn offset_ptr(&mut self, offset: isize) {
21        self.ptr = self.ptr.offset(offset);
22    }
23}
24
25impl<'a, T: RawDataType> NdIterator<'a, T> {
26    pub(super) fn from<I>(tensor: &'a NdArray<'a, T>, axes: I) -> Self
27    where
28        I: IntoIterator<Item=usize> + HasLength + Clone,
29    {
30        let ndims = axes.len();
31        let (shape, output_shape) = split_by_indices(&tensor.shape, axes.clone());
32        let (stride, output_stride) = split_by_indices(&tensor.stride, axes);
33        let size = shape.iter().product();
34
35        Self {
36            result: unsafe { tensor.reshaped_view(output_shape, output_stride) },
37            shape,
38            stride,
39            indices: vec![0; ndims],
40            iterator_index: 0,
41            size,
42        }
43    }
44}
45
46impl<'a, T: RawDataType> Iterator for NdIterator<'a, T> {
47    type Item = NdArray<'a, T>;
48
49    fn next(&mut self) -> Option<Self::Item> {
50        if self.iterator_index == self.size {
51            return None;
52        }
53
54        let return_value = self.result.clone();  // TODO this shouldn't be cloned!
55        self.iterator_index += 1;
56        
57        for i in (0..self.shape.len()).rev() {
58            if self.indices[i] != self.shape[i] {
59                self.indices[i] += 1;
60                unsafe { self.result.offset_ptr(self.stride[i] as isize); }
61                break;
62            }
63        
64            unsafe { self.result.offset_ptr(-((self.stride[i] * (self.shape[i] - 1)) as isize)); }
65            self.indices[i] = 0;
66        }
67
68        Some(return_value)
69    }
70}