rstsr_common/layout/
layoutbase.rs

1//! Layout of tensor.
2use crate::prelude_dev::*;
3use itertools::izip;
4
5/* #region Struct Definitions */
6
7/// Layout of tensor.
8///
9/// Layout is a struct that contains shape, stride, and offset of tensor.
10/// - Shape is the size of each dimension of tensor.
11/// - Stride is the number of elements to skip to get to the next element in
12///   each dimension.
13/// - Offset is the starting position of tensor.
14#[doc = include_str!("readme.md")]
15#[derive(Clone)]
16pub struct Layout<D>
17where
18    D: DimBaseAPI,
19{
20    // essential definitions to layout
21    pub(crate) shape: D,
22    pub(crate) stride: D::Stride,
23    pub(crate) offset: usize,
24}
25
26unsafe impl<D> Send for Layout<D> where D: DimBaseAPI {}
27unsafe impl<D> Sync for Layout<D> where D: DimBaseAPI {}
28
29/* #endregion */
30
31/* #region Layout */
32
33/// Getter/setter functions for layout.
34impl<D> Layout<D>
35where
36    D: DimBaseAPI,
37{
38    /// Shape of tensor. Getter function.
39    #[inline]
40    pub fn shape(&self) -> &D {
41        &self.shape
42    }
43
44    /// Stride of tensor. Getter function.
45    #[inline]
46    pub fn stride(&self) -> &D::Stride {
47        &self.stride
48    }
49
50    /// Starting offset of tensor. Getter function.
51    #[inline]
52    pub fn offset(&self) -> usize {
53        self.offset
54    }
55
56    /// Number of dimensions of tensor.
57    #[inline]
58    pub fn ndim(&self) -> usize {
59        self.shape.ndim()
60    }
61
62    /// Total number of elements in tensor.
63    ///
64    /// # Note
65    ///
66    /// This function uses cached size, instead of evaluating from shape.
67    #[inline]
68    pub fn size(&self) -> usize {
69        self.shape().as_ref().iter().product()
70    }
71
72    /// Manually set offset.
73    ///
74    /// # Safety
75    ///
76    /// We will not check whether this offset is valid or not.
77    /// In most cases, it is not intended to be used by user.
78    pub unsafe fn set_offset(&mut self, offset: usize) -> &mut Self {
79        self.offset = offset;
80        return self;
81    }
82}
83
84/// Properties of layout.
85impl<D> Layout<D>
86where
87    D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
88{
89    /// Whether this tensor is f-preferred.
90    pub fn f_prefer(&self) -> bool {
91        // always true for 0-dimension or 0-size tensor
92        if self.ndim() == 0 || self.size() == 0 {
93            return true;
94        }
95
96        let stride = self.stride.as_ref();
97        let shape = self.shape.as_ref();
98        let mut last = 0;
99        for (&s, &d) in stride.iter().zip(shape.iter()) {
100            if d != 1 {
101                if s < last {
102                    // latter strides must larger than previous strides
103                    return false;
104                }
105                if last == 0 && s != 1 {
106                    // first stride must be 1
107                    return false;
108                }
109                last = s;
110            } else if last == 0 {
111                // if dimension is one, then consider that stride is one, counted as contiguous
112                // in last dimension
113                last = 1;
114            }
115        }
116        return true;
117    }
118
119    /// Whether this tensor is c-preferred.
120    pub fn c_prefer(&self) -> bool {
121        // always true for 0-dimension or 0-size tensor
122        if self.ndim() == 0 || self.size() == 0 {
123            return true;
124        }
125
126        let stride = self.stride.as_ref();
127        let shape = self.shape.as_ref();
128        let mut last = 0;
129        for (&s, &d) in stride.iter().zip(shape.iter()).rev() {
130            if d != 1 {
131                if s < last {
132                    // previous strides must larger than latter strides
133                    return false;
134                }
135                if last == 0 && s != 1 {
136                    // last stride must be 1
137                    return false;
138                }
139                last = s;
140            } else if last == 0 {
141                // if dimension is one, then consider that stride is one, counted as contiguous
142                // in last dimension
143                last = 1;
144            }
145        }
146        return true;
147    }
148
149    /// Least number of dimensions that is f-contiguous for layout.
150    ///
151    /// This function can be useful determining when to iterate by contiguous,
152    /// and when to iterate by index.
153    pub fn ndim_of_f_contig(&self) -> usize {
154        if self.ndim() == 0 || self.size() == 0 {
155            return self.ndim();
156        }
157        let stride = self.stride.as_ref();
158        let shape = self.shape.as_ref();
159        let mut acc = 1;
160        for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).enumerate() {
161            if d != 1 && s != acc {
162                return ndim;
163            }
164            acc *= d as isize;
165        }
166        return self.ndim();
167    }
168
169    /// Least number of dimensions that is c-contiguous for layout.
170    ///
171    /// This function can be useful determining when to iterate by contiguous,
172    /// and when to iterate by index.
173    pub fn ndim_of_c_contig(&self) -> usize {
174        if self.ndim() == 0 || self.size() == 0 {
175            return self.ndim();
176        }
177        let stride = self.stride.as_ref();
178        let shape = self.shape.as_ref();
179        let mut acc = 1;
180        for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).rev().enumerate() {
181            if d != 1 && s != acc {
182                return ndim;
183            }
184            acc *= d as isize;
185        }
186        return self.ndim();
187    }
188
189    /// Whether this tensor is f-contiguous.
190    ///
191    /// Special cases
192    /// - When length of a dimension is one, then stride to that dimension is
193    ///   not important.
194    /// - When length of a dimension is zero, then tensor contains no elements,
195    ///   thus f-contiguous.
196    pub fn f_contig(&self) -> bool {
197        self.ndim() == self.ndim_of_f_contig()
198    }
199
200    /// Whether this tensor is c-contiguous.
201    ///
202    /// Special cases
203    /// - When length of a dimension is one, then stride to that dimension is
204    ///   not important.
205    /// - When length of a dimension is zero, then tensor contains no elements,
206    ///   thus c-contiguous.
207    pub fn c_contig(&self) -> bool {
208        self.ndim() == self.ndim_of_c_contig()
209    }
210
211    /// Index of tensor by list of indexes to dimensions.
212    ///
213    /// This function does not optimized for performance.
214    pub fn index_f(&self, index: &[isize]) -> Result<usize> {
215        rstsr_assert_eq!(index.len(), self.ndim(), InvalidLayout)?;
216        let mut pos = self.offset() as isize;
217        let shape = self.shape.as_ref();
218        let stride = self.stride.as_ref();
219
220        for (&idx, &shp, &strd) in izip!(index.iter(), shape.iter(), stride.iter()) {
221            let idx = if idx < 0 { idx + shp as isize } else { idx };
222            rstsr_pattern!(idx, 0..(shp as isize), ValueOutOfRange)?;
223            pos += strd * idx;
224        }
225        rstsr_pattern!(pos, 0.., ValueOutOfRange)?;
226        return Ok(pos as usize);
227    }
228
229    /// Index of tensor by list of indexes to dimensions.
230    ///
231    /// This function does not optimized for performance. Negative index
232    /// allowed.
233    ///
234    /// # Panics
235    ///
236    /// - Index greater than shape
237    pub fn index(&self, index: &[isize]) -> usize {
238        self.index_f(index).unwrap()
239    }
240
241    /// Index range bounds of current layout. This bound is [min, max), which
242    /// could be feed into range (min..max). If min == max, then this layout
243    /// should not contains any element.
244    ///
245    /// This function will raise error when minimum index is smaller than zero.
246    pub fn bounds_index(&self) -> Result<(usize, usize)> {
247        let n = self.ndim();
248        let offset = self.offset;
249        let shape = self.shape.as_ref();
250        let stride = self.stride.as_ref();
251
252        if n == 0 {
253            return Ok((offset, offset + 1));
254        }
255
256        let mut min = offset as isize;
257        let mut max = offset as isize;
258
259        for i in 0..n {
260            if shape[i] == 0 {
261                return Ok((offset, offset));
262            }
263            if stride[i] > 0 {
264                max += stride[i] * (shape[i] as isize - 1);
265            } else {
266                min += stride[i] * (shape[i] as isize - 1);
267            }
268        }
269        rstsr_pattern!(min, 0.., ValueOutOfRange)?;
270        return Ok((min as usize, max as usize + 1));
271    }
272
273    /// Check if strides is correct (no elemenets can overlap).
274    ///
275    /// This will check if all number of elements in dimension of small strides
276    /// is less than larger strides. For example of valid stride:
277    /// ```output
278    /// shape:  (3,    2,  6)  -> sorted ->  ( 3,   6,   2)
279    /// stride: (3, -300, 15)  -> sorted ->  ( 3,  15, 300)
280    /// number of elements:                    9,  90,
281    /// stride of next dimension              15, 300,
282    /// number of elem < stride of next dim?   +,   +,
283    /// ```
284    ///
285    /// Special cases
286    /// - if length of tensor is zero, then strides will always be correct.
287    /// - if certain dimension is one, then check for this stride will be
288    ///   ignored.
289    ///
290    /// # TODO
291    ///
292    /// Correctness of this function is not fully ensured.
293    pub fn check_strides(&self) -> Result<()> {
294        let shape = self.shape.as_ref();
295        let stride = self.stride.as_ref();
296        rstsr_assert_eq!(shape.len(), stride.len(), InvalidLayout)?;
297        let n = shape.len();
298
299        // unconditionally ok if no elements (length of tensor is zero)
300        // unconditionally ok if 0-dimension
301        if self.size() == 0 || n == 0 {
302            return Ok(());
303        }
304
305        let mut indices = (0..n).filter(|&k| shape[k] > 1).collect::<Vec<_>>();
306        indices.sort_by_key(|&k| stride[k].abs());
307        let shape_sorted = indices.iter().map(|&k| shape[k]).collect::<Vec<_>>();
308        let stride_sorted = indices.iter().map(|&k| stride[k].unsigned_abs()).collect::<Vec<_>>();
309
310        // note: `indices.len() - 1` can be smaller than 0, so `.max(1)` is used
311        for i in 0..indices.len().max(1) - 1 {
312            // following function also checks that stride could not be zero
313            rstsr_pattern!(
314                shape_sorted[i] * stride_sorted[i],
315                1..stride_sorted[i + 1] + 1,
316                InvalidLayout,
317                "Either stride be zero, or stride too small that elements in tensor can be overlapped."
318            )?;
319        }
320        return Ok(());
321    }
322
323    pub fn diagonal(
324        &self,
325        offset: Option<isize>,
326        axis1: Option<isize>,
327        axis2: Option<isize>,
328    ) -> Result<Layout<<D as DimSmallerOneAPI>::SmallerOne>>
329    where
330        D: DimSmallerOneAPI,
331    {
332        // check if this layout is at least 2-dimension
333        rstsr_assert!(self.ndim() >= 2, InvalidLayout)?;
334        // unwrap optional parameters
335        let offset = offset.unwrap_or(0);
336        let axis1 = axis1.unwrap_or(0);
337        let axis2 = axis2.unwrap_or(1);
338        let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
339        let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
340        rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
341        rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
342        let axis1 = axis1 as usize;
343        let axis2 = axis2 as usize;
344
345        // shape and strides of last two dimensions
346        let d1 = self.shape()[axis1] as isize;
347        let d2 = self.shape()[axis2] as isize;
348        let t1 = self.stride()[axis1];
349        let t2 = self.stride()[axis2];
350
351        // number of elements in diagonal, and starting offset
352        let (offset_diag, d_diag) = if (-d2 + 1..0).contains(&offset) {
353            let offset = -offset;
354            let offset_diag = (self.offset() as isize + t1 * offset) as usize;
355            let d_diag = (d1 - offset).min(d2) as usize;
356            (offset_diag, d_diag)
357        } else if (0..d1).contains(&offset) {
358            let offset_diag = (self.offset() as isize + t2 * offset) as usize;
359            let d_diag = (d2 - offset).min(d1) as usize;
360            (offset_diag, d_diag)
361        } else {
362            (self.offset(), 0)
363        };
364
365        // build new layout
366        let t_diag = t1 + t2;
367        let mut shape_diag = vec![];
368        let mut stride_diag = vec![];
369        for i in 0..self.ndim() {
370            if i != axis1 && i != axis2 {
371                shape_diag.push(self.shape()[i]);
372                stride_diag.push(self.stride()[i]);
373            }
374        }
375        shape_diag.push(d_diag);
376        stride_diag.push(t_diag);
377        let layout_diag = Layout::new(shape_diag, stride_diag, offset_diag)?;
378        return layout_diag.into_dim::<<D as DimSmallerOneAPI>::SmallerOne>();
379    }
380}
381
382/// Constructors of layout. See also [`DimLayoutContigAPI`] layout from shape
383/// directly.
384impl<D> Layout<D>
385where
386    D: DimBaseAPI,
387{
388    /// Generate new layout by providing everything.
389    ///
390    /// # Error when
391    ///
392    /// - Shape and stride length mismatch
393    /// - Strides is correct (no elements can overlap)
394    /// - Minimum bound is not negative
395    #[inline]
396    pub fn new(shape: D, stride: D::Stride, offset: usize) -> Result<Self>
397    where
398        D: DimShapeAPI + DimStrideAPI,
399    {
400        let layout = unsafe { Layout::new_unchecked(shape, stride, offset) };
401        layout.bounds_index()?;
402        layout.check_strides()?;
403        return Ok(layout);
404    }
405
406    /// Generate new layout by providing everything, without checking bounds and
407    /// strides.
408    ///
409    /// # Safety
410    ///
411    /// This function does not check whether layout is valid.
412    #[inline]
413    pub unsafe fn new_unchecked(shape: D, stride: D::Stride, offset: usize) -> Self {
414        Layout { shape, stride, offset }
415    }
416
417    /// New zero shape, which number of dimensions are the same to current
418    /// layout.
419    #[inline]
420    pub fn new_shape(&self) -> D {
421        self.shape.new_shape()
422    }
423
424    /// New zero stride, which number of dimensions are the same to current
425    /// layout.
426    #[inline]
427    pub fn new_stride(&self) -> D::Stride {
428        self.shape.new_stride()
429    }
430}
431
432/// Manuplation of layout.
433impl<D> Layout<D>
434where
435    D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
436{
437    /// Transpose layout by permutation.
438    ///
439    /// # See also
440    ///
441    /// - [`numpy.transpose`](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html)
442    /// - [Python array API: `permute_dims`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.permute_dims.html)
443    pub fn transpose(&self, axes: &[isize]) -> Result<Self> {
444        // check axes and cast to usize
445        let n = self.ndim();
446        rstsr_assert_eq!(
447            axes.len(),
448            n,
449            InvalidLayout,
450            "number of elements in axes should be the same to number of dimensions."
451        )?;
452        // no elements in axes can be the same
453        let mut permut_used = vec![false; n];
454        for &p in axes {
455            let p = if p < 0 { p + n as isize } else { p };
456            rstsr_pattern!(p, 0..n as isize, InvalidLayout)?;
457            let p = p as usize;
458            permut_used[p] = true;
459        }
460        rstsr_assert!(
461            permut_used.iter().all(|&b| b),
462            InvalidLayout,
463            "axes should contain all elements from 0 to n-1."
464        )?;
465        let axes = axes
466            .iter()
467            .map(|&p| if p < 0 { p + n as isize } else { p } as usize)
468            .collect::<Vec<_>>();
469
470        let shape_old = self.shape();
471        let stride_old = self.stride();
472        let mut shape = self.new_shape();
473        let mut stride = self.new_stride();
474        for i in 0..self.ndim() {
475            shape[i] = shape_old[axes[i]];
476            stride[i] = stride_old[axes[i]];
477        }
478        return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
479    }
480
481    /// Transpose layout by permutation.
482    ///
483    /// This is the same function to [`Layout::transpose`]
484    pub fn permute_dims(&self, axes: &[isize]) -> Result<Self> {
485        self.transpose(axes)
486    }
487
488    /// Reverse axes of layout.
489    pub fn reverse_axes(&self) -> Self {
490        let shape_old = self.shape();
491        let stride_old = self.stride();
492        let mut shape = self.new_shape();
493        let mut stride = self.new_stride();
494        for i in 0..self.ndim() {
495            shape[i] = shape_old[self.ndim() - i - 1];
496            stride[i] = stride_old[self.ndim() - i - 1];
497        }
498        return unsafe { Layout::new_unchecked(shape, stride, self.offset) };
499    }
500
501    /// Swap axes of layout.
502    pub fn swapaxes(&self, axis1: isize, axis2: isize) -> Result<Self> {
503        let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
504        rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
505        let axis1 = axis1 as usize;
506
507        let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
508        rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
509        let axis2 = axis2 as usize;
510
511        let mut shape = self.shape().clone();
512        let mut stride = self.stride().clone();
513        shape.as_mut().swap(axis1, axis2);
514        stride.as_mut().swap(axis1, axis2);
515        return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
516    }
517}
518
519/// Fast indexing and utilities of layout.
520///
521/// These functions are mostly internal to this crate.
522impl<D> Layout<D>
523where
524    D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
525{
526    /// Index of tensor by list of indexes to dimensions.
527    ///
528    /// # Safety
529    ///
530    /// This function does not check for bounds, including
531    /// - Negative index
532    /// - Index greater than shape
533    ///
534    /// Due to these reasons, this function may well give index smaller than
535    /// zero, which may occur in iterator; so this function returns isize.
536    #[inline]
537    pub unsafe fn index_uncheck(&self, index: &[usize]) -> isize {
538        let stride = self.stride.as_ref();
539        match self.ndim() {
540            0 => self.offset as isize,
541            1 => self.offset as isize + stride[0] * index[0] as isize,
542            2 => {
543                self.offset as isize + stride[0] * index[0] as isize + stride[1] * index[1] as isize
544            },
545            3 => {
546                self.offset as isize
547                    + stride[0] * index[0] as isize
548                    + stride[1] * index[1] as isize
549                    + stride[2] * index[2] as isize
550            },
551            4 => {
552                self.offset as isize
553                    + stride[0] * index[0] as isize
554                    + stride[1] * index[1] as isize
555                    + stride[2] * index[2] as isize
556                    + stride[3] * index[3] as isize
557            },
558            _ => {
559                let mut pos = self.offset as isize;
560                stride.iter().zip(index.iter()).for_each(|(&s, &i)| pos += s * i as isize);
561                pos
562            },
563        }
564    }
565}
566
567impl<D> PartialEq for Layout<D>
568where
569    D: DimBaseAPI,
570{
571    /// For layout, shape must be the same, while stride should be the same when
572    /// shape is not zero or one, but can be arbitary otherwise.
573    fn eq(&self, other: &Self) -> bool {
574        if self.ndim() != other.ndim() {
575            return false;
576        }
577        for i in 0..self.ndim() {
578            let s1 = self.shape()[i];
579            let s2 = other.shape()[i];
580            if s1 != s2 {
581                return false;
582            }
583            if s1 != 1 && s1 != 0 && self.stride()[i] != other.stride()[i] {
584                return false;
585            }
586        }
587        return true;
588    }
589}
590
591pub trait DimLayoutContigAPI: DimBaseAPI + DimShapeAPI + DimStrideAPI {
592    /// Generate new layout by providing shape and offset; stride fits into
593    /// c-contiguous.
594    fn new_c_contig(&self, offset: Option<usize>) -> Layout<Self> {
595        let shape = self.clone();
596        let stride = shape.stride_c_contig();
597        unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
598    }
599
600    /// Generate new layout by providing shape and offset; stride fits into
601    /// f-contiguous.
602    fn new_f_contig(&self, offset: Option<usize>) -> Layout<Self> {
603        let shape = self.clone();
604        let stride = shape.stride_f_contig();
605        unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
606    }
607
608    /// Simplified function to generate c-contiguous layout. See also
609    /// [DimLayoutContigAPI::new_c_contig].
610    fn c(&self) -> Layout<Self> {
611        self.new_c_contig(None)
612    }
613
614    /// Simplified function to generate f-contiguous layout. See also
615    /// [DimLayoutContigAPI::new_f_contig].
616    fn f(&self) -> Layout<Self> {
617        self.new_f_contig(None)
618    }
619}
620
621impl<const N: usize> DimLayoutContigAPI for Ix<N> {}
622impl DimLayoutContigAPI for IxD {}
623
624/* #endregion Layout */
625
626/* #region Dimension Conversion */
627
628pub trait DimIntoAPI<D>: DimBaseAPI
629where
630    D: DimBaseAPI,
631{
632    fn into_dim(layout: Layout<Self>) -> Result<Layout<D>>;
633}
634
635impl<D> DimIntoAPI<D> for IxD
636where
637    D: DimBaseAPI,
638{
639    fn into_dim(layout: Layout<IxD>) -> Result<Layout<D>> {
640        let shape = layout.shape().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
641        let stride = layout.stride().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
642        let offset = layout.offset();
643        return Ok(Layout { shape, stride, offset });
644    }
645}
646
647impl<const N: usize> DimIntoAPI<IxD> for Ix<N> {
648    fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<IxD>> {
649        let shape = (*layout.shape()).into();
650        let stride = (*layout.stride()).into();
651        let offset = layout.offset();
652        return Ok(Layout { shape, stride, offset });
653    }
654}
655
656impl<const N: usize, const M: usize> DimIntoAPI<Ix<M>> for Ix<N> {
657    fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<Ix<M>>> {
658        rstsr_assert_eq!(N, M, InvalidLayout)?;
659        let shape = layout.shape().to_vec().try_into().unwrap();
660        let stride = layout.stride().to_vec().try_into().unwrap();
661        let offset = layout.offset();
662        return Ok(Layout { shape, stride, offset });
663    }
664}
665
666impl<D> Layout<D>
667where
668    D: DimBaseAPI,
669{
670    /// Convert layout to another dimension.
671    pub fn into_dim<D2>(self) -> Result<Layout<D2>>
672    where
673        D2: DimBaseAPI,
674        D: DimIntoAPI<D2>,
675    {
676        D::into_dim(self)
677    }
678
679    /// Convert layout to another dimension.
680    pub fn to_dim<D2>(&self) -> Result<Layout<D2>>
681    where
682        D2: DimBaseAPI,
683        D: DimIntoAPI<D2>,
684    {
685        D::into_dim(self.clone())
686    }
687}
688
689impl<const N: usize> From<Ix<N>> for Layout<Ix<N>> {
690    fn from(shape: Ix<N>) -> Self {
691        let stride = shape.stride_contig();
692        Layout { shape, stride, offset: 0 }
693    }
694}
695
696impl From<IxD> for Layout<IxD> {
697    fn from(shape: IxD) -> Self {
698        let stride = shape.stride_contig();
699        Layout { shape, stride, offset: 0 }
700    }
701}
702
703/* #endregion */
704
705#[cfg(test)]
706mod test {
707    use std::panic::catch_unwind;
708
709    use super::*;
710
711    #[test]
712    fn test_layout_new() {
713        // a successful layout new
714        let shape = [3, 2, 6];
715        let stride = [3, -300, 15];
716        let layout = Layout::new(shape, stride, 917).unwrap();
717        assert_eq!(layout.shape(), &[3, 2, 6]);
718        assert_eq!(layout.stride(), &[3, -300, 15]);
719        assert_eq!(layout.offset(), 917);
720        assert_eq!(layout.ndim(), 3);
721        // unsuccessful layout new (offset underflow)
722        let shape = [3, 2, 6];
723        let stride = [3, -300, 15];
724        let layout = Layout::new(shape, stride, 0);
725        assert!(layout.is_err());
726        // unsuccessful layout new (zero stride for non-0/1 shape)
727        let shape = [3, 2, 6];
728        let stride = [3, -300, 0];
729        let layout = Layout::new(shape, stride, 1000);
730        assert!(layout.is_err());
731        // unsuccessful layout new (stride too small)
732        let shape = [3, 2, 6];
733        let stride = [3, 4, 7];
734        let layout = Layout::new(shape, stride, 1000);
735        assert!(layout.is_err());
736        // successful layout new (zero dim)
737        let shape = [];
738        let stride = [];
739        let layout = Layout::new(shape, stride, 1000);
740        assert!(layout.is_ok());
741        // successful layout new (stride 0 for 1-shape)
742        let shape = [3, 1, 5];
743        let stride = [1, 0, 15];
744        let layout = Layout::new(shape, stride, 1);
745        assert!(layout.is_ok());
746        // successful layout new (stride 0 for 1-shape)
747        let shape = [3, 1, 5];
748        let stride = [1, 0, 15];
749        let layout = Layout::new(shape, stride, 1);
750        assert!(layout.is_ok());
751        // successful layout new (zero-size tensor)
752        let shape = [3, 0, 5];
753        let stride = [-1, -2, -3];
754        let layout = Layout::new(shape, stride, 1);
755        assert!(layout.is_ok());
756        // anyway, if one need custom layout, use new_unchecked
757        let shape = [3, 2, 6];
758        let stride = [3, -300, 0];
759        let r = catch_unwind(|| unsafe { Layout::new_unchecked(shape, stride, 1000) });
760        assert!(r.is_ok());
761    }
762
763    #[test]
764    fn test_is_f_prefer() {
765        // general case
766        let shape = [3, 5, 7];
767        let layout = Layout::new(shape, [1, 10, 100], 0).unwrap();
768        assert!(layout.f_prefer());
769        let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
770        assert!(layout.f_prefer());
771        let layout = Layout::new(shape, [1, 3, -15], 1000).unwrap();
772        assert!(!layout.f_prefer());
773        let layout = Layout::new(shape, [1, 21, 3], 0).unwrap();
774        assert!(!layout.f_prefer());
775        let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
776        assert!(!layout.f_prefer());
777        let layout = Layout::new(shape, [2, 6, 30], 0).unwrap();
778        assert!(!layout.f_prefer());
779        // zero dimension
780        let layout = Layout::new([], [], 0).unwrap();
781        assert!(layout.f_prefer());
782        // zero size
783        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
784        assert!(layout.f_prefer());
785        // shape with 1
786        let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
787        assert!(layout.f_prefer());
788    }
789
790    #[test]
791    fn test_is_c_prefer() {
792        // general case
793        let shape = [3, 5, 7];
794        let layout = Layout::new(shape, [100, 10, 1], 0).unwrap();
795        assert!(layout.c_prefer());
796        let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
797        assert!(layout.c_prefer());
798        let layout = Layout::new(shape, [-35, 7, 1], 1000).unwrap();
799        assert!(!layout.c_prefer());
800        let layout = Layout::new(shape, [7, 21, 1], 0).unwrap();
801        assert!(!layout.c_prefer());
802        let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
803        assert!(!layout.c_prefer());
804        let layout = Layout::new(shape, [70, 14, 2], 0).unwrap();
805        assert!(!layout.c_prefer());
806        // zero dimension
807        let layout = Layout::new([], [], 0).unwrap();
808        assert!(layout.c_prefer());
809        // zero size
810        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
811        assert!(layout.c_prefer());
812        // shape with 1
813        let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
814        assert!(layout.c_prefer());
815    }
816
817    #[test]
818    fn test_is_f_contig() {
819        // general case
820        let shape = [3, 5, 7];
821        let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
822        assert!(layout.f_contig());
823        let layout = Layout::new(shape, [1, 4, 20], 0).unwrap();
824        assert!(!layout.f_contig());
825        // zero dimension
826        let layout = Layout::new([], [], 0).unwrap();
827        assert!(layout.f_contig());
828        // zero size
829        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
830        assert!(layout.f_contig());
831        // shape with 1
832        let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
833        assert!(layout.f_contig());
834    }
835
836    #[test]
837    fn test_is_c_contig() {
838        // general case
839        let shape = [3, 5, 7];
840        let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
841        assert!(layout.c_contig());
842        let layout = Layout::new(shape, [36, 7, 1], 0).unwrap();
843        assert!(!layout.c_contig());
844        // zero dimension
845        let layout = Layout::new([], [], 0).unwrap();
846        assert!(layout.c_contig());
847        // zero size
848        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
849        assert!(layout.c_contig());
850        // shape with 1
851        let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
852        assert!(layout.c_contig());
853    }
854
855    #[test]
856    fn test_index() {
857        // a = np.arange(9 * 12 * 15)
858        //       .reshape(9, 12, 15)[4:2:-1, 4:10, 2:10:3]
859        //       .transpose(2, 0, 1)
860        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
861        assert_eq!(layout.index(&[0, 0, 0]), 782);
862        assert_eq!(layout.index(&[2, 1, 4]), 668);
863        assert_eq!(layout.index(&[1, -2, -3]), 830);
864        // zero-dim
865        let layout = Layout::new([], [], 10).unwrap();
866        assert_eq!(layout.index(&[]), 10);
867    }
868
869    #[test]
870    fn test_bounds_index() {
871        // a = np.arange(9 * 12 * 15)
872        //       .reshape(9, 12, 15)[4:2:-1, 4:10, 2:10:3]
873        //       .transpose(2, 0, 1)
874        // a.min() = 602, a.max() = 863
875        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
876        assert_eq!(layout.bounds_index().unwrap(), (602, 864));
877        // situation that fails
878        let layout = unsafe { Layout::new_unchecked([3, 2, 6], [3, -180, 15], 15) };
879        assert!(layout.bounds_index().is_err());
880        // zero-dim
881        let layout = Layout::new([], [], 10).unwrap();
882        assert_eq!(layout.bounds_index().unwrap(), (10, 11));
883    }
884
885    #[test]
886    fn test_transpose() {
887        // general
888        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
889        let trans = layout.transpose(&[2, 0, 1]).unwrap();
890        assert_eq!(trans.shape(), &[6, 3, 2]);
891        assert_eq!(trans.stride(), &[15, 3, -180]);
892        // permute_dims is alias of transpose
893        let trans = layout.permute_dims(&[2, 0, 1]).unwrap();
894        assert_eq!(trans.shape(), &[6, 3, 2]);
895        assert_eq!(trans.stride(), &[15, 3, -180]);
896        // negative axis also allowed
897        let trans = layout.transpose(&[-1, 0, 1]).unwrap();
898        assert_eq!(trans.shape(), &[6, 3, 2]);
899        assert_eq!(trans.stride(), &[15, 3, -180]);
900        // repeated axis
901        let trans = layout.transpose(&[-2, 0, 1]);
902        assert!(trans.is_err());
903        // non-valid dimension
904        let trans = layout.transpose(&[1, 0]);
905        assert!(trans.is_err());
906        // zero-dim
907        let layout = Layout::new([], [], 0).unwrap();
908        let trans = layout.transpose(&[]);
909        assert!(trans.is_ok());
910    }
911
912    #[test]
913    fn test_reverse_axes() {
914        // general
915        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
916        let trans = layout.reverse_axes();
917        assert_eq!(trans.shape(), &[6, 2, 3]);
918        assert_eq!(trans.stride(), &[15, -180, 3]);
919        // zero-dim
920        let layout = Layout::new([], [], 782).unwrap();
921        let trans = layout.reverse_axes();
922        assert_eq!(trans.shape(), &[]);
923        assert_eq!(trans.stride(), &[]);
924    }
925
926    #[test]
927    fn test_swapaxes() {
928        // general
929        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
930        let trans = layout.swapaxes(-1, -2).unwrap();
931        assert_eq!(trans.shape(), &[3, 6, 2]);
932        assert_eq!(trans.stride(), &[3, 15, -180]);
933        // same index is allowed
934        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
935        let trans = layout.swapaxes(-1, -1).unwrap();
936        assert_eq!(trans.shape(), &[3, 2, 6]);
937        assert_eq!(trans.stride(), &[3, -180, 15]);
938    }
939
940    #[test]
941    fn test_index_uncheck() {
942        // a = np.arange(9 * 12 * 15)
943        //       .reshape(9, 12, 15)[4:2:-1, 4:10, 2:10:3]
944        //       .transpose(2, 0, 1)
945        unsafe {
946            // fixed dim
947            let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
948            assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
949            assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
950            // dynamic dim
951            let layout = Layout::new(vec![3, 2, 6], vec![3, -180, 15], 782).unwrap();
952            assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
953            assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
954            // zero-dim
955            let layout = Layout::new([], [], 10).unwrap();
956            assert_eq!(layout.index_uncheck(&[]), 10);
957        }
958    }
959
960    #[test]
961    fn test_diagonal() {
962        let layout = [2, 3, 4].c();
963        let diag = layout.diagonal(None, None, None).unwrap();
964        assert_eq!(diag, Layout::new([4, 2], [1, 16], 0).unwrap());
965        let diag = layout.diagonal(Some(-1), Some(-2), Some(-1)).unwrap();
966        assert_eq!(diag, Layout::new([2, 2], [12, 5], 0).unwrap());
967        let diag = layout.diagonal(Some(-4), Some(-2), Some(-1)).unwrap();
968        assert_eq!(diag, Layout::new([2, 0], [12, 5], 0).unwrap());
969    }
970
971    #[test]
972    fn test_new_contig() {
973        let layout = [3, 2, 6].c();
974        assert_eq!(layout.shape(), &[3, 2, 6]);
975        assert_eq!(layout.stride(), &[12, 6, 1]);
976        let layout = [3, 2, 6].f();
977        assert_eq!(layout.shape(), &[3, 2, 6]);
978        assert_eq!(layout.stride(), &[1, 3, 6]);
979        // following code generates contiguous layout
980        // c/f-contig depends on cargo feature
981        let layout: Layout<_> = [3, 2, 6].into();
982        println!("{:?}", layout);
983    }
984
985    #[test]
986    fn test_layout_cast() {
987        let layout = [3, 2, 6].c();
988        assert!(layout.clone().into_dim::<IxD>().is_ok());
989        assert!(layout.clone().into_dim::<Ix3>().is_ok());
990        let layout = vec![3, 2, 6].c();
991        assert!(layout.clone().into_dim::<IxD>().is_ok());
992        assert!(layout.clone().into_dim::<Ix3>().is_ok());
993        assert!(layout.clone().into_dim::<Ix2>().is_err());
994    }
995
996    #[test]
997    fn test_unravel_index() {
998        unsafe {
999            let shape = [3, 2, 6];
1000            assert_eq!(shape.unravel_index_f(0), [0, 0, 0]);
1001            assert_eq!(shape.unravel_index_f(16), [1, 1, 2]);
1002            assert_eq!(shape.unravel_index_c(0), [0, 0, 0]);
1003            assert_eq!(shape.unravel_index_c(16), [1, 0, 4]);
1004        }
1005    }
1006}