ndarray_layout/transform/
broadcast.rs1use crate::ArrayLayout;
2
3#[derive(Clone, PartialEq, Eq, Debug)]
5pub struct BroadcastArg {
6 pub axis: usize,
8 pub times: usize,
10}
11
12impl<const N: usize> ArrayLayout<N> {
13 pub fn broadcast(&self, axis: usize, times: usize) -> Self {
23 self.broadcast_many(&[BroadcastArg { axis, times }])
24 }
25
26 pub fn broadcast_many(&self, args: &[BroadcastArg]) -> Self {
28 let mut ans = self.clone();
29 let mut content = ans.content_mut();
30 for &BroadcastArg { axis, times } in args {
31 assert!(content.shape()[axis] == 1 || content.strides()[axis] == 0);
32 content.set_shape(axis, times);
33 content.set_stride(axis, 0);
34 }
35 ans
36 }
37}