redstone_ml/ndarray/iterator/
buffer_iterator.rs1use crate::dtype::RawDataType;
2use crate::flat_index_generator::FlatIndexGenerator;
3use crate::NdArray;
4
5
6pub struct BufferIterator<T: RawDataType> {
7 ptr: *mut T,
8 indices: FlatIndexGenerator,
9}
10
11impl<T: RawDataType> BufferIterator<T> {
12 pub(crate) fn from(tensor: &NdArray<T>) -> Self {
13 Self {
14 ptr: unsafe { tensor.mut_ptr() },
15 indices: FlatIndexGenerator::from(&tensor.shape, &tensor.stride),
16 }
17 }
18
19 pub(crate) unsafe fn from_reshaped_view(tensor: &NdArray<T>, shape: &[usize], stride: &[usize]) -> Self {
20 Self {
21 ptr: tensor.mut_ptr(),
22 indices: FlatIndexGenerator::from(shape, stride),
23 }
24 }
25
26 #[inline(always)]
27 fn advance_by(&mut self, n: usize) {
28 self.indices.advance_by(n);
29 }
30}
31
32impl<T: RawDataType> Iterator for BufferIterator<T> {
33 type Item = *mut T;
34
35 fn next(&mut self) -> Option<Self::Item> {
36 match self.indices.next() {
37 None => None,
38 Some(i) => Some(unsafe { self.ptr.add(i) })
39 }
40 }
41}