ndarray_layout/transform/
tile.rs

1use crate::{ArrayLayout, Endian};
2use std::iter::zip;
3
4/// 分块变换参数。
5#[derive(Clone, PartialEq, Eq, Debug)]
6pub struct TileArg<'a> {
7    /// 分块的轴。
8    pub axis: usize,
9    /// 分块的顺序。
10    pub endian: Endian,
11    /// 分块的大小。
12    pub tiles: &'a [usize],
13}
14
15impl<const N: usize> ArrayLayout<N> {
16    /// 分块变换是将单个维度划分为多个分块的变换。
17    /// 大端分块使得分块后范围更大的维度在形状中更靠前的位置。
18    ///
19    /// ```rust
20    /// # use ndarray_layout::ArrayLayout;
21    /// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_be(2, &[2, 3]);
22    /// assert_eq!(layout.shape(), &[2, 3, 2, 3]);
23    /// assert_eq!(layout.strides(), &[18, 6, 3, 1]);
24    /// assert_eq!(layout.offset(), 0);
25    /// ```
26    #[inline]
27    pub fn tile_be(&self, axis: usize, tiles: &[usize]) -> Self {
28        self.tile_many(&[TileArg {
29            axis,
30            endian: Endian::BigEndian,
31            tiles,
32        }])
33    }
34
35    /// 分块变换是将单个维度划分为多个分块的变换。
36    /// 小端分块使得分块后范围更小的维度在形状中更靠前的位置。
37    ///
38    /// ```rust
39    /// # use ndarray_layout::ArrayLayout;
40    /// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_le(2, &[2, 3]);
41    /// assert_eq!(layout.shape(), &[2, 3, 2, 3]);
42    /// assert_eq!(layout.strides(), &[18, 6, 1, 2]);
43    /// assert_eq!(layout.offset(), 0);
44    /// ```
45    #[inline]
46    pub fn tile_le(&self, axis: usize, tiles: &[usize]) -> Self {
47        self.tile_many(&[TileArg {
48            axis,
49            endian: Endian::LittleEndian,
50            tiles,
51        }])
52    }
53
54    /// 一次对多个阶进行分块变换。
55    pub fn tile_many(&self, mut args: &[TileArg]) -> Self {
56        let content = self.content();
57        let shape = content.shape();
58        let iter = zip(shape, content.strides()).enumerate();
59
60        let check = |&TileArg { axis, tiles, .. }| {
61            shape
62                .get(axis)
63                .filter(|&&d| d == tiles.iter().product())
64                .is_some()
65        };
66
67        let (mut new, mut last_axis) = match args {
68            [first, ..] => {
69                assert!(check(first));
70                (first.tiles.len(), first.axis)
71            }
72            [..] => return self.clone(),
73        };
74        for arg in &args[1..] {
75            assert!(check(arg));
76            assert!(arg.axis > last_axis);
77            new += arg.tiles.len();
78            last_axis = arg.axis;
79        }
80
81        let mut ans = Self::with_ndim(self.ndim + new - args.len());
82
83        let mut content = ans.content_mut();
84        content.set_offset(self.offset());
85        let mut j = 0;
86        let mut push = |t, s| {
87            content.set_shape(j, t);
88            content.set_stride(j, s);
89            j += 1;
90        };
91
92        for (i, (&d, &s)) in iter {
93            match *args {
94                [
95                    TileArg {
96                        axis,
97                        endian,
98                        tiles,
99                    },
100                    ref tail @ ..,
101                ] if axis == i => {
102                    match endian {
103                        Endian::BigEndian => {
104                            // tile   : [a,         b    , c]
105                            // strides: [s * c * b, s * c, s]
106                            let mut s = s * d as isize;
107                            for &t in tiles {
108                                s /= t as isize;
109                                push(t, s);
110                            }
111                        }
112                        Endian::LittleEndian => {
113                            // tile   : [a, b    , c        ]
114                            // strides: [s, s * a, s * a * b]
115                            let mut s = s;
116                            for &t in tiles {
117                                push(t, s);
118                                s *= t as isize;
119                            }
120                        }
121                    }
122                    args = tail;
123                }
124                [..] => push(d, s),
125            }
126        }
127        ans
128    }
129}
130
131#[test]
132fn test_tile_be() {
133    let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_be(2, &[2, 3]);
134    assert_eq!(layout.shape(), &[2, 3, 2, 3]);
135    assert_eq!(layout.strides(), &[18, 6, 3, 1]);
136    assert_eq!(layout.offset(), 0);
137}
138
139#[test]
140fn test_tile_le() {
141    let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_le(2, &[2, 3]);
142    assert_eq!(layout.shape(), &[2, 3, 2, 3]);
143    assert_eq!(layout.strides(), &[18, 6, 1, 2]);
144    assert_eq!(layout.offset(), 0);
145}
146
147#[test]
148fn test_empty_tile() {
149    let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_many(&[]);
150    assert_eq!(layout.shape(), &[2, 3, 6]);
151    assert_eq!(layout.strides(), &[18, 6, 1]);
152    assert_eq!(layout.offset(), 0);
153}
154
155#[test]
156fn test_multiple_tiles() {
157    let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_many(&[
158        TileArg {
159            axis: 0,
160            endian: Endian::BigEndian,
161            tiles: &[2, 1],
162        },
163        TileArg {
164            axis: 2,
165            endian: Endian::BigEndian,
166            tiles: &[2, 3],
167        },
168    ]);
169    assert_eq!(layout.shape(), &[2, 1, 3, 2, 3]);
170    assert_eq!(layout.strides(), &[18, 18, 6, 3, 1]);
171    assert_eq!(layout.offset(), 0);
172}