rlst/dense/array/
iterators.rs

1//! Various iterator implementations.
2
3use crate::dense::array::{views, Array, Shape, UnsafeRandomAccessByValue, UnsafeRandomAccessMut};
4use crate::dense::layout::convert_1d_nd_from_shape;
5use crate::dense::traits::AsMultiIndex;
6use crate::dense::types::RlstBase;
7
8use super::slice::ArraySlice;
9
10/// Default column major iterator.
11///
12/// This iterator returns elements of an array in standard column major order.
13pub struct ArrayDefaultIterator<
14    'a,
15    Item: RlstBase,
16    ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
17    const NDIM: usize,
18> {
19    arr: &'a Array<Item, ArrayImpl, NDIM>,
20    shape: [usize; NDIM],
21    pos: usize,
22    nelements: usize,
23}
24
25/// Mutable default iterator. Like [ArrayDefaultIterator] but with mutable access.
26pub struct ArrayDefaultIteratorMut<
27    'a,
28    Item: RlstBase,
29    ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item>
30        + Shape<NDIM>
31        + UnsafeRandomAccessMut<NDIM, Item = Item>,
32    const NDIM: usize,
33> {
34    arr: &'a mut Array<Item, ArrayImpl, NDIM>,
35    shape: [usize; NDIM],
36    pos: usize,
37    nelements: usize,
38}
39
40/// A multi-index iterator returns the corrent element and the corresponding multi-index.
41pub struct MultiIndexIterator<T, I: Iterator<Item = (usize, T)>, const NDIM: usize> {
42    shape: [usize; NDIM],
43    iter: I,
44}
45
46impl<T, I: Iterator<Item = (usize, T)>, const NDIM: usize> Iterator
47    for MultiIndexIterator<T, I, NDIM>
48{
49    type Item = ([usize; NDIM], T);
50
51    fn next(&mut self) -> Option<Self::Item> {
52        if let Some((index, value)) = self.iter.next() {
53            Some((convert_1d_nd_from_shape(index, self.shape), value))
54        } else {
55            None
56        }
57    }
58}
59
60impl<T, I: Iterator<Item = (usize, T)>, const NDIM: usize> AsMultiIndex<T, I, NDIM> for I {
61    fn multi_index(self, shape: [usize; NDIM]) -> MultiIndexIterator<T, I, NDIM> {
62        MultiIndexIterator::<T, I, NDIM> { shape, iter: self }
63    }
64}
65
66impl<
67        'a,
68        Item: RlstBase,
69        ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
70        const NDIM: usize,
71    > ArrayDefaultIterator<'a, Item, ArrayImpl, NDIM>
72{
73    fn new(arr: &'a Array<Item, ArrayImpl, NDIM>) -> Self {
74        Self {
75            arr,
76            shape: arr.shape(),
77            pos: 0,
78            nelements: arr.shape().iter().product(),
79        }
80    }
81}
82
83impl<
84        'a,
85        Item: RlstBase,
86        ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item>
87            + Shape<NDIM>
88            + UnsafeRandomAccessMut<NDIM, Item = Item>,
89        const NDIM: usize,
90    > ArrayDefaultIteratorMut<'a, Item, ArrayImpl, NDIM>
91{
92    fn new(arr: &'a mut Array<Item, ArrayImpl, NDIM>) -> Self {
93        let shape = arr.shape();
94        Self {
95            arr,
96            shape,
97            pos: 0,
98            nelements: shape.iter().product(),
99        }
100    }
101}
102
103impl<
104        'a,
105        Item: RlstBase,
106        ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
107        const NDIM: usize,
108    > std::iter::Iterator for ArrayDefaultIterator<'a, Item, ArrayImpl, NDIM>
109{
110    type Item = Item;
111    fn next(&mut self) -> Option<Self::Item> {
112        if self.pos >= self.nelements {
113            return None;
114        }
115        let multi_index = convert_1d_nd_from_shape(self.pos, self.shape);
116        self.pos += 1;
117        unsafe { Some(self.arr.get_value_unchecked(multi_index)) }
118    }
119}
120
121// In the following have to use transmute to manually change the lifetime of the data
122// obtained by `get_mut` to the lifetime 'a of the matrix. The borrow checker cannot see
123// that the reference obtained by get_mut is bound to the lifetime of the iterator due
124// to the mutable reference in its initialization.
125// See also: https://stackoverflow.com/questions/62361624/lifetime-parameter-problem-in-custom-iterator-over-mutable-references
126// And also: https://users.rust-lang.org/t/when-is-transmuting-lifetimes-useful/56140
127
128impl<
129        'a,
130        Item: RlstBase,
131        ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item>
132            + UnsafeRandomAccessMut<NDIM, Item = Item>
133            + Shape<NDIM>,
134        const NDIM: usize,
135    > std::iter::Iterator for ArrayDefaultIteratorMut<'a, Item, ArrayImpl, NDIM>
136{
137    type Item = &'a mut Item;
138    fn next(&mut self) -> Option<Self::Item> {
139        if self.pos >= self.nelements {
140            return None;
141        }
142        let multi_index = convert_1d_nd_from_shape(self.pos, self.shape);
143        self.pos += 1;
144        unsafe {
145            Some(std::mem::transmute::<&mut Item, &'a mut Item>(
146                self.arr.get_unchecked_mut(multi_index),
147            ))
148        }
149    }
150}
151
152impl<
153        Item: RlstBase,
154        ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
155        const NDIM: usize,
156    > crate::dense::traits::DefaultIterator for Array<Item, ArrayImpl, NDIM>
157{
158    type Item = Item;
159    type Iter<'a> = ArrayDefaultIterator<'a, Item, ArrayImpl, NDIM> where Self: 'a;
160
161    fn iter(&self) -> Self::Iter<'_> {
162        ArrayDefaultIterator::new(self)
163    }
164}
165
166impl<
167        Item: RlstBase,
168        ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item>
169            + Shape<NDIM>
170            + UnsafeRandomAccessMut<NDIM, Item = Item>,
171        const NDIM: usize,
172    > crate::dense::traits::DefaultIteratorMut for Array<Item, ArrayImpl, NDIM>
173{
174    type Item = Item;
175    type IterMut<'a> = ArrayDefaultIteratorMut<'a, Item, ArrayImpl, NDIM> where Self: 'a;
176
177    fn iter_mut(&mut self) -> Self::IterMut<'_> {
178        ArrayDefaultIteratorMut::new(self)
179    }
180}
181
182/// Row iterator
183pub struct RowIterator<
184    'a,
185    Item: RlstBase,
186    ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
187    const NDIM: usize,
188> {
189    arr: &'a Array<Item, ArrayImpl, NDIM>,
190    nrows: usize,
191    current_row: usize,
192}
193
194/// Mutable row iterator
195pub struct RowIteratorMut<
196    'a,
197    Item: RlstBase,
198    ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item>
199        + Shape<NDIM>
200        + UnsafeRandomAccessMut<NDIM, Item = Item>,
201    const NDIM: usize,
202> {
203    arr: &'a mut Array<Item, ArrayImpl, NDIM>,
204    nrows: usize,
205    current_row: usize,
206}
207
208impl<'a, Item: RlstBase, ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + Shape<2>>
209    std::iter::Iterator for RowIterator<'a, Item, ArrayImpl, 2>
210{
211    type Item = Array<Item, ArraySlice<Item, views::ArrayView<'a, Item, ArrayImpl, 2>, 2, 1>, 1>;
212    fn next(&mut self) -> Option<Self::Item> {
213        if self.current_row >= self.nrows {
214            return None;
215        }
216        let slice = self.arr.view().slice(0, self.current_row);
217        self.current_row += 1;
218        Some(slice)
219    }
220}
221
222impl<
223        'a,
224        Item: RlstBase,
225        ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item>
226            + UnsafeRandomAccessMut<2, Item = Item>
227            + Shape<2>,
228    > std::iter::Iterator for RowIteratorMut<'a, Item, ArrayImpl, 2>
229{
230    type Item = Array<Item, ArraySlice<Item, views::ArrayViewMut<'a, Item, ArrayImpl, 2>, 2, 1>, 1>;
231    fn next(&mut self) -> Option<Self::Item> {
232        if self.current_row >= self.nrows {
233            return None;
234        }
235        let slice = self.arr.view_mut().slice(0, self.current_row);
236        self.current_row += 1;
237        unsafe {
238            Some(std::mem::transmute::<
239                Array<Item, ArraySlice<Item, views::ArrayViewMut<'_, Item, ArrayImpl, 2>, 2, 1>, 1>,
240                Array<Item, ArraySlice<Item, views::ArrayViewMut<'a, Item, ArrayImpl, 2>, 2, 1>, 1>,
241            >(slice))
242        }
243    }
244}
245
246impl<Item: RlstBase, ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + Shape<2>>
247    Array<Item, ArrayImpl, 2>
248{
249    /// Return a row iterator for a two-dimensional array.
250    pub fn row_iter(&self) -> RowIterator<'_, Item, ArrayImpl, 2> {
251        RowIterator {
252            arr: self,
253            nrows: self.shape()[0],
254            current_row: 0,
255        }
256    }
257}
258
259impl<
260        Item: RlstBase,
261        ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item>
262            + Shape<2>
263            + UnsafeRandomAccessMut<2, Item = Item>,
264    > Array<Item, ArrayImpl, 2>
265{
266    /// Return a mutable row iterator for a two-dimensional array.
267    pub fn row_iter_mut(&mut self) -> RowIteratorMut<'_, Item, ArrayImpl, 2> {
268        let nrows = self.shape()[0];
269        RowIteratorMut {
270            arr: self,
271            nrows,
272            current_row: 0,
273        }
274    }
275}
276
277/// Column iterator
278pub struct ColIterator<
279    'a,
280    Item: RlstBase,
281    ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
282    const NDIM: usize,
283> {
284    arr: &'a Array<Item, ArrayImpl, NDIM>,
285    ncols: usize,
286    current_col: usize,
287}
288
289/// Mutable column iterator
290pub struct ColIteratorMut<
291    'a,
292    Item: RlstBase,
293    ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item>
294        + Shape<NDIM>
295        + UnsafeRandomAccessMut<NDIM, Item = Item>,
296    const NDIM: usize,
297> {
298    arr: &'a mut Array<Item, ArrayImpl, NDIM>,
299    ncols: usize,
300    current_col: usize,
301}
302
303impl<'a, Item: RlstBase, ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + Shape<2>>
304    std::iter::Iterator for ColIterator<'a, Item, ArrayImpl, 2>
305{
306    type Item = Array<Item, ArraySlice<Item, views::ArrayView<'a, Item, ArrayImpl, 2>, 2, 1>, 1>;
307    fn next(&mut self) -> Option<Self::Item> {
308        if self.current_col >= self.ncols {
309            return None;
310        }
311        let slice = self.arr.view().slice(1, self.current_col);
312        self.current_col += 1;
313        Some(slice)
314    }
315}
316
317impl<
318        'a,
319        Item: RlstBase,
320        ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item>
321            + UnsafeRandomAccessMut<2, Item = Item>
322            + Shape<2>,
323    > std::iter::Iterator for ColIteratorMut<'a, Item, ArrayImpl, 2>
324{
325    type Item = Array<Item, ArraySlice<Item, views::ArrayViewMut<'a, Item, ArrayImpl, 2>, 2, 1>, 1>;
326    fn next(&mut self) -> Option<Self::Item> {
327        if self.current_col >= self.ncols {
328            return None;
329        }
330        let slice = self.arr.view_mut().slice(1, self.current_col);
331        self.current_col += 1;
332        unsafe {
333            Some(std::mem::transmute::<
334                Array<Item, ArraySlice<Item, views::ArrayViewMut<'_, Item, ArrayImpl, 2>, 2, 1>, 1>,
335                Array<Item, ArraySlice<Item, views::ArrayViewMut<'a, Item, ArrayImpl, 2>, 2, 1>, 1>,
336            >(slice))
337        }
338    }
339}
340
341impl<Item: RlstBase, ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + Shape<2>>
342    Array<Item, ArrayImpl, 2>
343{
344    /// Return a column iterator for a two-dimensional array.
345    pub fn col_iter(&self) -> ColIterator<'_, Item, ArrayImpl, 2> {
346        ColIterator {
347            arr: self,
348            ncols: self.shape()[1],
349            current_col: 0,
350        }
351    }
352}
353
354impl<
355        Item: RlstBase,
356        ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item>
357            + Shape<2>
358            + UnsafeRandomAccessMut<2, Item = Item>,
359    > Array<Item, ArrayImpl, 2>
360{
361    /// Return a mutable column iterator for a two-dimensional array.
362    pub fn col_iter_mut(&mut self) -> ColIteratorMut<'_, Item, ArrayImpl, 2> {
363        let ncols = self.shape()[1];
364        ColIteratorMut {
365            arr: self,
366            ncols,
367            current_col: 0,
368        }
369    }
370}