Skip to main content

rstsr_common/layout/
layoutbase.rs

1//! Layout of tensor.
2use crate::prelude_dev::*;
3use itertools::izip;
4
5/* #region Struct Definitions */
6
7/// Layout of tensor.
8///
9/// Layout is a struct that contains shape, stride, and offset of tensor.
10/// - Shape is the size of each dimension of tensor.
11/// - Stride is the number of elements to skip to get to the next element in each dimension.
12/// - Offset is the starting position of tensor.
13#[doc = include_str!("readme.md")]
14#[derive(Clone)]
15pub struct Layout<D>
16where
17    D: DimBaseAPI,
18{
19    // essential definitions to layout
20    pub(crate) shape: D,
21    pub(crate) stride: D::Stride,
22    pub(crate) offset: usize,
23}
24
25unsafe impl<D> Send for Layout<D> where D: DimBaseAPI {}
26unsafe impl<D> Sync for Layout<D> where D: DimBaseAPI {}
27
28/* #endregion */
29
30/* #region Layout */
31
32/// Getter/setter functions for layout.
33impl<D> Layout<D>
34where
35    D: DimBaseAPI,
36{
37    /// Shape of tensor. Getter function.
38    #[inline]
39    pub fn shape(&self) -> &D {
40        &self.shape
41    }
42
43    /// Stride of tensor. Getter function.
44    #[inline]
45    pub fn stride(&self) -> &D::Stride {
46        &self.stride
47    }
48
49    /// Starting offset of tensor. Getter function.
50    #[inline]
51    pub fn offset(&self) -> usize {
52        self.offset
53    }
54
55    /// Number of dimensions of tensor.
56    #[inline]
57    pub fn ndim(&self) -> usize {
58        self.shape.ndim()
59    }
60
61    /// Total number of elements in tensor.
62    ///
63    /// # Note
64    ///
65    /// This function uses cached size, instead of evaluating from shape.
66    #[inline]
67    pub fn size(&self) -> usize {
68        self.shape().as_ref().iter().product()
69    }
70
71    /// Manually set offset.
72    ///
73    /// # Safety
74    ///
75    /// We will not check whether this offset is valid or not.
76    /// In most cases, it is not intended to be used by user.
77    pub unsafe fn set_offset(&mut self, offset: usize) -> &mut Self {
78        self.offset = offset;
79        return self;
80    }
81}
82
83/// Properties of layout.
84impl<D> Layout<D>
85where
86    D: DimBaseAPI + DimShapeAPI,
87{
88    /// Whether this tensor is f-preferred.
89    pub fn f_prefer(&self) -> bool {
90        // always true for 0-dimension or 0-size tensor
91        if self.ndim() == 0 || self.size() == 0 {
92            return true;
93        }
94
95        let stride = self.stride.as_ref();
96        let shape = self.shape.as_ref();
97        let mut last = 0;
98        for (&s, &d) in stride.iter().zip(shape.iter()) {
99            if d != 1 {
100                if s < last {
101                    // latter strides must larger than previous strides
102                    return false;
103                }
104                if last == 0 && s != 1 {
105                    // first stride must be 1
106                    return false;
107                }
108                last = s;
109            } else if last == 0 {
110                // if dimension is one, then consider that stride is one, counted as contiguous
111                // in last dimension
112                last = 1;
113            }
114        }
115        return true;
116    }
117
118    /// Whether this tensor is c-preferred.
119    pub fn c_prefer(&self) -> bool {
120        // always true for 0-dimension or 0-size tensor
121        if self.ndim() == 0 || self.size() == 0 {
122            return true;
123        }
124
125        let stride = self.stride.as_ref();
126        let shape = self.shape.as_ref();
127        let mut last = 0;
128        for (&s, &d) in stride.iter().zip(shape.iter()).rev() {
129            if d != 1 {
130                if s < last {
131                    // previous strides must larger than latter strides
132                    return false;
133                }
134                if last == 0 && s != 1 {
135                    // last stride must be 1
136                    return false;
137                }
138                last = s;
139            } else if last == 0 {
140                // if dimension is one, then consider that stride is one, counted as contiguous
141                // in last dimension
142                last = 1;
143            }
144        }
145        return true;
146    }
147
148    /// Least number of dimensions that is f-contiguous for layout.
149    ///
150    /// This function can be useful determining when to iterate by contiguous,
151    /// and when to iterate by index.
152    pub fn ndim_of_f_contig(&self) -> usize {
153        if self.ndim() == 0 || self.size() == 0 {
154            return self.ndim();
155        }
156        let stride = self.stride.as_ref();
157        let shape = self.shape.as_ref();
158        let mut acc = 1;
159        for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).enumerate() {
160            if d != 1 && s != acc {
161                return ndim;
162            }
163            acc *= d as isize;
164        }
165        return self.ndim();
166    }
167
168    /// Least number of dimensions that is c-contiguous for layout.
169    ///
170    /// This function can be useful determining when to iterate by contiguous,
171    /// and when to iterate by index.
172    pub fn ndim_of_c_contig(&self) -> usize {
173        if self.ndim() == 0 || self.size() == 0 {
174            return self.ndim();
175        }
176        let stride = self.stride.as_ref();
177        let shape = self.shape.as_ref();
178        let mut acc = 1;
179        for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).rev().enumerate() {
180            if d != 1 && s != acc {
181                return ndim;
182            }
183            acc *= d as isize;
184        }
185        return self.ndim();
186    }
187
188    /// Whether this tensor is f-contiguous.
189    ///
190    /// Special cases
191    /// - When length of a dimension is one, then stride to that dimension is not important.
192    /// - When length of a dimension is zero, then tensor contains no elements, thus f-contiguous.
193    pub fn f_contig(&self) -> bool {
194        self.ndim() == self.ndim_of_f_contig()
195    }
196
197    /// Whether this tensor is c-contiguous.
198    ///
199    /// Special cases
200    /// - When length of a dimension is one, then stride to that dimension is not important.
201    /// - When length of a dimension is zero, then tensor contains no elements, thus c-contiguous.
202    pub fn c_contig(&self) -> bool {
203        self.ndim() == self.ndim_of_c_contig()
204    }
205
206    /// Index of tensor by list of indexes to dimensions.
207    ///
208    /// This function does not optimized for performance.
209    pub fn index_f(&self, index: &[isize]) -> Result<usize> {
210        rstsr_assert_eq!(index.len(), self.ndim(), InvalidLayout)?;
211        let mut pos = self.offset() as isize;
212        let shape = self.shape.as_ref();
213        let stride = self.stride.as_ref();
214
215        for (&idx, &shp, &strd) in izip!(index.iter(), shape.iter(), stride.iter()) {
216            let idx = if idx < 0 { idx + shp as isize } else { idx };
217            rstsr_pattern!(idx, 0..(shp as isize), ValueOutOfRange)?;
218            pos += strd * idx;
219        }
220        rstsr_pattern!(pos, 0.., ValueOutOfRange)?;
221        return Ok(pos as usize);
222    }
223
224    /// Index of tensor by list of indexes to dimensions.
225    ///
226    /// This function does not optimized for performance. Negative index
227    /// allowed.
228    pub fn index(&self, index: &[isize]) -> usize {
229        self.index_f(index).unwrap()
230    }
231
232    /// Index range bounds of current layout. This bound is [min, max), which
233    /// could be feed into range (min..max). If min == max, then this layout
234    /// should not contains any element.
235    ///
236    /// This function will raise error when minimum index is smaller than zero.
237    pub fn bounds_index(&self) -> Result<(usize, usize)> {
238        let n = self.ndim();
239        let offset = self.offset;
240        let shape = self.shape.as_ref();
241        let stride = self.stride.as_ref();
242
243        if n == 0 {
244            return Ok((offset, offset + 1));
245        }
246
247        let mut min = offset as isize;
248        let mut max = offset as isize;
249
250        for i in 0..n {
251            if shape[i] == 0 {
252                return Ok((offset, offset));
253            }
254            if stride[i] > 0 {
255                max += stride[i] * (shape[i] as isize - 1);
256            } else {
257                min += stride[i] * (shape[i] as isize - 1);
258            }
259        }
260        rstsr_pattern!(min, 0.., ValueOutOfRange)?;
261        return Ok((min as usize, max as usize + 1));
262    }
263
264    /// Check if strides is correct (no elemenets can overlap).
265    ///
266    /// This will check if all number of elements in dimension of small strides
267    /// is less than larger strides. For example of valid stride:
268    /// ```output
269    /// shape:  (3,    2,  6)  -> sorted ->  ( 3,   6,   2)
270    /// stride: (3, -300, 15)  -> sorted ->  ( 3,  15, 300)
271    /// number of elements:                    9,  90,
272    /// stride of next dimension              15, 300,
273    /// number of elem < stride of next dim?   +,   +,
274    /// ```
275    ///
276    /// Special cases
277    /// - if length of tensor is zero, then strides will always be correct.
278    /// - if certain dimension is one, then check for this stride will be ignored.
279    /// - if stride has zero value, `skip_zero` parameter will determine whether this function will
280    ///   raise error or not.
281    ///
282    /// # TODO
283    ///
284    /// Correctness of this function is not fully ensured.
285    pub fn check_strides(&self, skip_zero: bool) -> Result<()> {
286        let shape = self.shape.as_ref();
287        let stride = self.stride.as_ref();
288        rstsr_assert_eq!(shape.len(), stride.len(), InvalidLayout)?;
289        let n = shape.len();
290
291        // unconditionally ok if no elements (length of tensor is zero)
292        // unconditionally ok if 0-dimension
293        if self.size() == 0 || n == 0 {
294            return Ok(());
295        }
296
297        let mut indices = (0..n).filter(|&k| shape[k] > 1).collect::<Vec<_>>();
298        indices.sort_by_key(|&k| stride[k].abs());
299        let shape_sorted = indices.iter().map(|&k| shape[k]).collect::<Vec<_>>();
300        let stride_sorted = indices.iter().map(|&k| stride[k].unsigned_abs()).collect::<Vec<_>>();
301
302        // elem_cum: cumulative number count of elements in tensor for small strides
303        let mut elem_cum = 0;
304        for i in 0..indices.len() {
305            // if stride is zero, then skip check for this axis
306            if stride_sorted[i] == 0 && skip_zero {
307                continue;
308            }
309            // following function also checks that stride could not be zero
310            rstsr_pattern!(
311                elem_cum,
312                0..stride_sorted[i],
313                InvalidLayout,
314                "Either stride be zero, or stride too small that elements in tensor can be overlapped."
315            )?;
316
317            elem_cum += (shape_sorted[i] - 1) * stride_sorted[i];
318        }
319        return Ok(());
320    }
321
322    pub fn diagonal(
323        &self,
324        offset: Option<isize>,
325        axis1: Option<isize>,
326        axis2: Option<isize>,
327    ) -> Result<Layout<<D as DimSmallerOneAPI>::SmallerOne>>
328    where
329        D: DimSmallerOneAPI,
330    {
331        // check if this layout is at least 2-dimension
332        rstsr_assert!(self.ndim() >= 2, InvalidLayout)?;
333        // unwrap optional parameters
334        let offset = offset.unwrap_or(0);
335        let axis1 = axis1.unwrap_or(0);
336        let axis2 = axis2.unwrap_or(1);
337        let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
338        let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
339        rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
340        rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
341        let axis1 = axis1 as usize;
342        let axis2 = axis2 as usize;
343
344        // shape and strides of last two dimensions
345        let d1 = self.shape()[axis1] as isize;
346        let d2 = self.shape()[axis2] as isize;
347        let t1 = self.stride()[axis1];
348        let t2 = self.stride()[axis2];
349
350        // number of elements in diagonal, and starting offset
351        let (offset_diag, d_diag) = if (-d2 + 1..0).contains(&offset) {
352            let offset = -offset;
353            let offset_diag = (self.offset() as isize + t1 * offset) as usize;
354            let d_diag = (d1 - offset).min(d2) as usize;
355            (offset_diag, d_diag)
356        } else if (0..d1).contains(&offset) {
357            let offset_diag = (self.offset() as isize + t2 * offset) as usize;
358            let d_diag = (d2 - offset).min(d1) as usize;
359            (offset_diag, d_diag)
360        } else {
361            (self.offset(), 0)
362        };
363
364        // build new layout
365        let t_diag = t1 + t2;
366        let mut shape_diag = vec![];
367        let mut stride_diag = vec![];
368        for i in 0..self.ndim() {
369            if i != axis1 && i != axis2 {
370                shape_diag.push(self.shape()[i]);
371                stride_diag.push(self.stride()[i]);
372            }
373        }
374        shape_diag.push(d_diag);
375        stride_diag.push(t_diag);
376        let layout_diag = Layout::new(shape_diag, stride_diag, offset_diag)?;
377        return layout_diag.into_dim::<<D as DimSmallerOneAPI>::SmallerOne>();
378    }
379}
380
381/// Constructors of layout. See also [`DimLayoutContigAPI`] layout from shape
382/// directly.
383impl<D> Layout<D>
384where
385    D: DimBaseAPI,
386{
387    /// Generate new layout by providing everything.
388    ///
389    /// # Error when
390    ///
391    /// - Shape and stride length mismatch
392    /// - Strides is correct (no elements can overlap)
393    /// - Minimum bound is not negative
394    #[inline]
395    pub fn new(shape: D, stride: D::Stride, offset: usize) -> Result<Self>
396    where
397        D: DimShapeAPI,
398    {
399        let layout = unsafe { Layout::new_unchecked(shape, stride, offset) };
400        layout.bounds_index()?;
401        layout.check_strides(true)?;
402        return Ok(layout);
403    }
404
405    /// Generate new layout by providing everything, without checking bounds and
406    /// strides.
407    ///
408    /// # Safety
409    ///
410    /// This function does not check whether layout is valid.
411    #[inline]
412    pub unsafe fn new_unchecked(shape: D, stride: D::Stride, offset: usize) -> Self {
413        Layout { shape, stride, offset }
414    }
415
416    /// New zero shape, which number of dimensions are the same to current
417    /// layout.
418    #[inline]
419    pub fn new_shape(&self) -> D {
420        self.shape.new_shape()
421    }
422
423    /// New zero stride, which number of dimensions are the same to current
424    /// layout.
425    #[inline]
426    pub fn new_stride(&self) -> D::Stride {
427        self.shape.new_stride()
428    }
429}
430
431/// Manuplation of layout.
432impl<D> Layout<D>
433where
434    D: DimBaseAPI + DimShapeAPI,
435{
436    /// Transpose layout by permutation.
437    ///
438    /// # See also
439    ///
440    /// - [`numpy.transpose`](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html)
441    /// - [Python array API: `permute_dims`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.permute_dims.html)
442    pub fn transpose(&self, axes: &[isize]) -> Result<Self> {
443        // check axes and cast to usize
444        let n = self.ndim();
445        rstsr_assert_eq!(
446            axes.len(),
447            n,
448            InvalidLayout,
449            "number of elements in axes should be the same to number of dimensions."
450        )?;
451        // normalize axes; since we have checked number of elements, and not allowed duplicate, so no other
452        // check is needed for axes
453        let axes = normalize_axes_index(axes.into(), n, false, false)?;
454        let axes = axes.into_iter().map(|a| a as usize).collect::<Vec<usize>>();
455
456        let shape_old = self.shape();
457        let stride_old = self.stride();
458        let mut shape = self.new_shape();
459        let mut stride = self.new_stride();
460        for i in 0..self.ndim() {
461            shape[i] = shape_old[axes[i]];
462            stride[i] = stride_old[axes[i]];
463        }
464        return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
465    }
466
467    /// Transpose layout by permutation.
468    ///
469    /// This is the same function to [`Layout::transpose`]
470    pub fn permute_dims(&self, axes: &[isize]) -> Result<Self> {
471        self.transpose(axes)
472    }
473
474    /// Reverse axes of layout.
475    pub fn reverse_axes(&self) -> Self {
476        let shape_old = self.shape();
477        let stride_old = self.stride();
478        let mut shape = self.new_shape();
479        let mut stride = self.new_stride();
480        for i in 0..self.ndim() {
481            shape[i] = shape_old[self.ndim() - i - 1];
482            stride[i] = stride_old[self.ndim() - i - 1];
483        }
484        return unsafe { Layout::new_unchecked(shape, stride, self.offset) };
485    }
486
487    /// Swap axes of layout.
488    pub fn swapaxes(&self, axis1: isize, axis2: isize) -> Result<Self> {
489        let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
490        rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
491        let axis1 = axis1 as usize;
492
493        let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
494        rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
495        let axis2 = axis2 as usize;
496
497        let mut shape = self.shape().clone();
498        let mut stride = self.stride().clone();
499        shape.as_mut().swap(axis1, axis2);
500        stride.as_mut().swap(axis1, axis2);
501        return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
502    }
503}
504
505/// Fast indexing and utilities of layout.
506///
507/// These functions are mostly internal to this crate.
508impl<D> Layout<D>
509where
510    D: DimBaseAPI + DimShapeAPI,
511{
512    /// Index of tensor by list of indexes to dimensions.
513    ///
514    /// # Safety
515    ///
516    /// This function does not check for bounds, including
517    /// - Negative index
518    /// - Index greater than shape
519    ///
520    /// Due to these reasons, this function may well give index smaller than
521    /// zero, which may occur in iterator; so this function returns isize.
522    #[inline]
523    pub unsafe fn index_uncheck(&self, index: &[usize]) -> isize {
524        let stride = self.stride.as_ref();
525        match self.ndim() {
526            0 => self.offset as isize,
527            1 => self.offset as isize + stride[0] * index[0] as isize,
528            2 => self.offset as isize + stride[0] * index[0] as isize + stride[1] * index[1] as isize,
529            3 => {
530                self.offset as isize
531                    + stride[0] * index[0] as isize
532                    + stride[1] * index[1] as isize
533                    + stride[2] * index[2] as isize
534            },
535            4 => {
536                self.offset as isize
537                    + stride[0] * index[0] as isize
538                    + stride[1] * index[1] as isize
539                    + stride[2] * index[2] as isize
540                    + stride[3] * index[3] as isize
541            },
542            _ => {
543                let mut pos = self.offset as isize;
544                stride.iter().zip(index.iter()).for_each(|(&s, &i)| pos += s * i as isize);
545                pos
546            },
547        }
548    }
549}
550
551impl<D> PartialEq for Layout<D>
552where
553    D: DimBaseAPI,
554{
555    /// For layout, shape must be the same, offset must be the same, while stride should be the same
556    /// when shape is not zero or one, but can be arbitary otherwise.
557    fn eq(&self, other: &Self) -> bool {
558        if self.ndim() != other.ndim() {
559            return false;
560        }
561        if self.offset != other.offset {
562            return false;
563        }
564        for i in 0..self.ndim() {
565            let s1 = self.shape()[i];
566            let s2 = other.shape()[i];
567            if s1 != s2 {
568                return false;
569            }
570            if s1 != 1 && s1 != 0 && self.stride()[i] != other.stride()[i] {
571                return false;
572            }
573        }
574        return true;
575    }
576}
577
578pub trait DimLayoutContigAPI: DimBaseAPI + DimShapeAPI {
579    /// Generate new layout by providing shape and offset; stride fits into
580    /// c-contiguous.
581    fn new_c_contig(&self, offset: Option<usize>) -> Layout<Self> {
582        let shape = self.clone();
583        let stride = shape.stride_c_contig();
584        unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
585    }
586
587    /// Generate new layout by providing shape and offset; stride fits into
588    /// f-contiguous.
589    fn new_f_contig(&self, offset: Option<usize>) -> Layout<Self> {
590        let shape = self.clone();
591        let stride = shape.stride_f_contig();
592        unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
593    }
594
595    /// Simplified function to generate c-contiguous layout. See also
596    /// [DimLayoutContigAPI::new_c_contig].
597    fn c(&self) -> Layout<Self> {
598        self.new_c_contig(None)
599    }
600
601    /// Simplified function to generate f-contiguous layout. See also
602    /// [DimLayoutContigAPI::new_f_contig].
603    fn f(&self) -> Layout<Self> {
604        self.new_f_contig(None)
605    }
606
607    /// Generate new layout by providing shape, offset and order.
608    fn new_contig(&self, offset: Option<usize>, order: FlagOrder) -> Layout<Self> {
609        match order {
610            FlagOrder::C => self.new_c_contig(offset),
611            FlagOrder::F => self.new_f_contig(offset),
612        }
613    }
614}
615
616impl<const N: usize> DimLayoutContigAPI for Ix<N> {}
617impl DimLayoutContigAPI for IxD {}
618
619/* #endregion Layout */
620
621/* #region Dimension Conversion */
622
623pub trait DimIntoAPI<D>: DimBaseAPI
624where
625    D: DimBaseAPI,
626{
627    fn into_dim(layout: Layout<Self>) -> Result<Layout<D>>;
628}
629
630impl<D> DimIntoAPI<D> for IxD
631where
632    D: DimBaseAPI,
633{
634    fn into_dim(layout: Layout<IxD>) -> Result<Layout<D>> {
635        let shape = layout.shape().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
636        let stride = layout.stride().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
637        let offset = layout.offset();
638        return Ok(Layout { shape, stride, offset });
639    }
640}
641
642impl<const N: usize> DimIntoAPI<IxD> for Ix<N> {
643    fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<IxD>> {
644        let shape = (*layout.shape()).into();
645        let stride = (*layout.stride()).into();
646        let offset = layout.offset();
647        return Ok(Layout { shape, stride, offset });
648    }
649}
650
651impl<const N: usize, const M: usize> DimIntoAPI<Ix<M>> for Ix<N> {
652    fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<Ix<M>>> {
653        rstsr_assert_eq!(N, M, InvalidLayout)?;
654        let shape = layout.shape().to_vec().try_into().unwrap();
655        let stride = layout.stride().to_vec().try_into().unwrap();
656        let offset = layout.offset();
657        return Ok(Layout { shape, stride, offset });
658    }
659}
660
661impl<D> Layout<D>
662where
663    D: DimBaseAPI,
664{
665    /// Convert layout to another dimension.
666    pub fn into_dim<D2>(self) -> Result<Layout<D2>>
667    where
668        D2: DimBaseAPI,
669        D: DimIntoAPI<D2>,
670    {
671        D::into_dim(self)
672    }
673
674    /// Convert layout to another dimension.
675    pub fn to_dim<D2>(&self) -> Result<Layout<D2>>
676    where
677        D2: DimBaseAPI,
678        D: DimIntoAPI<D2>,
679    {
680        D::into_dim(self.clone())
681    }
682}
683
684impl<const N: usize> From<Ix<N>> for Layout<Ix<N>> {
685    fn from(shape: Ix<N>) -> Self {
686        let stride = shape.stride_contig();
687        Layout { shape, stride, offset: 0 }
688    }
689}
690
691impl From<IxD> for Layout<IxD> {
692    fn from(shape: IxD) -> Self {
693        let stride = shape.stride_contig();
694        Layout { shape, stride, offset: 0 }
695    }
696}
697
698/* #endregion */
699
700#[cfg(test)]
701mod test {
702    use std::panic::catch_unwind;
703
704    use super::*;
705
706    #[test]
707    fn test_layout_new() {
708        // a successful layout new
709        let shape = [3, 2, 6];
710        let stride = [3, -300, 15];
711        let layout = Layout::new(shape, stride, 917).unwrap();
712        assert_eq!(layout.shape(), &[3, 2, 6]);
713        assert_eq!(layout.stride(), &[3, -300, 15]);
714        assert_eq!(layout.offset(), 917);
715        assert_eq!(layout.ndim(), 3);
716        // unsuccessful layout new (offset underflow)
717        let shape = [3, 2, 6];
718        let stride = [3, -300, 15];
719        let layout = Layout::new(shape, stride, 0);
720        assert!(layout.is_err());
721        // unsuccessful layout new (stride too small)
722        let shape = [3, 2, 6];
723        let stride = [3, 4, 7];
724        let layout = Layout::new(shape, stride, 1000);
725        assert!(layout.is_err());
726        // successful layout new (zero stride for non-0/1 shape)
727        let shape = [3, 2, 6];
728        let stride = [3, -300, 0];
729        let layout = Layout::new(shape, stride, 1000);
730        assert!(layout.is_ok());
731        // successful layout new (zero dim)
732        let shape = [];
733        let stride = [];
734        let layout = Layout::new(shape, stride, 1000);
735        assert!(layout.is_ok());
736        // successful layout new (stride 0 for 1-shape)
737        let shape = [3, 1, 5];
738        let stride = [1, 0, 15];
739        let layout = Layout::new(shape, stride, 1);
740        assert!(layout.is_ok());
741        // successful layout new (stride 0 for 1-shape)
742        let shape = [3, 1, 5];
743        let stride = [1, 0, 15];
744        let layout = Layout::new(shape, stride, 1);
745        assert!(layout.is_ok());
746        // successful layout new (zero-size tensor)
747        let shape = [3, 0, 5];
748        let stride = [-1, -2, -3];
749        let layout = Layout::new(shape, stride, 1);
750        assert!(layout.is_ok());
751        // anyway, if one need custom layout, use new_unchecked
752        let shape = [3, 2, 6];
753        let stride = [3, -300, 0];
754        let r = catch_unwind(|| unsafe { Layout::new_unchecked(shape, stride, 1000) });
755        assert!(r.is_ok());
756    }
757
758    #[test]
759    fn test_is_f_prefer() {
760        // general case
761        let shape = [3, 5, 7];
762        let layout = Layout::new(shape, [1, 10, 100], 0).unwrap();
763        assert!(layout.f_prefer());
764        let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
765        assert!(layout.f_prefer());
766        let layout = Layout::new(shape, [1, 3, -15], 1000).unwrap();
767        assert!(!layout.f_prefer());
768        let layout = Layout::new(shape, [1, 21, 3], 0).unwrap();
769        assert!(!layout.f_prefer());
770        let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
771        assert!(!layout.f_prefer());
772        let layout = Layout::new(shape, [2, 6, 30], 0).unwrap();
773        assert!(!layout.f_prefer());
774        // zero dimension
775        let layout = Layout::new([], [], 0).unwrap();
776        assert!(layout.f_prefer());
777        // zero size
778        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
779        assert!(layout.f_prefer());
780        // shape with 1
781        let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
782        assert!(layout.f_prefer());
783    }
784
785    #[test]
786    fn test_is_c_prefer() {
787        // general case
788        let shape = [3, 5, 7];
789        let layout = Layout::new(shape, [100, 10, 1], 0).unwrap();
790        assert!(layout.c_prefer());
791        let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
792        assert!(layout.c_prefer());
793        let layout = Layout::new(shape, [-35, 7, 1], 1000).unwrap();
794        assert!(!layout.c_prefer());
795        let layout = Layout::new(shape, [7, 21, 1], 0).unwrap();
796        assert!(!layout.c_prefer());
797        let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
798        assert!(!layout.c_prefer());
799        let layout = Layout::new(shape, [70, 14, 2], 0).unwrap();
800        assert!(!layout.c_prefer());
801        // zero dimension
802        let layout = Layout::new([], [], 0).unwrap();
803        assert!(layout.c_prefer());
804        // zero size
805        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
806        assert!(layout.c_prefer());
807        // shape with 1
808        let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
809        assert!(layout.c_prefer());
810    }
811
812    #[test]
813    fn test_is_f_contig() {
814        // general case
815        let shape = [3, 5, 7];
816        let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
817        assert!(layout.f_contig());
818        let layout = Layout::new(shape, [1, 4, 20], 0).unwrap();
819        assert!(!layout.f_contig());
820        // zero dimension
821        let layout = Layout::new([], [], 0).unwrap();
822        assert!(layout.f_contig());
823        // zero size
824        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
825        assert!(layout.f_contig());
826        // shape with 1
827        let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
828        assert!(layout.f_contig());
829    }
830
831    #[test]
832    fn test_is_c_contig() {
833        // general case
834        let shape = [3, 5, 7];
835        let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
836        assert!(layout.c_contig());
837        let layout = Layout::new(shape, [36, 7, 1], 0).unwrap();
838        assert!(!layout.c_contig());
839        // zero dimension
840        let layout = Layout::new([], [], 0).unwrap();
841        assert!(layout.c_contig());
842        // zero size
843        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
844        assert!(layout.c_contig());
845        // shape with 1
846        let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
847        assert!(layout.c_contig());
848    }
849
850    #[test]
851    fn test_index() {
852        // a = np.arange(9 * 12 * 15)
853        //       .reshape(9, 12, 15)[4:2:-1, 4:10, 2:10:3]
854        //       .transpose(2, 0, 1)
855        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
856        assert_eq!(layout.index(&[0, 0, 0]), 782);
857        assert_eq!(layout.index(&[2, 1, 4]), 668);
858        assert_eq!(layout.index(&[1, -2, -3]), 830);
859        // zero-dim
860        let layout = Layout::new([], [], 10).unwrap();
861        assert_eq!(layout.index(&[]), 10);
862    }
863
864    #[test]
865    fn test_bounds_index() {
866        // a = np.arange(9 * 12 * 15)
867        //       .reshape(9, 12, 15)[4:2:-1, 4:10, 2:10:3]
868        //       .transpose(2, 0, 1)
869        // a.min() = 602, a.max() = 863
870        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
871        assert_eq!(layout.bounds_index().unwrap(), (602, 864));
872        // situation that fails
873        let layout = unsafe { Layout::new_unchecked([3, 2, 6], [3, -180, 15], 15) };
874        assert!(layout.bounds_index().is_err());
875        // zero-dim
876        let layout = Layout::new([], [], 10).unwrap();
877        assert_eq!(layout.bounds_index().unwrap(), (10, 11));
878    }
879
880    #[test]
881    fn test_transpose() {
882        // general
883        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
884        let trans = layout.transpose(&[2, 0, 1]).unwrap();
885        assert_eq!(trans.shape(), &[6, 3, 2]);
886        assert_eq!(trans.stride(), &[15, 3, -180]);
887        // permute_dims is alias of transpose
888        let trans = layout.permute_dims(&[2, 0, 1]).unwrap();
889        assert_eq!(trans.shape(), &[6, 3, 2]);
890        assert_eq!(trans.stride(), &[15, 3, -180]);
891        // negative axis also allowed
892        let trans = layout.transpose(&[-1, 0, 1]).unwrap();
893        assert_eq!(trans.shape(), &[6, 3, 2]);
894        assert_eq!(trans.stride(), &[15, 3, -180]);
895        // repeated axis
896        let trans = layout.transpose(&[-2, 0, 1]);
897        assert!(trans.is_err());
898        // non-valid dimension
899        let trans = layout.transpose(&[1, 0]);
900        assert!(trans.is_err());
901        // zero-dim
902        let layout = Layout::new([], [], 0).unwrap();
903        let trans = layout.transpose(&[]);
904        assert!(trans.is_ok());
905    }
906
907    #[test]
908    fn test_reverse_axes() {
909        // general
910        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
911        let trans = layout.reverse_axes();
912        assert_eq!(trans.shape(), &[6, 2, 3]);
913        assert_eq!(trans.stride(), &[15, -180, 3]);
914        // zero-dim
915        let layout = Layout::new([], [], 782).unwrap();
916        let trans = layout.reverse_axes();
917        assert_eq!(trans.shape(), &[]);
918        assert_eq!(trans.stride(), &[]);
919    }
920
921    #[test]
922    fn test_swapaxes() {
923        // general
924        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
925        let trans = layout.swapaxes(-1, -2).unwrap();
926        assert_eq!(trans.shape(), &[3, 6, 2]);
927        assert_eq!(trans.stride(), &[3, 15, -180]);
928        // same index is allowed
929        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
930        let trans = layout.swapaxes(-1, -1).unwrap();
931        assert_eq!(trans.shape(), &[3, 2, 6]);
932        assert_eq!(trans.stride(), &[3, -180, 15]);
933    }
934
935    #[test]
936    fn test_index_uncheck() {
937        // a = np.arange(9 * 12 * 15)
938        //       .reshape(9, 12, 15)[4:2:-1, 4:10, 2:10:3]
939        //       .transpose(2, 0, 1)
940        unsafe {
941            // fixed dim
942            let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
943            assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
944            assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
945            // dynamic dim
946            let layout = Layout::new(vec![3, 2, 6], vec![3, -180, 15], 782).unwrap();
947            assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
948            assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
949            // zero-dim
950            let layout = Layout::new([], [], 10).unwrap();
951            assert_eq!(layout.index_uncheck(&[]), 10);
952        }
953    }
954
955    #[test]
956    fn test_diagonal() {
957        let layout = [2, 3, 4].c();
958        let diag = layout.diagonal(None, None, None).unwrap();
959        assert_eq!(diag, Layout::new([4, 2], [1, 16], 0).unwrap());
960        let diag = layout.diagonal(Some(-1), Some(-2), Some(-1)).unwrap();
961        assert_eq!(diag, Layout::new([2, 2], [12, 5], 4).unwrap()); // fixed at issue 77
962        let diag = layout.diagonal(Some(-4), Some(-2), Some(-1)).unwrap();
963        assert_eq!(diag, Layout::new([2, 0], [12, 5], 0).unwrap());
964    }
965
966    #[test]
967    fn test_new_contig() {
968        let layout = [3, 2, 6].c();
969        assert_eq!(layout.shape(), &[3, 2, 6]);
970        assert_eq!(layout.stride(), &[12, 6, 1]);
971        let layout = [3, 2, 6].f();
972        assert_eq!(layout.shape(), &[3, 2, 6]);
973        assert_eq!(layout.stride(), &[1, 3, 6]);
974        // following code generates contiguous layout
975        // c/f-contig depends on cargo feature
976        let layout: Layout<_> = [3, 2, 6].into();
977        println!("{layout:?}");
978    }
979
980    #[test]
981    fn test_layout_cast() {
982        let layout = [3, 2, 6].c();
983        assert!(layout.clone().into_dim::<IxD>().is_ok());
984        assert!(layout.clone().into_dim::<Ix3>().is_ok());
985        let layout = vec![3, 2, 6].c();
986        assert!(layout.clone().into_dim::<IxD>().is_ok());
987        assert!(layout.clone().into_dim::<Ix3>().is_ok());
988        assert!(layout.clone().into_dim::<Ix2>().is_err());
989    }
990
991    #[test]
992    fn test_unravel_index() {
993        unsafe {
994            let shape = [3, 2, 6];
995            assert_eq!(shape.unravel_index_f(0), [0, 0, 0]);
996            assert_eq!(shape.unravel_index_f(16), [1, 1, 2]);
997            assert_eq!(shape.unravel_index_c(0), [0, 0, 0]);
998            assert_eq!(shape.unravel_index_c(16), [1, 0, 4]);
999        }
1000    }
1001
1002    #[test]
1003    fn fix_too_strict_stride_check() {
1004        let layout = [10, 11, 12].c();
1005        let slc = (.., slice!(-1, 0, -4));
1006        let slc: AxesIndex<Indexer> = slc.try_into().unwrap();
1007        let indexed = layout.dim_slice(slc.as_ref()).unwrap();
1008        assert_eq!(indexed.shape(), &[10, 3, 12]);
1009        assert_eq!(indexed.stride(), &[132, -48, 1]);
1010    }
1011}