1#![allow(clippy::missing_transmute_annotations)]
2
3use crate::prelude_dev::*;
4use core::mem::transmute;
5
6pub struct IterAxesView<'a, T, B>
9where
10 B: DeviceAPI<T>,
11{
12 axes_iter: IterLayout<IxD>,
13 view: TensorView<'a, T, B, IxD>,
14}
15
16impl<T, B> IterAxesView<'_, T, B>
17where
18 B: DeviceAPI<T>,
19{
20 pub fn update_offset(&mut self, offset: usize) {
21 unsafe { self.view.layout.set_offset(offset) };
22 }
23}
24
25impl<'a, T, B> Iterator for IterAxesView<'a, T, B>
26where
27 B: DeviceAPI<T>,
28{
29 type Item = TensorView<'a, T, B, IxD>;
30
31 fn next(&mut self) -> Option<Self::Item> {
32 self.axes_iter.next().map(|offset| {
33 self.update_offset(offset);
34 unsafe { transmute(self.view.view()) }
35 })
36 }
37}
38
39impl<T, B> DoubleEndedIterator for IterAxesView<'_, T, B>
40where
41 B: DeviceAPI<T>,
42{
43 fn next_back(&mut self) -> Option<Self::Item> {
44 self.axes_iter.next_back().map(|offset| {
45 self.update_offset(offset);
46 unsafe { transmute(self.view.view()) }
47 })
48 }
49}
50
51impl<T, B> ExactSizeIterator for IterAxesView<'_, T, B>
52where
53 B: DeviceAPI<T>,
54{
55 fn len(&self) -> usize {
56 self.axes_iter.len()
57 }
58}
59
60impl<T, B> IterSplitAtAPI for IterAxesView<'_, T, B>
61where
62 B: DeviceAPI<T>,
63{
64 fn split_at(self, index: usize) -> (Self, Self) {
65 let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.split_at(index);
66 let view_lhs = unsafe { transmute(self.view.view()) };
67 let lhs = IterAxesView { axes_iter: lhs_axes_iter, view: view_lhs };
68 let rhs = IterAxesView { axes_iter: rhs_axes_iter, view: self.view };
69 return (lhs, rhs);
70 }
71}
72
73impl<'a, R, T, B, D> TensorAny<R, T, B, D>
74where
75 T: Clone,
76 R: DataCloneAPI<Data = B::Raw>,
77 D: DimAPI,
78 B: DeviceAPI<T, Raw = Vec<T>> + 'a,
79{
80 pub fn axes_iter_with_order_f<I>(&self, axes: I, order: TensorIterOrder) -> Result<IterAxesView<'a, T, B>>
81 where
82 I: TryInto<AxesIndex<isize>, Error = Error>,
83 {
84 let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
86 let axes: Vec<isize> =
87 axes.try_into()?.as_ref().iter().map(|&v| if v >= 0 { v } else { v + ndim }).collect::<Vec<isize>>();
88 let mut axes_check = axes.clone();
89 axes_check.sort();
90 if axes.first().is_some_and(|&v| v < 0) {
92 return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
93 }
94 for i in 0..axes_check.len() - 1 {
95 rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
96 }
97
98 let layout = self.layout().to_dim::<IxD>()?;
100 let shape_full = layout.shape();
101 let stride_full = layout.stride();
102 let offset = layout.offset();
103
104 let mut shape_axes = vec![];
106 let mut stride_axes = vec![];
107 for &idx in &axes {
108 shape_axes.push(shape_full[idx as usize]);
109 stride_axes.push(stride_full[idx as usize]);
110 }
111 let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
112
113 let mut shape_inner = vec![];
115 let mut stride_inner = vec![];
116 for idx in 0..ndim {
117 if !axes.contains(&idx) {
118 shape_inner.push(shape_full[idx as usize]);
119 stride_inner.push(stride_full[idx as usize]);
120 }
121 }
122 let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
123
124 let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
126 let mut view = self.view().into_dyn();
127 view.layout = layout_inner.clone();
128 let iter = IterAxesView { axes_iter, view: unsafe { transmute(view) } };
129 Ok(iter)
130 }
131
132 pub fn axes_iter_f<I>(&self, axes: I) -> Result<IterAxesView<'a, T, B>>
133 where
134 I: TryInto<AxesIndex<isize>, Error = Error>,
135 {
136 self.axes_iter_with_order_f(axes, TensorIterOrder::default())
137 }
138
139 pub fn axes_iter_with_order<I>(&self, axes: I, order: TensorIterOrder) -> IterAxesView<'a, T, B>
140 where
141 I: TryInto<AxesIndex<isize>, Error = Error>,
142 {
143 self.axes_iter_with_order_f(axes, order).rstsr_unwrap()
144 }
145
146 pub fn axes_iter<I>(&self, axes: I) -> IterAxesView<'a, T, B>
147 where
148 I: TryInto<AxesIndex<isize>, Error = Error>,
149 {
150 self.axes_iter_f(axes).rstsr_unwrap()
151 }
152}
153
154pub struct IterAxesMut<'a, T, B>
159where
160 B: DeviceAPI<T>,
161{
162 axes_iter: IterLayout<IxD>,
163 view: TensorMut<'a, T, B, IxD>,
164}
165
166impl<T, B> IterAxesMut<'_, T, B>
167where
168 B: DeviceAPI<T>,
169{
170 pub fn update_offset(&mut self, offset: usize) {
171 unsafe { self.view.layout.set_offset(offset) };
172 }
173}
174
175impl<'a, T, B> Iterator for IterAxesMut<'a, T, B>
176where
177 B: DeviceAPI<T>,
178{
179 type Item = TensorMut<'a, T, B, IxD>;
180
181 fn next(&mut self) -> Option<Self::Item> {
182 self.axes_iter.next().map(|offset| {
183 self.update_offset(offset);
184 unsafe { transmute(self.view.view_mut()) }
185 })
186 }
187}
188
189impl<T, B> DoubleEndedIterator for IterAxesMut<'_, T, B>
190where
191 B: DeviceAPI<T>,
192{
193 fn next_back(&mut self) -> Option<Self::Item> {
194 self.axes_iter.next_back().map(|offset| {
195 self.update_offset(offset);
196 unsafe { transmute(self.view.view_mut()) }
197 })
198 }
199}
200
201impl<T, B> ExactSizeIterator for IterAxesMut<'_, T, B>
202where
203 B: DeviceAPI<T>,
204{
205 fn len(&self) -> usize {
206 self.axes_iter.len()
207 }
208}
209
210impl<T, B> IterSplitAtAPI for IterAxesMut<'_, T, B>
211where
212 B: DeviceAPI<T>,
213{
214 fn split_at(mut self, index: usize) -> (Self, Self) {
215 let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.clone().split_at(index);
216 let view_lhs = unsafe { transmute(self.view.view_mut()) };
217 let lhs = IterAxesMut { axes_iter: lhs_axes_iter, view: view_lhs };
218 let rhs = IterAxesMut { axes_iter: rhs_axes_iter, view: self.view };
219 return (lhs, rhs);
220 }
221}
222
223impl<'a, R, T, B, D> TensorAny<R, T, B, D>
224where
225 T: Clone,
226 R: DataMutAPI<Data = B::Raw>,
227 D: DimAPI,
228 B: DeviceAPI<T, Raw = Vec<T>> + 'a,
229{
230 pub fn axes_iter_mut_with_order_f<I>(&'a mut self, axes: I, order: TensorIterOrder) -> Result<IterAxesMut<'a, T, B>>
231 where
232 I: TryInto<AxesIndex<isize>, Error = Error>,
233 {
234 let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
236 let axes: Vec<isize> =
237 axes.try_into()?.as_ref().iter().map(|&v| if v >= 0 { v } else { v + ndim }).collect::<Vec<isize>>();
238 let mut axes_check = axes.clone();
239 axes_check.sort();
240 if axes.first().is_some_and(|&v| v < 0) {
242 return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
243 }
244 for i in 0..axes_check.len() - 1 {
245 rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
246 }
247
248 let layout = self.layout().to_dim::<IxD>()?;
250 let shape_full = layout.shape();
251 let stride_full = layout.stride();
252 let offset = layout.offset();
253
254 let mut shape_axes = vec![];
256 let mut stride_axes = vec![];
257 for &idx in &axes {
258 shape_axes.push(shape_full[idx as usize]);
259 stride_axes.push(stride_full[idx as usize]);
260 }
261 let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
262
263 let mut shape_inner = vec![];
265 let mut stride_inner = vec![];
266 for idx in 0..ndim {
267 if !axes.contains(&idx) {
268 shape_inner.push(shape_full[idx as usize]);
269 stride_inner.push(stride_full[idx as usize]);
270 }
271 }
272 let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
273
274 let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
276 let mut view = self.view_mut().into_dyn();
277 view.layout = layout_inner.clone();
278 let iter = IterAxesMut { axes_iter, view };
279 Ok(iter)
280 }
281
282 pub fn axes_iter_mut_f<I>(&'a mut self, axes: I) -> Result<IterAxesMut<'a, T, B>>
283 where
284 I: TryInto<AxesIndex<isize>, Error = Error>,
285 {
286 self.axes_iter_mut_with_order_f(axes, TensorIterOrder::default())
287 }
288
289 pub fn axes_iter_mut_with_order<I>(&'a mut self, axes: I, order: TensorIterOrder) -> IterAxesMut<'a, T, B>
290 where
291 I: TryInto<AxesIndex<isize>, Error = Error>,
292 {
293 self.axes_iter_mut_with_order_f(axes, order).rstsr_unwrap()
294 }
295
296 pub fn axes_iter_mut<I>(&'a mut self, axes: I) -> IterAxesMut<'a, T, B>
297 where
298 I: TryInto<AxesIndex<isize>, Error = Error>,
299 {
300 self.axes_iter_mut_f(axes).rstsr_unwrap()
301 }
302}
303
304pub struct IndexedIterAxesView<'a, T, B>
309where
310 B: DeviceAPI<T>,
311{
312 axes_iter: IterLayout<IxD>,
313 view: TensorView<'a, T, B, IxD>,
314}
315
316impl<T, B> IndexedIterAxesView<'_, T, B>
317where
318 B: DeviceAPI<T>,
319{
320 pub fn update_offset(&mut self, offset: usize) {
321 unsafe { self.view.layout.set_offset(offset) };
322 }
323}
324
325impl<'a, T, B> Iterator for IndexedIterAxesView<'a, T, B>
326where
327 B: DeviceAPI<T>,
328{
329 type Item = (IxD, TensorView<'a, T, B, IxD>);
330
331 fn next(&mut self) -> Option<Self::Item> {
332 let index = match &self.axes_iter {
333 IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
334 IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
335 };
336 self.axes_iter.next().map(|offset| {
337 self.update_offset(offset);
338 (index, unsafe { transmute(self.view.view()) })
339 })
340 }
341}
342
343impl<T, B> DoubleEndedIterator for IndexedIterAxesView<'_, T, B>
344where
345 B: DeviceAPI<T>,
346{
347 fn next_back(&mut self) -> Option<Self::Item> {
348 let index = match &self.axes_iter {
349 IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
350 IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
351 };
352 self.axes_iter.next_back().map(|offset| {
353 self.update_offset(offset);
354 (index, unsafe { transmute(self.view.view()) })
355 })
356 }
357}
358
359impl<T, B> ExactSizeIterator for IndexedIterAxesView<'_, T, B>
360where
361 B: DeviceAPI<T>,
362{
363 fn len(&self) -> usize {
364 self.axes_iter.len()
365 }
366}
367
368impl<T, B> IterSplitAtAPI for IndexedIterAxesView<'_, T, B>
369where
370 B: DeviceAPI<T>,
371{
372 fn split_at(self, index: usize) -> (Self, Self) {
373 let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.split_at(index);
374 let view_lhs = unsafe { transmute(self.view.view()) };
375 let lhs = IndexedIterAxesView { axes_iter: lhs_axes_iter, view: view_lhs };
376 let rhs = IndexedIterAxesView { axes_iter: rhs_axes_iter, view: self.view };
377 return (lhs, rhs);
378 }
379}
380
381impl<'a, R, T, B, D> TensorAny<R, T, B, D>
382where
383 T: Clone,
384 R: DataCloneAPI<Data = B::Raw>,
385 D: DimAPI,
386 B: DeviceAPI<T, Raw = Vec<T>> + 'a,
387{
388 pub fn indexed_axes_iter_with_order_f<I>(
389 &self,
390 axes: I,
391 order: TensorIterOrder,
392 ) -> Result<IndexedIterAxesView<'a, T, B>>
393 where
394 I: TryInto<AxesIndex<isize>, Error = Error>,
395 {
396 use TensorIterOrder::*;
397 match order {
399 C | F => (),
400 _ => rstsr_invalid!(order, "This function only accepts TensorIterOrder::C|F.",)?,
401 };
402 let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
404 let axes: Vec<isize> =
405 axes.try_into()?.as_ref().iter().map(|&v| if v >= 0 { v } else { v + ndim }).collect::<Vec<isize>>();
406 let mut axes_check = axes.clone();
407 axes_check.sort();
408 if axes.first().is_some_and(|&v| v < 0) {
410 return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
411 }
412 for i in 0..axes_check.len() - 1 {
413 rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
414 }
415
416 let layout = self.layout().to_dim::<IxD>()?;
418 let shape_full = layout.shape();
419 let stride_full = layout.stride();
420 let offset = layout.offset();
421
422 let mut shape_axes = vec![];
424 let mut stride_axes = vec![];
425 for &idx in &axes {
426 shape_axes.push(shape_full[idx as usize]);
427 stride_axes.push(stride_full[idx as usize]);
428 }
429 let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
430
431 let mut shape_inner = vec![];
433 let mut stride_inner = vec![];
434 for idx in 0..ndim {
435 if !axes.contains(&idx) {
436 shape_inner.push(shape_full[idx as usize]);
437 stride_inner.push(stride_full[idx as usize]);
438 }
439 }
440 let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
441
442 let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
444 let mut view = self.view().into_dyn();
445 view.layout = layout_inner.clone();
446 let iter = IndexedIterAxesView { axes_iter, view: unsafe { transmute(view) } };
447 Ok(iter)
448 }
449
450 pub fn indexed_axes_iter_f<I>(&self, axes: I) -> Result<IndexedIterAxesView<'a, T, B>>
451 where
452 I: TryInto<AxesIndex<isize>, Error = Error>,
453 {
454 let default_order = self.device().default_order();
455 let order = match default_order {
456 RowMajor => TensorIterOrder::C,
457 ColMajor => TensorIterOrder::F,
458 };
459 self.indexed_axes_iter_with_order_f(axes, order)
460 }
461
462 pub fn indexed_axes_iter_with_order<I>(&self, axes: I, order: TensorIterOrder) -> IndexedIterAxesView<'a, T, B>
463 where
464 I: TryInto<AxesIndex<isize>, Error = Error>,
465 {
466 self.indexed_axes_iter_with_order_f(axes, order).rstsr_unwrap()
467 }
468
469 pub fn indexed_axes_iter<I>(&self, axes: I) -> IndexedIterAxesView<'a, T, B>
470 where
471 I: TryInto<AxesIndex<isize>, Error = Error>,
472 {
473 self.indexed_axes_iter_f(axes).rstsr_unwrap()
474 }
475}
476
477pub struct IndexedIterAxesMut<'a, T, B>
482where
483 B: DeviceAPI<T>,
484{
485 axes_iter: IterLayout<IxD>,
486 view: TensorMut<'a, T, B, IxD>,
487}
488
489impl<T, B> IndexedIterAxesMut<'_, T, B>
490where
491 B: DeviceAPI<T>,
492{
493 pub fn update_offset(&mut self, offset: usize) {
494 unsafe { self.view.layout.set_offset(offset) };
495 }
496}
497
498impl<'a, T, B> Iterator for IndexedIterAxesMut<'a, T, B>
499where
500 B: DeviceAPI<T>,
501{
502 type Item = (IxD, TensorMut<'a, T, B, IxD>);
503
504 fn next(&mut self) -> Option<Self::Item> {
505 let index = match &self.axes_iter {
506 IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
507 IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
508 };
509 self.axes_iter.next().map(|offset| {
510 self.update_offset(offset);
511 unsafe { transmute((index, self.view.view_mut())) }
512 })
513 }
514}
515
516impl<T, B> DoubleEndedIterator for IndexedIterAxesMut<'_, T, B>
517where
518 B: DeviceAPI<T>,
519{
520 fn next_back(&mut self) -> Option<Self::Item> {
521 let index = match &self.axes_iter {
522 IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
523 IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
524 };
525 self.axes_iter.next_back().map(|offset| {
526 self.update_offset(offset);
527 unsafe { transmute((index, self.view.view_mut())) }
528 })
529 }
530}
531
532impl<T, B> ExactSizeIterator for IndexedIterAxesMut<'_, T, B>
533where
534 B: DeviceAPI<T>,
535{
536 fn len(&self) -> usize {
537 self.axes_iter.len()
538 }
539}
540
541impl<T, B> IterSplitAtAPI for IndexedIterAxesMut<'_, T, B>
542where
543 B: DeviceAPI<T>,
544{
545 fn split_at(mut self, index: usize) -> (Self, Self) {
546 let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.clone().split_at(index);
547 let view_lhs = unsafe { transmute(self.view.view_mut()) };
548 let lhs = IndexedIterAxesMut { axes_iter: lhs_axes_iter, view: view_lhs };
549 let rhs = IndexedIterAxesMut { axes_iter: rhs_axes_iter, view: self.view };
550 return (lhs, rhs);
551 }
552}
553
554impl<'a, R, T, B, D> TensorAny<R, T, B, D>
555where
556 T: Clone,
557 R: DataMutAPI<Data = B::Raw>,
558 D: DimAPI,
559 B: DeviceAPI<T, Raw = Vec<T>> + 'a,
560{
561 pub fn indexed_axes_iter_mut_with_order_f<I>(
562 &'a mut self,
563 axes: I,
564 order: TensorIterOrder,
565 ) -> Result<IndexedIterAxesMut<'a, T, B>>
566 where
567 I: TryInto<AxesIndex<isize>, Error = Error>,
568 {
569 let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
571 let axes: Vec<isize> =
572 axes.try_into()?.as_ref().iter().map(|&v| if v >= 0 { v } else { v + ndim }).collect::<Vec<isize>>();
573 let mut axes_check = axes.clone();
574 axes_check.sort();
575 if axes.first().is_some_and(|&v| v < 0) {
577 return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
578 }
579 for i in 0..axes_check.len() - 1 {
580 rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
581 }
582
583 let layout = self.layout().to_dim::<IxD>()?;
585 let shape_full = layout.shape();
586 let stride_full = layout.stride();
587 let offset = layout.offset();
588
589 let mut shape_axes = vec![];
591 let mut stride_axes = vec![];
592 for &idx in &axes {
593 shape_axes.push(shape_full[idx as usize]);
594 stride_axes.push(stride_full[idx as usize]);
595 }
596 let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
597
598 let mut shape_inner = vec![];
600 let mut stride_inner = vec![];
601 for idx in 0..ndim {
602 if !axes.contains(&idx) {
603 shape_inner.push(shape_full[idx as usize]);
604 stride_inner.push(stride_full[idx as usize]);
605 }
606 }
607 let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
608
609 let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
611 let mut view = self.view_mut().into_dyn();
612 view.layout = layout_inner.clone();
613 let iter = IndexedIterAxesMut { axes_iter, view };
614 Ok(iter)
615 }
616
617 pub fn indexed_axes_iter_mut_f<I>(&'a mut self, axes: I) -> Result<IndexedIterAxesMut<'a, T, B>>
618 where
619 I: TryInto<AxesIndex<isize>, Error = Error>,
620 {
621 let default_order = self.device().default_order();
622 let order = match default_order {
623 RowMajor => TensorIterOrder::C,
624 ColMajor => TensorIterOrder::F,
625 };
626 self.indexed_axes_iter_mut_with_order_f(axes, order)
627 }
628
629 pub fn indexed_axes_iter_mut_with_order<I>(
630 &'a mut self,
631 axes: I,
632 order: TensorIterOrder,
633 ) -> IndexedIterAxesMut<'a, T, B>
634 where
635 I: TryInto<AxesIndex<isize>, Error = Error>,
636 {
637 self.indexed_axes_iter_mut_with_order_f(axes, order).rstsr_unwrap()
638 }
639
640 pub fn indexed_axes_iter_mut<I>(&'a mut self, axes: I) -> IndexedIterAxesMut<'a, T, B>
641 where
642 I: TryInto<AxesIndex<isize>, Error = Error>,
643 {
644 self.indexed_axes_iter_mut_f(axes).rstsr_unwrap()
645 }
646}
647
648#[cfg(test)]
651mod tests_serial {
652 use super::*;
653
654 #[test]
655 fn test_axes_iter() {
656 let a = arange(120).into_shape([2, 3, 4, 5]);
657 let iter = a.axes_iter_f([0, 2]).unwrap();
658
659 let res = iter
660 .map(|view| {
661 println!("{view:3}");
662 view[[1, 2]]
663 })
664 .collect::<Vec<_>>();
665 #[cfg(not(feature = "col_major"))]
666 {
667 assert_eq!(res, vec![22, 27, 32, 37, 82, 87, 92, 97]);
671 }
672 #[cfg(feature = "col_major")]
673 {
674 assert_eq!(res, vec![50, 51, 56, 57, 62, 63, 68, 69]);
678 }
679 }
680
681 #[test]
682 fn test_axes_iter_mut() {
683 let mut a = arange(120).into_shape([2, 3, 4, 5]);
684 let iter = a.axes_iter_mut_with_order_f([0, 2], TensorIterOrder::C).unwrap();
685
686 let res = iter
687 .map(|mut view| {
688 view += 1;
689 println!("{view:3}");
690 view[[1, 2]]
691 })
692 .collect::<Vec<_>>();
693 println!("{res:?}");
694 #[cfg(not(feature = "col_major"))]
695 {
696 assert_eq!(res, vec![23, 28, 33, 38, 83, 88, 93, 98]);
700 }
701 #[cfg(feature = "col_major")]
702 {
703 assert_eq!(res, vec![51, 57, 63, 69, 52, 58, 64, 70]);
707 }
708 }
709
710 #[test]
711 fn test_indexed_axes_iter() {
712 let a = arange(120).into_shape([2, 3, 4, 5]);
713 let iter = a.indexed_axes_iter([0, 2]);
714
715 let res = iter
716 .map(|(index, view)| {
717 println!("{index:?}");
718 println!("{view:3}");
719 (index, view[[1, 2]])
720 })
721 .collect::<Vec<_>>();
722 #[cfg(not(feature = "col_major"))]
723 {
724 assert_eq!(res, vec![
728 (vec![0, 0], 22),
729 (vec![0, 1], 27),
730 (vec![0, 2], 32),
731 (vec![0, 3], 37),
732 (vec![1, 0], 82),
733 (vec![1, 1], 87),
734 (vec![1, 2], 92),
735 (vec![1, 3], 97)
736 ]);
737 }
738 #[cfg(feature = "col_major")]
739 {
740 assert_eq!(res, vec![
744 (vec![0, 0], 50),
745 (vec![1, 0], 51),
746 (vec![0, 1], 56),
747 (vec![1, 1], 57),
748 (vec![0, 2], 62),
749 (vec![1, 2], 63),
750 (vec![0, 3], 68),
751 (vec![1, 3], 69)
752 ]);
753 }
754 }
755}
756
757#[cfg(test)]
758#[cfg(feature = "rayon")]
759mod tests_parallel {
760 use super::*;
761 use rayon::prelude::*;
762
763 #[test]
764 fn test_axes_iter() {
765 let mut a = arange(65536).into_shape([16, 16, 16, 16]);
766 let iter = a.axes_iter_mut([0, 2]);
767
768 let res = iter
769 .into_par_iter()
770 .map(|mut view| {
771 view += 1;
772 println!("{view:6}");
773 view[[1, 2]]
774 })
775 .collect::<Vec<_>>();
776 println!("{res:?}");
777 #[cfg(not(feature = "col_major"))]
778 {
779 assert_eq!(res[..17], vec![
782 259, 275, 291, 307, 323, 339, 355, 371, 387, 403, 419, 435, 451, 467, 483, 499, 4355
783 ]);
784 }
785 #[cfg(feature = "col_major")]
786 {
787 assert_eq!(res[..17], vec![
791 8209, 8210, 8211, 8212, 8213, 8214, 8215, 8216, 8217, 8218, 8219, 8220, 8221, 8222, 8223, 8224, 8465
792 ]);
793 }
794 }
795}