1use crate::prelude_dev::*;
3use itertools::izip;
4
5#[doc = include_str!("readme.md")]
14#[derive(Clone)]
15pub struct Layout<D>
16where
17 D: DimBaseAPI,
18{
19 pub(crate) shape: D,
21 pub(crate) stride: D::Stride,
22 pub(crate) offset: usize,
23}
24
25unsafe impl<D> Send for Layout<D> where D: DimBaseAPI {}
26unsafe impl<D> Sync for Layout<D> where D: DimBaseAPI {}
27
28impl<D> Layout<D>
34where
35 D: DimBaseAPI,
36{
37 #[inline]
39 pub fn shape(&self) -> &D {
40 &self.shape
41 }
42
43 #[inline]
45 pub fn stride(&self) -> &D::Stride {
46 &self.stride
47 }
48
49 #[inline]
51 pub fn offset(&self) -> usize {
52 self.offset
53 }
54
55 #[inline]
57 pub fn ndim(&self) -> usize {
58 self.shape.ndim()
59 }
60
61 #[inline]
67 pub fn size(&self) -> usize {
68 self.shape().as_ref().iter().product()
69 }
70
71 pub unsafe fn set_offset(&mut self, offset: usize) -> &mut Self {
78 self.offset = offset;
79 return self;
80 }
81}
82
83impl<D> Layout<D>
85where
86 D: DimBaseAPI + DimShapeAPI,
87{
88 pub fn f_prefer(&self) -> bool {
90 if self.ndim() == 0 || self.size() == 0 {
92 return true;
93 }
94
95 let stride = self.stride.as_ref();
96 let shape = self.shape.as_ref();
97 let mut last = 0;
98 for (&s, &d) in stride.iter().zip(shape.iter()) {
99 if d != 1 {
100 if s < last {
101 return false;
103 }
104 if last == 0 && s != 1 {
105 return false;
107 }
108 last = s;
109 } else if last == 0 {
110 last = 1;
113 }
114 }
115 return true;
116 }
117
118 pub fn c_prefer(&self) -> bool {
120 if self.ndim() == 0 || self.size() == 0 {
122 return true;
123 }
124
125 let stride = self.stride.as_ref();
126 let shape = self.shape.as_ref();
127 let mut last = 0;
128 for (&s, &d) in stride.iter().zip(shape.iter()).rev() {
129 if d != 1 {
130 if s < last {
131 return false;
133 }
134 if last == 0 && s != 1 {
135 return false;
137 }
138 last = s;
139 } else if last == 0 {
140 last = 1;
143 }
144 }
145 return true;
146 }
147
148 pub fn ndim_of_f_contig(&self) -> usize {
153 if self.ndim() == 0 || self.size() == 0 {
154 return self.ndim();
155 }
156 let stride = self.stride.as_ref();
157 let shape = self.shape.as_ref();
158 let mut acc = 1;
159 for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).enumerate() {
160 if d != 1 && s != acc {
161 return ndim;
162 }
163 acc *= d as isize;
164 }
165 return self.ndim();
166 }
167
168 pub fn ndim_of_c_contig(&self) -> usize {
173 if self.ndim() == 0 || self.size() == 0 {
174 return self.ndim();
175 }
176 let stride = self.stride.as_ref();
177 let shape = self.shape.as_ref();
178 let mut acc = 1;
179 for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).rev().enumerate() {
180 if d != 1 && s != acc {
181 return ndim;
182 }
183 acc *= d as isize;
184 }
185 return self.ndim();
186 }
187
188 pub fn f_contig(&self) -> bool {
194 self.ndim() == self.ndim_of_f_contig()
195 }
196
197 pub fn c_contig(&self) -> bool {
203 self.ndim() == self.ndim_of_c_contig()
204 }
205
206 pub fn index_f(&self, index: &[isize]) -> Result<usize> {
210 rstsr_assert_eq!(index.len(), self.ndim(), InvalidLayout)?;
211 let mut pos = self.offset() as isize;
212 let shape = self.shape.as_ref();
213 let stride = self.stride.as_ref();
214
215 for (&idx, &shp, &strd) in izip!(index.iter(), shape.iter(), stride.iter()) {
216 let idx = if idx < 0 { idx + shp as isize } else { idx };
217 rstsr_pattern!(idx, 0..(shp as isize), ValueOutOfRange)?;
218 pos += strd * idx;
219 }
220 rstsr_pattern!(pos, 0.., ValueOutOfRange)?;
221 return Ok(pos as usize);
222 }
223
224 pub fn index(&self, index: &[isize]) -> usize {
229 self.index_f(index).unwrap()
230 }
231
232 pub fn bounds_index(&self) -> Result<(usize, usize)> {
238 let n = self.ndim();
239 let offset = self.offset;
240 let shape = self.shape.as_ref();
241 let stride = self.stride.as_ref();
242
243 if n == 0 {
244 return Ok((offset, offset + 1));
245 }
246
247 let mut min = offset as isize;
248 let mut max = offset as isize;
249
250 for i in 0..n {
251 if shape[i] == 0 {
252 return Ok((offset, offset));
253 }
254 if stride[i] > 0 {
255 max += stride[i] * (shape[i] as isize - 1);
256 } else {
257 min += stride[i] * (shape[i] as isize - 1);
258 }
259 }
260 rstsr_pattern!(min, 0.., ValueOutOfRange)?;
261 return Ok((min as usize, max as usize + 1));
262 }
263
264 pub fn check_strides(&self, skip_zero: bool) -> Result<()> {
286 let shape = self.shape.as_ref();
287 let stride = self.stride.as_ref();
288 rstsr_assert_eq!(shape.len(), stride.len(), InvalidLayout)?;
289 let n = shape.len();
290
291 if self.size() == 0 || n == 0 {
294 return Ok(());
295 }
296
297 let mut indices = (0..n).filter(|&k| shape[k] > 1).collect::<Vec<_>>();
298 indices.sort_by_key(|&k| stride[k].abs());
299 let shape_sorted = indices.iter().map(|&k| shape[k]).collect::<Vec<_>>();
300 let stride_sorted = indices.iter().map(|&k| stride[k].unsigned_abs()).collect::<Vec<_>>();
301
302 let mut elem_cum = 0;
304 for i in 0..indices.len() {
305 if stride_sorted[i] == 0 && skip_zero {
307 continue;
308 }
309 rstsr_pattern!(
311 elem_cum,
312 0..stride_sorted[i],
313 InvalidLayout,
314 "Either stride be zero, or stride too small that elements in tensor can be overlapped."
315 )?;
316
317 elem_cum += (shape_sorted[i] - 1) * stride_sorted[i];
318 }
319 return Ok(());
320 }
321
322 pub fn diagonal(
323 &self,
324 offset: Option<isize>,
325 axis1: Option<isize>,
326 axis2: Option<isize>,
327 ) -> Result<Layout<<D as DimSmallerOneAPI>::SmallerOne>>
328 where
329 D: DimSmallerOneAPI,
330 {
331 rstsr_assert!(self.ndim() >= 2, InvalidLayout)?;
333 let offset = offset.unwrap_or(0);
335 let axis1 = axis1.unwrap_or(0);
336 let axis2 = axis2.unwrap_or(1);
337 let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
338 let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
339 rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
340 rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
341 let axis1 = axis1 as usize;
342 let axis2 = axis2 as usize;
343
344 let d1 = self.shape()[axis1] as isize;
346 let d2 = self.shape()[axis2] as isize;
347 let t1 = self.stride()[axis1];
348 let t2 = self.stride()[axis2];
349
350 let (offset_diag, d_diag) = if (-d2 + 1..0).contains(&offset) {
352 let offset = -offset;
353 let offset_diag = (self.offset() as isize + t1 * offset) as usize;
354 let d_diag = (d1 - offset).min(d2) as usize;
355 (offset_diag, d_diag)
356 } else if (0..d1).contains(&offset) {
357 let offset_diag = (self.offset() as isize + t2 * offset) as usize;
358 let d_diag = (d2 - offset).min(d1) as usize;
359 (offset_diag, d_diag)
360 } else {
361 (self.offset(), 0)
362 };
363
364 let t_diag = t1 + t2;
366 let mut shape_diag = vec![];
367 let mut stride_diag = vec![];
368 for i in 0..self.ndim() {
369 if i != axis1 && i != axis2 {
370 shape_diag.push(self.shape()[i]);
371 stride_diag.push(self.stride()[i]);
372 }
373 }
374 shape_diag.push(d_diag);
375 stride_diag.push(t_diag);
376 let layout_diag = Layout::new(shape_diag, stride_diag, offset_diag)?;
377 return layout_diag.into_dim::<<D as DimSmallerOneAPI>::SmallerOne>();
378 }
379}
380
381impl<D> Layout<D>
384where
385 D: DimBaseAPI,
386{
387 #[inline]
395 pub fn new(shape: D, stride: D::Stride, offset: usize) -> Result<Self>
396 where
397 D: DimShapeAPI,
398 {
399 let layout = unsafe { Layout::new_unchecked(shape, stride, offset) };
400 layout.bounds_index()?;
401 layout.check_strides(true)?;
402 return Ok(layout);
403 }
404
405 #[inline]
412 pub unsafe fn new_unchecked(shape: D, stride: D::Stride, offset: usize) -> Self {
413 Layout { shape, stride, offset }
414 }
415
416 #[inline]
419 pub fn new_shape(&self) -> D {
420 self.shape.new_shape()
421 }
422
423 #[inline]
426 pub fn new_stride(&self) -> D::Stride {
427 self.shape.new_stride()
428 }
429}
430
431impl<D> Layout<D>
433where
434 D: DimBaseAPI + DimShapeAPI,
435{
436 pub fn transpose(&self, axes: &[isize]) -> Result<Self> {
443 let n = self.ndim();
445 rstsr_assert_eq!(
446 axes.len(),
447 n,
448 InvalidLayout,
449 "number of elements in axes should be the same to number of dimensions."
450 )?;
451 let axes = normalize_axes_index(axes.into(), n, false, false)?;
454 let axes = axes.into_iter().map(|a| a as usize).collect::<Vec<usize>>();
455
456 let shape_old = self.shape();
457 let stride_old = self.stride();
458 let mut shape = self.new_shape();
459 let mut stride = self.new_stride();
460 for i in 0..self.ndim() {
461 shape[i] = shape_old[axes[i]];
462 stride[i] = stride_old[axes[i]];
463 }
464 return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
465 }
466
467 pub fn permute_dims(&self, axes: &[isize]) -> Result<Self> {
471 self.transpose(axes)
472 }
473
474 pub fn reverse_axes(&self) -> Self {
476 let shape_old = self.shape();
477 let stride_old = self.stride();
478 let mut shape = self.new_shape();
479 let mut stride = self.new_stride();
480 for i in 0..self.ndim() {
481 shape[i] = shape_old[self.ndim() - i - 1];
482 stride[i] = stride_old[self.ndim() - i - 1];
483 }
484 return unsafe { Layout::new_unchecked(shape, stride, self.offset) };
485 }
486
487 pub fn swapaxes(&self, axis1: isize, axis2: isize) -> Result<Self> {
489 let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
490 rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
491 let axis1 = axis1 as usize;
492
493 let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
494 rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
495 let axis2 = axis2 as usize;
496
497 let mut shape = self.shape().clone();
498 let mut stride = self.stride().clone();
499 shape.as_mut().swap(axis1, axis2);
500 stride.as_mut().swap(axis1, axis2);
501 return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
502 }
503}
504
505impl<D> Layout<D>
509where
510 D: DimBaseAPI + DimShapeAPI,
511{
512 #[inline]
523 pub unsafe fn index_uncheck(&self, index: &[usize]) -> isize {
524 let stride = self.stride.as_ref();
525 match self.ndim() {
526 0 => self.offset as isize,
527 1 => self.offset as isize + stride[0] * index[0] as isize,
528 2 => self.offset as isize + stride[0] * index[0] as isize + stride[1] * index[1] as isize,
529 3 => {
530 self.offset as isize
531 + stride[0] * index[0] as isize
532 + stride[1] * index[1] as isize
533 + stride[2] * index[2] as isize
534 },
535 4 => {
536 self.offset as isize
537 + stride[0] * index[0] as isize
538 + stride[1] * index[1] as isize
539 + stride[2] * index[2] as isize
540 + stride[3] * index[3] as isize
541 },
542 _ => {
543 let mut pos = self.offset as isize;
544 stride.iter().zip(index.iter()).for_each(|(&s, &i)| pos += s * i as isize);
545 pos
546 },
547 }
548 }
549}
550
551impl<D> PartialEq for Layout<D>
552where
553 D: DimBaseAPI,
554{
555 fn eq(&self, other: &Self) -> bool {
558 if self.ndim() != other.ndim() {
559 return false;
560 }
561 if self.offset != other.offset {
562 return false;
563 }
564 for i in 0..self.ndim() {
565 let s1 = self.shape()[i];
566 let s2 = other.shape()[i];
567 if s1 != s2 {
568 return false;
569 }
570 if s1 != 1 && s1 != 0 && self.stride()[i] != other.stride()[i] {
571 return false;
572 }
573 }
574 return true;
575 }
576}
577
578pub trait DimLayoutContigAPI: DimBaseAPI + DimShapeAPI {
579 fn new_c_contig(&self, offset: Option<usize>) -> Layout<Self> {
582 let shape = self.clone();
583 let stride = shape.stride_c_contig();
584 unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
585 }
586
587 fn new_f_contig(&self, offset: Option<usize>) -> Layout<Self> {
590 let shape = self.clone();
591 let stride = shape.stride_f_contig();
592 unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
593 }
594
595 fn c(&self) -> Layout<Self> {
598 self.new_c_contig(None)
599 }
600
601 fn f(&self) -> Layout<Self> {
604 self.new_f_contig(None)
605 }
606
607 fn new_contig(&self, offset: Option<usize>, order: FlagOrder) -> Layout<Self> {
609 match order {
610 FlagOrder::C => self.new_c_contig(offset),
611 FlagOrder::F => self.new_f_contig(offset),
612 }
613 }
614}
615
616impl<const N: usize> DimLayoutContigAPI for Ix<N> {}
617impl DimLayoutContigAPI for IxD {}
618
619pub trait DimIntoAPI<D>: DimBaseAPI
624where
625 D: DimBaseAPI,
626{
627 fn into_dim(layout: Layout<Self>) -> Result<Layout<D>>;
628}
629
630impl<D> DimIntoAPI<D> for IxD
631where
632 D: DimBaseAPI,
633{
634 fn into_dim(layout: Layout<IxD>) -> Result<Layout<D>> {
635 let shape = layout.shape().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
636 let stride = layout.stride().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
637 let offset = layout.offset();
638 return Ok(Layout { shape, stride, offset });
639 }
640}
641
642impl<const N: usize> DimIntoAPI<IxD> for Ix<N> {
643 fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<IxD>> {
644 let shape = (*layout.shape()).into();
645 let stride = (*layout.stride()).into();
646 let offset = layout.offset();
647 return Ok(Layout { shape, stride, offset });
648 }
649}
650
651impl<const N: usize, const M: usize> DimIntoAPI<Ix<M>> for Ix<N> {
652 fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<Ix<M>>> {
653 rstsr_assert_eq!(N, M, InvalidLayout)?;
654 let shape = layout.shape().to_vec().try_into().unwrap();
655 let stride = layout.stride().to_vec().try_into().unwrap();
656 let offset = layout.offset();
657 return Ok(Layout { shape, stride, offset });
658 }
659}
660
661impl<D> Layout<D>
662where
663 D: DimBaseAPI,
664{
665 pub fn into_dim<D2>(self) -> Result<Layout<D2>>
667 where
668 D2: DimBaseAPI,
669 D: DimIntoAPI<D2>,
670 {
671 D::into_dim(self)
672 }
673
674 pub fn to_dim<D2>(&self) -> Result<Layout<D2>>
676 where
677 D2: DimBaseAPI,
678 D: DimIntoAPI<D2>,
679 {
680 D::into_dim(self.clone())
681 }
682}
683
684impl<const N: usize> From<Ix<N>> for Layout<Ix<N>> {
685 fn from(shape: Ix<N>) -> Self {
686 let stride = shape.stride_contig();
687 Layout { shape, stride, offset: 0 }
688 }
689}
690
691impl From<IxD> for Layout<IxD> {
692 fn from(shape: IxD) -> Self {
693 let stride = shape.stride_contig();
694 Layout { shape, stride, offset: 0 }
695 }
696}
697
698#[cfg(test)]
701mod test {
702 use std::panic::catch_unwind;
703
704 use super::*;
705
706 #[test]
707 fn test_layout_new() {
708 let shape = [3, 2, 6];
710 let stride = [3, -300, 15];
711 let layout = Layout::new(shape, stride, 917).unwrap();
712 assert_eq!(layout.shape(), &[3, 2, 6]);
713 assert_eq!(layout.stride(), &[3, -300, 15]);
714 assert_eq!(layout.offset(), 917);
715 assert_eq!(layout.ndim(), 3);
716 let shape = [3, 2, 6];
718 let stride = [3, -300, 15];
719 let layout = Layout::new(shape, stride, 0);
720 assert!(layout.is_err());
721 let shape = [3, 2, 6];
723 let stride = [3, 4, 7];
724 let layout = Layout::new(shape, stride, 1000);
725 assert!(layout.is_err());
726 let shape = [3, 2, 6];
728 let stride = [3, -300, 0];
729 let layout = Layout::new(shape, stride, 1000);
730 assert!(layout.is_ok());
731 let shape = [];
733 let stride = [];
734 let layout = Layout::new(shape, stride, 1000);
735 assert!(layout.is_ok());
736 let shape = [3, 1, 5];
738 let stride = [1, 0, 15];
739 let layout = Layout::new(shape, stride, 1);
740 assert!(layout.is_ok());
741 let shape = [3, 1, 5];
743 let stride = [1, 0, 15];
744 let layout = Layout::new(shape, stride, 1);
745 assert!(layout.is_ok());
746 let shape = [3, 0, 5];
748 let stride = [-1, -2, -3];
749 let layout = Layout::new(shape, stride, 1);
750 assert!(layout.is_ok());
751 let shape = [3, 2, 6];
753 let stride = [3, -300, 0];
754 let r = catch_unwind(|| unsafe { Layout::new_unchecked(shape, stride, 1000) });
755 assert!(r.is_ok());
756 }
757
758 #[test]
759 fn test_is_f_prefer() {
760 let shape = [3, 5, 7];
762 let layout = Layout::new(shape, [1, 10, 100], 0).unwrap();
763 assert!(layout.f_prefer());
764 let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
765 assert!(layout.f_prefer());
766 let layout = Layout::new(shape, [1, 3, -15], 1000).unwrap();
767 assert!(!layout.f_prefer());
768 let layout = Layout::new(shape, [1, 21, 3], 0).unwrap();
769 assert!(!layout.f_prefer());
770 let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
771 assert!(!layout.f_prefer());
772 let layout = Layout::new(shape, [2, 6, 30], 0).unwrap();
773 assert!(!layout.f_prefer());
774 let layout = Layout::new([], [], 0).unwrap();
776 assert!(layout.f_prefer());
777 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
779 assert!(layout.f_prefer());
780 let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
782 assert!(layout.f_prefer());
783 }
784
785 #[test]
786 fn test_is_c_prefer() {
787 let shape = [3, 5, 7];
789 let layout = Layout::new(shape, [100, 10, 1], 0).unwrap();
790 assert!(layout.c_prefer());
791 let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
792 assert!(layout.c_prefer());
793 let layout = Layout::new(shape, [-35, 7, 1], 1000).unwrap();
794 assert!(!layout.c_prefer());
795 let layout = Layout::new(shape, [7, 21, 1], 0).unwrap();
796 assert!(!layout.c_prefer());
797 let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
798 assert!(!layout.c_prefer());
799 let layout = Layout::new(shape, [70, 14, 2], 0).unwrap();
800 assert!(!layout.c_prefer());
801 let layout = Layout::new([], [], 0).unwrap();
803 assert!(layout.c_prefer());
804 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
806 assert!(layout.c_prefer());
807 let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
809 assert!(layout.c_prefer());
810 }
811
812 #[test]
813 fn test_is_f_contig() {
814 let shape = [3, 5, 7];
816 let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
817 assert!(layout.f_contig());
818 let layout = Layout::new(shape, [1, 4, 20], 0).unwrap();
819 assert!(!layout.f_contig());
820 let layout = Layout::new([], [], 0).unwrap();
822 assert!(layout.f_contig());
823 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
825 assert!(layout.f_contig());
826 let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
828 assert!(layout.f_contig());
829 }
830
831 #[test]
832 fn test_is_c_contig() {
833 let shape = [3, 5, 7];
835 let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
836 assert!(layout.c_contig());
837 let layout = Layout::new(shape, [36, 7, 1], 0).unwrap();
838 assert!(!layout.c_contig());
839 let layout = Layout::new([], [], 0).unwrap();
841 assert!(layout.c_contig());
842 let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
844 assert!(layout.c_contig());
845 let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
847 assert!(layout.c_contig());
848 }
849
850 #[test]
851 fn test_index() {
852 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
856 assert_eq!(layout.index(&[0, 0, 0]), 782);
857 assert_eq!(layout.index(&[2, 1, 4]), 668);
858 assert_eq!(layout.index(&[1, -2, -3]), 830);
859 let layout = Layout::new([], [], 10).unwrap();
861 assert_eq!(layout.index(&[]), 10);
862 }
863
864 #[test]
865 fn test_bounds_index() {
866 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
871 assert_eq!(layout.bounds_index().unwrap(), (602, 864));
872 let layout = unsafe { Layout::new_unchecked([3, 2, 6], [3, -180, 15], 15) };
874 assert!(layout.bounds_index().is_err());
875 let layout = Layout::new([], [], 10).unwrap();
877 assert_eq!(layout.bounds_index().unwrap(), (10, 11));
878 }
879
880 #[test]
881 fn test_transpose() {
882 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
884 let trans = layout.transpose(&[2, 0, 1]).unwrap();
885 assert_eq!(trans.shape(), &[6, 3, 2]);
886 assert_eq!(trans.stride(), &[15, 3, -180]);
887 let trans = layout.permute_dims(&[2, 0, 1]).unwrap();
889 assert_eq!(trans.shape(), &[6, 3, 2]);
890 assert_eq!(trans.stride(), &[15, 3, -180]);
891 let trans = layout.transpose(&[-1, 0, 1]).unwrap();
893 assert_eq!(trans.shape(), &[6, 3, 2]);
894 assert_eq!(trans.stride(), &[15, 3, -180]);
895 let trans = layout.transpose(&[-2, 0, 1]);
897 assert!(trans.is_err());
898 let trans = layout.transpose(&[1, 0]);
900 assert!(trans.is_err());
901 let layout = Layout::new([], [], 0).unwrap();
903 let trans = layout.transpose(&[]);
904 assert!(trans.is_ok());
905 }
906
907 #[test]
908 fn test_reverse_axes() {
909 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
911 let trans = layout.reverse_axes();
912 assert_eq!(trans.shape(), &[6, 2, 3]);
913 assert_eq!(trans.stride(), &[15, -180, 3]);
914 let layout = Layout::new([], [], 782).unwrap();
916 let trans = layout.reverse_axes();
917 assert_eq!(trans.shape(), &[]);
918 assert_eq!(trans.stride(), &[]);
919 }
920
921 #[test]
922 fn test_swapaxes() {
923 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
925 let trans = layout.swapaxes(-1, -2).unwrap();
926 assert_eq!(trans.shape(), &[3, 6, 2]);
927 assert_eq!(trans.stride(), &[3, 15, -180]);
928 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
930 let trans = layout.swapaxes(-1, -1).unwrap();
931 assert_eq!(trans.shape(), &[3, 2, 6]);
932 assert_eq!(trans.stride(), &[3, -180, 15]);
933 }
934
935 #[test]
936 fn test_index_uncheck() {
937 unsafe {
941 let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
943 assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
944 assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
945 let layout = Layout::new(vec![3, 2, 6], vec![3, -180, 15], 782).unwrap();
947 assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
948 assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
949 let layout = Layout::new([], [], 10).unwrap();
951 assert_eq!(layout.index_uncheck(&[]), 10);
952 }
953 }
954
955 #[test]
956 fn test_diagonal() {
957 let layout = [2, 3, 4].c();
958 let diag = layout.diagonal(None, None, None).unwrap();
959 assert_eq!(diag, Layout::new([4, 2], [1, 16], 0).unwrap());
960 let diag = layout.diagonal(Some(-1), Some(-2), Some(-1)).unwrap();
961 assert_eq!(diag, Layout::new([2, 2], [12, 5], 4).unwrap()); let diag = layout.diagonal(Some(-4), Some(-2), Some(-1)).unwrap();
963 assert_eq!(diag, Layout::new([2, 0], [12, 5], 0).unwrap());
964 }
965
966 #[test]
967 fn test_new_contig() {
968 let layout = [3, 2, 6].c();
969 assert_eq!(layout.shape(), &[3, 2, 6]);
970 assert_eq!(layout.stride(), &[12, 6, 1]);
971 let layout = [3, 2, 6].f();
972 assert_eq!(layout.shape(), &[3, 2, 6]);
973 assert_eq!(layout.stride(), &[1, 3, 6]);
974 let layout: Layout<_> = [3, 2, 6].into();
977 println!("{layout:?}");
978 }
979
980 #[test]
981 fn test_layout_cast() {
982 let layout = [3, 2, 6].c();
983 assert!(layout.clone().into_dim::<IxD>().is_ok());
984 assert!(layout.clone().into_dim::<Ix3>().is_ok());
985 let layout = vec![3, 2, 6].c();
986 assert!(layout.clone().into_dim::<IxD>().is_ok());
987 assert!(layout.clone().into_dim::<Ix3>().is_ok());
988 assert!(layout.clone().into_dim::<Ix2>().is_err());
989 }
990
991 #[test]
992 fn test_unravel_index() {
993 unsafe {
994 let shape = [3, 2, 6];
995 assert_eq!(shape.unravel_index_f(0), [0, 0, 0]);
996 assert_eq!(shape.unravel_index_f(16), [1, 1, 2]);
997 assert_eq!(shape.unravel_index_c(0), [0, 0, 0]);
998 assert_eq!(shape.unravel_index_c(16), [1, 0, 4]);
999 }
1000 }
1001
1002 #[test]
1003 fn fix_too_strict_stride_check() {
1004 let layout = [10, 11, 12].c();
1005 let slc = (.., slice!(-1, 0, -4));
1006 let slc: AxesIndex<Indexer> = slc.try_into().unwrap();
1007 let indexed = layout.dim_slice(slc.as_ref()).unwrap();
1008 assert_eq!(indexed.shape(), &[10, 3, 12]);
1009 assert_eq!(indexed.stride(), &[132, -48, 1]);
1010 }
1011}