1use super::{matmatadd, matmatdiv, matmatmul, matmatsub, Matrix};
2
3#[derive(Debug, Clone, Copy)]
4pub enum Broadcast {
5 Hstack(usize),
6 Vstack(usize),
7 IsScalar,
8 None,
9 Invalid,
10}
11
12pub(crate) fn calc_broadcast_shape(m1: &Matrix, m2: &Matrix) -> [Broadcast; 2] {
13 if m1.shape() == m2.shape() {
14 [Broadcast::None, Broadcast::None]
15 } else if m1.shape().contains(&1) {
16 if m1.nrows == 1 {
17 assert!(
18 m1.ncols == m2.ncols || m2.ncols == 1 || m1.ncols == 1 );
22 if m1.ncols == m2.ncols {
23 [Broadcast::Vstack(m2.nrows), Broadcast::None]
24 } else if m2.ncols == 1 {
25 [Broadcast::Vstack(m2.nrows), Broadcast::Hstack(m1.ncols)]
26 } else if m1.ncols == 1 {
27 [Broadcast::IsScalar, Broadcast::None]
28 } else {
29 [Broadcast::Invalid, Broadcast::Invalid]
30 }
31 } else {
32 assert!(
34 m1.nrows == m2.nrows || m2.nrows == 1 || m1.nrows == 1 );
38 if m1.nrows == m2.nrows {
39 [Broadcast::Hstack(m2.ncols), Broadcast::None]
40 } else if m2.nrows == 1 {
41 [Broadcast::Hstack(m2.ncols), Broadcast::Vstack(m1.nrows)]
42 } else if m1.nrows == 1 {
43 [Broadcast::IsScalar, Broadcast::None]
44 } else {
45 [Broadcast::Invalid, Broadcast::Invalid]
46 }
47 }
48 } else if m2.shape().contains(&1) {
49 let [b1, b2] = calc_broadcast_shape(m2, m1);
50 [b2, b1]
51 } else {
52 [Broadcast::Invalid, Broadcast::Invalid]
53 }
54}
55
56macro_rules! broadcast_op {
57 ($op: tt, $fnname: ident, $matmatfn: ident) => {
58 pub(crate) fn $fnname(m1: &Matrix, m2: &Matrix) -> Matrix {
59 let b = calc_broadcast_shape(m1, m2);
60 match b {
61 [Broadcast::None, Broadcast::None] => {
62 assert_eq!(m1.shape(), m2.shape());
63 $matmatfn(m1, m2)
65 }
66 [Broadcast::Hstack(hstack), Broadcast::None] => {
67 assert_eq!(hstack, m2.ncols);
68 let mut new = m2.clone();
69 for i in 0..m1.nrows {
75 new.apply_along_row(i, |x| m1[i][0] $op x)
76 }
77 new
78 }
79 [Broadcast::Vstack(vstack), Broadcast::None] => {
80 assert_eq!(vstack, m2.nrows);
81 let mut new = m2.clone();
82 for i in 0..new.nrows {
88 new[i].iter_mut().zip(&m1[0]).for_each(|(x, y)| *x = y $op *x);
89 }
90 new
91 }
92 [Broadcast::None, Broadcast::Hstack(hstack)] => {
93 assert_eq!(hstack, m1.ncols);
94 let mut new = m1.clone();
95 for i in 0..m2.nrows {
101 new.apply_along_row(i, |x| x $op m2[i][0])
102 }
103 new
104 },
105 [Broadcast::None, Broadcast::Vstack(vstack)] => {
106 assert_eq!(vstack, m1.nrows);
107 let mut new = m1.clone();
108 for i in 0..new.nrows {
114 new[i].iter_mut().zip(&m2[0]).for_each(|(x, y)| *x = *x $op y);
115 }
116 new
117 },
118 [Broadcast::Hstack(hstack), Broadcast::Vstack(vstack)] => {
119 assert_eq!(m2.ncols, hstack);
120 assert_eq!(m1.nrows, vstack);
121 assert_eq!(m2.nrows, 1);
122 assert_eq!(m1.ncols, 1);
123 let mut new = Matrix::zeros(m1.nrows, m2.ncols);
128 for i in 0..new.nrows {
129 for j in 0..new.ncols {
130 new[i][j] = m1[i][0] $op m2[0][j]
131 }
132 }
133 new
134 }
135 [Broadcast::Vstack(vstack), Broadcast::Hstack(hstack)] => {
136 assert_eq!(m1.ncols, hstack);
137 assert_eq!(m2.nrows, vstack);
138 assert_eq!(m1.nrows, 1);
139 assert_eq!(m2.ncols, 1);
140 let mut new = Matrix::zeros(m2.nrows, m1.ncols);
145 for i in 0..new.nrows {
146 for j in 0..new.ncols {
147 new[i][j] = m1[0][j] $op m2[i][0]
148 }
149 }
150 new
151 }
152 [Broadcast::IsScalar, _] => {
153 assert!(m1.nrows == 1 && m1.ncols == 1);
155 m1[0][0] $op m2
156 }
157 [_, Broadcast::IsScalar] => {
158 assert!(m2.nrows == 1 && m2.ncols == 1);
160 m1 $op m2[0][0]
161 }
162 _ => {
163 panic!("invalid broadcast shape")
166 }
167 }
168 }
169 };
170}
171
172broadcast_op!(+, broadcast_add, matmatadd);
173broadcast_op!(-, broadcast_sub, matmatsub);
174broadcast_op!(*, broadcast_mul, matmatmul);
175broadcast_op!(/, broadcast_div, matmatdiv);
176
177#[cfg(test)]
178mod tests {
179 use super::super::super::arange;
180 use super::super::Vector;
181 use super::*;
182
183 #[test]
184 fn test_broadcast_1() {
185 let mut a = Matrix::new([8., 9., 2., 5., 4., 9., 1., 6., 3.], 3, 3);
186 let mut b = Vector::new([1., 2., 3.]).to_matrix(); let c = broadcast_add(&a, &b);
188 assert_eq!(c, Matrix::new([9., 11., 5., 6., 6., 12., 2., 8., 6.], 3, 3));
189 let d = broadcast_mul(&a, &b);
190 assert_eq!(
191 d,
192 Matrix::new([8., 18., 6., 5., 8., 27., 1., 12., 9.], 3, 3)
193 );
194 b.t_mut(); let e = broadcast_sub(&b, &a);
196 assert_eq!(
197 e,
198 Matrix::new([-7., -8., -1., -3., -2., -7., 2., -3., 0.], 3, 3)
199 );
200 a.reshape_mut(1, -1); let f = broadcast_div(&a, &b);
202 assert_eq!(
203 f,
204 Matrix::new(
205 vec![
206 8.,
207 9.,
208 2.,
209 5.,
210 4.,
211 9.,
212 1.,
213 6.,
214 3.,
215 4.,
216 4.5,
217 1.,
218 2.5,
219 2.,
220 4.5,
221 0.5,
222 3.,
223 1.5,
224 2. + 2. / 3.,
225 3.,
226 2. / 3.,
227 1. + 2. / 3.,
228 1. + 1. / 3.,
229 3.,
230 1. / 3.,
231 2.,
232 1.
233 ],
234 3,
235 9
236 )
237 );
238 }
239
240 #[test]
241 fn test_broadcast_2() {
242 let a = Matrix::new(
243 [
244 -0.699, -1.031, 1.235, 0.328, 0.026, 0.046, 1.501, 0.438, 1.304, 0.728, 1., -0.417,
245 -0.265, 0.091, 0.422, 0.602,
246 ],
247 4,
248 4,
249 );
250 let b = Matrix::new([0.896, 0.488, 0.577, 0.316], 4, 1);
251 let c = broadcast_sub(&a, &b);
252 assert_eq!(
253 c,
254 Matrix::new(
255 [
256 -1.595, -1.927, 0.339, -0.568, -0.462, -0.442, 1.013, -0.05, 0.727, 0.151,
257 0.423, -0.994, -0.581, -0.225, 0.106, 0.286
258 ],
259 4,
260 4
261 )
262 );
263 let d = broadcast_sub(&a, &b.t());
264 assert_eq!(
265 d,
266 Matrix::new(
267 [
268 -1.595, -1.519, 0.658, 0.012, -0.87, -0.442, 0.924, 0.122, 0.408, 0.24, 0.423,
269 -0.733, -1.161, -0.397, -0.155, 0.286
270 ],
271 4,
272 4
273 )
274 );
275 let e = broadcast_div(&b, &a);
276 assert_eq!(
277 e,
278 Matrix::new(
279 [
280 -1.2818311874105868,
281 -0.86905916585839,
282 0.7255060728744939,
283 2.7317073170731705,
284 18.76923076923077,
285 10.608695652173912,
286 0.3251165889407062,
287 1.1141552511415524,
288 0.44248466257668706,
289 0.7925824175824175,
290 0.577,
291 -1.3836930455635492,
292 -1.1924528301886792,
293 3.4725274725274726,
294 0.7488151658767773,
295 0.5249169435215947
296 ],
297 4,
298 4
299 )
300 );
301 let f = broadcast_div(&a.t(), &b);
302 assert_eq!(
303 f,
304 Matrix::new(
305 [
306 -0.7801339285714285,
307 0.02901785714285714,
308 1.4553571428571428,
309 -0.2957589285714286,
310 -2.1127049180327866,
311 0.0942622950819672,
312 1.4918032786885247,
313 0.1864754098360656,
314 2.1403812824956674,
315 2.601386481802426,
316 1.733102253032929,
317 0.7313691507798961,
318 1.0379746835443038,
319 1.3860759493670887,
320 -1.3196202531645569,
321 1.9050632911392404
322 ],
323 4,
324 4
325 )
326 );
327 }
328
329 #[test]
330 fn test_broadcast_3() {
331 let a = Matrix::new([1., 2., 3., 4.], 2, 2);
332 let b = Matrix::new([3., 4., 1., 1.], 2, 2);
333 let c = broadcast_add(&a, &b);
334 assert_eq!(c, Matrix::new([4., 6., 4., 5.], 2, 2));
335 let d = broadcast_sub(&a, &b.t());
336 assert_eq!(d, Matrix::new([-2., 1., -1., 3.], 2, 2));
337 let e = broadcast_div(&a.t(), &b);
338 assert_eq!(e, Matrix::new([1. / 3., 0.75, 2., 4.], 2, 2));
339 }
340
341 #[test]
342 fn test_broadcast_4() {
343 let a = arange(0., 4., 1.).to_matrix().reshape(1, 4);
344 let b = arange(0., 4., 1.).to_matrix().reshape(4, 1);
345 let c = broadcast_sub(&a, &b);
346 assert_eq!(
347 c,
348 Matrix::new(
349 [0., 1., 2., 3., -1., 0., 1., 2., -2., -1., 0., 1., -3., -2., -1., 0.],
350 4,
351 4
352 )
353 );
354 let d = broadcast_mul(&a, &b.t());
355 assert_eq!(d, Matrix::new([0., 1., 4., 9.], 1, 4));
356 let e = broadcast_add(&b, &a.t());
357 assert_eq!(e, Matrix::new([0., 2., 4., 6.], 4, 1));
358 }
359}