1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use crate::ArrayLayout;
use std::iter::zip;
/// 切片变换参数。
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct SliceArg {
/// 切片的轴。
pub axis: usize,
/// 切片的起始位置。
pub start: usize,
/// 切片的步长。
pub step: isize,
/// 切片的长度。
pub len: usize,
}
impl<const N: usize> ArrayLayout<N> {
/// 切片变换是裁剪张量指定阶上一组连续数据的变换。
///
/// ```rust
/// # use ndarray_layout::ArrayLayout;
/// // axis = 1, start = 1, step = -1, len = 2
/// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 2, -1, 2);
/// assert_eq!(layout.shape(), &[2, 2, 4]);
/// assert_eq!(layout.strides(), &[12, -4, 1]);
/// assert_eq!(layout.offset(), 8);
/// ```
pub fn slice(&self, axis: usize, start: usize, step: isize, len: usize) -> Self {
self.slice_many(&[SliceArg {
axis,
start,
step,
len,
}])
}
/// 一次对多个阶进行切片变换。
pub fn slice_many(&self, mut args: &[SliceArg]) -> Self {
let content = self.content();
let mut offset = content.offset();
let iter = zip(content.shape(), content.strides()).enumerate();
let mut ans = Self::with_ndim(self.ndim);
let mut content = ans.content_mut();
for (i, (&d, &s)) in iter {
match args {
[arg, tail @ ..] if arg.axis == i => {
let &SliceArg {
axis,
start,
step,
len,
} = arg;
use std::cmp::Ordering::*;
let len = match step.cmp(&0) {
Greater => {
assert!(start < d);
offset += start as isize * s;
(d - start).div_ceil(step as _).min(len)
}
Equal => {
assert!(start < d);
offset += start as isize * s;
len
}
Less => {
let start = start.min(d - 1);
offset += start as isize * s;
(start + 1).div_ceil((-step) as _).min(len)
}
};
content.set_shape(i, len);
content.set_stride(i, s * step);
if let [next, ..] = tail {
assert!(
axis < next.axis && next.axis < self.ndim,
"next.axis = {} !in ({}, {})",
next.axis,
axis,
self.ndim,
);
}
args = tail;
}
[..] => {
content.set_shape(i, d);
content.set_stride(i, s);
}
}
}
content.set_offset(offset as _);
ans
}
}
#[test]
fn test_slice() {
let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 2, -1, 2);
assert_eq!(layout.shape(), &[2, 2, 4]);
assert_eq!(layout.strides(), &[12, -4, 1]);
assert_eq!(layout.offset(), 8);
let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 2, 0, 2);
assert_eq!(layout.shape(), &[2, 2, 4]);
assert_eq!(layout.strides(), &[12, 0, 1]);
assert_eq!(layout.offset(), 8);
let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 0, 1, 2);
assert_eq!(layout.shape(), &[2, 2, 4]);
assert_eq!(layout.strides(), &[12, 4, 1]);
assert_eq!(layout.offset(), 0);
let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice_many(&[
SliceArg {
axis: 1,
start: 0,
step: 1,
len: 2,
},
SliceArg {
axis: 2,
start: 0,
step: 1,
len: 4,
},
]);
assert_eq!(layout.shape(), &[2, 2, 4]);
assert_eq!(layout.strides(), &[12, 4, 1]);
assert_eq!(layout.offset(), 0);
}