rten_tensor/
iterators.rs

1//! Iterators over tensor elements and sub-views.
2
3use std::iter::FusedIterator;
4use std::mem::transmute;
5use std::ops::Range;
6
7use rten_base::iter::SplitIterator;
8
9use super::{AsView, DynLayout, NdTensorView, NdTensorViewMut, TensorBase, TensorViewMut};
10use crate::layout::{Layout, MutLayout, NdLayout, OverlapPolicy, RemoveDim, merge_axes};
11use crate::storage::{StorageMut, ViewData, ViewMutData};
12
13mod parallel;
14
15/// Tracks the iteration position within a single dimension.
16#[derive(Copy, Clone, Debug, Default)]
17struct IterPos {
18    /// Remaining steps along this dimension before it needs to be reset.
19    remaining: usize,
20
21    /// Current index in this dimension pre-multiplied by stride.
22    offset: usize,
23
24    /// Update to `offset` for each step.
25    stride: usize,
26
27    /// Maximum value of `self.remaining`. Used when resetting position.
28    max_remaining: usize,
29}
30
31impl IterPos {
32    fn from_size_stride(size: usize, stride: usize) -> Self {
33        let remaining = size.saturating_sub(1);
34        IterPos {
35            remaining,
36            offset: 0,
37            stride,
38            max_remaining: remaining,
39        }
40    }
41
42    #[inline(always)]
43    fn step(&mut self) -> bool {
44        if self.remaining != 0 {
45            self.remaining -= 1;
46            self.offset += self.stride;
47            true
48        } else {
49            self.remaining = self.max_remaining;
50            self.offset = 0;
51            false
52        }
53    }
54
55    /// Return the size of this dimension.
56    fn size(&self) -> usize {
57        // nb. The size is always > 0 since if any dim has zero size, the
58        // iterator will have a length of zero.
59        self.max_remaining + 1
60    }
61
62    /// Return the current index along this dimension.
63    fn index(&self) -> usize {
64        self.max_remaining - self.remaining
65    }
66
67    /// Set the current index along this dimension.
68    fn set_index(&mut self, index: usize) {
69        self.remaining = self.max_remaining - index;
70        self.offset = index * self.stride;
71    }
72}
73
74const INNER_NDIM: usize = 2;
75
76/// Iterator over offsets of a tensor's elements.
77#[derive(Clone, Debug)]
78struct OffsetsBase {
79    /// Remaining number of elements this iterator will yield.
80    ///
81    /// The offsets and positions in other fields are only valid if this is
82    /// non-zero.
83    len: usize,
84
85    /// Component of next element offset from innermost (fastest-changing) dims.
86    inner_offset: usize,
87
88    /// Current position in innermost dims.
89    inner_pos: [IterPos; INNER_NDIM],
90
91    /// Component of next element offset from outermost (slowest-changing) dims.
92    outer_offset: usize,
93
94    /// Current position in outermost dims.
95    ///
96    /// Optimization note: The number of outermost dims will usually be small,
97    /// so you might be tempted to use `SmallVec`. However this resulted in
98    /// worse performance for `IndexingIterBase::step`, as the compiler was
99    /// less likely/able to unroll iteration loops.
100    outer_pos: Vec<IterPos>,
101}
102
103impl OffsetsBase {
104    /// Create an iterator over element offsets in `tensor`.
105    fn new<L: Layout>(layout: &L) -> OffsetsBase {
106        // Merge axes to maximize the number of iterations that use the fast
107        // path for stepping over the inner dimensions.
108        let merged = merge_axes(layout.shape().as_ref(), layout.strides().as_ref());
109
110        let inner_pos_pad = INNER_NDIM.saturating_sub(merged.len());
111        let n_outer = merged.len().saturating_sub(INNER_NDIM);
112
113        let inner_pos = std::array::from_fn(|dim| {
114            let (size, stride) = if dim < inner_pos_pad {
115                (1, 0)
116            } else {
117                merged[n_outer + dim - inner_pos_pad]
118            };
119            IterPos::from_size_stride(size, stride)
120        });
121
122        let outer_pos = (0..n_outer)
123            .map(|i| {
124                let (size, stride) = merged[i];
125                IterPos::from_size_stride(size, stride)
126            })
127            .collect();
128
129        OffsetsBase {
130            len: merged.iter().map(|dim| dim.0).product(),
131            inner_pos,
132            inner_offset: 0,
133            outer_pos,
134            outer_offset: 0,
135        }
136    }
137
138    /// Step in the outer dimensions.
139    ///
140    /// Returns `true` if the position was advanced or `false` if the end was
141    /// reached.
142    fn step_outer_pos(&mut self) -> bool {
143        let mut done = self.outer_pos.is_empty();
144        for (i, dim) in self.outer_pos.iter_mut().enumerate().rev() {
145            if dim.step() {
146                break;
147            } else if i == 0 {
148                done = true;
149            }
150        }
151        self.outer_offset = self.outer_pos.iter().map(|p| p.offset).sum();
152        !done
153    }
154
155    fn pos(&self, dim: usize) -> IterPos {
156        let outer_ndim = self.outer_pos.len();
157        if dim >= outer_ndim {
158            self.inner_pos[dim - outer_ndim]
159        } else {
160            self.outer_pos[dim]
161        }
162    }
163
164    fn pos_mut(&mut self, dim: usize) -> &mut IterPos {
165        let outer_ndim = self.outer_pos.len();
166        if dim >= outer_ndim {
167            &mut self.inner_pos[dim - outer_ndim]
168        } else {
169            &mut self.outer_pos[dim]
170        }
171    }
172
173    /// Advance iterator by up to `n` indices.
174    fn step_by(&mut self, n: usize) {
175        let mut remaining = n.min(self.len);
176        self.len -= remaining;
177
178        for dim in (0..self.ndim()).rev() {
179            if remaining == 0 {
180                break;
181            }
182
183            let pos = self.pos_mut(dim);
184            let size = pos.size();
185            let new_index = pos.index() + remaining;
186            pos.set_index(new_index % size);
187            remaining = new_index / size;
188        }
189
190        // Update offset of next element.
191        self.inner_offset = self.inner_pos.iter().map(|p| p.offset).sum();
192        self.outer_offset = self.outer_pos.iter().map(|p| p.offset).sum();
193    }
194
195    fn ndim(&self) -> usize {
196        self.outer_pos.len() + self.inner_pos.len()
197    }
198
199    /// Compute the storage offset of an element given a linear index into a
200    /// tensor's element sequence.
201    fn offset_from_linear_index(&self, index: usize) -> usize {
202        let mut offset = 0;
203        let mut shape_product = 1;
204        for dim in (0..self.ndim()).rev() {
205            let pos = self.pos(dim);
206            let dim_index = (index / shape_product) % pos.size();
207            shape_product *= pos.size();
208            offset += dim_index * pos.stride;
209        }
210        offset
211    }
212
213    /// Truncate this iterator so that it yields at most `len` elements.
214    fn truncate(&mut self, len: usize) {
215        // We adjust `self.len` here but not any of the iteration positions.
216        // This means that methods like `next` and `fold` must always check
217        // `self.len` before each step.
218        self.len = self.len.min(len);
219    }
220}
221
222impl Iterator for OffsetsBase {
223    type Item = usize;
224
225    #[inline(always)]
226    fn next(&mut self) -> Option<usize> {
227        if self.len == 0 {
228            return None;
229        }
230        let offset = self.outer_offset + self.inner_offset;
231
232        self.len -= 1;
233
234        // Optimistically update offset, assuming we haven't reached the
235        // end of the last dimension.
236        self.inner_offset += self.inner_pos[1].stride;
237
238        // Use a fast path to step inner dimensions and fall back to the slower
239        // path to step the outer dimensions only when we reach the end.
240        if !self.inner_pos[1].step() {
241            if !self.inner_pos[0].step() {
242                self.step_outer_pos();
243            }
244
245            // `inner_offset` is the sum of `inner_pos[i].offset`. It only
246            // contains two entries, and we know `inner_pos[1].offset` is zero
247            // since `inner_pos[1].step()` returned false. Hence we can use
248            // an assignment.
249            self.inner_offset = self.inner_pos[0].offset;
250        }
251
252        Some(offset)
253    }
254
255    fn size_hint(&self) -> (usize, Option<usize>) {
256        (self.len, Some(self.len))
257    }
258
259    fn fold<B, F>(mut self, init: B, mut f: F) -> B
260    where
261        Self: Sized,
262        F: FnMut(B, usize) -> B,
263    {
264        // Iter positions are only valid if `self.len > 0`.
265        if self.len == 0 {
266            return init;
267        }
268
269        let mut accum = init;
270        'outer: loop {
271            for i0 in self.inner_pos[0].index()..self.inner_pos[0].size() {
272                for i1 in self.inner_pos[1].index()..self.inner_pos[1].size() {
273                    let inner_offset =
274                        i0 * self.inner_pos[0].stride + i1 * self.inner_pos[1].stride;
275                    accum = f(accum, self.outer_offset + inner_offset);
276
277                    self.len -= 1;
278                    if self.len == 0 {
279                        break 'outer;
280                    }
281                }
282                self.inner_pos[1].set_index(0);
283            }
284            self.inner_pos[0].set_index(0);
285
286            if !self.step_outer_pos() {
287                break;
288            }
289        }
290
291        accum
292    }
293}
294
295impl ExactSizeIterator for OffsetsBase {}
296
297impl DoubleEndedIterator for OffsetsBase {
298    fn next_back(&mut self) -> Option<usize> {
299        if self.len == 0 {
300            return None;
301        }
302
303        // This is inefficient compared to forward iteration, but that's OK
304        // because reverse iteration is not performance critical.
305        let index = self.len - 1;
306        let offset = self.offset_from_linear_index(index);
307        self.len -= 1;
308
309        Some(offset)
310    }
311}
312
313impl SplitIterator for OffsetsBase {
314    /// Split this iterator into two. The left result visits indices before
315    /// `index`, the right result visits indices from `index` onwards.
316    fn split_at(mut self, index: usize) -> (Self, Self) {
317        assert!(self.len >= index);
318
319        let mut right = self.clone();
320        OffsetsBase::step_by(&mut right, index);
321
322        self.truncate(index);
323
324        (self, right)
325    }
326}
327
328/// Iterator over elements of a tensor, in their logical order.
329pub struct Iter<'a, T> {
330    offsets: Offsets,
331    data: ViewData<'a, T>,
332}
333
334impl<'a, T> Iter<'a, T> {
335    pub(super) fn new<L: Layout + Clone>(view: TensorBase<ViewData<'a, T>, L>) -> Iter<'a, T> {
336        Iter {
337            offsets: Offsets::new(view.layout()),
338            data: view.storage(),
339        }
340    }
341}
342
343impl<T> Clone for Iter<'_, T> {
344    fn clone(&self) -> Self {
345        Iter {
346            offsets: self.offsets.clone(),
347            data: self.data,
348        }
349    }
350}
351
352impl<'a, T> Iterator for Iter<'a, T> {
353    type Item = &'a T;
354
355    #[inline(always)]
356    fn next(&mut self) -> Option<Self::Item> {
357        let offset = self.offsets.next()?;
358
359        // Safety: Offset is valid for data length.
360        Some(unsafe { self.data.get_unchecked(offset) })
361    }
362
363    fn size_hint(&self) -> (usize, Option<usize>) {
364        self.offsets.size_hint()
365    }
366
367    fn nth(&mut self, n: usize) -> Option<Self::Item> {
368        let offset = self.offsets.nth(n)?;
369
370        // Safety: Offset is valid for data length.
371        Some(unsafe { self.data.get_unchecked(offset) })
372    }
373
374    fn fold<B, F>(self, init: B, mut f: F) -> B
375    where
376        Self: Sized,
377        F: FnMut(B, Self::Item) -> B,
378    {
379        self.offsets.fold(init, |acc, offset| {
380            // Safety: Offset is valid for data length.
381            let item = unsafe { self.data.get_unchecked(offset) };
382            f(acc, item)
383        })
384    }
385}
386
387impl<'a, T> DoubleEndedIterator for Iter<'a, T> {
388    fn next_back(&mut self) -> Option<Self::Item> {
389        let offset = self.offsets.next_back()?;
390
391        // Safety: Offset is valid for data length.
392        Some(unsafe { self.data.get_unchecked(offset) })
393    }
394}
395
396impl<T> ExactSizeIterator for Iter<'_, T> {}
397
398impl<T> FusedIterator for Iter<'_, T> {}
399
400/// Wrapper around [`transmute`] which allows transmuting only the lifetime,
401/// not the type, of a reference.
402unsafe fn transmute_lifetime_mut<'a, 'b, T>(x: &'a mut T) -> &'b mut T {
403    unsafe { transmute::<&'a mut T, &'b mut T>(x) }
404}
405
406/// Mutable iterator over elements of a tensor.
407pub struct IterMut<'a, T> {
408    offsets: Offsets,
409    data: ViewMutData<'a, T>,
410}
411
412impl<'a, T> IterMut<'a, T> {
413    pub(super) fn new<L: Layout + Clone>(
414        view: TensorBase<ViewMutData<'a, T>, L>,
415    ) -> IterMut<'a, T> {
416        IterMut {
417            offsets: Offsets::new(view.layout()),
418            data: view.into_storage(),
419        }
420    }
421}
422
423impl<'a, T> Iterator for IterMut<'a, T> {
424    type Item = &'a mut T;
425
426    #[inline]
427    fn next(&mut self) -> Option<Self::Item> {
428        let offset = self.offsets.next()?;
429
430        // Safety: Offset is valid for data length, `offsets.next` yields each
431        // offset only once.
432        Some(unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) })
433    }
434
435    #[inline]
436    fn size_hint(&self) -> (usize, Option<usize>) {
437        self.offsets.size_hint()
438    }
439
440    fn nth(&mut self, n: usize) -> Option<Self::Item> {
441        let offset = self.offsets.nth(n)?;
442
443        // Safety: Offset is valid for data length, `offsets.next` yields each
444        // offset only once.
445        Some(unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) })
446    }
447
448    fn fold<B, F>(mut self, init: B, mut f: F) -> B
449    where
450        Self: Sized,
451        F: FnMut(B, Self::Item) -> B,
452    {
453        self.offsets.fold(init, |acc, offset| {
454            // Safety: Offset is valid for data length, `offsets.fold` yields
455            // each offset only once.
456            let item = unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) };
457            f(acc, item)
458        })
459    }
460}
461
462impl<T> DoubleEndedIterator for IterMut<'_, T> {
463    fn next_back(&mut self) -> Option<Self::Item> {
464        let offset = self.offsets.next_back()?;
465
466        // Safety: Offset is valid for data length, `offsets.next` yields each
467        // offset only once.
468        Some(unsafe { transmute_lifetime_mut(self.data.get_unchecked_mut(offset)) })
469    }
470}
471
472impl<T> ExactSizeIterator for IterMut<'_, T> {}
473
474impl<T> FusedIterator for IterMut<'_, T> {}
475
476#[derive(Clone)]
477enum OffsetsKind {
478    Range(Range<usize>),
479    Indexing(OffsetsBase),
480}
481
482/// Iterator over element offsets of a tensor.
483///
484/// `Offsets` does not hold a reference to the tensor, allowing the tensor to
485/// be modified during iteration. It is the caller's responsibilty not to modify
486/// the tensor in ways that invalidate the offset sequence returned by this
487/// iterator.
488#[derive(Clone)]
489struct Offsets {
490    base: OffsetsKind,
491}
492
493impl Offsets {
494    fn new<L: Layout>(layout: &L) -> Offsets {
495        Offsets {
496            base: if layout.is_contiguous() {
497                OffsetsKind::Range(0..layout.min_data_len())
498            } else {
499                OffsetsKind::Indexing(OffsetsBase::new(layout))
500            },
501        }
502    }
503}
504
505impl Iterator for Offsets {
506    type Item = usize;
507
508    #[inline]
509    fn next(&mut self) -> Option<Self::Item> {
510        match &mut self.base {
511            OffsetsKind::Range(r) => r.next(),
512            OffsetsKind::Indexing(base) => base.next(),
513        }
514    }
515
516    fn size_hint(&self) -> (usize, Option<usize>) {
517        match &self.base {
518            OffsetsKind::Range(r) => r.size_hint(),
519            OffsetsKind::Indexing(base) => (base.len, Some(base.len)),
520        }
521    }
522
523    fn nth(&mut self, n: usize) -> Option<Self::Item> {
524        match &mut self.base {
525            OffsetsKind::Range(r) => r.nth(n),
526            OffsetsKind::Indexing(base) => {
527                base.step_by(n);
528                self.next()
529            }
530        }
531    }
532
533    fn fold<B, F>(self, init: B, f: F) -> B
534    where
535        Self: Sized,
536        F: FnMut(B, Self::Item) -> B,
537    {
538        match self.base {
539            OffsetsKind::Range(r) => r.fold(init, f),
540            OffsetsKind::Indexing(base) => base.fold(init, f),
541        }
542    }
543}
544
545impl DoubleEndedIterator for Offsets {
546    fn next_back(&mut self) -> Option<Self::Item> {
547        match &mut self.base {
548            OffsetsKind::Range(r) => r.next_back(),
549            OffsetsKind::Indexing(base) => base.next_back(),
550        }
551    }
552}
553
554impl ExactSizeIterator for Offsets {}
555
556impl FusedIterator for Offsets {}
557
558/// Iterator over the ranges of a tensor's data that correspond to 1D lanes
559/// along a particular dimension.
560struct LaneRanges {
561    /// Start offsets of each lane.
562    offsets: Offsets,
563
564    // Number of elements in each lane and gap between them.
565    dim_size: usize,
566    dim_stride: usize,
567}
568
569impl LaneRanges {
570    fn new<L: Layout + RemoveDim>(layout: &L, dim: usize) -> LaneRanges {
571        // If the layout is empty (has any zero-sized dims), we need to make
572        // sure that `offsets` is as well.
573        let offsets = if layout.is_empty() {
574            Offsets::new(layout)
575        } else {
576            let other_dims = layout.remove_dim(dim);
577            Offsets::new(&other_dims)
578        };
579
580        LaneRanges {
581            offsets,
582            dim_size: layout.size(dim),
583            dim_stride: layout.stride(dim),
584        }
585    }
586
587    /// Return the range of storage offsets for a 1D lane where the first
588    /// element is at `start_offset`.
589    fn lane_offset_range(&self, start_offset: usize) -> Range<usize> {
590        lane_offsets(start_offset, self.dim_size, self.dim_stride)
591    }
592}
593
594fn lane_offsets(start_offset: usize, size: usize, stride: usize) -> Range<usize> {
595    start_offset..start_offset + (size - 1) * stride + 1
596}
597
598impl Iterator for LaneRanges {
599    type Item = Range<usize>;
600
601    #[inline]
602    fn next(&mut self) -> Option<Range<usize>> {
603        self.offsets
604            .next()
605            .map(|offset| self.lane_offset_range(offset))
606    }
607
608    fn size_hint(&self) -> (usize, Option<usize>) {
609        self.offsets.size_hint()
610    }
611
612    fn fold<B, F>(self, init: B, mut f: F) -> B
613    where
614        Self: Sized,
615        F: FnMut(B, Self::Item) -> B,
616    {
617        let Self {
618            offsets,
619            dim_size,
620            dim_stride,
621        } = self;
622
623        offsets.fold(init, |acc, offset| {
624            f(acc, lane_offsets(offset, dim_size, dim_stride))
625        })
626    }
627}
628
629impl DoubleEndedIterator for LaneRanges {
630    fn next_back(&mut self) -> Option<Range<usize>> {
631        self.offsets
632            .next_back()
633            .map(|offset| self.lane_offset_range(offset))
634    }
635}
636
637impl ExactSizeIterator for LaneRanges {}
638
639impl FusedIterator for LaneRanges {}
640
641/// Iterator over 1D slices of a tensor along a target dimension of size N.
642///
643/// Conceptually this iterator steps through every distinct slice of a tensor
644/// where a target dim is varied from 0..N and other indices are held fixed.
645pub struct Lanes<'a, T> {
646    data: ViewData<'a, T>,
647    ranges: LaneRanges,
648    lane_layout: NdLayout<1>,
649}
650
651/// Iterator over items in a 1D slice of a tensor.
652#[derive(Clone, Debug)]
653pub struct Lane<'a, T> {
654    view: NdTensorView<'a, T, 1>,
655    index: usize,
656}
657
658impl<'a, T> Lane<'a, T> {
659    /// Return the remaining part of the lane as a slice, if it is contiguous.
660    pub fn as_slice(&self) -> Option<&'a [T]> {
661        self.view.data().map(|data| &data[self.index..])
662    }
663
664    /// Return the item at a given index in this lane.
665    pub fn get(&self, idx: usize) -> Option<&'a T> {
666        self.view.get([idx])
667    }
668
669    /// Return the entire lane as a 1D tensor view.
670    pub fn as_view(&self) -> NdTensorView<'a, T, 1> {
671        self.view
672    }
673}
674
675impl<'a, T> From<NdTensorView<'a, T, 1>> for Lane<'a, T> {
676    fn from(val: NdTensorView<'a, T, 1>) -> Self {
677        Lane {
678            view: val,
679            index: 0,
680        }
681    }
682}
683
684impl<'a, T> Iterator for Lane<'a, T> {
685    type Item = &'a T;
686
687    #[inline]
688    fn next(&mut self) -> Option<Self::Item> {
689        if self.index < self.view.len() {
690            let index = self.index;
691            self.index += 1;
692
693            // Safety: Index is in bounds for axis 0.
694            Some(unsafe { self.view.get_unchecked([index]) })
695        } else {
696            None
697        }
698    }
699
700    fn size_hint(&self) -> (usize, Option<usize>) {
701        let size = self.view.size(0);
702        (size, Some(size))
703    }
704}
705
706impl<T> ExactSizeIterator for Lane<'_, T> {}
707
708impl<T> FusedIterator for Lane<'_, T> {}
709
710impl<T: PartialEq> PartialEq<Lane<'_, T>> for Lane<'_, T> {
711    fn eq(&self, other: &Lane<'_, T>) -> bool {
712        self.view.slice(self.index..) == other.view.slice(other.index..)
713    }
714}
715
716impl<T: PartialEq> PartialEq<Lane<'_, T>> for LaneMut<'_, T> {
717    fn eq(&self, other: &Lane<'_, T>) -> bool {
718        self.view.slice(self.index..) == other.view.slice(other.index..)
719    }
720}
721
722impl<'a, T> Lanes<'a, T> {
723    /// Create an iterator which yields all possible slices over the `dim`
724    /// dimension of `tensor`.
725    pub(crate) fn new<L: Layout + RemoveDim + Clone>(
726        view: TensorBase<ViewData<'a, T>, L>,
727        dim: usize,
728    ) -> Lanes<'a, T> {
729        let size = view.size(dim);
730        let stride = view.stride(dim);
731        let lane_layout =
732            NdLayout::from_shape_and_strides([size], [stride], OverlapPolicy::AllowOverlap)
733                .unwrap();
734        Lanes {
735            data: view.storage(),
736            ranges: LaneRanges::new(view.layout(), dim),
737            lane_layout,
738        }
739    }
740}
741
742fn lane_for_offset_range<T>(
743    data: ViewData<T>,
744    layout: NdLayout<1>,
745    offsets: Range<usize>,
746) -> Lane<T> {
747    let view = NdTensorView::from_storage_and_layout(data.slice(offsets), layout);
748    Lane { view, index: 0 }
749}
750
751impl<'a, T> Iterator for Lanes<'a, T> {
752    type Item = Lane<'a, T>;
753
754    /// Yield the next slice over the target dimension.
755    #[inline]
756    fn next(&mut self) -> Option<Self::Item> {
757        self.ranges
758            .next()
759            .map(|range| lane_for_offset_range(self.data, self.lane_layout, range))
760    }
761
762    fn size_hint(&self) -> (usize, Option<usize>) {
763        self.ranges.size_hint()
764    }
765
766    fn fold<B, F>(self, init: B, mut f: F) -> B
767    where
768        Self: Sized,
769        F: FnMut(B, Self::Item) -> B,
770    {
771        self.ranges.fold(init, |acc, offsets| {
772            let lane = lane_for_offset_range(self.data, self.lane_layout, offsets);
773            f(acc, lane)
774        })
775    }
776}
777
778impl<T> DoubleEndedIterator for Lanes<'_, T> {
779    fn next_back(&mut self) -> Option<Self::Item> {
780        self.ranges
781            .next_back()
782            .map(|range| lane_for_offset_range(self.data, self.lane_layout, range))
783    }
784}
785
786impl<T> ExactSizeIterator for Lanes<'_, T> {}
787
788impl<T> FusedIterator for Lanes<'_, T> {}
789
790/// Mutable version of [`Lanes`].
791///
792/// Unlike [`Lanes`], this does not implement [`Iterator`] due to complications
793/// in implementing this for an iterator that returns mutable references, but
794/// it has a similar interface.
795pub struct LanesMut<'a, T> {
796    data: ViewMutData<'a, T>,
797    ranges: LaneRanges,
798    lane_layout: NdLayout<1>,
799}
800
801impl<'a, T> LanesMut<'a, T> {
802    /// Create an iterator which yields all possible slices over the `dim`
803    /// dimension of `view`.
804    pub(crate) fn new<L: Layout + RemoveDim + Clone>(
805        view: TensorBase<ViewMutData<'a, T>, L>,
806        dim: usize,
807    ) -> LanesMut<'a, T> {
808        // See notes in `Layout` about internal overlap.
809        assert!(
810            !view.is_broadcast(),
811            "Cannot mutably iterate over broadcasting view"
812        );
813
814        let size = view.size(dim);
815        let stride = view.stride(dim);
816
817        // We allow overlap here to handle the case where the stride is zero,
818        // but the tensor is empty. If the tensor was not empty, the assert above
819        // would have caught this.
820        let lane_layout =
821            NdLayout::from_shape_and_strides([size], [stride], OverlapPolicy::AllowOverlap)
822                .unwrap();
823
824        LanesMut {
825            ranges: LaneRanges::new(view.layout(), dim),
826            data: view.into_storage(),
827            lane_layout,
828        }
829    }
830}
831
832impl<'a, T> Iterator for LanesMut<'a, T> {
833    type Item = LaneMut<'a, T>;
834
835    #[inline]
836    fn next(&mut self) -> Option<LaneMut<'a, T>> {
837        self.ranges.next().map(|offsets| {
838            // Safety: Offsets range length is sufficient for layout, elements
839            // in each lane do not overlap.
840            unsafe {
841                LaneMut::from_storage_layout(self.data.to_view_slice_mut(offsets), self.lane_layout)
842            }
843        })
844    }
845
846    fn size_hint(&self) -> (usize, Option<usize>) {
847        self.ranges.size_hint()
848    }
849
850    fn fold<B, F>(mut self, init: B, mut f: F) -> B
851    where
852        Self: Sized,
853        F: FnMut(B, Self::Item) -> B,
854    {
855        self.ranges.fold(init, |acc, offsets| {
856            // Safety: Offsets range length is sufficient for layout, elements
857            // in each lane do not overlap.
858            let lane = unsafe {
859                LaneMut::from_storage_layout(self.data.to_view_slice_mut(offsets), self.lane_layout)
860            };
861            f(acc, lane)
862        })
863    }
864}
865
866impl<'a, T> ExactSizeIterator for LanesMut<'a, T> {}
867
868impl<'a, T> DoubleEndedIterator for LanesMut<'a, T> {
869    fn next_back(&mut self) -> Option<LaneMut<'a, T>> {
870        self.ranges.next_back().map(|offsets| {
871            // Safety: Offsets range length is sufficient for layout, elements
872            // in each lane do not overlap.
873            unsafe {
874                LaneMut::from_storage_layout(self.data.to_view_slice_mut(offsets), self.lane_layout)
875            }
876        })
877    }
878}
879
880/// Iterator over items in a 1D slice of a tensor.
881#[derive(Debug)]
882pub struct LaneMut<'a, T> {
883    view: NdTensorViewMut<'a, T, 1>,
884    index: usize,
885}
886
887impl<'a, T> LaneMut<'a, T> {
888    /// Create a new lane given the storage and layout.
889    ///
890    /// # Safety
891    ///
892    /// - Caller must ensure that no two lanes are created which overlap.
893    /// - Storage length must exceed `layout.min_data_len()`.
894    unsafe fn from_storage_layout(data: ViewMutData<'a, T>, layout: NdLayout<1>) -> Self {
895        let view = unsafe {
896            // Safety: Caller promises that each call uses the offset ranges for
897            // a different lane and that the range length is sufficient for the
898            // lane's size and stride.
899            NdTensorViewMut::from_storage_and_layout_unchecked(data, layout)
900        };
901        LaneMut { view, index: 0 }
902    }
903
904    /// Return the remaining part of the lane as a slice, if it is contiguous.
905    pub fn as_slice_mut(&mut self) -> Option<&mut [T]> {
906        self.view.data_mut().map(|data| &mut data[self.index..])
907    }
908
909    /// Return the entire lane as a mutable 1D tensor view.
910    pub fn into_view(self) -> NdTensorViewMut<'a, T, 1> {
911        self.view
912    }
913}
914
915impl<'a, T> Iterator for LaneMut<'a, T> {
916    type Item = &'a mut T;
917
918    #[inline]
919    fn next(&mut self) -> Option<Self::Item> {
920        if self.index < self.view.size(0) {
921            let index = self.index;
922            self.index += 1;
923            let item = unsafe { self.view.get_unchecked_mut([index]) };
924
925            // Transmute to preserve lifetime of data. This is safe as we
926            // yield each element only once.
927            Some(unsafe { transmute::<&mut T, Self::Item>(item) })
928        } else {
929            None
930        }
931    }
932
933    #[inline]
934    fn nth(&mut self, nth: usize) -> Option<Self::Item> {
935        self.index = (self.index + nth).min(self.view.size(0));
936        self.next()
937    }
938
939    fn size_hint(&self) -> (usize, Option<usize>) {
940        let size = self.view.size(0);
941        (size, Some(size))
942    }
943}
944
945impl<T> ExactSizeIterator for LaneMut<'_, T> {}
946
947impl<T: PartialEq> PartialEq<LaneMut<'_, T>> for LaneMut<'_, T> {
948    fn eq(&self, other: &LaneMut<'_, T>) -> bool {
949        self.view.slice(self.index..) == other.view.slice(other.index..)
950    }
951}
952
953/// Base for iterators over views of the inner dimensions of a tensor, where
954/// the inner dimensions have layout `L`.
955struct InnerIterBase<L: Layout> {
956    // Iterator over storage start offsets for each inner view. The storage
957    // range for each view is `offset..offset + inner_data_len`.
958    outer_offsets: Offsets,
959    inner_layout: L,
960    inner_data_len: usize,
961}
962
963impl<L: Layout + Clone> InnerIterBase<L> {
964    fn new_impl<PL: Layout, F: Fn(&[usize], &[usize]) -> L>(
965        parent_layout: &PL,
966        inner_dims: usize,
967        make_inner_layout: F,
968    ) -> InnerIterBase<L> {
969        assert!(parent_layout.ndim() >= inner_dims);
970        let outer_dims = parent_layout.ndim() - inner_dims;
971        let parent_shape = parent_layout.shape();
972        let parent_strides = parent_layout.strides();
973        let (outer_shape, inner_shape) = parent_shape.as_ref().split_at(outer_dims);
974        let (outer_strides, inner_strides) = parent_strides.as_ref().split_at(outer_dims);
975
976        let outer_layout = DynLayout::from_shape_and_strides(
977            outer_shape,
978            outer_strides,
979            OverlapPolicy::AllowOverlap,
980        )
981        .unwrap();
982
983        let inner_layout = make_inner_layout(inner_shape, inner_strides);
984
985        InnerIterBase {
986            outer_offsets: Offsets::new(&outer_layout),
987            inner_data_len: inner_layout.min_data_len(),
988            inner_layout,
989        }
990    }
991}
992
993impl<const N: usize> InnerIterBase<NdLayout<N>> {
994    pub(crate) fn new<L: Layout>(parent_layout: &L) -> Self {
995        Self::new_impl(parent_layout, N, |inner_shape, inner_strides| {
996            let inner_shape: [usize; N] = inner_shape.try_into().unwrap();
997            let inner_strides: [usize; N] = inner_strides.try_into().unwrap();
998            NdLayout::from_shape_and_strides(
999                inner_shape,
1000                inner_strides,
1001                // We allow overlap here, but the view that owns `parent_layout`
1002                // will enforce there is no overlap if it is a mutable view.
1003                OverlapPolicy::AllowOverlap,
1004            )
1005            .expect("failed to create layout")
1006        })
1007    }
1008}
1009
1010impl InnerIterBase<DynLayout> {
1011    pub(crate) fn new_dyn<L: Layout>(parent_layout: &L, inner_dims: usize) -> Self {
1012        Self::new_impl(parent_layout, inner_dims, |inner_shape, inner_strides| {
1013            DynLayout::from_shape_and_strides(
1014                inner_shape,
1015                inner_strides,
1016                // We allow overlap here, but the view that owns `parent_layout`
1017                // will enforce there is no overlap if it is a mutable view.
1018                OverlapPolicy::AllowOverlap,
1019            )
1020            .expect("failed to create layout")
1021        })
1022    }
1023}
1024
1025impl<L: Layout> Iterator for InnerIterBase<L> {
1026    /// Storage offset range for next view
1027    type Item = Range<usize>;
1028
1029    fn next(&mut self) -> Option<Range<usize>> {
1030        self.outer_offsets
1031            .next()
1032            .map(|offset| offset..offset + self.inner_data_len)
1033    }
1034
1035    fn size_hint(&self) -> (usize, Option<usize>) {
1036        self.outer_offsets.size_hint()
1037    }
1038
1039    fn fold<B, F>(self, init: B, mut f: F) -> B
1040    where
1041        Self: Sized,
1042        F: FnMut(B, Self::Item) -> B,
1043    {
1044        self.outer_offsets.fold(init, |acc, offset| {
1045            f(acc, offset..offset + self.inner_data_len)
1046        })
1047    }
1048}
1049
1050impl<L: Layout> ExactSizeIterator for InnerIterBase<L> {}
1051
1052impl<L: Layout> DoubleEndedIterator for InnerIterBase<L> {
1053    fn next_back(&mut self) -> Option<Self::Item> {
1054        self.outer_offsets
1055            .next_back()
1056            .map(|offset| offset..offset + self.inner_data_len)
1057    }
1058}
1059
1060/// Iterator over views of the innermost dimensions of a tensor, where the
1061/// tensor has element type T and the inner dimensions have layout L.
1062pub struct InnerIter<'a, T, L: Layout> {
1063    base: InnerIterBase<L>,
1064    data: ViewData<'a, T>,
1065}
1066
1067impl<'a, T, const N: usize> InnerIter<'a, T, NdLayout<N>> {
1068    pub(crate) fn new<L: Layout + Clone>(view: TensorBase<ViewData<'a, T>, L>) -> Self {
1069        let base = InnerIterBase::new(&view);
1070        InnerIter {
1071            base,
1072            data: view.storage(),
1073        }
1074    }
1075}
1076
1077impl<'a, T> InnerIter<'a, T, DynLayout> {
1078    pub(crate) fn new_dyn<L: Layout + Clone>(
1079        view: TensorBase<ViewData<'a, T>, L>,
1080        inner_dims: usize,
1081    ) -> Self {
1082        let base = InnerIterBase::new_dyn(&view, inner_dims);
1083        InnerIter {
1084            base,
1085            data: view.storage(),
1086        }
1087    }
1088}
1089
1090impl<'a, T, L: Layout + Clone> Iterator for InnerIter<'a, T, L> {
1091    type Item = TensorBase<ViewData<'a, T>, L>;
1092
1093    fn next(&mut self) -> Option<Self::Item> {
1094        self.base.next().map(|offset_range| {
1095            TensorBase::from_storage_and_layout(
1096                self.data.slice(offset_range),
1097                self.base.inner_layout.clone(),
1098            )
1099        })
1100    }
1101
1102    fn size_hint(&self) -> (usize, Option<usize>) {
1103        self.base.size_hint()
1104    }
1105
1106    fn fold<B, F>(self, init: B, mut f: F) -> B
1107    where
1108        Self: Sized,
1109        F: FnMut(B, Self::Item) -> B,
1110    {
1111        let inner_layout = self.base.inner_layout.clone();
1112        self.base.fold(init, |acc, offset_range| {
1113            let item = TensorBase::from_storage_and_layout(
1114                self.data.slice(offset_range),
1115                inner_layout.clone(),
1116            );
1117            f(acc, item)
1118        })
1119    }
1120}
1121
1122impl<T, L: Layout + Clone> ExactSizeIterator for InnerIter<'_, T, L> {}
1123
1124impl<T, L: Layout + Clone> DoubleEndedIterator for InnerIter<'_, T, L> {
1125    fn next_back(&mut self) -> Option<Self::Item> {
1126        self.base.next_back().map(|offset_range| {
1127            TensorBase::from_storage_and_layout(
1128                self.data.slice(offset_range),
1129                self.base.inner_layout.clone(),
1130            )
1131        })
1132    }
1133}
1134
1135/// Iterator over mutable views of the innermost dimensions of a tensor, where
1136/// the tensor has element type T and the inner dimensions have layout L.
1137pub struct InnerIterMut<'a, T, L: Layout> {
1138    base: InnerIterBase<L>,
1139    data: ViewMutData<'a, T>,
1140}
1141
1142impl<'a, T, const N: usize> InnerIterMut<'a, T, NdLayout<N>> {
1143    pub(crate) fn new<L: Layout>(view: TensorBase<ViewMutData<'a, T>, L>) -> Self {
1144        let base = InnerIterBase::new(&view);
1145        InnerIterMut {
1146            base,
1147            data: view.into_storage(),
1148        }
1149    }
1150}
1151
1152impl<'a, T> InnerIterMut<'a, T, DynLayout> {
1153    pub(crate) fn new_dyn<L: Layout>(
1154        view: TensorBase<ViewMutData<'a, T>, L>,
1155        inner_dims: usize,
1156    ) -> Self {
1157        let base = InnerIterBase::new_dyn(&view, inner_dims);
1158        InnerIterMut {
1159            base,
1160            data: view.into_storage(),
1161        }
1162    }
1163}
1164
1165impl<'a, T, L: Layout + Clone> Iterator for InnerIterMut<'a, T, L> {
1166    type Item = TensorBase<ViewMutData<'a, T>, L>;
1167
1168    fn next(&mut self) -> Option<Self::Item> {
1169        self.base.next().map(|offset_range| {
1170            let storage = self.data.slice_mut(offset_range);
1171            let storage = unsafe {
1172                // Safety: The iterator was constructed from a tensor with a
1173                // non-overlapping layout, and no two views yielded by this
1174                // iterator overlap. Hence we can transmute the lifetime without
1175                // creating multiple mutable references to the same elements.
1176                std::mem::transmute::<ViewMutData<'_, T>, ViewMutData<'a, T>>(storage)
1177            };
1178            TensorBase::from_storage_and_layout(storage, self.base.inner_layout.clone())
1179        })
1180    }
1181
1182    fn size_hint(&self) -> (usize, Option<usize>) {
1183        self.base.size_hint()
1184    }
1185
1186    fn fold<B, F>(mut self, init: B, mut f: F) -> B
1187    where
1188        Self: Sized,
1189        F: FnMut(B, Self::Item) -> B,
1190    {
1191        let inner_layout = self.base.inner_layout.clone();
1192        self.base.fold(init, |acc, offset_range| {
1193            let storage = self.data.slice_mut(offset_range);
1194            let storage = unsafe {
1195                // Safety: The iterator was constructed from a tensor with a
1196                // non-overlapping layout, and no two views yielded by this
1197                // iterator overlap. Hence we can transmute the lifetime without
1198                // creating multiple mutable references to the same elements.
1199                std::mem::transmute::<ViewMutData<'_, T>, ViewMutData<'a, T>>(storage)
1200            };
1201            let item = TensorBase::from_storage_and_layout(storage, inner_layout.clone());
1202            f(acc, item)
1203        })
1204    }
1205}
1206
1207impl<T, L: Layout + Clone> ExactSizeIterator for InnerIterMut<'_, T, L> {}
1208
1209impl<'a, T, L: Layout + Clone> DoubleEndedIterator for InnerIterMut<'a, T, L> {
1210    fn next_back(&mut self) -> Option<Self::Item> {
1211        self.base.next_back().map(|offset_range| {
1212            let storage = self.data.slice_mut(offset_range);
1213            let storage = unsafe {
1214                // Safety: Outer view is non-broadcasting, and we increment the
1215                // outer index each time, so returned views will not overlap.
1216                std::mem::transmute::<ViewMutData<'_, T>, ViewMutData<'a, T>>(storage)
1217            };
1218            TensorBase::from_storage_and_layout(storage, self.base.inner_layout.clone())
1219        })
1220    }
1221}
1222
1223/// Iterator over slices of a tensor along an axis. See
1224/// [`TensorView::axis_iter`](crate::TensorView::axis_iter).
1225pub struct AxisIter<'a, T, L: Layout + RemoveDim> {
1226    view: TensorBase<ViewData<'a, T>, L>,
1227    axis: usize,
1228    index: usize,
1229    end: usize,
1230}
1231
1232impl<'a, T, L: MutLayout + RemoveDim> AxisIter<'a, T, L> {
1233    pub(crate) fn new(view: &TensorBase<ViewData<'a, T>, L>, axis: usize) -> AxisIter<'a, T, L> {
1234        assert!(axis < view.ndim());
1235        AxisIter {
1236            view: view.clone(),
1237            axis,
1238            index: 0,
1239            end: view.size(axis),
1240        }
1241    }
1242}
1243
1244impl<'a, T, L: MutLayout + RemoveDim> Iterator for AxisIter<'a, T, L> {
1245    type Item = TensorBase<ViewData<'a, T>, <L as RemoveDim>::Output>;
1246
1247    fn next(&mut self) -> Option<Self::Item> {
1248        if self.index >= self.end {
1249            None
1250        } else {
1251            let slice = self.view.index_axis(self.axis, self.index);
1252            self.index += 1;
1253            Some(slice)
1254        }
1255    }
1256
1257    fn size_hint(&self) -> (usize, Option<usize>) {
1258        let len = self.end - self.index;
1259        (len, Some(len))
1260    }
1261}
1262
1263impl<'a, T, L: MutLayout + RemoveDim> ExactSizeIterator for AxisIter<'a, T, L> {}
1264
1265impl<'a, T, L: MutLayout + RemoveDim> DoubleEndedIterator for AxisIter<'a, T, L> {
1266    fn next_back(&mut self) -> Option<Self::Item> {
1267        if self.index >= self.end {
1268            None
1269        } else {
1270            let slice = self.view.index_axis(self.axis, self.end - 1);
1271            self.end -= 1;
1272            Some(slice)
1273        }
1274    }
1275}
1276
1277/// Iterator over mutable slices of a tensor along an axis. See [`TensorViewMut::axis_iter_mut`].
1278pub struct AxisIterMut<'a, T, L: Layout + RemoveDim> {
1279    view: TensorBase<ViewMutData<'a, T>, L>,
1280    axis: usize,
1281    index: usize,
1282    end: usize,
1283}
1284
1285impl<'a, T, L: Layout + RemoveDim + Clone> AxisIterMut<'a, T, L> {
1286    pub(crate) fn new(
1287        view: TensorBase<ViewMutData<'a, T>, L>,
1288        axis: usize,
1289    ) -> AxisIterMut<'a, T, L> {
1290        // See notes in `Layout` about internal overlap.
1291        assert!(
1292            !view.layout().is_broadcast(),
1293            "Cannot mutably iterate over broadcasting view"
1294        );
1295        assert!(axis < view.ndim());
1296        AxisIterMut {
1297            axis,
1298            index: 0,
1299            end: view.size(axis),
1300            view,
1301        }
1302    }
1303}
1304
1305/// Mutable tensor view with one less dimension than `L`.
1306type SmallerMutView<'b, T, L> = TensorBase<ViewMutData<'b, T>, <L as RemoveDim>::Output>;
1307
1308impl<'a, T, L: MutLayout + RemoveDim> Iterator for AxisIterMut<'a, T, L> {
1309    type Item = TensorBase<ViewMutData<'a, T>, <L as RemoveDim>::Output>;
1310
1311    fn next(&mut self) -> Option<Self::Item> {
1312        if self.index >= self.end {
1313            None
1314        } else {
1315            let index = self.index;
1316            self.index += 1;
1317
1318            let slice = self.view.index_axis_mut(self.axis, index);
1319
1320            // Promote lifetime from self -> 'a.
1321            //
1322            // Safety: This is non-broadcasting view, and we increment the index
1323            // each time, so returned views will not overlap.
1324            let view = unsafe { transmute::<SmallerMutView<'_, T, L>, Self::Item>(slice) };
1325
1326            Some(view)
1327        }
1328    }
1329
1330    fn size_hint(&self) -> (usize, Option<usize>) {
1331        let len = self.end - self.index;
1332        (len, Some(len))
1333    }
1334}
1335
1336impl<'a, T, L: MutLayout + RemoveDim> ExactSizeIterator for AxisIterMut<'a, T, L> {}
1337
1338impl<'a, T, L: MutLayout + RemoveDim> DoubleEndedIterator for AxisIterMut<'a, T, L> {
1339    fn next_back(&mut self) -> Option<Self::Item> {
1340        if self.index >= self.end {
1341            None
1342        } else {
1343            let index = self.end - 1;
1344            self.end -= 1;
1345
1346            let slice = self.view.index_axis_mut(self.axis, index);
1347
1348            // Promote lifetime from self -> 'a.
1349            //
1350            // Safety: This is non-broadcasting view, and we increment the index
1351            // each time, so returned views will not overlap.
1352            let view = unsafe { transmute::<SmallerMutView<'_, T, L>, Self::Item>(slice) };
1353
1354            Some(view)
1355        }
1356    }
1357}
1358
1359/// Iterator over slices of a tensor along an axis. See
1360/// [`TensorView::axis_chunks`](crate::TensorView::axis_chunks).
1361pub struct AxisChunks<'a, T, L: MutLayout> {
1362    remainder: Option<TensorBase<ViewData<'a, T>, L>>,
1363    axis: usize,
1364    chunk_size: usize,
1365}
1366
1367impl<'a, T, L: MutLayout> AxisChunks<'a, T, L> {
1368    pub(crate) fn new(
1369        view: &TensorBase<ViewData<'a, T>, L>,
1370        axis: usize,
1371        chunk_size: usize,
1372    ) -> AxisChunks<'a, T, L> {
1373        assert!(chunk_size > 0, "chunk size must be > 0");
1374        AxisChunks {
1375            remainder: if view.size(axis) > 0 {
1376                Some(view.view())
1377            } else {
1378                None
1379            },
1380            axis,
1381            chunk_size,
1382        }
1383    }
1384}
1385
1386impl<'a, T, L: MutLayout> Iterator for AxisChunks<'a, T, L> {
1387    type Item = TensorBase<ViewData<'a, T>, L>;
1388
1389    fn next(&mut self) -> Option<Self::Item> {
1390        let remainder = self.remainder.take()?;
1391        let chunk_len = self.chunk_size.min(remainder.size(self.axis));
1392        let (current, next_remainder) = remainder.split_at(self.axis, chunk_len);
1393        self.remainder = if next_remainder.size(self.axis) > 0 {
1394            Some(next_remainder)
1395        } else {
1396            None
1397        };
1398        Some(current)
1399    }
1400
1401    fn size_hint(&self) -> (usize, Option<usize>) {
1402        let len = self
1403            .remainder
1404            .as_ref()
1405            .map(|r| r.size(self.axis))
1406            .unwrap_or(0)
1407            .div_ceil(self.chunk_size);
1408        (len, Some(len))
1409    }
1410}
1411
1412impl<'a, T, L: MutLayout> ExactSizeIterator for AxisChunks<'a, T, L> {}
1413
1414impl<'a, T, L: MutLayout> DoubleEndedIterator for AxisChunks<'a, T, L> {
1415    fn next_back(&mut self) -> Option<Self::Item> {
1416        let remainder = self.remainder.take()?;
1417        let chunk_len = self.chunk_size.min(remainder.size(self.axis));
1418        let (prev_remainder, current) =
1419            remainder.split_at(self.axis, remainder.size(self.axis) - chunk_len);
1420        self.remainder = if prev_remainder.size(self.axis) > 0 {
1421            Some(prev_remainder)
1422        } else {
1423            None
1424        };
1425        Some(current)
1426    }
1427}
1428
1429/// Iterator over mutable slices of a tensor along an axis. See [`TensorViewMut::axis_chunks_mut`].
1430pub struct AxisChunksMut<'a, T, L: MutLayout> {
1431    remainder: Option<TensorBase<ViewMutData<'a, T>, L>>,
1432    axis: usize,
1433    chunk_size: usize,
1434}
1435
1436impl<'a, T, L: MutLayout> AxisChunksMut<'a, T, L> {
1437    pub(crate) fn new(
1438        view: TensorBase<ViewMutData<'a, T>, L>,
1439        axis: usize,
1440        chunk_size: usize,
1441    ) -> AxisChunksMut<'a, T, L> {
1442        // See notes in `Layout` about internal overlap.
1443        assert!(
1444            !view.layout().is_broadcast(),
1445            "Cannot mutably iterate over broadcasting view"
1446        );
1447        assert!(chunk_size > 0, "chunk size must be > 0");
1448        AxisChunksMut {
1449            remainder: if view.size(axis) > 0 {
1450                Some(view)
1451            } else {
1452                None
1453            },
1454            axis,
1455            chunk_size,
1456        }
1457    }
1458}
1459
1460impl<'a, T, L: MutLayout> Iterator for AxisChunksMut<'a, T, L> {
1461    type Item = TensorBase<ViewMutData<'a, T>, L>;
1462
1463    fn next(&mut self) -> Option<Self::Item> {
1464        let remainder = self.remainder.take()?;
1465        let chunk_len = self.chunk_size.min(remainder.size(self.axis));
1466        let (current, next_remainder) = remainder.split_at_mut(self.axis, chunk_len);
1467        self.remainder = if next_remainder.size(self.axis) > 0 {
1468            Some(next_remainder)
1469        } else {
1470            None
1471        };
1472        Some(current)
1473    }
1474
1475    fn size_hint(&self) -> (usize, Option<usize>) {
1476        let len = self
1477            .remainder
1478            .as_ref()
1479            .map(|r| r.size(self.axis))
1480            .unwrap_or(0)
1481            .div_ceil(self.chunk_size);
1482        (len, Some(len))
1483    }
1484}
1485
1486impl<'a, T, L: MutLayout> ExactSizeIterator for AxisChunksMut<'a, T, L> {}
1487
1488impl<'a, T, L: MutLayout> DoubleEndedIterator for AxisChunksMut<'a, T, L> {
1489    fn next_back(&mut self) -> Option<Self::Item> {
1490        let remainder = self.remainder.take()?;
1491        let remainder_size = remainder.size(self.axis);
1492        let chunk_len = self.chunk_size.min(remainder_size);
1493        let (prev_remainder, current) =
1494            remainder.split_at_mut(self.axis, remainder_size - chunk_len);
1495        self.remainder = if prev_remainder.size(self.axis) > 0 {
1496            Some(prev_remainder)
1497        } else {
1498            None
1499        };
1500        Some(current)
1501    }
1502}
1503
1504/// Call `f` on each element of `view`.
1505pub(crate) fn for_each_mut<T, F: Fn(&mut T)>(mut view: TensorViewMut<T>, f: F) {
1506    while view.ndim() < 4 {
1507        view.insert_axis(0);
1508    }
1509
1510    // This could be improved by sorting dimensions of `view` in order of
1511    // decreasing stride. If the resulting view is contiguous, `f` can be
1512    // applied to the underlying data directly. Even if it isn't, this will
1513    // still make memory access as contiguous as possible.
1514
1515    view.inner_iter_mut::<4>().for_each(|mut src| {
1516        for i0 in 0..src.size(0) {
1517            for i1 in 0..src.size(1) {
1518                for i2 in 0..src.size(2) {
1519                    for i3 in 0..src.size(3) {
1520                        // Safety: i0..i3 are in `[0, src.size(i))`.
1521                        let x = unsafe { src.get_unchecked_mut([i0, i1, i2, i3]) };
1522                        f(x);
1523                    }
1524                }
1525            }
1526        }
1527    });
1528}
1529
1530// Tests for iterator internals. Most tests of iterators are currently done via
1531// tests on tensor methods.
1532#[cfg(test)]
1533mod tests {
1534    use super::{AxisChunks, AxisChunksMut, Lanes, LanesMut};
1535    use crate::{AsView, Layout, NdLayout, NdTensor, Tensor};
1536
1537    fn compare_reversed<T: PartialEq + std::fmt::Debug>(fwd: &[T], rev: &[T]) {
1538        assert_eq!(fwd.len(), rev.len());
1539        for (x, y) in fwd.iter().zip(rev.iter().rev()) {
1540            assert_eq!(x, y);
1541        }
1542    }
1543
1544    /// Apply a standard set of tests to an iterator.
1545    fn test_iterator<I: Iterator + ExactSizeIterator + DoubleEndedIterator>(
1546        create_iter: impl Fn() -> I,
1547        expected: &[I::Item],
1548    ) where
1549        I::Item: PartialEq + std::fmt::Debug,
1550    {
1551        let iter = create_iter();
1552
1553        let (min_len, max_len) = iter.size_hint();
1554        let items: Vec<_> = iter.collect();
1555
1556        assert_eq!(&items, expected);
1557
1558        // Test ExactSizeIterator via `size_hint`.
1559        assert_eq!(min_len, items.len(), "incorrect size lower bound");
1560        assert_eq!(max_len, Some(items.len()), "incorrect size upper bound");
1561
1562        // Test DoubleEndedIterator via `rev`.
1563        let rev_items: Vec<_> = create_iter().rev().collect();
1564        compare_reversed(&items, &rev_items);
1565
1566        // Test FusedIterator.
1567        let mut iter = create_iter();
1568        for _x in &mut iter { /* noop */ }
1569        assert_eq!(iter.next(), None);
1570
1571        // Test fold.
1572        let mut fold_items = Vec::new();
1573        let mut idx = 0;
1574        create_iter().fold(0, |acc, item| {
1575            assert_eq!(acc, idx);
1576            fold_items.push(item);
1577            idx += 1;
1578            idx
1579        });
1580        assert_eq!(items, fold_items);
1581    }
1582
1583    /// A collection that can be mutably iterated over multiple times.
1584    ///
1585    /// We use a different pattern for testing mutable iterators to avoid
1586    /// restrictions on values returned from `FnMut` closures.
1587    trait MutIterable {
1588        type Iter<'a>: Iterator + ExactSizeIterator + DoubleEndedIterator
1589        where
1590            Self: 'a;
1591
1592        fn iter_mut(&mut self) -> Self::Iter<'_>;
1593    }
1594
1595    /// Apply a standard set of tests to a mutable iterator.
1596    fn test_mut_iterator<M, T>(mut iterable: M, expected: &[T])
1597    where
1598        M: MutIterable,
1599        T: std::fmt::Debug,
1600        for<'a> <M::Iter<'a> as Iterator>::Item: std::fmt::Debug + PartialEq + PartialEq<T>,
1601    {
1602        // Test Iterator and ExactSizeIterator.
1603        {
1604            let iter = iterable.iter_mut();
1605            let (min_len, max_len) = iter.size_hint();
1606            let items: Vec<_> = iter.collect();
1607
1608            // Test `next`
1609            assert_eq!(items, expected);
1610
1611            // Test `size_hint`
1612            assert_eq!(min_len, items.len(), "incorrect size lower bound");
1613            assert_eq!(max_len, Some(items.len()), "incorrect size upper bound");
1614        }
1615
1616        // Test FusedIterator.
1617        {
1618            let mut iter = iterable.iter_mut();
1619            for _x in &mut iter { /* noop */ }
1620            assert!(iter.next().is_none());
1621        }
1622
1623        // Test DoubleEndedIterator via `rev`.
1624        //
1625        // We use `format!` here to convert mutable references into comparable
1626        // items that have no connection to the mutable references yielded by
1627        // the iterator. Ideally this should be replaced by a clone or something.
1628        {
1629            let items: Vec<_> = iterable.iter_mut().map(|x| format!("{:?}", x)).collect();
1630            let rev_items: Vec<_> = iterable
1631                .iter_mut()
1632                .rev()
1633                .map(|x| format!("{:?}", x))
1634                .collect();
1635            compare_reversed(&items, &rev_items);
1636        }
1637
1638        // Test fold.
1639        {
1640            let items: Vec<_> = iterable.iter_mut().map(|x| format!("{:?}", x)).collect();
1641            let mut fold_items = Vec::new();
1642            let mut idx = 0;
1643            iterable.iter_mut().fold(0, |acc, item| {
1644                assert_eq!(acc, idx);
1645                fold_items.push(format!("{:?}", item));
1646                idx += 1;
1647                idx
1648            });
1649            assert_eq!(items, fold_items);
1650        }
1651    }
1652
1653    #[test]
1654    fn test_axis_chunks() {
1655        let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1656        test_iterator(
1657            || tensor.axis_chunks(0, 1),
1658            &[tensor.slice(0..1), tensor.slice(1..2)],
1659        );
1660    }
1661
1662    #[test]
1663    fn test_axis_chunks_empty() {
1664        let x = Tensor::<i32>::zeros(&[5, 0]);
1665        assert!(AxisChunks::new(&x.view(), 1, 1).next().is_none());
1666    }
1667
1668    #[test]
1669    #[should_panic(expected = "chunk size must be > 0")]
1670    fn test_axis_chunks_zero_size() {
1671        let x = Tensor::<i32>::zeros(&[5, 0]);
1672        assert!(AxisChunks::new(&x.view(), 1, 0).next().is_none());
1673    }
1674
1675    #[test]
1676    fn test_axis_chunks_mut_empty() {
1677        let mut x = Tensor::<i32>::zeros(&[5, 0]);
1678        assert!(AxisChunksMut::new(x.view_mut(), 1, 1).next().is_none());
1679    }
1680
1681    #[test]
1682    fn test_axis_chunks_mut_rev() {
1683        let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1684        let fwd: Vec<_> = tensor
1685            .axis_chunks_mut(0, 1)
1686            .map(|view| view.to_vec())
1687            .collect();
1688        let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1689        let rev: Vec<_> = tensor
1690            .axis_chunks_mut(0, 1)
1691            .rev()
1692            .map(|view| view.to_vec())
1693            .collect();
1694        compare_reversed(&fwd, &rev);
1695    }
1696
1697    #[test]
1698    #[should_panic(expected = "chunk size must be > 0")]
1699    fn test_axis_chunks_mut_zero_size() {
1700        let mut x = Tensor::<i32>::zeros(&[5, 0]);
1701        assert!(AxisChunksMut::new(x.view_mut(), 1, 0).next().is_none());
1702    }
1703
1704    #[test]
1705    fn test_axis_iter() {
1706        let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1707        test_iterator(|| tensor.axis_iter(0), &[tensor.slice(0), tensor.slice(1)]);
1708    }
1709
1710    #[test]
1711    fn test_axis_iter_mut_rev() {
1712        let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1713        let fwd: Vec<_> = tensor.axis_iter_mut(0).map(|view| view.to_vec()).collect();
1714        let mut tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1715        let rev: Vec<_> = tensor
1716            .axis_iter_mut(0)
1717            .rev()
1718            .map(|view| view.to_vec())
1719            .collect();
1720        compare_reversed(&fwd, &rev);
1721    }
1722
1723    #[test]
1724    fn test_inner_iter() {
1725        let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1726        test_iterator(
1727            || tensor.inner_iter::<2>(),
1728            &[tensor.slice(0), tensor.slice(1)],
1729        );
1730    }
1731
1732    #[test]
1733    fn test_inner_iter_mut() {
1734        struct InnerIterMutTest(NdTensor<i32, 3>);
1735
1736        impl MutIterable for InnerIterMutTest {
1737            type Iter<'a> = super::InnerIterMut<'a, i32, NdLayout<2>>;
1738
1739            fn iter_mut(&mut self) -> Self::Iter<'_> {
1740                self.0.inner_iter_mut::<2>()
1741            }
1742        }
1743
1744        let tensor = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
1745        test_mut_iterator(
1746            InnerIterMutTest(tensor.clone()),
1747            &[tensor.slice(0), tensor.slice(1)],
1748        );
1749    }
1750
1751    #[test]
1752    fn test_lanes() {
1753        let x = NdTensor::from([[1, 2], [3, 4]]);
1754        test_iterator(
1755            || x.lanes(0),
1756            &[x.slice((.., 0)).into(), x.slice((.., 1)).into()],
1757        );
1758        test_iterator(|| x.lanes(1), &[x.slice(0).into(), x.slice(1).into()]);
1759    }
1760
1761    #[test]
1762    fn test_lanes_empty() {
1763        let x = Tensor::<i32>::zeros(&[5, 0]);
1764        assert!(Lanes::new(x.view().view_ref(), 0).next().is_none());
1765        assert!(Lanes::new(x.view().view_ref(), 1).next().is_none());
1766    }
1767
1768    #[test]
1769    fn test_lanes_mut() {
1770        use super::Lane;
1771
1772        struct LanesMutTest(NdTensor<i32, 2>);
1773
1774        impl MutIterable for LanesMutTest {
1775            type Iter<'a> = super::LanesMut<'a, i32>;
1776
1777            fn iter_mut(&mut self) -> Self::Iter<'_> {
1778                self.0.lanes_mut(0)
1779            }
1780        }
1781
1782        let tensor = NdTensor::from([[1, 2], [3, 4]]);
1783        test_mut_iterator::<_, Lane<i32>>(
1784            LanesMutTest(tensor.clone()),
1785            &[
1786                Lane::from(tensor.slice((.., 0))),
1787                Lane::from(tensor.slice((.., 1))),
1788            ],
1789        );
1790    }
1791
1792    #[test]
1793    fn test_lane_as_slice() {
1794        // Contiguous lane
1795        let x = NdTensor::from([0, 1, 2]);
1796        let mut lane = x.lanes(0).next().unwrap();
1797        assert_eq!(lane.as_slice(), Some([0, 1, 2].as_slice()));
1798        lane.next();
1799        assert_eq!(lane.as_slice(), Some([1, 2].as_slice()));
1800        lane.next();
1801        lane.next();
1802        assert_eq!(lane.as_slice(), Some([0i32; 0].as_slice()));
1803        lane.next();
1804        assert_eq!(lane.as_slice(), Some([0i32; 0].as_slice()));
1805
1806        // Non-contiguous lane
1807        let x = NdTensor::from([[1i32, 2], [3, 4]]);
1808        let lane = x.lanes(0).next().unwrap();
1809        assert_eq!(lane.as_slice(), None);
1810    }
1811
1812    #[test]
1813    fn test_lanes_mut_empty() {
1814        let mut x = Tensor::<i32>::zeros(&[5, 0]);
1815        assert!(LanesMut::new(x.mut_view_ref(), 0).next().is_none());
1816        assert!(LanesMut::new(x.mut_view_ref(), 1).next().is_none());
1817    }
1818
1819    #[test]
1820    fn test_iter_step_by() {
1821        let tensor = Tensor::<f32>::full(&[1, 3, 16, 8], 1.);
1822
1823        // Take a non-contiguous slice so we don't use the fast path for
1824        // contiguous tensors.
1825        let tensor = tensor.slice((.., .., 1.., ..));
1826
1827        let sum = tensor.iter().sum::<f32>();
1828        for n_skip in 0..tensor.len() {
1829            let sum_skip = tensor.iter().skip(n_skip).sum::<f32>();
1830            assert_eq!(
1831                sum_skip,
1832                sum - n_skip as f32,
1833                "wrong sum for n_skip={}",
1834                n_skip
1835            );
1836        }
1837    }
1838
1839    #[test]
1840    fn test_iter_broadcast() {
1841        let tensor = Tensor::<f32>::full(&[1], 1.);
1842        let broadcast = tensor.broadcast([1, 3, 16, 8]);
1843        assert_eq!(broadcast.iter().len(), broadcast.len());
1844        let count = broadcast.iter().count();
1845        assert_eq!(count, broadcast.len());
1846        let sum = broadcast.iter().sum::<f32>();
1847        assert_eq!(sum, broadcast.len() as f32);
1848    }
1849
1850    #[test]
1851    fn test_iter() {
1852        let tensor = NdTensor::from([[[1, 2], [3, 4]]]);
1853
1854        // Test iterator over contiguous tensor.
1855        test_iterator(|| tensor.iter().copied(), &[1, 2, 3, 4]);
1856
1857        // Test iterator over non-contiguous tensor.
1858        test_iterator(|| tensor.transposed().iter().copied(), &[1, 3, 2, 4]);
1859    }
1860
1861    #[test]
1862    fn test_iter_mut() {
1863        struct IterTest(NdTensor<i32, 3>);
1864
1865        impl MutIterable for IterTest {
1866            type Iter<'a> = super::IterMut<'a, i32>;
1867
1868            fn iter_mut(&mut self) -> Self::Iter<'_> {
1869                self.0.iter_mut()
1870            }
1871        }
1872
1873        let tensor = NdTensor::from([[[1, 2], [3, 4]]]);
1874        test_mut_iterator(IterTest(tensor), &[&1, &2, &3, &4]);
1875    }
1876
1877    #[test]
1878    #[ignore]
1879    fn bench_iter() {
1880        use crate::Layout;
1881        use rten_bench::run_bench;
1882
1883        type Elem = i32;
1884
1885        let tensor = std::hint::black_box(Tensor::<Elem>::full(&[1, 6, 768, 64], 1));
1886        let n_trials = 1000;
1887        let mut result = Elem::default();
1888
1889        fn reduce<'a>(iter: impl Iterator<Item = &'a Elem>) -> Elem {
1890            iter.fold(Elem::default(), |acc, x| acc.wrapping_add(*x))
1891        }
1892
1893        // Iterate directly over data slice.
1894        run_bench(n_trials, Some("slice iter"), || {
1895            result = reduce(tensor.data().unwrap().iter());
1896        });
1897        println!("sum {}", result);
1898
1899        // Use tensor iterator with contiguous tensor. This will use the fast
1900        // path which wraps a slice iterator.
1901        run_bench(n_trials, Some("contiguous iter"), || {
1902            result = reduce(tensor.iter());
1903        });
1904        println!("sum {}", result);
1905
1906        run_bench(n_trials, Some("contiguous reverse iter"), || {
1907            result = reduce(tensor.iter().rev());
1908        });
1909        println!("sum {}", result);
1910
1911        // Use tensor iterator with non-contiguous slice. This will fall back
1912        // to indexed iteration.
1913        let slice = tensor.slice((.., .., 1.., ..));
1914        assert!(!slice.is_contiguous());
1915        let n_trials = 1000;
1916        run_bench(n_trials, Some("non-contiguous iter"), || {
1917            result = reduce(slice.iter());
1918        });
1919        println!("sum {}", result);
1920
1921        // Reverse iteration with non-contiguous slice. This is much slower
1922        // because it translates linear indexes into offsets using division.
1923        let n_trials = 100;
1924        run_bench(n_trials, Some("non-contiguous reverse iter"), || {
1925            result = reduce(slice.iter().rev());
1926        });
1927        println!("sum {}", result);
1928    }
1929
1930    #[test]
1931    #[ignore]
1932    fn bench_inner_iter() {
1933        use crate::rng::XorShiftRng;
1934        use rten_bench::run_bench;
1935
1936        let n_trials = 100;
1937        let mut rng = XorShiftRng::new(1234);
1938
1939        // Tensor with many steps along the outer two dimensions relative to the
1940        // steps along the inner two dimensions. This emphasizes the overhead of
1941        // stepping `inner_iter`.
1942        let tensor = Tensor::<f32>::rand(&[512, 512, 12, 1], &mut rng);
1943
1944        let mut sum = 0.;
1945        run_bench(n_trials, Some("inner iter"), || {
1946            for inner in tensor.inner_iter::<2>() {
1947                for i0 in 0..inner.size(0) {
1948                    for i1 in 0..inner.size(1) {
1949                        sum += inner[[i0, i1]];
1950                    }
1951                }
1952            }
1953        });
1954        println!("sum {}", sum);
1955    }
1956}