rten_tensor/
layout.rs

1//! Layouts which describe the shape and strides of a tensor.
2
3use std::iter::repeat;
4use std::ops::Range;
5
6use smallvec::{SmallVec, smallvec};
7
8use crate::errors::{DimensionError, ExpandError, FromDataError, ReshapeError, SliceError};
9use crate::index_iterator::{DynIndices, NdIndices};
10use crate::overlap::{is_contiguous, may_have_internal_overlap};
11use crate::slice_range::{IntoSliceItems, SliceItem};
12use crate::type_num::{OptionalUInt, U0, U1, U2, U3, U4, U5, Unknown};
13
14/// Return true if `permutation` is a valid permutation of dimensions for
15/// a tensor of rank `ndim`.
16pub fn is_valid_permutation(ndim: usize, permutation: &[usize]) -> bool {
17    permutation.len() == ndim
18        && (0..ndim).all(|dim| permutation.iter().filter(|d| **d == dim).count() == 1)
19}
20
21/// Merge dimensions of a layout where possible.
22///
23/// Two dimensions with sizes N1, N2 and strides S1, S2 can be merged into a
24/// single dimension with size N1 * N2 and stride S2 if S1 = N1 * S2;
25///
26/// Returns a vector of `(size, stride)` tuples for the merged dimensions.
27pub(crate) fn merge_axes(shape: &[usize], strides: &[usize]) -> SmallVec<[(usize, usize); 4]> {
28    let (Some(prev_size), Some(prev_stride)) = (shape.last(), strides.last()) else {
29        return SmallVec::new();
30    };
31
32    let mut merged: SmallVec<[(usize, usize); 4]> = SmallVec::with_capacity(shape.len());
33    merged.push((*prev_size, *prev_stride));
34
35    for (&outer_size, &outer_stride) in shape.iter().zip(strides.iter()).rev().skip(1) {
36        let (inner_size, inner_stride) = merged.last_mut().unwrap();
37        let can_merge = outer_size == 1 || (outer_stride == *inner_stride * *inner_size);
38        if can_merge {
39            *inner_size *= outer_size;
40        } else {
41            merged.push((outer_size, outer_stride));
42        }
43    }
44
45    merged.reverse();
46
47    merged
48}
49
50/// Generate debug assertion that a dimension index is valid for a layout.
51macro_rules! debug_assert_dim_valid {
52    ($layout:ident, $dim:expr) => {
53        debug_assert!(
54            $dim < $layout.ndim(),
55            "dim {} out of bounds for tensor with {} dims",
56            $dim,
57            $layout.ndim()
58        )
59    };
60}
61
62/// Describes the shape and strides of a tensor.
63///
64/// The `Layout` trait provides methods to query the shape of a tensor, i.e. the
65/// number of dimensions and size of each, and the strides which determine the
66/// mapping between logical indices in the tensor and offsets in the data storage.
67///
68/// This trait is implemented for tensor types
69/// ([`TensorBase`](crate::TensorBase)), as well as the underlying layout types
70/// such as [`NdLayout`] (for tensors with a static number of dimensions) and
71/// [`DynLayout`] (for tensors with a dynamic number of dimensions).
72pub trait Layout {
73    /// Type used to represent indices.
74    ///
75    /// It is assumed that this type can also represent the shape and strides
76    /// of the tensor.
77    type Index<'a>: AsRef<[usize]> + Clone + std::fmt::Debug + PartialEq<Self::Index<'a>>;
78
79    /// Iterator over indices in this tensor.
80    type Indices;
81
82    /// Map an index to a storage offset, without checking if it is valid for
83    /// the tensor's shape.
84    ///
85    /// This method is not itself unsafe, because it only computes a storage
86    /// offset but does not access any data. Using the offset to index into
87    /// storage without a bounds check is unsafe however.
88    fn offset_unchecked(&self, index: Self::Index<'_>) -> usize {
89        index
90            .as_ref()
91            .iter()
92            .zip(self.strides().as_ref())
93            .map(|(idx, stride)| *idx * *stride)
94            .sum()
95    }
96
97    /// Map an index to a storage offset, or return `None` if the index is out
98    /// of bounds along any dimension.
99    ///
100    /// Offsets returned by this method must be less than the layout's minimum
101    /// storage length reported by [`min_data_len`](Layout::min_data_len).
102    /// If a layout also implements [`TrustedLayout`] then callers can rely
103    /// on this to avoid subsequent bounds checks.
104    fn offset(&self, index: Self::Index<'_>) -> Option<usize>;
105
106    /// Return the number of dimensions.
107    fn ndim(&self) -> usize;
108
109    /// Returns the number of elements in the array.
110    fn len(&self) -> usize;
111
112    /// Return true if this layout describes a contiguous tensor, where the
113    /// logical order of elements matches the order in which they are stored.
114    fn is_contiguous(&self) -> bool {
115        is_contiguous(self.shape(), self.strides())
116    }
117
118    /// Return true if iterating over elements in this layout will visit
119    /// elements multiple times.
120    fn is_broadcast(&self) -> bool {
121        !self.is_empty() && self.strides().as_ref().contains(&0)
122    }
123
124    /// Returns true if the array has no elements.
125    fn is_empty(&self) -> bool {
126        self.len() == 0
127    }
128
129    /// Returns an array of the sizes of each dimension.
130    fn shape(&self) -> Self::Index<'_>;
131
132    /// Returns the size of the dimension `dim`.
133    fn size(&self, dim: usize) -> usize {
134        debug_assert_dim_valid!(self, dim);
135        self.shape().as_ref()[dim]
136    }
137
138    /// Returns an array of the strides of each dimension.
139    fn strides(&self) -> Self::Index<'_>;
140
141    /// Returns the offset between adjacent indices along dimension `dim`.
142    fn stride(&self, dim: usize) -> usize {
143        debug_assert_dim_valid!(self, dim);
144        self.strides().as_ref()[dim]
145    }
146
147    /// Return an iterator over all valid indices in this tensor.
148    fn indices(&self) -> Self::Indices;
149
150    /// Return true if this layout's shape can be broadcast to the given shape.
151    fn can_broadcast_to(&self, target_shape: &[usize]) -> bool {
152        if self.shape().as_ref() == target_shape {
153            return true;
154        } else if self.ndim() > target_shape.len() {
155            return false;
156        }
157
158        // For two shapes to be compatible for broadcasting, each dimension must
159        // either be the same or be 1.
160        //
161        // If the tensor has fewer dimensions, pretend that it was prefixed with
162        // 1-length dimensions to make the dimension counts equal.
163        let target_dims = target_shape[target_shape.len() - self.shape().as_ref().len()..]
164            .iter()
165            .copied();
166
167        self.shape()
168            .as_ref()
169            .iter()
170            .copied()
171            .zip(target_dims)
172            .all(|(a, b)| a == b || a == 1)
173    }
174
175    /// Return true if the tensor/view can be broadcast with another tensor or
176    /// view with a given `shape` as part of a binary operation.
177    ///
178    /// The shape of the result may be larger than either the current shape
179    /// or `shape`. eg. If a tensor of shape `[1, 5]` is broadcast with one
180    /// of size `[2, 1, 1]` the result has shape `[2, 1, 5]`.
181    ///
182    /// See <https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md> for
183    /// conditions in which broadcasting is allowed.
184    fn can_broadcast_with(&self, shape: &[usize]) -> bool {
185        if self.shape().as_ref() == shape {
186            return true;
187        }
188
189        // For two shapes to be compatible for broadcasting, each dimension must
190        // either be the same or be 1.
191        //
192        // If the tensor has fewer dimensions, pretend that it was prefixed with
193        // 1-length dimensions to make the dimension counts equal.
194
195        let current_shape = self.shape();
196        let a = current_shape.as_ref();
197        let b = shape;
198
199        let a_pad = b.len().saturating_sub(a.len());
200        let b_pad = a.len().saturating_sub(b.len());
201
202        let a_iter = a.iter().copied().rev().chain(repeat(1).take(a_pad));
203        let b_iter = b.iter().copied().rev().chain(repeat(1).take(b_pad));
204
205        a_iter.zip(b_iter).all(|(a, b)| a == b || a == 1 || b == 1)
206    }
207
208    /// Return the minimum length required for the element data buffer used
209    /// with this layout.
210    fn min_data_len(&self) -> usize {
211        if self.shape().as_ref().contains(&0) {
212            return 0;
213        }
214        let max_offset: usize = self
215            .shape()
216            .as_ref()
217            .iter()
218            .zip(self.strides().as_ref())
219            .map(|(size, stride)| (size - 1) * stride)
220            .sum();
221        max_offset + 1
222    }
223}
224
225/// A layout which upholds guarantees on returned storage offsets.
226///
227/// # Safety
228///
229/// Layouts which implement this trait promise that any offsets returned by
230/// [`offset`](Layout::offset) and
231/// [`offset_unchecked`](Layout::offset_unchecked) are less than the than the
232/// minimum required storage length reported by
233/// [`min_data_len`](Layout::min_data_len). This promise means that the offsets
234/// can be used to access elements in a buffer without a bounds check.
235pub unsafe trait TrustedLayout: Layout {}
236
237/// Extension methods for layouts.
238///
239/// These are separate from the [`Layout`] trait to prevent them from being
240/// overridden.
241pub(crate) trait LayoutExt: Layout {
242    /// Return the offset for an index or panic if invalid.
243    #[inline]
244    fn must_offset(&self, index: Self::Index<'_>) -> usize {
245        self.offset(index.clone()).unwrap_or_else(|| {
246            panic!(
247                "index {:?} out of bounds for shape {:?}",
248                index.as_ref(),
249                self.shape().as_ref()
250            )
251        })
252    }
253}
254
255impl<L: Layout> LayoutExt for L {}
256
257/// Provides convenience methods for querying the shape and strides of a matrix.
258pub trait MatrixLayout {
259    fn rows(&self) -> usize;
260    fn cols(&self) -> usize;
261    fn row_stride(&self) -> usize;
262    fn col_stride(&self) -> usize;
263}
264
265/// Specifies whether a tensor or view may have an overlapping layout.
266///
267/// An overlapping layout is one in which multiple valid indices map to the same
268/// offset in storage. To comply with Rust's rules for mutable aliases, mutable
269/// tensors/views must disallow overlap.
270pub enum OverlapPolicy {
271    AllowOverlap,
272    DisallowOverlap,
273}
274
275/// Defines the valid indices for an N-dimensional array and how to map them
276/// to offsets in a linear buffer, where N is known at compile time.
277#[derive(Clone, Copy, Debug, PartialEq)]
278pub struct NdLayout<const N: usize> {
279    shape: [usize; N],
280    strides: [usize; N],
281}
282
283impl<const N: usize> Layout for NdLayout<N> {
284    type Index<'a> = [usize; N];
285    type Indices = NdIndices<N>;
286
287    fn ndim(&self) -> usize {
288        N
289    }
290
291    fn len(&self) -> usize {
292        self.shape.iter().product()
293    }
294
295    #[inline]
296    fn offset(&self, index: [usize; N]) -> Option<usize> {
297        if !self.index_valid(index) {
298            return None;
299        }
300        Some(self.offset_unchecked(index))
301    }
302
303    #[inline]
304    fn offset_unchecked(&self, index: [usize; N]) -> usize {
305        let mut offset = 0;
306        for i in 0..N {
307            offset += index[i] * self.strides[i];
308        }
309        offset
310    }
311
312    #[inline]
313    fn shape(&self) -> Self::Index<'_> {
314        self.shape
315    }
316
317    #[inline]
318    fn strides(&self) -> Self::Index<'_> {
319        self.strides
320    }
321
322    fn indices(&self) -> Self::Indices {
323        NdIndices::from_shape(self.shape)
324    }
325}
326
327unsafe impl<const N: usize> TrustedLayout for NdLayout<N> {}
328
329impl<L: Layout> Layout for &L {
330    type Index<'b> = L::Index<'b>;
331    type Indices = L::Indices;
332
333    fn ndim(&self) -> usize {
334        (*self).ndim()
335    }
336
337    fn len(&self) -> usize {
338        (*self).len()
339    }
340
341    fn offset(&self, index: Self::Index<'_>) -> Option<usize> {
342        (*self).offset(index)
343    }
344
345    fn offset_unchecked(&self, index: Self::Index<'_>) -> usize {
346        (*self).offset_unchecked(index)
347    }
348
349    fn shape(&self) -> Self::Index<'_> {
350        (*self).shape()
351    }
352
353    fn strides(&self) -> Self::Index<'_> {
354        (*self).strides()
355    }
356
357    fn indices(&self) -> Self::Indices {
358        (*self).indices()
359    }
360}
361
362// The `Layout` impl for references proxies to the target, so upholds invariants
363// if the target does.
364unsafe impl<L: TrustedLayout> TrustedLayout for &L {}
365
366impl MatrixLayout for NdLayout<2> {
367    #[inline]
368    fn rows(&self) -> usize {
369        self.size(0)
370    }
371
372    #[inline]
373    fn cols(&self) -> usize {
374        self.size(1)
375    }
376
377    #[inline]
378    fn row_stride(&self) -> usize {
379        self.stride(0)
380    }
381
382    #[inline]
383    fn col_stride(&self) -> usize {
384        self.stride(1)
385    }
386}
387
388/// Compute the shape and strides of a layout after slicing with `range`.
389///
390/// Returns an `(ndim, offset)` tuple for the number of dimensions in the
391/// slice and the offset of the first element in the parent view's data.
392///
393/// This function is generic to allow for specialized variants to be generated
394/// when slicing with statically known input or output shape sizes.
395fn slice_layout<I: AsRef<[usize]>, O: AsMut<[usize]>>(
396    in_shape: I,
397    in_strides: I,
398    mut out_shape: O,
399    mut out_strides: O,
400    range: &[SliceItem],
401) -> Result<(usize, usize), SliceError> {
402    let in_shape = in_shape.as_ref();
403    let in_strides = in_strides.as_ref();
404    let out_shape = out_shape.as_mut();
405    let out_strides = out_strides.as_mut();
406
407    let mut ndim = 0;
408    let mut offset = 0;
409
410    for (in_dim, (&size, &stride)) in in_shape.iter().zip(in_strides.iter()).enumerate() {
411        let (offset_adjust, new_size_stride) = match range.get(in_dim) {
412            Some(&SliceItem::Index(idx)) => {
413                let pos_idx = if idx >= 0 { idx } else { idx + size as isize };
414                if pos_idx < 0 || pos_idx >= size as isize {
415                    return Err(SliceError::InvalidIndex {
416                        axis: in_dim,
417                        index: idx,
418                        size,
419                    });
420                }
421                (stride * pos_idx as usize, None)
422            }
423            Some(SliceItem::Range(range)) => {
424                let resolved = range.resolve(size).ok_or(SliceError::InvalidRange {
425                    axis: in_dim,
426                    range: *range,
427                    size,
428                })?;
429                let step: usize = range
430                    .step()
431                    .try_into()
432                    .map_err(|_| SliceError::InvalidStep {
433                        axis: in_dim,
434                        step: range.step(),
435                    })?;
436                let new_size = if step == 1 {
437                    // Fast path when no custom step is used.
438                    resolved.end - resolved.start
439                } else {
440                    range.index_range(size).steps()
441                };
442                let new_stride = stride * step;
443                (stride * resolved.start, Some((new_size, new_stride)))
444            }
445            None => (0, Some((size, stride))),
446        };
447
448        offset += offset_adjust;
449        if let Some((new_size, new_stride)) = new_size_stride {
450            out_shape[ndim] = new_size;
451            out_strides[ndim] = new_stride;
452            ndim += 1;
453        }
454    }
455
456    if out_shape.contains(&0) {
457        offset = 0;
458    }
459
460    Ok((ndim, offset))
461}
462
463/// Return an iterator over the strides of a layout that broadcasts a view
464/// with shape `from_shape` and strides `from_strides` to `to_shape`.
465fn broadcast_strides<'a>(
466    from_shape: &'a [usize],
467    from_strides: &'a [usize],
468    to_shape: &'a [usize],
469) -> impl Iterator<Item = usize> + 'a {
470    let pad = to_shape.len() - from_shape.len();
471    repeat(0)
472        .take(pad)
473        .chain(from_shape.iter().zip(from_strides.iter()).enumerate().map(
474            move |(i, (size, stride))| {
475                if *size == 1 && to_shape[i + pad] > 1 {
476                    0
477                } else {
478                    *stride
479                }
480            },
481        ))
482}
483
484impl<const N: usize> NdLayout<N> {
485    /// Convert this layout to one with a dynamic rank.
486    pub fn as_dyn(&self) -> DynLayout {
487        self.into()
488    }
489
490    /// Return true if all components of `index` are in-bounds.
491    fn index_valid(&self, index: [usize; N]) -> bool {
492        let mut valid = true;
493        for i in 0..N {
494            valid = valid && index[i] < self.shape[i]
495        }
496        valid
497    }
498
499    /// Return the strides that a contiguous layout with a given shape would
500    /// have.
501    fn contiguous_strides(shape: [usize; N]) -> [usize; N] {
502        let mut strides = [0; N];
503        for i in 0..N {
504            strides[i] = shape[i + 1..].iter().product();
505        }
506        strides
507    }
508}
509
510impl<'a, const N: usize> TryFrom<&'a DynLayout> for NdLayout<N> {
511    type Error = DimensionError;
512
513    /// Convert a dynamic layout into a static layout with N dims. Fails if
514    /// `value.ndim() != N`.
515    fn try_from(value: &'a DynLayout) -> Result<NdLayout<N>, DimensionError> {
516        let shape = value.shape();
517        let shape: [usize; N] = shape.try_into().map_err(|_| DimensionError {
518            actual: shape.len(),
519            expected: N,
520        })?;
521        let strides = value.strides();
522        let strides: [usize; N] = strides.try_into().map_err(|_| DimensionError {
523            actual: strides.len(),
524            expected: N,
525        })?;
526        Ok(NdLayout { shape, strides })
527    }
528}
529
530/// Defines the valid indices for an N-dimensional array and how to map them
531/// to offsets in a linear buffer, where N can be varied at runtime.
532///
533/// The layout specifies the size of each dimension of the tensor (the _shape_)
534/// and the stride (gap) between offsets in each dimension.
535#[derive(Debug, PartialEq)]
536pub struct DynLayout {
537    /// Array of dimension sizes followed by the corresponding dimension strides.
538    ///
539    /// Since we always have the same number of stride and shape dims, these
540    /// are combined into one array to avoid redundantly storing separate
541    /// lengths for each.
542    shape_and_strides: SmallVec<[usize; 8]>,
543}
544
545impl Clone for DynLayout {
546    fn clone(&self) -> DynLayout {
547        DynLayout {
548            // We implement `Clone` manually here so we can clone
549            // `shape_and_strides` using `SmallVec::from_slice` instead of
550            // `SmallVec::from`. This is faster for `Copy` types.
551            shape_and_strides: SmallVec::from_slice(self.shape_and_strides.as_slice()),
552        }
553    }
554}
555
556impl Layout for DynLayout {
557    type Index<'a> = &'a [usize];
558    type Indices = DynIndices;
559
560    /// Return the number of elements in the tensor shape described by this layout.
561    fn len(&self) -> usize {
562        self.shape().iter().product()
563    }
564
565    #[inline]
566    fn offset(&self, index: Self::Index<'_>) -> Option<usize> {
567        let shape = self.shape();
568        let strides = self.strides();
569        let mut valid = index.as_ref().len() == shape.len();
570        let mut offset = 0;
571        for (idx, (size, stride)) in index.as_ref().iter().zip(shape.iter().zip(strides.iter())) {
572            valid = valid && idx < size;
573            offset += idx * stride;
574        }
575        valid.then_some(offset)
576    }
577
578    fn is_empty(&self) -> bool {
579        self.len() == 0
580    }
581
582    /// Return the number of dimensions.
583    #[inline]
584    fn ndim(&self) -> usize {
585        self.shape_and_strides.len() / 2
586    }
587
588    /// Return the sizes of each dimension.
589    #[inline]
590    fn shape(&self) -> &[usize] {
591        &self.shape_and_strides[0..self.ndim()]
592    }
593
594    /// Returns the size of the dimension `dim`.
595    #[inline]
596    fn size(&self, dim: usize) -> usize {
597        debug_assert_dim_valid!(self, dim);
598        self.shape_and_strides[dim]
599    }
600
601    /// Return the stride (offset between elements) in the tensor's element array.
602    #[inline]
603    fn strides(&self) -> &[usize] {
604        &self.shape_and_strides[self.ndim()..]
605    }
606
607    /// Return the stride for a specific dimension.
608    #[inline]
609    fn stride(&self, dim: usize) -> usize {
610        debug_assert_dim_valid!(self, dim);
611        self.shape_and_strides[self.ndim() + dim]
612    }
613
614    fn indices(&self) -> DynIndices {
615        DynIndices::from_shape(self.shape())
616    }
617}
618
619unsafe impl TrustedLayout for DynLayout {}
620
621impl DynLayout {
622    pub fn make_contiguous(&mut self) {
623        self.shape_and_strides = Self::contiguous_shape_and_strides(self.shape());
624    }
625
626    fn permute_iter<I: Clone + Iterator<Item = usize>>(&mut self, dims: I) {
627        let strides = self.strides();
628        let shape = self.shape();
629        let shape_iter = dims.clone().map(|dim| shape[dim]);
630        let stride_iter = dims.map(|dim| strides[dim]);
631        self.shape_and_strides = shape_iter.chain(stride_iter).collect();
632    }
633
634    /// Swap the order of dimensions in this layout to the order described by
635    /// `dims`.
636    fn permute(&mut self, dims: &[usize]) {
637        assert!(
638            is_valid_permutation(self.ndim(), dims),
639            "permutation is invalid"
640        );
641        self.permute_iter(dims.iter().copied());
642    }
643
644    /// Reverse the order of dimensions in this layout.
645    fn transpose(&mut self) {
646        self.permute_iter((0..self.ndim()).rev());
647    }
648
649    /// Create a shape-and-strides array for a contiguous layout.
650    fn contiguous_shape_and_strides(shape: &[usize]) -> SmallVec<[usize; 8]> {
651        let mut strides_and_shape: SmallVec<[usize; 8]> = SmallVec::from_slice(shape);
652        strides_and_shape.resize(shape.len() * 2, 0);
653        let mut stride = 1;
654        for i in (0..shape.len()).rev() {
655            strides_and_shape[shape.len() + i] = stride;
656            stride *= shape[i];
657        }
658        strides_and_shape
659    }
660}
661
662impl<L: Layout> From<&L> for DynLayout {
663    fn from(layout: &L) -> DynLayout {
664        DynLayout::from_shape_and_strides(
665            layout.shape().as_ref(),
666            layout.strides().as_ref(),
667            OverlapPolicy::AllowOverlap,
668        )
669        .expect("invalid layout")
670    }
671}
672
673impl<const N: usize> From<NdLayout<N>> for DynLayout {
674    fn from(value: NdLayout<N>) -> DynLayout {
675        Self::from(&value)
676    }
677}
678
679/// MutLayout extends [`Layout`] with methods for creating, modifying and
680/// transforming layouts.
681///
682/// ## Strides and internal overlap
683///
684/// Rust requires that only one mutable reference can exist for any value. When
685/// creating mutable tensor views or iterators, it is therefore important to
686/// know whether multiple elements in the layout may map to the same offset.
687///
688/// Accurately checking this for arbitrary shape and strides is non-trivial. See
689/// notes in `mem_overlap.c` in the NumPy source. RTen handles this by using
690/// a conservative check for internal overlap when constructing a layout from
691/// arbitrary strides. Specifically it sorts dimensions by decreasing stride and
692/// then verifies that each dimension fully "steps over" the next one. This
693/// allows for layouts which are transposed or have been sliced, but disallows
694/// some more complex non-overlapping constructions.
695///
696/// When constructing a layout via
697/// [`from_shape_and_strides`](MutLayout::from_shape_and_strides) the intended
698/// usage is specified via an [`OverlapPolicy`].
699pub trait MutLayout: Layout + Clone {
700    /// Create a new contiguous layout with a given shape.
701    fn from_shape(shape: Self::Index<'_>) -> Self;
702
703    /// Create a layout with custom strides.
704    ///
705    /// The strides specify the offset gap between successive entries along a
706    /// given axis. `overlap` controls whether the layout is allowed to map
707    /// multiple indices to the same element. This can be true for immutable
708    /// views, but must be false for tensors or views that are mutable.
709    fn from_shape_and_strides(
710        shape: Self::Index<'_>,
711        strides: Self::Index<'_>,
712        overlap: OverlapPolicy,
713    ) -> Result<Self, FromDataError>;
714
715    /// Slice a layout by selecting a single entry from a given axis.
716    ///
717    /// Returns an `(offset_range, layout)` tuple for the sliced layout.
718    fn index_axis(&self, axis: usize, index: usize) -> (Range<usize>, <Self as RemoveDim>::Output)
719    where
720        Self: RemoveDim,
721    {
722        assert!(axis < self.ndim());
723        assert!(index < self.size(axis));
724
725        let layout = self.remove_dim(axis);
726        let start_offset = self.stride(axis) * index;
727
728        (start_offset..start_offset + layout.min_data_len(), layout)
729    }
730
731    /// Move the axis at position `from` to `to` by swapping their strides.
732    fn move_axis(&mut self, from: usize, to: usize);
733
734    /// Return a layout with the axes permuted according to the given order.
735    fn permuted(&self, order: Self::Index<'_>) -> Self;
736
737    /// Return a new layout formed by reshaping this one to `shape`.
738    ///
739    /// This has the same requirements as
740    /// [`reshaped_for_copy`](MutLayout::reshaped_for_copy) but also requires
741    /// that the layout is contiguous.
742    fn reshaped_for_view<S: IntoLayout>(&self, shape: S) -> Result<S::Layout, ReshapeError> {
743        if !self.is_contiguous() {
744            return Err(ReshapeError::NotContiguous);
745        }
746        self.reshaped_for_copy(shape)
747    }
748
749    /// Return a new layout formed by reshaping this one to `shape`.
750    fn reshaped_for_copy<S: IntoLayout>(&self, shape: S) -> Result<S::Layout, ReshapeError> {
751        let layout = shape.into_layout();
752        if layout.len() != self.len() {
753            return Err(ReshapeError::LengthMismatch);
754        }
755        Ok(layout)
756    }
757
758    // Modify the size of a dimension. This does not alter the strides.
759    fn resize_dim(&mut self, dim: usize, size: usize);
760
761    /// Reverse the order of dimensions. This is equivalent to
762    /// `self.permuted([N-1, N-2, ... 0])`.
763    fn transposed(&self) -> Self;
764
765    /// Slice the layout and return a static-rank layout.
766    ///
767    /// Returns a tuple of `(offset_range, sliced_layout)`.
768    fn slice<const M: usize>(
769        &self,
770        range: &[SliceItem],
771    ) -> Result<(Range<usize>, NdLayout<M>), SliceError>;
772
773    /// Slice the layout and return a dynamic rank layout.
774    ///
775    /// Returns a tuple of `(offset_range, sliced_layout)`.
776    fn slice_dyn(&self, range: &[SliceItem]) -> Result<(Range<usize>, DynLayout), SliceError>;
777
778    /// Slice the layout along a given axis.
779    ///
780    /// Returns a tuple of `(offset_range, sliced_layout)`.
781    fn slice_axis(
782        &self,
783        axis: usize,
784        range: Range<usize>,
785    ) -> Result<(Range<usize>, Self), SliceError> {
786        if axis >= self.ndim() {
787            return Err(SliceError::InvalidAxis { axis });
788        }
789        if range.end < range.start || range.end > self.size(axis) {
790            return Err(SliceError::InvalidRange {
791                axis,
792                range: range.into(),
793                size: self.size(axis),
794            });
795        }
796
797        let mut sliced_layout = self.clone();
798        sliced_layout.resize_dim(axis, range.len());
799        let range = if sliced_layout.is_empty() {
800            0..0
801        } else {
802            let start_offset = range.start * sliced_layout.stride(axis);
803            let end_offset = start_offset + sliced_layout.min_data_len();
804            start_offset..end_offset
805        };
806        Ok((range, sliced_layout))
807    }
808
809    /// Return a layout with all size-one dimensions removed.
810    fn squeezed(&self) -> DynLayout;
811
812    /// Split the layout along the given axis into two.
813    ///
814    /// Returns a tuple of `(left, right)` where each item is an `(offset_range,
815    /// layout)` tuple.
816    fn split(&self, axis: usize, mid: usize) -> ((Range<usize>, Self), (Range<usize>, Self));
817}
818
819/// Trait for broadcasting a layout from one shape to another.
820pub trait BroadcastLayout<L: Layout> {
821    /// Broadcast the `self` layout to a given shape.
822    fn broadcast<S: IntoLayout<Layout = L>>(&self, shape: S) -> Result<L, ExpandError>;
823}
824
825impl<const N: usize, const M: usize> BroadcastLayout<NdLayout<M>> for NdLayout<N> {
826    fn broadcast<S: IntoLayout<Layout = NdLayout<M>>>(
827        &self,
828        shape: S,
829    ) -> Result<NdLayout<M>, ExpandError> {
830        let shape: [usize; M] = shape.as_ref().try_into().unwrap();
831        if !self.can_broadcast_to(&shape) {
832            return Err(ExpandError::ShapeMismatch);
833        }
834        let mut strides = [0usize; M];
835        for (i, stride) in broadcast_strides(&self.shape(), &self.strides(), &shape).enumerate() {
836            strides[i] = stride;
837        }
838
839        Ok(NdLayout { shape, strides })
840    }
841}
842
843impl<const N: usize> BroadcastLayout<DynLayout> for NdLayout<N> {
844    fn broadcast<S: IntoLayout<Layout = DynLayout>>(
845        &self,
846        shape: S,
847    ) -> Result<DynLayout, ExpandError> {
848        let dyn_layout: DynLayout = self.into();
849        dyn_layout.broadcast(shape.as_ref())
850    }
851}
852
853impl BroadcastLayout<DynLayout> for DynLayout {
854    fn broadcast<S: IntoLayout<Layout = DynLayout>>(
855        &self,
856        shape: S,
857    ) -> Result<DynLayout, ExpandError> {
858        let to_shape = shape.as_ref();
859
860        if !self.can_broadcast_to(to_shape) {
861            return Err(ExpandError::ShapeMismatch);
862        }
863
864        let mut shape_and_strides = SmallVec::with_capacity(to_shape.len() * 2);
865        shape_and_strides.extend(to_shape.iter().copied());
866        shape_and_strides.extend(broadcast_strides(self.shape(), self.strides(), to_shape));
867
868        Ok(DynLayout { shape_and_strides })
869    }
870}
871
872impl<const N: usize> BroadcastLayout<NdLayout<N>> for DynLayout {
873    fn broadcast<S: IntoLayout<Layout = NdLayout<N>>>(
874        &self,
875        shape: S,
876    ) -> Result<NdLayout<N>, ExpandError> {
877        let dyn_broadcast = self.broadcast(shape.as_ref())?;
878        let layout = (&dyn_broadcast)
879            .try_into()
880            .map_err(|_| ExpandError::ShapeMismatch)?;
881        Ok(layout)
882    }
883}
884
885impl<const N: usize> MutLayout for NdLayout<N> {
886    fn from_shape(shape: [usize; N]) -> Self {
887        Self {
888            shape,
889            strides: Self::contiguous_strides(shape),
890        }
891    }
892
893    fn from_shape_and_strides(
894        shape: Self::Index<'_>,
895        strides: Self::Index<'_>,
896        overlap: OverlapPolicy,
897    ) -> Result<Self, FromDataError> {
898        let layout = NdLayout { shape, strides };
899
900        match overlap {
901            OverlapPolicy::DisallowOverlap => {
902                if may_have_internal_overlap(&layout.shape, &layout.strides) {
903                    return Err(FromDataError::MayOverlap);
904                }
905            }
906            OverlapPolicy::AllowOverlap => {}
907        }
908
909        Ok(layout)
910    }
911
912    fn move_axis(&mut self, from: usize, to: usize) {
913        assert!(from < N && to < N);
914        let mut dyn_layout = self.as_dyn();
915        dyn_layout.move_axis(from, to);
916        *self = NdLayout::try_from(&dyn_layout).unwrap();
917    }
918
919    fn permuted(&self, dims: [usize; N]) -> NdLayout<N> {
920        assert!(is_valid_permutation(N, &dims), "permutation is invalid");
921        let mut shape = [0; N];
922        let mut strides = [0; N];
923        for i in 0..N {
924            shape[i] = self.shape[dims[i]];
925            strides[i] = self.strides[dims[i]];
926        }
927        NdLayout { shape, strides }
928    }
929
930    fn resize_dim(&mut self, dim: usize, size: usize) {
931        self.shape[dim] = size;
932    }
933
934    fn transposed(&self) -> NdLayout<N> {
935        let dims = std::array::from_fn(|i| N - i - 1);
936        self.permuted(dims)
937    }
938
939    fn slice<const M: usize>(
940        &self,
941        range: &[SliceItem],
942    ) -> Result<(Range<usize>, NdLayout<M>), SliceError> {
943        if self.ndim() < range.len() {
944            return Err(SliceError::TooManyDims {
945                ndim: self.ndim(),
946                range_ndim: range.len(),
947            });
948        }
949
950        let mut shape: [usize; M] = [0; M];
951        let mut strides: [usize; M] = [0; M];
952
953        let (ndim, offset) =
954            slice_layout(&self.shape, &self.strides, &mut shape, &mut strides, range)?;
955
956        if ndim != M {
957            return Err(SliceError::OutputDimsMismatch {
958                actual: ndim,
959                expected: M,
960            });
961        }
962
963        let layout = NdLayout { shape, strides };
964        Ok((offset..offset + layout.min_data_len(), layout))
965    }
966
967    fn slice_dyn(&self, range: &[SliceItem]) -> Result<(Range<usize>, DynLayout), SliceError> {
968        self.as_dyn().slice_dyn(range)
969    }
970
971    fn squeezed(&self) -> DynLayout {
972        self.as_dyn().squeezed()
973    }
974
975    fn split(&self, axis: usize, mid: usize) -> ((Range<usize>, Self), (Range<usize>, Self)) {
976        assert!(axis < self.ndim());
977        assert!(mid <= self.size(axis));
978
979        let left_shape = std::array::from_fn(|i| if i == axis { mid } else { self.shape[i] });
980        let right_shape = std::array::from_fn(|i| {
981            if i == axis {
982                self.size(axis) - mid
983            } else {
984                self.shape[i]
985            }
986        });
987
988        let left = NdLayout {
989            shape: left_shape,
990            strides: self.strides,
991        };
992        let right = NdLayout {
993            shape: right_shape,
994            strides: self.strides,
995        };
996
997        let mid_offset = mid * self.strides[axis];
998        let left_offsets = 0..left.min_data_len();
999        let end_offset = self.min_data_len();
1000
1001        let right_offsets = if right.is_empty() {
1002            end_offset..end_offset
1003        } else {
1004            mid_offset..end_offset
1005        };
1006
1007        ((left_offsets, left), (right_offsets, right))
1008    }
1009}
1010
1011impl MutLayout for DynLayout {
1012    fn from_shape(shape: &[usize]) -> Self {
1013        DynLayout {
1014            shape_and_strides: Self::contiguous_shape_and_strides(shape),
1015        }
1016    }
1017
1018    fn from_shape_and_strides(
1019        shape: &[usize],
1020        strides: &[usize],
1021        overlap: OverlapPolicy,
1022    ) -> Result<Self, FromDataError> {
1023        let mut shape_and_strides = SmallVec::with_capacity(shape.len() + strides.len());
1024        shape_and_strides.extend_from_slice(shape);
1025        shape_and_strides.extend_from_slice(strides);
1026        let layout = DynLayout { shape_and_strides };
1027
1028        match overlap {
1029            OverlapPolicy::DisallowOverlap => {
1030                if may_have_internal_overlap(layout.shape(), layout.strides()) {
1031                    return Err(FromDataError::MayOverlap);
1032                }
1033            }
1034            OverlapPolicy::AllowOverlap => {}
1035        }
1036
1037        Ok(layout)
1038    }
1039
1040    fn move_axis(&mut self, from: usize, to: usize) {
1041        let ndim = self.ndim();
1042        assert!(from < ndim && to < ndim);
1043
1044        let size = self.shape_and_strides.remove(from);
1045        let stride = self.shape_and_strides.remove(ndim - 1 + from);
1046        self.shape_and_strides.insert(to, size);
1047        self.shape_and_strides.insert(ndim + to, stride);
1048    }
1049
1050    fn permuted(&self, order: &[usize]) -> DynLayout {
1051        let mut permuted = self.clone();
1052        permuted.permute(order);
1053        permuted
1054    }
1055
1056    fn resize_dim(&mut self, dim: usize, size: usize) {
1057        self.shape_and_strides[dim] = size;
1058    }
1059
1060    fn transposed(&self) -> DynLayout {
1061        let mut transposed = self.clone();
1062        transposed.transpose();
1063        transposed
1064    }
1065
1066    fn slice<const M: usize>(
1067        &self,
1068        range: &[SliceItem],
1069    ) -> Result<(Range<usize>, NdLayout<M>), SliceError> {
1070        let (offset_range, dyn_layout) = self.slice_dyn(range)?;
1071        let nd_layout =
1072            NdLayout::try_from(&dyn_layout).map_err(|_| SliceError::OutputDimsMismatch {
1073                actual: dyn_layout.ndim(),
1074                expected: M,
1075            })?;
1076        Ok((offset_range, nd_layout))
1077    }
1078
1079    fn slice_dyn(&self, range: &[SliceItem]) -> Result<(Range<usize>, DynLayout), SliceError> {
1080        if self.ndim() < range.len() {
1081            return Err(SliceError::TooManyDims {
1082                ndim: self.ndim(),
1083                range_ndim: range.len(),
1084            });
1085        }
1086
1087        let out_dims = self.ndim()
1088            - range
1089                .iter()
1090                .filter(|item| matches!(item, SliceItem::Index(_)))
1091                .count();
1092        let mut shape_and_strides = smallvec![0; out_dims * 2];
1093        let (out_shape, out_strides) = shape_and_strides.as_mut_slice().split_at_mut(out_dims);
1094
1095        let (_ndim, offset) =
1096            slice_layout(self.shape(), self.strides(), out_shape, out_strides, range)?;
1097
1098        let layout = Self { shape_and_strides };
1099        Ok((offset..offset + layout.min_data_len(), layout))
1100    }
1101
1102    fn squeezed(&self) -> DynLayout {
1103        let shape = self.shape().iter().copied().filter(|&size| size != 1);
1104        let strides = self
1105            .shape()
1106            .iter()
1107            .zip(self.strides())
1108            .filter_map(|(&size, &stride)| if size != 1 { Some(stride) } else { None });
1109        DynLayout {
1110            shape_and_strides: shape.chain(strides).collect(),
1111        }
1112    }
1113
1114    fn split(&self, axis: usize, mid: usize) -> ((Range<usize>, Self), (Range<usize>, Self)) {
1115        assert!(axis < self.ndim());
1116        assert!(mid <= self.size(axis));
1117
1118        let mut left_shape_strides: SmallVec<[usize; 8]> = (0..self.ndim())
1119            .map(|i| if i == axis { mid } else { self.size(i) })
1120            .collect();
1121        left_shape_strides.extend(self.strides().iter().copied());
1122
1123        let mut right_shape_strides: SmallVec<[usize; 8]> = (0..self.ndim())
1124            .map(|i| {
1125                if i == axis {
1126                    self.size(axis) - mid
1127                } else {
1128                    self.size(i)
1129                }
1130            })
1131            .collect();
1132        right_shape_strides.extend(self.strides().iter().copied());
1133
1134        let left = DynLayout {
1135            shape_and_strides: left_shape_strides,
1136        };
1137        let right = DynLayout {
1138            shape_and_strides: right_shape_strides,
1139        };
1140
1141        let mid_offset = mid * self.stride(axis);
1142        let left_offsets = 0..left.min_data_len();
1143        let end_offset = self.min_data_len();
1144
1145        let right_offsets = if right.is_empty() {
1146            end_offset..end_offset
1147        } else {
1148            mid_offset..end_offset
1149        };
1150
1151        ((left_offsets, left), (right_offsets, right))
1152    }
1153}
1154
1155/// Trait for shapes which can be used to create a contiguous layout.
1156///
1157/// This is implemented for `[usize; N]` for creating static-rank layouts from
1158/// arrays, and `&[usize]` for creating dynamic-rank layouts from slices.
1159pub trait IntoLayout: AsRef<[usize]> + std::fmt::Debug {
1160    /// The type of layout produced from this shape.
1161    type Layout: MutLayout;
1162
1163    /// Convert this shape into a contiguous layout.
1164    fn into_layout(self) -> Self::Layout;
1165}
1166
1167impl<const N: usize> IntoLayout for [usize; N] {
1168    type Layout = NdLayout<N>;
1169
1170    #[inline]
1171    fn into_layout(self) -> NdLayout<N> {
1172        NdLayout::from_shape(self)
1173    }
1174}
1175
1176impl IntoLayout for &[usize] {
1177    type Layout = DynLayout;
1178
1179    #[inline]
1180    fn into_layout(self) -> DynLayout {
1181        DynLayout::from_shape(self)
1182    }
1183}
1184
1185/// Trait which extends [`MutLayout`] with support for changing the number of
1186/// dimensions in-place.
1187///
1188/// This is only implemented for [`DynLayout`], since layouts that have a static
1189/// rank cannot change their dimension count at runtime.
1190pub trait ResizeLayout: MutLayout {
1191    /// Insert a size-one axis at the given index in the shape. This will have
1192    /// the same stride as the dimension that follows it.
1193    fn insert_axis(&mut self, index: usize);
1194
1195    /// Remove a size-1 axis at the given index.
1196    ///
1197    /// Since the axis has size one, this does not alter the number of elements
1198    /// in the layout or the order in which they are visited.
1199    ///
1200    /// Panics if the axis does not have a size of 1.
1201    #[track_caller]
1202    fn remove_axis(&mut self, index: usize) {
1203        assert!(
1204            self.size(index) == 1,
1205            "cannot remove axis of size {}",
1206            self.size(index)
1207        );
1208        self.remove_axis_of_any_size(index)
1209    }
1210
1211    /// Remove an axis that may have any size.
1212    ///
1213    /// If the size of the axis is not one, this will "remove" elements from
1214    /// the layout.
1215    fn remove_axis_of_any_size(&mut self, index: usize);
1216
1217    /// Merge consecutive axes where possible.
1218    ///
1219    /// This "simplifies" the layout by minimizing the number of dimensions
1220    /// while preserving the iteration order.
1221    fn merge_axes(&mut self);
1222}
1223
1224impl ResizeLayout for DynLayout {
1225    fn insert_axis(&mut self, index: usize) {
1226        let ndim = self.ndim();
1227        let new_size = 1;
1228
1229        // Choose stride for new dimension as if we were inserting it at the
1230        // beginning. If `dim != 0` then the result is as if we inserted the
1231        // dim at the start and then permuted the layout.
1232        let (max_stride, size_for_max_stride) = self
1233            .strides()
1234            .iter()
1235            .copied()
1236            .zip(self.shape().iter().copied())
1237            .max_by_key(|(stride, _size)| *stride)
1238            .unwrap_or((1, 1));
1239        let new_stride = max_stride * size_for_max_stride;
1240
1241        self.shape_and_strides.insert(index, new_size);
1242        self.shape_and_strides.insert(ndim + 1 + index, new_stride);
1243    }
1244
1245    fn remove_axis_of_any_size(&mut self, index: usize) {
1246        self.shape_and_strides.remove(index);
1247        self.shape_and_strides.remove(self.ndim() + index);
1248    }
1249
1250    fn merge_axes(&mut self) {
1251        let merged = merge_axes(self.shape(), self.strides());
1252        self.shape_and_strides = merged
1253            .iter()
1254            .map(|dim| dim.0)
1255            .chain(merged.iter().map(|dim| dim.1))
1256            .collect();
1257    }
1258}
1259
1260/// Trait for converting types into indices for use with a given layout.
1261///
1262/// Static-rank tensors can be indexed with `[usize; N]` arrays. Dynamic-rank
1263/// tensors can be indexed with any type that can be converted to an `&[usize]`
1264/// slice.
1265pub trait AsIndex<L: Layout> {
1266    /// Convert `self` into an index for use the layout `L`.
1267    fn as_index(&self) -> L::Index<'_>;
1268}
1269
1270impl<T: AsRef<[usize]>> AsIndex<DynLayout> for T {
1271    fn as_index(&self) -> &[usize] {
1272        self.as_ref()
1273    }
1274}
1275
1276impl<const N: usize> AsIndex<NdLayout<N>> for [usize; N] {
1277    fn as_index(&self) -> [usize; N] {
1278        *self
1279    }
1280}
1281
1282impl AsIndex<NdLayout<1>> for usize {
1283    fn as_index(&self) -> [usize; 1] {
1284        [*self]
1285    }
1286}
1287
1288/// Trait that removes one dimension from a layout.
1289pub trait RemoveDim {
1290    type Output: MutLayout;
1291
1292    /// Return a copy of this layout with the dimension at index `dim` removed.
1293    fn remove_dim(&self, dim: usize) -> Self::Output;
1294}
1295
1296impl<R: RemoveDim> RemoveDim for &R {
1297    type Output = R::Output;
1298
1299    fn remove_dim(&self, dim: usize) -> Self::Output {
1300        (*self).remove_dim(dim)
1301    }
1302}
1303
1304impl RemoveDim for DynLayout {
1305    type Output = DynLayout;
1306
1307    fn remove_dim(&self, dim: usize) -> DynLayout {
1308        let ndim = self.ndim();
1309        assert!(ndim > 0, "cannot remove axis from tensor with 0 dims");
1310
1311        let shape = (0..ndim - 1).map(|i| {
1312            if i < dim {
1313                self.size(i)
1314            } else {
1315                self.size(i + 1)
1316            }
1317        });
1318        let strides = (0..ndim - 1).map(|i| {
1319            if i < dim {
1320                self.stride(i)
1321            } else {
1322                self.stride(i + 1)
1323            }
1324        });
1325        DynLayout {
1326            shape_and_strides: shape.chain(strides).collect(),
1327        }
1328    }
1329}
1330
1331macro_rules! impl_remove_dim {
1332    ($in_dims:expr, $out_dims:expr) => {
1333        impl RemoveDim for NdLayout<$in_dims> {
1334            type Output = NdLayout<$out_dims>;
1335
1336            fn remove_dim(&self, dim: usize) -> Self::Output {
1337                let shape = std::array::from_fn(|i| {
1338                    if i < dim {
1339                        self.shape[i]
1340                    } else {
1341                        self.shape[i + 1]
1342                    }
1343                });
1344                let strides = std::array::from_fn(|i| {
1345                    if i < dim {
1346                        self.strides[i]
1347                    } else {
1348                        self.strides[i + 1]
1349                    }
1350                });
1351                NdLayout { shape, strides }
1352            }
1353        }
1354    };
1355}
1356
1357impl_remove_dim!(1, 0);
1358impl_remove_dim!(2, 1);
1359impl_remove_dim!(3, 2);
1360impl_remove_dim!(4, 3);
1361impl_remove_dim!(5, 4);
1362
1363/// Trait for slicing a layout with a range.
1364///
1365/// `R` is the type of the slice range. `IdxCount` is a marker type indicating
1366/// the number of items in `R` that are indices, as opposed to ranges.
1367pub trait SliceWith<R: IntoSliceItems, IdxCount: OptionalUInt> {
1368    /// The layout produced after slicing.
1369    type Layout: MutLayout;
1370
1371    /// Slice the layout with a range.
1372    ///
1373    /// Returns a tuple of `(offset_range, sliced_layout)` where `offset_range`
1374    /// is the range of data from the original view that is used by the slice
1375    /// and `sliced_layout` is the layout of the sliced view.
1376    fn slice_with(&self, range: R) -> Result<(Range<usize>, Self::Layout), SliceError>;
1377}
1378
1379impl<R: IntoSliceItems, L: MutLayout> SliceWith<R, Unknown> for L {
1380    type Layout = DynLayout;
1381
1382    fn slice_with(&self, range: R) -> Result<(Range<usize>, Self::Layout), SliceError> {
1383        self.slice_dyn(range.into_slice_items().as_ref())
1384    }
1385}
1386
1387impl<R: IntoSliceItems, const N: usize> SliceWith<R, U0> for NdLayout<N> {
1388    type Layout = NdLayout<N>;
1389
1390    fn slice_with(&self, range: R) -> Result<(Range<usize>, Self::Layout), SliceError> {
1391        self.slice(range.into_slice_items().as_ref())
1392    }
1393}
1394
1395macro_rules! impl_slice_with_dynlayout {
1396    ($range_ndim:ty) => {
1397        impl<R: IntoSliceItems> SliceWith<R, $range_ndim> for DynLayout {
1398            type Layout = DynLayout;
1399
1400            fn slice_with(&self, range: R) -> Result<(Range<usize>, Self::Layout), SliceError> {
1401                self.slice_dyn(range.into_slice_items().as_ref())
1402            }
1403        }
1404    };
1405}
1406
1407impl_slice_with_dynlayout!(U0);
1408impl_slice_with_dynlayout!(U1);
1409impl_slice_with_dynlayout!(U2);
1410impl_slice_with_dynlayout!(U3);
1411impl_slice_with_dynlayout!(U4);
1412impl_slice_with_dynlayout!(U5);
1413
1414macro_rules! impl_slice_with {
1415    ($ndim:literal, $range_ndim:ty, $out_ndim:literal) => {
1416        impl<R: IntoSliceItems> SliceWith<R, $range_ndim> for NdLayout<$ndim> {
1417            type Layout = NdLayout<$out_ndim>;
1418
1419            fn slice_with(&self, range: R) -> Result<(Range<usize>, Self::Layout), SliceError> {
1420                self.slice(range.into_slice_items().as_ref())
1421            }
1422        }
1423    };
1424}
1425
1426impl_slice_with!(1, U1, 0);
1427impl_slice_with!(2, U1, 1);
1428impl_slice_with!(2, U2, 0);
1429impl_slice_with!(3, U1, 2);
1430impl_slice_with!(3, U2, 1);
1431impl_slice_with!(3, U3, 0);
1432impl_slice_with!(4, U1, 3);
1433impl_slice_with!(4, U2, 2);
1434impl_slice_with!(4, U3, 1);
1435impl_slice_with!(4, U4, 0);
1436impl_slice_with!(5, U1, 4);
1437impl_slice_with!(5, U2, 3);
1438impl_slice_with!(5, U3, 2);
1439impl_slice_with!(5, U4, 1);
1440impl_slice_with!(5, U5, 0);
1441
1442#[cfg(test)]
1443mod tests {
1444    use rten_testing::TestCases;
1445
1446    use std::ops::Range;
1447
1448    use super::OverlapPolicy;
1449    use crate::SliceItem;
1450    use crate::errors::{ReshapeError, SliceError};
1451    use crate::layout::{DynLayout, Layout, MutLayout, NdLayout, ResizeLayout};
1452
1453    fn layout_with_strides<const N: usize>(shape: [usize; N], strides: [usize; N]) -> NdLayout<N> {
1454        NdLayout::from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap).unwrap()
1455    }
1456
1457    #[test]
1458    fn test_is_broadcast() {
1459        // Non-empty, contiguous layout
1460        let layout = DynLayout::from_shape(&[5, 5]);
1461        assert!(!layout.is_broadcast());
1462
1463        // Empty layout
1464        let layout = DynLayout::from_shape(&[5, 0]);
1465        assert!(!layout.is_broadcast());
1466
1467        // Broadcasting layout
1468        let layout =
1469            DynLayout::from_shape_and_strides(&[5, 5], &[0, 0], OverlapPolicy::AllowOverlap)
1470                .unwrap();
1471        assert!(layout.is_broadcast());
1472    }
1473
1474    #[test]
1475    fn test_from_shape_and_strides() {
1476        #[derive(Debug)]
1477        struct Case<'a> {
1478            shape: &'a [usize],
1479            strides: &'a [usize],
1480        }
1481
1482        let cases = [
1483            // Contiguous layout
1484            Case {
1485                shape: &[10, 10],
1486                strides: &[10, 1],
1487            },
1488            // Broadcasting layout
1489            Case {
1490                shape: &[10, 10],
1491                strides: &[10, 0],
1492            },
1493        ];
1494
1495        cases.test_each(|case| {
1496            let layout = DynLayout::from_shape_and_strides(
1497                case.shape,
1498                case.strides,
1499                OverlapPolicy::AllowOverlap,
1500            )
1501            .unwrap();
1502            assert_eq!(layout.shape(), case.shape);
1503            assert_eq!(layout.strides(), case.strides);
1504        })
1505    }
1506
1507    #[test]
1508    fn test_index_axis() {
1509        #[derive(Debug)]
1510        struct Case {
1511            layout: NdLayout<2>,
1512            axis: usize,
1513            index: usize,
1514            expected: (usize, NdLayout<1>), // (start offset, sliced layout)
1515        }
1516
1517        let cases = [
1518            Case {
1519                layout: NdLayout::from_shape([3, 4]),
1520                axis: 0,
1521                index: 1,
1522                expected: (4, layout_with_strides([4], [1])),
1523            },
1524            Case {
1525                layout: NdLayout::from_shape([3, 4]),
1526                axis: 1,
1527                index: 2,
1528                expected: (2, layout_with_strides([3], [4])),
1529            },
1530        ];
1531
1532        cases.test_each(|case| {
1533            let Case {
1534                layout,
1535                axis,
1536                index,
1537                expected,
1538            } = case;
1539
1540            let (expected_start, expected_layout) = expected;
1541
1542            let (offsets, sliced_layout) = layout.index_axis(*axis, *index);
1543            assert_eq!(sliced_layout, *expected_layout);
1544            assert_eq!(offsets.start, *expected_start);
1545            assert_eq!(offsets.len(), expected_layout.min_data_len());
1546
1547            let (_, sliced_layout_dyn) = layout.as_dyn().index_axis(*axis, *index);
1548            assert_eq!(sliced_layout_dyn, expected_layout.as_dyn());
1549        })
1550    }
1551
1552    #[test]
1553    #[should_panic(expected = "axis < self.ndim()")]
1554    fn test_index_axis_invalid_axis() {
1555        NdLayout::from_shape([2, 3]).index_axis(2, 0);
1556    }
1557
1558    #[test]
1559    #[should_panic(expected = "index < self.size(axis)")]
1560    fn test_index_axis_invalid_index() {
1561        NdLayout::from_shape([2, 3]).index_axis(0, 3);
1562    }
1563
1564    #[test]
1565    fn test_move_axis() {
1566        let mut layout = DynLayout::from_shape(&[2, 4, 8]);
1567        assert_eq!(layout.strides(), [32, 8, 1]);
1568
1569        layout.move_axis(1, 0);
1570        assert_eq!(layout.shape(), [4, 2, 8]);
1571        assert_eq!(layout.strides(), [8, 32, 1]);
1572
1573        layout.move_axis(0, 1);
1574        assert_eq!(layout.shape(), [2, 4, 8]);
1575        assert_eq!(layout.strides(), [32, 8, 1]);
1576
1577        layout.move_axis(2, 1);
1578        assert_eq!(layout.shape(), [2, 8, 4]);
1579        assert_eq!(layout.strides(), [32, 1, 8]);
1580    }
1581
1582    #[test]
1583    #[should_panic]
1584    fn test_move_axis_invalid_from() {
1585        let mut layout = DynLayout::from_shape(&[2, 4, 8]);
1586        layout.move_axis(3, 0);
1587    }
1588
1589    #[test]
1590    #[should_panic]
1591    fn test_move_axis_invalid_to() {
1592        let mut layout = DynLayout::from_shape(&[2, 4, 8]);
1593        layout.move_axis(0, 3);
1594    }
1595
1596    #[test]
1597    #[should_panic(expected = "permutation is invalid")]
1598    fn test_permute_invalid_len() {
1599        let mut layout = DynLayout::from_shape(&[5, 5]);
1600        layout.permute(&[1, 0, 3]);
1601    }
1602
1603    #[test]
1604    #[should_panic(expected = "permutation is invalid")]
1605    fn test_permute_too_few_dims() {
1606        let mut layout = DynLayout::from_shape(&[5, 5]);
1607        layout.permute(&[1]);
1608    }
1609
1610    #[test]
1611    #[should_panic(expected = "permutation is invalid")]
1612    fn test_permute_repeated_dims() {
1613        let mut layout = DynLayout::from_shape(&[5, 5]);
1614        layout.permute(&[1, 1]);
1615    }
1616
1617    #[test]
1618    fn test_remove_axis_of_any_size() {
1619        let shape = [1, 2, 3, 4];
1620        for d in 0..shape.len() {
1621            let mut layout = DynLayout::from_shape(&shape);
1622            let (expected_shape, expected_strides): (Vec<usize>, Vec<usize>) = layout
1623                .shape()
1624                .iter()
1625                .zip(layout.strides())
1626                .enumerate()
1627                .filter_map(|(i, (size, stride))| if i != d { Some((size, stride)) } else { None })
1628                .unzip();
1629
1630            layout.remove_axis_of_any_size(d);
1631
1632            assert_eq!(layout.shape(), expected_shape);
1633            assert_eq!(layout.strides(), expected_strides);
1634        }
1635    }
1636
1637    #[test]
1638    fn test_reshaped() {
1639        #[derive(Debug)]
1640        struct Case<'a> {
1641            layout: DynLayout,
1642            new_shape: &'a [usize],
1643            for_copy: bool,
1644            error: Option<ReshapeError>,
1645        }
1646
1647        let cases = [
1648            // Reshapes that don't allow copying.
1649            Case {
1650                layout: DynLayout::from_shape(&[2, 2]),
1651                new_shape: &[4],
1652                for_copy: false,
1653                error: None,
1654            },
1655            Case {
1656                layout: DynLayout::from_shape(&[2, 2]).transposed(),
1657                new_shape: &[4],
1658                for_copy: false,
1659                error: Some(ReshapeError::NotContiguous),
1660            },
1661            Case {
1662                layout: DynLayout::from_shape(&[2, 2]),
1663                new_shape: &[3],
1664                for_copy: false,
1665                error: Some(ReshapeError::LengthMismatch),
1666            },
1667            // Reshapes that do allow copying.
1668            Case {
1669                layout: DynLayout::from_shape(&[2, 2]).transposed(),
1670                new_shape: &[4],
1671                for_copy: true,
1672                error: None,
1673            },
1674            Case {
1675                layout: DynLayout::from_shape(&[2, 2]),
1676                new_shape: &[3],
1677                for_copy: false,
1678                error: Some(ReshapeError::LengthMismatch),
1679            },
1680        ];
1681
1682        cases.test_each(|case| {
1683            let Case {
1684                layout,
1685                new_shape,
1686                for_copy,
1687                error,
1688            } = case;
1689
1690            let reshaped = if *for_copy {
1691                layout.reshaped_for_copy(*new_shape)
1692            } else {
1693                layout.reshaped_for_view(*new_shape)
1694            };
1695
1696            assert_eq!(reshaped.as_ref().err(), error.as_ref());
1697            if let Ok(new_layout) = reshaped {
1698                assert_eq!(new_layout.shape(), *new_shape);
1699            }
1700        })
1701    }
1702
1703    #[test]
1704    fn test_squeezed() {
1705        let layout = DynLayout::from_shape(&[1, 1, 10, 20]);
1706        let squeezed = layout.squeezed();
1707        assert_eq!(squeezed.shape(), &[10, 20]);
1708        assert_eq!(squeezed.strides(), &[20, 1]);
1709    }
1710
1711    #[test]
1712    fn test_slice_axis() {
1713        #[derive(Clone, Debug)]
1714        struct Case<'a> {
1715            shape: &'a [usize],
1716            axis: usize,
1717            range: Range<usize>,
1718            sliced_shape: &'a [usize],
1719            offsets: Range<usize>,
1720        }
1721
1722        let cases = [Case {
1723            shape: &[3, 5],
1724            axis: 1,
1725            range: 2..4,
1726            sliced_shape: &[3, 2],
1727            offsets: 2..14,
1728        }];
1729
1730        cases.test_each_clone(|case| {
1731            let Case {
1732                shape,
1733                axis,
1734                range,
1735                sliced_shape,
1736                offsets,
1737            } = case;
1738
1739            let layout = DynLayout::from_shape(shape);
1740            let (offset_range, sliced_layout) = layout.slice_axis(axis, range).unwrap();
1741            assert_eq!(sliced_layout.shape(), sliced_shape);
1742            assert_eq!(sliced_layout.strides(), layout.strides());
1743            assert_eq!(offset_range, offsets);
1744        })
1745    }
1746
1747    #[test]
1748    fn test_slice_axis_invalid() {
1749        #[derive(Debug)]
1750        struct Case<'a> {
1751            shape: &'a [usize],
1752            axis: usize,
1753            range: Range<usize>,
1754            expected: SliceError,
1755        }
1756
1757        let cases = [
1758            Case {
1759                shape: &[1, 2, 3],
1760                axis: 4,
1761                range: 0..1,
1762                expected: SliceError::InvalidAxis { axis: 4 },
1763            },
1764            Case {
1765                shape: &[1, 2, 3],
1766                axis: 0,
1767                range: 0..2,
1768                expected: SliceError::InvalidRange {
1769                    axis: 0,
1770                    range: (0..2).into(),
1771                    size: 1,
1772                },
1773            },
1774        ];
1775
1776        cases.test_each(|case| {
1777            let layout = DynLayout::from_shape(case.shape);
1778            let result = layout.slice_axis(case.axis, case.range.clone());
1779            assert_eq!(result, Err(case.expected.clone()));
1780        })
1781    }
1782
1783    #[test]
1784    fn test_slice_invalid() {
1785        #[derive(Debug)]
1786        struct Case<'a> {
1787            layout: DynLayout,
1788            ranges: &'a [SliceItem],
1789            expected: SliceError,
1790        }
1791
1792        let cases = [
1793            Case {
1794                layout: DynLayout::from_shape(&[3, 5]),
1795                ranges: &[SliceItem::Index(4), SliceItem::Index(0)],
1796                expected: SliceError::InvalidIndex {
1797                    axis: 0,
1798                    index: 4,
1799                    size: 3,
1800                },
1801            },
1802            Case {
1803                layout: DynLayout::from_shape(&[3, 5]),
1804                ranges: &[SliceItem::Range((1..4).into()), SliceItem::Index(0)],
1805                expected: SliceError::InvalidRange {
1806                    axis: 0,
1807                    range: (1..4).into(),
1808                    size: 3,
1809                },
1810            },
1811            Case {
1812                layout: DynLayout::from_shape(&[3, 5]),
1813                ranges: &[SliceItem::Index(-4)],
1814                expected: SliceError::InvalidIndex {
1815                    axis: 0,
1816                    index: -4,
1817                    size: 3,
1818                },
1819            },
1820            Case {
1821                layout: DynLayout::from_shape(&[3, 5]),
1822                ranges: &[SliceItem::Range((4..).into()), SliceItem::Index(0)],
1823                expected: SliceError::InvalidRange {
1824                    axis: 0,
1825                    range: (4..).into(),
1826                    size: 3,
1827                },
1828            },
1829            Case {
1830                layout: DynLayout::from_shape(&[3, 5]),
1831                ranges: &[SliceItem::full_range(), SliceItem::range(0, None, -1)],
1832                expected: SliceError::InvalidStep { axis: 1, step: -1 },
1833            },
1834        ];
1835
1836        cases.test_each(|case| {
1837            let result = case.layout.slice_dyn(case.ranges);
1838            assert_eq!(result, Err(case.expected.clone()));
1839        })
1840    }
1841
1842    #[test]
1843    fn test_size_stride() {
1844        let layout = DynLayout::from_shape(&[10, 20, 30]);
1845        for (dim, (&size, &stride)) in layout.shape().iter().zip(layout.strides()).enumerate() {
1846            assert_eq!(layout.size(dim), size);
1847            assert_eq!(layout.stride(dim), stride);
1848        }
1849    }
1850
1851    #[test]
1852    fn test_split() {
1853        #[derive(Debug)]
1854        struct Case {
1855            shape: [usize; 2],
1856            strides: Option<[usize; 2]>,
1857            axis: usize,
1858            mid: usize,
1859        }
1860
1861        let mut cases = Vec::new();
1862
1863        // All combinations of (axis, mid) for a small shape.
1864        let shape = [4, 2];
1865        for axis in 0..shape.len() {
1866            for mid in 0..shape[axis] {
1867                cases.push(Case {
1868                    shape,
1869                    axis,
1870                    mid,
1871                    strides: None,
1872                });
1873            }
1874        }
1875
1876        // Empty layout
1877        cases.push(Case {
1878            shape: [0, 0],
1879            strides: None,
1880            axis: 0,
1881            mid: 0,
1882        });
1883
1884        // Case where we are splitting a 1-sized dimension with `mid=1` and
1885        // the stride is larger than the minimum storage length for the layout.
1886        cases.push(Case {
1887            shape: [1, 4],
1888            strides: Some([10, 0]),
1889            axis: 0,
1890            mid: 1,
1891        });
1892
1893        fn check_split<L: MutLayout>(layout: L, axis: usize, mid: usize) {
1894            let (left, right) = layout.split(axis, mid);
1895            let (left_offsets, left_layout) = left;
1896            let (right_offsets, right_layout) = right;
1897
1898            assert_eq!(left_layout.strides(), layout.strides());
1899            assert_eq!(right_layout.strides(), layout.strides());
1900
1901            assert_eq!(left_offsets.len(), left_layout.min_data_len());
1902            assert_eq!(right_offsets.len(), right_layout.min_data_len());
1903
1904            let orig_len = layout.min_data_len();
1905            assert!(left_offsets.start <= orig_len && left_offsets.end <= orig_len);
1906            assert!(right_offsets.start <= orig_len && right_offsets.end <= orig_len);
1907
1908            for i in 0..layout.ndim() {
1909                assert_eq!(
1910                    left_layout.size(i),
1911                    if i == axis { mid } else { layout.size(i) }
1912                );
1913                assert_eq!(
1914                    right_layout.size(i),
1915                    if i == axis {
1916                        layout.size(i) - mid
1917                    } else {
1918                        layout.size(i)
1919                    }
1920                );
1921            }
1922        }
1923
1924        cases.test_each(|case| {
1925            let Case {
1926                shape,
1927                strides,
1928                axis,
1929                mid,
1930            } = case;
1931
1932            let layout = if let Some(strides) = strides {
1933                NdLayout::from_shape_and_strides(*shape, *strides, OverlapPolicy::AllowOverlap)
1934                    .unwrap()
1935            } else {
1936                NdLayout::from_shape(*shape)
1937            };
1938            let dyn_layout = if let Some(strides) = strides {
1939                DynLayout::from_shape_and_strides(
1940                    shape.as_slice(),
1941                    strides.as_slice(),
1942                    OverlapPolicy::AllowOverlap,
1943                )
1944                .unwrap()
1945            } else {
1946                DynLayout::from_shape(shape.as_slice())
1947            };
1948
1949            check_split(layout, *axis, *mid);
1950            check_split(dyn_layout, *axis, *mid);
1951        })
1952    }
1953
1954    #[test]
1955    fn test_merge_axes() {
1956        #[derive(Debug)]
1957        struct Case<'a> {
1958            shape: &'a [usize],
1959            strides: &'a [usize],
1960            merged_shape: &'a [usize],
1961            merged_strides: &'a [usize],
1962        }
1963
1964        let cases = [
1965            // Empty shape
1966            Case {
1967                shape: &[],
1968                strides: &[],
1969                merged_shape: &[],
1970                merged_strides: &[],
1971            },
1972            // Vector
1973            Case {
1974                shape: &[10],
1975                strides: &[2],
1976                merged_shape: &[10],
1977                merged_strides: &[2],
1978            },
1979            // Simple contiguous layout
1980            Case {
1981                shape: &[10, 10],
1982                strides: &[10, 1],
1983                merged_shape: &[100],
1984                merged_strides: &[1],
1985            },
1986            // Transposed matrix
1987            Case {
1988                shape: &[10, 10],
1989                strides: &[1, 10],
1990                merged_shape: &[10, 10],
1991                merged_strides: &[1, 10],
1992            },
1993            // Leading 1-sized dims
1994            Case {
1995                shape: &[1, 10, 10],
1996                strides: &[10, 1, 10],
1997                merged_shape: &[10, 10],
1998                merged_strides: &[1, 10],
1999            },
2000            // Inner 1-sized dims
2001            Case {
2002                shape: &[2, 1, 1, 2],
2003                strides: &[2, 2, 2, 1],
2004                merged_shape: &[4],
2005                merged_strides: &[1],
2006            },
2007            // Inner 1-sized dims that have been shifted over from the left,
2008            // ie. where the 1-sized dims where inserted at the left and then
2009            // shifted over to the middle.
2010            Case {
2011                shape: &[2, 1, 1, 2],
2012                strides: &[2, 4, 4, 1],
2013                merged_shape: &[4],
2014                merged_strides: &[1],
2015            },
2016        ];
2017
2018        cases.test_each(|case| {
2019            let mut layout = DynLayout::from_shape_and_strides(
2020                case.shape,
2021                case.strides,
2022                OverlapPolicy::AllowOverlap,
2023            )
2024            .unwrap();
2025            layout.merge_axes();
2026            assert_eq!(layout.shape(), case.merged_shape);
2027            assert_eq!(layout.strides(), case.merged_strides);
2028        })
2029    }
2030}