rstsr_common/layout/
layoutbase.rs

1//! Layout of tensor.
2use crate::prelude_dev::*;
3use itertools::izip;
4
5/* #region Struct Definitions */
6
7/// Layout of tensor.
8///
9/// Layout is a struct that contains shape, stride, and offset of tensor.
10/// - Shape is the size of each dimension of tensor.
11/// - Stride is the number of elements to skip to get to the next element in
12///   each dimension.
13/// - Offset is the starting position of tensor.
14#[doc = include_str!("readme.md")]
15#[derive(Clone)]
16pub struct Layout<D>
17where
18    D: DimBaseAPI,
19{
20    // essential definitions to layout
21    pub(crate) shape: D,
22    pub(crate) stride: D::Stride,
23    pub(crate) offset: usize,
24}
25
26unsafe impl<D> Send for Layout<D> where D: DimBaseAPI {}
27unsafe impl<D> Sync for Layout<D> where D: DimBaseAPI {}
28
29/* #endregion */
30
31/* #region Layout */
32
33/// Getter/setter functions for layout.
34impl<D> Layout<D>
35where
36    D: DimBaseAPI,
37{
38    /// Shape of tensor. Getter function.
39    #[inline]
40    pub fn shape(&self) -> &D {
41        &self.shape
42    }
43
44    /// Stride of tensor. Getter function.
45    #[inline]
46    pub fn stride(&self) -> &D::Stride {
47        &self.stride
48    }
49
50    /// Starting offset of tensor. Getter function.
51    #[inline]
52    pub fn offset(&self) -> usize {
53        self.offset
54    }
55
56    /// Number of dimensions of tensor.
57    #[inline]
58    pub fn ndim(&self) -> usize {
59        self.shape.ndim()
60    }
61
62    /// Total number of elements in tensor.
63    ///
64    /// # Note
65    ///
66    /// This function uses cached size, instead of evaluating from shape.
67    #[inline]
68    pub fn size(&self) -> usize {
69        self.shape().as_ref().iter().product()
70    }
71
72    /// Manually set offset.
73    ///
74    /// # Safety
75    ///
76    /// We will not check whether this offset is valid or not.
77    /// In most cases, it is not intended to be used by user.
78    pub unsafe fn set_offset(&mut self, offset: usize) -> &mut Self {
79        self.offset = offset;
80        return self;
81    }
82}
83
84/// Properties of layout.
85impl<D> Layout<D>
86where
87    D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
88{
89    /// Whether this tensor is f-preferred.
90    pub fn f_prefer(&self) -> bool {
91        // always true for 0-dimension or 0-size tensor
92        if self.ndim() == 0 || self.size() == 0 {
93            return true;
94        }
95
96        let stride = self.stride.as_ref();
97        let shape = self.shape.as_ref();
98        let mut last = 0;
99        for (&s, &d) in stride.iter().zip(shape.iter()) {
100            if d != 1 {
101                if s < last {
102                    // latter strides must larger than previous strides
103                    return false;
104                }
105                if last == 0 && s != 1 {
106                    // first stride must be 1
107                    return false;
108                }
109                last = s;
110            } else if last == 0 {
111                // if dimension is one, then consider that stride is one, counted as contiguous
112                // in last dimension
113                last = 1;
114            }
115        }
116        return true;
117    }
118
119    /// Whether this tensor is c-preferred.
120    pub fn c_prefer(&self) -> bool {
121        // always true for 0-dimension or 0-size tensor
122        if self.ndim() == 0 || self.size() == 0 {
123            return true;
124        }
125
126        let stride = self.stride.as_ref();
127        let shape = self.shape.as_ref();
128        let mut last = 0;
129        for (&s, &d) in stride.iter().zip(shape.iter()).rev() {
130            if d != 1 {
131                if s < last {
132                    // previous strides must larger than latter strides
133                    return false;
134                }
135                if last == 0 && s != 1 {
136                    // last stride must be 1
137                    return false;
138                }
139                last = s;
140            } else if last == 0 {
141                // if dimension is one, then consider that stride is one, counted as contiguous
142                // in last dimension
143                last = 1;
144            }
145        }
146        return true;
147    }
148
149    /// Least number of dimensions that is f-contiguous for layout.
150    ///
151    /// This function can be useful determining when to iterate by contiguous,
152    /// and when to iterate by index.
153    pub fn ndim_of_f_contig(&self) -> usize {
154        if self.ndim() == 0 || self.size() == 0 {
155            return self.ndim();
156        }
157        let stride = self.stride.as_ref();
158        let shape = self.shape.as_ref();
159        let mut acc = 1;
160        for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).enumerate() {
161            if d != 1 && s != acc {
162                return ndim;
163            }
164            acc *= d as isize;
165        }
166        return self.ndim();
167    }
168
169    /// Least number of dimensions that is c-contiguous for layout.
170    ///
171    /// This function can be useful determining when to iterate by contiguous,
172    /// and when to iterate by index.
173    pub fn ndim_of_c_contig(&self) -> usize {
174        if self.ndim() == 0 || self.size() == 0 {
175            return self.ndim();
176        }
177        let stride = self.stride.as_ref();
178        let shape = self.shape.as_ref();
179        let mut acc = 1;
180        for (ndim, (&s, &d)) in stride.iter().zip(shape.iter()).rev().enumerate() {
181            if d != 1 && s != acc {
182                return ndim;
183            }
184            acc *= d as isize;
185        }
186        return self.ndim();
187    }
188
189    /// Whether this tensor is f-contiguous.
190    ///
191    /// Special cases
192    /// - When length of a dimension is one, then stride to that dimension is
193    ///   not important.
194    /// - When length of a dimension is zero, then tensor contains no elements,
195    ///   thus f-contiguous.
196    pub fn f_contig(&self) -> bool {
197        self.ndim() == self.ndim_of_f_contig()
198    }
199
200    /// Whether this tensor is c-contiguous.
201    ///
202    /// Special cases
203    /// - When length of a dimension is one, then stride to that dimension is
204    ///   not important.
205    /// - When length of a dimension is zero, then tensor contains no elements,
206    ///   thus c-contiguous.
207    pub fn c_contig(&self) -> bool {
208        self.ndim() == self.ndim_of_c_contig()
209    }
210
211    /// Index of tensor by list of indexes to dimensions.
212    ///
213    /// This function does not optimized for performance.
214    pub fn index_f(&self, index: &[isize]) -> Result<usize> {
215        rstsr_assert_eq!(index.len(), self.ndim(), InvalidLayout)?;
216        let mut pos = self.offset() as isize;
217        let shape = self.shape.as_ref();
218        let stride = self.stride.as_ref();
219
220        for (&idx, &shp, &strd) in izip!(index.iter(), shape.iter(), stride.iter()) {
221            let idx = if idx < 0 { idx + shp as isize } else { idx };
222            rstsr_pattern!(idx, 0..(shp as isize), ValueOutOfRange)?;
223            pos += strd * idx;
224        }
225        rstsr_pattern!(pos, 0.., ValueOutOfRange)?;
226        return Ok(pos as usize);
227    }
228
229    /// Index of tensor by list of indexes to dimensions.
230    ///
231    /// This function does not optimized for performance. Negative index
232    /// allowed.
233    ///
234    /// # Panics
235    ///
236    /// - Index greater than shape
237    pub fn index(&self, index: &[isize]) -> usize {
238        self.index_f(index).unwrap()
239    }
240
241    /// Index range bounds of current layout. This bound is [min, max), which
242    /// could be feed into range (min..max). If min == max, then this layout
243    /// should not contains any element.
244    ///
245    /// This function will raise error when minimum index is smaller than zero.
246    pub fn bounds_index(&self) -> Result<(usize, usize)> {
247        let n = self.ndim();
248        let offset = self.offset;
249        let shape = self.shape.as_ref();
250        let stride = self.stride.as_ref();
251
252        if n == 0 {
253            return Ok((offset, offset + 1));
254        }
255
256        let mut min = offset as isize;
257        let mut max = offset as isize;
258
259        for i in 0..n {
260            if shape[i] == 0 {
261                return Ok((offset, offset));
262            }
263            if stride[i] > 0 {
264                max += stride[i] * (shape[i] as isize - 1);
265            } else {
266                min += stride[i] * (shape[i] as isize - 1);
267            }
268        }
269        rstsr_pattern!(min, 0.., ValueOutOfRange)?;
270        return Ok((min as usize, max as usize + 1));
271    }
272
273    /// Check if strides is correct (no elemenets can overlap).
274    ///
275    /// This will check if all number of elements in dimension of small strides
276    /// is less than larger strides. For example of valid stride:
277    /// ```output
278    /// shape:  (3,    2,  6)  -> sorted ->  ( 3,   6,   2)
279    /// stride: (3, -300, 15)  -> sorted ->  ( 3,  15, 300)
280    /// number of elements:                    9,  90,
281    /// stride of next dimension              15, 300,
282    /// number of elem < stride of next dim?   +,   +,
283    /// ```
284    ///
285    /// Special cases
286    /// - if length of tensor is zero, then strides will always be correct.
287    /// - if certain dimension is one, then check for this stride will be
288    ///   ignored.
289    ///
290    /// # TODO
291    ///
292    /// Correctness of this function is not fully ensured.
293    pub fn check_strides(&self) -> Result<()> {
294        let shape = self.shape.as_ref();
295        let stride = self.stride.as_ref();
296        rstsr_assert_eq!(shape.len(), stride.len(), InvalidLayout)?;
297        let n = shape.len();
298
299        // unconditionally ok if no elements (length of tensor is zero)
300        // unconditionally ok if 0-dimension
301        if self.size() == 0 || n == 0 {
302            return Ok(());
303        }
304
305        let mut indices = (0..n).filter(|&k| shape[k] > 1).collect::<Vec<_>>();
306        indices.sort_by_key(|&k| stride[k].abs());
307        let shape_sorted = indices.iter().map(|&k| shape[k]).collect::<Vec<_>>();
308        let stride_sorted = indices.iter().map(|&k| stride[k].unsigned_abs()).collect::<Vec<_>>();
309
310        // elem_cum: cumulative number count of elements in tensor for small strides
311        let mut elem_cum = 0;
312        for i in 0..indices.len() {
313            // following function also checks that stride could not be zero
314            rstsr_pattern!(
315                elem_cum,
316                0..stride_sorted[i],
317                InvalidLayout,
318                "Either stride be zero, or stride too small that elements in tensor can be overlapped."
319            )?;
320            elem_cum += (shape_sorted[i] - 1) * stride_sorted[i];
321        }
322        return Ok(());
323    }
324
325    pub fn diagonal(
326        &self,
327        offset: Option<isize>,
328        axis1: Option<isize>,
329        axis2: Option<isize>,
330    ) -> Result<Layout<<D as DimSmallerOneAPI>::SmallerOne>>
331    where
332        D: DimSmallerOneAPI,
333    {
334        // check if this layout is at least 2-dimension
335        rstsr_assert!(self.ndim() >= 2, InvalidLayout)?;
336        // unwrap optional parameters
337        let offset = offset.unwrap_or(0);
338        let axis1 = axis1.unwrap_or(0);
339        let axis2 = axis2.unwrap_or(1);
340        let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
341        let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
342        rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
343        rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
344        let axis1 = axis1 as usize;
345        let axis2 = axis2 as usize;
346
347        // shape and strides of last two dimensions
348        let d1 = self.shape()[axis1] as isize;
349        let d2 = self.shape()[axis2] as isize;
350        let t1 = self.stride()[axis1];
351        let t2 = self.stride()[axis2];
352
353        // number of elements in diagonal, and starting offset
354        let (offset_diag, d_diag) = if (-d2 + 1..0).contains(&offset) {
355            let offset = -offset;
356            let offset_diag = (self.offset() as isize + t1 * offset) as usize;
357            let d_diag = (d1 - offset).min(d2) as usize;
358            (offset_diag, d_diag)
359        } else if (0..d1).contains(&offset) {
360            let offset_diag = (self.offset() as isize + t2 * offset) as usize;
361            let d_diag = (d2 - offset).min(d1) as usize;
362            (offset_diag, d_diag)
363        } else {
364            (self.offset(), 0)
365        };
366
367        // build new layout
368        let t_diag = t1 + t2;
369        let mut shape_diag = vec![];
370        let mut stride_diag = vec![];
371        for i in 0..self.ndim() {
372            if i != axis1 && i != axis2 {
373                shape_diag.push(self.shape()[i]);
374                stride_diag.push(self.stride()[i]);
375            }
376        }
377        shape_diag.push(d_diag);
378        stride_diag.push(t_diag);
379        let layout_diag = Layout::new(shape_diag, stride_diag, offset_diag)?;
380        return layout_diag.into_dim::<<D as DimSmallerOneAPI>::SmallerOne>();
381    }
382}
383
384/// Constructors of layout. See also [`DimLayoutContigAPI`] layout from shape
385/// directly.
386impl<D> Layout<D>
387where
388    D: DimBaseAPI,
389{
390    /// Generate new layout by providing everything.
391    ///
392    /// # Error when
393    ///
394    /// - Shape and stride length mismatch
395    /// - Strides is correct (no elements can overlap)
396    /// - Minimum bound is not negative
397    #[inline]
398    pub fn new(shape: D, stride: D::Stride, offset: usize) -> Result<Self>
399    where
400        D: DimShapeAPI + DimStrideAPI,
401    {
402        let layout = unsafe { Layout::new_unchecked(shape, stride, offset) };
403        layout.bounds_index()?;
404        layout.check_strides()?;
405        return Ok(layout);
406    }
407
408    /// Generate new layout by providing everything, without checking bounds and
409    /// strides.
410    ///
411    /// # Safety
412    ///
413    /// This function does not check whether layout is valid.
414    #[inline]
415    pub unsafe fn new_unchecked(shape: D, stride: D::Stride, offset: usize) -> Self {
416        Layout { shape, stride, offset }
417    }
418
419    /// New zero shape, which number of dimensions are the same to current
420    /// layout.
421    #[inline]
422    pub fn new_shape(&self) -> D {
423        self.shape.new_shape()
424    }
425
426    /// New zero stride, which number of dimensions are the same to current
427    /// layout.
428    #[inline]
429    pub fn new_stride(&self) -> D::Stride {
430        self.shape.new_stride()
431    }
432}
433
434/// Manuplation of layout.
435impl<D> Layout<D>
436where
437    D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
438{
439    /// Transpose layout by permutation.
440    ///
441    /// # See also
442    ///
443    /// - [`numpy.transpose`](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html)
444    /// - [Python array API: `permute_dims`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.permute_dims.html)
445    pub fn transpose(&self, axes: &[isize]) -> Result<Self> {
446        // check axes and cast to usize
447        let n = self.ndim();
448        rstsr_assert_eq!(
449            axes.len(),
450            n,
451            InvalidLayout,
452            "number of elements in axes should be the same to number of dimensions."
453        )?;
454        // no elements in axes can be the same
455        let mut permut_used = vec![false; n];
456        for &p in axes {
457            let p = if p < 0 { p + n as isize } else { p };
458            rstsr_pattern!(p, 0..n as isize, InvalidLayout)?;
459            let p = p as usize;
460            permut_used[p] = true;
461        }
462        rstsr_assert!(
463            permut_used.iter().all(|&b| b),
464            InvalidLayout,
465            "axes should contain all elements from 0 to n-1."
466        )?;
467        let axes = axes
468            .iter()
469            .map(|&p| if p < 0 { p + n as isize } else { p } as usize)
470            .collect::<Vec<_>>();
471
472        let shape_old = self.shape();
473        let stride_old = self.stride();
474        let mut shape = self.new_shape();
475        let mut stride = self.new_stride();
476        for i in 0..self.ndim() {
477            shape[i] = shape_old[axes[i]];
478            stride[i] = stride_old[axes[i]];
479        }
480        return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
481    }
482
483    /// Transpose layout by permutation.
484    ///
485    /// This is the same function to [`Layout::transpose`]
486    pub fn permute_dims(&self, axes: &[isize]) -> Result<Self> {
487        self.transpose(axes)
488    }
489
490    /// Reverse axes of layout.
491    pub fn reverse_axes(&self) -> Self {
492        let shape_old = self.shape();
493        let stride_old = self.stride();
494        let mut shape = self.new_shape();
495        let mut stride = self.new_stride();
496        for i in 0..self.ndim() {
497            shape[i] = shape_old[self.ndim() - i - 1];
498            stride[i] = stride_old[self.ndim() - i - 1];
499        }
500        return unsafe { Layout::new_unchecked(shape, stride, self.offset) };
501    }
502
503    /// Swap axes of layout.
504    pub fn swapaxes(&self, axis1: isize, axis2: isize) -> Result<Self> {
505        let axis1 = if axis1 < 0 { self.ndim() as isize + axis1 } else { axis1 };
506        rstsr_pattern!(axis1, 0..self.ndim() as isize, ValueOutOfRange)?;
507        let axis1 = axis1 as usize;
508
509        let axis2 = if axis2 < 0 { self.ndim() as isize + axis2 } else { axis2 };
510        rstsr_pattern!(axis2, 0..self.ndim() as isize, ValueOutOfRange)?;
511        let axis2 = axis2 as usize;
512
513        let mut shape = self.shape().clone();
514        let mut stride = self.stride().clone();
515        shape.as_mut().swap(axis1, axis2);
516        stride.as_mut().swap(axis1, axis2);
517        return unsafe { Ok(Layout::new_unchecked(shape, stride, self.offset)) };
518    }
519}
520
521/// Fast indexing and utilities of layout.
522///
523/// These functions are mostly internal to this crate.
524impl<D> Layout<D>
525where
526    D: DimBaseAPI + DimShapeAPI + DimStrideAPI,
527{
528    /// Index of tensor by list of indexes to dimensions.
529    ///
530    /// # Safety
531    ///
532    /// This function does not check for bounds, including
533    /// - Negative index
534    /// - Index greater than shape
535    ///
536    /// Due to these reasons, this function may well give index smaller than
537    /// zero, which may occur in iterator; so this function returns isize.
538    #[inline]
539    pub unsafe fn index_uncheck(&self, index: &[usize]) -> isize {
540        let stride = self.stride.as_ref();
541        match self.ndim() {
542            0 => self.offset as isize,
543            1 => self.offset as isize + stride[0] * index[0] as isize,
544            2 => {
545                self.offset as isize + stride[0] * index[0] as isize + stride[1] * index[1] as isize
546            },
547            3 => {
548                self.offset as isize
549                    + stride[0] * index[0] as isize
550                    + stride[1] * index[1] as isize
551                    + stride[2] * index[2] as isize
552            },
553            4 => {
554                self.offset as isize
555                    + stride[0] * index[0] as isize
556                    + stride[1] * index[1] as isize
557                    + stride[2] * index[2] as isize
558                    + stride[3] * index[3] as isize
559            },
560            _ => {
561                let mut pos = self.offset as isize;
562                stride.iter().zip(index.iter()).for_each(|(&s, &i)| pos += s * i as isize);
563                pos
564            },
565        }
566    }
567}
568
569impl<D> PartialEq for Layout<D>
570where
571    D: DimBaseAPI,
572{
573    /// For layout, shape must be the same, while stride should be the same when
574    /// shape is not zero or one, but can be arbitary otherwise.
575    fn eq(&self, other: &Self) -> bool {
576        if self.ndim() != other.ndim() {
577            return false;
578        }
579        for i in 0..self.ndim() {
580            let s1 = self.shape()[i];
581            let s2 = other.shape()[i];
582            if s1 != s2 {
583                return false;
584            }
585            if s1 != 1 && s1 != 0 && self.stride()[i] != other.stride()[i] {
586                return false;
587            }
588        }
589        return true;
590    }
591}
592
593pub trait DimLayoutContigAPI: DimBaseAPI + DimShapeAPI + DimStrideAPI {
594    /// Generate new layout by providing shape and offset; stride fits into
595    /// c-contiguous.
596    fn new_c_contig(&self, offset: Option<usize>) -> Layout<Self> {
597        let shape = self.clone();
598        let stride = shape.stride_c_contig();
599        unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
600    }
601
602    /// Generate new layout by providing shape and offset; stride fits into
603    /// f-contiguous.
604    fn new_f_contig(&self, offset: Option<usize>) -> Layout<Self> {
605        let shape = self.clone();
606        let stride = shape.stride_f_contig();
607        unsafe { Layout::new_unchecked(shape, stride, offset.unwrap_or(0)) }
608    }
609
610    /// Simplified function to generate c-contiguous layout. See also
611    /// [DimLayoutContigAPI::new_c_contig].
612    fn c(&self) -> Layout<Self> {
613        self.new_c_contig(None)
614    }
615
616    /// Simplified function to generate f-contiguous layout. See also
617    /// [DimLayoutContigAPI::new_f_contig].
618    fn f(&self) -> Layout<Self> {
619        self.new_f_contig(None)
620    }
621
622    /// Generate new layout by providing shape, offset and order.
623    fn new_contig(&self, offset: Option<usize>, order: FlagOrder) -> Layout<Self> {
624        match order {
625            FlagOrder::C => self.new_c_contig(offset),
626            FlagOrder::F => self.new_f_contig(offset),
627        }
628    }
629}
630
631impl<const N: usize> DimLayoutContigAPI for Ix<N> {}
632impl DimLayoutContigAPI for IxD {}
633
634/* #endregion Layout */
635
636/* #region Dimension Conversion */
637
638pub trait DimIntoAPI<D>: DimBaseAPI
639where
640    D: DimBaseAPI,
641{
642    fn into_dim(layout: Layout<Self>) -> Result<Layout<D>>;
643}
644
645impl<D> DimIntoAPI<D> for IxD
646where
647    D: DimBaseAPI,
648{
649    fn into_dim(layout: Layout<IxD>) -> Result<Layout<D>> {
650        let shape = layout.shape().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
651        let stride = layout.stride().clone().try_into().map_err(|_| rstsr_error!(InvalidLayout))?;
652        let offset = layout.offset();
653        return Ok(Layout { shape, stride, offset });
654    }
655}
656
657impl<const N: usize> DimIntoAPI<IxD> for Ix<N> {
658    fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<IxD>> {
659        let shape = (*layout.shape()).into();
660        let stride = (*layout.stride()).into();
661        let offset = layout.offset();
662        return Ok(Layout { shape, stride, offset });
663    }
664}
665
666impl<const N: usize, const M: usize> DimIntoAPI<Ix<M>> for Ix<N> {
667    fn into_dim(layout: Layout<Ix<N>>) -> Result<Layout<Ix<M>>> {
668        rstsr_assert_eq!(N, M, InvalidLayout)?;
669        let shape = layout.shape().to_vec().try_into().unwrap();
670        let stride = layout.stride().to_vec().try_into().unwrap();
671        let offset = layout.offset();
672        return Ok(Layout { shape, stride, offset });
673    }
674}
675
676impl<D> Layout<D>
677where
678    D: DimBaseAPI,
679{
680    /// Convert layout to another dimension.
681    pub fn into_dim<D2>(self) -> Result<Layout<D2>>
682    where
683        D2: DimBaseAPI,
684        D: DimIntoAPI<D2>,
685    {
686        D::into_dim(self)
687    }
688
689    /// Convert layout to another dimension.
690    pub fn to_dim<D2>(&self) -> Result<Layout<D2>>
691    where
692        D2: DimBaseAPI,
693        D: DimIntoAPI<D2>,
694    {
695        D::into_dim(self.clone())
696    }
697}
698
699impl<const N: usize> From<Ix<N>> for Layout<Ix<N>> {
700    fn from(shape: Ix<N>) -> Self {
701        let stride = shape.stride_contig();
702        Layout { shape, stride, offset: 0 }
703    }
704}
705
706impl From<IxD> for Layout<IxD> {
707    fn from(shape: IxD) -> Self {
708        let stride = shape.stride_contig();
709        Layout { shape, stride, offset: 0 }
710    }
711}
712
713/* #endregion */
714
715#[cfg(test)]
716mod test {
717    use std::panic::catch_unwind;
718
719    use super::*;
720
721    #[test]
722    fn test_layout_new() {
723        // a successful layout new
724        let shape = [3, 2, 6];
725        let stride = [3, -300, 15];
726        let layout = Layout::new(shape, stride, 917).unwrap();
727        assert_eq!(layout.shape(), &[3, 2, 6]);
728        assert_eq!(layout.stride(), &[3, -300, 15]);
729        assert_eq!(layout.offset(), 917);
730        assert_eq!(layout.ndim(), 3);
731        // unsuccessful layout new (offset underflow)
732        let shape = [3, 2, 6];
733        let stride = [3, -300, 15];
734        let layout = Layout::new(shape, stride, 0);
735        assert!(layout.is_err());
736        // unsuccessful layout new (zero stride for non-0/1 shape)
737        let shape = [3, 2, 6];
738        let stride = [3, -300, 0];
739        let layout = Layout::new(shape, stride, 1000);
740        assert!(layout.is_err());
741        // unsuccessful layout new (stride too small)
742        let shape = [3, 2, 6];
743        let stride = [3, 4, 7];
744        let layout = Layout::new(shape, stride, 1000);
745        assert!(layout.is_err());
746        // successful layout new (zero dim)
747        let shape = [];
748        let stride = [];
749        let layout = Layout::new(shape, stride, 1000);
750        assert!(layout.is_ok());
751        // successful layout new (stride 0 for 1-shape)
752        let shape = [3, 1, 5];
753        let stride = [1, 0, 15];
754        let layout = Layout::new(shape, stride, 1);
755        assert!(layout.is_ok());
756        // successful layout new (stride 0 for 1-shape)
757        let shape = [3, 1, 5];
758        let stride = [1, 0, 15];
759        let layout = Layout::new(shape, stride, 1);
760        assert!(layout.is_ok());
761        // successful layout new (zero-size tensor)
762        let shape = [3, 0, 5];
763        let stride = [-1, -2, -3];
764        let layout = Layout::new(shape, stride, 1);
765        assert!(layout.is_ok());
766        // anyway, if one need custom layout, use new_unchecked
767        let shape = [3, 2, 6];
768        let stride = [3, -300, 0];
769        let r = catch_unwind(|| unsafe { Layout::new_unchecked(shape, stride, 1000) });
770        assert!(r.is_ok());
771    }
772
773    #[test]
774    fn test_is_f_prefer() {
775        // general case
776        let shape = [3, 5, 7];
777        let layout = Layout::new(shape, [1, 10, 100], 0).unwrap();
778        assert!(layout.f_prefer());
779        let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
780        assert!(layout.f_prefer());
781        let layout = Layout::new(shape, [1, 3, -15], 1000).unwrap();
782        assert!(!layout.f_prefer());
783        let layout = Layout::new(shape, [1, 21, 3], 0).unwrap();
784        assert!(!layout.f_prefer());
785        let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
786        assert!(!layout.f_prefer());
787        let layout = Layout::new(shape, [2, 6, 30], 0).unwrap();
788        assert!(!layout.f_prefer());
789        // zero dimension
790        let layout = Layout::new([], [], 0).unwrap();
791        assert!(layout.f_prefer());
792        // zero size
793        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
794        assert!(layout.f_prefer());
795        // shape with 1
796        let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
797        assert!(layout.f_prefer());
798    }
799
800    #[test]
801    fn test_is_c_prefer() {
802        // general case
803        let shape = [3, 5, 7];
804        let layout = Layout::new(shape, [100, 10, 1], 0).unwrap();
805        assert!(layout.c_prefer());
806        let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
807        assert!(layout.c_prefer());
808        let layout = Layout::new(shape, [-35, 7, 1], 1000).unwrap();
809        assert!(!layout.c_prefer());
810        let layout = Layout::new(shape, [7, 21, 1], 0).unwrap();
811        assert!(!layout.c_prefer());
812        let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
813        assert!(!layout.c_prefer());
814        let layout = Layout::new(shape, [70, 14, 2], 0).unwrap();
815        assert!(!layout.c_prefer());
816        // zero dimension
817        let layout = Layout::new([], [], 0).unwrap();
818        assert!(layout.c_prefer());
819        // zero size
820        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
821        assert!(layout.c_prefer());
822        // shape with 1
823        let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
824        assert!(layout.c_prefer());
825    }
826
827    #[test]
828    fn test_is_f_contig() {
829        // general case
830        let shape = [3, 5, 7];
831        let layout = Layout::new(shape, [1, 3, 15], 0).unwrap();
832        assert!(layout.f_contig());
833        let layout = Layout::new(shape, [1, 4, 20], 0).unwrap();
834        assert!(!layout.f_contig());
835        // zero dimension
836        let layout = Layout::new([], [], 0).unwrap();
837        assert!(layout.f_contig());
838        // zero size
839        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
840        assert!(layout.f_contig());
841        // shape with 1
842        let layout = Layout::new([2, 1, 4], [1, 1, 2], 0).unwrap();
843        assert!(layout.f_contig());
844    }
845
846    #[test]
847    fn test_is_c_contig() {
848        // general case
849        let shape = [3, 5, 7];
850        let layout = Layout::new(shape, [35, 7, 1], 0).unwrap();
851        assert!(layout.c_contig());
852        let layout = Layout::new(shape, [36, 7, 1], 0).unwrap();
853        assert!(!layout.c_contig());
854        // zero dimension
855        let layout = Layout::new([], [], 0).unwrap();
856        assert!(layout.c_contig());
857        // zero size
858        let layout = Layout::new([2, 0, 4], [1, 10, 100], 0).unwrap();
859        assert!(layout.c_contig());
860        // shape with 1
861        let layout = Layout::new([2, 1, 4], [4, 1, 1], 0).unwrap();
862        assert!(layout.c_contig());
863    }
864
865    #[test]
866    fn test_index() {
867        // a = np.arange(9 * 12 * 15)
868        //       .reshape(9, 12, 15)[4:2:-1, 4:10, 2:10:3]
869        //       .transpose(2, 0, 1)
870        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
871        assert_eq!(layout.index(&[0, 0, 0]), 782);
872        assert_eq!(layout.index(&[2, 1, 4]), 668);
873        assert_eq!(layout.index(&[1, -2, -3]), 830);
874        // zero-dim
875        let layout = Layout::new([], [], 10).unwrap();
876        assert_eq!(layout.index(&[]), 10);
877    }
878
879    #[test]
880    fn test_bounds_index() {
881        // a = np.arange(9 * 12 * 15)
882        //       .reshape(9, 12, 15)[4:2:-1, 4:10, 2:10:3]
883        //       .transpose(2, 0, 1)
884        // a.min() = 602, a.max() = 863
885        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
886        assert_eq!(layout.bounds_index().unwrap(), (602, 864));
887        // situation that fails
888        let layout = unsafe { Layout::new_unchecked([3, 2, 6], [3, -180, 15], 15) };
889        assert!(layout.bounds_index().is_err());
890        // zero-dim
891        let layout = Layout::new([], [], 10).unwrap();
892        assert_eq!(layout.bounds_index().unwrap(), (10, 11));
893    }
894
895    #[test]
896    fn test_transpose() {
897        // general
898        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
899        let trans = layout.transpose(&[2, 0, 1]).unwrap();
900        assert_eq!(trans.shape(), &[6, 3, 2]);
901        assert_eq!(trans.stride(), &[15, 3, -180]);
902        // permute_dims is alias of transpose
903        let trans = layout.permute_dims(&[2, 0, 1]).unwrap();
904        assert_eq!(trans.shape(), &[6, 3, 2]);
905        assert_eq!(trans.stride(), &[15, 3, -180]);
906        // negative axis also allowed
907        let trans = layout.transpose(&[-1, 0, 1]).unwrap();
908        assert_eq!(trans.shape(), &[6, 3, 2]);
909        assert_eq!(trans.stride(), &[15, 3, -180]);
910        // repeated axis
911        let trans = layout.transpose(&[-2, 0, 1]);
912        assert!(trans.is_err());
913        // non-valid dimension
914        let trans = layout.transpose(&[1, 0]);
915        assert!(trans.is_err());
916        // zero-dim
917        let layout = Layout::new([], [], 0).unwrap();
918        let trans = layout.transpose(&[]);
919        assert!(trans.is_ok());
920    }
921
922    #[test]
923    fn test_reverse_axes() {
924        // general
925        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
926        let trans = layout.reverse_axes();
927        assert_eq!(trans.shape(), &[6, 2, 3]);
928        assert_eq!(trans.stride(), &[15, -180, 3]);
929        // zero-dim
930        let layout = Layout::new([], [], 782).unwrap();
931        let trans = layout.reverse_axes();
932        assert_eq!(trans.shape(), &[]);
933        assert_eq!(trans.stride(), &[]);
934    }
935
936    #[test]
937    fn test_swapaxes() {
938        // general
939        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
940        let trans = layout.swapaxes(-1, -2).unwrap();
941        assert_eq!(trans.shape(), &[3, 6, 2]);
942        assert_eq!(trans.stride(), &[3, 15, -180]);
943        // same index is allowed
944        let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
945        let trans = layout.swapaxes(-1, -1).unwrap();
946        assert_eq!(trans.shape(), &[3, 2, 6]);
947        assert_eq!(trans.stride(), &[3, -180, 15]);
948    }
949
950    #[test]
951    fn test_index_uncheck() {
952        // a = np.arange(9 * 12 * 15)
953        //       .reshape(9, 12, 15)[4:2:-1, 4:10, 2:10:3]
954        //       .transpose(2, 0, 1)
955        unsafe {
956            // fixed dim
957            let layout = Layout::new([3, 2, 6], [3, -180, 15], 782).unwrap();
958            assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
959            assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
960            // dynamic dim
961            let layout = Layout::new(vec![3, 2, 6], vec![3, -180, 15], 782).unwrap();
962            assert_eq!(layout.index_uncheck(&[0, 0, 0]), 782);
963            assert_eq!(layout.index_uncheck(&[2, 1, 4]), 668);
964            // zero-dim
965            let layout = Layout::new([], [], 10).unwrap();
966            assert_eq!(layout.index_uncheck(&[]), 10);
967        }
968    }
969
970    #[test]
971    fn test_diagonal() {
972        let layout = [2, 3, 4].c();
973        let diag = layout.diagonal(None, None, None).unwrap();
974        assert_eq!(diag, Layout::new([4, 2], [1, 16], 0).unwrap());
975        let diag = layout.diagonal(Some(-1), Some(-2), Some(-1)).unwrap();
976        assert_eq!(diag, Layout::new([2, 2], [12, 5], 0).unwrap());
977        let diag = layout.diagonal(Some(-4), Some(-2), Some(-1)).unwrap();
978        assert_eq!(diag, Layout::new([2, 0], [12, 5], 0).unwrap());
979    }
980
981    #[test]
982    fn test_new_contig() {
983        let layout = [3, 2, 6].c();
984        assert_eq!(layout.shape(), &[3, 2, 6]);
985        assert_eq!(layout.stride(), &[12, 6, 1]);
986        let layout = [3, 2, 6].f();
987        assert_eq!(layout.shape(), &[3, 2, 6]);
988        assert_eq!(layout.stride(), &[1, 3, 6]);
989        // following code generates contiguous layout
990        // c/f-contig depends on cargo feature
991        let layout: Layout<_> = [3, 2, 6].into();
992        println!("{layout:?}");
993    }
994
995    #[test]
996    fn test_layout_cast() {
997        let layout = [3, 2, 6].c();
998        assert!(layout.clone().into_dim::<IxD>().is_ok());
999        assert!(layout.clone().into_dim::<Ix3>().is_ok());
1000        let layout = vec![3, 2, 6].c();
1001        assert!(layout.clone().into_dim::<IxD>().is_ok());
1002        assert!(layout.clone().into_dim::<Ix3>().is_ok());
1003        assert!(layout.clone().into_dim::<Ix2>().is_err());
1004    }
1005
1006    #[test]
1007    fn test_unravel_index() {
1008        unsafe {
1009            let shape = [3, 2, 6];
1010            assert_eq!(shape.unravel_index_f(0), [0, 0, 0]);
1011            assert_eq!(shape.unravel_index_f(16), [1, 1, 2]);
1012            assert_eq!(shape.unravel_index_c(0), [0, 0, 0]);
1013            assert_eq!(shape.unravel_index_c(16), [1, 0, 4]);
1014        }
1015    }
1016
1017    #[test]
1018    fn fix_too_strict_stride_check() {
1019        let layout = [10, 11, 12].c();
1020        let slc = (.., slice!(-1, 0, -4));
1021        let slc: AxesIndex<Indexer> = slc.try_into().unwrap();
1022        let indexed = layout.dim_slice(slc.as_ref()).unwrap();
1023        assert_eq!(indexed.shape(), &[10, 3, 12]);
1024        assert_eq!(indexed.stride(), &[132, -48, 1]);
1025    }
1026}