1use ndarray::{Dimension as NdarrayDimension, IntoDimension};
4
5use crate::dimension::{Axis, Dimension, Ix1, IxDyn};
6use crate::dtype::Element;
7use crate::error::{FerrayError, FerrayResult};
8
9use super::owned::Array;
10use super::view::ArrayView;
11use super::view_mut::ArrayViewMut;
12
13impl<T: Element, D: Dimension> Array<T, D> {
18 pub fn iter(&self) -> impl Iterator<Item = &T> + '_ {
20 self.inner.iter()
21 }
22
23 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut T> + '_ {
25 self.inner.iter_mut()
26 }
27
28 pub fn indexed_iter(&self) -> impl Iterator<Item = (Vec<usize>, &T)> + '_ {
40 self.inner.indexed_iter().map(|(pat, val)| {
41 let dim = pat.into_dimension();
42 (dim.slice().to_vec(), val)
43 })
44 }
45
46 pub fn flat(&self) -> impl Iterator<Item = &T> + '_ {
48 self.inner.iter()
49 }
50
51 pub fn lanes(
59 &self,
60 axis: Axis,
61 ) -> FerrayResult<impl Iterator<Item = ArrayView<'_, T, Ix1>> + '_> {
62 let ndim = self.ndim();
63 if axis.index() >= ndim {
64 return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
65 }
66 let nd_axis = ndarray::Axis(axis.index());
67 Ok(self
68 .inner
69 .lanes(nd_axis)
70 .into_iter()
71 .map(|lane| ArrayView::from_ndarray(lane)))
72 }
73
74 pub fn axis_iter(
82 &self,
83 axis: Axis,
84 ) -> FerrayResult<impl Iterator<Item = ArrayView<'_, T, IxDyn>> + '_>
85 where
86 D::NdarrayDim: ndarray::RemoveAxis,
87 {
88 let ndim = self.ndim();
89 if axis.index() >= ndim {
90 return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
91 }
92 let nd_axis = ndarray::Axis(axis.index());
93 Ok(self.inner.axis_iter(nd_axis).map(|sub| {
94 let dyn_view = sub.into_dyn();
95 ArrayView::from_ndarray(dyn_view)
96 }))
97 }
98
99 pub fn axis_iter_mut(
104 &mut self,
105 axis: Axis,
106 ) -> FerrayResult<impl Iterator<Item = ArrayViewMut<'_, T, IxDyn>> + '_>
107 where
108 D::NdarrayDim: ndarray::RemoveAxis,
109 {
110 let ndim = self.ndim();
111 if axis.index() >= ndim {
112 return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
113 }
114 let nd_axis = ndarray::Axis(axis.index());
115 Ok(self.inner.axis_iter_mut(nd_axis).map(|sub| {
116 let dyn_view = sub.into_dyn();
117 ArrayViewMut::from_ndarray(dyn_view)
118 }))
119 }
120}
121
122impl<T: Element, D: Dimension> IntoIterator for Array<T, D> {
127 type Item = T;
128 type IntoIter = ndarray::iter::IntoIter<T, D::NdarrayDim>;
129
130 fn into_iter(self) -> Self::IntoIter {
131 self.inner.into_iter()
132 }
133}
134
135impl<'a, T: Element, D: Dimension> IntoIterator for &'a Array<T, D> {
136 type Item = &'a T;
137 type IntoIter = ndarray::iter::Iter<'a, T, D::NdarrayDim>;
138
139 fn into_iter(self) -> Self::IntoIter {
140 self.inner.iter()
141 }
142}
143
144impl<'a, T: Element, D: Dimension> IntoIterator for &'a mut Array<T, D> {
145 type Item = &'a mut T;
146 type IntoIter = ndarray::iter::IterMut<'a, T, D::NdarrayDim>;
147
148 fn into_iter(self) -> Self::IntoIter {
149 self.inner.iter_mut()
150 }
151}
152
153impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
158 pub fn iter(&self) -> impl Iterator<Item = &T> + '_ {
160 self.inner.iter()
161 }
162
163 pub fn flat(&self) -> impl Iterator<Item = &T> + '_ {
165 self.inner.iter()
166 }
167
168 pub fn indexed_iter(&self) -> impl Iterator<Item = (Vec<usize>, &T)> + '_ {
170 self.inner.indexed_iter().map(|(pat, val)| {
171 let dim = pat.into_dimension();
172 (dim.slice().to_vec(), val)
173 })
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use crate::dimension::{Ix1, Ix2};
181
182 #[test]
183 fn iter_elements() {
184 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
185 let collected: Vec<f64> = arr.iter().copied().collect();
186 assert_eq!(collected, vec![1.0, 2.0, 3.0, 4.0]);
187 }
188
189 #[test]
190 fn iter_mut_elements() {
191 let mut arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
192 for x in arr.iter_mut() {
193 *x *= 2.0;
194 }
195 assert_eq!(arr.as_slice().unwrap(), &[2.0, 4.0, 6.0]);
196 }
197
198 #[test]
199 fn into_iter_consuming() {
200 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
201 let collected: Vec<i32> = arr.into_iter().collect();
202 assert_eq!(collected, vec![10, 20, 30]);
203 }
204
205 #[test]
206 fn indexed_iter_2d() {
207 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
208 let items: Vec<_> = arr.indexed_iter().collect();
209 assert_eq!(items.len(), 6);
210 assert_eq!(items[0], (vec![0, 0], &1));
211 assert_eq!(items[1], (vec![0, 1], &2));
212 assert_eq!(items[3], (vec![1, 0], &4));
213 }
214
215 #[test]
216 fn flat_iterator() {
217 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
218 let flat: Vec<f64> = arr.flat().copied().collect();
219 assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0]);
220 }
221
222 #[test]
223 fn lanes_axis1() {
224 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
225 .unwrap();
226 let rows: Vec<Vec<f64>> = arr
227 .lanes(Axis(1))
228 .unwrap()
229 .map(|lane| lane.iter().copied().collect())
230 .collect();
231 assert_eq!(rows, vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
232 }
233
234 #[test]
235 fn lanes_out_of_bounds() {
236 let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
237 assert!(arr.lanes(Axis(2)).is_err());
238 }
239
240 #[test]
241 fn axis_iter_2d() {
242 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 2]), vec![1, 2, 3, 4, 5, 6]).unwrap();
243 let rows: Vec<Vec<i32>> = arr
244 .axis_iter(Axis(0))
245 .unwrap()
246 .map(|sub| sub.iter().copied().collect())
247 .collect();
248 assert_eq!(rows, vec![vec![1, 2], vec![3, 4], vec![5, 6]]);
249 }
250
251 #[test]
252 fn axis_iter_out_of_bounds() {
253 let arr = Array::<f64, Ix1>::zeros(Ix1::new([5])).unwrap();
254 assert!(arr.axis_iter(Axis(1)).is_err());
255 }
256
257 #[test]
258 fn axis_iter_mut_modify() {
259 let mut arr =
260 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
261 .unwrap();
262 for mut row in arr.axis_iter_mut(Axis(0)).unwrap() {
263 if let Some(s) = row.as_slice_mut() {
264 for v in s.iter_mut() {
265 *v *= 10.0;
266 }
267 }
268 }
269 assert_eq!(
270 arr.as_slice().unwrap(),
271 &[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]
272 );
273 }
274
275 #[test]
276 fn for_loop_borrow() {
277 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
278 let mut sum = 0;
279 for &x in &arr {
280 sum += x;
281 }
282 assert_eq!(sum, 60);
283 }
284}