ndarray_layout/transform/
tile.rs

1use crate::{ArrayLayout, Endian};
2use std::iter::zip;
3
4/// 分块变换参数。
5#[derive(Clone, PartialEq, Eq, Debug)]
6pub struct TileArg<'a> {
7    /// 分块的轴。
8    pub axis: usize,
9    /// 分块的顺序。
10    pub endian: Endian,
11    /// 分块的大小。
12    pub tiles: &'a [usize],
13}
14
15impl<const N: usize> ArrayLayout<N> {
16    /// 分块变换是将单个维度划分为多个分块的变换。
17    /// 大端分块使得分块后范围更大的维度在形状中更靠前的位置。
18    ///
19    /// ```rust
20    /// # use ndarray_layout::ArrayLayout;
21    /// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_be(2, &[2, 3]);
22    /// assert_eq!(layout.shape(), &[2, 3, 2, 3]);
23    /// assert_eq!(layout.strides(), &[18, 6, 3, 1]);
24    /// assert_eq!(layout.offset(), 0);
25    /// ```
26    #[inline]
27    pub fn tile_be(&self, axis: usize, tiles: &[usize]) -> Self {
28        self.tile_many(&[TileArg {
29            axis,
30            endian: Endian::BigEndian,
31            tiles,
32        }])
33    }
34
35    /// 分块变换是将单个维度划分为多个分块的变换。
36    /// 小端分块使得分块后范围更小的维度在形状中更靠前的位置。
37    ///
38    /// ```rust
39    /// # use ndarray_layout::ArrayLayout;
40    /// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_le(2, &[2, 3]);
41    /// assert_eq!(layout.shape(), &[2, 3, 2, 3]);
42    /// assert_eq!(layout.strides(), &[18, 6, 1, 2]);
43    /// assert_eq!(layout.offset(), 0);
44    /// ```
45    #[inline]
46    pub fn tile_le(&self, axis: usize, tiles: &[usize]) -> Self {
47        self.tile_many(&[TileArg {
48            axis,
49            endian: Endian::LittleEndian,
50            tiles,
51        }])
52    }
53
54    /// 一次对多个阶进行分块变换。
55    pub fn tile_many(&self, mut args: &[TileArg]) -> Self {
56        let content = self.content();
57        let shape = content.shape();
58        let iter = zip(shape, content.strides()).enumerate();
59
60        let check = |&TileArg { axis, tiles, .. }| {
61            shape
62                .get(axis)
63                .filter(|&&d| d == tiles.iter().product())
64                .is_some()
65        };
66
67        let (mut new, mut last_axis) = match args {
68            [first, ..] => {
69                assert!(check(first));
70                (first.tiles.len(), first.axis)
71            }
72            [..] => return self.clone(),
73        };
74        for arg in &args[1..] {
75            assert!(check(arg));
76            assert!(arg.axis > last_axis);
77            new += arg.tiles.len();
78            last_axis = arg.axis;
79        }
80
81        let mut ans = Self::with_ndim(self.ndim + new - args.len());
82
83        let mut content = ans.content_mut();
84        content.set_offset(self.offset());
85        let mut j = 0;
86        let mut push = |t, s| {
87            content.set_shape(j, t);
88            content.set_stride(j, s);
89            j += 1;
90        };
91
92        for (i, (&d, &s)) in iter {
93            match *args {
94                [TileArg {
95                    axis,
96                    endian,
97                    tiles,
98                }, ref tail @ ..]
99                    if axis == i =>
100                {
101                    match endian {
102                        Endian::BigEndian => {
103                            // tile   : [a,         b    , c]
104                            // strides: [s * c * b, s * c, s]
105                            let mut s = s * d as isize;
106                            for &t in tiles {
107                                s /= t as isize;
108                                push(t, s);
109                            }
110                        }
111                        Endian::LittleEndian => {
112                            // tile   : [a, b    , c        ]
113                            // strides: [s, s * a, s * a * b]
114                            let mut s = s;
115                            for &t in tiles {
116                                push(t, s);
117                                s *= t as isize;
118                            }
119                        }
120                    }
121                    args = tail;
122                }
123                [..] => push(d, s),
124            }
125        }
126        ans
127    }
128}