acme_tensor/actions/iter/
axis.rs

1/*
2    Appellation: axis <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::index::{Ix, Ixs};
6use crate::shape::{Axis, Layout};
7use crate::TensorBase;
8use core::ptr;
9
10pub struct AxisIter<'a, A> {
11    index: Ix,
12    end: Ix,
13    stride: Ixs,
14    inner_layout: Layout,
15    ptr: *const A,
16    tensor: TensorBase<&'a A>,
17}
18
19impl<'a, A> AxisIter<'a, A> {
20    pub fn new(v: TensorBase<&'a A>, axis: Axis) -> Self {
21        let stride = v.strides()[axis] as isize;
22        let end = v.shape()[axis];
23        Self {
24            index: 0,
25            end,
26            stride,
27            inner_layout: v.layout().remove_axis(axis),
28            ptr: unsafe { *v.as_ptr() },
29            tensor: v,
30        }
31    }
32}
33
34impl<'a, A> Iterator for AxisIter<'a, A> {
35    type Item = &'a A;
36
37    fn next(&mut self) -> Option<Self::Item> {
38        if self.index == self.end {
39            return None;
40        }
41        let ptr = unsafe { self.ptr.add(self.index) };
42        self.index += self.stride as Ix;
43        unsafe { ptr.as_ref() }
44    }
45}