1use crate::dimension::{Axis, Dimension, Ix1, IxDyn};
4use crate::dtype::Element;
5use crate::error::{FerrumError, FerrumResult};
6
7use super::owned::Array;
8use super::view::ArrayView;
9use super::view_mut::ArrayViewMut;
10
11impl<T: Element, D: Dimension> Array<T, D> {
16 pub fn iter(&self) -> impl Iterator<Item = &T> + '_ {
18 self.inner.iter()
19 }
20
21 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut T> + '_ {
23 self.inner.iter_mut()
24 }
25
26 pub fn indexed_iter(&self) -> impl Iterator<Item = (Vec<usize>, &T)> + '_ {
31 let shape = self.shape().to_vec();
32 let ndim = shape.len();
33 self.inner.iter().enumerate().map(move |(flat_idx, val)| {
34 let mut idx = vec![0usize; ndim];
35 let mut rem = flat_idx;
36 for (d, s) in shape.iter().enumerate().rev() {
38 if *s > 0 {
39 idx[d] = rem % s;
40 rem /= s;
41 }
42 }
43 (idx, val)
44 })
45 }
46
47 pub fn flat(&self) -> impl Iterator<Item = &T> + '_ {
49 self.inner.iter()
50 }
51
52 pub fn lanes(
60 &self,
61 axis: Axis,
62 ) -> FerrumResult<impl Iterator<Item = ArrayView<'_, T, Ix1>> + '_> {
63 let ndim = self.ndim();
64 if axis.index() >= ndim {
65 return Err(FerrumError::axis_out_of_bounds(axis.index(), ndim));
66 }
67 let nd_axis = ndarray::Axis(axis.index());
68 Ok(self
69 .inner
70 .lanes(nd_axis)
71 .into_iter()
72 .map(|lane| ArrayView::from_ndarray(lane)))
73 }
74
75 pub fn axis_iter(
83 &self,
84 axis: Axis,
85 ) -> FerrumResult<impl Iterator<Item = ArrayView<'_, T, IxDyn>> + '_>
86 where
87 D::NdarrayDim: ndarray::RemoveAxis,
88 {
89 let ndim = self.ndim();
90 if axis.index() >= ndim {
91 return Err(FerrumError::axis_out_of_bounds(axis.index(), ndim));
92 }
93 let nd_axis = ndarray::Axis(axis.index());
94 Ok(self.inner.axis_iter(nd_axis).map(|sub| {
95 let dyn_view = sub.into_dyn();
96 ArrayView::from_ndarray(dyn_view)
97 }))
98 }
99
100 pub fn axis_iter_mut(
105 &mut self,
106 axis: Axis,
107 ) -> FerrumResult<impl Iterator<Item = ArrayViewMut<'_, T, IxDyn>> + '_>
108 where
109 D::NdarrayDim: ndarray::RemoveAxis,
110 {
111 let ndim = self.ndim();
112 if axis.index() >= ndim {
113 return Err(FerrumError::axis_out_of_bounds(axis.index(), ndim));
114 }
115 let nd_axis = ndarray::Axis(axis.index());
116 Ok(self.inner.axis_iter_mut(nd_axis).map(|sub| {
117 let dyn_view = sub.into_dyn();
118 ArrayViewMut::from_ndarray(dyn_view)
119 }))
120 }
121}
122
123impl<T: Element, D: Dimension> IntoIterator for Array<T, D> {
128 type Item = T;
129 type IntoIter = ndarray::iter::IntoIter<T, D::NdarrayDim>;
130
131 fn into_iter(self) -> Self::IntoIter {
132 self.inner.into_iter()
133 }
134}
135
136impl<'a, T: Element, D: Dimension> IntoIterator for &'a Array<T, D> {
137 type Item = &'a T;
138 type IntoIter = ndarray::iter::Iter<'a, T, D::NdarrayDim>;
139
140 fn into_iter(self) -> Self::IntoIter {
141 self.inner.iter()
142 }
143}
144
145impl<'a, T: Element, D: Dimension> IntoIterator for &'a mut Array<T, D> {
146 type Item = &'a mut T;
147 type IntoIter = ndarray::iter::IterMut<'a, T, D::NdarrayDim>;
148
149 fn into_iter(self) -> Self::IntoIter {
150 self.inner.iter_mut()
151 }
152}
153
154impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
159 pub fn iter(&self) -> impl Iterator<Item = &T> + '_ {
161 self.inner.iter()
162 }
163
164 pub fn flat(&self) -> impl Iterator<Item = &T> + '_ {
166 self.inner.iter()
167 }
168
169 pub fn indexed_iter(&self) -> impl Iterator<Item = (Vec<usize>, &T)> + '_ {
171 let shape = self.shape().to_vec();
172 let ndim = shape.len();
173 self.inner.iter().enumerate().map(move |(flat_idx, val)| {
174 let mut idx = vec![0usize; ndim];
175 let mut rem = flat_idx;
176 for (d, s) in shape.iter().enumerate().rev() {
177 if *s > 0 {
178 idx[d] = rem % s;
179 rem /= s;
180 }
181 }
182 (idx, val)
183 })
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::dimension::{Ix1, Ix2};
191
192 #[test]
193 fn iter_elements() {
194 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
195 let collected: Vec<f64> = arr.iter().copied().collect();
196 assert_eq!(collected, vec![1.0, 2.0, 3.0, 4.0]);
197 }
198
199 #[test]
200 fn iter_mut_elements() {
201 let mut arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
202 for x in arr.iter_mut() {
203 *x *= 2.0;
204 }
205 assert_eq!(arr.as_slice().unwrap(), &[2.0, 4.0, 6.0]);
206 }
207
208 #[test]
209 fn into_iter_consuming() {
210 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
211 let collected: Vec<i32> = arr.into_iter().collect();
212 assert_eq!(collected, vec![10, 20, 30]);
213 }
214
215 #[test]
216 fn indexed_iter_2d() {
217 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
218 let items: Vec<_> = arr.indexed_iter().collect();
219 assert_eq!(items.len(), 6);
220 assert_eq!(items[0], (vec![0, 0], &1));
221 assert_eq!(items[1], (vec![0, 1], &2));
222 assert_eq!(items[3], (vec![1, 0], &4));
223 }
224
225 #[test]
226 fn flat_iterator() {
227 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
228 let flat: Vec<f64> = arr.flat().copied().collect();
229 assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0]);
230 }
231
232 #[test]
233 fn lanes_axis1() {
234 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
235 .unwrap();
236 let rows: Vec<Vec<f64>> = arr
237 .lanes(Axis(1))
238 .unwrap()
239 .map(|lane| lane.iter().copied().collect())
240 .collect();
241 assert_eq!(rows, vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
242 }
243
244 #[test]
245 fn lanes_out_of_bounds() {
246 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
247 assert!(arr.lanes(Axis(2)).is_err());
248 }
249
250 #[test]
251 fn axis_iter_2d() {
252 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 2]), vec![1, 2, 3, 4, 5, 6]).unwrap();
253 let rows: Vec<Vec<i32>> = arr
254 .axis_iter(Axis(0))
255 .unwrap()
256 .map(|sub| sub.iter().copied().collect())
257 .collect();
258 assert_eq!(rows, vec![vec![1, 2], vec![3, 4], vec![5, 6]]);
259 }
260
261 #[test]
262 fn axis_iter_out_of_bounds() {
263 let arr = Array::<f64, Ix1>::zeros(Ix1::new([5])).unwrap();
264 assert!(arr.axis_iter(Axis(1)).is_err());
265 }
266
267 #[test]
268 fn axis_iter_mut_modify() {
269 let mut arr =
270 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
271 .unwrap();
272 for mut row in arr.axis_iter_mut(Axis(0)).unwrap() {
273 if let Some(s) = row.as_slice_mut() {
274 for v in s.iter_mut() {
275 *v *= 10.0;
276 }
277 }
278 }
279 assert_eq!(
280 arr.as_slice().unwrap(),
281 &[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]
282 );
283 }
284
285 #[test]
286 fn for_loop_borrow() {
287 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
288 let mut sum = 0;
289 for &x in &arr {
290 sum += x;
291 }
292 assert_eq!(sum, 60);
293 }
294}