ndarray_layout/transform/
broadcast.rs1use crate::ArrayLayout;
2
3#[derive(Clone, PartialEq, Eq, Debug)]
5pub struct BroadcastArg {
6 pub axis: usize,
8 pub times: usize,
10}
11
12impl<const N: usize> ArrayLayout<N> {
13 pub fn broadcast(&self, axis: usize, times: usize) -> Self {
23 self.broadcast_many(&[BroadcastArg { axis, times }])
24 }
25
26 pub fn broadcast_many(&self, args: &[BroadcastArg]) -> Self {
28 let mut ans = self.clone();
29 let mut content = ans.content_mut();
30 for &BroadcastArg { axis, times } in args {
31 assert!(content.shape()[axis] == 1 || content.strides()[axis] == 0);
32 content.set_shape(axis, times);
33 content.set_stride(axis, 0);
34 }
35 ans
36 }
37}
38
39#[test]
40fn test_broadcast() {
41 let layout = ArrayLayout::<3>::new(&[1, 5, 2], &[10, 2, 1], 0).broadcast(0, 10);
42 assert_eq!(layout.shape(), &[10, 5, 2]);
43 assert_eq!(layout.strides(), &[0, 2, 1]);
44 assert_eq!(layout.offset(), 0);
45}