rten_tensor/
iterators.rs

1use std::iter::FusedIterator;
2use std::mem::transmute;
3use std::ops::Range;
4
5use rten_base::iter::SplitIterator;
6
7use super::{
8    AsView, DynLayout, MutLayout, NdTensorView, NdTensorViewMut, TensorBase, TensorViewMut,
9};
10use crate::layout::{Layout, NdLayout, OverlapPolicy, RemoveDim, merge_axes};
11use crate::storage::{StorageMut, ViewData, ViewMutData};
12
13mod parallel;
14
15/// Tracks the iteration position within a single dimension.
16#[derive(Copy, Clone, Debug, Default)]
17struct IterPos {
18    /// Remaining steps along this dimension before it needs to be reset.
19    remaining: usize,
20
21    /// Current index in this dimension pre-multiplied by stride.
22    offset: usize,
23
24    /// Update to `offset` for each step.
25    stride: usize,
26
27    /// Maximum value of `self.remaining`. Used when resetting position.
28    max_remaining: usize,
29}
30
31impl IterPos {
32    fn from_size_stride(size: usize, stride: usize) -> Self {
33        let remaining = size.saturating_sub(1);
34        IterPos {
35            remaining,
36            offset: 0,
37            stride,
38            max_remaining: remaining,
39        }
40    }
41
42    #[inline(always)]
43    fn step(&mut self) -> bool {
44        if self.remaining != 0 {
45            self.remaining -= 1;
46            self.offset += self.stride;
47            true
48        } else {
49            self.remaining = self.max_remaining;
50            self.offset = 0;
51            false
52        }
53    }
54
55    /// Return the size of this dimension.
56    fn size(&self) -> usize {
57        // nb. The size is always > 0 since if any dim has zero size, the
58        // iterator will have a length of zero.
59        self.max_remaining + 1
60    }
61
62    /// Return the current index along this dimension.
63    fn index(&self) -> usize {
64        self.max_remaining - self.remaining
65    }
66
67    /// Set the current index along this dimension.
68    fn set_index(&mut self, index: usize) {
69        self.remaining = self.max_remaining - index;
70        self.offset = index * self.stride;
71    }
72}
73
74const INNER_NDIM: usize = 2;
75
76/// Iterator over offsets of a tensor's elements.
77#[derive(Clone, Debug)]
78struct OffsetsBase {
79    /// Remaining number of elements this iterator will yield.
80    ///
81    /// The offsets and positions in other fields are only valid if this is
82    /// non-zero.
83    len: usize,
84
85    /// Component of next element offset from innermost (fastest-changing) dims.
86    inner_offset: usize,
87
88    /// Current position in innermost dims.
89    inner_pos: [IterPos; INNER_NDIM],
90
91    /// Component of next element offset from outermost (slowest-changing) dims.
92    outer_offset: usize,
93
94    /// Current position in outermost dims.
95    ///
96    /// Optimization note: The number of outermost dims will usually be small,
97    /// so you might be tempted to use `SmallVec`. However this resulted in
98    /// worse performance for `IndexingIterBase::step`, as the compiler was
99    /// less likely/able to unroll iteration loops.
100    outer_pos: Vec<IterPos>,
101}
102
103impl OffsetsBase {
104    /// Create an iterator over element offsets in `tensor`.
105    fn new<L: Layout>(layout: &L) -> OffsetsBase {
106        // Merge axes to maximize the number of iterations that use the fast
107        // path for stepping over the inner dimensions.
108        let merged = merge_axes(layout.shape().as_ref(), layout.strides().as_ref());
109
110        let inner_pos_pad = INNER_NDIM.saturating_sub(merged.len());
111        let n_outer = merged.len().saturating_sub(INNER_NDIM);
112
113        let inner_pos = std::array::from_fn(|dim| {
114            let (size, stride) = if dim < inner_pos_pad {
115                (1, 0)
116            } else {
117                merged[n_outer + dim - inner_pos_pad]
118            };
119            IterPos::from_size_stride(size, stride)
120        });
121
122        let outer_pos = (0..n_outer)
123            .map(|i| {
124                let (size, stride) = merged[i];
125                IterPos::from_size_stride(size, stride)
126            })
127            .collect();
128
129        OffsetsBase {
130            len: merged.iter().map(|dim| dim.0).product(),
131            inner_pos,
132            inner_offset: 0,
133            outer_pos,
134            outer_offset: 0,
135        }
136    }
137
138    /// Step in the outer dimensions.
139    ///
140    /// Returns `true` if the position was advanced or `false` if the end was
141    /// reached.
142    fn step_outer_pos(&mut self) -> bool {
143        let mut done = self.outer_pos.is_empty();
144        for (i, dim) in self.outer_pos.iter_mut().enumerate().rev() {
145            if dim.step() {
146                break;
147            } else if i == 0 {
148                done = true;
149            }
150        }
151        self.outer_offset = self.outer_pos.iter().map(|p| p.offset).sum();
152        !done
153    }
154
155    fn pos(&self, dim: usize) -> IterPos {
156        let outer_ndim = self.outer_pos.len();
157        if dim >= outer_ndim {
158            self.inner_pos[dim - outer_ndim]
159        } else {
160            self.outer_pos[dim]
161        }
162    }
163
164    fn pos_mut(&mut self, dim: usize) -> &mut IterPos {
165        let outer_ndim = self.outer_pos.len();
166        if dim >= outer_ndim {
167            &mut self.inner_pos[dim - outer_ndim]
168        } else {
169            &mut self.outer_pos[dim]
170        }
171    }
172
173    /// Advance iterator by up to `n` indices.
174    fn step_by(&mut self, n: usize) {
175        let mut remaining = n.min(self.len);
176        self.len -= remaining;
177
178        for dim in (0..self.ndim()).rev() {
179            if remaining == 0 {
180                break;
181            }
182
183            let pos = self.pos_mut(dim);
184            let size = pos.size();
185            let new_index = pos.index() + remaining;
186            pos.set_index(new_index % size);
187            remaining = new_index / size;
188        }
189
190        // Update offset of next element.
191        self.inner_offset = self.inner_pos.iter().map(|p| p.offset).sum();
192        self.outer_offset = self.outer_pos.iter().map(|p| p.offset).sum();
193    }
194
195    fn ndim(&self) -> usize {
196        self.outer_pos.len() + self.inner_pos.len()
197    }
198
199    /// Compute the storage offset of an element given a linear index into a
200    /// tensor's element sequence.
201    fn offset_from_linear_index(&self, index: usize) -> usize {
202        let mut offset = 0;
203        let mut shape_product = 1;
204        for dim in (0..self.ndim()).rev() {
205            let pos = self.pos(dim);
206            let dim_index = (index / shape_product) % pos.size();
207            shape_product *= pos.size();
208            offset += dim_index * pos.stride;
209        }
210        offset
211    }
212
213    /// Truncate this iterator so that it yields at most `len` elements.
214    fn truncate(&mut self, len: usize) {
215        // We adjust `self.len` here but not any of the iteration positions.
216        // This means that methods like `next` and `fold` must always check
217        // `self.len` before each step.
218        self.len = self.len.min(len);
219    }
220}
221
222impl Iterator for OffsetsBase {
223    type Item = usize;
224
225    #[inline(always)]
226    fn next(&mut self) -> Option<usize> {
227        if self.len == 0 {
228            return None;
229        }
230        let offset = self.outer_offset + self.inner_offset;
231
232        self.len -= 1;
233
234        // Optimistically update offset, assuming we haven't reached the
235        // end of the last dimension.
236        self.inner_offset += self.inner_pos[1].stride;
237
238        // Use a fast path to step inner dimensions and fall back to the slower
239        // path to step the outer dimensions only when we reach the end.
240        if !self.inner_pos[1].step() {
241            if !self.inner_pos[0].step() {
242                self.step_outer_pos();
243            }
244
245            // `inner_offset` is the sum of `inner_pos[i].offset`. It only
246            // contains two entries, and we know `inner_pos[1].offset` is zero
247            // since `inner_pos[1].step()` returned false. Hence we can use
248            // an assignment.
249            self.inner_offset = self.inner_pos[0].offset;
250        }
251
252        Some(offset)
253    }
254
255    fn size_hint(&self) -> (usize, Option<usize>) {
256        (self.len, Some(self.len))
257    }
258
259    fn fold<B, F>(mut self, init: B, mut f: F) -> B
260    where
261        Self: Sized,
262        F: FnMut(B, usize) -> B,
263    {
264        // Iter positions are only valid if `self.len > 0`.
265        if self.len == 0 {
266            return init;
267        }
268
269        let mut accum = init;
270        'outer: loop {
271            for i0 in self.inner_pos[0].index()..self.inner_pos[0].size() {
272                for i1 in self.inner_pos[1].index()..self.inner_pos[1].size() {
273                    let inner_offset =
274                        i0 * self.inner_pos[0].stride + i1 * self.inner_pos[1].stride;
275                    accum = f(accum, self.outer_offset + inner_offset);
276
277                    self.len -= 1;
278                    if self.len == 0 {
279                        break 'outer;
280                    }
281                }
282                self.inner_pos[1].set_index(0);
283            }
284            self.inner_pos[0].set_index(0);
285
286            if !self.step_outer_pos() {
287                break;
288            }
289        }
290
291        accum
292    }
293}
294
295impl ExactSizeIterator for OffsetsBase {}
296
297impl DoubleEndedIterator for OffsetsBase {
298    fn next_back(&mut self) -> Option<usize> {
299        if self.len == 0 {
300            return None;
301        }
302
303        // This is inefficient compared to forward iteration, but that's OK
304        // because reverse iteration is not performance critical.
305        let index = self.len - 1;
306        let offset = self.offset_from_linear_index(index);
307        self.len -= 1;
308
309        Some(offset)
310    }
311}
312
313impl SplitIterator for OffsetsBase {
314    /// Split this iterator into two. The left result visits indices before
315    /// `index`, the right result visits indices from `index` onwards.
316    fn split_at(mut self, index: usize) -> (Self, Self) {
317        assert!(self.len >= index);
318
319        let mut right = self.clone();
320        OffsetsBase::step_by(&mut right, index);
321
322        self.truncate(index);
323
324        (self, right)
325    }
326}
327
328/// Iterator over elements of a tensor, in their logical order.
329pub struct Iter<'a, T> {
330    offsets: Offsets,
331    data: ViewData<'a, T>,
332}
333
334impl<'a, T> Iter<'a, T> {
335    pub(super) fn new<L: Layout + Clone>(view: TensorBase<ViewData<'a, T>, L>) -> Iter<'a, T> {
336        Iter {
337            offsets: Offsets::new(view.layout()),
338            data: view.storage(),
339        }
340    }
341}
342
343impl<T> Clone for Iter<'_, T> {
344    fn clone(&self) -> Self {
345        Iter {
346            offsets: self.offsets.clone(),
347            data: self.data,
348        }
349    }
350}
351
352impl<'a, T> Iterator for Iter<'a, T> {
353    type Item = &'a T;
354
355    #[inline(always)]
356    fn next(&mut self) -> Option<Self::Item> {
357        let offset = self.offsets.next()?;
358
359        // Safety: Offset is valid for data length.
360        Some(unsafe { self.data.get_unchecked(offset) })
361    }
362
363    fn size_hint(&self) -> (usize, Option<usize>) {
364        self.offsets.size_hint()
365    }
366
367    fn nth(&mut self, n: usize) -> Option<Self::Item> {
368        let offset = self.offsets.nth(n)?;
369
370        // Safety: Offset is valid for data length.
371        Some(unsafe { self.data.get_unchecked(offset) })
372    }
373
374    fn fold<B, F>(self, init: B, mut f: F) -> B
375    where
376        Self: Sized,
377        F: FnMut(B, Self::Item) -> B,
378    {
379        self.offsets.fold(init, |acc, offset| {
380            // Safety: Offset is valid for data length.
381            let item = unsafe { self.data.get_unchecked(offset) };
382            f(acc, item)
383        })
384    }
385}
386
387impl<'a, T> DoubleEndedIterator for Iter<'a, T> {
388    fn next_back(&mut self) -> Option<Self::Item> {
389        let offset = self.offsets.next_back()?;
390
391        // Safety: Offset is valid for data length.
392        Some(unsafe { self.data.get_unchecked(offset) })
393    }
394}
395
396impl<T> ExactSizeIterator for Iter<'_, T> {}
397
398impl<T> FusedIterator for Iter<'_, T> {}
399
400/// Wrapper around [`transmute`] which allows transmuting only the lifetime,
401/// not the type, of a reference.
402unsafe fn transmute_lifetime_mut<'a, 'b, T>(x: &'a mut T) -> &'b mut T {
403    unsafe { transmute::<&'a mut T, &'b mut T>(x) }
404}
405
406/// Mutable iterator over elements of a tensor.
407pub struct IterMut<'a, T> {
408    offsets: Offsets,
409    data: ViewMutData<'a, T>,
410}
411
412impl<'a, T> IterMut<'a, T> {
413    pub(super) fn new<L: Layout + Clone>(
414        view: TensorBase<ViewMutData<'a, T>, L>,
415    ) -> IterMut<'a, T> {
416        IterMut {
417            offsets: Offsets::new(view.layout()),
418            data: view.into_storage(),
419        }
420    }
421}
422
423impl<'a, T> Iterator for IterMut<'a, T> {
424    type Item = &'a mut T;
425
426    #[inline]
427    fn next(&mut self) -> Option<Self::Item> {
428        let offset = self.offsets.next()?;
429
430        // Safety: Offset is valid for data length, `offsets.next` yields each
431        // offset only once.
432        Some(unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) })
433    }
434
435    #[inline]
436    fn size_hint(&self) -> (usize, Option<usize>) {
437        self.offsets.size_hint()
438    }
439
440    fn nth(&mut self, n: usize) -> Option<Self::Item> {
441        let offset = self.offsets.nth(n)?;
442
443        // Safety: Offset is valid for data length, `offsets.next` yields each
444        // offset only once.
445        Some(unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) })
446    }
447
448    fn fold<B, F>(mut self, init: B, mut f: F) -> B
449    where
450        Self: Sized,
451        F: FnMut(B, Self::Item) -> B,
452    {
453        self.offsets.fold(init, |acc, offset| {
454            // Safety: Offset is valid for data length, `offsets.fold` yields
455            // each offset only once.
456            let item = unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) };
457            f(acc, item)
458        })
459    }
460}
461
462impl<T> DoubleEndedIterator for IterMut<'_, T> {
463    fn next_back(&mut self) -> Option<Self::Item> {
464        let offset = self.offsets.next_back()?;
465
466        // Safety: Offset is valid for data length, `offsets.next` yields each
467        // offset only once.
468        Some(unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) })
469    }
470}
471
472impl<T> ExactSizeIterator for IterMut<'_, T> {}
473
474impl<T> FusedIterator for IterMut<'_, T> {}
475
476#[derive(Clone)]
477enum OffsetsKind {
478    Range(Range<usize>),
479    Indexing(OffsetsBase),
480}
481
482/// Iterator over element offsets of a tensor.
483///
484/// `Offsets` does not hold a reference to the tensor, allowing the tensor to
485/// be modified during iteration. It is the caller's responsibilty not to modify
486/// the tensor in ways that invalidate the offset sequence returned by this
487/// iterator.
488#[derive(Clone)]
489struct Offsets {
490    base: OffsetsKind,
491}
492
493impl Offsets {
494    pub fn new<L: Layout>(layout: &L) -> Offsets {
495        Offsets {
496            base: if layout.is_contiguous() {
497                OffsetsKind::Range(0..layout.min_data_len())
498            } else {
499                OffsetsKind::Indexing(OffsetsBase::new(layout))
500            },
501        }
502    }
503}
504
505impl Iterator for Offsets {
506    type Item = usize;
507
508    #[inline]
509    fn next(&mut self) -> Option<Self::Item> {
510        match &mut self.base {
511            OffsetsKind::Range(r) => r.next(),
512            OffsetsKind::Indexing(base) => base.next(),
513        }
514    }
515
516    fn size_hint(&self) -> (usize, Option<usize>) {
517        match &self.base {
518            OffsetsKind::Range(r) => r.size_hint(),
519            OffsetsKind::Indexing(base) => (base.len, Some(base.len)),
520        }
521    }
522
523    fn nth(&mut self, n: usize) -> Option<Self::Item> {
524        match &mut self.base {
525            OffsetsKind::Range(r) => r.nth(n),
526            OffsetsKind::Indexing(base) => {
527                base.step_by(n);
528                self.next()
529            }
530        }
531    }
532
533    fn fold<B, F>(self, init: B, f: F) -> B
534    where
535        Self: Sized,
536        F: FnMut(B, Self::Item) -> B,
537    {
538        match self.base {
539            OffsetsKind::Range(r) => r.fold(init, f),
540            OffsetsKind::Indexing(base) => base.fold(init, f),
541        }
542    }
543}
544
545impl DoubleEndedIterator for Offsets {
546    fn next_back(&mut self) -> Option<Self::Item> {
547        match &mut self.base {
548            OffsetsKind::Range(r) => r.next_back(),
549            OffsetsKind::Indexing(base) => base.next_back(),
550        }
551    }
552}
553
554impl ExactSizeIterator for Offsets {}
555
556impl FusedIterator for Offsets {}
557
558/// Iterator over the ranges of a tensor's data that correspond to 1D lanes
559/// along a particular dimension.
560struct LaneRanges {
561    /// Start offsets of each lane.
562    offsets: Offsets,
563
564    // Number of elements in each lane and gap between them.
565    dim_size: usize,
566    dim_stride: usize,
567}
568
569impl LaneRanges {
570    fn new<L: Layout + RemoveDim>(layout: &L, dim: usize) -> LaneRanges {
571        // If the layout is empty (has any zero-sized dims), we need to make
572        // sure that `offsets` is as well.
573        let offsets = if layout.is_empty() {
574            Offsets::new(layout)
575        } else {
576            let other_dims = layout.remove_dim(dim);
577            Offsets::new(&other_dims)
578        };
579
580        LaneRanges {
581            offsets,
582            dim_size: layout.size(dim),
583            dim_stride: layout.stride(dim),
584        }
585    }
586
587    /// Return the range of storage offsets for a 1D lane where the first
588    /// element is at `start_offset`.
589    fn lane_offset_range(&self, start_offset: usize) -> Range<usize> {
590        lane_offsets(start_offset, self.dim_size, self.dim_stride)
591    }
592}
593
594fn lane_offsets(start_offset: usize, size: usize, stride: usize) -> Range<usize> {
595    start_offset..start_offset + (size - 1) * stride + 1
596}
597
598impl Iterator for LaneRanges {
599    type Item = Range<usize>;
600
601    #[inline]
602    fn next(&mut self) -> Option<Range<usize>> {
603        self.offsets
604            .next()
605            .map(|offset| self.lane_offset_range(offset))
606    }
607
608    fn size_hint(&self) -> (usize, Option<usize>) {
609        self.offsets.size_hint()
610    }
611
612    fn fold<B, F>(self, init: B, mut f: F) -> B
613    where
614        Self: Sized,
615        F: FnMut(B, Self::Item) -> B,
616    {
617        let Self {
618            offsets,
619            dim_size,
620            dim_stride,
621        } = self;
622
623        offsets.fold(init, |acc, offset| {
624            f(acc, lane_offsets(offset, dim_size, dim_stride))
625        })
626    }
627}
628
629impl DoubleEndedIterator for LaneRanges {
630    fn next_back(&mut self) -> Option<Range<usize>> {
631        self.offsets
632            .next_back()
633            .map(|offset| self.lane_offset_range(offset))
634    }
635}
636
637impl ExactSizeIterator for LaneRanges {}
638
639impl FusedIterator for LaneRanges {}
640
641/// Iterator over 1D slices of a tensor along a target dimension of size N.
642///
643/// Conceptually this iterator steps through every distinct slice of a tensor
644/// where a target dim is varied from 0..N and other indices are held fixed.
645pub struct Lanes<'a, T> {
646    data: ViewData<'a, T>,
647    ranges: LaneRanges,
648    lane_layout: NdLayout<1>,
649}
650
651/// Iterator over items in a 1D slice of a tensor.
652#[derive(Clone, Debug)]
653pub struct Lane<'a, T> {
654    view: NdTensorView<'a, T, 1>,
655    index: usize,
656}
657
658impl<'a, T> Lane<'a, T> {
659    /// Return the remaining part of the lane as a slice, if it is contiguous.
660    pub fn as_slice(&self) -> Option<&'a [T]> {
661        self.view.data().map(|data| &data[self.index..])
662    }
663
664    /// Return the item at a given index in this lane.
665    pub fn get(&self, idx: usize) -> Option<&'a T> {
666        self.view.get([idx])
667    }
668
669    /// Return the entire lane as a 1D tensor view.
670    pub fn as_view(&self) -> NdTensorView<'a, T, 1> {
671        self.view
672    }
673}
674
675impl<'a, T> From<NdTensorView<'a, T, 1>> for Lane<'a, T> {
676    fn from(val: NdTensorView<'a, T, 1>) -> Self {
677        Lane {
678            view: val,
679            index: 0,
680        }
681    }
682}
683
684impl<'a, T> Iterator for Lane<'a, T> {
685    type Item = &'a T;
686
687    #[inline]
688    fn next(&mut self) -> Option<Self::Item> {
689        if self.index < self.view.len() {
690            let index = self.index;
691            self.index += 1;
692
693            // Safety: Index is in bounds for axis 0.
694            Some(unsafe { self.view.get_unchecked([index]) })
695        } else {
696            None
697        }
698    }
699
700    fn size_hint(&self) -> (usize, Option<usize>) {
701        let size = self.view.size(0);
702        (size, Some(size))
703    }
704}
705
706impl<T> ExactSizeIterator for Lane<'_, T> {}
707
708impl<T> FusedIterator for Lane<'_, T> {}
709
710impl<T: PartialEq> PartialEq<Lane<'_, T>> for Lane<'_, T> {
711    fn eq(&self, other: &Lane<'_, T>) -> bool {
712        self.view.slice(self.index..) == other.view.slice(other.index..)
713    }
714}
715
716impl<T: PartialEq> PartialEq<Lane<'_, T>> for LaneMut<'_, T> {
717    fn eq(&self, other: &Lane<'_, T>) -> bool {
718        self.view.slice(self.index..) == other.view.slice(other.index..)
719    }
720}
721
722impl<'a, T> Lanes<'a, T> {
723    /// Create an iterator which yields all possible slices over the `dim`
724    /// dimension of `tensor`.
725    pub(crate) fn new<L: Layout + RemoveDim + Clone>(
726        view: TensorBase<ViewData<'a, T>, L>,
727        dim: usize,
728    ) -> Lanes<'a, T> {
729        let size = view.size(dim);
730        let stride = view.stride(dim);
731        let lane_layout =
732            NdLayout::from_shape_and_strides([size], [stride], OverlapPolicy::AllowOverlap)
733                .unwrap();
734        Lanes {
735            data: view.storage(),
736            ranges: LaneRanges::new(view.layout(), dim),
737            lane_layout,
738        }
739    }
740}
741
742fn lane_for_offset_range<T>(
743    data: ViewData<T>,
744    layout: NdLayout<1>,
745    offsets: Range<usize>,
746) -> Lane<T> {
747    let view = NdTensorView::from_storage_and_layout(data.slice(offsets), layout);
748    Lane { view, index: 0 }
749}
750
751impl<'a, T> Iterator for Lanes<'a, T> {
752    type Item = Lane<'a, T>;
753
754    /// Yield the next slice over the target dimension.
755    #[inline]
756    fn next(&mut self) -> Option<Self::Item> {
757        self.ranges
758            .next()
759            .map(|range| lane_for_offset_range(self.data, self.lane_layout, range))
760    }
761
762    fn size_hint(&self) -> (usize, Option<usize>) {
763        self.ranges.size_hint()
764    }
765
766    fn fold<B, F>(self, init: B, mut f: F) -> B
767    where
768        Self: Sized,
769        F: FnMut(B, Self::Item) -> B,
770    {
771        self.ranges.fold(init, |acc, offsets| {
772            let lane = lane_for_offset_range(self.data, self.lane_layout, offsets);
773            f(acc, lane)
774        })
775    }
776}
777
778impl<T> DoubleEndedIterator for Lanes<'_, T> {
779    fn next_back(&mut self) -> Option<Self::Item> {
780        self.ranges
781            .next_back()
782            .map(|range| lane_for_offset_range(self.data, self.lane_layout, range))
783    }
784}
785
786impl<T> ExactSizeIterator for Lanes<'_, T> {}
787
788impl<T> FusedIterator for Lanes<'_, T> {}
789
790/// Mutable version of [`Lanes`].
791///
792/// Unlike [`Lanes`], this does not implement [`Iterator`] due to complications
793/// in implementing this for an iterator that returns mutable references, but
794/// it has a similar interface.
795pub struct LanesMut<'a, T> {
796    data: ViewMutData<'a, T>,
797    ranges: LaneRanges,
798    lane_layout: NdLayout<1>,
799}
800
801impl<'a, T> LanesMut<'a, T> {
802    /// Create an iterator which yields all possible slices over the `dim`
803    /// dimension of `view`.
804    pub(crate) fn new<L: Layout + RemoveDim + Clone>(
805        view: TensorBase<ViewMutData<'a, T>, L>,
806        dim: usize,
807    ) -> LanesMut<'a, T> {
808        // See notes in `Layout` about internal overlap.
809        assert!(
810            !view.is_broadcast(),
811            "Cannot mutably iterate over broadcasting view"
812        );
813
814        let size = view.size(dim);
815        let stride = view.stride(dim);
816
817        // We allow overlap here to handle the case where the stride is zero,
818        // but the tensor is empty. If the tensor was not empty, the assert above
819        // would have caught this.
820        let lane_layout =
821            NdLayout::from_shape_and_strides([size], [stride], OverlapPolicy::AllowOverlap)
822                .unwrap();
823
824        LanesMut {
825            ranges: LaneRanges::new(view.layout(), dim),
826            data: view.into_storage(),
827            lane_layout,
828        }
829    }
830}
831
832impl<'a, T> Iterator for LanesMut<'a, T> {
833    type Item = LaneMut<'a, T>;
834
835    #[inline]
836    fn next(&mut self) -> Option<LaneMut<'a, T>> {
837        self.ranges.next().map(|offsets| {
838            // Safety: Offsets range length is sufficient for layout, elements
839            // in each lane do not overlap.
840            unsafe {
841                LaneMut::from_storage_layout(self.data.to_view_slice_mut(offsets), self.lane_layout)
842            }
843        })
844    }
845
846    fn size_hint(&self) -> (usize, Option<usize>) {
847        self.ranges.size_hint()
848    }
849
850    fn fold<B, F>(mut self, init: B, mut f: F) -> B
851    where
852        Self: Sized,
853        F: FnMut(B, Self::Item) -> B,
854    {
855        self.ranges.fold(init, |acc, offsets| {
856            // Safety: Offsets range length is sufficient for layout, elements
857            // in each lane do not overlap.
858            let lane = unsafe {
859                LaneMut::from_storage_layout(self.data.to_view_slice_mut(offsets), self.lane_layout)
860            };
861            f(acc, lane)
862        })
863    }
864}
865
866impl<'a, T> ExactSizeIterator for LanesMut<'a, T> {}
867
868impl<'a, T> DoubleEndedIterator for LanesMut<'a, T> {
869    fn next_back(&mut self) -> Option<LaneMut<'a, T>> {
870        self.ranges.next_back().map(|offsets| {
871            // Safety: Offsets range length is sufficient for layout, elements
872            // in each lane do not overlap.
873            unsafe {
874                LaneMut::from_storage_layout(self.data.to_view_slice_mut(offsets), self.lane_layout)
875            }
876        })
877    }
878}
879
880/// Iterator over items in a 1D slice of a tensor.
881#[derive(Debug)]
882pub struct LaneMut<'a, T> {
883    view: NdTensorViewMut<'a, T, 1>,
884    index: usize,
885}
886
887impl<'a, T> LaneMut<'a, T> {
888    /// Create a new lane given the storage and layout.
889    ///
890    /// # Safety
891    ///
892    /// - Caller must ensure that no two lanes are created which overlap.
893    /// - Storage length must exceed `layout.min_data_len()`.
894    unsafe fn from_storage_layout(data: ViewMutData<'a, T>, layout: NdLayout<1>) -> Self {
895        let view = unsafe {
896            // Safety: Caller promises that each call uses the offset ranges for
897            // a different lane and that the range length is sufficient for the
898            // lane's size and stride.
899            NdTensorViewMut::from_storage_and_layout_unchecked(data, layout)
900        };
901        LaneMut { view, index: 0 }
902    }
903
904    /// Return the remaining part of the lane as a slice, if it is contiguous.
905    pub fn as_slice_mut(&mut self) -> Option<&mut [T]> {
906        self.view.data_mut().map(|data| &mut data[self.index..])
907    }
908
909    /// Return the entire lane as a mutable 1D tensor view.
910    pub fn into_view(self) -> NdTensorViewMut<'a, T, 1> {
911        self.view
912    }
913}
914
915impl<'a, T> Iterator for LaneMut<'a, T> {
916    type Item = &'a mut T;
917
918    #[inline]
919    fn next(&mut self) -> Option<Self::Item> {
920        if self.index < self.view.size(0) {
921            let index = self.index;
922            self.index += 1;
923            let item = unsafe { self.view.get_unchecked_mut([index]) };
924
925            // Transmute to preserve lifetime of data. This is safe as we
926            // yield each element only once.
927            Some(unsafe { transmute::<&mut T, Self::Item>(item) })
928        } else {
929            None
930        }
931    }
932
933    #[inline]
934    fn nth(&mut self, nth: usize) -> Option<Self::Item> {
935        self.index = (self.index + nth).min(self.view.size(0));
936        self.next()
937    }
938
939    fn size_hint(&self) -> (usize, Option<usize>) {
940        let size = self.view.size(0);
941        (size, Some(size))
942    }
943}
944
945impl<T> ExactSizeIterator for LaneMut<'_, T> {}
946
947impl<T: PartialEq> PartialEq<LaneMut<'_, T>> for LaneMut<'_, T> {
948    fn eq(&self, other: &LaneMut<'_, T>) -> bool {
949        self.view.slice(self.index..) == other.view.slice(other.index..)
950    }
951}
952
953/// Base for iterators over views of the inner dimensions of a tensor, where
954/// the inner dimensions have layout `L`.
955struct InnerIterBase<L: Layout> {
956    // Iterator over storage start offsets for each inner view. The storage
957    // range for each view is `offset..offset + inner_data_len`.
958    outer_offsets: Offsets,
959    inner_layout: L,
960    inner_data_len: usize,
961}
962
963impl<L: Layout + Clone> InnerIterBase<L> {
964    fn new_impl<PL: Layout, F: Fn(&[usize], &[usize]) -> L>(
965        parent_layout: &PL,
966        inner_dims: usize,
967        make_inner_layout: F,
968    ) -> InnerIterBase<L> {
969        assert!(parent_layout.ndim() >= inner_dims);
970        let outer_dims = parent_layout.ndim() - inner_dims;
971        let parent_shape = parent_layout.shape();
972        let parent_strides = parent_layout.strides();
973        let (outer_shape, inner_shape) = parent_shape.as_ref().split_at(outer_dims);
974        let (outer_strides, inner_strides) = parent_strides.as_ref().split_at(outer_dims);
975
976        let outer_layout = DynLayout::from_shape_and_strides(
977            outer_shape,
978            outer_strides,
979            OverlapPolicy::AllowOverlap,
980        )
981        .unwrap();
982
983        let inner_layout = make_inner_layout(inner_shape, inner_strides);
984
985        InnerIterBase {
986            outer_offsets: Offsets::new(&outer_layout),
987            inner_data_len: inner_layout.min_data_len(),
988            inner_layout,
989        }
990    }
991}
992
993impl<const N: usize> InnerIterBase<NdLayout<N>> {
994    pub fn new<L: Layout>(parent_layout: &L) -> Self {
995        Self::new_impl(parent_layout, N, |inner_shape, inner_strides| {
996            let inner_shape: [usize; N] = inner_shape.try_into().unwrap();
997            let inner_strides: [usize; N] = inner_strides.try_into().unwrap();
998            NdLayout::from_shape_and_strides(
999                inner_shape,
1000                inner_strides,
1001                // We allow overlap here, but the view that owns `parent_layout`
1002                // will enforce there is no overlap if it is a mutable view.
1003                OverlapPolicy::AllowOverlap,
1004            )
1005            .expect("failed to create layout")
1006        })
1007    }
1008}
1009
1010impl InnerIterBase<DynLayout> {
1011    pub fn new_dyn<L: Layout>(parent_layout: &L, inner_dims: usize) -> Self {
1012        Self::new_impl(parent_layout, inner_dims, |inner_shape, inner_strides| {
1013            DynLayout::from_shape_and_strides(
1014                inner_shape,
1015                inner_strides,
1016                // We allow overlap here, but the view that owns `parent_layout`
1017                // will enforce there is no overlap if it is a mutable view.
1018                OverlapPolicy::AllowOverlap,
1019            )
1020            .expect("failed to create layout")
1021        })
1022    }
1023}
1024
1025impl<L: Layout> Iterator for InnerIterBase<L> {
1026    /// Storage offset range for next view
1027    type Item = Range<usize>;
1028
1029    fn next(&mut self) -> Option<Range<usize>> {
1030        self.outer_offsets
1031            .next()
1032            .map(|offset| offset..offset + self.inner_data_len)
1033    }
1034
1035    fn size_hint(&self) -> (usize, Option<usize>) {
1036        self.outer_offsets.size_hint()
1037    }
1038
1039    fn fold<B, F>(self, init: B, mut f: F) -> B
1040    where
1041        Self: Sized,
1042        F: FnMut(B, Self::Item) -> B,
1043    {
1044        self.outer_offsets.fold(init, |acc, offset| {
1045            f(acc, offset..offset + self.inner_data_len)
1046        })
1047    }
1048}
1049
1050impl<L: Layout> ExactSizeIterator for InnerIterBase<L> {}
1051
1052impl<L: Layout> DoubleEndedIterator for InnerIterBase<L> {
1053    fn next_back(&mut self) -> Option<Self::Item> {
1054        self.outer_offsets
1055            .next_back()
1056            .map(|offset| offset..offset + self.inner_data_len)
1057    }
1058}
1059
1060/// Iterator over views of the innermost dimensions of a tensor, where the
1061/// tensor has element type T and the inner dimensions have layout L.
1062pub struct InnerIter<'a, T, L: Layout> {
1063    base: InnerIterBase<L>,
1064    data: ViewData<'a, T>,
1065}
1066
1067impl<'a, T, const N: usize> InnerIter<'a, T, NdLayout<N>> {
1068    pub fn new<L: Layout + Clone>(view: TensorBase<ViewData<'a, T>, L>) -> Self {
1069        let base = InnerIterBase::new(&view);
1070        InnerIter {
1071            base,
1072            data: view.storage(),
1073        }
1074    }
1075}
1076
1077impl<'a, T> InnerIter<'a, T, DynLayout> {
1078    pub fn new_dyn<L: Layout + Clone>(
1079        view: TensorBase<ViewData<'a, T>, L>,
1080        inner_dims: usize,
1081    ) -> Self {
1082        let base = InnerIterBase::new_dyn(&view, inner_dims);
1083        InnerIter {
1084            base,
1085            data: view.storage(),
1086        }
1087    }
1088}
1089
1090impl<'a, T, L: Layout + Clone> Iterator for InnerIter<'a, T, L> {
1091    type Item = TensorBase<ViewData<'a, T>, L>;
1092
1093    fn next(&mut self) -> Option<Self::Item> {
1094        self.base.next().map(|offset_range| {
1095            TensorBase::from_storage_and_layout(
1096                self.data.slice(offset_range),
1097                self.base.inner_layout.clone(),
1098            )
1099        })
1100    }
1101
1102    fn size_hint(&self) -> (usize, Option<usize>) {
1103        self.base.size_hint()
1104    }
1105
1106    fn fold<B, F>(self, init: B, mut f: F) -> B
1107    where
1108        Self: Sized,
1109        F: FnMut(B, Self::Item) -> B,
1110    {
1111        let inner_layout = self.base.inner_layout.clone();
1112        self.base.fold(init, |acc, offset_range| {
1113            let item = TensorBase::from_storage_and_layout(
1114                self.data.slice(offset_range),
1115                inner_layout.clone(),
1116            );
1117            f(acc, item)
1118        })
1119    }
1120}
1121
1122impl<T, L: Layout + Clone> ExactSizeIterator for InnerIter<'_, T, L> {}
1123
1124impl<T, L: Layout + Clone> DoubleEndedIterator for InnerIter<'_, T, L> {
1125    fn next_back(&mut self) -> Option<Self::Item> {
1126        self.base.next_back().map(|offset_range| {
1127            TensorBase::from_storage_and_layout(
1128                self.data.slice(offset_range),
1129                self.base.inner_layout.clone(),
1130            )
1131        })
1132    }
1133}
1134
1135/// Iterator over mutable views of the innermost dimensions of a tensor, where
1136/// the tensor has element type T and the inner dimensions have layout L.
1137pub struct InnerIterMut<'a, T, L: Layout> {
1138    base: InnerIterBase<L>,
1139    data: ViewMutData<'a, T>,
1140}
1141
1142impl<'a, T, const N: usize> InnerIterMut<'a, T, NdLayout<N>> {
1143    pub fn new<L: Layout>(view: TensorBase<ViewMutData<'a, T>, L>) -> Self {
1144        let base = InnerIterBase::new(&view);
1145        InnerIterMut {
1146            base,
1147            data: view.into_storage(),
1148        }
1149    }
1150}
1151
1152impl<'a, T> InnerIterMut<'a, T, DynLayout> {
1153    pub fn new_dyn<L: Layout>(view: TensorBase<ViewMutData<'a, T>, L>, inner_dims: usize) -> Self {
1154        let base = InnerIterBase::new_dyn(&view, inner_dims);
1155        InnerIterMut {
1156            base,
1157            data: view.into_storage(),
1158        }
1159    }
1160}
1161
1162impl<'a, T, L: Layout + Clone> Iterator for InnerIterMut<'a, T, L> {
1163    type Item = TensorBase<ViewMutData<'a, T>, L>;
1164
1165    fn next(&mut self) -> Option<Self::Item> {
1166        self.base.next().map(|offset_range| {
1167            let storage = self.data.slice_mut(offset_range);
1168            let storage = unsafe {
1169                // Safety: The iterator was constructed from a tensor with a
1170                // non-overlapping layout, and no two views yielded by this
1171                // iterator overlap. Hence we can transmute the lifetime without
1172                // creating multiple mutable references to the same elements.
1173                std::mem::transmute::<ViewMutData<'_, T>, ViewMutData<'a, T>>(storage)
1174            };
1175            TensorBase::from_storage_and_layout(storage, self.base.inner_layout.clone())
1176        })
1177    }
1178
1179    fn size_hint(&self) -> (usize, Option<usize>) {
1180        self.base.size_hint()
1181    }
1182
1183    fn fold<B, F>(mut self, init: B, mut f: F) -> B
1184    where
1185        Self: Sized,
1186        F: FnMut(B, Self::Item) -> B,
1187    {
1188        let inner_layout = self.base.inner_layout.clone();
1189        self.base.fold(init, |acc, offset_range| {
1190            let storage = self.data.slice_mut(offset_range);
1191            let storage = unsafe {
1192                // Safety: The iterator was constructed from a tensor with a
1193                // non-overlapping layout, and no two views yielded by this
1194                // iterator overlap. Hence we can transmute the lifetime without
1195                // creating multiple mutable references to the same elements.
1196                std::mem::transmute::<ViewMutData<'_, T>, ViewMutData<'a, T>>(storage)
1197            };
1198            let item = TensorBase::from_storage_and_layout(storage, inner_layout.clone());
1199            f(acc, item)
1200        })
1201    }
1202}
1203
1204impl<T, L: Layout + Clone> ExactSizeIterator for InnerIterMut<'_, T, L> {}
1205
1206impl<'a, T, L: Layout + Clone> DoubleEndedIterator for InnerIterMut<'a, T, L> {
1207    fn next_back(&mut self) -> Option<Self::Item> {
1208        self.base.next_back().map(|offset_range| {
1209            let storage = self.data.slice_mut(offset_range);
1210            let storage = unsafe {
1211                // Safety: Outer view is non-broadcasting, and we increment the
1212                // outer index each time, so returned views will not overlap.
1213                std::mem::transmute::<ViewMutData<'_, T>, ViewMutData<'a, T>>(storage)
1214            };
1215            TensorBase::from_storage_and_layout(storage, self.base.inner_layout.clone())
1216        })
1217    }
1218}
1219
1220/// Iterator over slices of a tensor along an axis. See
1221/// [`TensorView::axis_iter`](crate::TensorView::axis_iter).
1222pub struct AxisIter<'a, T, L: Layout + RemoveDim> {
1223    view: TensorBase<ViewData<'a, T>, L>,
1224    axis: usize,
1225    index: usize,
1226    end: usize,
1227}
1228
1229impl<'a, T, L: MutLayout + RemoveDim> AxisIter<'a, T, L> {
1230    pub fn new(view: &TensorBase<ViewData<'a, T>, L>, axis: usize) -> AxisIter<'a, T, L> {
1231        assert!(axis < view.ndim());
1232        AxisIter {
1233            view: view.clone(),
1234            axis,
1235            index: 0,
1236            end: view.size(axis),
1237        }
1238    }
1239}
1240
1241impl<'a, T, L: MutLayout + RemoveDim> Iterator for AxisIter<'a, T, L> {
1242    type Item = TensorBase<ViewData<'a, T>, <L as RemoveDim>::Output>;
1243
1244    fn next(&mut self) -> Option<Self::Item> {
1245        if self.index >= self.end {
1246            None
1247        } else {
1248            let slice = self.view.index_axis(self.axis, self.index);
1249            self.index += 1;
1250            Some(slice)
1251        }
1252    }
1253
1254    fn size_hint(&self) -> (usize, Option<usize>) {
1255        let len = self.end - self.index;
1256        (len, Some(len))
1257    }
1258}
1259
1260impl<'a, T, L: MutLayout + RemoveDim> ExactSizeIterator for AxisIter<'a, T, L> {}
1261
1262impl<'a, T, L: MutLayout + RemoveDim> DoubleEndedIterator for AxisIter<'a, T, L> {
1263    fn next_back(&mut self) -> Option<Self::Item> {
1264        if self.index >= self.end {
1265            None
1266        } else {
1267            let slice = self.view.index_axis(self.axis, self.end - 1);
1268            self.end -= 1;
1269            Some(slice)
1270        }
1271    }
1272}
1273
1274/// Iterator over mutable slices of a tensor along an axis. See [`TensorViewMut::axis_iter_mut`].
1275pub struct AxisIterMut<'a, T, L: Layout + RemoveDim> {
1276    view: TensorBase<ViewMutData<'a, T>, L>,
1277    axis: usize,
1278    index: usize,
1279    end: usize,
1280}
1281
1282impl<'a, T, L: Layout + RemoveDim + Clone> AxisIterMut<'a, T, L> {
1283    pub fn new(view: TensorBase<ViewMutData<'a, T>, L>, axis: usize) -> AxisIterMut<'a, T, L> {
1284        // See notes in `Layout` about internal overlap.
1285        assert!(
1286            !view.layout().is_broadcast(),
1287            "Cannot mutably iterate over broadcasting view"
1288        );
1289        assert!(axis < view.ndim());
1290        AxisIterMut {
1291            axis,
1292            index: 0,
1293            end: view.size(axis),
1294            view,
1295        }
1296    }
1297}
1298
1299/// Mutable tensor view with one less dimension than `L`.
1300type SmallerMutView<'b, T, L> = TensorBase<ViewMutData<'b, T>, <L as RemoveDim>::Output>;
1301
1302impl<'a, T, L: MutLayout + RemoveDim> Iterator for AxisIterMut<'a, T, L> {
1303    type Item = TensorBase<ViewMutData<'a, T>, <L as RemoveDim>::Output>;
1304
1305    fn next(&mut self) -> Option<Self::Item> {
1306        if self.index >= self.end {
1307            None
1308        } else {
1309            let index = self.index;
1310            self.index += 1;
1311
1312            let slice = self.view.index_axis_mut(self.axis, index);
1313
1314            // Promote lifetime from self -> 'a.
1315            //
1316            // Safety: This is non-broadcasting view, and we increment the index
1317            // each time, so returned views will not overlap.
1318            let view = unsafe { transmute::<SmallerMutView<'_, T, L>, Self::Item>(slice) };
1319
1320            Some(view)
1321        }
1322    }
1323
1324    fn size_hint(&self) -> (usize, Option<usize>) {
1325        let len = self.end - self.index;
1326        (len, Some(len))
1327    }
1328}
1329
1330impl<'a, T, L: MutLayout + RemoveDim> ExactSizeIterator for AxisIterMut<'a, T, L> {}
1331
1332impl<'a, T, L: MutLayout + RemoveDim> DoubleEndedIterator for AxisIterMut<'a, T, L> {
1333    fn next_back(&mut self) -> Option<Self::Item> {
1334        if self.index >= self.end {
1335            None
1336        } else {
1337            let index = self.end - 1;
1338            self.end -= 1;
1339
1340            let slice = self.view.index_axis_mut(self.axis, index);
1341
1342            // Promote lifetime from self -> 'a.
1343            //
1344            // Safety: This is non-broadcasting view, and we increment the index
1345            // each time, so returned views will not overlap.
1346            let view = unsafe { transmute::<SmallerMutView<'_, T, L>, Self::Item>(slice) };
1347
1348            Some(view)
1349        }
1350    }
1351}
1352
1353/// Iterator over slices of a tensor along an axis. See
1354/// [`TensorView::axis_chunks`](crate::TensorView::axis_chunks).
1355pub struct AxisChunks<'a, T, L: MutLayout> {
1356    remainder: Option<TensorBase<ViewData<'a, T>, L>>,
1357    axis: usize,
1358    chunk_size: usize,
1359}
1360
1361impl<'a, T, L: MutLayout> AxisChunks<'a, T, L> {
1362    pub fn new(
1363        view: &TensorBase<ViewData<'a, T>, L>,
1364        axis: usize,
1365        chunk_size: usize,
1366    ) -> AxisChunks<'a, T, L> {
1367        assert!(chunk_size > 0, "chunk size must be > 0");
1368        AxisChunks {
1369            remainder: if view.size(axis) > 0 {
1370                Some(view.view())
1371            } else {
1372                None
1373            },
1374            axis,
1375            chunk_size,
1376        }
1377    }
1378}
1379
1380impl<'a, T, L: MutLayout> Iterator for AxisChunks<'a, T, L> {
1381    type Item = TensorBase<ViewData<'a, T>, L>;
1382
1383    fn next(&mut self) -> Option<Self::Item> {
1384        let remainder = self.remainder.take()?;
1385        let chunk_len = self.chunk_size.min(remainder.size(self.axis));
1386        let (current, next_remainder) = remainder.split_at(self.axis, chunk_len);
1387        self.remainder = if next_remainder.size(self.axis) > 0 {
1388            Some(next_remainder)
1389        } else {
1390            None
1391        };
1392        Some(current)
1393    }
1394
1395    fn size_hint(&self) -> (usize, Option<usize>) {
1396        let len = self
1397            .remainder
1398            .as_ref()
1399            .map(|r| r.size(self.axis))
1400            .unwrap_or(0)
1401            .div_ceil(self.chunk_size);
1402        (len, Some(len))
1403    }
1404}
1405
1406impl<'a, T, L: MutLayout> ExactSizeIterator for AxisChunks<'a, T, L> {}
1407
1408impl<'a, T, L: MutLayout> DoubleEndedIterator for AxisChunks<'a, T, L> {
1409    fn next_back(&mut self) -> Option<Self::Item> {
1410        let remainder = self.remainder.take()?;
1411        let chunk_len = self.chunk_size.min(remainder.size(self.axis));
1412        let (prev_remainder, current) =
1413            remainder.split_at(self.axis, remainder.size(self.axis) - chunk_len);
1414        self.remainder = if prev_remainder.size(self.axis) > 0 {
1415            Some(prev_remainder)
1416        } else {
1417            None
1418        };
1419        Some(current)
1420    }
1421}
1422
1423/// Iterator over mutable slices of a tensor along an axis. See [`TensorViewMut::axis_chunks_mut`].
1424pub struct AxisChunksMut<'a, T, L: MutLayout> {
1425    remainder: Option<TensorBase<ViewMutData<'a, T>, L>>,
1426    axis: usize,
1427    chunk_size: usize,
1428}
1429
1430impl<'a, T, L: MutLayout> AxisChunksMut<'a, T, L> {
1431    pub fn new(
1432        view: TensorBase<ViewMutData<'a, T>, L>,
1433        axis: usize,
1434        chunk_size: usize,
1435    ) -> AxisChunksMut<'a, T, L> {
1436        // See notes in `Layout` about internal overlap.
1437        assert!(
1438            !view.layout().is_broadcast(),
1439            "Cannot mutably iterate over broadcasting view"
1440        );
1441        assert!(chunk_size > 0, "chunk size must be > 0");
1442        AxisChunksMut {
1443            remainder: if view.size(axis) > 0 {
1444                Some(view)
1445            } else {
1446                None
1447            },
1448            axis,
1449            chunk_size,
1450        }
1451    }
1452}
1453
1454impl<'a, T, L: MutLayout> Iterator for AxisChunksMut<'a, T, L> {
1455    type Item = TensorBase<ViewMutData<'a, T>, L>;
1456
1457    fn next(&mut self) -> Option<Self::Item> {
1458        let remainder = self.remainder.take()?;
1459        let chunk_len = self.chunk_size.min(remainder.size(self.axis));
1460        let (current, next_remainder) = remainder.split_at_mut(self.axis, chunk_len);
1461        self.remainder = if next_remainder.size(self.axis) > 0 {
1462            Some(next_remainder)
1463        } else {
1464            None
1465        };
1466        Some(current)
1467    }
1468
1469    fn size_hint(&self) -> (usize, Option<usize>) {
1470        let len = self
1471            .remainder
1472            .as_ref()
1473            .map(|r| r.size(self.axis))
1474            .unwrap_or(0)
1475            .div_ceil(self.chunk_size);
1476        (len, Some(len))
1477    }
1478}
1479
1480impl<'a, T, L: MutLayout> ExactSizeIterator for AxisChunksMut<'a, T, L> {}
1481
1482impl<'a, T, L: MutLayout> DoubleEndedIterator for AxisChunksMut<'a, T, L> {
1483    fn next_back(&mut self) -> Option<Self::Item> {
1484        let remainder = self.remainder.take()?;
1485        let remainder_size = remainder.size(self.axis);
1486        let chunk_len = self.chunk_size.min(remainder_size);
1487        let (prev_remainder, current) =
1488            remainder.split_at_mut(self.axis, remainder_size - chunk_len);
1489        self.remainder = if prev_remainder.size(self.axis) > 0 {
1490            Some(prev_remainder)
1491        } else {
1492            None
1493        };
1494        Some(current)
1495    }
1496}
1497
1498/// Call `f` on each element of `view`.
1499pub fn for_each_mut<T, F: Fn(&mut T)>(mut view: TensorViewMut<T>, f: F) {
1500    while view.ndim() < 4 {
1501        view.insert_axis(0);
1502    }
1503
1504    // This could be improved by sorting dimensions of `view` in order of
1505    // decreasing stride. If the resulting view is contiguous, `f` can be
1506    // applied to the underlying data directly. Even if it isn't, this will
1507    // still make memory access as contiguous as possible.
1508
1509    view.inner_iter_mut::<4>().for_each(|mut src| {
1510        for i0 in 0..src.size(0) {
1511            for i1 in 0..src.size(1) {
1512                for i2 in 0..src.size(2) {
1513                    for i3 in 0..src.size(3) {
1514                        // Safety: i0..i3 are in `[0, src.size(i))`.
1515                        let x = unsafe { src.get_unchecked_mut([i0, i1, i2, i3]) };
1516                        f(x);
1517                    }
1518                }
1519            }
1520        }
1521    });
1522}
1523
1524// Tests for iterator internals. Most tests of iterators are currently done via
1525// tests on tensor methods.
1526#[cfg(test)]
1527mod tests {
1528    use crate::{
1529        AsView, AxisChunks, AxisChunksMut, Lanes, LanesMut, Layout, NdLayout, NdTensor, Tensor,
1530    };
1531
1532    fn compare_reversed<T: PartialEq + std::fmt::Debug>(fwd: &[T], rev: &[T]) {
1533        assert_eq!(fwd.len(), rev.len());
1534        for (x, y) in fwd.iter().zip(rev.iter().rev()) {
1535            assert_eq!(x, y);
1536        }
1537    }
1538
1539    /// Apply a standard set of tests to an iterator.
1540    fn test_iterator<I: Iterator + ExactSizeIterator + DoubleEndedIterator>(
1541        create_iter: impl Fn() -> I,
1542        expected: &[I::Item],
1543    ) where
1544        I::Item: PartialEq + std::fmt::Debug,
1545    {
1546        let iter = create_iter();
1547
1548        let (min_len, max_len) = iter.size_hint();
1549        let items: Vec<_> = iter.collect();
1550
1551        assert_eq!(&items, expected);
1552
1553        // Test ExactSizeIterator via `size_hint`.
1554        assert_eq!(min_len, items.len(), "incorrect size lower bound");
1555        assert_eq!(max_len, Some(items.len()), "incorrect size upper bound");
1556
1557        // Test DoubleEndedIterator via `rev`.
1558        let rev_items: Vec<_> = create_iter().rev().collect();
1559        compare_reversed(&items, &rev_items);
1560
1561        // Test FusedIterator.
1562        let mut iter = create_iter();
1563        for _x in &mut iter { /* noop */ }
1564        assert_eq!(iter.next(), None);
1565
1566        // Test fold.
1567        let mut fold_items = Vec::new();
1568        let mut idx = 0;
1569        create_iter().fold(0, |acc, item| {
1570            assert_eq!(acc, idx);
1571            fold_items.push(item);
1572            idx += 1;
1573            idx
1574        });
1575        assert_eq!(items, fold_items);
1576    }
1577
1578    /// A collection that can be mutably iterated over multiple times.
1579    ///
1580    /// We use a different pattern for testing mutable iterators to avoid
1581    /// restrictions on values returned from `FnMut` closures.
1582    trait MutIterable {
1583        type Iter<'a>: Iterator + ExactSizeIterator + DoubleEndedIterator
1584        where
1585            Self: 'a;
1586
1587        fn iter_mut(&mut self) -> Self::Iter<'_>;
1588    }
1589
1590    /// Apply a standard set of tests to a mutable iterator.
1591    fn test_mut_iterator<M, T>(mut iterable: M, expected: &[T])
1592    where
1593        M: MutIterable,
1594        T: std::fmt::Debug,
1595        for<'a> <M::Iter<'a> as Iterator>::Item: std::fmt::Debug + PartialEq + PartialEq<T>,
1596    {
1597        // Test Iterator and ExactSizeIterator.
1598        {
1599            let iter = iterable.iter_mut();
1600            let (min_len, max_len) = iter.size_hint();
1601            let items: Vec<_> = iter.collect();
1602
1603            // Test `next`
1604            assert_eq!(items, expected);
1605
1606            // Test `size_hint`
1607            assert_eq!(min_len, items.len(), "incorrect size lower bound");
1608            assert_eq!(max_len, Some(items.len()), "incorrect size upper bound");
1609        }
1610
1611        // Test FusedIterator.
1612        {
1613            let mut iter = iterable.iter_mut();
1614            for _x in &mut iter { /* noop */ }
1615            assert!(iter.next().is_none());
1616        }
1617
1618        // Test DoubleEndedIterator via `rev`.
1619        //
1620        // We use `format!` here to convert mutable references into comparable
1621        // items that have no connection to the mutable references yielded by
1622        // the iterator. Ideally this should be replaced by a clone or something.
1623        {
1624            let items: Vec<_> = iterable.iter_mut().map(|x| format!("{:?}", x)).collect();
1625            let rev_items: Vec<_> = iterable
1626                .iter_mut()
1627                .rev()
1628                .map(|x| format!("{:?}", x))
1629                .collect();
1630            compare_reversed(&items, &rev_items);
1631        }
1632
1633        // Test fold.
1634        {
1635            let items: Vec<_> = iterable.iter_mut().map(|x| format!("{:?}", x)).collect();
1636            let mut fold_items = Vec::new();
1637            let mut idx = 0;
1638            iterable.iter_mut().fold(0, |acc, item| {
1639                assert_eq!(acc, idx);
1640                fold_items.push(format!("{:?}", item));
1641                idx += 1;
1642                idx
1643            });
1644            assert_eq!(items, fold_items);
1645        }
1646    }
1647
1648    #[test]
1649    fn test_axis_chunks() {
1650        let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1651        test_iterator(
1652            || tensor.axis_chunks(0, 1),
1653            &[tensor.slice(0..1), tensor.slice(1..2)],
1654        );
1655    }
1656
1657    #[test]
1658    fn test_axis_chunks_empty() {
1659        let x = Tensor::<i32>::zeros(&[5, 0]);
1660        assert!(AxisChunks::new(&x.view(), 1, 1).next().is_none());
1661    }
1662
1663    #[test]
1664    #[should_panic(expected = "chunk size must be > 0")]
1665    fn test_axis_chunks_zero_size() {
1666        let x = Tensor::<i32>::zeros(&[5, 0]);
1667        assert!(AxisChunks::new(&x.view(), 1, 0).next().is_none());
1668    }
1669
1670    #[test]
1671    fn test_axis_chunks_mut_empty() {
1672        let mut x = Tensor::<i32>::zeros(&[5, 0]);
1673        assert!(AxisChunksMut::new(x.view_mut(), 1, 1).next().is_none());
1674    }
1675
1676    #[test]
1677    fn test_axis_chunks_mut_rev() {
1678        let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1679        let fwd: Vec<_> = tensor
1680            .axis_chunks_mut(0, 1)
1681            .map(|view| view.to_vec())
1682            .collect();
1683        let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1684        let rev: Vec<_> = tensor
1685            .axis_chunks_mut(0, 1)
1686            .rev()
1687            .map(|view| view.to_vec())
1688            .collect();
1689        compare_reversed(&fwd, &rev);
1690    }
1691
1692    #[test]
1693    #[should_panic(expected = "chunk size must be > 0")]
1694    fn test_axis_chunks_mut_zero_size() {
1695        let mut x = Tensor::<i32>::zeros(&[5, 0]);
1696        assert!(AxisChunksMut::new(x.view_mut(), 1, 0).next().is_none());
1697    }
1698
1699    #[test]
1700    fn test_axis_iter() {
1701        let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1702        test_iterator(|| tensor.axis_iter(0), &[tensor.slice(0), tensor.slice(1)]);
1703    }
1704
1705    #[test]
1706    fn test_axis_iter_mut_rev() {
1707        let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1708        let fwd: Vec<_> = tensor.axis_iter_mut(0).map(|view| view.to_vec()).collect();
1709        let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1710        let rev: Vec<_> = tensor
1711            .axis_iter_mut(0)
1712            .rev()
1713            .map(|view| view.to_vec())
1714            .collect();
1715        compare_reversed(&fwd, &rev);
1716    }
1717
1718    #[test]
1719    fn test_inner_iter() {
1720        let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1721        test_iterator(
1722            || tensor.inner_iter::<2>(),
1723            &[tensor.slice(0), tensor.slice(1)],
1724        );
1725    }
1726
1727    #[test]
1728    fn test_inner_iter_mut() {
1729        struct InnerIterMutTest(NdTensor<i32, 3>);
1730
1731        impl MutIterable for InnerIterMutTest {
1732            type Iter<'a> = super::InnerIterMut<'a, i32, NdLayout<2>>;
1733
1734            fn iter_mut(&mut self) -> Self::Iter<'_> {
1735                self.0.inner_iter_mut::<2>()
1736            }
1737        }
1738
1739        let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1740        test_mut_iterator(
1741            InnerIterMutTest(tensor.clone()),
1742            &[tensor.slice(0), tensor.slice(1)],
1743        );
1744    }
1745
1746    #[test]
1747    fn test_lanes() {
1748        let x = NdTensor::from([[1, 2], [3, 4]]);
1749        test_iterator(
1750            || x.lanes(0),
1751            &[x.slice((.., 0)).into(), x.slice((.., 1)).into()],
1752        );
1753        test_iterator(|| x.lanes(1), &[x.slice(0).into(), x.slice(1).into()]);
1754    }
1755
1756    #[test]
1757    fn test_lanes_empty() {
1758        let x = Tensor::<i32>::zeros(&[5, 0]);
1759        assert!(Lanes::new(x.view().view_ref(), 0).next().is_none());
1760        assert!(Lanes::new(x.view().view_ref(), 1).next().is_none());
1761    }
1762
1763    #[test]
1764    fn test_lanes_mut() {
1765        use super::Lane;
1766
1767        struct LanesMutTest(NdTensor<i32, 2>);
1768
1769        impl MutIterable for LanesMutTest {
1770            type Iter<'a> = super::LanesMut<'a, i32>;
1771
1772            fn iter_mut(&mut self) -> Self::Iter<'_> {
1773                self.0.lanes_mut(0)
1774            }
1775        }
1776
1777        let tensor = NdTensor::from([[1, 2], [3, 4]]);
1778        test_mut_iterator::<_, Lane<i32>>(
1779            LanesMutTest(tensor.clone()),
1780            &[
1781                Lane::from(tensor.slice((.., 0))),
1782                Lane::from(tensor.slice((.., 1))),
1783            ],
1784        );
1785    }
1786
1787    #[test]
1788    fn test_lane_as_slice() {
1789        // Contiguous lane
1790        let x = NdTensor::from([0, 1, 2]);
1791        let mut lane = x.lanes(0).next().unwrap();
1792        assert_eq!(lane.as_slice(), Some([0, 1, 2].as_slice()));
1793        lane.next();
1794        assert_eq!(lane.as_slice(), Some([1, 2].as_slice()));
1795        lane.next();
1796        lane.next();
1797        assert_eq!(lane.as_slice(), Some([0i32; 0].as_slice()));
1798        lane.next();
1799        assert_eq!(lane.as_slice(), Some([0i32; 0].as_slice()));
1800
1801        // Non-contiguous lane
1802        let x = NdTensor::from([[1i32, 2], [3, 4]]);
1803        let lane = x.lanes(0).next().unwrap();
1804        assert_eq!(lane.as_slice(), None);
1805    }
1806
1807    #[test]
1808    fn test_lanes_mut_empty() {
1809        let mut x = Tensor::<i32>::zeros(&[5, 0]);
1810        assert!(LanesMut::new(x.mut_view_ref(), 0).next().is_none());
1811        assert!(LanesMut::new(x.mut_view_ref(), 1).next().is_none());
1812    }
1813
1814    #[test]
1815    fn test_iter_step_by() {
1816        let tensor = Tensor::<f32>::full(&[1, 3, 16, 8], 1.);
1817
1818        // Take a non-contiguous slice so we don't use the fast path for
1819        // contiguous tensors.
1820        let tensor = tensor.slice((.., .., 1.., ..));
1821
1822        let sum = tensor.iter().sum::<f32>();
1823        for n_skip in 0..tensor.len() {
1824            let sum_skip = tensor.iter().skip(n_skip).sum::<f32>();
1825            assert_eq!(
1826                sum_skip,
1827                sum - n_skip as f32,
1828                "wrong sum for n_skip={}",
1829                n_skip
1830            );
1831        }
1832    }
1833
1834    #[test]
1835    fn test_iter_broadcast() {
1836        let tensor = Tensor::<f32>::full(&[1], 1.);
1837        let broadcast = tensor.broadcast([1, 3, 16, 8]);
1838        assert_eq!(broadcast.iter().len(), broadcast.len());
1839        let count = broadcast.iter().count();
1840        assert_eq!(count, broadcast.len());
1841        let sum = broadcast.iter().sum::<f32>();
1842        assert_eq!(sum, broadcast.len() as f32);
1843    }
1844
1845    #[test]
1846    fn test_iter() {
1847        let tensor = NdTensor::from([[[1, 2], [3, 4]]]);
1848
1849        // Test iterator over contiguous tensor.
1850        test_iterator(|| tensor.iter().copied(), &[1, 2, 3, 4]);
1851
1852        // Test iterator over non-contiguous tensor.
1853        test_iterator(|| tensor.transposed().iter().copied(), &[1, 3, 2, 4]);
1854    }
1855
1856    #[test]
1857    fn test_iter_mut() {
1858        struct IterTest(NdTensor<i32, 3>);
1859
1860        impl MutIterable for IterTest {
1861            type Iter<'a> = super::IterMut<'a, i32>;
1862
1863            fn iter_mut(&mut self) -> Self::Iter<'_> {
1864                self.0.iter_mut()
1865            }
1866        }
1867
1868        let tensor = NdTensor::from([[[1, 2], [3, 4]]]);
1869        test_mut_iterator(IterTest(tensor), &[&1, &2, &3, &4]);
1870    }
1871
1872    #[test]
1873    #[ignore]
1874    fn bench_iter() {
1875        use crate::Layout;
1876        use rten_bench::run_bench;
1877
1878        type Elem = i32;
1879
1880        let tensor = std::hint::black_box(Tensor::<Elem>::full(&[1, 6, 768, 64], 1));
1881        let n_trials = 1000;
1882        let mut result = Elem::default();
1883
1884        fn reduce<'a>(iter: impl Iterator<Item = &'a Elem>) -> Elem {
1885            iter.fold(Elem::default(), |acc, x| acc.wrapping_add(*x))
1886        }
1887
1888        // Iterate directly over data slice.
1889        run_bench(n_trials, Some("slice iter"), || {
1890            result = reduce(tensor.data().unwrap().iter());
1891        });
1892        println!("sum {}", result);
1893
1894        // Use tensor iterator with contiguous tensor. This will use the fast
1895        // path which wraps a slice iterator.
1896        run_bench(n_trials, Some("contiguous iter"), || {
1897            result = reduce(tensor.iter());
1898        });
1899        println!("sum {}", result);
1900
1901        run_bench(n_trials, Some("contiguous reverse iter"), || {
1902            result = reduce(tensor.iter().rev());
1903        });
1904        println!("sum {}", result);
1905
1906        // Use tensor iterator with non-contiguous slice. This will fall back
1907        // to indexed iteration.
1908        let slice = tensor.slice((.., .., 1.., ..));
1909        assert!(!slice.is_contiguous());
1910        let n_trials = 1000;
1911        run_bench(n_trials, Some("non-contiguous iter"), || {
1912            result = reduce(slice.iter());
1913        });
1914        println!("sum {}", result);
1915
1916        // Reverse iteration with non-contiguous slice. This is much slower
1917        // because it translates linear indexes into offsets using division.
1918        let n_trials = 100;
1919        run_bench(n_trials, Some("non-contiguous reverse iter"), || {
1920            result = reduce(slice.iter().rev());
1921        });
1922        println!("sum {}", result);
1923    }
1924
1925    #[test]
1926    #[ignore]
1927    fn bench_inner_iter() {
1928        use crate::rng::XorShiftRng;
1929        use rten_bench::run_bench;
1930
1931        let n_trials = 100;
1932        let mut rng = XorShiftRng::new(1234);
1933
1934        // Tensor with many steps along the outer two dimensions relative to the
1935        // steps along the inner two dimensions. This emphasizes the overhead of
1936        // stepping `inner_iter`.
1937        let tensor = Tensor::<f32>::rand(&[512, 512, 12, 1], &mut rng);
1938
1939        let mut sum = 0.;
1940        run_bench(n_trials, Some("inner iter"), || {
1941            for inner in tensor.inner_iter::<2>() {
1942                for i0 in 0..inner.size(0) {
1943                    for i1 in 0..inner.size(1) {
1944                        sum += inner[[i0, i1]];
1945                    }
1946                }
1947            }
1948        });
1949        println!("sum {}", sum);
1950    }
1951}