ndarray_layout/transform/
broadcast.rs

1use crate::ArrayLayout;
2
3/// 广播变换参数。
4#[derive(Clone, PartialEq, Eq, Debug)]
5pub struct BroadcastArg {
6    /// 广播的轴。
7    pub axis: usize,
8    /// 广播次数。
9    pub times: usize,
10}
11
12impl<const N: usize> ArrayLayout<N> {
13    /// 广播变换将指定的长度为 1 的阶扩增指定的倍数,并将其步长固定为 0。
14    ///
15    /// ```rust
16    /// # use ndarray_layout::ArrayLayout;
17    /// let layout = ArrayLayout::<3>::new(&[1, 5, 2], &[10, 2, 1], 0).broadcast(0, 10);
18    /// assert_eq!(layout.shape(), &[10, 5, 2]);
19    /// assert_eq!(layout.strides(), &[0, 2, 1]);
20    /// assert_eq!(layout.offset(), 0);
21    /// ```
22    pub fn broadcast(&self, axis: usize, times: usize) -> Self {
23        self.broadcast_many(&[BroadcastArg { axis, times }])
24    }
25
26    /// 一次对多个阶进行广播变换。
27    pub fn broadcast_many(&self, args: &[BroadcastArg]) -> Self {
28        let mut ans = self.clone();
29        let mut content = ans.content_mut();
30        for &BroadcastArg { axis, times } in args {
31            assert!(content.shape()[axis] == 1 || content.strides()[axis] == 0);
32            content.set_shape(axis, times);
33            content.set_stride(axis, 0);
34        }
35        ans
36    }
37}
38
39#[test]
40fn test_broadcast() {
41    let layout = ArrayLayout::<3>::new(&[1, 5, 2], &[10, 2, 1], 0).broadcast(0, 10);
42    assert_eq!(layout.shape(), &[10, 5, 2]);
43    assert_eq!(layout.strides(), &[0, 2, 1]);
44    assert_eq!(layout.offset(), 0);
45}