Skip to main content

ferray_core/array/
iter.rs

1// ferray-core: Iterator implementations (REQ-37)
2
3use crate::dimension::{Axis, Dimension, Ix1, IxDyn};
4use crate::dtype::Element;
5use crate::error::{FerrumError, FerrumResult};
6
7use super::owned::Array;
8use super::view::ArrayView;
9use super::view_mut::ArrayViewMut;
10
11// ---------------------------------------------------------------------------
12// Element iteration for Array<T, D>
13// ---------------------------------------------------------------------------
14
15impl<T: Element, D: Dimension> Array<T, D> {
16    /// Iterate over all elements in logical (row-major) order.
17    pub fn iter(&self) -> impl Iterator<Item = &T> + '_ {
18        self.inner.iter()
19    }
20
21    /// Mutably iterate over all elements in logical order.
22    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut T> + '_ {
23        self.inner.iter_mut()
24    }
25
26    /// Iterate with multi-dimensional indices.
27    ///
28    /// Yields `(Vec<usize>, &T)` pairs in logical order. The index vector
29    /// has one entry per dimension.
30    pub fn indexed_iter(&self) -> impl Iterator<Item = (Vec<usize>, &T)> + '_ {
31        let shape = self.shape().to_vec();
32        let ndim = shape.len();
33        self.inner.iter().enumerate().map(move |(flat_idx, val)| {
34            let mut idx = vec![0usize; ndim];
35            let mut rem = flat_idx;
36            // Convert flat index to multi-dimensional (assuming C-order iteration)
37            for (d, s) in shape.iter().enumerate().rev() {
38                if *s > 0 {
39                    idx[d] = rem % s;
40                    rem /= s;
41                }
42            }
43            (idx, val)
44        })
45    }
46
47    /// Flat iterator — same as `iter()` but emphasises logical-order traversal.
48    pub fn flat(&self) -> impl Iterator<Item = &T> + '_ {
49        self.inner.iter()
50    }
51
52    /// Iterate over lanes (1-D slices) along the given axis.
53    ///
54    /// For a 2-D array with `axis=1`, this yields each row.
55    /// For `axis=0`, this yields each column.
56    ///
57    /// # Errors
58    /// Returns `FerrumError::AxisOutOfBounds` if `axis >= ndim`.
59    pub fn lanes(
60        &self,
61        axis: Axis,
62    ) -> FerrumResult<impl Iterator<Item = ArrayView<'_, T, Ix1>> + '_> {
63        let ndim = self.ndim();
64        if axis.index() >= ndim {
65            return Err(FerrumError::axis_out_of_bounds(axis.index(), ndim));
66        }
67        let nd_axis = ndarray::Axis(axis.index());
68        Ok(self
69            .inner
70            .lanes(nd_axis)
71            .into_iter()
72            .map(|lane| ArrayView::from_ndarray(lane)))
73    }
74
75    /// Iterate over sub-arrays along the given axis.
76    ///
77    /// For a 3-D array with shape `[2,3,4]` and `axis=0`, this yields
78    /// two 2-D views each of shape `[3,4]`.
79    ///
80    /// # Errors
81    /// Returns `FerrumError::AxisOutOfBounds` if `axis >= ndim`.
82    pub fn axis_iter(
83        &self,
84        axis: Axis,
85    ) -> FerrumResult<impl Iterator<Item = ArrayView<'_, T, IxDyn>> + '_>
86    where
87        D::NdarrayDim: ndarray::RemoveAxis,
88    {
89        let ndim = self.ndim();
90        if axis.index() >= ndim {
91            return Err(FerrumError::axis_out_of_bounds(axis.index(), ndim));
92        }
93        let nd_axis = ndarray::Axis(axis.index());
94        Ok(self.inner.axis_iter(nd_axis).map(|sub| {
95            let dyn_view = sub.into_dyn();
96            ArrayView::from_ndarray(dyn_view)
97        }))
98    }
99
100    /// Mutably iterate over sub-arrays along the given axis.
101    ///
102    /// # Errors
103    /// Returns `FerrumError::AxisOutOfBounds` if `axis >= ndim`.
104    pub fn axis_iter_mut(
105        &mut self,
106        axis: Axis,
107    ) -> FerrumResult<impl Iterator<Item = ArrayViewMut<'_, T, IxDyn>> + '_>
108    where
109        D::NdarrayDim: ndarray::RemoveAxis,
110    {
111        let ndim = self.ndim();
112        if axis.index() >= ndim {
113            return Err(FerrumError::axis_out_of_bounds(axis.index(), ndim));
114        }
115        let nd_axis = ndarray::Axis(axis.index());
116        Ok(self.inner.axis_iter_mut(nd_axis).map(|sub| {
117            let dyn_view = sub.into_dyn();
118            ArrayViewMut::from_ndarray(dyn_view)
119        }))
120    }
121}
122
123// ---------------------------------------------------------------------------
124// Consuming iterator
125// ---------------------------------------------------------------------------
126
127impl<T: Element, D: Dimension> IntoIterator for Array<T, D> {
128    type Item = T;
129    type IntoIter = ndarray::iter::IntoIter<T, D::NdarrayDim>;
130
131    fn into_iter(self) -> Self::IntoIter {
132        self.inner.into_iter()
133    }
134}
135
136impl<'a, T: Element, D: Dimension> IntoIterator for &'a Array<T, D> {
137    type Item = &'a T;
138    type IntoIter = ndarray::iter::Iter<'a, T, D::NdarrayDim>;
139
140    fn into_iter(self) -> Self::IntoIter {
141        self.inner.iter()
142    }
143}
144
145impl<'a, T: Element, D: Dimension> IntoIterator for &'a mut Array<T, D> {
146    type Item = &'a mut T;
147    type IntoIter = ndarray::iter::IterMut<'a, T, D::NdarrayDim>;
148
149    fn into_iter(self) -> Self::IntoIter {
150        self.inner.iter_mut()
151    }
152}
153
154// ---------------------------------------------------------------------------
155// ArrayView iteration
156// ---------------------------------------------------------------------------
157
158impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
159    /// Iterate over all elements in logical order.
160    pub fn iter(&self) -> impl Iterator<Item = &T> + '_ {
161        self.inner.iter()
162    }
163
164    /// Flat iterator.
165    pub fn flat(&self) -> impl Iterator<Item = &T> + '_ {
166        self.inner.iter()
167    }
168
169    /// Iterate with multi-dimensional indices.
170    pub fn indexed_iter(&self) -> impl Iterator<Item = (Vec<usize>, &T)> + '_ {
171        let shape = self.shape().to_vec();
172        let ndim = shape.len();
173        self.inner.iter().enumerate().map(move |(flat_idx, val)| {
174            let mut idx = vec![0usize; ndim];
175            let mut rem = flat_idx;
176            for (d, s) in shape.iter().enumerate().rev() {
177                if *s > 0 {
178                    idx[d] = rem % s;
179                    rem /= s;
180                }
181            }
182            (idx, val)
183        })
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::dimension::{Ix1, Ix2};
191
192    #[test]
193    fn iter_elements() {
194        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
195        let collected: Vec<f64> = arr.iter().copied().collect();
196        assert_eq!(collected, vec![1.0, 2.0, 3.0, 4.0]);
197    }
198
199    #[test]
200    fn iter_mut_elements() {
201        let mut arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
202        for x in arr.iter_mut() {
203            *x *= 2.0;
204        }
205        assert_eq!(arr.as_slice().unwrap(), &[2.0, 4.0, 6.0]);
206    }
207
208    #[test]
209    fn into_iter_consuming() {
210        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
211        let collected: Vec<i32> = arr.into_iter().collect();
212        assert_eq!(collected, vec![10, 20, 30]);
213    }
214
215    #[test]
216    fn indexed_iter_2d() {
217        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
218        let items: Vec<_> = arr.indexed_iter().collect();
219        assert_eq!(items.len(), 6);
220        assert_eq!(items[0], (vec![0, 0], &1));
221        assert_eq!(items[1], (vec![0, 1], &2));
222        assert_eq!(items[3], (vec![1, 0], &4));
223    }
224
225    #[test]
226    fn flat_iterator() {
227        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
228        let flat: Vec<f64> = arr.flat().copied().collect();
229        assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0]);
230    }
231
232    #[test]
233    fn lanes_axis1() {
234        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
235            .unwrap();
236        let rows: Vec<Vec<f64>> = arr
237            .lanes(Axis(1))
238            .unwrap()
239            .map(|lane| lane.iter().copied().collect())
240            .collect();
241        assert_eq!(rows, vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
242    }
243
244    #[test]
245    fn lanes_out_of_bounds() {
246        let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
247        assert!(arr.lanes(Axis(2)).is_err());
248    }
249
250    #[test]
251    fn axis_iter_2d() {
252        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 2]), vec![1, 2, 3, 4, 5, 6]).unwrap();
253        let rows: Vec<Vec<i32>> = arr
254            .axis_iter(Axis(0))
255            .unwrap()
256            .map(|sub| sub.iter().copied().collect())
257            .collect();
258        assert_eq!(rows, vec![vec![1, 2], vec![3, 4], vec![5, 6]]);
259    }
260
261    #[test]
262    fn axis_iter_out_of_bounds() {
263        let arr = Array::<f64, Ix1>::zeros(Ix1::new([5])).unwrap();
264        assert!(arr.axis_iter(Axis(1)).is_err());
265    }
266
267    #[test]
268    fn axis_iter_mut_modify() {
269        let mut arr =
270            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
271                .unwrap();
272        for mut row in arr.axis_iter_mut(Axis(0)).unwrap() {
273            if let Some(s) = row.as_slice_mut() {
274                for v in s.iter_mut() {
275                    *v *= 10.0;
276                }
277            }
278        }
279        assert_eq!(
280            arr.as_slice().unwrap(),
281            &[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]
282        );
283    }
284
285    #[test]
286    fn for_loop_borrow() {
287        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
288        let mut sum = 0;
289        for &x in &arr {
290            sum += x;
291        }
292        assert_eq!(sum, 60);
293    }
294}