1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
use crate::{ArrayLayout, Endian};
use std::iter::zip;
/// 分块变换参数。
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct TileArg<'a> {
/// 分块的轴。
pub axis: usize,
/// 分块的顺序。
pub endian: Endian,
/// 分块的大小。
pub tiles: &'a [usize],
}
impl<const N: usize> ArrayLayout<N> {
/// 分块变换是将单个维度划分为多个分块的变换。
/// 大端分块使得分块后范围更大的维度在形状中更靠前的位置。
///
/// ```rust
/// # use ndarray_layout::ArrayLayout;
/// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_be(2, &[2, 3]);
/// assert_eq!(layout.shape(), &[2, 3, 2, 3]);
/// assert_eq!(layout.strides(), &[18, 6, 3, 1]);
/// assert_eq!(layout.offset(), 0);
/// ```
#[inline]
pub fn tile_be(&self, axis: usize, tiles: &[usize]) -> Self {
self.tile_many(&[TileArg {
axis,
endian: Endian::BigEndian,
tiles,
}])
}
/// 分块变换是将单个维度划分为多个分块的变换。
/// 小端分块使得分块后范围更小的维度在形状中更靠前的位置。
///
/// ```rust
/// # use ndarray_layout::ArrayLayout;
/// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_le(2, &[2, 3]);
/// assert_eq!(layout.shape(), &[2, 3, 2, 3]);
/// assert_eq!(layout.strides(), &[18, 6, 1, 2]);
/// assert_eq!(layout.offset(), 0);
/// ```
#[inline]
pub fn tile_le(&self, axis: usize, tiles: &[usize]) -> Self {
self.tile_many(&[TileArg {
axis,
endian: Endian::LittleEndian,
tiles,
}])
}
/// 一次对多个阶进行分块变换。
pub fn tile_many(&self, mut args: &[TileArg]) -> Self {
let content = self.content();
let shape = content.shape();
let iter = zip(shape, content.strides()).enumerate();
let check = |&TileArg { axis, tiles, .. }| {
shape
.get(axis)
.filter(|&&d| d == tiles.iter().product())
.is_some()
};
let (mut new, mut last_axis) = match args {
[first, ..] => {
assert!(check(first));
(first.tiles.len(), first.axis)
}
[..] => return self.clone(),
};
for arg in &args[1..] {
assert!(check(arg));
assert!(arg.axis > last_axis);
new += arg.tiles.len();
last_axis = arg.axis;
}
let mut ans = Self::with_ndim(self.ndim + new - args.len());
let mut content = ans.content_mut();
content.set_offset(self.offset());
let mut j = 0;
let mut push = |t, s| {
content.set_shape(j, t);
content.set_stride(j, s);
j += 1;
};
for (i, (&d, &s)) in iter {
match *args {
[
TileArg {
axis,
endian,
tiles,
},
ref tail @ ..,
] if axis == i => {
match endian {
Endian::BigEndian => {
// tile : [a, b , c]
// strides: [s * c * b, s * c, s]
let mut s = s * d as isize;
for &t in tiles {
s /= t as isize;
push(t, s);
}
}
Endian::LittleEndian => {
// tile : [a, b , c ]
// strides: [s, s * a, s * a * b]
let mut s = s;
for &t in tiles {
push(t, s);
s *= t as isize;
}
}
}
args = tail;
}
[..] => push(d, s),
}
}
ans
}
}
#[test]
fn test_tile_be() {
let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_be(2, &[2, 3]);
assert_eq!(layout.shape(), &[2, 3, 2, 3]);
assert_eq!(layout.strides(), &[18, 6, 3, 1]);
assert_eq!(layout.offset(), 0);
}
#[test]
fn test_tile_le() {
let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_le(2, &[2, 3]);
assert_eq!(layout.shape(), &[2, 3, 2, 3]);
assert_eq!(layout.strides(), &[18, 6, 1, 2]);
assert_eq!(layout.offset(), 0);
}
#[test]
fn test_empty_tile() {
let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_many(&[]);
assert_eq!(layout.shape(), &[2, 3, 6]);
assert_eq!(layout.strides(), &[18, 6, 1]);
assert_eq!(layout.offset(), 0);
}
#[test]
fn test_multiple_tiles() {
let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_many(&[
TileArg {
axis: 0,
endian: Endian::BigEndian,
tiles: &[2, 1],
},
TileArg {
axis: 2,
endian: Endian::BigEndian,
tiles: &[2, 3],
},
]);
assert_eq!(layout.shape(), &[2, 1, 3, 2, 3]);
assert_eq!(layout.strides(), &[18, 18, 6, 3, 1]);
assert_eq!(layout.offset(), 0);
}