ndarray_layout/transform/slice.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
use crate::ArrayLayout;
use std::iter::zip;
/// 切片变换参数。
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct SliceArg {
/// 切片的轴。
pub axis: usize,
/// 切片的起始位置。
pub start: usize,
/// 切片的步长。
pub step: isize,
/// 切片的长度。
pub len: usize,
}
impl<const N: usize> ArrayLayout<N> {
/// 切片变换是裁剪张量指定阶上一组连续数据的变换。
///
/// ```rust
/// # use ndarray_layout::ArrayLayout;
/// // axis = 1, start = 1, step = -1, len = 2
/// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 2, -1, 2);
/// assert_eq!(layout.shape(), &[2, 2, 4]);
/// assert_eq!(layout.strides(), &[12, -4, 1]);
/// assert_eq!(layout.offset(), 8);
/// ```
pub fn slice(&self, axis: usize, start: usize, step: isize, len: usize) -> Self {
self.slice_many(&[SliceArg {
axis,
start,
step,
len,
}])
}
/// 一次对多个阶进行切片变换。
pub fn slice_many(&self, mut args: &[SliceArg]) -> Self {
let content = self.content();
let mut offset = content.offset();
let iter = zip(content.shape(), content.strides()).enumerate();
let mut ans = Self::with_ndim(self.ndim);
let mut content = ans.content_mut();
for (i, (&d, &s)) in iter {
match args {
[arg, tail @ ..] if arg.axis == i => {
let &SliceArg {
axis,
start,
step,
len,
} = arg;
use std::cmp::Ordering::*;
let len = match step.cmp(&0) {
Greater => {
assert!(start < d);
offset += start as isize * s;
(d - start).div_ceil(step as _).min(len)
}
Equal => {
assert!(start < d);
offset += start as isize * s;
len
}
Less => {
let start = start.min(d - 1);
offset += start as isize * s;
(start + 1).div_ceil((-step) as _).min(len)
}
};
content.set_shape(i, len);
content.set_stride(i, s * step);
if let [next, ..] = tail {
assert!(
axis < next.axis && next.axis < self.ndim,
"next.axis = {} !in ({}, {})",
next.axis,
axis,
self.ndim,
);
}
args = tail;
}
[..] => {
content.set_shape(i, d);
content.set_stride(i, s);
}
}
}
content.set_offset(offset as _);
ans
}
}