1use crate::prelude_dev::*;
3use itertools::izip;
4
5#[doc = include_str!("readme.md")]
15#[derive(Clone)]
16pub struct Layout<D>
17where
18 D: DimBaseAPI,
19{
20 pub(crate) shape: D,
22 pub(crate) stride: D::Stride,
23 pub(crate) offset: usize,
24}
25
26unsafe impl<D> Send for Layout<D> where D: DimBaseAPI {}
27unsafe impl<D> Sync for Layout<D> where D: DimBaseAPI {}
28
29impl<D> Layout<D>
35where
36 D: DimBaseAPI,
37{
38 #[inline]
40 pub fn shape(&self) -> &D {
41 &self.shape
42 }
43
44 #[inline]
46 pub fn stride(&self) -> &D::Stride {
47 &self.stride
48 }
49
50 #[inline]
52 pub fn offset(&self) -> usize {
53 self.offset
54 }
55
56 #[inline]
58 pub fn ndim(&self) -> usize {
59 self.shape.ndim()
60 }
61
62 #[inline]
68 pub fn size(&self) -> usize {
69 self.shape().as_ref().iter().product()
70 }
71
72 pub unsafe fn set_offset(&mut self, offset: usize) -> &mut Self {
79 self.offset = offset;
80 return self;
81 }
82}
83
84impl<D> Layout<D>
86where
87 D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
88{
89 pub fn f_prefer(&self) -> bool {
91 if self.ndim() == 0 || self.size() == 0 {
93 return true;
94 }
95
96 let stride = self.stride.as_ref();
97 let shape = self.shape.as_ref();
98 let mut last = 0;
99 for (&s, &d) in stride.iter().zip(shape.iter()) {
100 if d != 1 {
101 if s < last {
102 return false;
104 }
105 if last == 0 && s != 1 {
106 return false;
108 }
109 last = s;
110 } else if last == 0 {
111 last = 1;
114 }
115 }
116 return true;
117 }
118
119 pub fn c_prefer(&self) -> bool {
121 if self.ndim() == 0 || self.size() == 0 {
123 return true;
124 }
125
126 let stride = self.stride.as_ref();
127 let shape = self.shape.as_ref();
128 let mut last = 0;
129 for (&s, &d) in stride.iter().zip(shape.iter()).rev() {
130 if d != 1 {
131 if s < last {
132 return false;
134 }
135 if last == 0 && s != 1 {
136 return false;
138 }
139 last = s;
140 } else if last == 0 {
141 last = 1;
144 }
145 }
146 return true;
147 }
148
149 pub fn ndim_of_f_contig(&self) -> usize {
154 if self.ndim() == 0 || self.size() == 0 {
155 return self.ndim();
156 }
157 let stride = self.stride.as_ref();
158 let shape = self.shape.as_ref();
159 let mut acc = 1;
160 for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).enumerate() {
161 if d != 1 && s != acc {
162 return ndim;
163 }
164 acc *= d as isize;
165 }
166 return self.ndim();
167 }
168
169 pub fn ndim_of_c_contig(&self) -> usize {
174 if self.ndim() == 0 || self.size() == 0 {
175 return self.ndim();
176 }
177 let stride = self.stride.as_ref();
178 let shape = self.shape.as_ref();
179 let mut acc = 1;
180 for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).rev().enumerate() {
181 if d != 1 && s != acc {
182 return ndim;
183 }
184 acc *= d as isize;
185 }
186 return self.ndim();
187 }
188
189 pub fn f_contig(&self) -> bool {
197 self.ndim() == self.ndim_of_f_contig()
198 }
199
200 pub fn c_contig(&self) -> bool {
208 self.ndim() == self.ndim_of_c_contig()
209 }
210
211 pub fn index_f(&self, index: &[isize]) -> Result<usize> {
215 rstsr_assert_eq!(index.len(), self.ndim(), InvalidLayout)?;
216 let mut pos = self.offset() as isize;
217 let shape = self.shape.as_ref();
218 let stride = self.stride.as_ref();
219
220 for (&idx, &shp, &strd) in izip!(index.iter(), shape.iter(), stride.iter()) {
221 let idx = if idx < 0 { idx + shp as isize } else { idx };
222 rstsr_pattern!(idx, 0..(shp as isize), ValueOutOfRange)?;
223 pos += strd * idx;
224 }
225 rstsr_pattern!(pos, 0.., ValueOutOfRange)?;
226 return Ok(pos as usize);
227 }
228
229 pub fn index(&self, index: &[isize]) -> usize {
238 self.index_f(index).unwrap()
239 }
240
241 pub fn bounds_index(&self) -> Result<(usize, usize)> {
247 let n = self.ndim();
248 let offset = self.offset;
249 let shape = self.shape.as_ref();
250 let stride = self.stride.as_ref();
251
252 if n == 0 {
253 return Ok((offset, offset + 1));
254 }
255
256 let mut min = offset as isize;
257 let mut max = offset as isize;
258
259 for i in 0..n {
260 if shape[i] == 0 {
261 return Ok((offset, offset));
262 }
263 if stride[i] > 0 {
264 max += stride[i] * (shape[i] as isize - 1);
265 } else {
266 min += stride[i] * (shape[i] as isize - 1);
267 }
268 }
269 rstsr_pattern!(min, 0.., ValueOutOfRange)?;
270 return Ok((min as usize, max as usize + 1));
271 }
272
273 pub fn check_strides(&self) -> Result<()> {
294 let shape = self.shape.as_ref();
295 let stride = self.stride.as_ref();
296 rstsr_assert_eq!(shape.len(), stride.len(), InvalidLayout)?;
297 let n = shape.len();
298
299 if self.size() == 0 || n == 0 {
302 return Ok(());
303 }
304
305 let mut indices = (0..n).filter(|&k| shape[k] > 1).collect::<Vec<_>>();
306 indices.sort_by_key(|&k| stride[k].abs());
307 let shape_sorted = indices.iter().map(|&k| shape[k]).collect::<Vec<_>>();
308 let stride_sorted = indices.iter().map(|&k| stride[k].unsigned_abs()).collect::<Vec<_>>();
309
310 let mut elem_cum = 0;
312 for i in 0..indices.len() {
313 rstsr_pattern!(
315 elem_cum,
316 0..stride_sorted[i],
317 InvalidLayout,
318 "Either stride be zero, or stride too small that elements in tensor can be overlapped."
319 )?;
320 elem_cum += (shape_sorted[i] - 1) * stride_sorted[i];
321 }
322 return Ok(());
323 }
324
325 pub fn diagonal(
326 &self,
327 offset: Option<isize>,
328 axis1: Option<isize>,
329 axis2: Option<isize>,
330 ) -> Result<Layout<<D as DimSmallerOneAPI>::SmallerOne>>
331 where
332 D: DimSmallerOneAPI,
333 {
334 rstsr_assert!(self.ndim() >= 2, InvalidLayout)?;
336 let offset = offset.unwrap_or(0);
338 let axis1 = axis1.unwrap_or(0);
339 let axis2 = axis2.unwrap_or(1);
340 let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
341 let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
342 rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
343 rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
344 let axis1 = axis1 as usize;
345 let axis2 = axis2 as usize;
346
347 let d1 = self.shape()[axis1] as isize;
349 let d2 = self.shape()[axis2] as isize;
350 let t1 = self.stride()[axis1];
351 let t2 = self.stride()[axis2];
352
353 let (offset_diag, d_diag) = if (-d2 + 1..0).contains(&offset) {
355 let offset = -offset;
356 let offset_diag = (self.offset() as isize + t1 * offset) as usize;
357 let d_diag = (d1 - offset).min(d2) as usize;
358 (offset_diag, d_diag)
359 } else if (0..d1).contains(&offset) {
360 let offset_diag = (self.offset() as isize + t2 * offset) as usize;
361 let d_diag = (d2 - offset).min(d1) as usize;
362 (offset_diag, d_diag)
363 } else {
364 (self.offset(), 0)
365 };
366
367 let t_diag = t1 + t2;
369 let mut shape_diag = vec![];
370 let mut stride_diag = vec![];
371 for i in 0..self.ndim() {
372 if i != axis1 && i != axis2 {
373 shape_diag.push(self.shape()[i]);
374 stride_diag.push(self.stride()[i]);
375 }
376 }
377 shape_diag.push(d_diag);
378 stride_diag.push(t_diag);
379 let layout_diag = Layout::new(shape_diag, stride_diag, offset_diag)?;
380 return layout_diag.into_dim::<<D as DimSmallerOneAPI>::SmallerOne>();
381 }
382}
383
384impl<D> Layout<D>
387where
388 D: DimBaseAPI,
389{
390 #[inline]
398 pub fn new(shape: D, stride: D::Stride, offset: usize) -> Result<Self>
399 where
400 D: DimShapeAPI + DimStrideAPI,
401 {
402 let layout = unsafe { Layout::new_unchecked(shape, stride, offset) };
403 layout.bounds_index()?;
404 layout.check_strides()?;
405 return Ok(layout);
406 }
407
408 #[inline]
415 pub unsafe fn new_unchecked(shape: D, stride: D::Stride, offset: usize) -> Self {
416 Layout { shape, stride, offset }
417 }
418
419 #[inline]
422 pub fn new_shape(&self) -> D {
423 self.shape.new_shape()
424 }
425
426 #[inline]
429 pub fn new_stride(&self) -> D::Stride {
430 self.shape.new_stride()
431 }
432}
433
434impl<D> Layout<D>
436where
437 D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
438{
439 pub fn transpose(&self, axes: &[isize]) -> Result<Self> {
446 let n = self.ndim();
448 rstsr_assert_eq!(
449 axes.len(),
450 n,
451 InvalidLayout,
452 "number of elements in axes should be the same to number of dimensions."
453 )?;
454 let mut permut_used = vec![false; n];
456 for &p in axes {
457 let p = if p < 0 { p + n as isize } else { p };
458 rstsr_pattern!(p, 0..n as isize, InvalidLayout)?;
459 let p = p as usize;
460 permut_used[p] = true;
461 }
462 rstsr_assert!(
463 permut_used.iter().all(|&b| b),
464 InvalidLayout,
465 "axes should contain all elements from 0 to n-1."
466 )?;
467 let axes = axes
468 .iter()
469 .map(|&p| if p < 0 { p + n as isize } else { p } as usize)
470 .collect::<Vec<_>>();
471
472 let shape_old = self.shape();
473 let stride_old = self.stride();
474 let mut shape = self.new_shape();
475 let mut stride = self.new_stride();
476 for i in 0..self.ndim() {
477 shape[i] = shape_old[axes[i]];
478 stride[i] = stride_old[axes[i]];
479 }
480 return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
481 }
482
483 pub fn permute_dims(&self, axes: &[isize]) -> Result<Self> {
487 self.transpose(axes)
488 }
489
490 pub fn reverse_axes(&self) -> Self {
492 let shape_old = self.shape();
493 let stride_old = self.stride();
494 let mut shape = self.new_shape();
495 let mut stride = self.new_stride();
496 for i in 0..self.ndim() {
497 shape[i] = shape_old[self.ndim() - i - 1];
498 stride[i] = stride_old[self.ndim() - i - 1];
499 }
500 return unsafe { Layout::new_unchecked(shape, stride, self.offset) };
501 }
502
503 pub fn swapaxes(&self, axis1: isize, axis2: isize) -> Result<Self> {
505 let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
506 rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
507 let axis1 = axis1 as usize;
508
509 let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
510 rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
511 let axis2 = axis2 as usize;
512
513 let mut shape = self.shape().clone();
514 let mut stride = self.stride().clone();
515 shape.as_mut().swap(axis1, axis2);
516 stride.as_mut().swap(axis1, axis2);
517 return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
518 }
519}
520
521impl<D> Layout<D>
525where
526 D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
527{
528 #[inline]
539 pub unsafe fn index_uncheck(&self, index: &[usize]) -> isize {
540 let stride = self.stride.as_ref();
541 match self.ndim() {
542 0 => self.offset as isize,
543 1 => self.offset as isize + stride[0] * index[0] as isize,
544 2 => {
545 self.offset as isize + stride[0] * index[0] as isize + stride[1] * index[1] as isize
546 },
547 3 => {
548 self.offset as isize
549 + stride[0] * index[0] as isize
550 + stride[1] * index[1] as isize
551 + stride[2] * index[2] as isize
552 },
553 4 => {
554 self.offset as isize
555 + stride[0] * index[0] as isize
556 + stride[1] * index[1] as isize
557 + stride[2] * index[2] as isize
558 + stride[3] * index[3] as isize
559 },
560 _ => {
561 let mut pos = self.offset as isize;
562 stride.iter().zip(index.iter()).for_each(|(&s, &i)| pos += s * i as isize);
563 pos
564 },
565 }
566 }
567}
568
569impl<D> PartialEq for Layout<D>
570where
571 D: DimBaseAPI,
572{
573 fn eq(&self, other: &Self) -> bool {
576 if self.ndim() != other.ndim() {
577 return false;
578 }
579 for i in 0..self.ndim() {
580 let s1 = self.shape()[i];
581 let s2 = other.shape()[i];
582 if s1 != s2 {
583 return false;
584 }
585 if s1 != 1 && s1 != 0 && self.stride()[i] != other.stride()[i] {
586 return false;
587 }
588 }
589 return true;
590 }
591}
592
593pub trait DimLayoutContigAPI: DimBaseAPI + DimShapeAPI + DimStrideAPI {
594 fn new_c_contig(&self, offset: Option<usize>) -> Layout<Self> {
597 let shape = self.clone();
598 let stride = shape.stride_c_contig();
599 unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
600 }
601
602 fn new_f_contig(&self, offset: Option<usize>) -> Layout<Self> {
605 let shape = self.clone();
606 let stride = shape.stride_f_contig();
607 unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
608 }
609
610 fn c(&self) -> Layout<Self> {
613 self.new_c_contig(None)
614 }
615
616 fn f(&self) -> Layout<Self> {
619 self.new_f_contig(None)
620 }
621
622 fn new_contig(&self, offset: Option<usize>, order: FlagOrder) -> Layout<Self> {
624 match order {
625 FlagOrder::C => self.new_c_contig(offset),
626 FlagOrder::F => self.new_f_contig(offset),
627 }
628 }
629}
630
631impl<const N: usize> DimLayoutContigAPI for Ix<N> {}
632impl DimLayoutContigAPI for IxD {}
633
634pub trait DimIntoAPI<D>: DimBaseAPI
639where
640 D: DimBaseAPI,
641{
642 fn into_dim(layout: Layout<Self>) -> Result<Layout<D>>;
643}
644
645impl<D> DimIntoAPI<D> for IxD
646where
647 D: DimBaseAPI,
648{
649 fn into_dim(layout: Layout<IxD>) -> Result<Layout<D>> {
650 let shape = layout.shape().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
651 let stride = layout.stride().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
652 let offset = layout.offset();
653 return Ok(Layout { shape, stride, offset });
654 }
655}
656
657impl<const N: usize> DimIntoAPI<IxD> for Ix<N> {
658 fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<IxD>> {
659 let shape = (*layout.shape()).into();
660 let stride = (*layout.stride()).into();
661 let offset = layout.offset();
662 return Ok(Layout { shape, stride, offset });
663 }
664}
665
666impl<const N: usize, const M: usize> DimIntoAPI<Ix<M>> for Ix<N> {
667 fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<Ix<M>>> {
668 rstsr_assert_eq!(N, M, InvalidLayout)?;
669 let shape = layout.shape().to_vec().try_into().unwrap();
670 let stride = layout.stride().to_vec().try_into().unwrap();
671 let offset = layout.offset();
672 return Ok(Layout { shape, stride, offset });
673 }
674}
675
676impl<D> Layout<D>
677where
678 D: DimBaseAPI,
679{
680 pub fn into_dim<D2>(self) -> Result<Layout<D2>>
682 where
683 D2: DimBaseAPI,
684 D: DimIntoAPI<D2>,
685 {
686 D::into_dim(self)
687 }
688
689 pub fn to_dim<D2>(&self) -> Result<Layout<D2>>
691 where
692 D2: DimBaseAPI,
693 D: DimIntoAPI<D2>,
694 {
695 D::into_dim(self.clone())
696 }
697}
698
699impl<const N: usize> From<Ix<N>> for Layout<Ix<N>> {
700 fn from(shape: Ix<N>) -> Self {
701 let stride = shape.stride_contig();
702 Layout { shape, stride, offset: 0 }
703 }
704}
705
706impl From<IxD> for Layout<IxD> {
707 fn from(shape: IxD) -> Self {
708 let stride = shape.stride_contig();
709 Layout { shape, stride, offset: 0 }
710 }
711}
712
713#[cfg(test)]
716mod test {
717 use std::panic::catch_unwind;
718
719 use super::*;
720
721 #[test]
722 fn test_layout_new() {
723 let shape = [3, 2, 6];
725 let stride = [3, -300, 15];
726 let layout = Layout::new(shape, stride, 917).unwrap();
727 assert_eq!(layout.shape(), &[3, 2, 6]);
728 assert_eq!(layout.stride(), &[3, -300, 15]);
729 assert_eq!(layout.offset(), 917);
730 assert_eq!(layout.ndim(), 3);
731 let shape = [3, 2, 6];
733 let stride = [3, -300, 15];
734 let layout = Layout::new(shape, stride, 0);
735 assert!(layout.is_err());
736 let shape = [3, 2, 6];
738 let stride = [3, -300, 0];
739 let layout = Layout::new(shape, stride, 1000);
740 assert!(layout.is_err());
741 let shape = [3, 2, 6];
743 let stride = [3, 4, 7];
744 let layout = Layout::new(shape, stride, 1000);
745 assert!(layout.is_err());
746 let shape = [];
748 let stride = [];
749 let layout = Layout::new(shape, stride, 1000);
750 assert!(layout.is_ok());
751 let shape = [3, 1, 5];
753 let stride = [1, 0, 15];
754 let layout = Layout::new(shape, stride, 1);
755 assert!(layout.is_ok());
756 let shape = [3, 1, 5];
758 let stride = [1, 0, 15];
759 let layout = Layout::new(shape, stride, 1);
760 assert!(layout.is_ok());
761 let shape = [3, 0, 5];
763 let stride = [-1, -2, -3];
764 let layout = Layout::new(shape, stride, 1);
765 assert!(layout.is_ok());
766 let shape = [3, 2, 6];
768 let stride = [3, -300, 0];
769 let r = catch_unwind(|| unsafe { Layout::new_unchecked(shape, stride, 1000) });
770 assert!(r.is_ok());
771 }
772
773 #[test]
774 fn test_is_f_prefer() {
775 let shape = [3, 5, 7];
777 let layout = Layout::new(shape, [1, 10, 100], 0).unwrap();
778 assert!(layout.f_prefer());
779 let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
780 assert!(layout.f_prefer());
781 let layout = Layout::new(shape, [1, 3, -15], 1000).unwrap();
782 assert!(!layout.f_prefer());
783 let layout = Layout::new(shape, [1, 21, 3], 0).unwrap();
784 assert!(!layout.f_prefer());
785 let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
786 assert!(!layout.f_prefer());
787 let layout = Layout::new(shape, [2, 6, 30], 0).unwrap();
788 assert!(!layout.f_prefer());
789 let layout = Layout::new([], [], 0).unwrap();
791 assert!(layout.f_prefer());
792 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
794 assert!(layout.f_prefer());
795 let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
797 assert!(layout.f_prefer());
798 }
799
800 #[test]
801 fn test_is_c_prefer() {
802 let shape = [3, 5, 7];
804 let layout = Layout::new(shape, [100, 10, 1], 0).unwrap();
805 assert!(layout.c_prefer());
806 let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
807 assert!(layout.c_prefer());
808 let layout = Layout::new(shape, [-35, 7, 1], 1000).unwrap();
809 assert!(!layout.c_prefer());
810 let layout = Layout::new(shape, [7, 21, 1], 0).unwrap();
811 assert!(!layout.c_prefer());
812 let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
813 assert!(!layout.c_prefer());
814 let layout = Layout::new(shape, [70, 14, 2], 0).unwrap();
815 assert!(!layout.c_prefer());
816 let layout = Layout::new([], [], 0).unwrap();
818 assert!(layout.c_prefer());
819 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
821 assert!(layout.c_prefer());
822 let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
824 assert!(layout.c_prefer());
825 }
826
827 #[test]
828 fn test_is_f_contig() {
829 let shape = [3, 5, 7];
831 let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
832 assert!(layout.f_contig());
833 let layout = Layout::new(shape, [1, 4, 20], 0).unwrap();
834 assert!(!layout.f_contig());
835 let layout = Layout::new([], [], 0).unwrap();
837 assert!(layout.f_contig());
838 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
840 assert!(layout.f_contig());
841 let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
843 assert!(layout.f_contig());
844 }
845
846 #[test]
847 fn test_is_c_contig() {
848 let shape = [3, 5, 7];
850 let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
851 assert!(layout.c_contig());
852 let layout = Layout::new(shape, [36, 7, 1], 0).unwrap();
853 assert!(!layout.c_contig());
854 let layout = Layout::new([], [], 0).unwrap();
856 assert!(layout.c_contig());
857 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
859 assert!(layout.c_contig());
860 let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
862 assert!(layout.c_contig());
863 }
864
865 #[test]
866 fn test_index() {
867 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
871 assert_eq!(layout.index(&[0, 0, 0]), 782);
872 assert_eq!(layout.index(&[2, 1, 4]), 668);
873 assert_eq!(layout.index(&[1, -2, -3]), 830);
874 let layout = Layout::new([], [], 10).unwrap();
876 assert_eq!(layout.index(&[]), 10);
877 }
878
879 #[test]
880 fn test_bounds_index() {
881 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
886 assert_eq!(layout.bounds_index().unwrap(), (602, 864));
887 let layout = unsafe { Layout::new_unchecked([3, 2, 6], [3, -180, 15], 15) };
889 assert!(layout.bounds_index().is_err());
890 let layout = Layout::new([], [], 10).unwrap();
892 assert_eq!(layout.bounds_index().unwrap(), (10, 11));
893 }
894
895 #[test]
896 fn test_transpose() {
897 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
899 let trans = layout.transpose(&[2, 0, 1]).unwrap();
900 assert_eq!(trans.shape(), &[6, 3, 2]);
901 assert_eq!(trans.stride(), &[15, 3, -180]);
902 let trans = layout.permute_dims(&[2, 0, 1]).unwrap();
904 assert_eq!(trans.shape(), &[6, 3, 2]);
905 assert_eq!(trans.stride(), &[15, 3, -180]);
906 let trans = layout.transpose(&[-1, 0, 1]).unwrap();
908 assert_eq!(trans.shape(), &[6, 3, 2]);
909 assert_eq!(trans.stride(), &[15, 3, -180]);
910 let trans = layout.transpose(&[-2, 0, 1]);
912 assert!(trans.is_err());
913 let trans = layout.transpose(&[1, 0]);
915 assert!(trans.is_err());
916 let layout = Layout::new([], [], 0).unwrap();
918 let trans = layout.transpose(&[]);
919 assert!(trans.is_ok());
920 }
921
922 #[test]
923 fn test_reverse_axes() {
924 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
926 let trans = layout.reverse_axes();
927 assert_eq!(trans.shape(), &[6, 2, 3]);
928 assert_eq!(trans.stride(), &[15, -180, 3]);
929 let layout = Layout::new([], [], 782).unwrap();
931 let trans = layout.reverse_axes();
932 assert_eq!(trans.shape(), &[]);
933 assert_eq!(trans.stride(), &[]);
934 }
935
936 #[test]
937 fn test_swapaxes() {
938 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
940 let trans = layout.swapaxes(-1, -2).unwrap();
941 assert_eq!(trans.shape(), &[3, 6, 2]);
942 assert_eq!(trans.stride(), &[3, 15, -180]);
943 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
945 let trans = layout.swapaxes(-1, -1).unwrap();
946 assert_eq!(trans.shape(), &[3, 2, 6]);
947 assert_eq!(trans.stride(), &[3, -180, 15]);
948 }
949
950 #[test]
951 fn test_index_uncheck() {
952 unsafe {
956 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
958 assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
959 assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
960 let layout = Layout::new(vec![3, 2, 6], vec![3, -180, 15], 782).unwrap();
962 assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
963 assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
964 let layout = Layout::new([], [], 10).unwrap();
966 assert_eq!(layout.index_uncheck(&[]), 10);
967 }
968 }
969
970 #[test]
971 fn test_diagonal() {
972 let layout = [2, 3, 4].c();
973 let diag = layout.diagonal(None, None, None).unwrap();
974 assert_eq!(diag, Layout::new([4, 2], [1, 16], 0).unwrap());
975 let diag = layout.diagonal(Some(-1), Some(-2), Some(-1)).unwrap();
976 assert_eq!(diag, Layout::new([2, 2], [12, 5], 0).unwrap());
977 let diag = layout.diagonal(Some(-4), Some(-2), Some(-1)).unwrap();
978 assert_eq!(diag, Layout::new([2, 0], [12, 5], 0).unwrap());
979 }
980
981 #[test]
982 fn test_new_contig() {
983 let layout = [3, 2, 6].c();
984 assert_eq!(layout.shape(), &[3, 2, 6]);
985 assert_eq!(layout.stride(), &[12, 6, 1]);
986 let layout = [3, 2, 6].f();
987 assert_eq!(layout.shape(), &[3, 2, 6]);
988 assert_eq!(layout.stride(), &[1, 3, 6]);
989 let layout: Layout<_> = [3, 2, 6].into();
992 println!("{layout:?}");
993 }
994
995 #[test]
996 fn test_layout_cast() {
997 let layout = [3, 2, 6].c();
998 assert!(layout.clone().into_dim::<IxD>().is_ok());
999 assert!(layout.clone().into_dim::<Ix3>().is_ok());
1000 let layout = vec![3, 2, 6].c();
1001 assert!(layout.clone().into_dim::<IxD>().is_ok());
1002 assert!(layout.clone().into_dim::<Ix3>().is_ok());
1003 assert!(layout.clone().into_dim::<Ix2>().is_err());
1004 }
1005
1006 #[test]
1007 fn test_unravel_index() {
1008 unsafe {
1009 let shape = [3, 2, 6];
1010 assert_eq!(shape.unravel_index_f(0), [0, 0, 0]);
1011 assert_eq!(shape.unravel_index_f(16), [1, 1, 2]);
1012 assert_eq!(shape.unravel_index_c(0), [0, 0, 0]);
1013 assert_eq!(shape.unravel_index_c(16), [1, 0, 4]);
1014 }
1015 }
1016
1017 #[test]
1018 fn fix_too_strict_stride_check() {
1019 let layout = [10, 11, 12].c();
1020 let slc = (.., slice!(-1, 0, -4));
1021 let slc: AxesIndex<Indexer> = slc.try_into().unwrap();
1022 let indexed = layout.dim_slice(slc.as_ref()).unwrap();
1023 assert_eq!(indexed.shape(), &[10, 3, 12]);
1024 assert_eq!(indexed.stride(), &[132, -48, 1]);
1025 }
1026}