redstone_ml/ndarray/iterator/
buffer_iterator.rs

1use crate::dtype::RawDataType;
2use crate::flat_index_generator::FlatIndexGenerator;
3use crate::NdArray;
4
5
6pub struct BufferIterator<T: RawDataType> {
7    ptr: *mut T,
8    indices: FlatIndexGenerator,
9}
10
11impl<T: RawDataType> BufferIterator<T> {
12    pub(crate) fn from(tensor: &NdArray<T>) -> Self {
13        Self {
14            ptr: unsafe { tensor.mut_ptr() },
15            indices: FlatIndexGenerator::from(&tensor.shape, &tensor.stride),
16        }
17    }
18
19    pub(crate) unsafe fn from_reshaped_view(tensor: &NdArray<T>, shape: &[usize], stride: &[usize]) -> Self {
20        Self {
21            ptr: tensor.mut_ptr(),
22            indices: FlatIndexGenerator::from(shape, stride),
23        }
24    }
25
26    #[inline(always)]
27    fn advance_by(&mut self, n: usize) {
28        self.indices.advance_by(n);
29    }
30}
31
32impl<T: RawDataType> Iterator for BufferIterator<T> {
33    type Item = *mut T;
34
35    fn next(&mut self) -> Option<Self::Item> {
36        match self.indices.next() {
37            None => None,
38            Some(i) => Some(unsafe { self.ptr.add(i) })
39        }
40    }
41}