1use crate::prelude_dev::*;
3use itertools::izip;
4
5#[doc = include_str!("readme.md")]
15#[derive(Clone)]
16pub struct Layout<D>
17where
18 D: DimBaseAPI,
19{
20 pub(crate) shape: D,
22 pub(crate) stride: D::Stride,
23 pub(crate) offset: usize,
24}
25
26unsafe impl<D> Send for Layout<D> where D: DimBaseAPI {}
27unsafe impl<D> Sync for Layout<D> where D: DimBaseAPI {}
28
29impl<D> Layout<D>
35where
36 D: DimBaseAPI,
37{
38 #[inline]
40 pub fn shape(&self) -> &D {
41 &self.shape
42 }
43
44 #[inline]
46 pub fn stride(&self) -> &D::Stride {
47 &self.stride
48 }
49
50 #[inline]
52 pub fn offset(&self) -> usize {
53 self.offset
54 }
55
56 #[inline]
58 pub fn ndim(&self) -> usize {
59 self.shape.ndim()
60 }
61
62 #[inline]
68 pub fn size(&self) -> usize {
69 self.shape().as_ref().iter().product()
70 }
71
72 pub unsafe fn set_offset(&mut self, offset: usize) -> &mut Self {
79 self.offset = offset;
80 return self;
81 }
82}
83
84impl<D> Layout<D>
86where
87 D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
88{
89 pub fn f_prefer(&self) -> bool {
91 if self.ndim() == 0 || self.size() == 0 {
93 return true;
94 }
95
96 let stride = self.stride.as_ref();
97 let shape = self.shape.as_ref();
98 let mut last = 0;
99 for (&s, &d) in stride.iter().zip(shape.iter()) {
100 if d != 1 {
101 if s < last {
102 return false;
104 }
105 if last == 0 && s != 1 {
106 return false;
108 }
109 last = s;
110 } else if last == 0 {
111 last = 1;
114 }
115 }
116 return true;
117 }
118
119 pub fn c_prefer(&self) -> bool {
121 if self.ndim() == 0 || self.size() == 0 {
123 return true;
124 }
125
126 let stride = self.stride.as_ref();
127 let shape = self.shape.as_ref();
128 let mut last = 0;
129 for (&s, &d) in stride.iter().zip(shape.iter()).rev() {
130 if d != 1 {
131 if s < last {
132 return false;
134 }
135 if last == 0 && s != 1 {
136 return false;
138 }
139 last = s;
140 } else if last == 0 {
141 last = 1;
144 }
145 }
146 return true;
147 }
148
149 pub fn ndim_of_f_contig(&self) -> usize {
154 if self.ndim() == 0 || self.size() == 0 {
155 return self.ndim();
156 }
157 let stride = self.stride.as_ref();
158 let shape = self.shape.as_ref();
159 let mut acc = 1;
160 for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).enumerate() {
161 if d != 1 && s != acc {
162 return ndim;
163 }
164 acc *= d as isize;
165 }
166 return self.ndim();
167 }
168
169 pub fn ndim_of_c_contig(&self) -> usize {
174 if self.ndim() == 0 || self.size() == 0 {
175 return self.ndim();
176 }
177 let stride = self.stride.as_ref();
178 let shape = self.shape.as_ref();
179 let mut acc = 1;
180 for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).rev().enumerate() {
181 if d != 1 && s != acc {
182 return ndim;
183 }
184 acc *= d as isize;
185 }
186 return self.ndim();
187 }
188
189 pub fn f_contig(&self) -> bool {
197 self.ndim() == self.ndim_of_f_contig()
198 }
199
200 pub fn c_contig(&self) -> bool {
208 self.ndim() == self.ndim_of_c_contig()
209 }
210
211 pub fn index_f(&self, index: &[isize]) -> Result<usize> {
215 rstsr_assert_eq!(index.len(), self.ndim(), InvalidLayout)?;
216 let mut pos = self.offset() as isize;
217 let shape = self.shape.as_ref();
218 let stride = self.stride.as_ref();
219
220 for (&idx, &shp, &strd) in izip!(index.iter(), shape.iter(), stride.iter()) {
221 let idx = if idx < 0 { idx + shp as isize } else { idx };
222 rstsr_pattern!(idx, 0..(shp as isize), ValueOutOfRange)?;
223 pos += strd * idx;
224 }
225 rstsr_pattern!(pos, 0.., ValueOutOfRange)?;
226 return Ok(pos as usize);
227 }
228
229 pub fn index(&self, index: &[isize]) -> usize {
238 self.index_f(index).unwrap()
239 }
240
241 pub fn bounds_index(&self) -> Result<(usize, usize)> {
247 let n = self.ndim();
248 let offset = self.offset;
249 let shape = self.shape.as_ref();
250 let stride = self.stride.as_ref();
251
252 if n == 0 {
253 return Ok((offset, offset + 1));
254 }
255
256 let mut min = offset as isize;
257 let mut max = offset as isize;
258
259 for i in 0..n {
260 if shape[i] == 0 {
261 return Ok((offset, offset));
262 }
263 if stride[i] > 0 {
264 max += stride[i] * (shape[i] as isize - 1);
265 } else {
266 min += stride[i] * (shape[i] as isize - 1);
267 }
268 }
269 rstsr_pattern!(min, 0.., ValueOutOfRange)?;
270 return Ok((min as usize, max as usize + 1));
271 }
272
273 pub fn check_strides(&self) -> Result<()> {
294 let shape = self.shape.as_ref();
295 let stride = self.stride.as_ref();
296 rstsr_assert_eq!(shape.len(), stride.len(), InvalidLayout)?;
297 let n = shape.len();
298
299 if self.size() == 0 || n == 0 {
302 return Ok(());
303 }
304
305 let mut indices = (0..n).filter(|&k| shape[k] > 1).collect::<Vec<_>>();
306 indices.sort_by_key(|&k| stride[k].abs());
307 let shape_sorted = indices.iter().map(|&k| shape[k]).collect::<Vec<_>>();
308 let stride_sorted = indices.iter().map(|&k| stride[k].unsigned_abs()).collect::<Vec<_>>();
309
310 for i in 0..indices.len().max(1) - 1 {
312 rstsr_pattern!(
314 shape_sorted[i] * stride_sorted[i],
315 1..stride_sorted[i + 1] + 1,
316 InvalidLayout,
317 "Either stride be zero, or stride too small that elements in tensor can be overlapped."
318 )?;
319 }
320 return Ok(());
321 }
322
323 pub fn diagonal(
324 &self,
325 offset: Option<isize>,
326 axis1: Option<isize>,
327 axis2: Option<isize>,
328 ) -> Result<Layout<<D as DimSmallerOneAPI>::SmallerOne>>
329 where
330 D: DimSmallerOneAPI,
331 {
332 rstsr_assert!(self.ndim() >= 2, InvalidLayout)?;
334 let offset = offset.unwrap_or(0);
336 let axis1 = axis1.unwrap_or(0);
337 let axis2 = axis2.unwrap_or(1);
338 let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
339 let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
340 rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
341 rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
342 let axis1 = axis1 as usize;
343 let axis2 = axis2 as usize;
344
345 let d1 = self.shape()[axis1] as isize;
347 let d2 = self.shape()[axis2] as isize;
348 let t1 = self.stride()[axis1];
349 let t2 = self.stride()[axis2];
350
351 let (offset_diag, d_diag) = if (-d2 + 1..0).contains(&offset) {
353 let offset = -offset;
354 let offset_diag = (self.offset() as isize + t1 * offset) as usize;
355 let d_diag = (d1 - offset).min(d2) as usize;
356 (offset_diag, d_diag)
357 } else if (0..d1).contains(&offset) {
358 let offset_diag = (self.offset() as isize + t2 * offset) as usize;
359 let d_diag = (d2 - offset).min(d1) as usize;
360 (offset_diag, d_diag)
361 } else {
362 (self.offset(), 0)
363 };
364
365 let t_diag = t1 + t2;
367 let mut shape_diag = vec![];
368 let mut stride_diag = vec![];
369 for i in 0..self.ndim() {
370 if i != axis1 && i != axis2 {
371 shape_diag.push(self.shape()[i]);
372 stride_diag.push(self.stride()[i]);
373 }
374 }
375 shape_diag.push(d_diag);
376 stride_diag.push(t_diag);
377 let layout_diag = Layout::new(shape_diag, stride_diag, offset_diag)?;
378 return layout_diag.into_dim::<<D as DimSmallerOneAPI>::SmallerOne>();
379 }
380}
381
382impl<D> Layout<D>
385where
386 D: DimBaseAPI,
387{
388 #[inline]
396 pub fn new(shape: D, stride: D::Stride, offset: usize) -> Result<Self>
397 where
398 D: DimShapeAPI + DimStrideAPI,
399 {
400 let layout = unsafe { Layout::new_unchecked(shape, stride, offset) };
401 layout.bounds_index()?;
402 layout.check_strides()?;
403 return Ok(layout);
404 }
405
406 #[inline]
413 pub unsafe fn new_unchecked(shape: D, stride: D::Stride, offset: usize) -> Self {
414 Layout { shape, stride, offset }
415 }
416
417 #[inline]
420 pub fn new_shape(&self) -> D {
421 self.shape.new_shape()
422 }
423
424 #[inline]
427 pub fn new_stride(&self) -> D::Stride {
428 self.shape.new_stride()
429 }
430}
431
432impl<D> Layout<D>
434where
435 D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
436{
437 pub fn transpose(&self, axes: &[isize]) -> Result<Self> {
444 let n = self.ndim();
446 rstsr_assert_eq!(
447 axes.len(),
448 n,
449 InvalidLayout,
450 "number of elements in axes should be the same to number of dimensions."
451 )?;
452 let mut permut_used = vec![false; n];
454 for &p in axes {
455 let p = if p < 0 { p + n as isize } else { p };
456 rstsr_pattern!(p, 0..n as isize, InvalidLayout)?;
457 let p = p as usize;
458 permut_used[p] = true;
459 }
460 rstsr_assert!(
461 permut_used.iter().all(|&b| b),
462 InvalidLayout,
463 "axes should contain all elements from 0 to n-1."
464 )?;
465 let axes = axes
466 .iter()
467 .map(|&p| if p < 0 { p + n as isize } else { p } as usize)
468 .collect::<Vec<_>>();
469
470 let shape_old = self.shape();
471 let stride_old = self.stride();
472 let mut shape = self.new_shape();
473 let mut stride = self.new_stride();
474 for i in 0..self.ndim() {
475 shape[i] = shape_old[axes[i]];
476 stride[i] = stride_old[axes[i]];
477 }
478 return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
479 }
480
481 pub fn permute_dims(&self, axes: &[isize]) -> Result<Self> {
485 self.transpose(axes)
486 }
487
488 pub fn reverse_axes(&self) -> Self {
490 let shape_old = self.shape();
491 let stride_old = self.stride();
492 let mut shape = self.new_shape();
493 let mut stride = self.new_stride();
494 for i in 0..self.ndim() {
495 shape[i] = shape_old[self.ndim() - i - 1];
496 stride[i] = stride_old[self.ndim() - i - 1];
497 }
498 return unsafe { Layout::new_unchecked(shape, stride, self.offset) };
499 }
500
501 pub fn swapaxes(&self, axis1: isize, axis2: isize) -> Result<Self> {
503 let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
504 rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
505 let axis1 = axis1 as usize;
506
507 let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
508 rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
509 let axis2 = axis2 as usize;
510
511 let mut shape = self.shape().clone();
512 let mut stride = self.stride().clone();
513 shape.as_mut().swap(axis1, axis2);
514 stride.as_mut().swap(axis1, axis2);
515 return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
516 }
517}
518
519impl<D> Layout<D>
523where
524 D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
525{
526 #[inline]
537 pub unsafe fn index_uncheck(&self, index: &[usize]) -> isize {
538 let stride = self.stride.as_ref();
539 match self.ndim() {
540 0 => self.offset as isize,
541 1 => self.offset as isize + stride[0] * index[0] as isize,
542 2 => {
543 self.offset as isize + stride[0] * index[0] as isize + stride[1] * index[1] as isize
544 },
545 3 => {
546 self.offset as isize
547 + stride[0] * index[0] as isize
548 + stride[1] * index[1] as isize
549 + stride[2] * index[2] as isize
550 },
551 4 => {
552 self.offset as isize
553 + stride[0] * index[0] as isize
554 + stride[1] * index[1] as isize
555 + stride[2] * index[2] as isize
556 + stride[3] * index[3] as isize
557 },
558 _ => {
559 let mut pos = self.offset as isize;
560 stride.iter().zip(index.iter()).for_each(|(&s, &i)| pos += s * i as isize);
561 pos
562 },
563 }
564 }
565}
566
567impl<D> PartialEq for Layout<D>
568where
569 D: DimBaseAPI,
570{
571 fn eq(&self, other: &Self) -> bool {
574 if self.ndim() != other.ndim() {
575 return false;
576 }
577 for i in 0..self.ndim() {
578 let s1 = self.shape()[i];
579 let s2 = other.shape()[i];
580 if s1 != s2 {
581 return false;
582 }
583 if s1 != 1 && s1 != 0 && self.stride()[i] != other.stride()[i] {
584 return false;
585 }
586 }
587 return true;
588 }
589}
590
591pub trait DimLayoutContigAPI: DimBaseAPI + DimShapeAPI + DimStrideAPI {
592 fn new_c_contig(&self, offset: Option<usize>) -> Layout<Self> {
595 let shape = self.clone();
596 let stride = shape.stride_c_contig();
597 unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
598 }
599
600 fn new_f_contig(&self, offset: Option<usize>) -> Layout<Self> {
603 let shape = self.clone();
604 let stride = shape.stride_f_contig();
605 unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
606 }
607
608 fn c(&self) -> Layout<Self> {
611 self.new_c_contig(None)
612 }
613
614 fn f(&self) -> Layout<Self> {
617 self.new_f_contig(None)
618 }
619
620 fn new_contig(&self, offset: Option<usize>, order: FlagOrder) -> Layout<Self> {
622 match order {
623 FlagOrder::C => self.new_c_contig(offset),
624 FlagOrder::F => self.new_f_contig(offset),
625 }
626 }
627}
628
629impl<const N: usize> DimLayoutContigAPI for Ix<N> {}
630impl DimLayoutContigAPI for IxD {}
631
632pub trait DimIntoAPI<D>: DimBaseAPI
637where
638 D: DimBaseAPI,
639{
640 fn into_dim(layout: Layout<Self>) -> Result<Layout<D>>;
641}
642
643impl<D> DimIntoAPI<D> for IxD
644where
645 D: DimBaseAPI,
646{
647 fn into_dim(layout: Layout<IxD>) -> Result<Layout<D>> {
648 let shape = layout.shape().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
649 let stride = layout.stride().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
650 let offset = layout.offset();
651 return Ok(Layout { shape, stride, offset });
652 }
653}
654
655impl<const N: usize> DimIntoAPI<IxD> for Ix<N> {
656 fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<IxD>> {
657 let shape = (*layout.shape()).into();
658 let stride = (*layout.stride()).into();
659 let offset = layout.offset();
660 return Ok(Layout { shape, stride, offset });
661 }
662}
663
664impl<const N: usize, const M: usize> DimIntoAPI<Ix<M>> for Ix<N> {
665 fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<Ix<M>>> {
666 rstsr_assert_eq!(N, M, InvalidLayout)?;
667 let shape = layout.shape().to_vec().try_into().unwrap();
668 let stride = layout.stride().to_vec().try_into().unwrap();
669 let offset = layout.offset();
670 return Ok(Layout { shape, stride, offset });
671 }
672}
673
674impl<D> Layout<D>
675where
676 D: DimBaseAPI,
677{
678 pub fn into_dim<D2>(self) -> Result<Layout<D2>>
680 where
681 D2: DimBaseAPI,
682 D: DimIntoAPI<D2>,
683 {
684 D::into_dim(self)
685 }
686
687 pub fn to_dim<D2>(&self) -> Result<Layout<D2>>
689 where
690 D2: DimBaseAPI,
691 D: DimIntoAPI<D2>,
692 {
693 D::into_dim(self.clone())
694 }
695}
696
697impl<const N: usize> From<Ix<N>> for Layout<Ix<N>> {
698 fn from(shape: Ix<N>) -> Self {
699 let stride = shape.stride_contig();
700 Layout { shape, stride, offset: 0 }
701 }
702}
703
704impl From<IxD> for Layout<IxD> {
705 fn from(shape: IxD) -> Self {
706 let stride = shape.stride_contig();
707 Layout { shape, stride, offset: 0 }
708 }
709}
710
711#[cfg(test)]
714mod test {
715 use std::panic::catch_unwind;
716
717 use super::*;
718
719 #[test]
720 fn test_layout_new() {
721 let shape = [3, 2, 6];
723 let stride = [3, -300, 15];
724 let layout = Layout::new(shape, stride, 917).unwrap();
725 assert_eq!(layout.shape(), &[3, 2, 6]);
726 assert_eq!(layout.stride(), &[3, -300, 15]);
727 assert_eq!(layout.offset(), 917);
728 assert_eq!(layout.ndim(), 3);
729 let shape = [3, 2, 6];
731 let stride = [3, -300, 15];
732 let layout = Layout::new(shape, stride, 0);
733 assert!(layout.is_err());
734 let shape = [3, 2, 6];
736 let stride = [3, -300, 0];
737 let layout = Layout::new(shape, stride, 1000);
738 assert!(layout.is_err());
739 let shape = [3, 2, 6];
741 let stride = [3, 4, 7];
742 let layout = Layout::new(shape, stride, 1000);
743 assert!(layout.is_err());
744 let shape = [];
746 let stride = [];
747 let layout = Layout::new(shape, stride, 1000);
748 assert!(layout.is_ok());
749 let shape = [3, 1, 5];
751 let stride = [1, 0, 15];
752 let layout = Layout::new(shape, stride, 1);
753 assert!(layout.is_ok());
754 let shape = [3, 1, 5];
756 let stride = [1, 0, 15];
757 let layout = Layout::new(shape, stride, 1);
758 assert!(layout.is_ok());
759 let shape = [3, 0, 5];
761 let stride = [-1, -2, -3];
762 let layout = Layout::new(shape, stride, 1);
763 assert!(layout.is_ok());
764 let shape = [3, 2, 6];
766 let stride = [3, -300, 0];
767 let r = catch_unwind(|| unsafe { Layout::new_unchecked(shape, stride, 1000) });
768 assert!(r.is_ok());
769 }
770
771 #[test]
772 fn test_is_f_prefer() {
773 let shape = [3, 5, 7];
775 let layout = Layout::new(shape, [1, 10, 100], 0).unwrap();
776 assert!(layout.f_prefer());
777 let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
778 assert!(layout.f_prefer());
779 let layout = Layout::new(shape, [1, 3, -15], 1000).unwrap();
780 assert!(!layout.f_prefer());
781 let layout = Layout::new(shape, [1, 21, 3], 0).unwrap();
782 assert!(!layout.f_prefer());
783 let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
784 assert!(!layout.f_prefer());
785 let layout = Layout::new(shape, [2, 6, 30], 0).unwrap();
786 assert!(!layout.f_prefer());
787 let layout = Layout::new([], [], 0).unwrap();
789 assert!(layout.f_prefer());
790 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
792 assert!(layout.f_prefer());
793 let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
795 assert!(layout.f_prefer());
796 }
797
798 #[test]
799 fn test_is_c_prefer() {
800 let shape = [3, 5, 7];
802 let layout = Layout::new(shape, [100, 10, 1], 0).unwrap();
803 assert!(layout.c_prefer());
804 let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
805 assert!(layout.c_prefer());
806 let layout = Layout::new(shape, [-35, 7, 1], 1000).unwrap();
807 assert!(!layout.c_prefer());
808 let layout = Layout::new(shape, [7, 21, 1], 0).unwrap();
809 assert!(!layout.c_prefer());
810 let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
811 assert!(!layout.c_prefer());
812 let layout = Layout::new(shape, [70, 14, 2], 0).unwrap();
813 assert!(!layout.c_prefer());
814 let layout = Layout::new([], [], 0).unwrap();
816 assert!(layout.c_prefer());
817 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
819 assert!(layout.c_prefer());
820 let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
822 assert!(layout.c_prefer());
823 }
824
825 #[test]
826 fn test_is_f_contig() {
827 let shape = [3, 5, 7];
829 let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
830 assert!(layout.f_contig());
831 let layout = Layout::new(shape, [1, 4, 20], 0).unwrap();
832 assert!(!layout.f_contig());
833 let layout = Layout::new([], [], 0).unwrap();
835 assert!(layout.f_contig());
836 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
838 assert!(layout.f_contig());
839 let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
841 assert!(layout.f_contig());
842 }
843
844 #[test]
845 fn test_is_c_contig() {
846 let shape = [3, 5, 7];
848 let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
849 assert!(layout.c_contig());
850 let layout = Layout::new(shape, [36, 7, 1], 0).unwrap();
851 assert!(!layout.c_contig());
852 let layout = Layout::new([], [], 0).unwrap();
854 assert!(layout.c_contig());
855 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
857 assert!(layout.c_contig());
858 let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
860 assert!(layout.c_contig());
861 }
862
863 #[test]
864 fn test_index() {
865 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
869 assert_eq!(layout.index(&[0, 0, 0]), 782);
870 assert_eq!(layout.index(&[2, 1, 4]), 668);
871 assert_eq!(layout.index(&[1, -2, -3]), 830);
872 let layout = Layout::new([], [], 10).unwrap();
874 assert_eq!(layout.index(&[]), 10);
875 }
876
877 #[test]
878 fn test_bounds_index() {
879 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
884 assert_eq!(layout.bounds_index().unwrap(), (602, 864));
885 let layout = unsafe { Layout::new_unchecked([3, 2, 6], [3, -180, 15], 15) };
887 assert!(layout.bounds_index().is_err());
888 let layout = Layout::new([], [], 10).unwrap();
890 assert_eq!(layout.bounds_index().unwrap(), (10, 11));
891 }
892
893 #[test]
894 fn test_transpose() {
895 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
897 let trans = layout.transpose(&[2, 0, 1]).unwrap();
898 assert_eq!(trans.shape(), &[6, 3, 2]);
899 assert_eq!(trans.stride(), &[15, 3, -180]);
900 let trans = layout.permute_dims(&[2, 0, 1]).unwrap();
902 assert_eq!(trans.shape(), &[6, 3, 2]);
903 assert_eq!(trans.stride(), &[15, 3, -180]);
904 let trans = layout.transpose(&[-1, 0, 1]).unwrap();
906 assert_eq!(trans.shape(), &[6, 3, 2]);
907 assert_eq!(trans.stride(), &[15, 3, -180]);
908 let trans = layout.transpose(&[-2, 0, 1]);
910 assert!(trans.is_err());
911 let trans = layout.transpose(&[1, 0]);
913 assert!(trans.is_err());
914 let layout = Layout::new([], [], 0).unwrap();
916 let trans = layout.transpose(&[]);
917 assert!(trans.is_ok());
918 }
919
920 #[test]
921 fn test_reverse_axes() {
922 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
924 let trans = layout.reverse_axes();
925 assert_eq!(trans.shape(), &[6, 2, 3]);
926 assert_eq!(trans.stride(), &[15, -180, 3]);
927 let layout = Layout::new([], [], 782).unwrap();
929 let trans = layout.reverse_axes();
930 assert_eq!(trans.shape(), &[]);
931 assert_eq!(trans.stride(), &[]);
932 }
933
934 #[test]
935 fn test_swapaxes() {
936 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
938 let trans = layout.swapaxes(-1, -2).unwrap();
939 assert_eq!(trans.shape(), &[3, 6, 2]);
940 assert_eq!(trans.stride(), &[3, 15, -180]);
941 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
943 let trans = layout.swapaxes(-1, -1).unwrap();
944 assert_eq!(trans.shape(), &[3, 2, 6]);
945 assert_eq!(trans.stride(), &[3, -180, 15]);
946 }
947
948 #[test]
949 fn test_index_uncheck() {
950 unsafe {
954 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
956 assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
957 assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
958 let layout = Layout::new(vec![3, 2, 6], vec![3, -180, 15], 782).unwrap();
960 assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
961 assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
962 let layout = Layout::new([], [], 10).unwrap();
964 assert_eq!(layout.index_uncheck(&[]), 10);
965 }
966 }
967
968 #[test]
969 fn test_diagonal() {
970 let layout = [2, 3, 4].c();
971 let diag = layout.diagonal(None, None, None).unwrap();
972 assert_eq!(diag, Layout::new([4, 2], [1, 16], 0).unwrap());
973 let diag = layout.diagonal(Some(-1), Some(-2), Some(-1)).unwrap();
974 assert_eq!(diag, Layout::new([2, 2], [12, 5], 0).unwrap());
975 let diag = layout.diagonal(Some(-4), Some(-2), Some(-1)).unwrap();
976 assert_eq!(diag, Layout::new([2, 0], [12, 5], 0).unwrap());
977 }
978
979 #[test]
980 fn test_new_contig() {
981 let layout = [3, 2, 6].c();
982 assert_eq!(layout.shape(), &[3, 2, 6]);
983 assert_eq!(layout.stride(), &[12, 6, 1]);
984 let layout = [3, 2, 6].f();
985 assert_eq!(layout.shape(), &[3, 2, 6]);
986 assert_eq!(layout.stride(), &[1, 3, 6]);
987 let layout: Layout<_> = [3, 2, 6].into();
990 println!("{layout:?}");
991 }
992
993 #[test]
994 fn test_layout_cast() {
995 let layout = [3, 2, 6].c();
996 assert!(layout.clone().into_dim::<IxD>().is_ok());
997 assert!(layout.clone().into_dim::<Ix3>().is_ok());
998 let layout = vec![3, 2, 6].c();
999 assert!(layout.clone().into_dim::<IxD>().is_ok());
1000 assert!(layout.clone().into_dim::<Ix3>().is_ok());
1001 assert!(layout.clone().into_dim::<Ix2>().is_err());
1002 }
1003
1004 #[test]
1005 fn test_unravel_index() {
1006 unsafe {
1007 let shape = [3, 2, 6];
1008 assert_eq!(shape.unravel_index_f(0), [0, 0, 0]);
1009 assert_eq!(shape.unravel_index_f(16), [1, 1, 2]);
1010 assert_eq!(shape.unravel_index_c(0), [0, 0, 0]);
1011 assert_eq!(shape.unravel_index_c(16), [1, 0, 4]);
1012 }
1013 }
1014}