ndarray_layout/transform/
index.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
use crate::ArrayLayout;
use std::iter::zip;

/// 索引变换参数。
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct IndexArg {
    /// 索引的轴。
    pub axis: usize,
    /// 选择指定轴的第几个元素。
    pub index: usize,
}

impl<const N: usize> ArrayLayout<N> {
    /// 索引变换是选择张量指定阶上一项数据的变换,例如指定向量中的一个数、指定矩阵的一行或一列。
    /// 索引变换导致张量降阶,确定索引的阶从张量表示移除。
    ///
    /// ```rust
    /// # use ndarray_layout::ArrayLayout;
    /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).index(1, 2);
    /// assert_eq!(layout.shape(), &[2, 4]);
    /// assert_eq!(layout.strides(), &[12, 1]);
    /// assert_eq!(layout.offset(), 8);
    /// ```
    pub fn index(&self, axis: usize, index: usize) -> Self {
        self.index_many(&[IndexArg { axis, index }])
    }

    /// 一次对多个阶进行索引变换。
    pub fn index_many(&self, mut args: &[IndexArg]) -> Self {
        let content = self.content();
        let mut offset = content.offset();
        let shape = content.shape();
        let iter = zip(shape, content.strides()).enumerate();

        let check = |&IndexArg { axis, index }| shape.get(axis).filter(|&&d| index < d).is_some();

        if let [first, ..] = args {
            assert!(check(first), "Invalid index arg: {first:?}");
        } else {
            return self.clone();
        }

        let mut ans = Self::with_ndim(self.ndim - args.len());
        let mut content = ans.content_mut();
        let mut j = 0;
        for (i, (&d, &s)) in iter {
            match *args {
                [IndexArg { axis, index }, ref tail @ ..] if axis == i => {
                    offset += index as isize * s;
                    if let [first, ..] = tail {
                        assert!(check(first), "Invalid index arg: {first:?}");
                        assert!(first.axis > axis, "Index args must be in ascending order");
                    }
                    args = tail;
                }
                [..] => {
                    content.set_shape(j, d);
                    content.set_stride(j, s);
                    j += 1;
                }
            }
        }
        content.set_offset(offset as _);
        ans
    }
}

#[test]
fn test() {
    let layout = ArrayLayout::<1>::new(&[2, 3, 4], &[12, 4, 1], 0);
    let layout = layout.index(1, 2);
    assert_eq!(layout.shape(), &[2, 4]);
    assert_eq!(layout.strides(), &[12, 1]);
    assert_eq!(layout.offset(), 8);

    let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
    let layout = layout.index(1, 2);
    assert_eq!(layout.shape(), &[2, 4]);
    assert_eq!(layout.strides(), &[12, 1]);
    assert_eq!(layout.offset(), 12);
}