1use crate::prelude_dev::*;
2use core::mem::transmute;
3
4pub struct IterVecView<'a, T, D>
7where
8 D: DimDevAPI,
9{
10 layout_iter: IterLayout<D>,
11 view: &'a [T],
12}
13
14impl<'a, T, D> Iterator for IterVecView<'a, T, D>
15where
16 D: DimDevAPI,
17{
18 type Item = &'a T;
19
20 fn next(&mut self) -> Option<Self::Item> {
21 self.layout_iter.next().map(|offset| &self.view[offset])
22 }
23}
24
25impl<T, D> DoubleEndedIterator for IterVecView<'_, T, D>
26where
27 D: DimDevAPI,
28{
29 fn next_back(&mut self) -> Option<Self::Item> {
30 self.layout_iter.next_back().map(|offset| &self.view[offset])
31 }
32}
33
34impl<T, D> ExactSizeIterator for IterVecView<'_, T, D>
35where
36 D: DimDevAPI,
37{
38 fn len(&self) -> usize {
39 self.layout_iter.len()
40 }
41}
42
43impl<T, D> IterSplitAtAPI for IterVecView<'_, T, D>
44where
45 D: DimDevAPI,
46{
47 fn split_at(self, mid: usize) -> (Self, Self) {
48 let (lhs, rhs) = self.layout_iter.split_at(mid);
49 let lhs = IterVecView { layout_iter: lhs, view: self.view };
50 let rhs = IterVecView { layout_iter: rhs, view: self.view };
51 (lhs, rhs)
52 }
53}
54
55impl<'a, R, T, B, D> TensorAny<R, T, B, D>
56where
57 R: DataAPI<Data = B::Raw>,
58 D: DimAPI,
59 B: DeviceAPI<T, Raw = Vec<T>> + 'a,
60{
61 pub fn iter_with_order_f(&self, order: TensorIterOrder) -> Result<IterVecView<'a, T, D>> {
62 let layout_iter = IterLayout::new(self.layout(), order)?;
63 let raw = self.raw().as_ref();
64
65 let iter = IterVecView { layout_iter, view: raw };
68 Ok(unsafe { transmute::<IterVecView<'_, T, D>, IterVecView<'_, T, D>>(iter) })
69 }
70
71 pub fn iter_with_order(&self, order: TensorIterOrder) -> IterVecView<'a, T, D> {
72 self.iter_with_order_f(order).rstsr_unwrap()
73 }
74
75 pub fn iter_f(&self) -> Result<IterVecView<'a, T, D>> {
76 let default_order = self.device().default_order();
77 let order = match default_order {
78 RowMajor => TensorIterOrder::C,
79 ColMajor => TensorIterOrder::F,
80 };
81 self.iter_with_order_f(order)
82 }
83
84 pub fn iter(&self) -> IterVecView<'a, T, D> {
85 self.iter_f().rstsr_unwrap()
86 }
87}
88
89pub struct IterVecMut<'a, T, D>
94where
95 D: DimDevAPI,
96{
97 layout_iter: IterLayout<D>,
98 view: &'a mut [T],
99}
100
101impl<'a, T, D> Iterator for IterVecMut<'a, T, D>
102where
103 D: DimDevAPI,
104{
105 type Item = &'a mut T;
106
107 fn next(&mut self) -> Option<Self::Item> {
108 self.layout_iter.next().map(|offset| unsafe { transmute(&mut self.view[offset]) })
109 }
110}
111
112impl<T, D> DoubleEndedIterator for IterVecMut<'_, T, D>
113where
114 D: DimDevAPI,
115{
116 fn next_back(&mut self) -> Option<Self::Item> {
117 self.layout_iter.next_back().map(|offset| unsafe { transmute(&mut self.view[offset]) })
118 }
119}
120
121impl<T, D> ExactSizeIterator for IterVecMut<'_, T, D>
122where
123 D: DimDevAPI,
124{
125 fn len(&self) -> usize {
126 self.layout_iter.len()
127 }
128}
129
130impl<T, D> IterSplitAtAPI for IterVecMut<'_, T, D>
131where
132 D: DimDevAPI,
133{
134 fn split_at(self, mid: usize) -> (Self, Self) {
135 let (lhs, rhs) = self.layout_iter.split_at(mid);
138 let cloned_view = unsafe {
139 let len = self.view.len();
140 let ptr = self.view.as_mut_ptr();
141 core::slice::from_raw_parts_mut(ptr, len)
142 };
143 let lhs = IterVecMut { layout_iter: lhs, view: cloned_view };
144 let rhs = IterVecMut { layout_iter: rhs, view: self.view };
145 (lhs, rhs)
146 }
147}
148
149impl<'a, R, T, B, D> TensorAny<R, T, B, D>
150where
151 R: DataMutAPI<Data = B::Raw>,
152 D: DimAPI,
153 B: DeviceAPI<T, Raw = Vec<T>> + 'a,
154{
155 pub fn iter_mut_with_order_f(&'a mut self, order: TensorIterOrder) -> Result<IterVecMut<'a, T, D>> {
156 let layout_iter = IterLayout::new(self.layout(), order)?;
157 let raw = self.raw_mut().as_mut();
158 let iter = IterVecMut { layout_iter, view: raw };
159 Ok(iter)
160 }
161
162 pub fn iter_mut_with_order(&'a mut self, order: TensorIterOrder) -> IterVecMut<'a, T, D> {
163 self.iter_mut_with_order_f(order).rstsr_unwrap()
164 }
165
166 pub fn iter_mut_f(&'a mut self) -> Result<IterVecMut<'a, T, D>> {
167 let default_order = self.device().default_order();
168 let order = match default_order {
169 RowMajor => TensorIterOrder::C,
170 ColMajor => TensorIterOrder::F,
171 };
172 self.iter_mut_with_order_f(order)
173 }
174
175 pub fn iter_mut(&'a mut self) -> IterVecMut<'a, T, D> {
176 self.iter_mut_f().rstsr_unwrap()
177 }
178}
179
180pub struct IndexedIterVecView<'a, T, D>
185where
186 D: DimDevAPI,
187{
188 layout_iter: IterLayout<D>,
189 view: &'a [T],
190}
191
192impl<'a, T, D> Iterator for IndexedIterVecView<'a, T, D>
193where
194 D: DimDevAPI,
195{
196 type Item = (D, &'a T);
197
198 fn next(&mut self) -> Option<Self::Item> {
199 let index = match &self.layout_iter {
200 IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
201 IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
202 };
203 self.layout_iter.next().map(|offset| (index, &self.view[offset]))
204 }
205}
206
207impl<T, D> DoubleEndedIterator for IndexedIterVecView<'_, T, D>
208where
209 D: DimDevAPI,
210{
211 fn next_back(&mut self) -> Option<Self::Item> {
212 let index = match &self.layout_iter {
213 IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
214 IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
215 };
216 self.layout_iter.next_back().map(|offset| (index, &self.view[offset]))
217 }
218}
219
220impl<T, D> ExactSizeIterator for IndexedIterVecView<'_, T, D>
221where
222 D: DimDevAPI,
223{
224 fn len(&self) -> usize {
225 self.layout_iter.len()
226 }
227}
228
229impl<T, D> IterSplitAtAPI for IndexedIterVecView<'_, T, D>
230where
231 D: DimDevAPI,
232{
233 fn split_at(self, mid: usize) -> (Self, Self) {
234 let (lhs, rhs) = self.layout_iter.split_at(mid);
235 let lhs = IndexedIterVecView { layout_iter: lhs, view: self.view };
236 let rhs = IndexedIterVecView { layout_iter: rhs, view: self.view };
237 (lhs, rhs)
238 }
239}
240
241impl<'a, R, T, B, D> TensorAny<R, T, B, D>
242where
243 R: DataAPI<Data = B::Raw>,
244 D: DimAPI,
245 B: DeviceAPI<T, Raw = Vec<T>> + 'a,
246{
247 pub fn indexed_iter_with_order_f(&self, order: TensorIterOrder) -> Result<IndexedIterVecView<'a, T, D>> {
248 use TensorIterOrder::*;
249 match order {
251 C | F => (),
252 _ => rstsr_invalid!(order, "This function only accepts TensorIterOrder::C|F.",)?,
253 };
254 let layout_iter = IterLayout::<D>::new(self.layout(), order)?;
255 let raw = self.raw().as_ref();
256
257 let iter = IndexedIterVecView { layout_iter, view: raw };
260 Ok(unsafe { transmute::<IndexedIterVecView<'_, T, D>, IndexedIterVecView<'_, T, D>>(iter) })
261 }
262
263 pub fn indexed_iter_with_order(&self, order: TensorIterOrder) -> IndexedIterVecView<'a, T, D> {
264 self.indexed_iter_with_order_f(order).rstsr_unwrap()
265 }
266
267 pub fn indexed_iter_f(&self) -> Result<IndexedIterVecView<'a, T, D>> {
268 let default_order = self.device().default_order();
269 let order = match default_order {
270 RowMajor => TensorIterOrder::C,
271 ColMajor => TensorIterOrder::F,
272 };
273 self.indexed_iter_with_order_f(order)
274 }
275
276 pub fn indexed_iter(&self) -> IndexedIterVecView<'a, T, D> {
277 self.indexed_iter_f().rstsr_unwrap()
278 }
279}
280
281pub struct IndexedIterVecMut<'a, T, D>
285where
286 D: DimDevAPI,
287{
288 layout_iter: IterLayout<D>,
289 view: &'a mut [T],
290}
291
292impl<'a, T, D> Iterator for IndexedIterVecMut<'a, T, D>
293where
294 D: DimDevAPI,
295{
296 type Item = (D, &'a mut T);
297
298 fn next(&mut self) -> Option<Self::Item> {
299 let index = match &self.layout_iter {
300 IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
301 IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
302 };
303 self.layout_iter.next().map(|offset| (index, unsafe { transmute::<&mut T, &mut T>(&mut self.view[offset]) }))
304 }
305}
306
307impl<T, D> DoubleEndedIterator for IndexedIterVecMut<'_, T, D>
308where
309 D: DimDevAPI,
310{
311 fn next_back(&mut self) -> Option<Self::Item> {
312 let index = match &self.layout_iter {
313 IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
314 IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
315 };
316 self.layout_iter
317 .next_back()
318 .map(|offset| (index, unsafe { transmute::<&mut T, &mut T>(&mut self.view[offset]) }))
319 }
320}
321
322impl<T, D> ExactSizeIterator for IndexedIterVecMut<'_, T, D>
323where
324 D: DimDevAPI,
325{
326 fn len(&self) -> usize {
327 self.layout_iter.len()
328 }
329}
330
331impl<T, D> IterSplitAtAPI for IndexedIterVecMut<'_, T, D>
332where
333 D: DimDevAPI,
334{
335 fn split_at(self, mid: usize) -> (Self, Self) {
336 let (lhs, rhs) = self.layout_iter.split_at(mid);
337 let cloned_view = unsafe {
338 let len = self.view.len();
339 let ptr = self.view.as_mut_ptr();
340 core::slice::from_raw_parts_mut(ptr, len)
341 };
342 let lhs = IndexedIterVecMut { layout_iter: lhs, view: cloned_view };
343 let rhs = IndexedIterVecMut { layout_iter: rhs, view: self.view };
344 (lhs, rhs)
345 }
346}
347
348impl<'a, R, T, B, D> TensorAny<R, T, B, D>
349where
350 R: DataMutAPI<Data = B::Raw>,
351 D: DimAPI,
352 B: DeviceAPI<T, Raw = Vec<T>> + 'a,
353{
354 pub fn indexed_iter_mut_with_order_f(&'a mut self, order: TensorIterOrder) -> Result<IndexedIterVecMut<'a, T, D>> {
355 use TensorIterOrder::*;
356 match order {
358 C | F => (),
359 _ => rstsr_invalid!(order, "This function only accepts TensorIterOrder::C|F.",)?,
360 };
361 let layout_iter = IterLayout::<D>::new(self.layout(), order)?;
362 let raw = self.raw_mut().as_mut();
363
364 let iter = IndexedIterVecMut { layout_iter, view: raw };
365 Ok(iter)
366 }
367
368 pub fn indexed_iter_mut_with_order(&'a mut self, order: TensorIterOrder) -> IndexedIterVecMut<'a, T, D> {
369 self.indexed_iter_mut_with_order_f(order).rstsr_unwrap()
370 }
371
372 pub fn indexed_iter_mut_f(&'a mut self) -> Result<IndexedIterVecMut<'a, T, D>> {
373 let default_order = self.device().default_order();
374 let order = match default_order {
375 RowMajor => TensorIterOrder::C,
376 ColMajor => TensorIterOrder::F,
377 };
378 self.indexed_iter_mut_with_order_f(order)
379 }
380
381 pub fn indexed_iter_mut(&'a mut self) -> IndexedIterVecMut<'a, T, D> {
382 self.indexed_iter_mut_f().rstsr_unwrap()
383 }
384}
385
386#[cfg(test)]
389mod tests_serial {
390 use super::*;
391
392 #[test]
393 fn test_iter() {
394 let a = arange(6).into_shape([3, 2]);
395 let iter = a.iter();
396 let vec = iter.collect::<Vec<_>>();
397 assert_eq!(vec, vec![&0, &1, &2, &3, &4, &5]);
398
399 let iter_t = a.t().iter();
400 let vec_t = iter_t.collect::<Vec<_>>();
401 #[cfg(not(feature = "col_major"))]
402 {
403 assert_eq!(vec_t, vec![&0, &2, &4, &1, &3, &5]);
406 }
407 #[cfg(feature = "col_major")]
408 {
409 assert_eq!(vec_t, vec![&0, &3, &1, &4, &2, &5]);
412 }
413 }
414
415 #[test]
416 fn test_mut_iter() {
417 let mut a = arange(6usize).into_shape([3, 2]);
418 let iter = a.iter_mut();
419 iter.for_each(|x| *x = 0);
420 let a = a.reshape(-1).to_vec();
421 assert_eq!(a, vec![0, 0, 0, 0, 0, 0]);
422 }
423
424 #[test]
425 fn test_indexed_c_iter() {
426 let a = arange(6).into_layout([3, 2].c());
427 let iter = a.indexed_iter_with_order(TensorIterOrder::C);
428 let vec = iter.collect::<Vec<_>>();
429 #[cfg(not(feature = "col_major"))]
430 assert_eq!(vec, vec![([0, 0], &0), ([0, 1], &1), ([1, 0], &2), ([1, 1], &3), ([2, 0], &4), ([2, 1], &5)]);
431 #[cfg(feature = "col_major")]
432 assert_eq!(vec, vec![([0, 0], &0), ([0, 1], &3), ([1, 0], &1), ([1, 1], &4), ([2, 0], &2), ([2, 1], &5)]);
433
434 let iter_t = a.t().indexed_iter_with_order(TensorIterOrder::C);
435 let vec_t = iter_t.collect::<Vec<_>>();
436 #[cfg(not(feature = "col_major"))]
437 assert_eq!(vec_t, vec![([0, 0], &0), ([0, 1], &2), ([0, 2], &4), ([1, 0], &1), ([1, 1], &3), ([1, 2], &5)]);
438 #[cfg(feature = "col_major")]
439 assert_eq!(vec_t, vec![([0, 0], &0), ([0, 1], &1), ([0, 2], &2), ([1, 0], &3), ([1, 1], &4), ([1, 2], &5)]);
440 }
441}
442
443#[cfg(test)]
444#[cfg(feature = "rayon")]
445mod tests_parallel {
446 use super::*;
447 use rayon::prelude::*;
448
449 #[test]
450 fn test_iter() {
451 let a = arange(16384).into_shape([128, 128]);
452 let iter = a.iter().into_par_iter();
453 let vec = iter.collect::<Vec<_>>();
454 assert_eq!(vec[..6], vec![&0, &1, &2, &3, &4, &5]);
455
456 let iter_t = a.t().iter().into_par_iter();
457 let vec_t = iter_t.collect::<Vec<_>>();
458 assert_eq!(vec_t[..6], vec![&0, &128, &256, &384, &512, &640]);
462 }
463
464 #[test]
465 fn test_mut_iter() {
466 let mut a = arange(16384).into_shape([128, 128]);
467 let b = &a + 1;
468
469 let iter = a.iter_mut().into_par_iter();
470 iter.for_each(|x| *x += 1);
471
472 assert_eq!(a.reshape(-1).to_vec(), b.reshape(-1).to_vec());
473 }
474}