ndarray_layout/transform/
broadcast.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
use crate::ArrayLayout;

/// 索引变换参数。
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct BroadcastArg {
    /// 广播的轴。
    pub axis: usize,
    /// 广播次数。
    pub times: usize,
}

impl<const N: usize> ArrayLayout<N> {
    /// 广播变换将指定的长度为 1 的阶扩增指定的倍数,并将其步长固定为 0。
    ///
    /// ```rust
    /// # use ndarray_layout::ArrayLayout;
    /// let layout = ArrayLayout::<3>::new(&[1, 5, 2], &[10, 2, 1], 0).broadcast(0, 10);
    /// assert_eq!(layout.shape(), &[10, 5, 2]);
    /// assert_eq!(layout.strides(), &[0, 2, 1]);
    /// assert_eq!(layout.offset(), 0);
    /// ```
    pub fn broadcast(&self, axis: usize, times: usize) -> Self {
        self.broadcast_many(&[BroadcastArg { axis, times }])
    }

    /// 一次对多个阶进行广播变换。
    pub fn broadcast_many(&self, args: &[BroadcastArg]) -> Self {
        let mut ans = self.clone();
        let mut content = ans.content_mut();
        for &BroadcastArg { axis, times } in args {
            assert!(content.shape()[axis] == 1 || content.strides()[axis] == 0);
            content.set_shape(axis, times);
            content.set_stride(axis, 0);
        }
        ans
    }
}