ndarray_layout/transform/
index.rs

1use crate::ArrayLayout;
2use std::iter::zip;
3
4/// 索引变换参数。
5#[derive(Clone, PartialEq, Eq, Debug)]
6pub struct IndexArg {
7    /// 索引的轴。
8    pub axis: usize,
9    /// 选择指定轴的第几个元素。
10    pub index: usize,
11}
12
13impl<const N: usize> ArrayLayout<N> {
14    /// 索引变换是选择张量指定阶上一项数据的变换,例如指定向量中的一个数、指定矩阵的一行或一列。
15    /// 索引变换导致张量降阶,确定索引的阶从张量表示移除。
16    ///
17    /// ```rust
18    /// # use ndarray_layout::ArrayLayout;
19    /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).index(1, 2);
20    /// assert_eq!(layout.shape(), &[2, 4]);
21    /// assert_eq!(layout.strides(), &[12, 1]);
22    /// assert_eq!(layout.offset(), 8);
23    /// ```
24    pub fn index(&self, axis: usize, index: usize) -> Self {
25        self.index_many(&[IndexArg { axis, index }])
26    }
27
28    /// 一次对多个阶进行索引变换。
29    pub fn index_many(&self, mut args: &[IndexArg]) -> Self {
30        let content = self.content();
31        let mut offset = content.offset();
32        let shape = content.shape();
33        let iter = zip(shape, content.strides()).enumerate();
34
35        let check = |&IndexArg { axis, index }| shape.get(axis).filter(|&&d| index < d).is_some();
36
37        if let [first, ..] = args {
38            assert!(check(first), "Invalid index arg: {first:?}");
39        } else {
40            return self.clone();
41        }
42
43        let mut ans = Self::with_ndim(self.ndim - args.len());
44        let mut content = ans.content_mut();
45        let mut j = 0;
46        for (i, (&d, &s)) in iter {
47            match *args {
48                [IndexArg { axis, index }, ref tail @ ..] if axis == i => {
49                    offset += index as isize * s;
50                    if let [first, ..] = tail {
51                        assert!(check(first), "Invalid index arg: {first:?}");
52                        assert!(first.axis > axis, "Index args must be in ascending order");
53                    }
54                    args = tail;
55                }
56                [..] => {
57                    content.set_shape(j, d);
58                    content.set_stride(j, s);
59                    j += 1;
60                }
61            }
62        }
63        content.set_offset(offset as _);
64        ans
65    }
66}
67
68#[test]
69fn test() {
70    let layout = ArrayLayout::<1>::new(&[2, 3, 4], &[12, 4, 1], 0);
71    let layout = layout.index(1, 2);
72    assert_eq!(layout.shape(), &[2, 4]);
73    assert_eq!(layout.strides(), &[12, 1]);
74    assert_eq!(layout.offset(), 8);
75
76    let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
77    let layout = layout.index(1, 2);
78    assert_eq!(layout.shape(), &[2, 4]);
79    assert_eq!(layout.strides(), &[12, 1]);
80    assert_eq!(layout.offset(), 12);
81
82    let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
83    let layout = layout.index_many(&[]);
84    assert_eq!(layout.shape(), &[2, 3, 4]);
85    assert_eq!(layout.strides(), &[12, -4, 1]);
86    assert_eq!(layout.offset(), 20);
87
88    let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
89    let layout = layout.index_many(&[
90        IndexArg { axis: 0, index: 1 },
91        IndexArg { axis: 1, index: 2 },
92    ]);
93    assert_eq!(layout.shape(), &[4]);
94    assert_eq!(layout.strides(), &[1]);
95    assert_eq!(layout.offset(), 24);
96}