use crate::ArrayLayout;
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct BroadcastArg {
pub axis: usize,
pub times: usize,
}
impl<const N: usize> ArrayLayout<N> {
pub fn broadcast(&self, axis: usize, times: usize) -> Self {
self.broadcast_many(&[BroadcastArg { axis, times }])
}
pub fn broadcast_many(&self, args: &[BroadcastArg]) -> Self {
let mut ans = self.clone();
let mut content = ans.content_mut();
for &BroadcastArg { axis, times } in args {
assert!(content.shape()[axis] == 1 || content.strides()[axis] == 0);
content.set_shape(axis, times);
content.set_stride(axis, 0);
}
ans
}
}
#[test]
fn test_broadcast() {
let layout = ArrayLayout::<3>::new(&[1, 5, 2], &[10, 2, 1], 0).broadcast(0, 10);
assert_eq!(layout.shape(), &[10, 5, 2]);
assert_eq!(layout.strides(), &[0, 2, 1]);
assert_eq!(layout.offset(), 0);
}