1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
use crate::{shapes::*, tensor::*};
pub trait BroadcastTo: HasErr + HasShape {
fn broadcast<Dst: ConstShape, Ax: Axes>(self) -> Self::WithShape<Dst>
where
Self::Shape: BroadcastShapeTo<Dst, Ax>,
{
self.try_broadcast_like(&Default::default()).unwrap()
}
fn try_broadcast<Dst: ConstShape, Ax: Axes>(self) -> Result<Self::WithShape<Dst>, Self::Err>
where
Self::Shape: BroadcastShapeTo<Dst, Ax>,
{
self.try_broadcast_like(&Default::default())
}
fn broadcast_like<Dst: Shape, Ax: Axes>(self, dst: &Dst) -> Self::WithShape<Dst>
where
Self::Shape: BroadcastShapeTo<Dst, Ax>,
{
self.try_broadcast_like(dst).unwrap()
}
fn try_broadcast_like<Dst: Shape, Ax: Axes>(
self,
dst: &Dst,
) -> Result<Self::WithShape<Dst>, Self::Err>
where
Self::Shape: BroadcastShapeTo<Dst, Ax>;
}
impl<S: Shape, E: Unit, D: DeviceStorage, T: Tape<E, D>> BroadcastTo for Tensor<S, E, D, T> {
fn try_broadcast_like<Dst: Shape, Ax: Axes>(
self,
dst: &Dst,
) -> Result<Self::WithShape<Dst>, Self::Err>
where
Self::Shape: BroadcastShapeTo<Dst, Ax>,
{
self.shape().check(dst);
Ok(Tensor {
id: self.id,
data: self.data,
shape: *dst,
strides: self.shape.broadcast_strides(self.strides),
device: self.device,
tape: self.tape,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_ops::*;
use crate::tests::*;
#[test]
#[should_panic]
fn test_broadcast_incorrect_dims() {
let dev: TestDevice = Default::default();
let a: Tensor<(usize,), TestDtype, _> = dev.zeros_like(&(5,));
let _: Tensor<(Const<3>, usize), TestDtype, _> = a.broadcast_like(&(Const, 7));
}
#[test]
fn test_valid_1d_broadcasts() {
let dev: TestDevice = Default::default();
let _: Tensor<Rank1<5>, TestDtype, _> = dev.zeros::<Rank0>().broadcast();
let _: Tensor<Rank2<5, 3>, TestDtype, _> = dev.zeros::<Rank1<3>>().broadcast();
let _: Tensor<Rank2<5, 3>, TestDtype, _> = dev.zeros::<Rank1<5>>().broadcast();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank2<5, 7>>().broadcast();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank2<3, 7>>().broadcast();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank2<3, 5>>().broadcast();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank2<3, 5>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank3<5, 7, 9>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank3<3, 7, 9>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank3<3, 5, 9>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank3<3, 5, 7>>().broadcast();
}
#[test]
fn test_valid_2d_broadcasts() {
let dev: TestDevice = Default::default();
let _: Tensor<Rank2<5, 3>, TestDtype, _> = dev.zeros::<Rank0>().broadcast();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank1<3>>().broadcast();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank1<5>>().broadcast();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank1<7>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank2<3, 5>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank2<3, 7>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank2<3, 9>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank2<5, 7>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank2<5, 9>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank2<7, 9>>().broadcast();
}
#[test]
fn test_valid_3d_broadcasts() {
let dev: TestDevice = Default::default();
let _: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.zeros::<Rank0>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank1<3>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank1<5>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank1<7>>().broadcast();
let _: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.zeros::<Rank1<9>>().broadcast();
}
#[test]
fn test_broadcast_backwards() {
let dev: TestDevice = Default::default();
let a: Tensor<Rank1<3>, TestDtype, _> = dev.sample_normal();
let b: Tensor<Rank2<5, 3>, TestDtype, _> = dev.sample_normal();
let a_up = a.leaky_trace().broadcast::<Rank2<5, 3>, _>();
a_up.array().assert_close(&[a.array(); 5], 1e-4);
let r = a_up * b.clone();
let g = r.exp().mean().backward();
let a_up = a.clone().broadcast::<Rank2<5, 3>, _>();
let a_grad = (b.clone() * (b.clone() * a_up.clone()).exp()).sum::<Rank1<3>, _>() / 15.0;
let b_grad = (a_up.clone() * (b.clone() * a_up).exp()) / 15.0;
g.get(&a).array().assert_close(&a_grad.array(), 1e-4);
g.get(&b).array().assert_close(&b_grad.array(), 1e-4);
}
#[test]
fn test_broadcast_summed() {
let dev: TestDevice = Default::default();
let a: Tensor<Rank1<3>, TestDtype, _> = dev.sample_normal();
let g = a
.leaky_trace()
.broadcast::<Rank2<4, 3>, _>()
.exp()
.mean()
.backward();
assert_close(&g.get(&a).array(), &a.array().map(|x| x.exp() / 3.0));
}
}