1use crate::dense::array::{views, Array, Shape, UnsafeRandomAccessByValue, UnsafeRandomAccessMut};
4use crate::dense::layout::convert_1d_nd_from_shape;
5use crate::dense::traits::AsMultiIndex;
6use crate::dense::types::RlstBase;
7
8use super::slice::ArraySlice;
9
10pub struct ArrayDefaultIterator<
14 'a,
15 Item: RlstBase,
16 ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
17 const NDIM: usize,
18> {
19 arr: &'a Array<Item, ArrayImpl, NDIM>,
20 shape: [usize; NDIM],
21 pos: usize,
22 nelements: usize,
23}
24
25pub struct ArrayDefaultIteratorMut<
27 'a,
28 Item: RlstBase,
29 ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item>
30 + Shape<NDIM>
31 + UnsafeRandomAccessMut<NDIM, Item = Item>,
32 const NDIM: usize,
33> {
34 arr: &'a mut Array<Item, ArrayImpl, NDIM>,
35 shape: [usize; NDIM],
36 pos: usize,
37 nelements: usize,
38}
39
40pub struct MultiIndexIterator<T, I: Iterator<Item = (usize, T)>, const NDIM: usize> {
42 shape: [usize; NDIM],
43 iter: I,
44}
45
46impl<T, I: Iterator<Item = (usize, T)>, const NDIM: usize> Iterator
47 for MultiIndexIterator<T, I, NDIM>
48{
49 type Item = ([usize; NDIM], T);
50
51 fn next(&mut self) -> Option<Self::Item> {
52 if let Some((index, value)) = self.iter.next() {
53 Some((convert_1d_nd_from_shape(index, self.shape), value))
54 } else {
55 None
56 }
57 }
58}
59
60impl<T, I: Iterator<Item = (usize, T)>, const NDIM: usize> AsMultiIndex<T, I, NDIM> for I {
61 fn multi_index(self, shape: [usize; NDIM]) -> MultiIndexIterator<T, I, NDIM> {
62 MultiIndexIterator::<T, I, NDIM> { shape, iter: self }
63 }
64}
65
66impl<
67 'a,
68 Item: RlstBase,
69 ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
70 const NDIM: usize,
71 > ArrayDefaultIterator<'a, Item, ArrayImpl, NDIM>
72{
73 fn new(arr: &'a Array<Item, ArrayImpl, NDIM>) -> Self {
74 Self {
75 arr,
76 shape: arr.shape(),
77 pos: 0,
78 nelements: arr.shape().iter().product(),
79 }
80 }
81}
82
83impl<
84 'a,
85 Item: RlstBase,
86 ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item>
87 + Shape<NDIM>
88 + UnsafeRandomAccessMut<NDIM, Item = Item>,
89 const NDIM: usize,
90 > ArrayDefaultIteratorMut<'a, Item, ArrayImpl, NDIM>
91{
92 fn new(arr: &'a mut Array<Item, ArrayImpl, NDIM>) -> Self {
93 let shape = arr.shape();
94 Self {
95 arr,
96 shape,
97 pos: 0,
98 nelements: shape.iter().product(),
99 }
100 }
101}
102
103impl<
104 'a,
105 Item: RlstBase,
106 ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
107 const NDIM: usize,
108 > std::iter::Iterator for ArrayDefaultIterator<'a, Item, ArrayImpl, NDIM>
109{
110 type Item = Item;
111 fn next(&mut self) -> Option<Self::Item> {
112 if self.pos >= self.nelements {
113 return None;
114 }
115 let multi_index = convert_1d_nd_from_shape(self.pos, self.shape);
116 self.pos += 1;
117 unsafe { Some(self.arr.get_value_unchecked(multi_index)) }
118 }
119}
120
121impl<
129 'a,
130 Item: RlstBase,
131 ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item>
132 + UnsafeRandomAccessMut<NDIM, Item = Item>
133 + Shape<NDIM>,
134 const NDIM: usize,
135 > std::iter::Iterator for ArrayDefaultIteratorMut<'a, Item, ArrayImpl, NDIM>
136{
137 type Item = &'a mut Item;
138 fn next(&mut self) -> Option<Self::Item> {
139 if self.pos >= self.nelements {
140 return None;
141 }
142 let multi_index = convert_1d_nd_from_shape(self.pos, self.shape);
143 self.pos += 1;
144 unsafe {
145 Some(std::mem::transmute::<&mut Item, &'a mut Item>(
146 self.arr.get_unchecked_mut(multi_index),
147 ))
148 }
149 }
150}
151
152impl<
153 Item: RlstBase,
154 ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
155 const NDIM: usize,
156 > crate::dense::traits::DefaultIterator for Array<Item, ArrayImpl, NDIM>
157{
158 type Item = Item;
159 type Iter<'a> = ArrayDefaultIterator<'a, Item, ArrayImpl, NDIM> where Self: 'a;
160
161 fn iter(&self) -> Self::Iter<'_> {
162 ArrayDefaultIterator::new(self)
163 }
164}
165
166impl<
167 Item: RlstBase,
168 ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item>
169 + Shape<NDIM>
170 + UnsafeRandomAccessMut<NDIM, Item = Item>,
171 const NDIM: usize,
172 > crate::dense::traits::DefaultIteratorMut for Array<Item, ArrayImpl, NDIM>
173{
174 type Item = Item;
175 type IterMut<'a> = ArrayDefaultIteratorMut<'a, Item, ArrayImpl, NDIM> where Self: 'a;
176
177 fn iter_mut(&mut self) -> Self::IterMut<'_> {
178 ArrayDefaultIteratorMut::new(self)
179 }
180}
181
182pub struct RowIterator<
184 'a,
185 Item: RlstBase,
186 ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
187 const NDIM: usize,
188> {
189 arr: &'a Array<Item, ArrayImpl, NDIM>,
190 nrows: usize,
191 current_row: usize,
192}
193
194pub struct RowIteratorMut<
196 'a,
197 Item: RlstBase,
198 ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item>
199 + Shape<NDIM>
200 + UnsafeRandomAccessMut<NDIM, Item = Item>,
201 const NDIM: usize,
202> {
203 arr: &'a mut Array<Item, ArrayImpl, NDIM>,
204 nrows: usize,
205 current_row: usize,
206}
207
208impl<'a, Item: RlstBase, ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + Shape<2>>
209 std::iter::Iterator for RowIterator<'a, Item, ArrayImpl, 2>
210{
211 type Item = Array<Item, ArraySlice<Item, views::ArrayView<'a, Item, ArrayImpl, 2>, 2, 1>, 1>;
212 fn next(&mut self) -> Option<Self::Item> {
213 if self.current_row >= self.nrows {
214 return None;
215 }
216 let slice = self.arr.view().slice(0, self.current_row);
217 self.current_row += 1;
218 Some(slice)
219 }
220}
221
222impl<
223 'a,
224 Item: RlstBase,
225 ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item>
226 + UnsafeRandomAccessMut<2, Item = Item>
227 + Shape<2>,
228 > std::iter::Iterator for RowIteratorMut<'a, Item, ArrayImpl, 2>
229{
230 type Item = Array<Item, ArraySlice<Item, views::ArrayViewMut<'a, Item, ArrayImpl, 2>, 2, 1>, 1>;
231 fn next(&mut self) -> Option<Self::Item> {
232 if self.current_row >= self.nrows {
233 return None;
234 }
235 let slice = self.arr.view_mut().slice(0, self.current_row);
236 self.current_row += 1;
237 unsafe {
238 Some(std::mem::transmute::<
239 Array<Item, ArraySlice<Item, views::ArrayViewMut<'_, Item, ArrayImpl, 2>, 2, 1>, 1>,
240 Array<Item, ArraySlice<Item, views::ArrayViewMut<'a, Item, ArrayImpl, 2>, 2, 1>, 1>,
241 >(slice))
242 }
243 }
244}
245
246impl<Item: RlstBase, ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + Shape<2>>
247 Array<Item, ArrayImpl, 2>
248{
249 pub fn row_iter(&self) -> RowIterator<'_, Item, ArrayImpl, 2> {
251 RowIterator {
252 arr: self,
253 nrows: self.shape()[0],
254 current_row: 0,
255 }
256 }
257}
258
259impl<
260 Item: RlstBase,
261 ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item>
262 + Shape<2>
263 + UnsafeRandomAccessMut<2, Item = Item>,
264 > Array<Item, ArrayImpl, 2>
265{
266 pub fn row_iter_mut(&mut self) -> RowIteratorMut<'_, Item, ArrayImpl, 2> {
268 let nrows = self.shape()[0];
269 RowIteratorMut {
270 arr: self,
271 nrows,
272 current_row: 0,
273 }
274 }
275}
276
277pub struct ColIterator<
279 'a,
280 Item: RlstBase,
281 ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
282 const NDIM: usize,
283> {
284 arr: &'a Array<Item, ArrayImpl, NDIM>,
285 ncols: usize,
286 current_col: usize,
287}
288
289pub struct ColIteratorMut<
291 'a,
292 Item: RlstBase,
293 ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item>
294 + Shape<NDIM>
295 + UnsafeRandomAccessMut<NDIM, Item = Item>,
296 const NDIM: usize,
297> {
298 arr: &'a mut Array<Item, ArrayImpl, NDIM>,
299 ncols: usize,
300 current_col: usize,
301}
302
303impl<'a, Item: RlstBase, ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + Shape<2>>
304 std::iter::Iterator for ColIterator<'a, Item, ArrayImpl, 2>
305{
306 type Item = Array<Item, ArraySlice<Item, views::ArrayView<'a, Item, ArrayImpl, 2>, 2, 1>, 1>;
307 fn next(&mut self) -> Option<Self::Item> {
308 if self.current_col >= self.ncols {
309 return None;
310 }
311 let slice = self.arr.view().slice(1, self.current_col);
312 self.current_col += 1;
313 Some(slice)
314 }
315}
316
317impl<
318 'a,
319 Item: RlstBase,
320 ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item>
321 + UnsafeRandomAccessMut<2, Item = Item>
322 + Shape<2>,
323 > std::iter::Iterator for ColIteratorMut<'a, Item, ArrayImpl, 2>
324{
325 type Item = Array<Item, ArraySlice<Item, views::ArrayViewMut<'a, Item, ArrayImpl, 2>, 2, 1>, 1>;
326 fn next(&mut self) -> Option<Self::Item> {
327 if self.current_col >= self.ncols {
328 return None;
329 }
330 let slice = self.arr.view_mut().slice(1, self.current_col);
331 self.current_col += 1;
332 unsafe {
333 Some(std::mem::transmute::<
334 Array<Item, ArraySlice<Item, views::ArrayViewMut<'_, Item, ArrayImpl, 2>, 2, 1>, 1>,
335 Array<Item, ArraySlice<Item, views::ArrayViewMut<'a, Item, ArrayImpl, 2>, 2, 1>, 1>,
336 >(slice))
337 }
338 }
339}
340
341impl<Item: RlstBase, ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + Shape<2>>
342 Array<Item, ArrayImpl, 2>
343{
344 pub fn col_iter(&self) -> ColIterator<'_, Item, ArrayImpl, 2> {
346 ColIterator {
347 arr: self,
348 ncols: self.shape()[1],
349 current_col: 0,
350 }
351 }
352}
353
354impl<
355 Item: RlstBase,
356 ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item>
357 + Shape<2>
358 + UnsafeRandomAccessMut<2, Item = Item>,
359 > Array<Item, ArrayImpl, 2>
360{
361 pub fn col_iter_mut(&mut self) -> ColIteratorMut<'_, Item, ArrayImpl, 2> {
363 let ncols = self.shape()[1];
364 ColIteratorMut {
365 arr: self,
366 ncols,
367 current_col: 0,
368 }
369 }
370}