1use std::iter::FusedIterator;
2use std::mem::transmute;
3use std::ops::Range;
4
5use rten_base::iter::SplitIterator;
6
7use super::{
8 AsView, DynLayout, MutLayout, NdTensorView, NdTensorViewMut, TensorBase, TensorViewMut,
9};
10use crate::layout::{Layout, NdLayout, OverlapPolicy, RemoveDim, merge_axes};
11use crate::storage::{StorageMut, ViewData, ViewMutData};
12
13mod parallel;
14
15#[derive(Copy, Clone, Debug, Default)]
17struct IterPos {
18 remaining: usize,
20
21 offset: usize,
23
24 stride: usize,
26
27 max_remaining: usize,
29}
30
31impl IterPos {
32 fn from_size_stride(size: usize, stride: usize) -> Self {
33 let remaining = size.saturating_sub(1);
34 IterPos {
35 remaining,
36 offset: 0,
37 stride,
38 max_remaining: remaining,
39 }
40 }
41
42 #[inline(always)]
43 fn step(&mut self) -> bool {
44 if self.remaining != 0 {
45 self.remaining -= 1;
46 self.offset += self.stride;
47 true
48 } else {
49 self.remaining = self.max_remaining;
50 self.offset = 0;
51 false
52 }
53 }
54
55 fn size(&self) -> usize {
57 self.max_remaining + 1
60 }
61
62 fn index(&self) -> usize {
64 self.max_remaining - self.remaining
65 }
66
67 fn set_index(&mut self, index: usize) {
69 self.remaining = self.max_remaining - index;
70 self.offset = index * self.stride;
71 }
72}
73
74const INNER_NDIM: usize = 2;
75
76#[derive(Clone, Debug)]
78struct OffsetsBase {
79 len: usize,
84
85 inner_offset: usize,
87
88 inner_pos: [IterPos; INNER_NDIM],
90
91 outer_offset: usize,
93
94 outer_pos: Vec<IterPos>,
101}
102
103impl OffsetsBase {
104 fn new<L: Layout>(layout: &L) -> OffsetsBase {
106 let merged = merge_axes(layout.shape().as_ref(), layout.strides().as_ref());
109
110 let inner_pos_pad = INNER_NDIM.saturating_sub(merged.len());
111 let n_outer = merged.len().saturating_sub(INNER_NDIM);
112
113 let inner_pos = std::array::from_fn(|dim| {
114 let (size, stride) = if dim < inner_pos_pad {
115 (1, 0)
116 } else {
117 merged[n_outer + dim - inner_pos_pad]
118 };
119 IterPos::from_size_stride(size, stride)
120 });
121
122 let outer_pos = (0..n_outer)
123 .map(|i| {
124 let (size, stride) = merged[i];
125 IterPos::from_size_stride(size, stride)
126 })
127 .collect();
128
129 OffsetsBase {
130 len: merged.iter().map(|dim| dim.0).product(),
131 inner_pos,
132 inner_offset: 0,
133 outer_pos,
134 outer_offset: 0,
135 }
136 }
137
138 fn step_outer_pos(&mut self) -> bool {
143 let mut done = self.outer_pos.is_empty();
144 for (i, dim) in self.outer_pos.iter_mut().enumerate().rev() {
145 if dim.step() {
146 break;
147 } else if i == 0 {
148 done = true;
149 }
150 }
151 self.outer_offset = self.outer_pos.iter().map(|p| p.offset).sum();
152 !done
153 }
154
155 fn pos(&self, dim: usize) -> IterPos {
156 let outer_ndim = self.outer_pos.len();
157 if dim >= outer_ndim {
158 self.inner_pos[dim - outer_ndim]
159 } else {
160 self.outer_pos[dim]
161 }
162 }
163
164 fn pos_mut(&mut self, dim: usize) -> &mut IterPos {
165 let outer_ndim = self.outer_pos.len();
166 if dim >= outer_ndim {
167 &mut self.inner_pos[dim - outer_ndim]
168 } else {
169 &mut self.outer_pos[dim]
170 }
171 }
172
173 fn step_by(&mut self, n: usize) {
175 let mut remaining = n.min(self.len);
176 self.len -= remaining;
177
178 for dim in (0..self.ndim()).rev() {
179 if remaining == 0 {
180 break;
181 }
182
183 let pos = self.pos_mut(dim);
184 let size = pos.size();
185 let new_index = pos.index() + remaining;
186 pos.set_index(new_index % size);
187 remaining = new_index / size;
188 }
189
190 self.inner_offset = self.inner_pos.iter().map(|p| p.offset).sum();
192 self.outer_offset = self.outer_pos.iter().map(|p| p.offset).sum();
193 }
194
195 fn ndim(&self) -> usize {
196 self.outer_pos.len() + self.inner_pos.len()
197 }
198
199 fn offset_from_linear_index(&self, index: usize) -> usize {
202 let mut offset = 0;
203 let mut shape_product = 1;
204 for dim in (0..self.ndim()).rev() {
205 let pos = self.pos(dim);
206 let dim_index = (index / shape_product) % pos.size();
207 shape_product *= pos.size();
208 offset += dim_index * pos.stride;
209 }
210 offset
211 }
212
213 fn truncate(&mut self, len: usize) {
215 self.len = self.len.min(len);
219 }
220}
221
222impl Iterator for OffsetsBase {
223 type Item = usize;
224
225 #[inline(always)]
226 fn next(&mut self) -> Option<usize> {
227 if self.len == 0 {
228 return None;
229 }
230 let offset = self.outer_offset + self.inner_offset;
231
232 self.len -= 1;
233
234 self.inner_offset += self.inner_pos[1].stride;
237
238 if !self.inner_pos[1].step() {
241 if !self.inner_pos[0].step() {
242 self.step_outer_pos();
243 }
244
245 self.inner_offset = self.inner_pos[0].offset;
250 }
251
252 Some(offset)
253 }
254
255 fn size_hint(&self) -> (usize, Option<usize>) {
256 (self.len, Some(self.len))
257 }
258
259 fn fold<B, F>(mut self, init: B, mut f: F) -> B
260 where
261 Self: Sized,
262 F: FnMut(B, usize) -> B,
263 {
264 if self.len == 0 {
266 return init;
267 }
268
269 let mut accum = init;
270 'outer: loop {
271 for i0 in self.inner_pos[0].index()..self.inner_pos[0].size() {
272 for i1 in self.inner_pos[1].index()..self.inner_pos[1].size() {
273 let inner_offset =
274 i0 * self.inner_pos[0].stride + i1 * self.inner_pos[1].stride;
275 accum = f(accum, self.outer_offset + inner_offset);
276
277 self.len -= 1;
278 if self.len == 0 {
279 break 'outer;
280 }
281 }
282 self.inner_pos[1].set_index(0);
283 }
284 self.inner_pos[0].set_index(0);
285
286 if !self.step_outer_pos() {
287 break;
288 }
289 }
290
291 accum
292 }
293}
294
295impl ExactSizeIterator for OffsetsBase {}
296
297impl DoubleEndedIterator for OffsetsBase {
298 fn next_back(&mut self) -> Option<usize> {
299 if self.len == 0 {
300 return None;
301 }
302
303 let index = self.len - 1;
306 let offset = self.offset_from_linear_index(index);
307 self.len -= 1;
308
309 Some(offset)
310 }
311}
312
313impl SplitIterator for OffsetsBase {
314 fn split_at(mut self, index: usize) -> (Self, Self) {
317 assert!(self.len >= index);
318
319 let mut right = self.clone();
320 OffsetsBase::step_by(&mut right, index);
321
322 self.truncate(index);
323
324 (self, right)
325 }
326}
327
328pub struct Iter<'a, T> {
330 offsets: Offsets,
331 data: ViewData<'a, T>,
332}
333
334impl<'a, T> Iter<'a, T> {
335 pub(super) fn new<L: Layout + Clone>(view: TensorBase<ViewData<'a, T>, L>) -> Iter<'a, T> {
336 Iter {
337 offsets: Offsets::new(view.layout()),
338 data: view.storage(),
339 }
340 }
341}
342
343impl<T> Clone for Iter<'_, T> {
344 fn clone(&self) -> Self {
345 Iter {
346 offsets: self.offsets.clone(),
347 data: self.data,
348 }
349 }
350}
351
352impl<'a, T> Iterator for Iter<'a, T> {
353 type Item = &'a T;
354
355 #[inline(always)]
356 fn next(&mut self) -> Option<Self::Item> {
357 let offset = self.offsets.next()?;
358
359 Some(unsafe { self.data.get_unchecked(offset) })
361 }
362
363 fn size_hint(&self) -> (usize, Option<usize>) {
364 self.offsets.size_hint()
365 }
366
367 fn nth(&mut self, n: usize) -> Option<Self::Item> {
368 let offset = self.offsets.nth(n)?;
369
370 Some(unsafe { self.data.get_unchecked(offset) })
372 }
373
374 fn fold<B, F>(self, init: B, mut f: F) -> B
375 where
376 Self: Sized,
377 F: FnMut(B, Self::Item) -> B,
378 {
379 self.offsets.fold(init, |acc, offset| {
380 let item = unsafe { self.data.get_unchecked(offset) };
382 f(acc, item)
383 })
384 }
385}
386
387impl<'a, T> DoubleEndedIterator for Iter<'a, T> {
388 fn next_back(&mut self) -> Option<Self::Item> {
389 let offset = self.offsets.next_back()?;
390
391 Some(unsafe { self.data.get_unchecked(offset) })
393 }
394}
395
396impl<T> ExactSizeIterator for Iter<'_, T> {}
397
398impl<T> FusedIterator for Iter<'_, T> {}
399
400unsafe fn transmute_lifetime_mut<'a, 'b, T>(x: &'a mut T) -> &'b mut T {
403 unsafe { transmute::<&'a mut T, &'b mut T>(x) }
404}
405
406pub struct IterMut<'a, T> {
408 offsets: Offsets,
409 data: ViewMutData<'a, T>,
410}
411
412impl<'a, T> IterMut<'a, T> {
413 pub(super) fn new<L: Layout + Clone>(
414 view: TensorBase<ViewMutData<'a, T>, L>,
415 ) -> IterMut<'a, T> {
416 IterMut {
417 offsets: Offsets::new(view.layout()),
418 data: view.into_storage(),
419 }
420 }
421}
422
423impl<'a, T> Iterator for IterMut<'a, T> {
424 type Item = &'a mut T;
425
426 #[inline]
427 fn next(&mut self) -> Option<Self::Item> {
428 let offset = self.offsets.next()?;
429
430 Some(unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) })
433 }
434
435 #[inline]
436 fn size_hint(&self) -> (usize, Option<usize>) {
437 self.offsets.size_hint()
438 }
439
440 fn nth(&mut self, n: usize) -> Option<Self::Item> {
441 let offset = self.offsets.nth(n)?;
442
443 Some(unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) })
446 }
447
448 fn fold<B, F>(mut self, init: B, mut f: F) -> B
449 where
450 Self: Sized,
451 F: FnMut(B, Self::Item) -> B,
452 {
453 self.offsets.fold(init, |acc, offset| {
454 let item = unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) };
457 f(acc, item)
458 })
459 }
460}
461
462impl<T> DoubleEndedIterator for IterMut<'_, T> {
463 fn next_back(&mut self) -> Option<Self::Item> {
464 let offset = self.offsets.next_back()?;
465
466 Some(unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) })
469 }
470}
471
472impl<T> ExactSizeIterator for IterMut<'_, T> {}
473
474impl<T> FusedIterator for IterMut<'_, T> {}
475
476#[derive(Clone)]
477enum OffsetsKind {
478 Range(Range<usize>),
479 Indexing(OffsetsBase),
480}
481
482#[derive(Clone)]
489struct Offsets {
490 base: OffsetsKind,
491}
492
493impl Offsets {
494 pub fn new<L: Layout>(layout: &L) -> Offsets {
495 Offsets {
496 base: if layout.is_contiguous() {
497 OffsetsKind::Range(0..layout.min_data_len())
498 } else {
499 OffsetsKind::Indexing(OffsetsBase::new(layout))
500 },
501 }
502 }
503}
504
505impl Iterator for Offsets {
506 type Item = usize;
507
508 #[inline]
509 fn next(&mut self) -> Option<Self::Item> {
510 match &mut self.base {
511 OffsetsKind::Range(r) => r.next(),
512 OffsetsKind::Indexing(base) => base.next(),
513 }
514 }
515
516 fn size_hint(&self) -> (usize, Option<usize>) {
517 match &self.base {
518 OffsetsKind::Range(r) => r.size_hint(),
519 OffsetsKind::Indexing(base) => (base.len, Some(base.len)),
520 }
521 }
522
523 fn nth(&mut self, n: usize) -> Option<Self::Item> {
524 match &mut self.base {
525 OffsetsKind::Range(r) => r.nth(n),
526 OffsetsKind::Indexing(base) => {
527 base.step_by(n);
528 self.next()
529 }
530 }
531 }
532
533 fn fold<B, F>(self, init: B, f: F) -> B
534 where
535 Self: Sized,
536 F: FnMut(B, Self::Item) -> B,
537 {
538 match self.base {
539 OffsetsKind::Range(r) => r.fold(init, f),
540 OffsetsKind::Indexing(base) => base.fold(init, f),
541 }
542 }
543}
544
545impl DoubleEndedIterator for Offsets {
546 fn next_back(&mut self) -> Option<Self::Item> {
547 match &mut self.base {
548 OffsetsKind::Range(r) => r.next_back(),
549 OffsetsKind::Indexing(base) => base.next_back(),
550 }
551 }
552}
553
554impl ExactSizeIterator for Offsets {}
555
556impl FusedIterator for Offsets {}
557
558struct LaneRanges {
561 offsets: Offsets,
563
564 dim_size: usize,
566 dim_stride: usize,
567}
568
569impl LaneRanges {
570 fn new<L: Layout + RemoveDim>(layout: &L, dim: usize) -> LaneRanges {
571 let offsets = if layout.is_empty() {
574 Offsets::new(layout)
575 } else {
576 let other_dims = layout.remove_dim(dim);
577 Offsets::new(&other_dims)
578 };
579
580 LaneRanges {
581 offsets,
582 dim_size: layout.size(dim),
583 dim_stride: layout.stride(dim),
584 }
585 }
586
587 fn lane_offset_range(&self, start_offset: usize) -> Range<usize> {
590 lane_offsets(start_offset, self.dim_size, self.dim_stride)
591 }
592}
593
594fn lane_offsets(start_offset: usize, size: usize, stride: usize) -> Range<usize> {
595 start_offset..start_offset + (size - 1) * stride + 1
596}
597
598impl Iterator for LaneRanges {
599 type Item = Range<usize>;
600
601 #[inline]
602 fn next(&mut self) -> Option<Range<usize>> {
603 self.offsets
604 .next()
605 .map(|offset| self.lane_offset_range(offset))
606 }
607
608 fn size_hint(&self) -> (usize, Option<usize>) {
609 self.offsets.size_hint()
610 }
611
612 fn fold<B, F>(self, init: B, mut f: F) -> B
613 where
614 Self: Sized,
615 F: FnMut(B, Self::Item) -> B,
616 {
617 let Self {
618 offsets,
619 dim_size,
620 dim_stride,
621 } = self;
622
623 offsets.fold(init, |acc, offset| {
624 f(acc, lane_offsets(offset, dim_size, dim_stride))
625 })
626 }
627}
628
629impl DoubleEndedIterator for LaneRanges {
630 fn next_back(&mut self) -> Option<Range<usize>> {
631 self.offsets
632 .next_back()
633 .map(|offset| self.lane_offset_range(offset))
634 }
635}
636
637impl ExactSizeIterator for LaneRanges {}
638
639impl FusedIterator for LaneRanges {}
640
641pub struct Lanes<'a, T> {
646 data: ViewData<'a, T>,
647 ranges: LaneRanges,
648 lane_layout: NdLayout<1>,
649}
650
651#[derive(Clone, Debug)]
653pub struct Lane<'a, T> {
654 view: NdTensorView<'a, T, 1>,
655 index: usize,
656}
657
658impl<'a, T> Lane<'a, T> {
659 pub fn as_slice(&self) -> Option<&'a [T]> {
661 self.view.data().map(|data| &data[self.index..])
662 }
663
664 pub fn get(&self, idx: usize) -> Option<&'a T> {
666 self.view.get([idx])
667 }
668
669 pub fn as_view(&self) -> NdTensorView<'a, T, 1> {
671 self.view
672 }
673}
674
675impl<'a, T> From<NdTensorView<'a, T, 1>> for Lane<'a, T> {
676 fn from(val: NdTensorView<'a, T, 1>) -> Self {
677 Lane {
678 view: val,
679 index: 0,
680 }
681 }
682}
683
684impl<'a, T> Iterator for Lane<'a, T> {
685 type Item = &'a T;
686
687 #[inline]
688 fn next(&mut self) -> Option<Self::Item> {
689 if self.index < self.view.len() {
690 let index = self.index;
691 self.index += 1;
692
693 Some(unsafe { self.view.get_unchecked([index]) })
695 } else {
696 None
697 }
698 }
699
700 fn size_hint(&self) -> (usize, Option<usize>) {
701 let size = self.view.size(0);
702 (size, Some(size))
703 }
704}
705
706impl<T> ExactSizeIterator for Lane<'_, T> {}
707
708impl<T> FusedIterator for Lane<'_, T> {}
709
710impl<T: PartialEq> PartialEq<Lane<'_, T>> for Lane<'_, T> {
711 fn eq(&self, other: &Lane<'_, T>) -> bool {
712 self.view.slice(self.index..) == other.view.slice(other.index..)
713 }
714}
715
716impl<T: PartialEq> PartialEq<Lane<'_, T>> for LaneMut<'_, T> {
717 fn eq(&self, other: &Lane<'_, T>) -> bool {
718 self.view.slice(self.index..) == other.view.slice(other.index..)
719 }
720}
721
722impl<'a, T> Lanes<'a, T> {
723 pub(crate) fn new<L: Layout + RemoveDim + Clone>(
726 view: TensorBase<ViewData<'a, T>, L>,
727 dim: usize,
728 ) -> Lanes<'a, T> {
729 let size = view.size(dim);
730 let stride = view.stride(dim);
731 let lane_layout =
732 NdLayout::from_shape_and_strides([size], [stride], OverlapPolicy::AllowOverlap)
733 .unwrap();
734 Lanes {
735 data: view.storage(),
736 ranges: LaneRanges::new(view.layout(), dim),
737 lane_layout,
738 }
739 }
740}
741
742fn lane_for_offset_range<T>(
743 data: ViewData<T>,
744 layout: NdLayout<1>,
745 offsets: Range<usize>,
746) -> Lane<T> {
747 let view = NdTensorView::from_storage_and_layout(data.slice(offsets), layout);
748 Lane { view, index: 0 }
749}
750
751impl<'a, T> Iterator for Lanes<'a, T> {
752 type Item = Lane<'a, T>;
753
754 #[inline]
756 fn next(&mut self) -> Option<Self::Item> {
757 self.ranges
758 .next()
759 .map(|range| lane_for_offset_range(self.data, self.lane_layout, range))
760 }
761
762 fn size_hint(&self) -> (usize, Option<usize>) {
763 self.ranges.size_hint()
764 }
765
766 fn fold<B, F>(self, init: B, mut f: F) -> B
767 where
768 Self: Sized,
769 F: FnMut(B, Self::Item) -> B,
770 {
771 self.ranges.fold(init, |acc, offsets| {
772 let lane = lane_for_offset_range(self.data, self.lane_layout, offsets);
773 f(acc, lane)
774 })
775 }
776}
777
778impl<T> DoubleEndedIterator for Lanes<'_, T> {
779 fn next_back(&mut self) -> Option<Self::Item> {
780 self.ranges
781 .next_back()
782 .map(|range| lane_for_offset_range(self.data, self.lane_layout, range))
783 }
784}
785
786impl<T> ExactSizeIterator for Lanes<'_, T> {}
787
788impl<T> FusedIterator for Lanes<'_, T> {}
789
790pub struct LanesMut<'a, T> {
796 data: ViewMutData<'a, T>,
797 ranges: LaneRanges,
798 lane_layout: NdLayout<1>,
799}
800
801impl<'a, T> LanesMut<'a, T> {
802 pub(crate) fn new<L: Layout + RemoveDim + Clone>(
805 view: TensorBase<ViewMutData<'a, T>, L>,
806 dim: usize,
807 ) -> LanesMut<'a, T> {
808 assert!(
810 !view.is_broadcast(),
811 "Cannot mutably iterate over broadcasting view"
812 );
813
814 let size = view.size(dim);
815 let stride = view.stride(dim);
816
817 let lane_layout =
821 NdLayout::from_shape_and_strides([size], [stride], OverlapPolicy::AllowOverlap)
822 .unwrap();
823
824 LanesMut {
825 ranges: LaneRanges::new(view.layout(), dim),
826 data: view.into_storage(),
827 lane_layout,
828 }
829 }
830}
831
832impl<'a, T> Iterator for LanesMut<'a, T> {
833 type Item = LaneMut<'a, T>;
834
835 #[inline]
836 fn next(&mut self) -> Option<LaneMut<'a, T>> {
837 self.ranges.next().map(|offsets| {
838 unsafe {
841 LaneMut::from_storage_layout(self.data.to_view_slice_mut(offsets), self.lane_layout)
842 }
843 })
844 }
845
846 fn size_hint(&self) -> (usize, Option<usize>) {
847 self.ranges.size_hint()
848 }
849
850 fn fold<B, F>(mut self, init: B, mut f: F) -> B
851 where
852 Self: Sized,
853 F: FnMut(B, Self::Item) -> B,
854 {
855 self.ranges.fold(init, |acc, offsets| {
856 let lane = unsafe {
859 LaneMut::from_storage_layout(self.data.to_view_slice_mut(offsets), self.lane_layout)
860 };
861 f(acc, lane)
862 })
863 }
864}
865
866impl<'a, T> ExactSizeIterator for LanesMut<'a, T> {}
867
868impl<'a, T> DoubleEndedIterator for LanesMut<'a, T> {
869 fn next_back(&mut self) -> Option<LaneMut<'a, T>> {
870 self.ranges.next_back().map(|offsets| {
871 unsafe {
874 LaneMut::from_storage_layout(self.data.to_view_slice_mut(offsets), self.lane_layout)
875 }
876 })
877 }
878}
879
880#[derive(Debug)]
882pub struct LaneMut<'a, T> {
883 view: NdTensorViewMut<'a, T, 1>,
884 index: usize,
885}
886
887impl<'a, T> LaneMut<'a, T> {
888 unsafe fn from_storage_layout(data: ViewMutData<'a, T>, layout: NdLayout<1>) -> Self {
895 let view = unsafe {
896 NdTensorViewMut::from_storage_and_layout_unchecked(data, layout)
900 };
901 LaneMut { view, index: 0 }
902 }
903
904 pub fn as_slice_mut(&mut self) -> Option<&mut [T]> {
906 self.view.data_mut().map(|data| &mut data[self.index..])
907 }
908
909 pub fn into_view(self) -> NdTensorViewMut<'a, T, 1> {
911 self.view
912 }
913}
914
915impl<'a, T> Iterator for LaneMut<'a, T> {
916 type Item = &'a mut T;
917
918 #[inline]
919 fn next(&mut self) -> Option<Self::Item> {
920 if self.index < self.view.size(0) {
921 let index = self.index;
922 self.index += 1;
923 let item = unsafe { self.view.get_unchecked_mut([index]) };
924
925 Some(unsafe { transmute::<&mut T, Self::Item>(item) })
928 } else {
929 None
930 }
931 }
932
933 #[inline]
934 fn nth(&mut self, nth: usize) -> Option<Self::Item> {
935 self.index = (self.index + nth).min(self.view.size(0));
936 self.next()
937 }
938
939 fn size_hint(&self) -> (usize, Option<usize>) {
940 let size = self.view.size(0);
941 (size, Some(size))
942 }
943}
944
945impl<T> ExactSizeIterator for LaneMut<'_, T> {}
946
947impl<T: PartialEq> PartialEq<LaneMut<'_, T>> for LaneMut<'_, T> {
948 fn eq(&self, other: &LaneMut<'_, T>) -> bool {
949 self.view.slice(self.index..) == other.view.slice(other.index..)
950 }
951}
952
953struct InnerIterBase<L: Layout> {
956 outer_offsets: Offsets,
959 inner_layout: L,
960 inner_data_len: usize,
961}
962
963impl<L: Layout + Clone> InnerIterBase<L> {
964 fn new_impl<PL: Layout, F: Fn(&[usize], &[usize]) -> L>(
965 parent_layout: &PL,
966 inner_dims: usize,
967 make_inner_layout: F,
968 ) -> InnerIterBase<L> {
969 assert!(parent_layout.ndim() >= inner_dims);
970 let outer_dims = parent_layout.ndim() - inner_dims;
971 let parent_shape = parent_layout.shape();
972 let parent_strides = parent_layout.strides();
973 let (outer_shape, inner_shape) = parent_shape.as_ref().split_at(outer_dims);
974 let (outer_strides, inner_strides) = parent_strides.as_ref().split_at(outer_dims);
975
976 let outer_layout = DynLayout::from_shape_and_strides(
977 outer_shape,
978 outer_strides,
979 OverlapPolicy::AllowOverlap,
980 )
981 .unwrap();
982
983 let inner_layout = make_inner_layout(inner_shape, inner_strides);
984
985 InnerIterBase {
986 outer_offsets: Offsets::new(&outer_layout),
987 inner_data_len: inner_layout.min_data_len(),
988 inner_layout,
989 }
990 }
991}
992
993impl<const N: usize> InnerIterBase<NdLayout<N>> {
994 pub fn new<L: Layout>(parent_layout: &L) -> Self {
995 Self::new_impl(parent_layout, N, |inner_shape, inner_strides| {
996 let inner_shape: [usize; N] = inner_shape.try_into().unwrap();
997 let inner_strides: [usize; N] = inner_strides.try_into().unwrap();
998 NdLayout::from_shape_and_strides(
999 inner_shape,
1000 inner_strides,
1001 OverlapPolicy::AllowOverlap,
1004 )
1005 .expect("failed to create layout")
1006 })
1007 }
1008}
1009
1010impl InnerIterBase<DynLayout> {
1011 pub fn new_dyn<L: Layout>(parent_layout: &L, inner_dims: usize) -> Self {
1012 Self::new_impl(parent_layout, inner_dims, |inner_shape, inner_strides| {
1013 DynLayout::from_shape_and_strides(
1014 inner_shape,
1015 inner_strides,
1016 OverlapPolicy::AllowOverlap,
1019 )
1020 .expect("failed to create layout")
1021 })
1022 }
1023}
1024
1025impl<L: Layout> Iterator for InnerIterBase<L> {
1026 type Item = Range<usize>;
1028
1029 fn next(&mut self) -> Option<Range<usize>> {
1030 self.outer_offsets
1031 .next()
1032 .map(|offset| offset..offset + self.inner_data_len)
1033 }
1034
1035 fn size_hint(&self) -> (usize, Option<usize>) {
1036 self.outer_offsets.size_hint()
1037 }
1038
1039 fn fold<B, F>(self, init: B, mut f: F) -> B
1040 where
1041 Self: Sized,
1042 F: FnMut(B, Self::Item) -> B,
1043 {
1044 self.outer_offsets.fold(init, |acc, offset| {
1045 f(acc, offset..offset + self.inner_data_len)
1046 })
1047 }
1048}
1049
1050impl<L: Layout> ExactSizeIterator for InnerIterBase<L> {}
1051
1052impl<L: Layout> DoubleEndedIterator for InnerIterBase<L> {
1053 fn next_back(&mut self) -> Option<Self::Item> {
1054 self.outer_offsets
1055 .next_back()
1056 .map(|offset| offset..offset + self.inner_data_len)
1057 }
1058}
1059
1060pub struct InnerIter<'a, T, L: Layout> {
1063 base: InnerIterBase<L>,
1064 data: ViewData<'a, T>,
1065}
1066
1067impl<'a, T, const N: usize> InnerIter<'a, T, NdLayout<N>> {
1068 pub fn new<L: Layout + Clone>(view: TensorBase<ViewData<'a, T>, L>) -> Self {
1069 let base = InnerIterBase::new(&view);
1070 InnerIter {
1071 base,
1072 data: view.storage(),
1073 }
1074 }
1075}
1076
1077impl<'a, T> InnerIter<'a, T, DynLayout> {
1078 pub fn new_dyn<L: Layout + Clone>(
1079 view: TensorBase<ViewData<'a, T>, L>,
1080 inner_dims: usize,
1081 ) -> Self {
1082 let base = InnerIterBase::new_dyn(&view, inner_dims);
1083 InnerIter {
1084 base,
1085 data: view.storage(),
1086 }
1087 }
1088}
1089
1090impl<'a, T, L: Layout + Clone> Iterator for InnerIter<'a, T, L> {
1091 type Item = TensorBase<ViewData<'a, T>, L>;
1092
1093 fn next(&mut self) -> Option<Self::Item> {
1094 self.base.next().map(|offset_range| {
1095 TensorBase::from_storage_and_layout(
1096 self.data.slice(offset_range),
1097 self.base.inner_layout.clone(),
1098 )
1099 })
1100 }
1101
1102 fn size_hint(&self) -> (usize, Option<usize>) {
1103 self.base.size_hint()
1104 }
1105
1106 fn fold<B, F>(self, init: B, mut f: F) -> B
1107 where
1108 Self: Sized,
1109 F: FnMut(B, Self::Item) -> B,
1110 {
1111 let inner_layout = self.base.inner_layout.clone();
1112 self.base.fold(init, |acc, offset_range| {
1113 let item = TensorBase::from_storage_and_layout(
1114 self.data.slice(offset_range),
1115 inner_layout.clone(),
1116 );
1117 f(acc, item)
1118 })
1119 }
1120}
1121
1122impl<T, L: Layout + Clone> ExactSizeIterator for InnerIter<'_, T, L> {}
1123
1124impl<T, L: Layout + Clone> DoubleEndedIterator for InnerIter<'_, T, L> {
1125 fn next_back(&mut self) -> Option<Self::Item> {
1126 self.base.next_back().map(|offset_range| {
1127 TensorBase::from_storage_and_layout(
1128 self.data.slice(offset_range),
1129 self.base.inner_layout.clone(),
1130 )
1131 })
1132 }
1133}
1134
1135pub struct InnerIterMut<'a, T, L: Layout> {
1138 base: InnerIterBase<L>,
1139 data: ViewMutData<'a, T>,
1140}
1141
1142impl<'a, T, const N: usize> InnerIterMut<'a, T, NdLayout<N>> {
1143 pub fn new<L: Layout>(view: TensorBase<ViewMutData<'a, T>, L>) -> Self {
1144 let base = InnerIterBase::new(&view);
1145 InnerIterMut {
1146 base,
1147 data: view.into_storage(),
1148 }
1149 }
1150}
1151
1152impl<'a, T> InnerIterMut<'a, T, DynLayout> {
1153 pub fn new_dyn<L: Layout>(view: TensorBase<ViewMutData<'a, T>, L>, inner_dims: usize) -> Self {
1154 let base = InnerIterBase::new_dyn(&view, inner_dims);
1155 InnerIterMut {
1156 base,
1157 data: view.into_storage(),
1158 }
1159 }
1160}
1161
1162impl<'a, T, L: Layout + Clone> Iterator for InnerIterMut<'a, T, L> {
1163 type Item = TensorBase<ViewMutData<'a, T>, L>;
1164
1165 fn next(&mut self) -> Option<Self::Item> {
1166 self.base.next().map(|offset_range| {
1167 let storage = self.data.slice_mut(offset_range);
1168 let storage = unsafe {
1169 std::mem::transmute::<ViewMutData<'_, T>, ViewMutData<'a, T>>(storage)
1174 };
1175 TensorBase::from_storage_and_layout(storage, self.base.inner_layout.clone())
1176 })
1177 }
1178
1179 fn size_hint(&self) -> (usize, Option<usize>) {
1180 self.base.size_hint()
1181 }
1182
1183 fn fold<B, F>(mut self, init: B, mut f: F) -> B
1184 where
1185 Self: Sized,
1186 F: FnMut(B, Self::Item) -> B,
1187 {
1188 let inner_layout = self.base.inner_layout.clone();
1189 self.base.fold(init, |acc, offset_range| {
1190 let storage = self.data.slice_mut(offset_range);
1191 let storage = unsafe {
1192 std::mem::transmute::<ViewMutData<'_, T>, ViewMutData<'a, T>>(storage)
1197 };
1198 let item = TensorBase::from_storage_and_layout(storage, inner_layout.clone());
1199 f(acc, item)
1200 })
1201 }
1202}
1203
1204impl<T, L: Layout + Clone> ExactSizeIterator for InnerIterMut<'_, T, L> {}
1205
1206impl<'a, T, L: Layout + Clone> DoubleEndedIterator for InnerIterMut<'a, T, L> {
1207 fn next_back(&mut self) -> Option<Self::Item> {
1208 self.base.next_back().map(|offset_range| {
1209 let storage = self.data.slice_mut(offset_range);
1210 let storage = unsafe {
1211 std::mem::transmute::<ViewMutData<'_, T>, ViewMutData<'a, T>>(storage)
1214 };
1215 TensorBase::from_storage_and_layout(storage, self.base.inner_layout.clone())
1216 })
1217 }
1218}
1219
1220pub struct AxisIter<'a, T, L: Layout + RemoveDim> {
1223 view: TensorBase<ViewData<'a, T>, L>,
1224 axis: usize,
1225 index: usize,
1226 end: usize,
1227}
1228
1229impl<'a, T, L: MutLayout + RemoveDim> AxisIter<'a, T, L> {
1230 pub fn new(view: &TensorBase<ViewData<'a, T>, L>, axis: usize) -> AxisIter<'a, T, L> {
1231 assert!(axis < view.ndim());
1232 AxisIter {
1233 view: view.clone(),
1234 axis,
1235 index: 0,
1236 end: view.size(axis),
1237 }
1238 }
1239}
1240
1241impl<'a, T, L: MutLayout + RemoveDim> Iterator for AxisIter<'a, T, L> {
1242 type Item = TensorBase<ViewData<'a, T>, <L as RemoveDim>::Output>;
1243
1244 fn next(&mut self) -> Option<Self::Item> {
1245 if self.index >= self.end {
1246 None
1247 } else {
1248 let slice = self.view.index_axis(self.axis, self.index);
1249 self.index += 1;
1250 Some(slice)
1251 }
1252 }
1253
1254 fn size_hint(&self) -> (usize, Option<usize>) {
1255 let len = self.end - self.index;
1256 (len, Some(len))
1257 }
1258}
1259
1260impl<'a, T, L: MutLayout + RemoveDim> ExactSizeIterator for AxisIter<'a, T, L> {}
1261
1262impl<'a, T, L: MutLayout + RemoveDim> DoubleEndedIterator for AxisIter<'a, T, L> {
1263 fn next_back(&mut self) -> Option<Self::Item> {
1264 if self.index >= self.end {
1265 None
1266 } else {
1267 let slice = self.view.index_axis(self.axis, self.end - 1);
1268 self.end -= 1;
1269 Some(slice)
1270 }
1271 }
1272}
1273
1274pub struct AxisIterMut<'a, T, L: Layout + RemoveDim> {
1276 view: TensorBase<ViewMutData<'a, T>, L>,
1277 axis: usize,
1278 index: usize,
1279 end: usize,
1280}
1281
1282impl<'a, T, L: Layout + RemoveDim + Clone> AxisIterMut<'a, T, L> {
1283 pub fn new(view: TensorBase<ViewMutData<'a, T>, L>, axis: usize) -> AxisIterMut<'a, T, L> {
1284 assert!(
1286 !view.layout().is_broadcast(),
1287 "Cannot mutably iterate over broadcasting view"
1288 );
1289 assert!(axis < view.ndim());
1290 AxisIterMut {
1291 axis,
1292 index: 0,
1293 end: view.size(axis),
1294 view,
1295 }
1296 }
1297}
1298
1299type SmallerMutView<'b, T, L> = TensorBase<ViewMutData<'b, T>, <L as RemoveDim>::Output>;
1301
1302impl<'a, T, L: MutLayout + RemoveDim> Iterator for AxisIterMut<'a, T, L> {
1303 type Item = TensorBase<ViewMutData<'a, T>, <L as RemoveDim>::Output>;
1304
1305 fn next(&mut self) -> Option<Self::Item> {
1306 if self.index >= self.end {
1307 None
1308 } else {
1309 let index = self.index;
1310 self.index += 1;
1311
1312 let slice = self.view.index_axis_mut(self.axis, index);
1313
1314 let view = unsafe { transmute::<SmallerMutView<'_, T, L>, Self::Item>(slice) };
1319
1320 Some(view)
1321 }
1322 }
1323
1324 fn size_hint(&self) -> (usize, Option<usize>) {
1325 let len = self.end - self.index;
1326 (len, Some(len))
1327 }
1328}
1329
1330impl<'a, T, L: MutLayout + RemoveDim> ExactSizeIterator for AxisIterMut<'a, T, L> {}
1331
1332impl<'a, T, L: MutLayout + RemoveDim> DoubleEndedIterator for AxisIterMut<'a, T, L> {
1333 fn next_back(&mut self) -> Option<Self::Item> {
1334 if self.index >= self.end {
1335 None
1336 } else {
1337 let index = self.end - 1;
1338 self.end -= 1;
1339
1340 let slice = self.view.index_axis_mut(self.axis, index);
1341
1342 let view = unsafe { transmute::<SmallerMutView<'_, T, L>, Self::Item>(slice) };
1347
1348 Some(view)
1349 }
1350 }
1351}
1352
1353pub struct AxisChunks<'a, T, L: MutLayout> {
1356 remainder: Option<TensorBase<ViewData<'a, T>, L>>,
1357 axis: usize,
1358 chunk_size: usize,
1359}
1360
1361impl<'a, T, L: MutLayout> AxisChunks<'a, T, L> {
1362 pub fn new(
1363 view: &TensorBase<ViewData<'a, T>, L>,
1364 axis: usize,
1365 chunk_size: usize,
1366 ) -> AxisChunks<'a, T, L> {
1367 assert!(chunk_size > 0, "chunk size must be > 0");
1368 AxisChunks {
1369 remainder: if view.size(axis) > 0 {
1370 Some(view.view())
1371 } else {
1372 None
1373 },
1374 axis,
1375 chunk_size,
1376 }
1377 }
1378}
1379
1380impl<'a, T, L: MutLayout> Iterator for AxisChunks<'a, T, L> {
1381 type Item = TensorBase<ViewData<'a, T>, L>;
1382
1383 fn next(&mut self) -> Option<Self::Item> {
1384 let remainder = self.remainder.take()?;
1385 let chunk_len = self.chunk_size.min(remainder.size(self.axis));
1386 let (current, next_remainder) = remainder.split_at(self.axis, chunk_len);
1387 self.remainder = if next_remainder.size(self.axis) > 0 {
1388 Some(next_remainder)
1389 } else {
1390 None
1391 };
1392 Some(current)
1393 }
1394
1395 fn size_hint(&self) -> (usize, Option<usize>) {
1396 let len = self
1397 .remainder
1398 .as_ref()
1399 .map(|r| r.size(self.axis))
1400 .unwrap_or(0)
1401 .div_ceil(self.chunk_size);
1402 (len, Some(len))
1403 }
1404}
1405
1406impl<'a, T, L: MutLayout> ExactSizeIterator for AxisChunks<'a, T, L> {}
1407
1408impl<'a, T, L: MutLayout> DoubleEndedIterator for AxisChunks<'a, T, L> {
1409 fn next_back(&mut self) -> Option<Self::Item> {
1410 let remainder = self.remainder.take()?;
1411 let chunk_len = self.chunk_size.min(remainder.size(self.axis));
1412 let (prev_remainder, current) =
1413 remainder.split_at(self.axis, remainder.size(self.axis) - chunk_len);
1414 self.remainder = if prev_remainder.size(self.axis) > 0 {
1415 Some(prev_remainder)
1416 } else {
1417 None
1418 };
1419 Some(current)
1420 }
1421}
1422
1423pub struct AxisChunksMut<'a, T, L: MutLayout> {
1425 remainder: Option<TensorBase<ViewMutData<'a, T>, L>>,
1426 axis: usize,
1427 chunk_size: usize,
1428}
1429
1430impl<'a, T, L: MutLayout> AxisChunksMut<'a, T, L> {
1431 pub fn new(
1432 view: TensorBase<ViewMutData<'a, T>, L>,
1433 axis: usize,
1434 chunk_size: usize,
1435 ) -> AxisChunksMut<'a, T, L> {
1436 assert!(
1438 !view.layout().is_broadcast(),
1439 "Cannot mutably iterate over broadcasting view"
1440 );
1441 assert!(chunk_size > 0, "chunk size must be > 0");
1442 AxisChunksMut {
1443 remainder: if view.size(axis) > 0 {
1444 Some(view)
1445 } else {
1446 None
1447 },
1448 axis,
1449 chunk_size,
1450 }
1451 }
1452}
1453
1454impl<'a, T, L: MutLayout> Iterator for AxisChunksMut<'a, T, L> {
1455 type Item = TensorBase<ViewMutData<'a, T>, L>;
1456
1457 fn next(&mut self) -> Option<Self::Item> {
1458 let remainder = self.remainder.take()?;
1459 let chunk_len = self.chunk_size.min(remainder.size(self.axis));
1460 let (current, next_remainder) = remainder.split_at_mut(self.axis, chunk_len);
1461 self.remainder = if next_remainder.size(self.axis) > 0 {
1462 Some(next_remainder)
1463 } else {
1464 None
1465 };
1466 Some(current)
1467 }
1468
1469 fn size_hint(&self) -> (usize, Option<usize>) {
1470 let len = self
1471 .remainder
1472 .as_ref()
1473 .map(|r| r.size(self.axis))
1474 .unwrap_or(0)
1475 .div_ceil(self.chunk_size);
1476 (len, Some(len))
1477 }
1478}
1479
1480impl<'a, T, L: MutLayout> ExactSizeIterator for AxisChunksMut<'a, T, L> {}
1481
1482impl<'a, T, L: MutLayout> DoubleEndedIterator for AxisChunksMut<'a, T, L> {
1483 fn next_back(&mut self) -> Option<Self::Item> {
1484 let remainder = self.remainder.take()?;
1485 let remainder_size = remainder.size(self.axis);
1486 let chunk_len = self.chunk_size.min(remainder_size);
1487 let (prev_remainder, current) =
1488 remainder.split_at_mut(self.axis, remainder_size - chunk_len);
1489 self.remainder = if prev_remainder.size(self.axis) > 0 {
1490 Some(prev_remainder)
1491 } else {
1492 None
1493 };
1494 Some(current)
1495 }
1496}
1497
1498pub fn for_each_mut<T, F: Fn(&mut T)>(mut view: TensorViewMut<T>, f: F) {
1500 while view.ndim() < 4 {
1501 view.insert_axis(0);
1502 }
1503
1504 view.inner_iter_mut::<4>().for_each(|mut src| {
1510 for i0 in 0..src.size(0) {
1511 for i1 in 0..src.size(1) {
1512 for i2 in 0..src.size(2) {
1513 for i3 in 0..src.size(3) {
1514 let x = unsafe { src.get_unchecked_mut([i0, i1, i2, i3]) };
1516 f(x);
1517 }
1518 }
1519 }
1520 }
1521 });
1522}
1523
1524#[cfg(test)]
1527mod tests {
1528 use crate::{
1529 AsView, AxisChunks, AxisChunksMut, Lanes, LanesMut, Layout, NdLayout, NdTensor, Tensor,
1530 };
1531
1532 fn compare_reversed<T: PartialEq + std::fmt::Debug>(fwd: &[T], rev: &[T]) {
1533 assert_eq!(fwd.len(), rev.len());
1534 for (x, y) in fwd.iter().zip(rev.iter().rev()) {
1535 assert_eq!(x, y);
1536 }
1537 }
1538
1539 fn test_iterator<I: Iterator + ExactSizeIterator + DoubleEndedIterator>(
1541 create_iter: impl Fn() -> I,
1542 expected: &[I::Item],
1543 ) where
1544 I::Item: PartialEq + std::fmt::Debug,
1545 {
1546 let iter = create_iter();
1547
1548 let (min_len, max_len) = iter.size_hint();
1549 let items: Vec<_> = iter.collect();
1550
1551 assert_eq!(&items, expected);
1552
1553 assert_eq!(min_len, items.len(), "incorrect size lower bound");
1555 assert_eq!(max_len, Some(items.len()), "incorrect size upper bound");
1556
1557 let rev_items: Vec<_> = create_iter().rev().collect();
1559 compare_reversed(&items, &rev_items);
1560
1561 let mut iter = create_iter();
1563 for _x in &mut iter { }
1564 assert_eq!(iter.next(), None);
1565
1566 let mut fold_items = Vec::new();
1568 let mut idx = 0;
1569 create_iter().fold(0, |acc, item| {
1570 assert_eq!(acc, idx);
1571 fold_items.push(item);
1572 idx += 1;
1573 idx
1574 });
1575 assert_eq!(items, fold_items);
1576 }
1577
1578 trait MutIterable {
1583 type Iter<'a>: Iterator + ExactSizeIterator + DoubleEndedIterator
1584 where
1585 Self: 'a;
1586
1587 fn iter_mut(&mut self) -> Self::Iter<'_>;
1588 }
1589
1590 fn test_mut_iterator<M, T>(mut iterable: M, expected: &[T])
1592 where
1593 M: MutIterable,
1594 T: std::fmt::Debug,
1595 for<'a> <M::Iter<'a> as Iterator>::Item: std::fmt::Debug + PartialEq + PartialEq<T>,
1596 {
1597 {
1599 let iter = iterable.iter_mut();
1600 let (min_len, max_len) = iter.size_hint();
1601 let items: Vec<_> = iter.collect();
1602
1603 assert_eq!(items, expected);
1605
1606 assert_eq!(min_len, items.len(), "incorrect size lower bound");
1608 assert_eq!(max_len, Some(items.len()), "incorrect size upper bound");
1609 }
1610
1611 {
1613 let mut iter = iterable.iter_mut();
1614 for _x in &mut iter { }
1615 assert!(iter.next().is_none());
1616 }
1617
1618 {
1624 let items: Vec<_> = iterable.iter_mut().map(|x| format!("{:?}", x)).collect();
1625 let rev_items: Vec<_> = iterable
1626 .iter_mut()
1627 .rev()
1628 .map(|x| format!("{:?}", x))
1629 .collect();
1630 compare_reversed(&items, &rev_items);
1631 }
1632
1633 {
1635 let items: Vec<_> = iterable.iter_mut().map(|x| format!("{:?}", x)).collect();
1636 let mut fold_items = Vec::new();
1637 let mut idx = 0;
1638 iterable.iter_mut().fold(0, |acc, item| {
1639 assert_eq!(acc, idx);
1640 fold_items.push(format!("{:?}", item));
1641 idx += 1;
1642 idx
1643 });
1644 assert_eq!(items, fold_items);
1645 }
1646 }
1647
1648 #[test]
1649 fn test_axis_chunks() {
1650 let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1651 test_iterator(
1652 || tensor.axis_chunks(0, 1),
1653 &[tensor.slice(0..1), tensor.slice(1..2)],
1654 );
1655 }
1656
1657 #[test]
1658 fn test_axis_chunks_empty() {
1659 let x = Tensor::<i32>::zeros(&[5, 0]);
1660 assert!(AxisChunks::new(&x.view(), 1, 1).next().is_none());
1661 }
1662
1663 #[test]
1664 #[should_panic(expected = "chunk size must be > 0")]
1665 fn test_axis_chunks_zero_size() {
1666 let x = Tensor::<i32>::zeros(&[5, 0]);
1667 assert!(AxisChunks::new(&x.view(), 1, 0).next().is_none());
1668 }
1669
1670 #[test]
1671 fn test_axis_chunks_mut_empty() {
1672 let mut x = Tensor::<i32>::zeros(&[5, 0]);
1673 assert!(AxisChunksMut::new(x.view_mut(), 1, 1).next().is_none());
1674 }
1675
1676 #[test]
1677 fn test_axis_chunks_mut_rev() {
1678 let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1679 let fwd: Vec<_> = tensor
1680 .axis_chunks_mut(0, 1)
1681 .map(|view| view.to_vec())
1682 .collect();
1683 let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1684 let rev: Vec<_> = tensor
1685 .axis_chunks_mut(0, 1)
1686 .rev()
1687 .map(|view| view.to_vec())
1688 .collect();
1689 compare_reversed(&fwd, &rev);
1690 }
1691
1692 #[test]
1693 #[should_panic(expected = "chunk size must be > 0")]
1694 fn test_axis_chunks_mut_zero_size() {
1695 let mut x = Tensor::<i32>::zeros(&[5, 0]);
1696 assert!(AxisChunksMut::new(x.view_mut(), 1, 0).next().is_none());
1697 }
1698
1699 #[test]
1700 fn test_axis_iter() {
1701 let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1702 test_iterator(|| tensor.axis_iter(0), &[tensor.slice(0), tensor.slice(1)]);
1703 }
1704
1705 #[test]
1706 fn test_axis_iter_mut_rev() {
1707 let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1708 let fwd: Vec<_> = tensor.axis_iter_mut(0).map(|view| view.to_vec()).collect();
1709 let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1710 let rev: Vec<_> = tensor
1711 .axis_iter_mut(0)
1712 .rev()
1713 .map(|view| view.to_vec())
1714 .collect();
1715 compare_reversed(&fwd, &rev);
1716 }
1717
1718 #[test]
1719 fn test_inner_iter() {
1720 let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1721 test_iterator(
1722 || tensor.inner_iter::<2>(),
1723 &[tensor.slice(0), tensor.slice(1)],
1724 );
1725 }
1726
1727 #[test]
1728 fn test_inner_iter_mut() {
1729 struct InnerIterMutTest(NdTensor<i32, 3>);
1730
1731 impl MutIterable for InnerIterMutTest {
1732 type Iter<'a> = super::InnerIterMut<'a, i32, NdLayout<2>>;
1733
1734 fn iter_mut(&mut self) -> Self::Iter<'_> {
1735 self.0.inner_iter_mut::<2>()
1736 }
1737 }
1738
1739 let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1740 test_mut_iterator(
1741 InnerIterMutTest(tensor.clone()),
1742 &[tensor.slice(0), tensor.slice(1)],
1743 );
1744 }
1745
1746 #[test]
1747 fn test_lanes() {
1748 let x = NdTensor::from([[1, 2], [3, 4]]);
1749 test_iterator(
1750 || x.lanes(0),
1751 &[x.slice((.., 0)).into(), x.slice((.., 1)).into()],
1752 );
1753 test_iterator(|| x.lanes(1), &[x.slice(0).into(), x.slice(1).into()]);
1754 }
1755
1756 #[test]
1757 fn test_lanes_empty() {
1758 let x = Tensor::<i32>::zeros(&[5, 0]);
1759 assert!(Lanes::new(x.view().view_ref(), 0).next().is_none());
1760 assert!(Lanes::new(x.view().view_ref(), 1).next().is_none());
1761 }
1762
1763 #[test]
1764 fn test_lanes_mut() {
1765 use super::Lane;
1766
1767 struct LanesMutTest(NdTensor<i32, 2>);
1768
1769 impl MutIterable for LanesMutTest {
1770 type Iter<'a> = super::LanesMut<'a, i32>;
1771
1772 fn iter_mut(&mut self) -> Self::Iter<'_> {
1773 self.0.lanes_mut(0)
1774 }
1775 }
1776
1777 let tensor = NdTensor::from([[1, 2], [3, 4]]);
1778 test_mut_iterator::<_, Lane<i32>>(
1779 LanesMutTest(tensor.clone()),
1780 &[
1781 Lane::from(tensor.slice((.., 0))),
1782 Lane::from(tensor.slice((.., 1))),
1783 ],
1784 );
1785 }
1786
1787 #[test]
1788 fn test_lane_as_slice() {
1789 let x = NdTensor::from([0, 1, 2]);
1791 let mut lane = x.lanes(0).next().unwrap();
1792 assert_eq!(lane.as_slice(), Some([0, 1, 2].as_slice()));
1793 lane.next();
1794 assert_eq!(lane.as_slice(), Some([1, 2].as_slice()));
1795 lane.next();
1796 lane.next();
1797 assert_eq!(lane.as_slice(), Some([0i32; 0].as_slice()));
1798 lane.next();
1799 assert_eq!(lane.as_slice(), Some([0i32; 0].as_slice()));
1800
1801 let x = NdTensor::from([[1i32, 2], [3, 4]]);
1803 let lane = x.lanes(0).next().unwrap();
1804 assert_eq!(lane.as_slice(), None);
1805 }
1806
1807 #[test]
1808 fn test_lanes_mut_empty() {
1809 let mut x = Tensor::<i32>::zeros(&[5, 0]);
1810 assert!(LanesMut::new(x.mut_view_ref(), 0).next().is_none());
1811 assert!(LanesMut::new(x.mut_view_ref(), 1).next().is_none());
1812 }
1813
1814 #[test]
1815 fn test_iter_step_by() {
1816 let tensor = Tensor::<f32>::full(&[1, 3, 16, 8], 1.);
1817
1818 let tensor = tensor.slice((.., .., 1.., ..));
1821
1822 let sum = tensor.iter().sum::<f32>();
1823 for n_skip in 0..tensor.len() {
1824 let sum_skip = tensor.iter().skip(n_skip).sum::<f32>();
1825 assert_eq!(
1826 sum_skip,
1827 sum - n_skip as f32,
1828 "wrong sum for n_skip={}",
1829 n_skip
1830 );
1831 }
1832 }
1833
1834 #[test]
1835 fn test_iter_broadcast() {
1836 let tensor = Tensor::<f32>::full(&[1], 1.);
1837 let broadcast = tensor.broadcast([1, 3, 16, 8]);
1838 assert_eq!(broadcast.iter().len(), broadcast.len());
1839 let count = broadcast.iter().count();
1840 assert_eq!(count, broadcast.len());
1841 let sum = broadcast.iter().sum::<f32>();
1842 assert_eq!(sum, broadcast.len() as f32);
1843 }
1844
1845 #[test]
1846 fn test_iter() {
1847 let tensor = NdTensor::from([[[1, 2], [3, 4]]]);
1848
1849 test_iterator(|| tensor.iter().copied(), &[1, 2, 3, 4]);
1851
1852 test_iterator(|| tensor.transposed().iter().copied(), &[1, 3, 2, 4]);
1854 }
1855
1856 #[test]
1857 fn test_iter_mut() {
1858 struct IterTest(NdTensor<i32, 3>);
1859
1860 impl MutIterable for IterTest {
1861 type Iter<'a> = super::IterMut<'a, i32>;
1862
1863 fn iter_mut(&mut self) -> Self::Iter<'_> {
1864 self.0.iter_mut()
1865 }
1866 }
1867
1868 let tensor = NdTensor::from([[[1, 2], [3, 4]]]);
1869 test_mut_iterator(IterTest(tensor), &[&1, &2, &3, &4]);
1870 }
1871
1872 #[test]
1873 #[ignore]
1874 fn bench_iter() {
1875 use crate::Layout;
1876 use rten_bench::run_bench;
1877
1878 type Elem = i32;
1879
1880 let tensor = std::hint::black_box(Tensor::<Elem>::full(&[1, 6, 768, 64], 1));
1881 let n_trials = 1000;
1882 let mut result = Elem::default();
1883
1884 fn reduce<'a>(iter: impl Iterator<Item = &'a Elem>) -> Elem {
1885 iter.fold(Elem::default(), |acc, x| acc.wrapping_add(*x))
1886 }
1887
1888 run_bench(n_trials, Some("slice iter"), || {
1890 result = reduce(tensor.data().unwrap().iter());
1891 });
1892 println!("sum {}", result);
1893
1894 run_bench(n_trials, Some("contiguous iter"), || {
1897 result = reduce(tensor.iter());
1898 });
1899 println!("sum {}", result);
1900
1901 run_bench(n_trials, Some("contiguous reverse iter"), || {
1902 result = reduce(tensor.iter().rev());
1903 });
1904 println!("sum {}", result);
1905
1906 let slice = tensor.slice((.., .., 1.., ..));
1909 assert!(!slice.is_contiguous());
1910 let n_trials = 1000;
1911 run_bench(n_trials, Some("non-contiguous iter"), || {
1912 result = reduce(slice.iter());
1913 });
1914 println!("sum {}", result);
1915
1916 let n_trials = 100;
1919 run_bench(n_trials, Some("non-contiguous reverse iter"), || {
1920 result = reduce(slice.iter().rev());
1921 });
1922 println!("sum {}", result);
1923 }
1924
1925 #[test]
1926 #[ignore]
1927 fn bench_inner_iter() {
1928 use crate::rng::XorShiftRng;
1929 use rten_bench::run_bench;
1930
1931 let n_trials = 100;
1932 let mut rng = XorShiftRng::new(1234);
1933
1934 let tensor = Tensor::<f32>::rand(&[512, 512, 12, 1], &mut rng);
1938
1939 let mut sum = 0.;
1940 run_bench(n_trials, Some("inner iter"), || {
1941 for inner in tensor.inner_iter::<2>() {
1942 for i0 in 0..inner.size(0) {
1943 for i1 in 0..inner.size(1) {
1944 sum += inner[[i0, i1]];
1945 }
1946 }
1947 }
1948 });
1949 println!("sum {}", sum);
1950 }
1951}