1use crate::prelude_dev::*;
3use itertools::izip;
4
5#[doc = include_str!("readme.md")]
14#[derive(Clone)]
15pub struct Layout<D>
16where
17 D: DimBaseAPI,
18{
19 pub(crate) shape: D,
21 pub(crate) stride: D::Stride,
22 pub(crate) offset: usize,
23}
24
25unsafe impl<D> Send for Layout<D> where D: DimBaseAPI {}
26unsafe impl<D> Sync for Layout<D> where D: DimBaseAPI {}
27
28impl<D> Layout<D>
34where
35 D: DimBaseAPI,
36{
37 #[inline]
39 pub fn shape(&self) -> &D {
40 &self.shape
41 }
42
43 #[inline]
45 pub fn stride(&self) -> &D::Stride {
46 &self.stride
47 }
48
49 #[inline]
51 pub fn offset(&self) -> usize {
52 self.offset
53 }
54
55 #[inline]
57 pub fn ndim(&self) -> usize {
58 self.shape.ndim()
59 }
60
61 #[inline]
67 pub fn size(&self) -> usize {
68 self.shape().as_ref().iter().product()
69 }
70
71 pub unsafe fn set_offset(&mut self, offset: usize) -> &mut Self {
78 self.offset = offset;
79 return self;
80 }
81}
82
83impl<D> Layout<D>
85where
86 D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
87{
88 pub fn f_prefer(&self) -> bool {
90 if self.ndim() == 0 || self.size() == 0 {
92 return true;
93 }
94
95 let stride = self.stride.as_ref();
96 let shape = self.shape.as_ref();
97 let mut last = 0;
98 for (&s, &d) in stride.iter().zip(shape.iter()) {
99 if d != 1 {
100 if s < last {
101 return false;
103 }
104 if last == 0 && s != 1 {
105 return false;
107 }
108 last = s;
109 } else if last == 0 {
110 last = 1;
113 }
114 }
115 return true;
116 }
117
118 pub fn c_prefer(&self) -> bool {
120 if self.ndim() == 0 || self.size() == 0 {
122 return true;
123 }
124
125 let stride = self.stride.as_ref();
126 let shape = self.shape.as_ref();
127 let mut last = 0;
128 for (&s, &d) in stride.iter().zip(shape.iter()).rev() {
129 if d != 1 {
130 if s < last {
131 return false;
133 }
134 if last == 0 && s != 1 {
135 return false;
137 }
138 last = s;
139 } else if last == 0 {
140 last = 1;
143 }
144 }
145 return true;
146 }
147
148 pub fn ndim_of_f_contig(&self) -> usize {
153 if self.ndim() == 0 || self.size() == 0 {
154 return self.ndim();
155 }
156 let stride = self.stride.as_ref();
157 let shape = self.shape.as_ref();
158 let mut acc = 1;
159 for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).enumerate() {
160 if d != 1 && s != acc {
161 return ndim;
162 }
163 acc *= d as isize;
164 }
165 return self.ndim();
166 }
167
168 pub fn ndim_of_c_contig(&self) -> usize {
173 if self.ndim() == 0 || self.size() == 0 {
174 return self.ndim();
175 }
176 let stride = self.stride.as_ref();
177 let shape = self.shape.as_ref();
178 let mut acc = 1;
179 for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).rev().enumerate() {
180 if d != 1 && s != acc {
181 return ndim;
182 }
183 acc *= d as isize;
184 }
185 return self.ndim();
186 }
187
188 pub fn f_contig(&self) -> bool {
194 self.ndim() == self.ndim_of_f_contig()
195 }
196
197 pub fn c_contig(&self) -> bool {
203 self.ndim() == self.ndim_of_c_contig()
204 }
205
206 pub fn index_f(&self, index: &[isize]) -> Result<usize> {
210 rstsr_assert_eq!(index.len(), self.ndim(), InvalidLayout)?;
211 let mut pos = self.offset() as isize;
212 let shape = self.shape.as_ref();
213 let stride = self.stride.as_ref();
214
215 for (&idx, &shp, &strd) in izip!(index.iter(), shape.iter(), stride.iter()) {
216 let idx = if idx < 0 { idx + shp as isize } else { idx };
217 rstsr_pattern!(idx, 0..(shp as isize), ValueOutOfRange)?;
218 pos += strd * idx;
219 }
220 rstsr_pattern!(pos, 0.., ValueOutOfRange)?;
221 return Ok(pos as usize);
222 }
223
224 pub fn index(&self, index: &[isize]) -> usize {
233 self.index_f(index).unwrap()
234 }
235
236 pub fn bounds_index(&self) -> Result<(usize, usize)> {
242 let n = self.ndim();
243 let offset = self.offset;
244 let shape = self.shape.as_ref();
245 let stride = self.stride.as_ref();
246
247 if n == 0 {
248 return Ok((offset, offset + 1));
249 }
250
251 let mut min = offset as isize;
252 let mut max = offset as isize;
253
254 for i in 0..n {
255 if shape[i] == 0 {
256 return Ok((offset, offset));
257 }
258 if stride[i] > 0 {
259 max += stride[i] * (shape[i] as isize - 1);
260 } else {
261 min += stride[i] * (shape[i] as isize - 1);
262 }
263 }
264 rstsr_pattern!(min, 0.., ValueOutOfRange)?;
265 return Ok((min as usize, max as usize + 1));
266 }
267
268 pub fn check_strides(&self) -> Result<()> {
288 let shape = self.shape.as_ref();
289 let stride = self.stride.as_ref();
290 rstsr_assert_eq!(shape.len(), stride.len(), InvalidLayout)?;
291 let n = shape.len();
292
293 if self.size() == 0 || n == 0 {
296 return Ok(());
297 }
298
299 let mut indices = (0..n).filter(|&k| shape[k] > 1).collect::<Vec<_>>();
300 indices.sort_by_key(|&k| stride[k].abs());
301 let shape_sorted = indices.iter().map(|&k| shape[k]).collect::<Vec<_>>();
302 let stride_sorted = indices.iter().map(|&k| stride[k].unsigned_abs()).collect::<Vec<_>>();
303
304 let mut elem_cum = 0;
306 for i in 0..indices.len() {
307 rstsr_pattern!(
309 elem_cum,
310 0..stride_sorted[i],
311 InvalidLayout,
312 "Either stride be zero, or stride too small that elements in tensor can be overlapped."
313 )?;
314 elem_cum += (shape_sorted[i] - 1) * stride_sorted[i];
315 }
316 return Ok(());
317 }
318
319 pub fn diagonal(
320 &self,
321 offset: Option<isize>,
322 axis1: Option<isize>,
323 axis2: Option<isize>,
324 ) -> Result<Layout<<D as DimSmallerOneAPI>::SmallerOne>>
325 where
326 D: DimSmallerOneAPI,
327 {
328 rstsr_assert!(self.ndim() >= 2, InvalidLayout)?;
330 let offset = offset.unwrap_or(0);
332 let axis1 = axis1.unwrap_or(0);
333 let axis2 = axis2.unwrap_or(1);
334 let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
335 let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
336 rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
337 rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
338 let axis1 = axis1 as usize;
339 let axis2 = axis2 as usize;
340
341 let d1 = self.shape()[axis1] as isize;
343 let d2 = self.shape()[axis2] as isize;
344 let t1 = self.stride()[axis1];
345 let t2 = self.stride()[axis2];
346
347 let (offset_diag, d_diag) = if (-d2 + 1..0).contains(&offset) {
349 let offset = -offset;
350 let offset_diag = (self.offset() as isize + t1 * offset) as usize;
351 let d_diag = (d1 - offset).min(d2) as usize;
352 (offset_diag, d_diag)
353 } else if (0..d1).contains(&offset) {
354 let offset_diag = (self.offset() as isize + t2 * offset) as usize;
355 let d_diag = (d2 - offset).min(d1) as usize;
356 (offset_diag, d_diag)
357 } else {
358 (self.offset(), 0)
359 };
360
361 let t_diag = t1 + t2;
363 let mut shape_diag = vec![];
364 let mut stride_diag = vec![];
365 for i in 0..self.ndim() {
366 if i != axis1 && i != axis2 {
367 shape_diag.push(self.shape()[i]);
368 stride_diag.push(self.stride()[i]);
369 }
370 }
371 shape_diag.push(d_diag);
372 stride_diag.push(t_diag);
373 let layout_diag = Layout::new(shape_diag, stride_diag, offset_diag)?;
374 return layout_diag.into_dim::<<D as DimSmallerOneAPI>::SmallerOne>();
375 }
376}
377
378impl<D> Layout<D>
381where
382 D: DimBaseAPI,
383{
384 #[inline]
392 pub fn new(shape: D, stride: D::Stride, offset: usize) -> Result<Self>
393 where
394 D: DimShapeAPI + DimStrideAPI,
395 {
396 let layout = unsafe { Layout::new_unchecked(shape, stride, offset) };
397 layout.bounds_index()?;
398 layout.check_strides()?;
399 return Ok(layout);
400 }
401
402 #[inline]
409 pub unsafe fn new_unchecked(shape: D, stride: D::Stride, offset: usize) -> Self {
410 Layout { shape, stride, offset }
411 }
412
413 #[inline]
416 pub fn new_shape(&self) -> D {
417 self.shape.new_shape()
418 }
419
420 #[inline]
423 pub fn new_stride(&self) -> D::Stride {
424 self.shape.new_stride()
425 }
426}
427
428impl<D> Layout<D>
430where
431 D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
432{
433 pub fn transpose(&self, axes: &[isize]) -> Result<Self> {
440 let n = self.ndim();
442 rstsr_assert_eq!(
443 axes.len(),
444 n,
445 InvalidLayout,
446 "number of elements in axes should be the same to number of dimensions."
447 )?;
448 let mut permut_used = vec![false; n];
450 for &p in axes {
451 let p = if p < 0 { p + n as isize } else { p };
452 rstsr_pattern!(p, 0..n as isize, InvalidLayout)?;
453 let p = p as usize;
454 permut_used[p] = true;
455 }
456 rstsr_assert!(
457 permut_used.iter().all(|&b| b),
458 InvalidLayout,
459 "axes should contain all elements from 0 to n-1."
460 )?;
461 let axes = axes.iter().map(|&p| if p < 0 { p + n as isize } else { p } as usize).collect::<Vec<_>>();
462
463 let shape_old = self.shape();
464 let stride_old = self.stride();
465 let mut shape = self.new_shape();
466 let mut stride = self.new_stride();
467 for i in 0..self.ndim() {
468 shape[i] = shape_old[axes[i]];
469 stride[i] = stride_old[axes[i]];
470 }
471 return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
472 }
473
474 pub fn permute_dims(&self, axes: &[isize]) -> Result<Self> {
478 self.transpose(axes)
479 }
480
481 pub fn reverse_axes(&self) -> Self {
483 let shape_old = self.shape();
484 let stride_old = self.stride();
485 let mut shape = self.new_shape();
486 let mut stride = self.new_stride();
487 for i in 0..self.ndim() {
488 shape[i] = shape_old[self.ndim() - i - 1];
489 stride[i] = stride_old[self.ndim() - i - 1];
490 }
491 return unsafe { Layout::new_unchecked(shape, stride, self.offset) };
492 }
493
494 pub fn swapaxes(&self, axis1: isize, axis2: isize) -> Result<Self> {
496 let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
497 rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
498 let axis1 = axis1 as usize;
499
500 let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
501 rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
502 let axis2 = axis2 as usize;
503
504 let mut shape = self.shape().clone();
505 let mut stride = self.stride().clone();
506 shape.as_mut().swap(axis1, axis2);
507 stride.as_mut().swap(axis1, axis2);
508 return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
509 }
510}
511
512impl<D> Layout<D>
516where
517 D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
518{
519 #[inline]
530 pub unsafe fn index_uncheck(&self, index: &[usize]) -> isize {
531 let stride = self.stride.as_ref();
532 match self.ndim() {
533 0 => self.offset as isize,
534 1 => self.offset as isize + stride[0] * index[0] as isize,
535 2 => self.offset as isize + stride[0] * index[0] as isize + stride[1] * index[1] as isize,
536 3 => {
537 self.offset as isize
538 + stride[0] * index[0] as isize
539 + stride[1] * index[1] as isize
540 + stride[2] * index[2] as isize
541 },
542 4 => {
543 self.offset as isize
544 + stride[0] * index[0] as isize
545 + stride[1] * index[1] as isize
546 + stride[2] * index[2] as isize
547 + stride[3] * index[3] as isize
548 },
549 _ => {
550 let mut pos = self.offset as isize;
551 stride.iter().zip(index.iter()).for_each(|(&s, &i)| pos += s * i as isize);
552 pos
553 },
554 }
555 }
556}
557
558impl<D> PartialEq for Layout<D>
559where
560 D: DimBaseAPI,
561{
562 fn eq(&self, other: &Self) -> bool {
565 if self.ndim() != other.ndim() {
566 return false;
567 }
568 for i in 0..self.ndim() {
569 let s1 = self.shape()[i];
570 let s2 = other.shape()[i];
571 if s1 != s2 {
572 return false;
573 }
574 if s1 != 1 && s1 != 0 && self.stride()[i] != other.stride()[i] {
575 return false;
576 }
577 }
578 return true;
579 }
580}
581
582pub trait DimLayoutContigAPI: DimBaseAPI + DimShapeAPI + DimStrideAPI {
583 fn new_c_contig(&self, offset: Option<usize>) -> Layout<Self> {
586 let shape = self.clone();
587 let stride = shape.stride_c_contig();
588 unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
589 }
590
591 fn new_f_contig(&self, offset: Option<usize>) -> Layout<Self> {
594 let shape = self.clone();
595 let stride = shape.stride_f_contig();
596 unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
597 }
598
599 fn c(&self) -> Layout<Self> {
602 self.new_c_contig(None)
603 }
604
605 fn f(&self) -> Layout<Self> {
608 self.new_f_contig(None)
609 }
610
611 fn new_contig(&self, offset: Option<usize>, order: FlagOrder) -> Layout<Self> {
613 match order {
614 FlagOrder::C => self.new_c_contig(offset),
615 FlagOrder::F => self.new_f_contig(offset),
616 }
617 }
618}
619
620impl<const N: usize> DimLayoutContigAPI for Ix<N> {}
621impl DimLayoutContigAPI for IxD {}
622
623pub trait DimIntoAPI<D>: DimBaseAPI
628where
629 D: DimBaseAPI,
630{
631 fn into_dim(layout: Layout<Self>) -> Result<Layout<D>>;
632}
633
634impl<D> DimIntoAPI<D> for IxD
635where
636 D: DimBaseAPI,
637{
638 fn into_dim(layout: Layout<IxD>) -> Result<Layout<D>> {
639 let shape = layout.shape().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
640 let stride = layout.stride().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
641 let offset = layout.offset();
642 return Ok(Layout { shape, stride, offset });
643 }
644}
645
646impl<const N: usize> DimIntoAPI<IxD> for Ix<N> {
647 fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<IxD>> {
648 let shape = (*layout.shape()).into();
649 let stride = (*layout.stride()).into();
650 let offset = layout.offset();
651 return Ok(Layout { shape, stride, offset });
652 }
653}
654
655impl<const N: usize, const M: usize> DimIntoAPI<Ix<M>> for Ix<N> {
656 fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<Ix<M>>> {
657 rstsr_assert_eq!(N, M, InvalidLayout)?;
658 let shape = layout.shape().to_vec().try_into().unwrap();
659 let stride = layout.stride().to_vec().try_into().unwrap();
660 let offset = layout.offset();
661 return Ok(Layout { shape, stride, offset });
662 }
663}
664
665impl<D> Layout<D>
666where
667 D: DimBaseAPI,
668{
669 pub fn into_dim<D2>(self) -> Result<Layout<D2>>
671 where
672 D2: DimBaseAPI,
673 D: DimIntoAPI<D2>,
674 {
675 D::into_dim(self)
676 }
677
678 pub fn to_dim<D2>(&self) -> Result<Layout<D2>>
680 where
681 D2: DimBaseAPI,
682 D: DimIntoAPI<D2>,
683 {
684 D::into_dim(self.clone())
685 }
686}
687
688impl<const N: usize> From<Ix<N>> for Layout<Ix<N>> {
689 fn from(shape: Ix<N>) -> Self {
690 let stride = shape.stride_contig();
691 Layout { shape, stride, offset: 0 }
692 }
693}
694
695impl From<IxD> for Layout<IxD> {
696 fn from(shape: IxD) -> Self {
697 let stride = shape.stride_contig();
698 Layout { shape, stride, offset: 0 }
699 }
700}
701
702#[cfg(test)]
705mod test {
706 use std::panic::catch_unwind;
707
708 use super::*;
709
710 #[test]
711 fn test_layout_new() {
712 let shape = [3, 2, 6];
714 let stride = [3, -300, 15];
715 let layout = Layout::new(shape, stride, 917).unwrap();
716 assert_eq!(layout.shape(), &[3, 2, 6]);
717 assert_eq!(layout.stride(), &[3, -300, 15]);
718 assert_eq!(layout.offset(), 917);
719 assert_eq!(layout.ndim(), 3);
720 let shape = [3, 2, 6];
722 let stride = [3, -300, 15];
723 let layout = Layout::new(shape, stride, 0);
724 assert!(layout.is_err());
725 let shape = [3, 2, 6];
727 let stride = [3, -300, 0];
728 let layout = Layout::new(shape, stride, 1000);
729 assert!(layout.is_err());
730 let shape = [3, 2, 6];
732 let stride = [3, 4, 7];
733 let layout = Layout::new(shape, stride, 1000);
734 assert!(layout.is_err());
735 let shape = [];
737 let stride = [];
738 let layout = Layout::new(shape, stride, 1000);
739 assert!(layout.is_ok());
740 let shape = [3, 1, 5];
742 let stride = [1, 0, 15];
743 let layout = Layout::new(shape, stride, 1);
744 assert!(layout.is_ok());
745 let shape = [3, 1, 5];
747 let stride = [1, 0, 15];
748 let layout = Layout::new(shape, stride, 1);
749 assert!(layout.is_ok());
750 let shape = [3, 0, 5];
752 let stride = [-1, -2, -3];
753 let layout = Layout::new(shape, stride, 1);
754 assert!(layout.is_ok());
755 let shape = [3, 2, 6];
757 let stride = [3, -300, 0];
758 let r = catch_unwind(|| unsafe { Layout::new_unchecked(shape, stride, 1000) });
759 assert!(r.is_ok());
760 }
761
762 #[test]
763 fn test_is_f_prefer() {
764 let shape = [3, 5, 7];
766 let layout = Layout::new(shape, [1, 10, 100], 0).unwrap();
767 assert!(layout.f_prefer());
768 let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
769 assert!(layout.f_prefer());
770 let layout = Layout::new(shape, [1, 3, -15], 1000).unwrap();
771 assert!(!layout.f_prefer());
772 let layout = Layout::new(shape, [1, 21, 3], 0).unwrap();
773 assert!(!layout.f_prefer());
774 let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
775 assert!(!layout.f_prefer());
776 let layout = Layout::new(shape, [2, 6, 30], 0).unwrap();
777 assert!(!layout.f_prefer());
778 let layout = Layout::new([], [], 0).unwrap();
780 assert!(layout.f_prefer());
781 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
783 assert!(layout.f_prefer());
784 let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
786 assert!(layout.f_prefer());
787 }
788
789 #[test]
790 fn test_is_c_prefer() {
791 let shape = [3, 5, 7];
793 let layout = Layout::new(shape, [100, 10, 1], 0).unwrap();
794 assert!(layout.c_prefer());
795 let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
796 assert!(layout.c_prefer());
797 let layout = Layout::new(shape, [-35, 7, 1], 1000).unwrap();
798 assert!(!layout.c_prefer());
799 let layout = Layout::new(shape, [7, 21, 1], 0).unwrap();
800 assert!(!layout.c_prefer());
801 let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
802 assert!(!layout.c_prefer());
803 let layout = Layout::new(shape, [70, 14, 2], 0).unwrap();
804 assert!(!layout.c_prefer());
805 let layout = Layout::new([], [], 0).unwrap();
807 assert!(layout.c_prefer());
808 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
810 assert!(layout.c_prefer());
811 let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
813 assert!(layout.c_prefer());
814 }
815
816 #[test]
817 fn test_is_f_contig() {
818 let shape = [3, 5, 7];
820 let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
821 assert!(layout.f_contig());
822 let layout = Layout::new(shape, [1, 4, 20], 0).unwrap();
823 assert!(!layout.f_contig());
824 let layout = Layout::new([], [], 0).unwrap();
826 assert!(layout.f_contig());
827 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
829 assert!(layout.f_contig());
830 let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
832 assert!(layout.f_contig());
833 }
834
835 #[test]
836 fn test_is_c_contig() {
837 let shape = [3, 5, 7];
839 let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
840 assert!(layout.c_contig());
841 let layout = Layout::new(shape, [36, 7, 1], 0).unwrap();
842 assert!(!layout.c_contig());
843 let layout = Layout::new([], [], 0).unwrap();
845 assert!(layout.c_contig());
846 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
848 assert!(layout.c_contig());
849 let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
851 assert!(layout.c_contig());
852 }
853
854 #[test]
855 fn test_index() {
856 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
860 assert_eq!(layout.index(&[0, 0, 0]), 782);
861 assert_eq!(layout.index(&[2, 1, 4]), 668);
862 assert_eq!(layout.index(&[1, -2, -3]), 830);
863 let layout = Layout::new([], [], 10).unwrap();
865 assert_eq!(layout.index(&[]), 10);
866 }
867
868 #[test]
869 fn test_bounds_index() {
870 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
875 assert_eq!(layout.bounds_index().unwrap(), (602, 864));
876 let layout = unsafe { Layout::new_unchecked([3, 2, 6], [3, -180, 15], 15) };
878 assert!(layout.bounds_index().is_err());
879 let layout = Layout::new([], [], 10).unwrap();
881 assert_eq!(layout.bounds_index().unwrap(), (10, 11));
882 }
883
884 #[test]
885 fn test_transpose() {
886 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
888 let trans = layout.transpose(&[2, 0, 1]).unwrap();
889 assert_eq!(trans.shape(), &[6, 3, 2]);
890 assert_eq!(trans.stride(), &[15, 3, -180]);
891 let trans = layout.permute_dims(&[2, 0, 1]).unwrap();
893 assert_eq!(trans.shape(), &[6, 3, 2]);
894 assert_eq!(trans.stride(), &[15, 3, -180]);
895 let trans = layout.transpose(&[-1, 0, 1]).unwrap();
897 assert_eq!(trans.shape(), &[6, 3, 2]);
898 assert_eq!(trans.stride(), &[15, 3, -180]);
899 let trans = layout.transpose(&[-2, 0, 1]);
901 assert!(trans.is_err());
902 let trans = layout.transpose(&[1, 0]);
904 assert!(trans.is_err());
905 let layout = Layout::new([], [], 0).unwrap();
907 let trans = layout.transpose(&[]);
908 assert!(trans.is_ok());
909 }
910
911 #[test]
912 fn test_reverse_axes() {
913 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
915 let trans = layout.reverse_axes();
916 assert_eq!(trans.shape(), &[6, 2, 3]);
917 assert_eq!(trans.stride(), &[15, -180, 3]);
918 let layout = Layout::new([], [], 782).unwrap();
920 let trans = layout.reverse_axes();
921 assert_eq!(trans.shape(), &[]);
922 assert_eq!(trans.stride(), &[]);
923 }
924
925 #[test]
926 fn test_swapaxes() {
927 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
929 let trans = layout.swapaxes(-1, -2).unwrap();
930 assert_eq!(trans.shape(), &[3, 6, 2]);
931 assert_eq!(trans.stride(), &[3, 15, -180]);
932 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
934 let trans = layout.swapaxes(-1, -1).unwrap();
935 assert_eq!(trans.shape(), &[3, 2, 6]);
936 assert_eq!(trans.stride(), &[3, -180, 15]);
937 }
938
939 #[test]
940 fn test_index_uncheck() {
941 unsafe {
945 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
947 assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
948 assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
949 let layout = Layout::new(vec![3, 2, 6], vec![3, -180, 15], 782).unwrap();
951 assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
952 assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
953 let layout = Layout::new([], [], 10).unwrap();
955 assert_eq!(layout.index_uncheck(&[]), 10);
956 }
957 }
958
959 #[test]
960 fn test_diagonal() {
961 let layout = [2, 3, 4].c();
962 let diag = layout.diagonal(None, None, None).unwrap();
963 assert_eq!(diag, Layout::new([4, 2], [1, 16], 0).unwrap());
964 let diag = layout.diagonal(Some(-1), Some(-2), Some(-1)).unwrap();
965 assert_eq!(diag, Layout::new([2, 2], [12, 5], 0).unwrap());
966 let diag = layout.diagonal(Some(-4), Some(-2), Some(-1)).unwrap();
967 assert_eq!(diag, Layout::new([2, 0], [12, 5], 0).unwrap());
968 }
969
970 #[test]
971 fn test_new_contig() {
972 let layout = [3, 2, 6].c();
973 assert_eq!(layout.shape(), &[3, 2, 6]);
974 assert_eq!(layout.stride(), &[12, 6, 1]);
975 let layout = [3, 2, 6].f();
976 assert_eq!(layout.shape(), &[3, 2, 6]);
977 assert_eq!(layout.stride(), &[1, 3, 6]);
978 let layout: Layout<_> = [3, 2, 6].into();
981 println!("{layout:?}");
982 }
983
984 #[test]
985 fn test_layout_cast() {
986 let layout = [3, 2, 6].c();
987 assert!(layout.clone().into_dim::<IxD>().is_ok());
988 assert!(layout.clone().into_dim::<Ix3>().is_ok());
989 let layout = vec![3, 2, 6].c();
990 assert!(layout.clone().into_dim::<IxD>().is_ok());
991 assert!(layout.clone().into_dim::<Ix3>().is_ok());
992 assert!(layout.clone().into_dim::<Ix2>().is_err());
993 }
994
995 #[test]
996 fn test_unravel_index() {
997 unsafe {
998 let shape = [3, 2, 6];
999 assert_eq!(shape.unravel_index_f(0), [0, 0, 0]);
1000 assert_eq!(shape.unravel_index_f(16), [1, 1, 2]);
1001 assert_eq!(shape.unravel_index_c(0), [0, 0, 0]);
1002 assert_eq!(shape.unravel_index_c(16), [1, 0, 4]);
1003 }
1004 }
1005
1006 #[test]
1007 fn fix_too_strict_stride_check() {
1008 let layout = [10, 11, 12].c();
1009 let slc = (.., slice!(-1, 0, -4));
1010 let slc: AxesIndex<Indexer> = slc.try_into().unwrap();
1011 let indexed = layout.dim_slice(slc.as_ref()).unwrap();
1012 assert_eq!(indexed.shape(), &[10, 3, 12]);
1013 assert_eq!(indexed.stride(), &[132, -48, 1]);
1014 }
1015}