1use std::iter::repeat;
4use std::ops::Range;
5
6use smallvec::{SmallVec, smallvec};
7
8use crate::errors::{DimensionError, ExpandError, FromDataError, ReshapeError, SliceError};
9use crate::index_iterator::{DynIndices, NdIndices};
10use crate::overlap::{is_contiguous, may_have_internal_overlap};
11use crate::slice_range::{IntoSliceItems, SliceItem};
12use crate::type_num::{OptionalUInt, U0, U1, U2, U3, U4, U5, Unknown};
13
14pub fn is_valid_permutation(ndim: usize, permutation: &[usize]) -> bool {
17 permutation.len() == ndim
18 && (0..ndim).all(|dim| permutation.iter().filter(|d| **d == dim).count() == 1)
19}
20
21pub(crate) fn merge_axes(shape: &[usize], strides: &[usize]) -> SmallVec<[(usize, usize); 4]> {
28 let (Some(prev_size), Some(prev_stride)) = (shape.last(), strides.last()) else {
29 return SmallVec::new();
30 };
31
32 let mut merged: SmallVec<[(usize, usize); 4]> = SmallVec::with_capacity(shape.len());
33 merged.push((*prev_size, *prev_stride));
34
35 for (&outer_size, &outer_stride) in shape.iter().zip(strides.iter()).rev().skip(1) {
36 let (inner_size, inner_stride) = merged.last_mut().unwrap();
37 let can_merge = outer_size == 1 || (outer_stride == *inner_stride * *inner_size);
38 if can_merge {
39 *inner_size *= outer_size;
40 } else {
41 merged.push((outer_size, outer_stride));
42 }
43 }
44
45 merged.reverse();
46
47 merged
48}
49
50macro_rules! debug_assert_dim_valid {
52 ($layout:ident, $dim:expr) => {
53 debug_assert!(
54 $dim < $layout.ndim(),
55 "dim {} out of bounds for tensor with {} dims",
56 $dim,
57 $layout.ndim()
58 )
59 };
60}
61
62pub trait Layout {
73 type Index<'a>: AsRef<[usize]> + Clone + std::fmt::Debug + PartialEq<Self::Index<'a>>;
78
79 type Indices;
81
82 fn offset_unchecked(&self, index: Self::Index<'_>) -> usize {
89 index
90 .as_ref()
91 .iter()
92 .zip(self.strides().as_ref())
93 .map(|(idx, stride)| *idx * *stride)
94 .sum()
95 }
96
97 fn offset(&self, index: Self::Index<'_>) -> Option<usize>;
105
106 fn ndim(&self) -> usize;
108
109 fn len(&self) -> usize;
111
112 fn is_contiguous(&self) -> bool {
115 is_contiguous(self.shape(), self.strides())
116 }
117
118 fn is_broadcast(&self) -> bool {
121 !self.is_empty() && self.strides().as_ref().contains(&0)
122 }
123
124 fn is_empty(&self) -> bool {
126 self.len() == 0
127 }
128
129 fn shape(&self) -> Self::Index<'_>;
131
132 fn size(&self, dim: usize) -> usize {
134 debug_assert_dim_valid!(self, dim);
135 self.shape().as_ref()[dim]
136 }
137
138 fn strides(&self) -> Self::Index<'_>;
140
141 fn stride(&self, dim: usize) -> usize {
143 debug_assert_dim_valid!(self, dim);
144 self.strides().as_ref()[dim]
145 }
146
147 fn indices(&self) -> Self::Indices;
149
150 fn can_broadcast_to(&self, target_shape: &[usize]) -> bool {
152 if self.shape().as_ref() == target_shape {
153 return true;
154 } else if self.ndim() > target_shape.len() {
155 return false;
156 }
157
158 let target_dims = target_shape[target_shape.len() - self.shape().as_ref().len()..]
164 .iter()
165 .copied();
166
167 self.shape()
168 .as_ref()
169 .iter()
170 .copied()
171 .zip(target_dims)
172 .all(|(a, b)| a == b || a == 1)
173 }
174
175 fn can_broadcast_with(&self, shape: &[usize]) -> bool {
185 if self.shape().as_ref() == shape {
186 return true;
187 }
188
189 let current_shape = self.shape();
196 let a = current_shape.as_ref();
197 let b = shape;
198
199 let a_pad = b.len().saturating_sub(a.len());
200 let b_pad = a.len().saturating_sub(b.len());
201
202 let a_iter = a.iter().copied().rev().chain(repeat(1).take(a_pad));
203 let b_iter = b.iter().copied().rev().chain(repeat(1).take(b_pad));
204
205 a_iter.zip(b_iter).all(|(a, b)| a == b || a == 1 || b == 1)
206 }
207
208 fn min_data_len(&self) -> usize {
211 if self.shape().as_ref().contains(&0) {
212 return 0;
213 }
214 let max_offset: usize = self
215 .shape()
216 .as_ref()
217 .iter()
218 .zip(self.strides().as_ref())
219 .map(|(size, stride)| (size - 1) * stride)
220 .sum();
221 max_offset + 1
222 }
223}
224
225pub unsafe trait TrustedLayout: Layout {}
236
237pub(crate) trait LayoutExt: Layout {
242 #[inline]
244 fn must_offset(&self, index: Self::Index<'_>) -> usize {
245 self.offset(index.clone()).unwrap_or_else(|| {
246 panic!(
247 "index {:?} out of bounds for shape {:?}",
248 index.as_ref(),
249 self.shape().as_ref()
250 )
251 })
252 }
253}
254
255impl<L: Layout> LayoutExt for L {}
256
257pub trait MatrixLayout {
259 fn rows(&self) -> usize;
260 fn cols(&self) -> usize;
261 fn row_stride(&self) -> usize;
262 fn col_stride(&self) -> usize;
263}
264
265pub enum OverlapPolicy {
271 AllowOverlap,
272 DisallowOverlap,
273}
274
275#[derive(Clone, Copy, Debug, PartialEq)]
278pub struct NdLayout<const N: usize> {
279 shape: [usize; N],
280 strides: [usize; N],
281}
282
283impl<const N: usize> Layout for NdLayout<N> {
284 type Index<'a> = [usize; N];
285 type Indices = NdIndices<N>;
286
287 fn ndim(&self) -> usize {
288 N
289 }
290
291 fn len(&self) -> usize {
292 self.shape.iter().product()
293 }
294
295 #[inline]
296 fn offset(&self, index: [usize; N]) -> Option<usize> {
297 if !self.index_valid(index) {
298 return None;
299 }
300 Some(self.offset_unchecked(index))
301 }
302
303 #[inline]
304 fn offset_unchecked(&self, index: [usize; N]) -> usize {
305 let mut offset = 0;
306 for i in 0..N {
307 offset += index[i] * self.strides[i];
308 }
309 offset
310 }
311
312 #[inline]
313 fn shape(&self) -> Self::Index<'_> {
314 self.shape
315 }
316
317 #[inline]
318 fn strides(&self) -> Self::Index<'_> {
319 self.strides
320 }
321
322 fn indices(&self) -> Self::Indices {
323 NdIndices::from_shape(self.shape)
324 }
325}
326
327unsafe impl<const N: usize> TrustedLayout for NdLayout<N> {}
328
329impl<L: Layout> Layout for &L {
330 type Index<'b> = L::Index<'b>;
331 type Indices = L::Indices;
332
333 fn ndim(&self) -> usize {
334 (*self).ndim()
335 }
336
337 fn len(&self) -> usize {
338 (*self).len()
339 }
340
341 fn offset(&self, index: Self::Index<'_>) -> Option<usize> {
342 (*self).offset(index)
343 }
344
345 fn offset_unchecked(&self, index: Self::Index<'_>) -> usize {
346 (*self).offset_unchecked(index)
347 }
348
349 fn shape(&self) -> Self::Index<'_> {
350 (*self).shape()
351 }
352
353 fn strides(&self) -> Self::Index<'_> {
354 (*self).strides()
355 }
356
357 fn indices(&self) -> Self::Indices {
358 (*self).indices()
359 }
360}
361
362unsafe impl<L: TrustedLayout> TrustedLayout for &L {}
365
366impl MatrixLayout for NdLayout<2> {
367 #[inline]
368 fn rows(&self) -> usize {
369 self.size(0)
370 }
371
372 #[inline]
373 fn cols(&self) -> usize {
374 self.size(1)
375 }
376
377 #[inline]
378 fn row_stride(&self) -> usize {
379 self.stride(0)
380 }
381
382 #[inline]
383 fn col_stride(&self) -> usize {
384 self.stride(1)
385 }
386}
387
388fn slice_layout<I: AsRef<[usize]>, O: AsMut<[usize]>>(
396 in_shape: I,
397 in_strides: I,
398 mut out_shape: O,
399 mut out_strides: O,
400 range: &[SliceItem],
401) -> Result<(usize, usize), SliceError> {
402 let in_shape = in_shape.as_ref();
403 let in_strides = in_strides.as_ref();
404 let out_shape = out_shape.as_mut();
405 let out_strides = out_strides.as_mut();
406
407 let mut ndim = 0;
408 let mut offset = 0;
409
410 for (in_dim, (&size, &stride)) in in_shape.iter().zip(in_strides.iter()).enumerate() {
411 let (offset_adjust, new_size_stride) = match range.get(in_dim) {
412 Some(&SliceItem::Index(idx)) => {
413 let pos_idx = if idx >= 0 { idx } else { idx + size as isize };
414 if pos_idx < 0 || pos_idx >= size as isize {
415 return Err(SliceError::InvalidIndex {
416 axis: in_dim,
417 index: idx,
418 size,
419 });
420 }
421 (stride * pos_idx as usize, None)
422 }
423 Some(SliceItem::Range(range)) => {
424 let resolved = range.resolve(size).ok_or(SliceError::InvalidRange {
425 axis: in_dim,
426 range: *range,
427 size,
428 })?;
429 let step: usize = range
430 .step()
431 .try_into()
432 .map_err(|_| SliceError::InvalidStep {
433 axis: in_dim,
434 step: range.step(),
435 })?;
436 let new_size = if step == 1 {
437 resolved.end - resolved.start
439 } else {
440 range.index_range(size).steps()
441 };
442 let new_stride = stride * step;
443 (stride * resolved.start, Some((new_size, new_stride)))
444 }
445 None => (0, Some((size, stride))),
446 };
447
448 offset += offset_adjust;
449 if let Some((new_size, new_stride)) = new_size_stride {
450 out_shape[ndim] = new_size;
451 out_strides[ndim] = new_stride;
452 ndim += 1;
453 }
454 }
455
456 if out_shape.contains(&0) {
457 offset = 0;
458 }
459
460 Ok((ndim, offset))
461}
462
463fn broadcast_strides<'a>(
466 from_shape: &'a [usize],
467 from_strides: &'a [usize],
468 to_shape: &'a [usize],
469) -> impl Iterator<Item = usize> + 'a {
470 let pad = to_shape.len() - from_shape.len();
471 repeat(0)
472 .take(pad)
473 .chain(from_shape.iter().zip(from_strides.iter()).enumerate().map(
474 move |(i, (size, stride))| {
475 if *size == 1 && to_shape[i + pad] > 1 {
476 0
477 } else {
478 *stride
479 }
480 },
481 ))
482}
483
484impl<const N: usize> NdLayout<N> {
485 pub fn as_dyn(&self) -> DynLayout {
487 self.into()
488 }
489
490 fn index_valid(&self, index: [usize; N]) -> bool {
492 let mut valid = true;
493 for i in 0..N {
494 valid = valid && index[i] < self.shape[i]
495 }
496 valid
497 }
498
499 fn contiguous_strides(shape: [usize; N]) -> [usize; N] {
502 let mut strides = [0; N];
503 for i in 0..N {
504 strides[i] = shape[i + 1..].iter().product();
505 }
506 strides
507 }
508}
509
510impl<'a, const N: usize> TryFrom<&'a DynLayout> for NdLayout<N> {
511 type Error = DimensionError;
512
513 fn try_from(value: &'a DynLayout) -> Result<NdLayout<N>, DimensionError> {
516 let shape = value.shape();
517 let shape: [usize; N] = shape.try_into().map_err(|_| DimensionError {
518 actual: shape.len(),
519 expected: N,
520 })?;
521 let strides = value.strides();
522 let strides: [usize; N] = strides.try_into().map_err(|_| DimensionError {
523 actual: strides.len(),
524 expected: N,
525 })?;
526 Ok(NdLayout { shape, strides })
527 }
528}
529
530#[derive(Debug, PartialEq)]
536pub struct DynLayout {
537 shape_and_strides: SmallVec<[usize; 8]>,
543}
544
545impl Clone for DynLayout {
546 fn clone(&self) -> DynLayout {
547 DynLayout {
548 shape_and_strides: SmallVec::from_slice(self.shape_and_strides.as_slice()),
552 }
553 }
554}
555
556impl Layout for DynLayout {
557 type Index<'a> = &'a [usize];
558 type Indices = DynIndices;
559
560 fn len(&self) -> usize {
562 self.shape().iter().product()
563 }
564
565 #[inline]
566 fn offset(&self, index: Self::Index<'_>) -> Option<usize> {
567 let shape = self.shape();
568 let strides = self.strides();
569 let mut valid = index.as_ref().len() == shape.len();
570 let mut offset = 0;
571 for (idx, (size, stride)) in index.as_ref().iter().zip(shape.iter().zip(strides.iter())) {
572 valid = valid && idx < size;
573 offset += idx * stride;
574 }
575 valid.then_some(offset)
576 }
577
578 fn is_empty(&self) -> bool {
579 self.len() == 0
580 }
581
582 #[inline]
584 fn ndim(&self) -> usize {
585 self.shape_and_strides.len() / 2
586 }
587
588 #[inline]
590 fn shape(&self) -> &[usize] {
591 &self.shape_and_strides[0..self.ndim()]
592 }
593
594 #[inline]
596 fn size(&self, dim: usize) -> usize {
597 debug_assert_dim_valid!(self, dim);
598 self.shape_and_strides[dim]
599 }
600
601 #[inline]
603 fn strides(&self) -> &[usize] {
604 &self.shape_and_strides[self.ndim()..]
605 }
606
607 #[inline]
609 fn stride(&self, dim: usize) -> usize {
610 debug_assert_dim_valid!(self, dim);
611 self.shape_and_strides[self.ndim() + dim]
612 }
613
614 fn indices(&self) -> DynIndices {
615 DynIndices::from_shape(self.shape())
616 }
617}
618
619unsafe impl TrustedLayout for DynLayout {}
620
621impl DynLayout {
622 pub fn make_contiguous(&mut self) {
623 self.shape_and_strides = Self::contiguous_shape_and_strides(self.shape());
624 }
625
626 fn permute_iter<I: Clone + Iterator<Item = usize>>(&mut self, dims: I) {
627 let strides = self.strides();
628 let shape = self.shape();
629 let shape_iter = dims.clone().map(|dim| shape[dim]);
630 let stride_iter = dims.map(|dim| strides[dim]);
631 self.shape_and_strides = shape_iter.chain(stride_iter).collect();
632 }
633
634 fn permute(&mut self, dims: &[usize]) {
637 assert!(
638 is_valid_permutation(self.ndim(), dims),
639 "permutation is invalid"
640 );
641 self.permute_iter(dims.iter().copied());
642 }
643
644 fn transpose(&mut self) {
646 self.permute_iter((0..self.ndim()).rev());
647 }
648
649 fn contiguous_shape_and_strides(shape: &[usize]) -> SmallVec<[usize; 8]> {
651 let mut strides_and_shape: SmallVec<[usize; 8]> = SmallVec::from_slice(shape);
652 strides_and_shape.resize(shape.len() * 2, 0);
653 let mut stride = 1;
654 for i in (0..shape.len()).rev() {
655 strides_and_shape[shape.len() + i] = stride;
656 stride *= shape[i];
657 }
658 strides_and_shape
659 }
660}
661
662impl<L: Layout> From<&L> for DynLayout {
663 fn from(layout: &L) -> DynLayout {
664 DynLayout::from_shape_and_strides(
665 layout.shape().as_ref(),
666 layout.strides().as_ref(),
667 OverlapPolicy::AllowOverlap,
668 )
669 .expect("invalid layout")
670 }
671}
672
673impl<const N: usize> From<NdLayout<N>> for DynLayout {
674 fn from(value: NdLayout<N>) -> DynLayout {
675 Self::from(&value)
676 }
677}
678
679pub trait MutLayout: Layout + Clone {
700 fn from_shape(shape: Self::Index<'_>) -> Self;
702
703 fn from_shape_and_strides(
710 shape: Self::Index<'_>,
711 strides: Self::Index<'_>,
712 overlap: OverlapPolicy,
713 ) -> Result<Self, FromDataError>;
714
715 fn index_axis(&self, axis: usize, index: usize) -> (Range<usize>, <Self as RemoveDim>::Output)
719 where
720 Self: RemoveDim,
721 {
722 assert!(axis < self.ndim());
723 assert!(index < self.size(axis));
724
725 let layout = self.remove_dim(axis);
726 let start_offset = self.stride(axis) * index;
727
728 (start_offset..start_offset + layout.min_data_len(), layout)
729 }
730
731 fn move_axis(&mut self, from: usize, to: usize);
733
734 fn permuted(&self, order: Self::Index<'_>) -> Self;
736
737 fn reshaped_for_view<S: IntoLayout>(&self, shape: S) -> Result<S::Layout, ReshapeError> {
743 if !self.is_contiguous() {
744 return Err(ReshapeError::NotContiguous);
745 }
746 self.reshaped_for_copy(shape)
747 }
748
749 fn reshaped_for_copy<S: IntoLayout>(&self, shape: S) -> Result<S::Layout, ReshapeError> {
751 let layout = shape.into_layout();
752 if layout.len() != self.len() {
753 return Err(ReshapeError::LengthMismatch);
754 }
755 Ok(layout)
756 }
757
758 fn resize_dim(&mut self, dim: usize, size: usize);
760
761 fn transposed(&self) -> Self;
764
765 fn slice<const M: usize>(
769 &self,
770 range: &[SliceItem],
771 ) -> Result<(Range<usize>, NdLayout<M>), SliceError>;
772
773 fn slice_dyn(&self, range: &[SliceItem]) -> Result<(Range<usize>, DynLayout), SliceError>;
777
778 fn slice_axis(
782 &self,
783 axis: usize,
784 range: Range<usize>,
785 ) -> Result<(Range<usize>, Self), SliceError> {
786 if axis >= self.ndim() {
787 return Err(SliceError::InvalidAxis { axis });
788 }
789 if range.end < range.start || range.end > self.size(axis) {
790 return Err(SliceError::InvalidRange {
791 axis,
792 range: range.into(),
793 size: self.size(axis),
794 });
795 }
796
797 let mut sliced_layout = self.clone();
798 sliced_layout.resize_dim(axis, range.len());
799 let range = if sliced_layout.is_empty() {
800 0..0
801 } else {
802 let start_offset = range.start * sliced_layout.stride(axis);
803 let end_offset = start_offset + sliced_layout.min_data_len();
804 start_offset..end_offset
805 };
806 Ok((range, sliced_layout))
807 }
808
809 fn squeezed(&self) -> DynLayout;
811
812 fn split(&self, axis: usize, mid: usize) -> ((Range<usize>, Self), (Range<usize>, Self));
817}
818
819pub trait BroadcastLayout<L: Layout> {
821 fn broadcast<S: IntoLayout<Layout = L>>(&self, shape: S) -> Result<L, ExpandError>;
823}
824
825impl<const N: usize, const M: usize> BroadcastLayout<NdLayout<M>> for NdLayout<N> {
826 fn broadcast<S: IntoLayout<Layout = NdLayout<M>>>(
827 &self,
828 shape: S,
829 ) -> Result<NdLayout<M>, ExpandError> {
830 let shape: [usize; M] = shape.as_ref().try_into().unwrap();
831 if !self.can_broadcast_to(&shape) {
832 return Err(ExpandError::ShapeMismatch);
833 }
834 let mut strides = [0usize; M];
835 for (i, stride) in broadcast_strides(&self.shape(), &self.strides(), &shape).enumerate() {
836 strides[i] = stride;
837 }
838
839 Ok(NdLayout { shape, strides })
840 }
841}
842
843impl<const N: usize> BroadcastLayout<DynLayout> for NdLayout<N> {
844 fn broadcast<S: IntoLayout<Layout = DynLayout>>(
845 &self,
846 shape: S,
847 ) -> Result<DynLayout, ExpandError> {
848 let dyn_layout: DynLayout = self.into();
849 dyn_layout.broadcast(shape.as_ref())
850 }
851}
852
853impl BroadcastLayout<DynLayout> for DynLayout {
854 fn broadcast<S: IntoLayout<Layout = DynLayout>>(
855 &self,
856 shape: S,
857 ) -> Result<DynLayout, ExpandError> {
858 let to_shape = shape.as_ref();
859
860 if !self.can_broadcast_to(to_shape) {
861 return Err(ExpandError::ShapeMismatch);
862 }
863
864 let mut shape_and_strides = SmallVec::with_capacity(to_shape.len() * 2);
865 shape_and_strides.extend(to_shape.iter().copied());
866 shape_and_strides.extend(broadcast_strides(self.shape(), self.strides(), to_shape));
867
868 Ok(DynLayout { shape_and_strides })
869 }
870}
871
872impl<const N: usize> BroadcastLayout<NdLayout<N>> for DynLayout {
873 fn broadcast<S: IntoLayout<Layout = NdLayout<N>>>(
874 &self,
875 shape: S,
876 ) -> Result<NdLayout<N>, ExpandError> {
877 let dyn_broadcast = self.broadcast(shape.as_ref())?;
878 let layout = (&dyn_broadcast)
879 .try_into()
880 .map_err(|_| ExpandError::ShapeMismatch)?;
881 Ok(layout)
882 }
883}
884
885impl<const N: usize> MutLayout for NdLayout<N> {
886 fn from_shape(shape: [usize; N]) -> Self {
887 Self {
888 shape,
889 strides: Self::contiguous_strides(shape),
890 }
891 }
892
893 fn from_shape_and_strides(
894 shape: Self::Index<'_>,
895 strides: Self::Index<'_>,
896 overlap: OverlapPolicy,
897 ) -> Result<Self, FromDataError> {
898 let layout = NdLayout { shape, strides };
899
900 match overlap {
901 OverlapPolicy::DisallowOverlap => {
902 if may_have_internal_overlap(&layout.shape, &layout.strides) {
903 return Err(FromDataError::MayOverlap);
904 }
905 }
906 OverlapPolicy::AllowOverlap => {}
907 }
908
909 Ok(layout)
910 }
911
912 fn move_axis(&mut self, from: usize, to: usize) {
913 assert!(from < N && to < N);
914 let mut dyn_layout = self.as_dyn();
915 dyn_layout.move_axis(from, to);
916 *self = NdLayout::try_from(&dyn_layout).unwrap();
917 }
918
919 fn permuted(&self, dims: [usize; N]) -> NdLayout<N> {
920 assert!(is_valid_permutation(N, &dims), "permutation is invalid");
921 let mut shape = [0; N];
922 let mut strides = [0; N];
923 for i in 0..N {
924 shape[i] = self.shape[dims[i]];
925 strides[i] = self.strides[dims[i]];
926 }
927 NdLayout { shape, strides }
928 }
929
930 fn resize_dim(&mut self, dim: usize, size: usize) {
931 self.shape[dim] = size;
932 }
933
934 fn transposed(&self) -> NdLayout<N> {
935 let dims = std::array::from_fn(|i| N - i - 1);
936 self.permuted(dims)
937 }
938
939 fn slice<const M: usize>(
940 &self,
941 range: &[SliceItem],
942 ) -> Result<(Range<usize>, NdLayout<M>), SliceError> {
943 if self.ndim() < range.len() {
944 return Err(SliceError::TooManyDims {
945 ndim: self.ndim(),
946 range_ndim: range.len(),
947 });
948 }
949
950 let mut shape: [usize; M] = [0; M];
951 let mut strides: [usize; M] = [0; M];
952
953 let (ndim, offset) =
954 slice_layout(&self.shape, &self.strides, &mut shape, &mut strides, range)?;
955
956 if ndim != M {
957 return Err(SliceError::OutputDimsMismatch {
958 actual: ndim,
959 expected: M,
960 });
961 }
962
963 let layout = NdLayout { shape, strides };
964 Ok((offset..offset + layout.min_data_len(), layout))
965 }
966
967 fn slice_dyn(&self, range: &[SliceItem]) -> Result<(Range<usize>, DynLayout), SliceError> {
968 self.as_dyn().slice_dyn(range)
969 }
970
971 fn squeezed(&self) -> DynLayout {
972 self.as_dyn().squeezed()
973 }
974
975 fn split(&self, axis: usize, mid: usize) -> ((Range<usize>, Self), (Range<usize>, Self)) {
976 assert!(axis < self.ndim());
977 assert!(mid <= self.size(axis));
978
979 let left_shape = std::array::from_fn(|i| if i == axis { mid } else { self.shape[i] });
980 let right_shape = std::array::from_fn(|i| {
981 if i == axis {
982 self.size(axis) - mid
983 } else {
984 self.shape[i]
985 }
986 });
987
988 let left = NdLayout {
989 shape: left_shape,
990 strides: self.strides,
991 };
992 let right = NdLayout {
993 shape: right_shape,
994 strides: self.strides,
995 };
996
997 let mid_offset = mid * self.strides[axis];
998 let left_offsets = 0..left.min_data_len();
999 let end_offset = self.min_data_len();
1000
1001 let right_offsets = if right.is_empty() {
1002 end_offset..end_offset
1003 } else {
1004 mid_offset..end_offset
1005 };
1006
1007 ((left_offsets, left), (right_offsets, right))
1008 }
1009}
1010
1011impl MutLayout for DynLayout {
1012 fn from_shape(shape: &[usize]) -> Self {
1013 DynLayout {
1014 shape_and_strides: Self::contiguous_shape_and_strides(shape),
1015 }
1016 }
1017
1018 fn from_shape_and_strides(
1019 shape: &[usize],
1020 strides: &[usize],
1021 overlap: OverlapPolicy,
1022 ) -> Result<Self, FromDataError> {
1023 let mut shape_and_strides = SmallVec::with_capacity(shape.len() + strides.len());
1024 shape_and_strides.extend_from_slice(shape);
1025 shape_and_strides.extend_from_slice(strides);
1026 let layout = DynLayout { shape_and_strides };
1027
1028 match overlap {
1029 OverlapPolicy::DisallowOverlap => {
1030 if may_have_internal_overlap(layout.shape(), layout.strides()) {
1031 return Err(FromDataError::MayOverlap);
1032 }
1033 }
1034 OverlapPolicy::AllowOverlap => {}
1035 }
1036
1037 Ok(layout)
1038 }
1039
1040 fn move_axis(&mut self, from: usize, to: usize) {
1041 let ndim = self.ndim();
1042 assert!(from < ndim && to < ndim);
1043
1044 let size = self.shape_and_strides.remove(from);
1045 let stride = self.shape_and_strides.remove(ndim - 1 + from);
1046 self.shape_and_strides.insert(to, size);
1047 self.shape_and_strides.insert(ndim + to, stride);
1048 }
1049
1050 fn permuted(&self, order: &[usize]) -> DynLayout {
1051 let mut permuted = self.clone();
1052 permuted.permute(order);
1053 permuted
1054 }
1055
1056 fn resize_dim(&mut self, dim: usize, size: usize) {
1057 self.shape_and_strides[dim] = size;
1058 }
1059
1060 fn transposed(&self) -> DynLayout {
1061 let mut transposed = self.clone();
1062 transposed.transpose();
1063 transposed
1064 }
1065
1066 fn slice<const M: usize>(
1067 &self,
1068 range: &[SliceItem],
1069 ) -> Result<(Range<usize>, NdLayout<M>), SliceError> {
1070 let (offset_range, dyn_layout) = self.slice_dyn(range)?;
1071 let nd_layout =
1072 NdLayout::try_from(&dyn_layout).map_err(|_| SliceError::OutputDimsMismatch {
1073 actual: dyn_layout.ndim(),
1074 expected: M,
1075 })?;
1076 Ok((offset_range, nd_layout))
1077 }
1078
1079 fn slice_dyn(&self, range: &[SliceItem]) -> Result<(Range<usize>, DynLayout), SliceError> {
1080 if self.ndim() < range.len() {
1081 return Err(SliceError::TooManyDims {
1082 ndim: self.ndim(),
1083 range_ndim: range.len(),
1084 });
1085 }
1086
1087 let out_dims = self.ndim()
1088 - range
1089 .iter()
1090 .filter(|item| matches!(item, SliceItem::Index(_)))
1091 .count();
1092 let mut shape_and_strides = smallvec![0; out_dims * 2];
1093 let (out_shape, out_strides) = shape_and_strides.as_mut_slice().split_at_mut(out_dims);
1094
1095 let (_ndim, offset) =
1096 slice_layout(self.shape(), self.strides(), out_shape, out_strides, range)?;
1097
1098 let layout = Self { shape_and_strides };
1099 Ok((offset..offset + layout.min_data_len(), layout))
1100 }
1101
1102 fn squeezed(&self) -> DynLayout {
1103 let shape = self.shape().iter().copied().filter(|&size| size != 1);
1104 let strides = self
1105 .shape()
1106 .iter()
1107 .zip(self.strides())
1108 .filter_map(|(&size, &stride)| if size != 1 { Some(stride) } else { None });
1109 DynLayout {
1110 shape_and_strides: shape.chain(strides).collect(),
1111 }
1112 }
1113
1114 fn split(&self, axis: usize, mid: usize) -> ((Range<usize>, Self), (Range<usize>, Self)) {
1115 assert!(axis < self.ndim());
1116 assert!(mid <= self.size(axis));
1117
1118 let mut left_shape_strides: SmallVec<[usize; 8]> = (0..self.ndim())
1119 .map(|i| if i == axis { mid } else { self.size(i) })
1120 .collect();
1121 left_shape_strides.extend(self.strides().iter().copied());
1122
1123 let mut right_shape_strides: SmallVec<[usize; 8]> = (0..self.ndim())
1124 .map(|i| {
1125 if i == axis {
1126 self.size(axis) - mid
1127 } else {
1128 self.size(i)
1129 }
1130 })
1131 .collect();
1132 right_shape_strides.extend(self.strides().iter().copied());
1133
1134 let left = DynLayout {
1135 shape_and_strides: left_shape_strides,
1136 };
1137 let right = DynLayout {
1138 shape_and_strides: right_shape_strides,
1139 };
1140
1141 let mid_offset = mid * self.stride(axis);
1142 let left_offsets = 0..left.min_data_len();
1143 let end_offset = self.min_data_len();
1144
1145 let right_offsets = if right.is_empty() {
1146 end_offset..end_offset
1147 } else {
1148 mid_offset..end_offset
1149 };
1150
1151 ((left_offsets, left), (right_offsets, right))
1152 }
1153}
1154
1155pub trait IntoLayout: AsRef<[usize]> + std::fmt::Debug {
1160 type Layout: MutLayout;
1162
1163 fn into_layout(self) -> Self::Layout;
1165}
1166
1167impl<const N: usize> IntoLayout for [usize; N] {
1168 type Layout = NdLayout<N>;
1169
1170 #[inline]
1171 fn into_layout(self) -> NdLayout<N> {
1172 NdLayout::from_shape(self)
1173 }
1174}
1175
1176impl IntoLayout for &[usize] {
1177 type Layout = DynLayout;
1178
1179 #[inline]
1180 fn into_layout(self) -> DynLayout {
1181 DynLayout::from_shape(self)
1182 }
1183}
1184
1185pub trait ResizeLayout: MutLayout {
1191 fn insert_axis(&mut self, index: usize);
1194
1195 #[track_caller]
1202 fn remove_axis(&mut self, index: usize) {
1203 assert!(
1204 self.size(index) == 1,
1205 "cannot remove axis of size {}",
1206 self.size(index)
1207 );
1208 self.remove_axis_of_any_size(index)
1209 }
1210
1211 fn remove_axis_of_any_size(&mut self, index: usize);
1216
1217 fn merge_axes(&mut self);
1222}
1223
1224impl ResizeLayout for DynLayout {
1225 fn insert_axis(&mut self, index: usize) {
1226 let ndim = self.ndim();
1227 let new_size = 1;
1228
1229 let (max_stride, size_for_max_stride) = self
1233 .strides()
1234 .iter()
1235 .copied()
1236 .zip(self.shape().iter().copied())
1237 .max_by_key(|(stride, _size)| *stride)
1238 .unwrap_or((1, 1));
1239 let new_stride = max_stride * size_for_max_stride;
1240
1241 self.shape_and_strides.insert(index, new_size);
1242 self.shape_and_strides.insert(ndim + 1 + index, new_stride);
1243 }
1244
1245 fn remove_axis_of_any_size(&mut self, index: usize) {
1246 self.shape_and_strides.remove(index);
1247 self.shape_and_strides.remove(self.ndim() + index);
1248 }
1249
1250 fn merge_axes(&mut self) {
1251 let merged = merge_axes(self.shape(), self.strides());
1252 self.shape_and_strides = merged
1253 .iter()
1254 .map(|dim| dim.0)
1255 .chain(merged.iter().map(|dim| dim.1))
1256 .collect();
1257 }
1258}
1259
1260pub trait AsIndex<L: Layout> {
1266 fn as_index(&self) -> L::Index<'_>;
1268}
1269
1270impl<T: AsRef<[usize]>> AsIndex<DynLayout> for T {
1271 fn as_index(&self) -> &[usize] {
1272 self.as_ref()
1273 }
1274}
1275
1276impl<const N: usize> AsIndex<NdLayout<N>> for [usize; N] {
1277 fn as_index(&self) -> [usize; N] {
1278 *self
1279 }
1280}
1281
1282impl AsIndex<NdLayout<1>> for usize {
1283 fn as_index(&self) -> [usize; 1] {
1284 [*self]
1285 }
1286}
1287
1288pub trait RemoveDim {
1290 type Output: MutLayout;
1291
1292 fn remove_dim(&self, dim: usize) -> Self::Output;
1294}
1295
1296impl<R: RemoveDim> RemoveDim for &R {
1297 type Output = R::Output;
1298
1299 fn remove_dim(&self, dim: usize) -> Self::Output {
1300 (*self).remove_dim(dim)
1301 }
1302}
1303
1304impl RemoveDim for DynLayout {
1305 type Output = DynLayout;
1306
1307 fn remove_dim(&self, dim: usize) -> DynLayout {
1308 let ndim = self.ndim();
1309 assert!(ndim > 0, "cannot remove axis from tensor with 0 dims");
1310
1311 let shape = (0..ndim - 1).map(|i| {
1312 if i < dim {
1313 self.size(i)
1314 } else {
1315 self.size(i + 1)
1316 }
1317 });
1318 let strides = (0..ndim - 1).map(|i| {
1319 if i < dim {
1320 self.stride(i)
1321 } else {
1322 self.stride(i + 1)
1323 }
1324 });
1325 DynLayout {
1326 shape_and_strides: shape.chain(strides).collect(),
1327 }
1328 }
1329}
1330
1331macro_rules! impl_remove_dim {
1332 ($in_dims:expr, $out_dims:expr) => {
1333 impl RemoveDim for NdLayout<$in_dims> {
1334 type Output = NdLayout<$out_dims>;
1335
1336 fn remove_dim(&self, dim: usize) -> Self::Output {
1337 let shape = std::array::from_fn(|i| {
1338 if i < dim {
1339 self.shape[i]
1340 } else {
1341 self.shape[i + 1]
1342 }
1343 });
1344 let strides = std::array::from_fn(|i| {
1345 if i < dim {
1346 self.strides[i]
1347 } else {
1348 self.strides[i + 1]
1349 }
1350 });
1351 NdLayout { shape, strides }
1352 }
1353 }
1354 };
1355}
1356
1357impl_remove_dim!(1, 0);
1358impl_remove_dim!(2, 1);
1359impl_remove_dim!(3, 2);
1360impl_remove_dim!(4, 3);
1361impl_remove_dim!(5, 4);
1362
1363pub trait SliceWith<R: IntoSliceItems, IdxCount: OptionalUInt> {
1368 type Layout: MutLayout;
1370
1371 fn slice_with(&self, range: R) -> Result<(Range<usize>, Self::Layout), SliceError>;
1377}
1378
1379impl<R: IntoSliceItems, L: MutLayout> SliceWith<R, Unknown> for L {
1380 type Layout = DynLayout;
1381
1382 fn slice_with(&self, range: R) -> Result<(Range<usize>, Self::Layout), SliceError> {
1383 self.slice_dyn(range.into_slice_items().as_ref())
1384 }
1385}
1386
1387impl<R: IntoSliceItems, const N: usize> SliceWith<R, U0> for NdLayout<N> {
1388 type Layout = NdLayout<N>;
1389
1390 fn slice_with(&self, range: R) -> Result<(Range<usize>, Self::Layout), SliceError> {
1391 self.slice(range.into_slice_items().as_ref())
1392 }
1393}
1394
1395macro_rules! impl_slice_with_dynlayout {
1396 ($range_ndim:ty) => {
1397 impl<R: IntoSliceItems> SliceWith<R, $range_ndim> for DynLayout {
1398 type Layout = DynLayout;
1399
1400 fn slice_with(&self, range: R) -> Result<(Range<usize>, Self::Layout), SliceError> {
1401 self.slice_dyn(range.into_slice_items().as_ref())
1402 }
1403 }
1404 };
1405}
1406
1407impl_slice_with_dynlayout!(U0);
1408impl_slice_with_dynlayout!(U1);
1409impl_slice_with_dynlayout!(U2);
1410impl_slice_with_dynlayout!(U3);
1411impl_slice_with_dynlayout!(U4);
1412impl_slice_with_dynlayout!(U5);
1413
1414macro_rules! impl_slice_with {
1415 ($ndim:literal, $range_ndim:ty, $out_ndim:literal) => {
1416 impl<R: IntoSliceItems> SliceWith<R, $range_ndim> for NdLayout<$ndim> {
1417 type Layout = NdLayout<$out_ndim>;
1418
1419 fn slice_with(&self, range: R) -> Result<(Range<usize>, Self::Layout), SliceError> {
1420 self.slice(range.into_slice_items().as_ref())
1421 }
1422 }
1423 };
1424}
1425
1426impl_slice_with!(1, U1, 0);
1427impl_slice_with!(2, U1, 1);
1428impl_slice_with!(2, U2, 0);
1429impl_slice_with!(3, U1, 2);
1430impl_slice_with!(3, U2, 1);
1431impl_slice_with!(3, U3, 0);
1432impl_slice_with!(4, U1, 3);
1433impl_slice_with!(4, U2, 2);
1434impl_slice_with!(4, U3, 1);
1435impl_slice_with!(4, U4, 0);
1436impl_slice_with!(5, U1, 4);
1437impl_slice_with!(5, U2, 3);
1438impl_slice_with!(5, U3, 2);
1439impl_slice_with!(5, U4, 1);
1440impl_slice_with!(5, U5, 0);
1441
1442#[cfg(test)]
1443mod tests {
1444 use rten_testing::TestCases;
1445
1446 use std::ops::Range;
1447
1448 use super::OverlapPolicy;
1449 use crate::SliceItem;
1450 use crate::errors::{ReshapeError, SliceError};
1451 use crate::layout::{DynLayout, Layout, MutLayout, NdLayout, ResizeLayout};
1452
1453 fn layout_with_strides<const N: usize>(shape: [usize; N], strides: [usize; N]) -> NdLayout<N> {
1454 NdLayout::from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap).unwrap()
1455 }
1456
1457 #[test]
1458 fn test_is_broadcast() {
1459 let layout = DynLayout::from_shape(&[5, 5]);
1461 assert!(!layout.is_broadcast());
1462
1463 let layout = DynLayout::from_shape(&[5, 0]);
1465 assert!(!layout.is_broadcast());
1466
1467 let layout =
1469 DynLayout::from_shape_and_strides(&[5, 5], &[0, 0], OverlapPolicy::AllowOverlap)
1470 .unwrap();
1471 assert!(layout.is_broadcast());
1472 }
1473
1474 #[test]
1475 fn test_from_shape_and_strides() {
1476 #[derive(Debug)]
1477 struct Case<'a> {
1478 shape: &'a [usize],
1479 strides: &'a [usize],
1480 }
1481
1482 let cases = [
1483 Case {
1485 shape: &[10, 10],
1486 strides: &[10, 1],
1487 },
1488 Case {
1490 shape: &[10, 10],
1491 strides: &[10, 0],
1492 },
1493 ];
1494
1495 cases.test_each(|case| {
1496 let layout = DynLayout::from_shape_and_strides(
1497 case.shape,
1498 case.strides,
1499 OverlapPolicy::AllowOverlap,
1500 )
1501 .unwrap();
1502 assert_eq!(layout.shape(), case.shape);
1503 assert_eq!(layout.strides(), case.strides);
1504 })
1505 }
1506
1507 #[test]
1508 fn test_index_axis() {
1509 #[derive(Debug)]
1510 struct Case {
1511 layout: NdLayout<2>,
1512 axis: usize,
1513 index: usize,
1514 expected: (usize, NdLayout<1>), }
1516
1517 let cases = [
1518 Case {
1519 layout: NdLayout::from_shape([3, 4]),
1520 axis: 0,
1521 index: 1,
1522 expected: (4, layout_with_strides([4], [1])),
1523 },
1524 Case {
1525 layout: NdLayout::from_shape([3, 4]),
1526 axis: 1,
1527 index: 2,
1528 expected: (2, layout_with_strides([3], [4])),
1529 },
1530 ];
1531
1532 cases.test_each(|case| {
1533 let Case {
1534 layout,
1535 axis,
1536 index,
1537 expected,
1538 } = case;
1539
1540 let (expected_start, expected_layout) = expected;
1541
1542 let (offsets, sliced_layout) = layout.index_axis(*axis, *index);
1543 assert_eq!(sliced_layout, *expected_layout);
1544 assert_eq!(offsets.start, *expected_start);
1545 assert_eq!(offsets.len(), expected_layout.min_data_len());
1546
1547 let (_, sliced_layout_dyn) = layout.as_dyn().index_axis(*axis, *index);
1548 assert_eq!(sliced_layout_dyn, expected_layout.as_dyn());
1549 })
1550 }
1551
1552 #[test]
1553 #[should_panic(expected = "axis < self.ndim()")]
1554 fn test_index_axis_invalid_axis() {
1555 NdLayout::from_shape([2, 3]).index_axis(2, 0);
1556 }
1557
1558 #[test]
1559 #[should_panic(expected = "index < self.size(axis)")]
1560 fn test_index_axis_invalid_index() {
1561 NdLayout::from_shape([2, 3]).index_axis(0, 3);
1562 }
1563
1564 #[test]
1565 fn test_move_axis() {
1566 let mut layout = DynLayout::from_shape(&[2, 4, 8]);
1567 assert_eq!(layout.strides(), [32, 8, 1]);
1568
1569 layout.move_axis(1, 0);
1570 assert_eq!(layout.shape(), [4, 2, 8]);
1571 assert_eq!(layout.strides(), [8, 32, 1]);
1572
1573 layout.move_axis(0, 1);
1574 assert_eq!(layout.shape(), [2, 4, 8]);
1575 assert_eq!(layout.strides(), [32, 8, 1]);
1576
1577 layout.move_axis(2, 1);
1578 assert_eq!(layout.shape(), [2, 8, 4]);
1579 assert_eq!(layout.strides(), [32, 1, 8]);
1580 }
1581
1582 #[test]
1583 #[should_panic]
1584 fn test_move_axis_invalid_from() {
1585 let mut layout = DynLayout::from_shape(&[2, 4, 8]);
1586 layout.move_axis(3, 0);
1587 }
1588
1589 #[test]
1590 #[should_panic]
1591 fn test_move_axis_invalid_to() {
1592 let mut layout = DynLayout::from_shape(&[2, 4, 8]);
1593 layout.move_axis(0, 3);
1594 }
1595
1596 #[test]
1597 #[should_panic(expected = "permutation is invalid")]
1598 fn test_permute_invalid_len() {
1599 let mut layout = DynLayout::from_shape(&[5, 5]);
1600 layout.permute(&[1, 0, 3]);
1601 }
1602
1603 #[test]
1604 #[should_panic(expected = "permutation is invalid")]
1605 fn test_permute_too_few_dims() {
1606 let mut layout = DynLayout::from_shape(&[5, 5]);
1607 layout.permute(&[1]);
1608 }
1609
1610 #[test]
1611 #[should_panic(expected = "permutation is invalid")]
1612 fn test_permute_repeated_dims() {
1613 let mut layout = DynLayout::from_shape(&[5, 5]);
1614 layout.permute(&[1, 1]);
1615 }
1616
1617 #[test]
1618 fn test_remove_axis_of_any_size() {
1619 let shape = [1, 2, 3, 4];
1620 for d in 0..shape.len() {
1621 let mut layout = DynLayout::from_shape(&shape);
1622 let (expected_shape, expected_strides): (Vec<usize>, Vec<usize>) = layout
1623 .shape()
1624 .iter()
1625 .zip(layout.strides())
1626 .enumerate()
1627 .filter_map(|(i, (size, stride))| if i != d { Some((size, stride)) } else { None })
1628 .unzip();
1629
1630 layout.remove_axis_of_any_size(d);
1631
1632 assert_eq!(layout.shape(), expected_shape);
1633 assert_eq!(layout.strides(), expected_strides);
1634 }
1635 }
1636
1637 #[test]
1638 fn test_reshaped() {
1639 #[derive(Debug)]
1640 struct Case<'a> {
1641 layout: DynLayout,
1642 new_shape: &'a [usize],
1643 for_copy: bool,
1644 error: Option<ReshapeError>,
1645 }
1646
1647 let cases = [
1648 Case {
1650 layout: DynLayout::from_shape(&[2, 2]),
1651 new_shape: &[4],
1652 for_copy: false,
1653 error: None,
1654 },
1655 Case {
1656 layout: DynLayout::from_shape(&[2, 2]).transposed(),
1657 new_shape: &[4],
1658 for_copy: false,
1659 error: Some(ReshapeError::NotContiguous),
1660 },
1661 Case {
1662 layout: DynLayout::from_shape(&[2, 2]),
1663 new_shape: &[3],
1664 for_copy: false,
1665 error: Some(ReshapeError::LengthMismatch),
1666 },
1667 Case {
1669 layout: DynLayout::from_shape(&[2, 2]).transposed(),
1670 new_shape: &[4],
1671 for_copy: true,
1672 error: None,
1673 },
1674 Case {
1675 layout: DynLayout::from_shape(&[2, 2]),
1676 new_shape: &[3],
1677 for_copy: false,
1678 error: Some(ReshapeError::LengthMismatch),
1679 },
1680 ];
1681
1682 cases.test_each(|case| {
1683 let Case {
1684 layout,
1685 new_shape,
1686 for_copy,
1687 error,
1688 } = case;
1689
1690 let reshaped = if *for_copy {
1691 layout.reshaped_for_copy(*new_shape)
1692 } else {
1693 layout.reshaped_for_view(*new_shape)
1694 };
1695
1696 assert_eq!(reshaped.as_ref().err(), error.as_ref());
1697 if let Ok(new_layout) = reshaped {
1698 assert_eq!(new_layout.shape(), *new_shape);
1699 }
1700 })
1701 }
1702
1703 #[test]
1704 fn test_squeezed() {
1705 let layout = DynLayout::from_shape(&[1, 1, 10, 20]);
1706 let squeezed = layout.squeezed();
1707 assert_eq!(squeezed.shape(), &[10, 20]);
1708 assert_eq!(squeezed.strides(), &[20, 1]);
1709 }
1710
1711 #[test]
1712 fn test_slice_axis() {
1713 #[derive(Clone, Debug)]
1714 struct Case<'a> {
1715 shape: &'a [usize],
1716 axis: usize,
1717 range: Range<usize>,
1718 sliced_shape: &'a [usize],
1719 offsets: Range<usize>,
1720 }
1721
1722 let cases = [Case {
1723 shape: &[3, 5],
1724 axis: 1,
1725 range: 2..4,
1726 sliced_shape: &[3, 2],
1727 offsets: 2..14,
1728 }];
1729
1730 cases.test_each_clone(|case| {
1731 let Case {
1732 shape,
1733 axis,
1734 range,
1735 sliced_shape,
1736 offsets,
1737 } = case;
1738
1739 let layout = DynLayout::from_shape(shape);
1740 let (offset_range, sliced_layout) = layout.slice_axis(axis, range).unwrap();
1741 assert_eq!(sliced_layout.shape(), sliced_shape);
1742 assert_eq!(sliced_layout.strides(), layout.strides());
1743 assert_eq!(offset_range, offsets);
1744 })
1745 }
1746
1747 #[test]
1748 fn test_slice_axis_invalid() {
1749 #[derive(Debug)]
1750 struct Case<'a> {
1751 shape: &'a [usize],
1752 axis: usize,
1753 range: Range<usize>,
1754 expected: SliceError,
1755 }
1756
1757 let cases = [
1758 Case {
1759 shape: &[1, 2, 3],
1760 axis: 4,
1761 range: 0..1,
1762 expected: SliceError::InvalidAxis { axis: 4 },
1763 },
1764 Case {
1765 shape: &[1, 2, 3],
1766 axis: 0,
1767 range: 0..2,
1768 expected: SliceError::InvalidRange {
1769 axis: 0,
1770 range: (0..2).into(),
1771 size: 1,
1772 },
1773 },
1774 ];
1775
1776 cases.test_each(|case| {
1777 let layout = DynLayout::from_shape(case.shape);
1778 let result = layout.slice_axis(case.axis, case.range.clone());
1779 assert_eq!(result, Err(case.expected.clone()));
1780 })
1781 }
1782
1783 #[test]
1784 fn test_slice_invalid() {
1785 #[derive(Debug)]
1786 struct Case<'a> {
1787 layout: DynLayout,
1788 ranges: &'a [SliceItem],
1789 expected: SliceError,
1790 }
1791
1792 let cases = [
1793 Case {
1794 layout: DynLayout::from_shape(&[3, 5]),
1795 ranges: &[SliceItem::Index(4), SliceItem::Index(0)],
1796 expected: SliceError::InvalidIndex {
1797 axis: 0,
1798 index: 4,
1799 size: 3,
1800 },
1801 },
1802 Case {
1803 layout: DynLayout::from_shape(&[3, 5]),
1804 ranges: &[SliceItem::Range((1..4).into()), SliceItem::Index(0)],
1805 expected: SliceError::InvalidRange {
1806 axis: 0,
1807 range: (1..4).into(),
1808 size: 3,
1809 },
1810 },
1811 Case {
1812 layout: DynLayout::from_shape(&[3, 5]),
1813 ranges: &[SliceItem::Index(-4)],
1814 expected: SliceError::InvalidIndex {
1815 axis: 0,
1816 index: -4,
1817 size: 3,
1818 },
1819 },
1820 Case {
1821 layout: DynLayout::from_shape(&[3, 5]),
1822 ranges: &[SliceItem::Range((4..).into()), SliceItem::Index(0)],
1823 expected: SliceError::InvalidRange {
1824 axis: 0,
1825 range: (4..).into(),
1826 size: 3,
1827 },
1828 },
1829 Case {
1830 layout: DynLayout::from_shape(&[3, 5]),
1831 ranges: &[SliceItem::full_range(), SliceItem::range(0, None, -1)],
1832 expected: SliceError::InvalidStep { axis: 1, step: -1 },
1833 },
1834 ];
1835
1836 cases.test_each(|case| {
1837 let result = case.layout.slice_dyn(case.ranges);
1838 assert_eq!(result, Err(case.expected.clone()));
1839 })
1840 }
1841
1842 #[test]
1843 fn test_size_stride() {
1844 let layout = DynLayout::from_shape(&[10, 20, 30]);
1845 for (dim, (&size, &stride)) in layout.shape().iter().zip(layout.strides()).enumerate() {
1846 assert_eq!(layout.size(dim), size);
1847 assert_eq!(layout.stride(dim), stride);
1848 }
1849 }
1850
1851 #[test]
1852 fn test_split() {
1853 #[derive(Debug)]
1854 struct Case {
1855 shape: [usize; 2],
1856 strides: Option<[usize; 2]>,
1857 axis: usize,
1858 mid: usize,
1859 }
1860
1861 let mut cases = Vec::new();
1862
1863 let shape = [4, 2];
1865 for axis in 0..shape.len() {
1866 for mid in 0..shape[axis] {
1867 cases.push(Case {
1868 shape,
1869 axis,
1870 mid,
1871 strides: None,
1872 });
1873 }
1874 }
1875
1876 cases.push(Case {
1878 shape: [0, 0],
1879 strides: None,
1880 axis: 0,
1881 mid: 0,
1882 });
1883
1884 cases.push(Case {
1887 shape: [1, 4],
1888 strides: Some([10, 0]),
1889 axis: 0,
1890 mid: 1,
1891 });
1892
1893 fn check_split<L: MutLayout>(layout: L, axis: usize, mid: usize) {
1894 let (left, right) = layout.split(axis, mid);
1895 let (left_offsets, left_layout) = left;
1896 let (right_offsets, right_layout) = right;
1897
1898 assert_eq!(left_layout.strides(), layout.strides());
1899 assert_eq!(right_layout.strides(), layout.strides());
1900
1901 assert_eq!(left_offsets.len(), left_layout.min_data_len());
1902 assert_eq!(right_offsets.len(), right_layout.min_data_len());
1903
1904 let orig_len = layout.min_data_len();
1905 assert!(left_offsets.start <= orig_len && left_offsets.end <= orig_len);
1906 assert!(right_offsets.start <= orig_len && right_offsets.end <= orig_len);
1907
1908 for i in 0..layout.ndim() {
1909 assert_eq!(
1910 left_layout.size(i),
1911 if i == axis { mid } else { layout.size(i) }
1912 );
1913 assert_eq!(
1914 right_layout.size(i),
1915 if i == axis {
1916 layout.size(i) - mid
1917 } else {
1918 layout.size(i)
1919 }
1920 );
1921 }
1922 }
1923
1924 cases.test_each(|case| {
1925 let Case {
1926 shape,
1927 strides,
1928 axis,
1929 mid,
1930 } = case;
1931
1932 let layout = if let Some(strides) = strides {
1933 NdLayout::from_shape_and_strides(*shape, *strides, OverlapPolicy::AllowOverlap)
1934 .unwrap()
1935 } else {
1936 NdLayout::from_shape(*shape)
1937 };
1938 let dyn_layout = if let Some(strides) = strides {
1939 DynLayout::from_shape_and_strides(
1940 shape.as_slice(),
1941 strides.as_slice(),
1942 OverlapPolicy::AllowOverlap,
1943 )
1944 .unwrap()
1945 } else {
1946 DynLayout::from_shape(shape.as_slice())
1947 };
1948
1949 check_split(layout, *axis, *mid);
1950 check_split(dyn_layout, *axis, *mid);
1951 })
1952 }
1953
1954 #[test]
1955 fn test_merge_axes() {
1956 #[derive(Debug)]
1957 struct Case<'a> {
1958 shape: &'a [usize],
1959 strides: &'a [usize],
1960 merged_shape: &'a [usize],
1961 merged_strides: &'a [usize],
1962 }
1963
1964 let cases = [
1965 Case {
1967 shape: &[],
1968 strides: &[],
1969 merged_shape: &[],
1970 merged_strides: &[],
1971 },
1972 Case {
1974 shape: &[10],
1975 strides: &[2],
1976 merged_shape: &[10],
1977 merged_strides: &[2],
1978 },
1979 Case {
1981 shape: &[10, 10],
1982 strides: &[10, 1],
1983 merged_shape: &[100],
1984 merged_strides: &[1],
1985 },
1986 Case {
1988 shape: &[10, 10],
1989 strides: &[1, 10],
1990 merged_shape: &[10, 10],
1991 merged_strides: &[1, 10],
1992 },
1993 Case {
1995 shape: &[1, 10, 10],
1996 strides: &[10, 1, 10],
1997 merged_shape: &[10, 10],
1998 merged_strides: &[1, 10],
1999 },
2000 Case {
2002 shape: &[2, 1, 1, 2],
2003 strides: &[2, 2, 2, 1],
2004 merged_shape: &[4],
2005 merged_strides: &[1],
2006 },
2007 Case {
2011 shape: &[2, 1, 1, 2],
2012 strides: &[2, 4, 4, 1],
2013 merged_shape: &[4],
2014 merged_strides: &[1],
2015 },
2016 ];
2017
2018 cases.test_each(|case| {
2019 let mut layout = DynLayout::from_shape_and_strides(
2020 case.shape,
2021 case.strides,
2022 OverlapPolicy::AllowOverlap,
2023 )
2024 .unwrap();
2025 layout.merge_axes();
2026 assert_eq!(layout.shape(), case.merged_shape);
2027 assert_eq!(layout.strides(), case.merged_strides);
2028 })
2029 }
2030}