redstone_ml/ndarray/iterator/
tensor_iterator.rs1use crate::dtype::RawDataType;
2use crate::iterator::util::split_by_indices;
3use crate::util::haslength::HasLength;
4use crate::{NdArray, Reshape};
5
6
7#[non_exhaustive]
8pub struct NdIterator<'a, T: RawDataType> {
9 result: NdArray<'a, T>,
10
11 shape: Vec<usize>,
12 stride: Vec<usize>,
13
14 indices: Vec<usize>, iterator_index: usize,
16 size: usize,
17}
18
19impl<T: RawDataType> NdArray<'_, T> {
20 pub(crate) unsafe fn offset_ptr(&mut self, offset: isize) {
21 self.ptr = self.ptr.offset(offset);
22 }
23}
24
25impl<'a, T: RawDataType> NdIterator<'a, T> {
26 pub(super) fn from<I>(tensor: &'a NdArray<'a, T>, axes: I) -> Self
27 where
28 I: IntoIterator<Item=usize> + HasLength + Clone,
29 {
30 let ndims = axes.len();
31 let (shape, output_shape) = split_by_indices(&tensor.shape, axes.clone());
32 let (stride, output_stride) = split_by_indices(&tensor.stride, axes);
33 let size = shape.iter().product();
34
35 Self {
36 result: unsafe { tensor.reshaped_view(output_shape, output_stride) },
37 shape,
38 stride,
39 indices: vec![0; ndims],
40 iterator_index: 0,
41 size,
42 }
43 }
44}
45
46impl<'a, T: RawDataType> Iterator for NdIterator<'a, T> {
47 type Item = NdArray<'a, T>;
48
49 fn next(&mut self) -> Option<Self::Item> {
50 if self.iterator_index == self.size {
51 return None;
52 }
53
54 let return_value = self.result.clone(); self.iterator_index += 1;
56
57 for i in (0..self.shape.len()).rev() {
58 if self.indices[i] != self.shape[i] {
59 self.indices[i] += 1;
60 unsafe { self.result.offset_ptr(self.stride[i] as isize); }
61 break;
62 }
63
64 unsafe { self.result.offset_ptr(-((self.stride[i] * (self.shape[i] - 1)) as isize)); }
65 self.indices[i] = 0;
66 }
67
68 Some(return_value)
69 }
70}