ndarray_layout/transform/
tile.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use crate::{ArrayLayout, Endian};
use std::iter::zip;

/// 分块变换参数。
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct TileArg<'a> {
    /// 分块的轴。
    pub axis: usize,
    /// 分块的顺序。
    pub endian: Endian,
    /// 分块的大小。
    pub tiles: &'a [usize],
}

impl<const N: usize> ArrayLayout<N> {
    /// 分块变换是将单个维度划分为多个分块的变换。
    /// 大端分块使得分块后范围更大的维度在形状中更靠前的位置。
    ///
    /// ```rust
    /// # use ndarray_layout::ArrayLayout;
    /// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_be(2, &[2, 3]);
    /// assert_eq!(layout.shape(), &[2, 3, 2, 3]);
    /// assert_eq!(layout.strides(), &[18, 6, 3, 1]);
    /// assert_eq!(layout.offset(), 0);
    /// ```
    #[inline]
    pub fn tile_be(&self, axis: usize, tiles: &[usize]) -> Self {
        self.tile_many(&[TileArg {
            axis,
            endian: Endian::BigEndian,
            tiles,
        }])
    }

    /// 分块变换是将单个维度划分为多个分块的变换。
    /// 小端分块使得分块后范围更小的维度在形状中更靠前的位置。
    ///
    /// ```rust
    /// # use ndarray_layout::ArrayLayout;
    /// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_le(2, &[2, 3]);
    /// assert_eq!(layout.shape(), &[2, 3, 2, 3]);
    /// assert_eq!(layout.strides(), &[18, 6, 1, 2]);
    /// assert_eq!(layout.offset(), 0);
    /// ```
    #[inline]
    pub fn tile_le(&self, axis: usize, tiles: &[usize]) -> Self {
        self.tile_many(&[TileArg {
            axis,
            endian: Endian::LittleEndian,
            tiles,
        }])
    }

    /// 一次对多个阶进行分块变换。
    pub fn tile_many(&self, mut args: &[TileArg]) -> Self {
        let content = self.content();
        let shape = content.shape();
        let iter = zip(shape, content.strides()).enumerate();

        let check = |&TileArg { axis, tiles, .. }| {
            shape
                .get(axis)
                .filter(|&&d| d == tiles.iter().product())
                .is_some()
        };

        let (mut new, mut last_axis) = match args {
            [first, ..] => {
                assert!(check(first));
                (first.tiles.len(), first.axis)
            }
            [..] => return self.clone(),
        };
        for arg in &args[1..] {
            assert!(check(arg));
            assert!(arg.axis > last_axis);
            new += arg.tiles.len();
            last_axis = arg.axis;
        }

        let mut ans = Self::with_ndim(self.ndim + new - args.len());

        let mut content = ans.content_mut();
        content.set_offset(self.offset());
        let mut j = 0;
        let mut push = |t, s| {
            content.set_shape(j, t);
            content.set_stride(j, s);
            j += 1;
        };

        for (i, (&d, &s)) in iter {
            match *args {
                [TileArg {
                    axis,
                    endian,
                    tiles,
                }, ref tail @ ..]
                    if axis == i =>
                {
                    match endian {
                        Endian::BigEndian => {
                            // tile   : [a,         b    , c]
                            // strides: [s * c * b, s * c, s]
                            let mut s = s * d as isize;
                            for &t in tiles {
                                s /= t as isize;
                                push(t, s);
                            }
                        }
                        Endian::LittleEndian => {
                            // tile   : [a, b    , c        ]
                            // strides: [s, s * a, s * a * b]
                            let mut s = s;
                            for &t in tiles {
                                push(t, s);
                                s *= t as isize;
                            }
                        }
                    }
                    args = tail;
                }
                [..] => push(d, s),
            }
        }
        ans
    }
}