1use std::iter::FusedIterator;
4use std::mem::transmute;
5use std::ops::Range;
6
7use rten_base::iter::SplitIterator;
8
9use super::{AsView, DynLayout, NdTensorView, NdTensorViewMut, TensorBase, TensorViewMut};
10use crate::layout::{Layout, MutLayout, NdLayout, OverlapPolicy, RemoveDim, merge_axes};
11use crate::storage::{StorageMut, ViewData, ViewMutData};
12
13mod parallel;
14
15#[derive(Copy, Clone, Debug, Default)]
17struct IterPos {
18 remaining: usize,
20
21 offset: usize,
23
24 stride: usize,
26
27 max_remaining: usize,
29}
30
31impl IterPos {
32 fn from_size_stride(size: usize, stride: usize) -> Self {
33 let remaining = size.saturating_sub(1);
34 IterPos {
35 remaining,
36 offset: 0,
37 stride,
38 max_remaining: remaining,
39 }
40 }
41
42 #[inline(always)]
43 fn step(&mut self) -> bool {
44 if self.remaining != 0 {
45 self.remaining -= 1;
46 self.offset += self.stride;
47 true
48 } else {
49 self.remaining = self.max_remaining;
50 self.offset = 0;
51 false
52 }
53 }
54
55 fn size(&self) -> usize {
57 self.max_remaining + 1
60 }
61
62 fn index(&self) -> usize {
64 self.max_remaining - self.remaining
65 }
66
67 fn set_index(&mut self, index: usize) {
69 self.remaining = self.max_remaining - index;
70 self.offset = index * self.stride;
71 }
72}
73
74const INNER_NDIM: usize = 2;
75
76#[derive(Clone, Debug)]
78struct OffsetsBase {
79 len: usize,
84
85 inner_offset: usize,
87
88 inner_pos: [IterPos; INNER_NDIM],
90
91 outer_offset: usize,
93
94 outer_pos: Vec<IterPos>,
101}
102
103impl OffsetsBase {
104 fn new<L: Layout>(layout: &L) -> OffsetsBase {
106 let merged = merge_axes(layout.shape().as_ref(), layout.strides().as_ref());
109
110 let inner_pos_pad = INNER_NDIM.saturating_sub(merged.len());
111 let n_outer = merged.len().saturating_sub(INNER_NDIM);
112
113 let inner_pos = std::array::from_fn(|dim| {
114 let (size, stride) = if dim < inner_pos_pad {
115 (1, 0)
116 } else {
117 merged[n_outer + dim - inner_pos_pad]
118 };
119 IterPos::from_size_stride(size, stride)
120 });
121
122 let outer_pos = (0..n_outer)
123 .map(|i| {
124 let (size, stride) = merged[i];
125 IterPos::from_size_stride(size, stride)
126 })
127 .collect();
128
129 OffsetsBase {
130 len: merged.iter().map(|dim| dim.0).product(),
131 inner_pos,
132 inner_offset: 0,
133 outer_pos,
134 outer_offset: 0,
135 }
136 }
137
138 fn step_outer_pos(&mut self) -> bool {
143 let mut done = self.outer_pos.is_empty();
144 for (i, dim) in self.outer_pos.iter_mut().enumerate().rev() {
145 if dim.step() {
146 break;
147 } else if i == 0 {
148 done = true;
149 }
150 }
151 self.outer_offset = self.outer_pos.iter().map(|p| p.offset).sum();
152 !done
153 }
154
155 fn pos(&self, dim: usize) -> IterPos {
156 let outer_ndim = self.outer_pos.len();
157 if dim >= outer_ndim {
158 self.inner_pos[dim - outer_ndim]
159 } else {
160 self.outer_pos[dim]
161 }
162 }
163
164 fn pos_mut(&mut self, dim: usize) -> &mut IterPos {
165 let outer_ndim = self.outer_pos.len();
166 if dim >= outer_ndim {
167 &mut self.inner_pos[dim - outer_ndim]
168 } else {
169 &mut self.outer_pos[dim]
170 }
171 }
172
173 fn step_by(&mut self, n: usize) {
175 let mut remaining = n.min(self.len);
176 self.len -= remaining;
177
178 for dim in (0..self.ndim()).rev() {
179 if remaining == 0 {
180 break;
181 }
182
183 let pos = self.pos_mut(dim);
184 let size = pos.size();
185 let new_index = pos.index() + remaining;
186 pos.set_index(new_index % size);
187 remaining = new_index / size;
188 }
189
190 self.inner_offset = self.inner_pos.iter().map(|p| p.offset).sum();
192 self.outer_offset = self.outer_pos.iter().map(|p| p.offset).sum();
193 }
194
195 fn ndim(&self) -> usize {
196 self.outer_pos.len() + self.inner_pos.len()
197 }
198
199 fn offset_from_linear_index(&self, index: usize) -> usize {
202 let mut offset = 0;
203 let mut shape_product = 1;
204 for dim in (0..self.ndim()).rev() {
205 let pos = self.pos(dim);
206 let dim_index = (index / shape_product) % pos.size();
207 shape_product *= pos.size();
208 offset += dim_index * pos.stride;
209 }
210 offset
211 }
212
213 fn truncate(&mut self, len: usize) {
215 self.len = self.len.min(len);
219 }
220}
221
222impl Iterator for OffsetsBase {
223 type Item = usize;
224
225 #[inline(always)]
226 fn next(&mut self) -> Option<usize> {
227 if self.len == 0 {
228 return None;
229 }
230 let offset = self.outer_offset + self.inner_offset;
231
232 self.len -= 1;
233
234 self.inner_offset += self.inner_pos[1].stride;
237
238 if !self.inner_pos[1].step() {
241 if !self.inner_pos[0].step() {
242 self.step_outer_pos();
243 }
244
245 self.inner_offset = self.inner_pos[0].offset;
250 }
251
252 Some(offset)
253 }
254
255 fn size_hint(&self) -> (usize, Option<usize>) {
256 (self.len, Some(self.len))
257 }
258
259 fn fold<B, F>(mut self, init: B, mut f: F) -> B
260 where
261 Self: Sized,
262 F: FnMut(B, usize) -> B,
263 {
264 if self.len == 0 {
266 return init;
267 }
268
269 let mut accum = init;
270 'outer: loop {
271 for i0 in self.inner_pos[0].index()..self.inner_pos[0].size() {
272 for i1 in self.inner_pos[1].index()..self.inner_pos[1].size() {
273 let inner_offset =
274 i0 * self.inner_pos[0].stride + i1 * self.inner_pos[1].stride;
275 accum = f(accum, self.outer_offset + inner_offset);
276
277 self.len -= 1;
278 if self.len == 0 {
279 break 'outer;
280 }
281 }
282 self.inner_pos[1].set_index(0);
283 }
284 self.inner_pos[0].set_index(0);
285
286 if !self.step_outer_pos() {
287 break;
288 }
289 }
290
291 accum
292 }
293}
294
295impl ExactSizeIterator for OffsetsBase {}
296
297impl DoubleEndedIterator for OffsetsBase {
298 fn next_back(&mut self) -> Option<usize> {
299 if self.len == 0 {
300 return None;
301 }
302
303 let index = self.len - 1;
306 let offset = self.offset_from_linear_index(index);
307 self.len -= 1;
308
309 Some(offset)
310 }
311}
312
313impl SplitIterator for OffsetsBase {
314 fn split_at(mut self, index: usize) -> (Self, Self) {
317 assert!(self.len >= index);
318
319 let mut right = self.clone();
320 OffsetsBase::step_by(&mut right, index);
321
322 self.truncate(index);
323
324 (self, right)
325 }
326}
327
328pub struct Iter<'a, T> {
330 offsets: Offsets,
331 data: ViewData<'a, T>,
332}
333
334impl<'a, T> Iter<'a, T> {
335 pub(super) fn new<L: Layout + Clone>(view: TensorBase<ViewData<'a, T>, L>) -> Iter<'a, T> {
336 Iter {
337 offsets: Offsets::new(view.layout()),
338 data: view.storage(),
339 }
340 }
341}
342
343impl<T> Clone for Iter<'_, T> {
344 fn clone(&self) -> Self {
345 Iter {
346 offsets: self.offsets.clone(),
347 data: self.data,
348 }
349 }
350}
351
352impl<'a, T> Iterator for Iter<'a, T> {
353 type Item = &'a T;
354
355 #[inline(always)]
356 fn next(&mut self) -> Option<Self::Item> {
357 let offset = self.offsets.next()?;
358
359 Some(unsafe { self.data.get_unchecked(offset) })
361 }
362
363 fn size_hint(&self) -> (usize, Option<usize>) {
364 self.offsets.size_hint()
365 }
366
367 fn nth(&mut self, n: usize) -> Option<Self::Item> {
368 let offset = self.offsets.nth(n)?;
369
370 Some(unsafe { self.data.get_unchecked(offset) })
372 }
373
374 fn fold<B, F>(self, init: B, mut f: F) -> B
375 where
376 Self: Sized,
377 F: FnMut(B, Self::Item) -> B,
378 {
379 self.offsets.fold(init, |acc, offset| {
380 let item = unsafe { self.data.get_unchecked(offset) };
382 f(acc, item)
383 })
384 }
385}
386
387impl<'a, T> DoubleEndedIterator for Iter<'a, T> {
388 fn next_back(&mut self) -> Option<Self::Item> {
389 let offset = self.offsets.next_back()?;
390
391 Some(unsafe { self.data.get_unchecked(offset) })
393 }
394}
395
396impl<T> ExactSizeIterator for Iter<'_, T> {}
397
398impl<T> FusedIterator for Iter<'_, T> {}
399
400unsafe fn transmute_lifetime_mut<'a, 'b, T>(x: &'a mut T) -> &'b mut T {
403 unsafe { transmute::<&'a mut T, &'b mut T>(x) }
404}
405
406pub struct IterMut<'a, T> {
408 offsets: Offsets,
409 data: ViewMutData<'a, T>,
410}
411
412impl<'a, T> IterMut<'a, T> {
413 pub(super) fn new<L: Layout + Clone>(
414 view: TensorBase<ViewMutData<'a, T>, L>,
415 ) -> IterMut<'a, T> {
416 IterMut {
417 offsets: Offsets::new(view.layout()),
418 data: view.into_storage(),
419 }
420 }
421}
422
423impl<'a, T> Iterator for IterMut<'a, T> {
424 type Item = &'a mut T;
425
426 #[inline]
427 fn next(&mut self) -> Option<Self::Item> {
428 let offset = self.offsets.next()?;
429
430 Some(unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) })
433 }
434
435 #[inline]
436 fn size_hint(&self) -> (usize, Option<usize>) {
437 self.offsets.size_hint()
438 }
439
440 fn nth(&mut self, n: usize) -> Option<Self::Item> {
441 let offset = self.offsets.nth(n)?;
442
443 Some(unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) })
446 }
447
448 fn fold<B, F>(mut self, init: B, mut f: F) -> B
449 where
450 Self: Sized,
451 F: FnMut(B, Self::Item) -> B,
452 {
453 self.offsets.fold(init, |acc, offset| {
454 let item = unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) };
457 f(acc, item)
458 })
459 }
460}
461
462impl<T> DoubleEndedIterator for IterMut<'_, T> {
463 fn next_back(&mut self) -> Option<Self::Item> {
464 let offset = self.offsets.next_back()?;
465
466 Some(unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) })
469 }
470}
471
472impl<T> ExactSizeIterator for IterMut<'_, T> {}
473
474impl<T> FusedIterator for IterMut<'_, T> {}
475
476#[derive(Clone)]
477enum OffsetsKind {
478 Range(Range<usize>),
479 Indexing(OffsetsBase),
480}
481
482#[derive(Clone)]
489struct Offsets {
490 base: OffsetsKind,
491}
492
493impl Offsets {
494 fn new<L: Layout>(layout: &L) -> Offsets {
495 Offsets {
496 base: if layout.is_contiguous() {
497 OffsetsKind::Range(0..layout.min_data_len())
498 } else {
499 OffsetsKind::Indexing(OffsetsBase::new(layout))
500 },
501 }
502 }
503}
504
505impl Iterator for Offsets {
506 type Item = usize;
507
508 #[inline]
509 fn next(&mut self) -> Option<Self::Item> {
510 match &mut self.base {
511 OffsetsKind::Range(r) => r.next(),
512 OffsetsKind::Indexing(base) => base.next(),
513 }
514 }
515
516 fn size_hint(&self) -> (usize, Option<usize>) {
517 match &self.base {
518 OffsetsKind::Range(r) => r.size_hint(),
519 OffsetsKind::Indexing(base) => (base.len, Some(base.len)),
520 }
521 }
522
523 fn nth(&mut self, n: usize) -> Option<Self::Item> {
524 match &mut self.base {
525 OffsetsKind::Range(r) => r.nth(n),
526 OffsetsKind::Indexing(base) => {
527 base.step_by(n);
528 self.next()
529 }
530 }
531 }
532
533 fn fold<B, F>(self, init: B, f: F) -> B
534 where
535 Self: Sized,
536 F: FnMut(B, Self::Item) -> B,
537 {
538 match self.base {
539 OffsetsKind::Range(r) => r.fold(init, f),
540 OffsetsKind::Indexing(base) => base.fold(init, f),
541 }
542 }
543}
544
545impl DoubleEndedIterator for Offsets {
546 fn next_back(&mut self) -> Option<Self::Item> {
547 match &mut self.base {
548 OffsetsKind::Range(r) => r.next_back(),
549 OffsetsKind::Indexing(base) => base.next_back(),
550 }
551 }
552}
553
554impl ExactSizeIterator for Offsets {}
555
556impl FusedIterator for Offsets {}
557
558struct LaneRanges {
561 offsets: Offsets,
563
564 dim_size: usize,
566 dim_stride: usize,
567}
568
569impl LaneRanges {
570 fn new<L: Layout + RemoveDim>(layout: &L, dim: usize) -> LaneRanges {
571 let offsets = if layout.is_empty() {
574 Offsets::new(layout)
575 } else {
576 let other_dims = layout.remove_dim(dim);
577 Offsets::new(&other_dims)
578 };
579
580 LaneRanges {
581 offsets,
582 dim_size: layout.size(dim),
583 dim_stride: layout.stride(dim),
584 }
585 }
586
587 fn lane_offset_range(&self, start_offset: usize) -> Range<usize> {
590 lane_offsets(start_offset, self.dim_size, self.dim_stride)
591 }
592}
593
594fn lane_offsets(start_offset: usize, size: usize, stride: usize) -> Range<usize> {
595 start_offset..start_offset + (size - 1) * stride + 1
596}
597
598impl Iterator for LaneRanges {
599 type Item = Range<usize>;
600
601 #[inline]
602 fn next(&mut self) -> Option<Range<usize>> {
603 self.offsets
604 .next()
605 .map(|offset| self.lane_offset_range(offset))
606 }
607
608 fn size_hint(&self) -> (usize, Option<usize>) {
609 self.offsets.size_hint()
610 }
611
612 fn fold<B, F>(self, init: B, mut f: F) -> B
613 where
614 Self: Sized,
615 F: FnMut(B, Self::Item) -> B,
616 {
617 let Self {
618 offsets,
619 dim_size,
620 dim_stride,
621 } = self;
622
623 offsets.fold(init, |acc, offset| {
624 f(acc, lane_offsets(offset, dim_size, dim_stride))
625 })
626 }
627}
628
629impl DoubleEndedIterator for LaneRanges {
630 fn next_back(&mut self) -> Option<Range<usize>> {
631 self.offsets
632 .next_back()
633 .map(|offset| self.lane_offset_range(offset))
634 }
635}
636
637impl ExactSizeIterator for LaneRanges {}
638
639impl FusedIterator for LaneRanges {}
640
641pub struct Lanes<'a, T> {
646 data: ViewData<'a, T>,
647 ranges: LaneRanges,
648 lane_layout: NdLayout<1>,
649}
650
651#[derive(Clone, Debug)]
653pub struct Lane<'a, T> {
654 view: NdTensorView<'a, T, 1>,
655 index: usize,
656}
657
658impl<'a, T> Lane<'a, T> {
659 pub fn as_slice(&self) -> Option<&'a [T]> {
661 self.view.data().map(|data| &data[self.index..])
662 }
663
664 pub fn get(&self, idx: usize) -> Option<&'a T> {
666 self.view.get([idx])
667 }
668
669 pub fn as_view(&self) -> NdTensorView<'a, T, 1> {
671 self.view
672 }
673}
674
675impl<'a, T> From<NdTensorView<'a, T, 1>> for Lane<'a, T> {
676 fn from(val: NdTensorView<'a, T, 1>) -> Self {
677 Lane {
678 view: val,
679 index: 0,
680 }
681 }
682}
683
684impl<'a, T> Iterator for Lane<'a, T> {
685 type Item = &'a T;
686
687 #[inline]
688 fn next(&mut self) -> Option<Self::Item> {
689 if self.index < self.view.len() {
690 let index = self.index;
691 self.index += 1;
692
693 Some(unsafe { self.view.get_unchecked([index]) })
695 } else {
696 None
697 }
698 }
699
700 fn size_hint(&self) -> (usize, Option<usize>) {
701 let size = self.view.size(0);
702 (size, Some(size))
703 }
704}
705
706impl<T> ExactSizeIterator for Lane<'_, T> {}
707
708impl<T> FusedIterator for Lane<'_, T> {}
709
710impl<T: PartialEq> PartialEq<Lane<'_, T>> for Lane<'_, T> {
711 fn eq(&self, other: &Lane<'_, T>) -> bool {
712 self.view.slice(self.index..) == other.view.slice(other.index..)
713 }
714}
715
716impl<T: PartialEq> PartialEq<Lane<'_, T>> for LaneMut<'_, T> {
717 fn eq(&self, other: &Lane<'_, T>) -> bool {
718 self.view.slice(self.index..) == other.view.slice(other.index..)
719 }
720}
721
722impl<'a, T> Lanes<'a, T> {
723 pub(crate) fn new<L: Layout + RemoveDim + Clone>(
726 view: TensorBase<ViewData<'a, T>, L>,
727 dim: usize,
728 ) -> Lanes<'a, T> {
729 let size = view.size(dim);
730 let stride = view.stride(dim);
731 let lane_layout =
732 NdLayout::from_shape_and_strides([size], [stride], OverlapPolicy::AllowOverlap)
733 .unwrap();
734 Lanes {
735 data: view.storage(),
736 ranges: LaneRanges::new(view.layout(), dim),
737 lane_layout,
738 }
739 }
740}
741
742fn lane_for_offset_range<T>(
743 data: ViewData<T>,
744 layout: NdLayout<1>,
745 offsets: Range<usize>,
746) -> Lane<T> {
747 let view = NdTensorView::from_storage_and_layout(data.slice(offsets), layout);
748 Lane { view, index: 0 }
749}
750
751impl<'a, T> Iterator for Lanes<'a, T> {
752 type Item = Lane<'a, T>;
753
754 #[inline]
756 fn next(&mut self) -> Option<Self::Item> {
757 self.ranges
758 .next()
759 .map(|range| lane_for_offset_range(self.data, self.lane_layout, range))
760 }
761
762 fn size_hint(&self) -> (usize, Option<usize>) {
763 self.ranges.size_hint()
764 }
765
766 fn fold<B, F>(self, init: B, mut f: F) -> B
767 where
768 Self: Sized,
769 F: FnMut(B, Self::Item) -> B,
770 {
771 self.ranges.fold(init, |acc, offsets| {
772 let lane = lane_for_offset_range(self.data, self.lane_layout, offsets);
773 f(acc, lane)
774 })
775 }
776}
777
778impl<T> DoubleEndedIterator for Lanes<'_, T> {
779 fn next_back(&mut self) -> Option<Self::Item> {
780 self.ranges
781 .next_back()
782 .map(|range| lane_for_offset_range(self.data, self.lane_layout, range))
783 }
784}
785
786impl<T> ExactSizeIterator for Lanes<'_, T> {}
787
788impl<T> FusedIterator for Lanes<'_, T> {}
789
790pub struct LanesMut<'a, T> {
796 data: ViewMutData<'a, T>,
797 ranges: LaneRanges,
798 lane_layout: NdLayout<1>,
799}
800
801impl<'a, T> LanesMut<'a, T> {
802 pub(crate) fn new<L: Layout + RemoveDim + Clone>(
805 view: TensorBase<ViewMutData<'a, T>, L>,
806 dim: usize,
807 ) -> LanesMut<'a, T> {
808 assert!(
810 !view.is_broadcast(),
811 "Cannot mutably iterate over broadcasting view"
812 );
813
814 let size = view.size(dim);
815 let stride = view.stride(dim);
816
817 let lane_layout =
821 NdLayout::from_shape_and_strides([size], [stride], OverlapPolicy::AllowOverlap)
822 .unwrap();
823
824 LanesMut {
825 ranges: LaneRanges::new(view.layout(), dim),
826 data: view.into_storage(),
827 lane_layout,
828 }
829 }
830}
831
832impl<'a, T> Iterator for LanesMut<'a, T> {
833 type Item = LaneMut<'a, T>;
834
835 #[inline]
836 fn next(&mut self) -> Option<LaneMut<'a, T>> {
837 self.ranges.next().map(|offsets| {
838 unsafe {
841 LaneMut::from_storage_layout(self.data.to_view_slice_mut(offsets), self.lane_layout)
842 }
843 })
844 }
845
846 fn size_hint(&self) -> (usize, Option<usize>) {
847 self.ranges.size_hint()
848 }
849
850 fn fold<B, F>(mut self, init: B, mut f: F) -> B
851 where
852 Self: Sized,
853 F: FnMut(B, Self::Item) -> B,
854 {
855 self.ranges.fold(init, |acc, offsets| {
856 let lane = unsafe {
859 LaneMut::from_storage_layout(self.data.to_view_slice_mut(offsets), self.lane_layout)
860 };
861 f(acc, lane)
862 })
863 }
864}
865
866impl<'a, T> ExactSizeIterator for LanesMut<'a, T> {}
867
868impl<'a, T> DoubleEndedIterator for LanesMut<'a, T> {
869 fn next_back(&mut self) -> Option<LaneMut<'a, T>> {
870 self.ranges.next_back().map(|offsets| {
871 unsafe {
874 LaneMut::from_storage_layout(self.data.to_view_slice_mut(offsets), self.lane_layout)
875 }
876 })
877 }
878}
879
880#[derive(Debug)]
882pub struct LaneMut<'a, T> {
883 view: NdTensorViewMut<'a, T, 1>,
884 index: usize,
885}
886
887impl<'a, T> LaneMut<'a, T> {
888 unsafe fn from_storage_layout(data: ViewMutData<'a, T>, layout: NdLayout<1>) -> Self {
895 let view = unsafe {
896 NdTensorViewMut::from_storage_and_layout_unchecked(data, layout)
900 };
901 LaneMut { view, index: 0 }
902 }
903
904 pub fn as_slice_mut(&mut self) -> Option<&mut [T]> {
906 self.view.data_mut().map(|data| &mut data[self.index..])
907 }
908
909 pub fn into_view(self) -> NdTensorViewMut<'a, T, 1> {
911 self.view
912 }
913}
914
915impl<'a, T> Iterator for LaneMut<'a, T> {
916 type Item = &'a mut T;
917
918 #[inline]
919 fn next(&mut self) -> Option<Self::Item> {
920 if self.index < self.view.size(0) {
921 let index = self.index;
922 self.index += 1;
923 let item = unsafe { self.view.get_unchecked_mut([index]) };
924
925 Some(unsafe { transmute::<&mut T, Self::Item>(item) })
928 } else {
929 None
930 }
931 }
932
933 #[inline]
934 fn nth(&mut self, nth: usize) -> Option<Self::Item> {
935 self.index = (self.index + nth).min(self.view.size(0));
936 self.next()
937 }
938
939 fn size_hint(&self) -> (usize, Option<usize>) {
940 let size = self.view.size(0);
941 (size, Some(size))
942 }
943}
944
945impl<T> ExactSizeIterator for LaneMut<'_, T> {}
946
947impl<T: PartialEq> PartialEq<LaneMut<'_, T>> for LaneMut<'_, T> {
948 fn eq(&self, other: &LaneMut<'_, T>) -> bool {
949 self.view.slice(self.index..) == other.view.slice(other.index..)
950 }
951}
952
953struct InnerIterBase<L: Layout> {
956 outer_offsets: Offsets,
959 inner_layout: L,
960 inner_data_len: usize,
961}
962
963impl<L: Layout + Clone> InnerIterBase<L> {
964 fn new_impl<PL: Layout, F: Fn(&[usize], &[usize]) -> L>(
965 parent_layout: &PL,
966 inner_dims: usize,
967 make_inner_layout: F,
968 ) -> InnerIterBase<L> {
969 assert!(parent_layout.ndim() >= inner_dims);
970 let outer_dims = parent_layout.ndim() - inner_dims;
971 let parent_shape = parent_layout.shape();
972 let parent_strides = parent_layout.strides();
973 let (outer_shape, inner_shape) = parent_shape.as_ref().split_at(outer_dims);
974 let (outer_strides, inner_strides) = parent_strides.as_ref().split_at(outer_dims);
975
976 let outer_layout = DynLayout::from_shape_and_strides(
977 outer_shape,
978 outer_strides,
979 OverlapPolicy::AllowOverlap,
980 )
981 .unwrap();
982
983 let inner_layout = make_inner_layout(inner_shape, inner_strides);
984
985 InnerIterBase {
986 outer_offsets: Offsets::new(&outer_layout),
987 inner_data_len: inner_layout.min_data_len(),
988 inner_layout,
989 }
990 }
991}
992
993impl<const N: usize> InnerIterBase<NdLayout<N>> {
994 pub(crate) fn new<L: Layout>(parent_layout: &L) -> Self {
995 Self::new_impl(parent_layout, N, |inner_shape, inner_strides| {
996 let inner_shape: [usize; N] = inner_shape.try_into().unwrap();
997 let inner_strides: [usize; N] = inner_strides.try_into().unwrap();
998 NdLayout::from_shape_and_strides(
999 inner_shape,
1000 inner_strides,
1001 OverlapPolicy::AllowOverlap,
1004 )
1005 .expect("failed to create layout")
1006 })
1007 }
1008}
1009
1010impl InnerIterBase<DynLayout> {
1011 pub(crate) fn new_dyn<L: Layout>(parent_layout: &L, inner_dims: usize) -> Self {
1012 Self::new_impl(parent_layout, inner_dims, |inner_shape, inner_strides| {
1013 DynLayout::from_shape_and_strides(
1014 inner_shape,
1015 inner_strides,
1016 OverlapPolicy::AllowOverlap,
1019 )
1020 .expect("failed to create layout")
1021 })
1022 }
1023}
1024
1025impl<L: Layout> Iterator for InnerIterBase<L> {
1026 type Item = Range<usize>;
1028
1029 fn next(&mut self) -> Option<Range<usize>> {
1030 self.outer_offsets
1031 .next()
1032 .map(|offset| offset..offset + self.inner_data_len)
1033 }
1034
1035 fn size_hint(&self) -> (usize, Option<usize>) {
1036 self.outer_offsets.size_hint()
1037 }
1038
1039 fn fold<B, F>(self, init: B, mut f: F) -> B
1040 where
1041 Self: Sized,
1042 F: FnMut(B, Self::Item) -> B,
1043 {
1044 self.outer_offsets.fold(init, |acc, offset| {
1045 f(acc, offset..offset + self.inner_data_len)
1046 })
1047 }
1048}
1049
1050impl<L: Layout> ExactSizeIterator for InnerIterBase<L> {}
1051
1052impl<L: Layout> DoubleEndedIterator for InnerIterBase<L> {
1053 fn next_back(&mut self) -> Option<Self::Item> {
1054 self.outer_offsets
1055 .next_back()
1056 .map(|offset| offset..offset + self.inner_data_len)
1057 }
1058}
1059
1060pub struct InnerIter<'a, T, L: Layout> {
1063 base: InnerIterBase<L>,
1064 data: ViewData<'a, T>,
1065}
1066
1067impl<'a, T, const N: usize> InnerIter<'a, T, NdLayout<N>> {
1068 pub(crate) fn new<L: Layout + Clone>(view: TensorBase<ViewData<'a, T>, L>) -> Self {
1069 let base = InnerIterBase::new(&view);
1070 InnerIter {
1071 base,
1072 data: view.storage(),
1073 }
1074 }
1075}
1076
1077impl<'a, T> InnerIter<'a, T, DynLayout> {
1078 pub(crate) fn new_dyn<L: Layout + Clone>(
1079 view: TensorBase<ViewData<'a, T>, L>,
1080 inner_dims: usize,
1081 ) -> Self {
1082 let base = InnerIterBase::new_dyn(&view, inner_dims);
1083 InnerIter {
1084 base,
1085 data: view.storage(),
1086 }
1087 }
1088}
1089
1090impl<'a, T, L: Layout + Clone> Iterator for InnerIter<'a, T, L> {
1091 type Item = TensorBase<ViewData<'a, T>, L>;
1092
1093 fn next(&mut self) -> Option<Self::Item> {
1094 self.base.next().map(|offset_range| {
1095 TensorBase::from_storage_and_layout(
1096 self.data.slice(offset_range),
1097 self.base.inner_layout.clone(),
1098 )
1099 })
1100 }
1101
1102 fn size_hint(&self) -> (usize, Option<usize>) {
1103 self.base.size_hint()
1104 }
1105
1106 fn fold<B, F>(self, init: B, mut f: F) -> B
1107 where
1108 Self: Sized,
1109 F: FnMut(B, Self::Item) -> B,
1110 {
1111 let inner_layout = self.base.inner_layout.clone();
1112 self.base.fold(init, |acc, offset_range| {
1113 let item = TensorBase::from_storage_and_layout(
1114 self.data.slice(offset_range),
1115 inner_layout.clone(),
1116 );
1117 f(acc, item)
1118 })
1119 }
1120}
1121
1122impl<T, L: Layout + Clone> ExactSizeIterator for InnerIter<'_, T, L> {}
1123
1124impl<T, L: Layout + Clone> DoubleEndedIterator for InnerIter<'_, T, L> {
1125 fn next_back(&mut self) -> Option<Self::Item> {
1126 self.base.next_back().map(|offset_range| {
1127 TensorBase::from_storage_and_layout(
1128 self.data.slice(offset_range),
1129 self.base.inner_layout.clone(),
1130 )
1131 })
1132 }
1133}
1134
1135pub struct InnerIterMut<'a, T, L: Layout> {
1138 base: InnerIterBase<L>,
1139 data: ViewMutData<'a, T>,
1140}
1141
1142impl<'a, T, const N: usize> InnerIterMut<'a, T, NdLayout<N>> {
1143 pub(crate) fn new<L: Layout>(view: TensorBase<ViewMutData<'a, T>, L>) -> Self {
1144 let base = InnerIterBase::new(&view);
1145 InnerIterMut {
1146 base,
1147 data: view.into_storage(),
1148 }
1149 }
1150}
1151
1152impl<'a, T> InnerIterMut<'a, T, DynLayout> {
1153 pub(crate) fn new_dyn<L: Layout>(
1154 view: TensorBase<ViewMutData<'a, T>, L>,
1155 inner_dims: usize,
1156 ) -> Self {
1157 let base = InnerIterBase::new_dyn(&view, inner_dims);
1158 InnerIterMut {
1159 base,
1160 data: view.into_storage(),
1161 }
1162 }
1163}
1164
1165impl<'a, T, L: Layout + Clone> Iterator for InnerIterMut<'a, T, L> {
1166 type Item = TensorBase<ViewMutData<'a, T>, L>;
1167
1168 fn next(&mut self) -> Option<Self::Item> {
1169 self.base.next().map(|offset_range| {
1170 let storage = self.data.slice_mut(offset_range);
1171 let storage = unsafe {
1172 std::mem::transmute::<ViewMutData<'_, T>, ViewMutData<'a, T>>(storage)
1177 };
1178 TensorBase::from_storage_and_layout(storage, self.base.inner_layout.clone())
1179 })
1180 }
1181
1182 fn size_hint(&self) -> (usize, Option<usize>) {
1183 self.base.size_hint()
1184 }
1185
1186 fn fold<B, F>(mut self, init: B, mut f: F) -> B
1187 where
1188 Self: Sized,
1189 F: FnMut(B, Self::Item) -> B,
1190 {
1191 let inner_layout = self.base.inner_layout.clone();
1192 self.base.fold(init, |acc, offset_range| {
1193 let storage = self.data.slice_mut(offset_range);
1194 let storage = unsafe {
1195 std::mem::transmute::<ViewMutData<'_, T>, ViewMutData<'a, T>>(storage)
1200 };
1201 let item = TensorBase::from_storage_and_layout(storage, inner_layout.clone());
1202 f(acc, item)
1203 })
1204 }
1205}
1206
1207impl<T, L: Layout + Clone> ExactSizeIterator for InnerIterMut<'_, T, L> {}
1208
1209impl<'a, T, L: Layout + Clone> DoubleEndedIterator for InnerIterMut<'a, T, L> {
1210 fn next_back(&mut self) -> Option<Self::Item> {
1211 self.base.next_back().map(|offset_range| {
1212 let storage = self.data.slice_mut(offset_range);
1213 let storage = unsafe {
1214 std::mem::transmute::<ViewMutData<'_, T>, ViewMutData<'a, T>>(storage)
1217 };
1218 TensorBase::from_storage_and_layout(storage, self.base.inner_layout.clone())
1219 })
1220 }
1221}
1222
1223pub struct AxisIter<'a, T, L: Layout + RemoveDim> {
1226 view: TensorBase<ViewData<'a, T>, L>,
1227 axis: usize,
1228 index: usize,
1229 end: usize,
1230}
1231
1232impl<'a, T, L: MutLayout + RemoveDim> AxisIter<'a, T, L> {
1233 pub(crate) fn new(view: &TensorBase<ViewData<'a, T>, L>, axis: usize) -> AxisIter<'a, T, L> {
1234 assert!(axis < view.ndim());
1235 AxisIter {
1236 view: view.clone(),
1237 axis,
1238 index: 0,
1239 end: view.size(axis),
1240 }
1241 }
1242}
1243
1244impl<'a, T, L: MutLayout + RemoveDim> Iterator for AxisIter<'a, T, L> {
1245 type Item = TensorBase<ViewData<'a, T>, <L as RemoveDim>::Output>;
1246
1247 fn next(&mut self) -> Option<Self::Item> {
1248 if self.index >= self.end {
1249 None
1250 } else {
1251 let slice = self.view.index_axis(self.axis, self.index);
1252 self.index += 1;
1253 Some(slice)
1254 }
1255 }
1256
1257 fn size_hint(&self) -> (usize, Option<usize>) {
1258 let len = self.end - self.index;
1259 (len, Some(len))
1260 }
1261}
1262
1263impl<'a, T, L: MutLayout + RemoveDim> ExactSizeIterator for AxisIter<'a, T, L> {}
1264
1265impl<'a, T, L: MutLayout + RemoveDim> DoubleEndedIterator for AxisIter<'a, T, L> {
1266 fn next_back(&mut self) -> Option<Self::Item> {
1267 if self.index >= self.end {
1268 None
1269 } else {
1270 let slice = self.view.index_axis(self.axis, self.end - 1);
1271 self.end -= 1;
1272 Some(slice)
1273 }
1274 }
1275}
1276
1277pub struct AxisIterMut<'a, T, L: Layout + RemoveDim> {
1279 view: TensorBase<ViewMutData<'a, T>, L>,
1280 axis: usize,
1281 index: usize,
1282 end: usize,
1283}
1284
1285impl<'a, T, L: Layout + RemoveDim + Clone> AxisIterMut<'a, T, L> {
1286 pub(crate) fn new(
1287 view: TensorBase<ViewMutData<'a, T>, L>,
1288 axis: usize,
1289 ) -> AxisIterMut<'a, T, L> {
1290 assert!(
1292 !view.layout().is_broadcast(),
1293 "Cannot mutably iterate over broadcasting view"
1294 );
1295 assert!(axis < view.ndim());
1296 AxisIterMut {
1297 axis,
1298 index: 0,
1299 end: view.size(axis),
1300 view,
1301 }
1302 }
1303}
1304
1305type SmallerMutView<'b, T, L> = TensorBase<ViewMutData<'b, T>, <L as RemoveDim>::Output>;
1307
1308impl<'a, T, L: MutLayout + RemoveDim> Iterator for AxisIterMut<'a, T, L> {
1309 type Item = TensorBase<ViewMutData<'a, T>, <L as RemoveDim>::Output>;
1310
1311 fn next(&mut self) -> Option<Self::Item> {
1312 if self.index >= self.end {
1313 None
1314 } else {
1315 let index = self.index;
1316 self.index += 1;
1317
1318 let slice = self.view.index_axis_mut(self.axis, index);
1319
1320 let view = unsafe { transmute::<SmallerMutView<'_, T, L>, Self::Item>(slice) };
1325
1326 Some(view)
1327 }
1328 }
1329
1330 fn size_hint(&self) -> (usize, Option<usize>) {
1331 let len = self.end - self.index;
1332 (len, Some(len))
1333 }
1334}
1335
1336impl<'a, T, L: MutLayout + RemoveDim> ExactSizeIterator for AxisIterMut<'a, T, L> {}
1337
1338impl<'a, T, L: MutLayout + RemoveDim> DoubleEndedIterator for AxisIterMut<'a, T, L> {
1339 fn next_back(&mut self) -> Option<Self::Item> {
1340 if self.index >= self.end {
1341 None
1342 } else {
1343 let index = self.end - 1;
1344 self.end -= 1;
1345
1346 let slice = self.view.index_axis_mut(self.axis, index);
1347
1348 let view = unsafe { transmute::<SmallerMutView<'_, T, L>, Self::Item>(slice) };
1353
1354 Some(view)
1355 }
1356 }
1357}
1358
1359pub struct AxisChunks<'a, T, L: MutLayout> {
1362 remainder: Option<TensorBase<ViewData<'a, T>, L>>,
1363 axis: usize,
1364 chunk_size: usize,
1365}
1366
1367impl<'a, T, L: MutLayout> AxisChunks<'a, T, L> {
1368 pub(crate) fn new(
1369 view: &TensorBase<ViewData<'a, T>, L>,
1370 axis: usize,
1371 chunk_size: usize,
1372 ) -> AxisChunks<'a, T, L> {
1373 assert!(chunk_size > 0, "chunk size must be > 0");
1374 AxisChunks {
1375 remainder: if view.size(axis) > 0 {
1376 Some(view.view())
1377 } else {
1378 None
1379 },
1380 axis,
1381 chunk_size,
1382 }
1383 }
1384}
1385
1386impl<'a, T, L: MutLayout> Iterator for AxisChunks<'a, T, L> {
1387 type Item = TensorBase<ViewData<'a, T>, L>;
1388
1389 fn next(&mut self) -> Option<Self::Item> {
1390 let remainder = self.remainder.take()?;
1391 let chunk_len = self.chunk_size.min(remainder.size(self.axis));
1392 let (current, next_remainder) = remainder.split_at(self.axis, chunk_len);
1393 self.remainder = if next_remainder.size(self.axis) > 0 {
1394 Some(next_remainder)
1395 } else {
1396 None
1397 };
1398 Some(current)
1399 }
1400
1401 fn size_hint(&self) -> (usize, Option<usize>) {
1402 let len = self
1403 .remainder
1404 .as_ref()
1405 .map(|r| r.size(self.axis))
1406 .unwrap_or(0)
1407 .div_ceil(self.chunk_size);
1408 (len, Some(len))
1409 }
1410}
1411
1412impl<'a, T, L: MutLayout> ExactSizeIterator for AxisChunks<'a, T, L> {}
1413
1414impl<'a, T, L: MutLayout> DoubleEndedIterator for AxisChunks<'a, T, L> {
1415 fn next_back(&mut self) -> Option<Self::Item> {
1416 let remainder = self.remainder.take()?;
1417 let chunk_len = self.chunk_size.min(remainder.size(self.axis));
1418 let (prev_remainder, current) =
1419 remainder.split_at(self.axis, remainder.size(self.axis) - chunk_len);
1420 self.remainder = if prev_remainder.size(self.axis) > 0 {
1421 Some(prev_remainder)
1422 } else {
1423 None
1424 };
1425 Some(current)
1426 }
1427}
1428
1429pub struct AxisChunksMut<'a, T, L: MutLayout> {
1431 remainder: Option<TensorBase<ViewMutData<'a, T>, L>>,
1432 axis: usize,
1433 chunk_size: usize,
1434}
1435
1436impl<'a, T, L: MutLayout> AxisChunksMut<'a, T, L> {
1437 pub(crate) fn new(
1438 view: TensorBase<ViewMutData<'a, T>, L>,
1439 axis: usize,
1440 chunk_size: usize,
1441 ) -> AxisChunksMut<'a, T, L> {
1442 assert!(
1444 !view.layout().is_broadcast(),
1445 "Cannot mutably iterate over broadcasting view"
1446 );
1447 assert!(chunk_size > 0, "chunk size must be > 0");
1448 AxisChunksMut {
1449 remainder: if view.size(axis) > 0 {
1450 Some(view)
1451 } else {
1452 None
1453 },
1454 axis,
1455 chunk_size,
1456 }
1457 }
1458}
1459
1460impl<'a, T, L: MutLayout> Iterator for AxisChunksMut<'a, T, L> {
1461 type Item = TensorBase<ViewMutData<'a, T>, L>;
1462
1463 fn next(&mut self) -> Option<Self::Item> {
1464 let remainder = self.remainder.take()?;
1465 let chunk_len = self.chunk_size.min(remainder.size(self.axis));
1466 let (current, next_remainder) = remainder.split_at_mut(self.axis, chunk_len);
1467 self.remainder = if next_remainder.size(self.axis) > 0 {
1468 Some(next_remainder)
1469 } else {
1470 None
1471 };
1472 Some(current)
1473 }
1474
1475 fn size_hint(&self) -> (usize, Option<usize>) {
1476 let len = self
1477 .remainder
1478 .as_ref()
1479 .map(|r| r.size(self.axis))
1480 .unwrap_or(0)
1481 .div_ceil(self.chunk_size);
1482 (len, Some(len))
1483 }
1484}
1485
1486impl<'a, T, L: MutLayout> ExactSizeIterator for AxisChunksMut<'a, T, L> {}
1487
1488impl<'a, T, L: MutLayout> DoubleEndedIterator for AxisChunksMut<'a, T, L> {
1489 fn next_back(&mut self) -> Option<Self::Item> {
1490 let remainder = self.remainder.take()?;
1491 let remainder_size = remainder.size(self.axis);
1492 let chunk_len = self.chunk_size.min(remainder_size);
1493 let (prev_remainder, current) =
1494 remainder.split_at_mut(self.axis, remainder_size - chunk_len);
1495 self.remainder = if prev_remainder.size(self.axis) > 0 {
1496 Some(prev_remainder)
1497 } else {
1498 None
1499 };
1500 Some(current)
1501 }
1502}
1503
1504pub(crate) fn for_each_mut<T, F: Fn(&mut T)>(mut view: TensorViewMut<T>, f: F) {
1506 while view.ndim() < 4 {
1507 view.insert_axis(0);
1508 }
1509
1510 view.inner_iter_mut::<4>().for_each(|mut src| {
1516 for i0 in 0..src.size(0) {
1517 for i1 in 0..src.size(1) {
1518 for i2 in 0..src.size(2) {
1519 for i3 in 0..src.size(3) {
1520 let x = unsafe { src.get_unchecked_mut([i0, i1, i2, i3]) };
1522 f(x);
1523 }
1524 }
1525 }
1526 }
1527 });
1528}
1529
1530#[cfg(test)]
1533mod tests {
1534 use super::{AxisChunks, AxisChunksMut, Lanes, LanesMut};
1535 use crate::{AsView, Layout, NdLayout, NdTensor, Tensor};
1536
1537 fn compare_reversed<T: PartialEq + std::fmt::Debug>(fwd: &[T], rev: &[T]) {
1538 assert_eq!(fwd.len(), rev.len());
1539 for (x, y) in fwd.iter().zip(rev.iter().rev()) {
1540 assert_eq!(x, y);
1541 }
1542 }
1543
1544 fn test_iterator<I: Iterator + ExactSizeIterator + DoubleEndedIterator>(
1546 create_iter: impl Fn() -> I,
1547 expected: &[I::Item],
1548 ) where
1549 I::Item: PartialEq + std::fmt::Debug,
1550 {
1551 let iter = create_iter();
1552
1553 let (min_len, max_len) = iter.size_hint();
1554 let items: Vec<_> = iter.collect();
1555
1556 assert_eq!(&items, expected);
1557
1558 assert_eq!(min_len, items.len(), "incorrect size lower bound");
1560 assert_eq!(max_len, Some(items.len()), "incorrect size upper bound");
1561
1562 let rev_items: Vec<_> = create_iter().rev().collect();
1564 compare_reversed(&items, &rev_items);
1565
1566 let mut iter = create_iter();
1568 for _x in &mut iter { }
1569 assert_eq!(iter.next(), None);
1570
1571 let mut fold_items = Vec::new();
1573 let mut idx = 0;
1574 create_iter().fold(0, |acc, item| {
1575 assert_eq!(acc, idx);
1576 fold_items.push(item);
1577 idx += 1;
1578 idx
1579 });
1580 assert_eq!(items, fold_items);
1581 }
1582
1583 trait MutIterable {
1588 type Iter<'a>: Iterator + ExactSizeIterator + DoubleEndedIterator
1589 where
1590 Self: 'a;
1591
1592 fn iter_mut(&mut self) -> Self::Iter<'_>;
1593 }
1594
1595 fn test_mut_iterator<M, T>(mut iterable: M, expected: &[T])
1597 where
1598 M: MutIterable,
1599 T: std::fmt::Debug,
1600 for<'a> <M::Iter<'a> as Iterator>::Item: std::fmt::Debug + PartialEq + PartialEq<T>,
1601 {
1602 {
1604 let iter = iterable.iter_mut();
1605 let (min_len, max_len) = iter.size_hint();
1606 let items: Vec<_> = iter.collect();
1607
1608 assert_eq!(items, expected);
1610
1611 assert_eq!(min_len, items.len(), "incorrect size lower bound");
1613 assert_eq!(max_len, Some(items.len()), "incorrect size upper bound");
1614 }
1615
1616 {
1618 let mut iter = iterable.iter_mut();
1619 for _x in &mut iter { }
1620 assert!(iter.next().is_none());
1621 }
1622
1623 {
1629 let items: Vec<_> = iterable.iter_mut().map(|x| format!("{:?}", x)).collect();
1630 let rev_items: Vec<_> = iterable
1631 .iter_mut()
1632 .rev()
1633 .map(|x| format!("{:?}", x))
1634 .collect();
1635 compare_reversed(&items, &rev_items);
1636 }
1637
1638 {
1640 let items: Vec<_> = iterable.iter_mut().map(|x| format!("{:?}", x)).collect();
1641 let mut fold_items = Vec::new();
1642 let mut idx = 0;
1643 iterable.iter_mut().fold(0, |acc, item| {
1644 assert_eq!(acc, idx);
1645 fold_items.push(format!("{:?}", item));
1646 idx += 1;
1647 idx
1648 });
1649 assert_eq!(items, fold_items);
1650 }
1651 }
1652
1653 #[test]
1654 fn test_axis_chunks() {
1655 let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1656 test_iterator(
1657 || tensor.axis_chunks(0, 1),
1658 &[tensor.slice(0..1), tensor.slice(1..2)],
1659 );
1660 }
1661
1662 #[test]
1663 fn test_axis_chunks_empty() {
1664 let x = Tensor::<i32>::zeros(&[5, 0]);
1665 assert!(AxisChunks::new(&x.view(), 1, 1).next().is_none());
1666 }
1667
1668 #[test]
1669 #[should_panic(expected = "chunk size must be > 0")]
1670 fn test_axis_chunks_zero_size() {
1671 let x = Tensor::<i32>::zeros(&[5, 0]);
1672 assert!(AxisChunks::new(&x.view(), 1, 0).next().is_none());
1673 }
1674
1675 #[test]
1676 fn test_axis_chunks_mut_empty() {
1677 let mut x = Tensor::<i32>::zeros(&[5, 0]);
1678 assert!(AxisChunksMut::new(x.view_mut(), 1, 1).next().is_none());
1679 }
1680
1681 #[test]
1682 fn test_axis_chunks_mut_rev() {
1683 let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1684 let fwd: Vec<_> = tensor
1685 .axis_chunks_mut(0, 1)
1686 .map(|view| view.to_vec())
1687 .collect();
1688 let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1689 let rev: Vec<_> = tensor
1690 .axis_chunks_mut(0, 1)
1691 .rev()
1692 .map(|view| view.to_vec())
1693 .collect();
1694 compare_reversed(&fwd, &rev);
1695 }
1696
1697 #[test]
1698 #[should_panic(expected = "chunk size must be > 0")]
1699 fn test_axis_chunks_mut_zero_size() {
1700 let mut x = Tensor::<i32>::zeros(&[5, 0]);
1701 assert!(AxisChunksMut::new(x.view_mut(), 1, 0).next().is_none());
1702 }
1703
1704 #[test]
1705 fn test_axis_iter() {
1706 let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1707 test_iterator(|| tensor.axis_iter(0), &[tensor.slice(0), tensor.slice(1)]);
1708 }
1709
1710 #[test]
1711 fn test_axis_iter_mut_rev() {
1712 let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1713 let fwd: Vec<_> = tensor.axis_iter_mut(0).map(|view| view.to_vec()).collect();
1714 let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1715 let rev: Vec<_> = tensor
1716 .axis_iter_mut(0)
1717 .rev()
1718 .map(|view| view.to_vec())
1719 .collect();
1720 compare_reversed(&fwd, &rev);
1721 }
1722
1723 #[test]
1724 fn test_inner_iter() {
1725 let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1726 test_iterator(
1727 || tensor.inner_iter::<2>(),
1728 &[tensor.slice(0), tensor.slice(1)],
1729 );
1730 }
1731
1732 #[test]
1733 fn test_inner_iter_mut() {
1734 struct InnerIterMutTest(NdTensor<i32, 3>);
1735
1736 impl MutIterable for InnerIterMutTest {
1737 type Iter<'a> = super::InnerIterMut<'a, i32, NdLayout<2>>;
1738
1739 fn iter_mut(&mut self) -> Self::Iter<'_> {
1740 self.0.inner_iter_mut::<2>()
1741 }
1742 }
1743
1744 let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1745 test_mut_iterator(
1746 InnerIterMutTest(tensor.clone()),
1747 &[tensor.slice(0), tensor.slice(1)],
1748 );
1749 }
1750
1751 #[test]
1752 fn test_lanes() {
1753 let x = NdTensor::from([[1, 2], [3, 4]]);
1754 test_iterator(
1755 || x.lanes(0),
1756 &[x.slice((.., 0)).into(), x.slice((.., 1)).into()],
1757 );
1758 test_iterator(|| x.lanes(1), &[x.slice(0).into(), x.slice(1).into()]);
1759 }
1760
1761 #[test]
1762 fn test_lanes_empty() {
1763 let x = Tensor::<i32>::zeros(&[5, 0]);
1764 assert!(Lanes::new(x.view().view_ref(), 0).next().is_none());
1765 assert!(Lanes::new(x.view().view_ref(), 1).next().is_none());
1766 }
1767
1768 #[test]
1769 fn test_lanes_mut() {
1770 use super::Lane;
1771
1772 struct LanesMutTest(NdTensor<i32, 2>);
1773
1774 impl MutIterable for LanesMutTest {
1775 type Iter<'a> = super::LanesMut<'a, i32>;
1776
1777 fn iter_mut(&mut self) -> Self::Iter<'_> {
1778 self.0.lanes_mut(0)
1779 }
1780 }
1781
1782 let tensor = NdTensor::from([[1, 2], [3, 4]]);
1783 test_mut_iterator::<_, Lane<i32>>(
1784 LanesMutTest(tensor.clone()),
1785 &[
1786 Lane::from(tensor.slice((.., 0))),
1787 Lane::from(tensor.slice((.., 1))),
1788 ],
1789 );
1790 }
1791
1792 #[test]
1793 fn test_lane_as_slice() {
1794 let x = NdTensor::from([0, 1, 2]);
1796 let mut lane = x.lanes(0).next().unwrap();
1797 assert_eq!(lane.as_slice(), Some([0, 1, 2].as_slice()));
1798 lane.next();
1799 assert_eq!(lane.as_slice(), Some([1, 2].as_slice()));
1800 lane.next();
1801 lane.next();
1802 assert_eq!(lane.as_slice(), Some([0i32; 0].as_slice()));
1803 lane.next();
1804 assert_eq!(lane.as_slice(), Some([0i32; 0].as_slice()));
1805
1806 let x = NdTensor::from([[1i32, 2], [3, 4]]);
1808 let lane = x.lanes(0).next().unwrap();
1809 assert_eq!(lane.as_slice(), None);
1810 }
1811
1812 #[test]
1813 fn test_lanes_mut_empty() {
1814 let mut x = Tensor::<i32>::zeros(&[5, 0]);
1815 assert!(LanesMut::new(x.mut_view_ref(), 0).next().is_none());
1816 assert!(LanesMut::new(x.mut_view_ref(), 1).next().is_none());
1817 }
1818
1819 #[test]
1820 fn test_iter_step_by() {
1821 let tensor = Tensor::<f32>::full(&[1, 3, 16, 8], 1.);
1822
1823 let tensor = tensor.slice((.., .., 1.., ..));
1826
1827 let sum = tensor.iter().sum::<f32>();
1828 for n_skip in 0..tensor.len() {
1829 let sum_skip = tensor.iter().skip(n_skip).sum::<f32>();
1830 assert_eq!(
1831 sum_skip,
1832 sum - n_skip as f32,
1833 "wrong sum for n_skip={}",
1834 n_skip
1835 );
1836 }
1837 }
1838
1839 #[test]
1840 fn test_iter_broadcast() {
1841 let tensor = Tensor::<f32>::full(&[1], 1.);
1842 let broadcast = tensor.broadcast([1, 3, 16, 8]);
1843 assert_eq!(broadcast.iter().len(), broadcast.len());
1844 let count = broadcast.iter().count();
1845 assert_eq!(count, broadcast.len());
1846 let sum = broadcast.iter().sum::<f32>();
1847 assert_eq!(sum, broadcast.len() as f32);
1848 }
1849
1850 #[test]
1851 fn test_iter() {
1852 let tensor = NdTensor::from([[[1, 2], [3, 4]]]);
1853
1854 test_iterator(|| tensor.iter().copied(), &[1, 2, 3, 4]);
1856
1857 test_iterator(|| tensor.transposed().iter().copied(), &[1, 3, 2, 4]);
1859 }
1860
1861 #[test]
1862 fn test_iter_mut() {
1863 struct IterTest(NdTensor<i32, 3>);
1864
1865 impl MutIterable for IterTest {
1866 type Iter<'a> = super::IterMut<'a, i32>;
1867
1868 fn iter_mut(&mut self) -> Self::Iter<'_> {
1869 self.0.iter_mut()
1870 }
1871 }
1872
1873 let tensor = NdTensor::from([[[1, 2], [3, 4]]]);
1874 test_mut_iterator(IterTest(tensor), &[&1, &2, &3, &4]);
1875 }
1876
1877 #[test]
1878 #[ignore]
1879 fn bench_iter() {
1880 use crate::Layout;
1881 use rten_bench::run_bench;
1882
1883 type Elem = i32;
1884
1885 let tensor = std::hint::black_box(Tensor::<Elem>::full(&[1, 6, 768, 64], 1));
1886 let n_trials = 1000;
1887 let mut result = Elem::default();
1888
1889 fn reduce<'a>(iter: impl Iterator<Item = &'a Elem>) -> Elem {
1890 iter.fold(Elem::default(), |acc, x| acc.wrapping_add(*x))
1891 }
1892
1893 run_bench(n_trials, Some("slice iter"), || {
1895 result = reduce(tensor.data().unwrap().iter());
1896 });
1897 println!("sum {}", result);
1898
1899 run_bench(n_trials, Some("contiguous iter"), || {
1902 result = reduce(tensor.iter());
1903 });
1904 println!("sum {}", result);
1905
1906 run_bench(n_trials, Some("contiguous reverse iter"), || {
1907 result = reduce(tensor.iter().rev());
1908 });
1909 println!("sum {}", result);
1910
1911 let slice = tensor.slice((.., .., 1.., ..));
1914 assert!(!slice.is_contiguous());
1915 let n_trials = 1000;
1916 run_bench(n_trials, Some("non-contiguous iter"), || {
1917 result = reduce(slice.iter());
1918 });
1919 println!("sum {}", result);
1920
1921 let n_trials = 100;
1924 run_bench(n_trials, Some("non-contiguous reverse iter"), || {
1925 result = reduce(slice.iter().rev());
1926 });
1927 println!("sum {}", result);
1928 }
1929
1930 #[test]
1931 #[ignore]
1932 fn bench_inner_iter() {
1933 use crate::rng::XorShiftRng;
1934 use rten_bench::run_bench;
1935
1936 let n_trials = 100;
1937 let mut rng = XorShiftRng::new(1234);
1938
1939 let tensor = Tensor::<f32>::rand(&[512, 512, 12, 1], &mut rng);
1943
1944 let mut sum = 0.;
1945 run_bench(n_trials, Some("inner iter"), || {
1946 for inner in tensor.inner_iter::<2>() {
1947 for i0 in 0..inner.size(0) {
1948 for i1 in 0..inner.size(1) {
1949 sum += inner[[i0, i1]];
1950 }
1951 }
1952 }
1953 });
1954 println!("sum {}", sum);
1955 }
1956}