rstsr_core/tensor/
iterator_elem.rs

1use crate::prelude_dev::*;
2use core::mem::transmute;
3
4/* #region elem view iterator */
5
6pub struct IterVecView<'a, T, D>
7where
8    D: DimDevAPI,
9{
10    layout_iter: IterLayout<D>,
11    view: &'a [T],
12}
13
14impl<'a, T, D> Iterator for IterVecView<'a, T, D>
15where
16    D: DimDevAPI,
17{
18    type Item = &'a T;
19
20    fn next(&mut self) -> Option<Self::Item> {
21        self.layout_iter.next().map(|offset| &self.view[offset])
22    }
23}
24
25impl<T, D> DoubleEndedIterator for IterVecView<'_, T, D>
26where
27    D: DimDevAPI,
28{
29    fn next_back(&mut self) -> Option<Self::Item> {
30        self.layout_iter.next_back().map(|offset| &self.view[offset])
31    }
32}
33
34impl<T, D> ExactSizeIterator for IterVecView<'_, T, D>
35where
36    D: DimDevAPI,
37{
38    fn len(&self) -> usize {
39        self.layout_iter.len()
40    }
41}
42
43impl<T, D> IterSplitAtAPI for IterVecView<'_, T, D>
44where
45    D: DimDevAPI,
46{
47    fn split_at(self, mid: usize) -> (Self, Self) {
48        let (lhs, rhs) = self.layout_iter.split_at(mid);
49        let lhs = IterVecView { layout_iter: lhs, view: self.view };
50        let rhs = IterVecView { layout_iter: rhs, view: self.view };
51        (lhs, rhs)
52    }
53}
54
55impl<'a, R, T, B, D> TensorAny<R, T, B, D>
56where
57    R: DataAPI<Data = B::Raw>,
58    D: DimAPI,
59    B: DeviceAPI<T, Raw = Vec<T>> + 'a,
60{
61    pub fn iter_with_order_f(&self, order: TensorIterOrder) -> Result<IterVecView<'a, T, D>> {
62        let layout_iter = IterLayout::new(self.layout(), order)?;
63        let raw = self.raw().as_ref();
64
65        // SAFETY: The lifetime of `raw` is guaranteed to be at least `'a`.
66        // transmute is to change the lifetime, not for type casting.
67        let iter = IterVecView { layout_iter, view: raw };
68        Ok(unsafe { transmute::<IterVecView<'_, T, D>, IterVecView<'_, T, D>>(iter) })
69    }
70
71    pub fn iter_with_order(&self, order: TensorIterOrder) -> IterVecView<'a, T, D> {
72        self.iter_with_order_f(order).rstsr_unwrap()
73    }
74
75    pub fn iter_f(&self) -> Result<IterVecView<'a, T, D>> {
76        let default_order = self.device().default_order();
77        let order = match default_order {
78            RowMajor => TensorIterOrder::C,
79            ColMajor => TensorIterOrder::F,
80        };
81        self.iter_with_order_f(order)
82    }
83
84    pub fn iter(&self) -> IterVecView<'a, T, D> {
85        self.iter_f().rstsr_unwrap()
86    }
87}
88
89/* #endregion */
90
91/* #region elem mut iterator */
92
93pub struct IterVecMut<'a, T, D>
94where
95    D: DimDevAPI,
96{
97    layout_iter: IterLayout<D>,
98    view: &'a mut [T],
99}
100
101impl<'a, T, D> Iterator for IterVecMut<'a, T, D>
102where
103    D: DimDevAPI,
104{
105    type Item = &'a mut T;
106
107    fn next(&mut self) -> Option<Self::Item> {
108        self.layout_iter.next().map(|offset| unsafe { transmute(&mut self.view[offset]) })
109    }
110}
111
112impl<T, D> DoubleEndedIterator for IterVecMut<'_, T, D>
113where
114    D: DimDevAPI,
115{
116    fn next_back(&mut self) -> Option<Self::Item> {
117        self.layout_iter.next_back().map(|offset| unsafe { transmute(&mut self.view[offset]) })
118    }
119}
120
121impl<T, D> ExactSizeIterator for IterVecMut<'_, T, D>
122where
123    D: DimDevAPI,
124{
125    fn len(&self) -> usize {
126        self.layout_iter.len()
127    }
128}
129
130impl<T, D> IterSplitAtAPI for IterVecMut<'_, T, D>
131where
132    D: DimDevAPI,
133{
134    fn split_at(self, mid: usize) -> (Self, Self) {
135        // we do not split &mut [T], but split the layout iterator
136        // so we use unsafe code to generate two same &mut [T] views
137        let (lhs, rhs) = self.layout_iter.split_at(mid);
138        let cloned_view = unsafe {
139            let len = self.view.len();
140            let ptr = self.view.as_mut_ptr();
141            core::slice::from_raw_parts_mut(ptr, len)
142        };
143        let lhs = IterVecMut { layout_iter: lhs, view: cloned_view };
144        let rhs = IterVecMut { layout_iter: rhs, view: self.view };
145        (lhs, rhs)
146    }
147}
148
149impl<'a, R, T, B, D> TensorAny<R, T, B, D>
150where
151    R: DataMutAPI<Data = B::Raw>,
152    D: DimAPI,
153    B: DeviceAPI<T, Raw = Vec<T>> + 'a,
154{
155    pub fn iter_mut_with_order_f(&'a mut self, order: TensorIterOrder) -> Result<IterVecMut<'a, T, D>> {
156        let layout_iter = IterLayout::new(self.layout(), order)?;
157        let raw = self.raw_mut().as_mut();
158        let iter = IterVecMut { layout_iter, view: raw };
159        Ok(iter)
160    }
161
162    pub fn iter_mut_with_order(&'a mut self, order: TensorIterOrder) -> IterVecMut<'a, T, D> {
163        self.iter_mut_with_order_f(order).rstsr_unwrap()
164    }
165
166    pub fn iter_mut_f(&'a mut self) -> Result<IterVecMut<'a, T, D>> {
167        let default_order = self.device().default_order();
168        let order = match default_order {
169            RowMajor => TensorIterOrder::C,
170            ColMajor => TensorIterOrder::F,
171        };
172        self.iter_mut_with_order_f(order)
173    }
174
175    pub fn iter_mut(&'a mut self) -> IterVecMut<'a, T, D> {
176        self.iter_mut_f().rstsr_unwrap()
177    }
178}
179
180/* #endregion */
181
182/* #region elem view indexed iterator */
183
184pub struct IndexedIterVecView<'a, T, D>
185where
186    D: DimDevAPI,
187{
188    layout_iter: IterLayout<D>,
189    view: &'a [T],
190}
191
192impl<'a, T, D> Iterator for IndexedIterVecView<'a, T, D>
193where
194    D: DimDevAPI,
195{
196    type Item = (D, &'a T);
197
198    fn next(&mut self) -> Option<Self::Item> {
199        let index = match &self.layout_iter {
200            IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
201            IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
202        };
203        self.layout_iter.next().map(|offset| (index, &self.view[offset]))
204    }
205}
206
207impl<T, D> DoubleEndedIterator for IndexedIterVecView<'_, T, D>
208where
209    D: DimDevAPI,
210{
211    fn next_back(&mut self) -> Option<Self::Item> {
212        let index = match &self.layout_iter {
213            IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
214            IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
215        };
216        self.layout_iter.next_back().map(|offset| (index, &self.view[offset]))
217    }
218}
219
220impl<T, D> ExactSizeIterator for IndexedIterVecView<'_, T, D>
221where
222    D: DimDevAPI,
223{
224    fn len(&self) -> usize {
225        self.layout_iter.len()
226    }
227}
228
229impl<T, D> IterSplitAtAPI for IndexedIterVecView<'_, T, D>
230where
231    D: DimDevAPI,
232{
233    fn split_at(self, mid: usize) -> (Self, Self) {
234        let (lhs, rhs) = self.layout_iter.split_at(mid);
235        let lhs = IndexedIterVecView { layout_iter: lhs, view: self.view };
236        let rhs = IndexedIterVecView { layout_iter: rhs, view: self.view };
237        (lhs, rhs)
238    }
239}
240
241impl<'a, R, T, B, D> TensorAny<R, T, B, D>
242where
243    R: DataAPI<Data = B::Raw>,
244    D: DimAPI,
245    B: DeviceAPI<T, Raw = Vec<T>> + 'a,
246{
247    pub fn indexed_iter_with_order_f(&self, order: TensorIterOrder) -> Result<IndexedIterVecView<'a, T, D>> {
248        use TensorIterOrder::*;
249        // this function only accepts c/f iter order currently
250        match order {
251            C | F => (),
252            _ => rstsr_invalid!(order, "This function only accepts TensorIterOrder::C|F.",)?,
253        };
254        let layout_iter = IterLayout::<D>::new(self.layout(), order)?;
255        let raw = self.raw().as_ref();
256
257        // SAFETY: The lifetime of `raw` is guaranteed to be at least `'a`.
258        // transmute is to change the lifetime, not for type casting.
259        let iter = IndexedIterVecView { layout_iter, view: raw };
260        Ok(unsafe { transmute::<IndexedIterVecView<'_, T, D>, IndexedIterVecView<'_, T, D>>(iter) })
261    }
262
263    pub fn indexed_iter_with_order(&self, order: TensorIterOrder) -> IndexedIterVecView<'a, T, D> {
264        self.indexed_iter_with_order_f(order).rstsr_unwrap()
265    }
266
267    pub fn indexed_iter_f(&self) -> Result<IndexedIterVecView<'a, T, D>> {
268        let default_order = self.device().default_order();
269        let order = match default_order {
270            RowMajor => TensorIterOrder::C,
271            ColMajor => TensorIterOrder::F,
272        };
273        self.indexed_iter_with_order_f(order)
274    }
275
276    pub fn indexed_iter(&self) -> IndexedIterVecView<'a, T, D> {
277        self.indexed_iter_f().rstsr_unwrap()
278    }
279}
280
281/* #endregion */
282
283/* #region elem mut col iterator */
284pub struct IndexedIterVecMut<'a, T, D>
285where
286    D: DimDevAPI,
287{
288    layout_iter: IterLayout<D>,
289    view: &'a mut [T],
290}
291
292impl<'a, T, D> Iterator for IndexedIterVecMut<'a, T, D>
293where
294    D: DimDevAPI,
295{
296    type Item = (D, &'a mut T);
297
298    fn next(&mut self) -> Option<Self::Item> {
299        let index = match &self.layout_iter {
300            IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
301            IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
302        };
303        self.layout_iter.next().map(|offset| (index, unsafe { transmute::<&mut T, &mut T>(&mut self.view[offset]) }))
304    }
305}
306
307impl<T, D> DoubleEndedIterator for IndexedIterVecMut<'_, T, D>
308where
309    D: DimDevAPI,
310{
311    fn next_back(&mut self) -> Option<Self::Item> {
312        let index = match &self.layout_iter {
313            IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
314            IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
315        };
316        self.layout_iter
317            .next_back()
318            .map(|offset| (index, unsafe { transmute::<&mut T, &mut T>(&mut self.view[offset]) }))
319    }
320}
321
322impl<T, D> ExactSizeIterator for IndexedIterVecMut<'_, T, D>
323where
324    D: DimDevAPI,
325{
326    fn len(&self) -> usize {
327        self.layout_iter.len()
328    }
329}
330
331impl<T, D> IterSplitAtAPI for IndexedIterVecMut<'_, T, D>
332where
333    D: DimDevAPI,
334{
335    fn split_at(self, mid: usize) -> (Self, Self) {
336        let (lhs, rhs) = self.layout_iter.split_at(mid);
337        let cloned_view = unsafe {
338            let len = self.view.len();
339            let ptr = self.view.as_mut_ptr();
340            core::slice::from_raw_parts_mut(ptr, len)
341        };
342        let lhs = IndexedIterVecMut { layout_iter: lhs, view: cloned_view };
343        let rhs = IndexedIterVecMut { layout_iter: rhs, view: self.view };
344        (lhs, rhs)
345    }
346}
347
348impl<'a, R, T, B, D> TensorAny<R, T, B, D>
349where
350    R: DataMutAPI<Data = B::Raw>,
351    D: DimAPI,
352    B: DeviceAPI<T, Raw = Vec<T>> + 'a,
353{
354    pub fn indexed_iter_mut_with_order_f(&'a mut self, order: TensorIterOrder) -> Result<IndexedIterVecMut<'a, T, D>> {
355        use TensorIterOrder::*;
356        // this function only accepts c/f iter order currently
357        match order {
358            C | F => (),
359            _ => rstsr_invalid!(order, "This function only accepts TensorIterOrder::C|F.",)?,
360        };
361        let layout_iter = IterLayout::<D>::new(self.layout(), order)?;
362        let raw = self.raw_mut().as_mut();
363
364        let iter = IndexedIterVecMut { layout_iter, view: raw };
365        Ok(iter)
366    }
367
368    pub fn indexed_iter_mut_with_order(&'a mut self, order: TensorIterOrder) -> IndexedIterVecMut<'a, T, D> {
369        self.indexed_iter_mut_with_order_f(order).rstsr_unwrap()
370    }
371
372    pub fn indexed_iter_mut_f(&'a mut self) -> Result<IndexedIterVecMut<'a, T, D>> {
373        let default_order = self.device().default_order();
374        let order = match default_order {
375            RowMajor => TensorIterOrder::C,
376            ColMajor => TensorIterOrder::F,
377        };
378        self.indexed_iter_mut_with_order_f(order)
379    }
380
381    pub fn indexed_iter_mut(&'a mut self) -> IndexedIterVecMut<'a, T, D> {
382        self.indexed_iter_mut_f().rstsr_unwrap()
383    }
384}
385
386/* #endregion */
387
388#[cfg(test)]
389mod tests_serial {
390    use super::*;
391
392    #[test]
393    fn test_iter() {
394        let a = arange(6).into_shape([3, 2]);
395        let iter = a.iter();
396        let vec = iter.collect::<Vec<_>>();
397        assert_eq!(vec, vec![&0, &1, &2, &3, &4, &5]);
398
399        let iter_t = a.t().iter();
400        let vec_t = iter_t.collect::<Vec<_>>();
401        #[cfg(not(feature = "col_major"))]
402        {
403            // a = np.arange(6).reshape(3, 2)
404            // a.T.reshape(-1)
405            assert_eq!(vec_t, vec![&0, &2, &4, &1, &3, &5]);
406        }
407        #[cfg(feature = "col_major")]
408        {
409            // a = reshape(range(0, 5), (3, 2));
410            // reshape(a', 6)
411            assert_eq!(vec_t, vec![&0, &3, &1, &4, &2, &5]);
412        }
413    }
414
415    #[test]
416    fn test_mut_iter() {
417        let mut a = arange(6usize).into_shape([3, 2]);
418        let iter = a.iter_mut();
419        iter.for_each(|x| *x = 0);
420        let a = a.reshape(-1).to_vec();
421        assert_eq!(a, vec![0, 0, 0, 0, 0, 0]);
422    }
423
424    #[test]
425    fn test_indexed_c_iter() {
426        let a = arange(6).into_layout([3, 2].c());
427        let iter = a.indexed_iter_with_order(TensorIterOrder::C);
428        let vec = iter.collect::<Vec<_>>();
429        #[cfg(not(feature = "col_major"))]
430        assert_eq!(vec, vec![([0, 0], &0), ([0, 1], &1), ([1, 0], &2), ([1, 1], &3), ([2, 0], &4), ([2, 1], &5)]);
431        #[cfg(feature = "col_major")]
432        assert_eq!(vec, vec![([0, 0], &0), ([0, 1], &3), ([1, 0], &1), ([1, 1], &4), ([2, 0], &2), ([2, 1], &5)]);
433
434        let iter_t = a.t().indexed_iter_with_order(TensorIterOrder::C);
435        let vec_t = iter_t.collect::<Vec<_>>();
436        #[cfg(not(feature = "col_major"))]
437        assert_eq!(vec_t, vec![([0, 0], &0), ([0, 1], &2), ([0, 2], &4), ([1, 0], &1), ([1, 1], &3), ([1, 2], &5)]);
438        #[cfg(feature = "col_major")]
439        assert_eq!(vec_t, vec![([0, 0], &0), ([0, 1], &1), ([0, 2], &2), ([1, 0], &3), ([1, 1], &4), ([1, 2], &5)]);
440    }
441}
442
443#[cfg(test)]
444#[cfg(feature = "rayon")]
445mod tests_parallel {
446    use super::*;
447    use rayon::prelude::*;
448
449    #[test]
450    fn test_iter() {
451        let a = arange(16384).into_shape([128, 128]);
452        let iter = a.iter().into_par_iter();
453        let vec = iter.collect::<Vec<_>>();
454        assert_eq!(vec[..6], vec![&0, &1, &2, &3, &4, &5]);
455
456        let iter_t = a.t().iter().into_par_iter();
457        let vec_t = iter_t.collect::<Vec<_>>();
458        // since we only collect the first 6 elements, the order is the same for col and
459        // row major however, if more elements are collected, the order will be
460        // different
461        assert_eq!(vec_t[..6], vec![&0, &128, &256, &384, &512, &640]);
462    }
463
464    #[test]
465    fn test_mut_iter() {
466        let mut a = arange(16384).into_shape([128, 128]);
467        let b = &a + 1;
468
469        let iter = a.iter_mut().into_par_iter();
470        iter.for_each(|x| *x += 1);
471
472        assert_eq!(a.reshape(-1).to_vec(), b.reshape(-1).to_vec());
473    }
474}