rstsr_common/layout/
layoutbase.rs

1//! Layout of tensor.
2use crate::prelude_dev::*;
3use itertools::izip;
4
5/* #region Struct Definitions */
6
7/// Layout of tensor.
8///
9/// Layout is a struct that contains shape, stride, and offset of tensor.
10/// - Shape is the size of each dimension of tensor.
11/// - Stride is the number of elements to skip to get to the next element in each dimension.
12/// - Offset is the starting position of tensor.
13#[doc = include_str!("readme.md")]
14#[derive(Clone)]
15pub struct Layout<D>
16where
17    D: DimBaseAPI,
18{
19    // essential definitions to layout
20    pub(crate) shape: D,
21    pub(crate) stride: D::Stride,
22    pub(crate) offset: usize,
23}
24
25unsafe impl<D> Send for Layout<D> where D: DimBaseAPI {}
26unsafe impl<D> Sync for Layout<D> where D: DimBaseAPI {}
27
28/* #endregion */
29
30/* #region Layout */
31
32/// Getter/setter functions for layout.
33impl<D> Layout<D>
34where
35    D: DimBaseAPI,
36{
37    /// Shape of tensor. Getter function.
38    #[inline]
39    pub fn shape(&self) -> &D {
40        &self.shape
41    }
42
43    /// Stride of tensor. Getter function.
44    #[inline]
45    pub fn stride(&self) -> &D::Stride {
46        &self.stride
47    }
48
49    /// Starting offset of tensor. Getter function.
50    #[inline]
51    pub fn offset(&self) -> usize {
52        self.offset
53    }
54
55    /// Number of dimensions of tensor.
56    #[inline]
57    pub fn ndim(&self) -> usize {
58        self.shape.ndim()
59    }
60
61    /// Total number of elements in tensor.
62    ///
63    /// # Note
64    ///
65    /// This function uses cached size, instead of evaluating from shape.
66    #[inline]
67    pub fn size(&self) -> usize {
68        self.shape().as_ref().iter().product()
69    }
70
71    /// Manually set offset.
72    ///
73    /// # Safety
74    ///
75    /// We will not check whether this offset is valid or not.
76    /// In most cases, it is not intended to be used by user.
77    pub unsafe fn set_offset(&mut self, offset: usize) -> &mut Self {
78        self.offset = offset;
79        return self;
80    }
81}
82
83/// Properties of layout.
84impl<D> Layout<D>
85where
86    D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
87{
88    /// Whether this tensor is f-preferred.
89    pub fn f_prefer(&self) -> bool {
90        // always true for 0-dimension or 0-size tensor
91        if self.ndim() == 0 || self.size() == 0 {
92            return true;
93        }
94
95        let stride = self.stride.as_ref();
96        let shape = self.shape.as_ref();
97        let mut last = 0;
98        for (&s, &d) in stride.iter().zip(shape.iter()) {
99            if d != 1 {
100                if s < last {
101                    // latter strides must larger than previous strides
102                    return false;
103                }
104                if last == 0 && s != 1 {
105                    // first stride must be 1
106                    return false;
107                }
108                last = s;
109            } else if last == 0 {
110                // if dimension is one, then consider that stride is one, counted as contiguous
111                // in last dimension
112                last = 1;
113            }
114        }
115        return true;
116    }
117
118    /// Whether this tensor is c-preferred.
119    pub fn c_prefer(&self) -> bool {
120        // always true for 0-dimension or 0-size tensor
121        if self.ndim() == 0 || self.size() == 0 {
122            return true;
123        }
124
125        let stride = self.stride.as_ref();
126        let shape = self.shape.as_ref();
127        let mut last = 0;
128        for (&s, &d) in stride.iter().zip(shape.iter()).rev() {
129            if d != 1 {
130                if s < last {
131                    // previous strides must larger than latter strides
132                    return false;
133                }
134                if last == 0 && s != 1 {
135                    // last stride must be 1
136                    return false;
137                }
138                last = s;
139            } else if last == 0 {
140                // if dimension is one, then consider that stride is one, counted as contiguous
141                // in last dimension
142                last = 1;
143            }
144        }
145        return true;
146    }
147
148    /// Least number of dimensions that is f-contiguous for layout.
149    ///
150    /// This function can be useful determining when to iterate by contiguous,
151    /// and when to iterate by index.
152    pub fn ndim_of_f_contig(&self) -> usize {
153        if self.ndim() == 0 || self.size() == 0 {
154            return self.ndim();
155        }
156        let stride = self.stride.as_ref();
157        let shape = self.shape.as_ref();
158        let mut acc = 1;
159        for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).enumerate() {
160            if d != 1 && s != acc {
161                return ndim;
162            }
163            acc *= d as isize;
164        }
165        return self.ndim();
166    }
167
168    /// Least number of dimensions that is c-contiguous for layout.
169    ///
170    /// This function can be useful determining when to iterate by contiguous,
171    /// and when to iterate by index.
172    pub fn ndim_of_c_contig(&self) -> usize {
173        if self.ndim() == 0 || self.size() == 0 {
174            return self.ndim();
175        }
176        let stride = self.stride.as_ref();
177        let shape = self.shape.as_ref();
178        let mut acc = 1;
179        for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).rev().enumerate() {
180            if d != 1 && s != acc {
181                return ndim;
182            }
183            acc *= d as isize;
184        }
185        return self.ndim();
186    }
187
188    /// Whether this tensor is f-contiguous.
189    ///
190    /// Special cases
191    /// - When length of a dimension is one, then stride to that dimension is not important.
192    /// - When length of a dimension is zero, then tensor contains no elements, thus f-contiguous.
193    pub fn f_contig(&self) -> bool {
194        self.ndim() == self.ndim_of_f_contig()
195    }
196
197    /// Whether this tensor is c-contiguous.
198    ///
199    /// Special cases
200    /// - When length of a dimension is one, then stride to that dimension is not important.
201    /// - When length of a dimension is zero, then tensor contains no elements, thus c-contiguous.
202    pub fn c_contig(&self) -> bool {
203        self.ndim() == self.ndim_of_c_contig()
204    }
205
206    /// Index of tensor by list of indexes to dimensions.
207    ///
208    /// This function does not optimized for performance.
209    pub fn index_f(&self, index: &[isize]) -> Result<usize> {
210        rstsr_assert_eq!(index.len(), self.ndim(), InvalidLayout)?;
211        let mut pos = self.offset() as isize;
212        let shape = self.shape.as_ref();
213        let stride = self.stride.as_ref();
214
215        for (&idx, &shp, &strd) in izip!(index.iter(), shape.iter(), stride.iter()) {
216            let idx = if idx < 0 { idx + shp as isize } else { idx };
217            rstsr_pattern!(idx, 0..(shp as isize), ValueOutOfRange)?;
218            pos += strd * idx;
219        }
220        rstsr_pattern!(pos, 0.., ValueOutOfRange)?;
221        return Ok(pos as usize);
222    }
223
224    /// Index of tensor by list of indexes to dimensions.
225    ///
226    /// This function does not optimized for performance. Negative index
227    /// allowed.
228    ///
229    /// # Panics
230    ///
231    /// - Index greater than shape
232    pub fn index(&self, index: &[isize]) -> usize {
233        self.index_f(index).unwrap()
234    }
235
236    /// Index range bounds of current layout. This bound is [min, max), which
237    /// could be feed into range (min..max). If min == max, then this layout
238    /// should not contains any element.
239    ///
240    /// This function will raise error when minimum index is smaller than zero.
241    pub fn bounds_index(&self) -> Result<(usize, usize)> {
242        let n = self.ndim();
243        let offset = self.offset;
244        let shape = self.shape.as_ref();
245        let stride = self.stride.as_ref();
246
247        if n == 0 {
248            return Ok((offset, offset + 1));
249        }
250
251        let mut min = offset as isize;
252        let mut max = offset as isize;
253
254        for i in 0..n {
255            if shape[i] == 0 {
256                return Ok((offset, offset));
257            }
258            if stride[i] > 0 {
259                max += stride[i] * (shape[i] as isize - 1);
260            } else {
261                min += stride[i] * (shape[i] as isize - 1);
262            }
263        }
264        rstsr_pattern!(min, 0.., ValueOutOfRange)?;
265        return Ok((min as usize, max as usize + 1));
266    }
267
268    /// Check if strides is correct (no elemenets can overlap).
269    ///
270    /// This will check if all number of elements in dimension of small strides
271    /// is less than larger strides. For example of valid stride:
272    /// ```output
273    /// shape:  (3,    2,  6)  -> sorted ->  ( 3,   6,   2)
274    /// stride: (3, -300, 15)  -> sorted ->  ( 3,  15, 300)
275    /// number of elements:                    9,  90,
276    /// stride of next dimension              15, 300,
277    /// number of elem < stride of next dim?   +,   +,
278    /// ```
279    ///
280    /// Special cases
281    /// - if length of tensor is zero, then strides will always be correct.
282    /// - if certain dimension is one, then check for this stride will be ignored.
283    ///
284    /// # TODO
285    ///
286    /// Correctness of this function is not fully ensured.
287    pub fn check_strides(&self) -> Result<()> {
288        let shape = self.shape.as_ref();
289        let stride = self.stride.as_ref();
290        rstsr_assert_eq!(shape.len(), stride.len(), InvalidLayout)?;
291        let n = shape.len();
292
293        // unconditionally ok if no elements (length of tensor is zero)
294        // unconditionally ok if 0-dimension
295        if self.size() == 0 || n == 0 {
296            return Ok(());
297        }
298
299        let mut indices = (0..n).filter(|&k| shape[k] > 1).collect::<Vec<_>>();
300        indices.sort_by_key(|&k| stride[k].abs());
301        let shape_sorted = indices.iter().map(|&k| shape[k]).collect::<Vec<_>>();
302        let stride_sorted = indices.iter().map(|&k| stride[k].unsigned_abs()).collect::<Vec<_>>();
303
304        // elem_cum: cumulative number count of elements in tensor for small strides
305        let mut elem_cum = 0;
306        for i in 0..indices.len() {
307            // following function also checks that stride could not be zero
308            rstsr_pattern!(
309                elem_cum,
310                0..stride_sorted[i],
311                InvalidLayout,
312                "Either stride be zero, or stride too small that elements in tensor can be overlapped."
313            )?;
314            elem_cum += (shape_sorted[i] - 1) * stride_sorted[i];
315        }
316        return Ok(());
317    }
318
319    pub fn diagonal(
320        &self,
321        offset: Option<isize>,
322        axis1: Option<isize>,
323        axis2: Option<isize>,
324    ) -> Result<Layout<<D as DimSmallerOneAPI>::SmallerOne>>
325    where
326        D: DimSmallerOneAPI,
327    {
328        // check if this layout is at least 2-dimension
329        rstsr_assert!(self.ndim() >= 2, InvalidLayout)?;
330        // unwrap optional parameters
331        let offset = offset.unwrap_or(0);
332        let axis1 = axis1.unwrap_or(0);
333        let axis2 = axis2.unwrap_or(1);
334        let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
335        let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
336        rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
337        rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
338        let axis1 = axis1 as usize;
339        let axis2 = axis2 as usize;
340
341        // shape and strides of last two dimensions
342        let d1 = self.shape()[axis1] as isize;
343        let d2 = self.shape()[axis2] as isize;
344        let t1 = self.stride()[axis1];
345        let t2 = self.stride()[axis2];
346
347        // number of elements in diagonal, and starting offset
348        let (offset_diag, d_diag) = if (-d2 + 1..0).contains(&offset) {
349            let offset = -offset;
350            let offset_diag = (self.offset() as isize + t1 * offset) as usize;
351            let d_diag = (d1 - offset).min(d2) as usize;
352            (offset_diag, d_diag)
353        } else if (0..d1).contains(&offset) {
354            let offset_diag = (self.offset() as isize + t2 * offset) as usize;
355            let d_diag = (d2 - offset).min(d1) as usize;
356            (offset_diag, d_diag)
357        } else {
358            (self.offset(), 0)
359        };
360
361        // build new layout
362        let t_diag = t1 + t2;
363        let mut shape_diag = vec![];
364        let mut stride_diag = vec![];
365        for i in 0..self.ndim() {
366            if i != axis1 && i != axis2 {
367                shape_diag.push(self.shape()[i]);
368                stride_diag.push(self.stride()[i]);
369            }
370        }
371        shape_diag.push(d_diag);
372        stride_diag.push(t_diag);
373        let layout_diag = Layout::new(shape_diag, stride_diag, offset_diag)?;
374        return layout_diag.into_dim::<<D as DimSmallerOneAPI>::SmallerOne>();
375    }
376}
377
378/// Constructors of layout. See also [`DimLayoutContigAPI`] layout from shape
379/// directly.
380impl<D> Layout<D>
381where
382    D: DimBaseAPI,
383{
384    /// Generate new layout by providing everything.
385    ///
386    /// # Error when
387    ///
388    /// - Shape and stride length mismatch
389    /// - Strides is correct (no elements can overlap)
390    /// - Minimum bound is not negative
391    #[inline]
392    pub fn new(shape: D, stride: D::Stride, offset: usize) -> Result<Self>
393    where
394        D: DimShapeAPI + DimStrideAPI,
395    {
396        let layout = unsafe { Layout::new_unchecked(shape, stride, offset) };
397        layout.bounds_index()?;
398        layout.check_strides()?;
399        return Ok(layout);
400    }
401
402    /// Generate new layout by providing everything, without checking bounds and
403    /// strides.
404    ///
405    /// # Safety
406    ///
407    /// This function does not check whether layout is valid.
408    #[inline]
409    pub unsafe fn new_unchecked(shape: D, stride: D::Stride, offset: usize) -> Self {
410        Layout { shape, stride, offset }
411    }
412
413    /// New zero shape, which number of dimensions are the same to current
414    /// layout.
415    #[inline]
416    pub fn new_shape(&self) -> D {
417        self.shape.new_shape()
418    }
419
420    /// New zero stride, which number of dimensions are the same to current
421    /// layout.
422    #[inline]
423    pub fn new_stride(&self) -> D::Stride {
424        self.shape.new_stride()
425    }
426}
427
428/// Manuplation of layout.
429impl<D> Layout<D>
430where
431    D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
432{
433    /// Transpose layout by permutation.
434    ///
435    /// # See also
436    ///
437    /// - [`numpy.transpose`](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html)
438    /// - [Python array API: `permute_dims`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.permute_dims.html)
439    pub fn transpose(&self, axes: &[isize]) -> Result<Self> {
440        // check axes and cast to usize
441        let n = self.ndim();
442        rstsr_assert_eq!(
443            axes.len(),
444            n,
445            InvalidLayout,
446            "number of elements in axes should be the same to number of dimensions."
447        )?;
448        // no elements in axes can be the same
449        let mut permut_used = vec![false; n];
450        for &p in axes {
451            let p = if p < 0 { p + n as isize } else { p };
452            rstsr_pattern!(p, 0..n as isize, InvalidLayout)?;
453            let p = p as usize;
454            permut_used[p] = true;
455        }
456        rstsr_assert!(
457            permut_used.iter().all(|&b| b),
458            InvalidLayout,
459            "axes should contain all elements from 0 to n-1."
460        )?;
461        let axes = axes.iter().map(|&p| if p < 0 { p + n as isize } else { p } as usize).collect::<Vec<_>>();
462
463        let shape_old = self.shape();
464        let stride_old = self.stride();
465        let mut shape = self.new_shape();
466        let mut stride = self.new_stride();
467        for i in 0..self.ndim() {
468            shape[i] = shape_old[axes[i]];
469            stride[i] = stride_old[axes[i]];
470        }
471        return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
472    }
473
474    /// Transpose layout by permutation.
475    ///
476    /// This is the same function to [`Layout::transpose`]
477    pub fn permute_dims(&self, axes: &[isize]) -> Result<Self> {
478        self.transpose(axes)
479    }
480
481    /// Reverse axes of layout.
482    pub fn reverse_axes(&self) -> Self {
483        let shape_old = self.shape();
484        let stride_old = self.stride();
485        let mut shape = self.new_shape();
486        let mut stride = self.new_stride();
487        for i in 0..self.ndim() {
488            shape[i] = shape_old[self.ndim() - i - 1];
489            stride[i] = stride_old[self.ndim() - i - 1];
490        }
491        return unsafe { Layout::new_unchecked(shape, stride, self.offset) };
492    }
493
494    /// Swap axes of layout.
495    pub fn swapaxes(&self, axis1: isize, axis2: isize) -> Result<Self> {
496        let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
497        rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
498        let axis1 = axis1 as usize;
499
500        let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
501        rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
502        let axis2 = axis2 as usize;
503
504        let mut shape = self.shape().clone();
505        let mut stride = self.stride().clone();
506        shape.as_mut().swap(axis1, axis2);
507        stride.as_mut().swap(axis1, axis2);
508        return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
509    }
510}
511
512/// Fast indexing and utilities of layout.
513///
514/// These functions are mostly internal to this crate.
515impl<D> Layout<D>
516where
517    D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
518{
519    /// Index of tensor by list of indexes to dimensions.
520    ///
521    /// # Safety
522    ///
523    /// This function does not check for bounds, including
524    /// - Negative index
525    /// - Index greater than shape
526    ///
527    /// Due to these reasons, this function may well give index smaller than
528    /// zero, which may occur in iterator; so this function returns isize.
529    #[inline]
530    pub unsafe fn index_uncheck(&self, index: &[usize]) -> isize {
531        let stride = self.stride.as_ref();
532        match self.ndim() {
533            0 => self.offset as isize,
534            1 => self.offset as isize + stride[0] * index[0] as isize,
535            2 => self.offset as isize + stride[0] * index[0] as isize + stride[1] * index[1] as isize,
536            3 => {
537                self.offset as isize
538                    + stride[0] * index[0] as isize
539                    + stride[1] * index[1] as isize
540                    + stride[2] * index[2] as isize
541            },
542            4 => {
543                self.offset as isize
544                    + stride[0] * index[0] as isize
545                    + stride[1] * index[1] as isize
546                    + stride[2] * index[2] as isize
547                    + stride[3] * index[3] as isize
548            },
549            _ => {
550                let mut pos = self.offset as isize;
551                stride.iter().zip(index.iter()).for_each(|(&s, &i)| pos += s * i as isize);
552                pos
553            },
554        }
555    }
556}
557
558impl<D> PartialEq for Layout<D>
559where
560    D: DimBaseAPI,
561{
562    /// For layout, shape must be the same, while stride should be the same when
563    /// shape is not zero or one, but can be arbitary otherwise.
564    fn eq(&self, other: &Self) -> bool {
565        if self.ndim() != other.ndim() {
566            return false;
567        }
568        for i in 0..self.ndim() {
569            let s1 = self.shape()[i];
570            let s2 = other.shape()[i];
571            if s1 != s2 {
572                return false;
573            }
574            if s1 != 1 && s1 != 0 && self.stride()[i] != other.stride()[i] {
575                return false;
576            }
577        }
578        return true;
579    }
580}
581
582pub trait DimLayoutContigAPI: DimBaseAPI + DimShapeAPI + DimStrideAPI {
583    /// Generate new layout by providing shape and offset; stride fits into
584    /// c-contiguous.
585    fn new_c_contig(&self, offset: Option<usize>) -> Layout<Self> {
586        let shape = self.clone();
587        let stride = shape.stride_c_contig();
588        unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
589    }
590
591    /// Generate new layout by providing shape and offset; stride fits into
592    /// f-contiguous.
593    fn new_f_contig(&self, offset: Option<usize>) -> Layout<Self> {
594        let shape = self.clone();
595        let stride = shape.stride_f_contig();
596        unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
597    }
598
599    /// Simplified function to generate c-contiguous layout. See also
600    /// [DimLayoutContigAPI::new_c_contig].
601    fn c(&self) -> Layout<Self> {
602        self.new_c_contig(None)
603    }
604
605    /// Simplified function to generate f-contiguous layout. See also
606    /// [DimLayoutContigAPI::new_f_contig].
607    fn f(&self) -> Layout<Self> {
608        self.new_f_contig(None)
609    }
610
611    /// Generate new layout by providing shape, offset and order.
612    fn new_contig(&self, offset: Option<usize>, order: FlagOrder) -> Layout<Self> {
613        match order {
614            FlagOrder::C => self.new_c_contig(offset),
615            FlagOrder::F => self.new_f_contig(offset),
616        }
617    }
618}
619
620impl<const N: usize> DimLayoutContigAPI for Ix<N> {}
621impl DimLayoutContigAPI for IxD {}
622
623/* #endregion Layout */
624
625/* #region Dimension Conversion */
626
627pub trait DimIntoAPI<D>: DimBaseAPI
628where
629    D: DimBaseAPI,
630{
631    fn into_dim(layout: Layout<Self>) -> Result<Layout<D>>;
632}
633
634impl<D> DimIntoAPI<D> for IxD
635where
636    D: DimBaseAPI,
637{
638    fn into_dim(layout: Layout<IxD>) -> Result<Layout<D>> {
639        let shape = layout.shape().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
640        let stride = layout.stride().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
641        let offset = layout.offset();
642        return Ok(Layout { shape, stride, offset });
643    }
644}
645
646impl<const N: usize> DimIntoAPI<IxD> for Ix<N> {
647    fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<IxD>> {
648        let shape = (*layout.shape()).into();
649        let stride = (*layout.stride()).into();
650        let offset = layout.offset();
651        return Ok(Layout { shape, stride, offset });
652    }
653}
654
655impl<const N: usize, const M: usize> DimIntoAPI<Ix<M>> for Ix<N> {
656    fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<Ix<M>>> {
657        rstsr_assert_eq!(N, M, InvalidLayout)?;
658        let shape = layout.shape().to_vec().try_into().unwrap();
659        let stride = layout.stride().to_vec().try_into().unwrap();
660        let offset = layout.offset();
661        return Ok(Layout { shape, stride, offset });
662    }
663}
664
665impl<D> Layout<D>
666where
667    D: DimBaseAPI,
668{
669    /// Convert layout to another dimension.
670    pub fn into_dim<D2>(self) -> Result<Layout<D2>>
671    where
672        D2: DimBaseAPI,
673        D: DimIntoAPI<D2>,
674    {
675        D::into_dim(self)
676    }
677
678    /// Convert layout to another dimension.
679    pub fn to_dim<D2>(&self) -> Result<Layout<D2>>
680    where
681        D2: DimBaseAPI,
682        D: DimIntoAPI<D2>,
683    {
684        D::into_dim(self.clone())
685    }
686}
687
688impl<const N: usize> From<Ix<N>> for Layout<Ix<N>> {
689    fn from(shape: Ix<N>) -> Self {
690        let stride = shape.stride_contig();
691        Layout { shape, stride, offset: 0 }
692    }
693}
694
695impl From<IxD> for Layout<IxD> {
696    fn from(shape: IxD) -> Self {
697        let stride = shape.stride_contig();
698        Layout { shape, stride, offset: 0 }
699    }
700}
701
702/* #endregion */
703
704#[cfg(test)]
705mod test {
706    use std::panic::catch_unwind;
707
708    use super::*;
709
710    #[test]
711    fn test_layout_new() {
712        // a successful layout new
713        let shape = [3, 2, 6];
714        let stride = [3, -300, 15];
715        let layout = Layout::new(shape, stride, 917).unwrap();
716        assert_eq!(layout.shape(), &[3, 2, 6]);
717        assert_eq!(layout.stride(), &[3, -300, 15]);
718        assert_eq!(layout.offset(), 917);
719        assert_eq!(layout.ndim(), 3);
720        // unsuccessful layout new (offset underflow)
721        let shape = [3, 2, 6];
722        let stride = [3, -300, 15];
723        let layout = Layout::new(shape, stride, 0);
724        assert!(layout.is_err());
725        // unsuccessful layout new (zero stride for non-0/1 shape)
726        let shape = [3, 2, 6];
727        let stride = [3, -300, 0];
728        let layout = Layout::new(shape, stride, 1000);
729        assert!(layout.is_err());
730        // unsuccessful layout new (stride too small)
731        let shape = [3, 2, 6];
732        let stride = [3, 4, 7];
733        let layout = Layout::new(shape, stride, 1000);
734        assert!(layout.is_err());
735        // successful layout new (zero dim)
736        let shape = [];
737        let stride = [];
738        let layout = Layout::new(shape, stride, 1000);
739        assert!(layout.is_ok());
740        // successful layout new (stride 0 for 1-shape)
741        let shape = [3, 1, 5];
742        let stride = [1, 0, 15];
743        let layout = Layout::new(shape, stride, 1);
744        assert!(layout.is_ok());
745        // successful layout new (stride 0 for 1-shape)
746        let shape = [3, 1, 5];
747        let stride = [1, 0, 15];
748        let layout = Layout::new(shape, stride, 1);
749        assert!(layout.is_ok());
750        // successful layout new (zero-size tensor)
751        let shape = [3, 0, 5];
752        let stride = [-1, -2, -3];
753        let layout = Layout::new(shape, stride, 1);
754        assert!(layout.is_ok());
755        // anyway, if one need custom layout, use new_unchecked
756        let shape = [3, 2, 6];
757        let stride = [3, -300, 0];
758        let r = catch_unwind(|| unsafe { Layout::new_unchecked(shape, stride, 1000) });
759        assert!(r.is_ok());
760    }
761
762    #[test]
763    fn test_is_f_prefer() {
764        // general case
765        let shape = [3, 5, 7];
766        let layout = Layout::new(shape, [1, 10, 100], 0).unwrap();
767        assert!(layout.f_prefer());
768        let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
769        assert!(layout.f_prefer());
770        let layout = Layout::new(shape, [1, 3, -15], 1000).unwrap();
771        assert!(!layout.f_prefer());
772        let layout = Layout::new(shape, [1, 21, 3], 0).unwrap();
773        assert!(!layout.f_prefer());
774        let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
775        assert!(!layout.f_prefer());
776        let layout = Layout::new(shape, [2, 6, 30], 0).unwrap();
777        assert!(!layout.f_prefer());
778        // zero dimension
779        let layout = Layout::new([], [], 0).unwrap();
780        assert!(layout.f_prefer());
781        // zero size
782        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
783        assert!(layout.f_prefer());
784        // shape with 1
785        let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
786        assert!(layout.f_prefer());
787    }
788
789    #[test]
790    fn test_is_c_prefer() {
791        // general case
792        let shape = [3, 5, 7];
793        let layout = Layout::new(shape, [100, 10, 1], 0).unwrap();
794        assert!(layout.c_prefer());
795        let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
796        assert!(layout.c_prefer());
797        let layout = Layout::new(shape, [-35, 7, 1], 1000).unwrap();
798        assert!(!layout.c_prefer());
799        let layout = Layout::new(shape, [7, 21, 1], 0).unwrap();
800        assert!(!layout.c_prefer());
801        let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
802        assert!(!layout.c_prefer());
803        let layout = Layout::new(shape, [70, 14, 2], 0).unwrap();
804        assert!(!layout.c_prefer());
805        // zero dimension
806        let layout = Layout::new([], [], 0).unwrap();
807        assert!(layout.c_prefer());
808        // zero size
809        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
810        assert!(layout.c_prefer());
811        // shape with 1
812        let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
813        assert!(layout.c_prefer());
814    }
815
816    #[test]
817    fn test_is_f_contig() {
818        // general case
819        let shape = [3, 5, 7];
820        let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
821        assert!(layout.f_contig());
822        let layout = Layout::new(shape, [1, 4, 20], 0).unwrap();
823        assert!(!layout.f_contig());
824        // zero dimension
825        let layout = Layout::new([], [], 0).unwrap();
826        assert!(layout.f_contig());
827        // zero size
828        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
829        assert!(layout.f_contig());
830        // shape with 1
831        let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
832        assert!(layout.f_contig());
833    }
834
835    #[test]
836    fn test_is_c_contig() {
837        // general case
838        let shape = [3, 5, 7];
839        let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
840        assert!(layout.c_contig());
841        let layout = Layout::new(shape, [36, 7, 1], 0).unwrap();
842        assert!(!layout.c_contig());
843        // zero dimension
844        let layout = Layout::new([], [], 0).unwrap();
845        assert!(layout.c_contig());
846        // zero size
847        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
848        assert!(layout.c_contig());
849        // shape with 1
850        let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
851        assert!(layout.c_contig());
852    }
853
854    #[test]
855    fn test_index() {
856        // a = np.arange(9 * 12 * 15)
857        //       .reshape(9, 12, 15)[4:2:-1, 4:10, 2:10:3]
858        //       .transpose(2, 0, 1)
859        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
860        assert_eq!(layout.index(&[0, 0, 0]), 782);
861        assert_eq!(layout.index(&[2, 1, 4]), 668);
862        assert_eq!(layout.index(&[1, -2, -3]), 830);
863        // zero-dim
864        let layout = Layout::new([], [], 10).unwrap();
865        assert_eq!(layout.index(&[]), 10);
866    }
867
868    #[test]
869    fn test_bounds_index() {
870        // a = np.arange(9 * 12 * 15)
871        //       .reshape(9, 12, 15)[4:2:-1, 4:10, 2:10:3]
872        //       .transpose(2, 0, 1)
873        // a.min() = 602, a.max() = 863
874        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
875        assert_eq!(layout.bounds_index().unwrap(), (602, 864));
876        // situation that fails
877        let layout = unsafe { Layout::new_unchecked([3, 2, 6], [3, -180, 15], 15) };
878        assert!(layout.bounds_index().is_err());
879        // zero-dim
880        let layout = Layout::new([], [], 10).unwrap();
881        assert_eq!(layout.bounds_index().unwrap(), (10, 11));
882    }
883
884    #[test]
885    fn test_transpose() {
886        // general
887        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
888        let trans = layout.transpose(&[2, 0, 1]).unwrap();
889        assert_eq!(trans.shape(), &[6, 3, 2]);
890        assert_eq!(trans.stride(), &[15, 3, -180]);
891        // permute_dims is alias of transpose
892        let trans = layout.permute_dims(&[2, 0, 1]).unwrap();
893        assert_eq!(trans.shape(), &[6, 3, 2]);
894        assert_eq!(trans.stride(), &[15, 3, -180]);
895        // negative axis also allowed
896        let trans = layout.transpose(&[-1, 0, 1]).unwrap();
897        assert_eq!(trans.shape(), &[6, 3, 2]);
898        assert_eq!(trans.stride(), &[15, 3, -180]);
899        // repeated axis
900        let trans = layout.transpose(&[-2, 0, 1]);
901        assert!(trans.is_err());
902        // non-valid dimension
903        let trans = layout.transpose(&[1, 0]);
904        assert!(trans.is_err());
905        // zero-dim
906        let layout = Layout::new([], [], 0).unwrap();
907        let trans = layout.transpose(&[]);
908        assert!(trans.is_ok());
909    }
910
911    #[test]
912    fn test_reverse_axes() {
913        // general
914        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
915        let trans = layout.reverse_axes();
916        assert_eq!(trans.shape(), &[6, 2, 3]);
917        assert_eq!(trans.stride(), &[15, -180, 3]);
918        // zero-dim
919        let layout = Layout::new([], [], 782).unwrap();
920        let trans = layout.reverse_axes();
921        assert_eq!(trans.shape(), &[]);
922        assert_eq!(trans.stride(), &[]);
923    }
924
925    #[test]
926    fn test_swapaxes() {
927        // general
928        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
929        let trans = layout.swapaxes(-1, -2).unwrap();
930        assert_eq!(trans.shape(), &[3, 6, 2]);
931        assert_eq!(trans.stride(), &[3, 15, -180]);
932        // same index is allowed
933        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
934        let trans = layout.swapaxes(-1, -1).unwrap();
935        assert_eq!(trans.shape(), &[3, 2, 6]);
936        assert_eq!(trans.stride(), &[3, -180, 15]);
937    }
938
939    #[test]
940    fn test_index_uncheck() {
941        // a = np.arange(9 * 12 * 15)
942        //       .reshape(9, 12, 15)[4:2:-1, 4:10, 2:10:3]
943        //       .transpose(2, 0, 1)
944        unsafe {
945            // fixed dim
946            let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
947            assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
948            assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
949            // dynamic dim
950            let layout = Layout::new(vec![3, 2, 6], vec![3, -180, 15], 782).unwrap();
951            assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
952            assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
953            // zero-dim
954            let layout = Layout::new([], [], 10).unwrap();
955            assert_eq!(layout.index_uncheck(&[]), 10);
956        }
957    }
958
959    #[test]
960    fn test_diagonal() {
961        let layout = [2, 3, 4].c();
962        let diag = layout.diagonal(None, None, None).unwrap();
963        assert_eq!(diag, Layout::new([4, 2], [1, 16], 0).unwrap());
964        let diag = layout.diagonal(Some(-1), Some(-2), Some(-1)).unwrap();
965        assert_eq!(diag, Layout::new([2, 2], [12, 5], 0).unwrap());
966        let diag = layout.diagonal(Some(-4), Some(-2), Some(-1)).unwrap();
967        assert_eq!(diag, Layout::new([2, 0], [12, 5], 0).unwrap());
968    }
969
970    #[test]
971    fn test_new_contig() {
972        let layout = [3, 2, 6].c();
973        assert_eq!(layout.shape(), &[3, 2, 6]);
974        assert_eq!(layout.stride(), &[12, 6, 1]);
975        let layout = [3, 2, 6].f();
976        assert_eq!(layout.shape(), &[3, 2, 6]);
977        assert_eq!(layout.stride(), &[1, 3, 6]);
978        // following code generates contiguous layout
979        // c/f-contig depends on cargo feature
980        let layout: Layout<_> = [3, 2, 6].into();
981        println!("{layout:?}");
982    }
983
984    #[test]
985    fn test_layout_cast() {
986        let layout = [3, 2, 6].c();
987        assert!(layout.clone().into_dim::<IxD>().is_ok());
988        assert!(layout.clone().into_dim::<Ix3>().is_ok());
989        let layout = vec![3, 2, 6].c();
990        assert!(layout.clone().into_dim::<IxD>().is_ok());
991        assert!(layout.clone().into_dim::<Ix3>().is_ok());
992        assert!(layout.clone().into_dim::<Ix2>().is_err());
993    }
994
995    #[test]
996    fn test_unravel_index() {
997        unsafe {
998            let shape = [3, 2, 6];
999            assert_eq!(shape.unravel_index_f(0), [0, 0, 0]);
1000            assert_eq!(shape.unravel_index_f(16), [1, 1, 2]);
1001            assert_eq!(shape.unravel_index_c(0), [0, 0, 0]);
1002            assert_eq!(shape.unravel_index_c(16), [1, 0, 4]);
1003        }
1004    }
1005
1006    #[test]
1007    fn fix_too_strict_stride_check() {
1008        let layout = [10, 11, 12].c();
1009        let slc = (.., slice!(-1, 0, -4));
1010        let slc: AxesIndex<Indexer> = slc.try_into().unwrap();
1011        let indexed = layout.dim_slice(slc.as_ref()).unwrap();
1012        assert_eq!(indexed.shape(), &[10, 3, 12]);
1013        assert_eq!(indexed.stride(), &[132, -48, 1]);
1014    }
1015}