1use crate::prelude_dev::*;
3use itertools::izip;
4
5#[doc = include_str!("readme.md")]
15#[derive(Clone)]
16pub struct Layout<D>
17where
18 D: DimBaseAPI,
19{
20 pub(crate) shape: D,
22 pub(crate) stride: D::Stride,
23 pub(crate) offset: usize,
24}
25
26unsafe impl<D> Send for Layout<D> where D: DimBaseAPI {}
27unsafe impl<D> Sync for Layout<D> where D: DimBaseAPI {}
28
29impl<D> Layout<D>
35where
36 D: DimBaseAPI,
37{
38 #[inline]
40 pub fn shape(&self) -> &D {
41 &self.shape
42 }
43
44 #[inline]
46 pub fn stride(&self) -> &D::Stride {
47 &self.stride
48 }
49
50 #[inline]
52 pub fn offset(&self) -> usize {
53 self.offset
54 }
55
56 #[inline]
58 pub fn ndim(&self) -> usize {
59 self.shape.ndim()
60 }
61
62 #[inline]
68 pub fn size(&self) -> usize {
69 self.shape().as_ref().iter().product()
70 }
71
72 pub unsafe fn set_offset(&mut self, offset: usize) -> &mut Self {
79 self.offset = offset;
80 return self;
81 }
82}
83
84impl<D> Layout<D>
86where
87 D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
88{
89 pub fn f_prefer(&self) -> bool {
91 if self.ndim() == 0 || self.size() == 0 {
93 return true;
94 }
95
96 let stride = self.stride.as_ref();
97 let shape = self.shape.as_ref();
98 let mut last = 0;
99 for (&s, &d) in stride.iter().zip(shape.iter()) {
100 if d != 1 {
101 if s < last {
102 return false;
104 }
105 if last == 0 && s != 1 {
106 return false;
108 }
109 last = s;
110 } else if last == 0 {
111 last = 1;
114 }
115 }
116 return true;
117 }
118
119 pub fn c_prefer(&self) -> bool {
121 if self.ndim() == 0 || self.size() == 0 {
123 return true;
124 }
125
126 let stride = self.stride.as_ref();
127 let shape = self.shape.as_ref();
128 let mut last = 0;
129 for (&s, &d) in stride.iter().zip(shape.iter()).rev() {
130 if d != 1 {
131 if s < last {
132 return false;
134 }
135 if last == 0 && s != 1 {
136 return false;
138 }
139 last = s;
140 } else if last == 0 {
141 last = 1;
144 }
145 }
146 return true;
147 }
148
149 pub fn ndim_of_f_contig(&self) -> usize {
154 if self.ndim() == 0 || self.size() == 0 {
155 return self.ndim();
156 }
157 let stride = self.stride.as_ref();
158 let shape = self.shape.as_ref();
159 let mut acc = 1;
160 for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).enumerate() {
161 if d != 1 && s != acc {
162 return ndim;
163 }
164 acc *= d as isize;
165 }
166 return self.ndim();
167 }
168
169 pub fn ndim_of_c_contig(&self) -> usize {
174 if self.ndim() == 0 || self.size() == 0 {
175 return self.ndim();
176 }
177 let stride = self.stride.as_ref();
178 let shape = self.shape.as_ref();
179 let mut acc = 1;
180 for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).rev().enumerate() {
181 if d != 1 && s != acc {
182 return ndim;
183 }
184 acc *= d as isize;
185 }
186 return self.ndim();
187 }
188
189 pub fn f_contig(&self) -> bool {
197 self.ndim() == self.ndim_of_f_contig()
198 }
199
200 pub fn c_contig(&self) -> bool {
208 self.ndim() == self.ndim_of_c_contig()
209 }
210
211 pub fn index_f(&self, index: &[isize]) -> Result<usize> {
215 rstsr_assert_eq!(index.len(), self.ndim(), InvalidLayout)?;
216 let mut pos = self.offset() as isize;
217 let shape = self.shape.as_ref();
218 let stride = self.stride.as_ref();
219
220 for (&idx, &shp, &strd) in izip!(index.iter(), shape.iter(), stride.iter()) {
221 let idx = if idx < 0 { idx + shp as isize } else { idx };
222 rstsr_pattern!(idx, 0..(shp as isize), ValueOutOfRange)?;
223 pos += strd * idx;
224 }
225 rstsr_pattern!(pos, 0.., ValueOutOfRange)?;
226 return Ok(pos as usize);
227 }
228
229 pub fn index(&self, index: &[isize]) -> usize {
238 self.index_f(index).unwrap()
239 }
240
241 pub fn bounds_index(&self) -> Result<(usize, usize)> {
247 let n = self.ndim();
248 let offset = self.offset;
249 let shape = self.shape.as_ref();
250 let stride = self.stride.as_ref();
251
252 if n == 0 {
253 return Ok((offset, offset + 1));
254 }
255
256 let mut min = offset as isize;
257 let mut max = offset as isize;
258
259 for i in 0..n {
260 if shape[i] == 0 {
261 return Ok((offset, offset));
262 }
263 if stride[i] > 0 {
264 max += stride[i] * (shape[i] as isize - 1);
265 } else {
266 min += stride[i] * (shape[i] as isize - 1);
267 }
268 }
269 rstsr_pattern!(min, 0.., ValueOutOfRange)?;
270 return Ok((min as usize, max as usize + 1));
271 }
272
273 pub fn check_strides(&self) -> Result<()> {
294 let shape = self.shape.as_ref();
295 let stride = self.stride.as_ref();
296 rstsr_assert_eq!(shape.len(), stride.len(), InvalidLayout)?;
297 let n = shape.len();
298
299 if self.size() == 0 || n == 0 {
302 return Ok(());
303 }
304
305 let mut indices = (0..n).filter(|&k| shape[k] > 1).collect::<Vec<_>>();
306 indices.sort_by_key(|&k| stride[k].abs());
307 let shape_sorted = indices.iter().map(|&k| shape[k]).collect::<Vec<_>>();
308 let stride_sorted = indices.iter().map(|&k| stride[k].unsigned_abs()).collect::<Vec<_>>();
309
310 for i in 0..indices.len().max(1) - 1 {
312 rstsr_pattern!(
314 shape_sorted[i] * stride_sorted[i],
315 1..stride_sorted[i + 1] + 1,
316 InvalidLayout,
317 "Either stride be zero, or stride too small that elements in tensor can be overlapped."
318 )?;
319 }
320 return Ok(());
321 }
322
323 pub fn diagonal(
324 &self,
325 offset: Option<isize>,
326 axis1: Option<isize>,
327 axis2: Option<isize>,
328 ) -> Result<Layout<<D as DimSmallerOneAPI>::SmallerOne>>
329 where
330 D: DimSmallerOneAPI,
331 {
332 rstsr_assert!(self.ndim() >= 2, InvalidLayout)?;
334 let offset = offset.unwrap_or(0);
336 let axis1 = axis1.unwrap_or(0);
337 let axis2 = axis2.unwrap_or(1);
338 let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
339 let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
340 rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
341 rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
342 let axis1 = axis1 as usize;
343 let axis2 = axis2 as usize;
344
345 let d1 = self.shape()[axis1] as isize;
347 let d2 = self.shape()[axis2] as isize;
348 let t1 = self.stride()[axis1];
349 let t2 = self.stride()[axis2];
350
351 let (offset_diag, d_diag) = if (-d2 + 1..0).contains(&offset) {
353 let offset = -offset;
354 let offset_diag = (self.offset() as isize + t1 * offset) as usize;
355 let d_diag = (d1 - offset).min(d2) as usize;
356 (offset_diag, d_diag)
357 } else if (0..d1).contains(&offset) {
358 let offset_diag = (self.offset() as isize + t2 * offset) as usize;
359 let d_diag = (d2 - offset).min(d1) as usize;
360 (offset_diag, d_diag)
361 } else {
362 (self.offset(), 0)
363 };
364
365 let t_diag = t1 + t2;
367 let mut shape_diag = vec![];
368 let mut stride_diag = vec![];
369 for i in 0..self.ndim() {
370 if i != axis1 && i != axis2 {
371 shape_diag.push(self.shape()[i]);
372 stride_diag.push(self.stride()[i]);
373 }
374 }
375 shape_diag.push(d_diag);
376 stride_diag.push(t_diag);
377 let layout_diag = Layout::new(shape_diag, stride_diag, offset_diag)?;
378 return layout_diag.into_dim::<<D as DimSmallerOneAPI>::SmallerOne>();
379 }
380}
381
382impl<D> Layout<D>
385where
386 D: DimBaseAPI,
387{
388 #[inline]
396 pub fn new(shape: D, stride: D::Stride, offset: usize) -> Result<Self>
397 where
398 D: DimShapeAPI + DimStrideAPI,
399 {
400 let layout = unsafe { Layout::new_unchecked(shape, stride, offset) };
401 layout.bounds_index()?;
402 layout.check_strides()?;
403 return Ok(layout);
404 }
405
406 #[inline]
413 pub unsafe fn new_unchecked(shape: D, stride: D::Stride, offset: usize) -> Self {
414 Layout { shape, stride, offset }
415 }
416
417 #[inline]
420 pub fn new_shape(&self) -> D {
421 self.shape.new_shape()
422 }
423
424 #[inline]
427 pub fn new_stride(&self) -> D::Stride {
428 self.shape.new_stride()
429 }
430}
431
432impl<D> Layout<D>
434where
435 D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
436{
437 pub fn transpose(&self, axes: &[isize]) -> Result<Self> {
444 let n = self.ndim();
446 rstsr_assert_eq!(
447 axes.len(),
448 n,
449 InvalidLayout,
450 "number of elements in axes should be the same to number of dimensions."
451 )?;
452 let mut permut_used = vec![false; n];
454 for &p in axes {
455 let p = if p < 0 { p + n as isize } else { p };
456 rstsr_pattern!(p, 0..n as isize, InvalidLayout)?;
457 let p = p as usize;
458 permut_used[p] = true;
459 }
460 rstsr_assert!(
461 permut_used.iter().all(|&b| b),
462 InvalidLayout,
463 "axes should contain all elements from 0 to n-1."
464 )?;
465 let axes = axes
466 .iter()
467 .map(|&p| if p < 0 { p + n as isize } else { p } as usize)
468 .collect::<Vec<_>>();
469
470 let shape_old = self.shape();
471 let stride_old = self.stride();
472 let mut shape = self.new_shape();
473 let mut stride = self.new_stride();
474 for i in 0..self.ndim() {
475 shape[i] = shape_old[axes[i]];
476 stride[i] = stride_old[axes[i]];
477 }
478 return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
479 }
480
481 pub fn permute_dims(&self, axes: &[isize]) -> Result<Self> {
485 self.transpose(axes)
486 }
487
488 pub fn reverse_axes(&self) -> Self {
490 let shape_old = self.shape();
491 let stride_old = self.stride();
492 let mut shape = self.new_shape();
493 let mut stride = self.new_stride();
494 for i in 0..self.ndim() {
495 shape[i] = shape_old[self.ndim() - i - 1];
496 stride[i] = stride_old[self.ndim() - i - 1];
497 }
498 return unsafe { Layout::new_unchecked(shape, stride, self.offset) };
499 }
500
501 pub fn swapaxes(&self, axis1: isize, axis2: isize) -> Result<Self> {
503 let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
504 rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
505 let axis1 = axis1 as usize;
506
507 let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
508 rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
509 let axis2 = axis2 as usize;
510
511 let mut shape = self.shape().clone();
512 let mut stride = self.stride().clone();
513 shape.as_mut().swap(axis1, axis2);
514 stride.as_mut().swap(axis1, axis2);
515 return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
516 }
517}
518
519impl<D> Layout<D>
523where
524 D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
525{
526 #[inline]
537 pub unsafe fn index_uncheck(&self, index: &[usize]) -> isize {
538 let stride = self.stride.as_ref();
539 match self.ndim() {
540 0 => self.offset as isize,
541 1 => self.offset as isize + stride[0] * index[0] as isize,
542 2 => {
543 self.offset as isize + stride[0] * index[0] as isize + stride[1] * index[1] as isize
544 },
545 3 => {
546 self.offset as isize
547 + stride[0] * index[0] as isize
548 + stride[1] * index[1] as isize
549 + stride[2] * index[2] as isize
550 },
551 4 => {
552 self.offset as isize
553 + stride[0] * index[0] as isize
554 + stride[1] * index[1] as isize
555 + stride[2] * index[2] as isize
556 + stride[3] * index[3] as isize
557 },
558 _ => {
559 let mut pos = self.offset as isize;
560 stride.iter().zip(index.iter()).for_each(|(&s, &i)| pos += s * i as isize);
561 pos
562 },
563 }
564 }
565}
566
567impl<D> PartialEq for Layout<D>
568where
569 D: DimBaseAPI,
570{
571 fn eq(&self, other: &Self) -> bool {
574 if self.ndim() != other.ndim() {
575 return false;
576 }
577 for i in 0..self.ndim() {
578 let s1 = self.shape()[i];
579 let s2 = other.shape()[i];
580 if s1 != s2 {
581 return false;
582 }
583 if s1 != 1 && s1 != 0 && self.stride()[i] != other.stride()[i] {
584 return false;
585 }
586 }
587 return true;
588 }
589}
590
591pub trait DimLayoutContigAPI: DimBaseAPI + DimShapeAPI + DimStrideAPI {
592 fn new_c_contig(&self, offset: Option<usize>) -> Layout<Self> {
595 let shape = self.clone();
596 let stride = shape.stride_c_contig();
597 unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
598 }
599
600 fn new_f_contig(&self, offset: Option<usize>) -> Layout<Self> {
603 let shape = self.clone();
604 let stride = shape.stride_f_contig();
605 unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
606 }
607
608 fn c(&self) -> Layout<Self> {
611 self.new_c_contig(None)
612 }
613
614 fn f(&self) -> Layout<Self> {
617 self.new_f_contig(None)
618 }
619}
620
621impl<const N: usize> DimLayoutContigAPI for Ix<N> {}
622impl DimLayoutContigAPI for IxD {}
623
624pub trait DimIntoAPI<D>: DimBaseAPI
629where
630 D: DimBaseAPI,
631{
632 fn into_dim(layout: Layout<Self>) -> Result<Layout<D>>;
633}
634
635impl<D> DimIntoAPI<D> for IxD
636where
637 D: DimBaseAPI,
638{
639 fn into_dim(layout: Layout<IxD>) -> Result<Layout<D>> {
640 let shape = layout.shape().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
641 let stride = layout.stride().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
642 let offset = layout.offset();
643 return Ok(Layout { shape, stride, offset });
644 }
645}
646
647impl<const N: usize> DimIntoAPI<IxD> for Ix<N> {
648 fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<IxD>> {
649 let shape = (*layout.shape()).into();
650 let stride = (*layout.stride()).into();
651 let offset = layout.offset();
652 return Ok(Layout { shape, stride, offset });
653 }
654}
655
656impl<const N: usize, const M: usize> DimIntoAPI<Ix<M>> for Ix<N> {
657 fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<Ix<M>>> {
658 rstsr_assert_eq!(N, M, InvalidLayout)?;
659 let shape = layout.shape().to_vec().try_into().unwrap();
660 let stride = layout.stride().to_vec().try_into().unwrap();
661 let offset = layout.offset();
662 return Ok(Layout { shape, stride, offset });
663 }
664}
665
666impl<D> Layout<D>
667where
668 D: DimBaseAPI,
669{
670 pub fn into_dim<D2>(self) -> Result<Layout<D2>>
672 where
673 D2: DimBaseAPI,
674 D: DimIntoAPI<D2>,
675 {
676 D::into_dim(self)
677 }
678
679 pub fn to_dim<D2>(&self) -> Result<Layout<D2>>
681 where
682 D2: DimBaseAPI,
683 D: DimIntoAPI<D2>,
684 {
685 D::into_dim(self.clone())
686 }
687}
688
689impl<const N: usize> From<Ix<N>> for Layout<Ix<N>> {
690 fn from(shape: Ix<N>) -> Self {
691 let stride = shape.stride_contig();
692 Layout { shape, stride, offset: 0 }
693 }
694}
695
696impl From<IxD> for Layout<IxD> {
697 fn from(shape: IxD) -> Self {
698 let stride = shape.stride_contig();
699 Layout { shape, stride, offset: 0 }
700 }
701}
702
703#[cfg(test)]
706mod test {
707 use std::panic::catch_unwind;
708
709 use super::*;
710
711 #[test]
712 fn test_layout_new() {
713 let shape = [3, 2, 6];
715 let stride = [3, -300, 15];
716 let layout = Layout::new(shape, stride, 917).unwrap();
717 assert_eq!(layout.shape(), &[3, 2, 6]);
718 assert_eq!(layout.stride(), &[3, -300, 15]);
719 assert_eq!(layout.offset(), 917);
720 assert_eq!(layout.ndim(), 3);
721 let shape = [3, 2, 6];
723 let stride = [3, -300, 15];
724 let layout = Layout::new(shape, stride, 0);
725 assert!(layout.is_err());
726 let shape = [3, 2, 6];
728 let stride = [3, -300, 0];
729 let layout = Layout::new(shape, stride, 1000);
730 assert!(layout.is_err());
731 let shape = [3, 2, 6];
733 let stride = [3, 4, 7];
734 let layout = Layout::new(shape, stride, 1000);
735 assert!(layout.is_err());
736 let shape = [];
738 let stride = [];
739 let layout = Layout::new(shape, stride, 1000);
740 assert!(layout.is_ok());
741 let shape = [3, 1, 5];
743 let stride = [1, 0, 15];
744 let layout = Layout::new(shape, stride, 1);
745 assert!(layout.is_ok());
746 let shape = [3, 1, 5];
748 let stride = [1, 0, 15];
749 let layout = Layout::new(shape, stride, 1);
750 assert!(layout.is_ok());
751 let shape = [3, 0, 5];
753 let stride = [-1, -2, -3];
754 let layout = Layout::new(shape, stride, 1);
755 assert!(layout.is_ok());
756 let shape = [3, 2, 6];
758 let stride = [3, -300, 0];
759 let r = catch_unwind(|| unsafe { Layout::new_unchecked(shape, stride, 1000) });
760 assert!(r.is_ok());
761 }
762
763 #[test]
764 fn test_is_f_prefer() {
765 let shape = [3, 5, 7];
767 let layout = Layout::new(shape, [1, 10, 100], 0).unwrap();
768 assert!(layout.f_prefer());
769 let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
770 assert!(layout.f_prefer());
771 let layout = Layout::new(shape, [1, 3, -15], 1000).unwrap();
772 assert!(!layout.f_prefer());
773 let layout = Layout::new(shape, [1, 21, 3], 0).unwrap();
774 assert!(!layout.f_prefer());
775 let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
776 assert!(!layout.f_prefer());
777 let layout = Layout::new(shape, [2, 6, 30], 0).unwrap();
778 assert!(!layout.f_prefer());
779 let layout = Layout::new([], [], 0).unwrap();
781 assert!(layout.f_prefer());
782 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
784 assert!(layout.f_prefer());
785 let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
787 assert!(layout.f_prefer());
788 }
789
790 #[test]
791 fn test_is_c_prefer() {
792 let shape = [3, 5, 7];
794 let layout = Layout::new(shape, [100, 10, 1], 0).unwrap();
795 assert!(layout.c_prefer());
796 let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
797 assert!(layout.c_prefer());
798 let layout = Layout::new(shape, [-35, 7, 1], 1000).unwrap();
799 assert!(!layout.c_prefer());
800 let layout = Layout::new(shape, [7, 21, 1], 0).unwrap();
801 assert!(!layout.c_prefer());
802 let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
803 assert!(!layout.c_prefer());
804 let layout = Layout::new(shape, [70, 14, 2], 0).unwrap();
805 assert!(!layout.c_prefer());
806 let layout = Layout::new([], [], 0).unwrap();
808 assert!(layout.c_prefer());
809 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
811 assert!(layout.c_prefer());
812 let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
814 assert!(layout.c_prefer());
815 }
816
817 #[test]
818 fn test_is_f_contig() {
819 let shape = [3, 5, 7];
821 let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
822 assert!(layout.f_contig());
823 let layout = Layout::new(shape, [1, 4, 20], 0).unwrap();
824 assert!(!layout.f_contig());
825 let layout = Layout::new([], [], 0).unwrap();
827 assert!(layout.f_contig());
828 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
830 assert!(layout.f_contig());
831 let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
833 assert!(layout.f_contig());
834 }
835
836 #[test]
837 fn test_is_c_contig() {
838 let shape = [3, 5, 7];
840 let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
841 assert!(layout.c_contig());
842 let layout = Layout::new(shape, [36, 7, 1], 0).unwrap();
843 assert!(!layout.c_contig());
844 let layout = Layout::new([], [], 0).unwrap();
846 assert!(layout.c_contig());
847 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
849 assert!(layout.c_contig());
850 let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
852 assert!(layout.c_contig());
853 }
854
855 #[test]
856 fn test_index() {
857 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
861 assert_eq!(layout.index(&[0, 0, 0]), 782);
862 assert_eq!(layout.index(&[2, 1, 4]), 668);
863 assert_eq!(layout.index(&[1, -2, -3]), 830);
864 let layout = Layout::new([], [], 10).unwrap();
866 assert_eq!(layout.index(&[]), 10);
867 }
868
869 #[test]
870 fn test_bounds_index() {
871 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
876 assert_eq!(layout.bounds_index().unwrap(), (602, 864));
877 let layout = unsafe { Layout::new_unchecked([3, 2, 6], [3, -180, 15], 15) };
879 assert!(layout.bounds_index().is_err());
880 let layout = Layout::new([], [], 10).unwrap();
882 assert_eq!(layout.bounds_index().unwrap(), (10, 11));
883 }
884
885 #[test]
886 fn test_transpose() {
887 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
889 let trans = layout.transpose(&[2, 0, 1]).unwrap();
890 assert_eq!(trans.shape(), &[6, 3, 2]);
891 assert_eq!(trans.stride(), &[15, 3, -180]);
892 let trans = layout.permute_dims(&[2, 0, 1]).unwrap();
894 assert_eq!(trans.shape(), &[6, 3, 2]);
895 assert_eq!(trans.stride(), &[15, 3, -180]);
896 let trans = layout.transpose(&[-1, 0, 1]).unwrap();
898 assert_eq!(trans.shape(), &[6, 3, 2]);
899 assert_eq!(trans.stride(), &[15, 3, -180]);
900 let trans = layout.transpose(&[-2, 0, 1]);
902 assert!(trans.is_err());
903 let trans = layout.transpose(&[1, 0]);
905 assert!(trans.is_err());
906 let layout = Layout::new([], [], 0).unwrap();
908 let trans = layout.transpose(&[]);
909 assert!(trans.is_ok());
910 }
911
912 #[test]
913 fn test_reverse_axes() {
914 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
916 let trans = layout.reverse_axes();
917 assert_eq!(trans.shape(), &[6, 2, 3]);
918 assert_eq!(trans.stride(), &[15, -180, 3]);
919 let layout = Layout::new([], [], 782).unwrap();
921 let trans = layout.reverse_axes();
922 assert_eq!(trans.shape(), &[]);
923 assert_eq!(trans.stride(), &[]);
924 }
925
926 #[test]
927 fn test_swapaxes() {
928 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
930 let trans = layout.swapaxes(-1, -2).unwrap();
931 assert_eq!(trans.shape(), &[3, 6, 2]);
932 assert_eq!(trans.stride(), &[3, 15, -180]);
933 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
935 let trans = layout.swapaxes(-1, -1).unwrap();
936 assert_eq!(trans.shape(), &[3, 2, 6]);
937 assert_eq!(trans.stride(), &[3, -180, 15]);
938 }
939
940 #[test]
941 fn test_index_uncheck() {
942 unsafe {
946 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
948 assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
949 assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
950 let layout = Layout::new(vec![3, 2, 6], vec![3, -180, 15], 782).unwrap();
952 assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
953 assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
954 let layout = Layout::new([], [], 10).unwrap();
956 assert_eq!(layout.index_uncheck(&[]), 10);
957 }
958 }
959
960 #[test]
961 fn test_diagonal() {
962 let layout = [2, 3, 4].c();
963 let diag = layout.diagonal(None, None, None).unwrap();
964 assert_eq!(diag, Layout::new([4, 2], [1, 16], 0).unwrap());
965 let diag = layout.diagonal(Some(-1), Some(-2), Some(-1)).unwrap();
966 assert_eq!(diag, Layout::new([2, 2], [12, 5], 0).unwrap());
967 let diag = layout.diagonal(Some(-4), Some(-2), Some(-1)).unwrap();
968 assert_eq!(diag, Layout::new([2, 0], [12, 5], 0).unwrap());
969 }
970
971 #[test]
972 fn test_new_contig() {
973 let layout = [3, 2, 6].c();
974 assert_eq!(layout.shape(), &[3, 2, 6]);
975 assert_eq!(layout.stride(), &[12, 6, 1]);
976 let layout = [3, 2, 6].f();
977 assert_eq!(layout.shape(), &[3, 2, 6]);
978 assert_eq!(layout.stride(), &[1, 3, 6]);
979 let layout: Layout<_> = [3, 2, 6].into();
982 println!("{:?}", layout);
983 }
984
985 #[test]
986 fn test_layout_cast() {
987 let layout = [3, 2, 6].c();
988 assert!(layout.clone().into_dim::<IxD>().is_ok());
989 assert!(layout.clone().into_dim::<Ix3>().is_ok());
990 let layout = vec![3, 2, 6].c();
991 assert!(layout.clone().into_dim::<IxD>().is_ok());
992 assert!(layout.clone().into_dim::<Ix3>().is_ok());
993 assert!(layout.clone().into_dim::<Ix2>().is_err());
994 }
995
996 #[test]
997 fn test_unravel_index() {
998 unsafe {
999 let shape = [3, 2, 6];
1000 assert_eq!(shape.unravel_index_f(0), [0, 0, 0]);
1001 assert_eq!(shape.unravel_index_f(16), [1, 1, 2]);
1002 assert_eq!(shape.unravel_index_c(0), [0, 0, 0]);
1003 assert_eq!(shape.unravel_index_c(16), [1, 0, 4]);
1004 }
1005 }
1006}