Skip to main content

ferray_core/array/
iter.rs

1// ferray-core: Iterator implementations (REQ-37)
2
3use ndarray::{Dimension as NdarrayDimension, IntoDimension};
4
5use crate::dimension::{Axis, Dimension, Ix1, IxDyn};
6use crate::dtype::Element;
7use crate::error::{FerrayError, FerrayResult};
8
9use super::owned::Array;
10use super::view::ArrayView;
11use super::view_mut::ArrayViewMut;
12
13// ---------------------------------------------------------------------------
14// Element iteration for Array<T, D>
15// ---------------------------------------------------------------------------
16
17impl<T: Element, D: Dimension> Array<T, D> {
18    /// Iterate over all elements in logical (row-major) order.
19    pub fn iter(&self) -> impl Iterator<Item = &T> + '_ {
20        self.inner.iter()
21    }
22
23    /// Mutably iterate over all elements in logical order.
24    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut T> + '_ {
25        self.inner.iter_mut()
26    }
27
28    /// Iterate with multi-dimensional indices.
29    ///
30    /// Yields `(Vec<usize>, &T)` pairs in logical order. The index vector
31    /// has one entry per dimension.
32    ///
33    /// Delegates to ndarray's cached `indexed_iter`, which carries a
34    /// typed index through the walk (no per-element divmod on a flat
35    /// index as the previous hand-rolled implementation did, see
36    /// issue #80). The public `Vec<usize>` return type forces a single
37    /// allocation per yielded element; any further win requires an API
38    /// change to a streaming iterator or `&[usize]` signature.
39    pub fn indexed_iter(&self) -> impl Iterator<Item = (Vec<usize>, &T)> + '_ {
40        self.inner.indexed_iter().map(|(pat, val)| {
41            let dim = pat.into_dimension();
42            (dim.slice().to_vec(), val)
43        })
44    }
45
46    /// Flat iterator — same as `iter()` but emphasises logical-order traversal.
47    pub fn flat(&self) -> impl Iterator<Item = &T> + '_ {
48        self.inner.iter()
49    }
50
51    /// Iterate over lanes (1-D slices) along the given axis.
52    ///
53    /// For a 2-D array with `axis=1`, this yields each row.
54    /// For `axis=0`, this yields each column.
55    ///
56    /// # Errors
57    /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
58    pub fn lanes(
59        &self,
60        axis: Axis,
61    ) -> FerrayResult<impl Iterator<Item = ArrayView<'_, T, Ix1>> + '_> {
62        let ndim = self.ndim();
63        if axis.index() >= ndim {
64            return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
65        }
66        let nd_axis = ndarray::Axis(axis.index());
67        Ok(self
68            .inner
69            .lanes(nd_axis)
70            .into_iter()
71            .map(|lane| ArrayView::from_ndarray(lane)))
72    }
73
74    /// Iterate over sub-arrays along the given axis.
75    ///
76    /// For a 3-D array with shape `[2,3,4]` and `axis=0`, this yields
77    /// two 2-D views each of shape `[3,4]`.
78    ///
79    /// # Errors
80    /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
81    pub fn axis_iter(
82        &self,
83        axis: Axis,
84    ) -> FerrayResult<impl Iterator<Item = ArrayView<'_, T, IxDyn>> + '_>
85    where
86        D::NdarrayDim: ndarray::RemoveAxis,
87    {
88        let ndim = self.ndim();
89        if axis.index() >= ndim {
90            return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
91        }
92        let nd_axis = ndarray::Axis(axis.index());
93        Ok(self.inner.axis_iter(nd_axis).map(|sub| {
94            let dyn_view = sub.into_dyn();
95            ArrayView::from_ndarray(dyn_view)
96        }))
97    }
98
99    /// Mutably iterate over sub-arrays along the given axis.
100    ///
101    /// # Errors
102    /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
103    pub fn axis_iter_mut(
104        &mut self,
105        axis: Axis,
106    ) -> FerrayResult<impl Iterator<Item = ArrayViewMut<'_, T, IxDyn>> + '_>
107    where
108        D::NdarrayDim: ndarray::RemoveAxis,
109    {
110        let ndim = self.ndim();
111        if axis.index() >= ndim {
112            return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
113        }
114        let nd_axis = ndarray::Axis(axis.index());
115        Ok(self.inner.axis_iter_mut(nd_axis).map(|sub| {
116            let dyn_view = sub.into_dyn();
117            ArrayViewMut::from_ndarray(dyn_view)
118        }))
119    }
120}
121
122// ---------------------------------------------------------------------------
123// Consuming iterator
124// ---------------------------------------------------------------------------
125
126impl<T: Element, D: Dimension> IntoIterator for Array<T, D> {
127    type Item = T;
128    type IntoIter = ndarray::iter::IntoIter<T, D::NdarrayDim>;
129
130    fn into_iter(self) -> Self::IntoIter {
131        self.inner.into_iter()
132    }
133}
134
135impl<'a, T: Element, D: Dimension> IntoIterator for &'a Array<T, D> {
136    type Item = &'a T;
137    type IntoIter = ndarray::iter::Iter<'a, T, D::NdarrayDim>;
138
139    fn into_iter(self) -> Self::IntoIter {
140        self.inner.iter()
141    }
142}
143
144impl<'a, T: Element, D: Dimension> IntoIterator for &'a mut Array<T, D> {
145    type Item = &'a mut T;
146    type IntoIter = ndarray::iter::IterMut<'a, T, D::NdarrayDim>;
147
148    fn into_iter(self) -> Self::IntoIter {
149        self.inner.iter_mut()
150    }
151}
152
153// ---------------------------------------------------------------------------
154// ArrayView iteration
155// ---------------------------------------------------------------------------
156
157impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
158    /// Iterate over all elements in logical order.
159    pub fn iter(&self) -> impl Iterator<Item = &T> + '_ {
160        self.inner.iter()
161    }
162
163    /// Flat iterator.
164    pub fn flat(&self) -> impl Iterator<Item = &T> + '_ {
165        self.inner.iter()
166    }
167
168    /// Iterate with multi-dimensional indices. See [`Array::indexed_iter`].
169    pub fn indexed_iter(&self) -> impl Iterator<Item = (Vec<usize>, &T)> + '_ {
170        self.inner.indexed_iter().map(|(pat, val)| {
171            let dim = pat.into_dimension();
172            (dim.slice().to_vec(), val)
173        })
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use crate::dimension::{Ix1, Ix2};
181
182    #[test]
183    fn iter_elements() {
184        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
185        let collected: Vec<f64> = arr.iter().copied().collect();
186        assert_eq!(collected, vec![1.0, 2.0, 3.0, 4.0]);
187    }
188
189    #[test]
190    fn iter_mut_elements() {
191        let mut arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
192        for x in arr.iter_mut() {
193            *x *= 2.0;
194        }
195        assert_eq!(arr.as_slice().unwrap(), &[2.0, 4.0, 6.0]);
196    }
197
198    #[test]
199    fn into_iter_consuming() {
200        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
201        let collected: Vec<i32> = arr.into_iter().collect();
202        assert_eq!(collected, vec![10, 20, 30]);
203    }
204
205    #[test]
206    fn indexed_iter_2d() {
207        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
208        let items: Vec<_> = arr.indexed_iter().collect();
209        assert_eq!(items.len(), 6);
210        assert_eq!(items[0], (vec![0, 0], &1));
211        assert_eq!(items[1], (vec![0, 1], &2));
212        assert_eq!(items[3], (vec![1, 0], &4));
213    }
214
215    #[test]
216    fn flat_iterator() {
217        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
218        let flat: Vec<f64> = arr.flat().copied().collect();
219        assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0]);
220    }
221
222    #[test]
223    fn lanes_axis1() {
224        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
225            .unwrap();
226        let rows: Vec<Vec<f64>> = arr
227            .lanes(Axis(1))
228            .unwrap()
229            .map(|lane| lane.iter().copied().collect())
230            .collect();
231        assert_eq!(rows, vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
232    }
233
234    #[test]
235    fn lanes_out_of_bounds() {
236        let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
237        assert!(arr.lanes(Axis(2)).is_err());
238    }
239
240    #[test]
241    fn axis_iter_2d() {
242        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 2]), vec![1, 2, 3, 4, 5, 6]).unwrap();
243        let rows: Vec<Vec<i32>> = arr
244            .axis_iter(Axis(0))
245            .unwrap()
246            .map(|sub| sub.iter().copied().collect())
247            .collect();
248        assert_eq!(rows, vec![vec![1, 2], vec![3, 4], vec![5, 6]]);
249    }
250
251    #[test]
252    fn axis_iter_out_of_bounds() {
253        let arr = Array::<f64, Ix1>::zeros(Ix1::new([5])).unwrap();
254        assert!(arr.axis_iter(Axis(1)).is_err());
255    }
256
257    #[test]
258    fn axis_iter_mut_modify() {
259        let mut arr =
260            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
261                .unwrap();
262        for mut row in arr.axis_iter_mut(Axis(0)).unwrap() {
263            if let Some(s) = row.as_slice_mut() {
264                for v in s.iter_mut() {
265                    *v *= 10.0;
266                }
267            }
268        }
269        assert_eq!(
270            arr.as_slice().unwrap(),
271            &[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]
272        );
273    }
274
275    #[test]
276    fn for_loop_borrow() {
277        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
278        let mut sum = 0;
279        for &x in &arr {
280            sum += x;
281        }
282        assert_eq!(sum, 60);
283    }
284}