ndarray_layout/transform/
slice.rs

1use crate::ArrayLayout;
2use std::iter::zip;
3
4/// 切片变换参数。
5#[derive(Clone, PartialEq, Eq, Debug)]
6pub struct SliceArg {
7    /// 切片的轴。
8    pub axis: usize,
9    /// 切片的起始位置。
10    pub start: usize,
11    /// 切片的步长。
12    pub step: isize,
13    /// 切片的长度。
14    pub len: usize,
15}
16
17impl<const N: usize> ArrayLayout<N> {
18    /// 切片变换是裁剪张量指定阶上一组连续数据的变换。
19    ///
20    /// ```rust
21    /// # use ndarray_layout::ArrayLayout;
22    /// // axis = 1, start = 1, step = -1, len = 2
23    /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 2, -1, 2);
24    /// assert_eq!(layout.shape(), &[2, 2, 4]);
25    /// assert_eq!(layout.strides(), &[12, -4, 1]);
26    /// assert_eq!(layout.offset(), 8);
27    /// ```
28    pub fn slice(&self, axis: usize, start: usize, step: isize, len: usize) -> Self {
29        self.slice_many(&[SliceArg {
30            axis,
31            start,
32            step,
33            len,
34        }])
35    }
36
37    /// 一次对多个阶进行切片变换。
38    pub fn slice_many(&self, mut args: &[SliceArg]) -> Self {
39        let content = self.content();
40        let mut offset = content.offset();
41        let iter = zip(content.shape(), content.strides()).enumerate();
42
43        let mut ans = Self::with_ndim(self.ndim);
44        let mut content = ans.content_mut();
45        for (i, (&d, &s)) in iter {
46            match args {
47                [arg, tail @ ..] if arg.axis == i => {
48                    let &SliceArg {
49                        axis,
50                        start,
51                        step,
52                        len,
53                    } = arg;
54                    use std::cmp::Ordering::*;
55                    let len = match step.cmp(&0) {
56                        Greater => {
57                            assert!(start < d);
58                            offset += start as isize * s;
59                            (d - start).div_ceil(step as _).min(len)
60                        }
61                        Equal => {
62                            assert!(start < d);
63                            offset += start as isize * s;
64                            len
65                        }
66                        Less => {
67                            let start = start.min(d - 1);
68                            offset += start as isize * s;
69                            (start + 1).div_ceil((-step) as _).min(len)
70                        }
71                    };
72                    content.set_shape(i, len);
73                    content.set_stride(i, s * step);
74
75                    if let [next, ..] = tail {
76                        assert!(
77                            axis < next.axis && next.axis < self.ndim,
78                            "next.axis = {} !in ({}, {})",
79                            next.axis,
80                            axis,
81                            self.ndim,
82                        );
83                    }
84                    args = tail;
85                }
86                [..] => {
87                    content.set_shape(i, d);
88                    content.set_stride(i, s);
89                }
90            }
91        }
92        content.set_offset(offset as _);
93        ans
94    }
95}
96
97#[test]
98fn test_slice() {
99    let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 2, -1, 2);
100    assert_eq!(layout.shape(), &[2, 2, 4]);
101    assert_eq!(layout.strides(), &[12, -4, 1]);
102    assert_eq!(layout.offset(), 8);
103
104    let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 2, 0, 2);
105    assert_eq!(layout.shape(), &[2, 2, 4]);
106    assert_eq!(layout.strides(), &[12, 0, 1]);
107    assert_eq!(layout.offset(), 8);
108
109    let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 0, 1, 2);
110    assert_eq!(layout.shape(), &[2, 2, 4]);
111    assert_eq!(layout.strides(), &[12, 4, 1]);
112    assert_eq!(layout.offset(), 0);
113
114    let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice_many(&[
115        SliceArg {
116            axis: 1,
117            start: 0,
118            step: 1,
119            len: 2,
120        },
121        SliceArg {
122            axis: 2,
123            start: 0,
124            step: 1,
125            len: 4,
126        },
127    ]);
128    assert_eq!(layout.shape(), &[2, 2, 4]);
129    assert_eq!(layout.strides(), &[12, 4, 1]);
130    assert_eq!(layout.offset(), 0);
131}