use crate::{shapes::*, tensor::*};
pub trait BroadcastTo: HasErr + HasShape {
fn broadcast<Dst: ConstShape, Ax: Axes>(self) -> Self::WithShape<Dst>
where
Self::Shape: BroadcastShapeTo<Dst, Ax>,
{
self.try_broadcast_like::<Dst, _>(&Default::default())
.unwrap()
}
fn try_broadcast<Dst: ConstShape, Ax: Axes>(self) -> Result<Self::WithShape<Dst>, Self::Err>
where
Self::Shape: BroadcastShapeTo<Dst, Ax>,
{
self.try_broadcast_like::<Dst, _>(&Default::default())
}
fn broadcast_like<Dst: HasShape, Ax: Axes>(self, dst: &Dst) -> Self::WithShape<Dst::Shape>
where
Self::Shape: BroadcastShapeTo<Dst::Shape, Ax>,
{
self.try_broadcast_like(dst).unwrap()
}
fn try_broadcast_like<Dst: HasShape, Ax: Axes>(
self,
dst: &Dst,
) -> Result<Self::WithShape<Dst::Shape>, Self::Err>
where
Self::Shape: BroadcastShapeTo<Dst::Shape, Ax>;
}
impl<S: Shape, E, D: Storage<E>, T: Tape<E, D>> BroadcastTo for Tensor<S, E, D, T> {
fn try_broadcast_like<Dst: HasShape, Ax: Axes>(
self,
dst: &Dst,
) -> Result<Self::WithShape<Dst::Shape>, Self::Err>
where
Self::Shape: BroadcastShapeTo<Dst::Shape, Ax>,
{
self.shape().check(dst.shape());
Ok(Tensor {
id: self.id,
data: self.data,
shape: *dst.shape(),
strides: self.shape.broadcast_strides(self.strides),
device: self.device,
tape: self.tape,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_ops::*;
use crate::tests::*;
#[test]
#[should_panic]
fn test_broadcast_incorrect_dims() {
let dev: TestDevice = Default::default();
let a: Tensor<(usize,), TestDtype, _> = dev.zeros_like(&(5,));
let _: Tensor<(Const<3>, usize), TestDtype, _> = a.broadcast_like(&(Const, 7));
}
#[test]
fn test_broadcast_with_tensor() {
let dev: TestDevice = Default::default();
let a1: Tensor<_, TestDtype, _> = dev.zeros_like(&(5,));
let b: Tensor<_, TestDtype, _> = dev.zeros_like(&(2, 5, 3));
let a2 = a1.broadcast_like::<_, Axes2<0, 2>>(&b);
assert_eq!(a2.shape(), &(2, 5, 3));
}
#[test]
fn test_valid_1d_broadcasts() {
let dev: TestDevice = Default::default();
let _: Tensor<Rank1<5>, TestDtype, _> = dev.zeros::<Rank0>().broadcast();
let _: Tensor<Rank2<5, 3>, TestDtype, _> = dev.zeros::<Rank1<3>>().broadcast();
let _: Tensor<Rank2<5, 3>, TestDtype, _> = dev.zeros::<Rank1<5>>().broadcast();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank2<5, 7>>().broadcast();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank2<3, 7>>().broadcast();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank2<3, 5>>().broadcast();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank2<3, 5>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank3<5, 7, 9>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank3<3, 7, 9>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank3<3, 5, 9>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank3<3, 5, 7>>().broadcast();
}
#[test]
fn test_valid_2d_broadcasts() {
let dev: TestDevice = Default::default();
let _: Tensor<Rank2<5, 3>, TestDtype, _> = dev.zeros::<Rank0>().broadcast();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank1<3>>().broadcast();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank1<5>>().broadcast();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank1<7>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank2<3, 5>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank2<3, 7>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank2<3, 9>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank2<5, 7>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank2<5, 9>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank2<7, 9>>().broadcast();
}
#[test]
fn test_valid_3d_broadcasts() {
let dev: TestDevice = Default::default();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank0>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank1<3>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank1<5>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank1<7>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank1<9>>().broadcast();
}
#[test]
fn test_broadcast_backwards() {
let dev: TestDevice = Default::default();
let a: Tensor<Rank1<3>, TestDtype, _> = dev.sample_normal();
let b: Tensor<Rank2<5, 3>, TestDtype, _> = dev.sample_normal();
let a_up = a.leaky_trace().broadcast::<Rank2<5, 3>, _>();
assert_close!(a_up.array(), [a.array(); 5], 1e-4);
let r = a_up * b.clone();
let g = r.exp().mean().backward();
let a_up = a.clone().broadcast::<Rank2<5, 3>, _>();
let a_grad = (b.clone() * (b.clone() * a_up.clone()).exp()).sum::<Rank1<3>, _>() / 15.0;
let b_grad = (a_up.clone() * (b.clone() * a_up).exp()) / 15.0;
assert_close_to_tensor!(g.get(&a), a_grad, 1e-4);
assert_close_to_tensor!(g.get(&b), b_grad, 1e-4);
}
#[test]
fn test_broadcast_summed() {
let dev: TestDevice = Default::default();
let a: Tensor<Rank1<3>, TestDtype, _> = dev.sample_normal();
let g = a
.leaky_trace()
.broadcast::<Rank2<4, 3>, _>()
.exp()
.mean()
.backward();
assert_close_to_tensor!(g.get(&a), a.exp() / 3.0);
}
}