acme_tensor/actions/iter/
iterator.rs

1/*
2    Appellation: iterator <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::LayoutIter;
6use crate::TensorBase;
7use core::marker::PhantomData;
8use core::ptr;
9
10/// An immutable iterator of the elements of a [tensor](crate::tensor::TensorBase)
11/// Elements are visited in order, matching the layout of the tensor.
12pub struct Iter<'a, T> {
13    inner: LayoutIter,
14    ptr: *const T,
15    tensor: TensorBase<&'a T>,
16}
17
18impl<'a, T> Iter<'a, T> {
19    pub(crate) fn new(tensor: TensorBase<&'a T>) -> Self {
20        Self {
21            inner: tensor.layout().iter(),
22            ptr: unsafe { *tensor.as_ptr() },
23            tensor,
24        }
25    }
26}
27
28impl<'a, T> Iterator for Iter<'a, T> {
29    type Item = &'a T;
30
31    fn next(&mut self) -> Option<Self::Item> {
32        let pos = self.inner.next()?;
33        let item = self.tensor.get_by_index(pos.index())?;
34        self.ptr = ptr::from_ref(item);
35        unsafe { self.ptr.as_ref() }
36    }
37}
38
39impl<'a, T> DoubleEndedIterator for Iter<'a, T> {
40    fn next_back(&mut self) -> Option<Self::Item> {
41        let pos = self.inner.next_back()?;
42        let item = self.tensor.get_by_index(pos.index())?;
43        self.ptr = ptr::from_ref(item);
44        unsafe { self.ptr.as_ref() }
45    }
46}
47
48impl<'a, T> ExactSizeIterator for Iter<'a, T> {
49    fn len(&self) -> usize {
50        self.tensor.size()
51    }
52}
53
54impl<'a, T> From<TensorBase<&'a T>> for Iter<'a, T> {
55    fn from(tensor: TensorBase<&'a T>) -> Self {
56        Self::new(tensor)
57    }
58}
59
60impl<'a, T> From<&'a TensorBase<T>> for Iter<'a, T> {
61    fn from(tensor: &'a TensorBase<T>) -> Self {
62        Self::new(tensor.view())
63    }
64}
65
66pub struct IterMut<'a, T: 'a> {
67    inner: LayoutIter,
68    ptr: *mut T,
69    tensor: TensorBase<&'a mut T>,
70    _marker: PhantomData<&'a mut T>,
71}
72
73impl<'a, T> IterMut<'a, T> {
74    pub(crate) fn new(tensor: &'a mut TensorBase<T>) -> Self {
75        Self {
76            inner: tensor.layout().iter(),
77            ptr: tensor.as_mut_ptr(),
78            tensor: tensor.view_mut(),
79            _marker: PhantomData,
80        }
81    }
82}
83
84impl<'a, T> Iterator for IterMut<'a, T> {
85    type Item = &'a mut T;
86
87    fn next(&mut self) -> Option<Self::Item> {
88        let pos = self.inner.next()?;
89        let elem = self.tensor.get_mut_by_index(pos.index())?;
90        self.ptr = ptr::from_mut(elem);
91        unsafe { self.ptr.as_mut() }
92    }
93}
94
95impl<'a, T> DoubleEndedIterator for IterMut<'a, T> {
96    fn next_back(&mut self) -> Option<Self::Item> {
97        let pos = self.inner.next_back()?;
98        let elem = self.tensor.get_mut_by_index(pos.index())?;
99
100        self.ptr = ptr::from_mut(elem);
101        unsafe { self.ptr.as_mut() }
102    }
103}
104
105impl<'a, T> ExactSizeIterator for IterMut<'a, T> {
106    fn len(&self) -> usize {
107        self.tensor.size()
108    }
109}