compute/linalg/array/
broadcast.rs

1use super::{matmatadd, matmatdiv, matmatmul, matmatsub, Matrix};
2
3#[derive(Debug, Clone, Copy)]
4pub enum Broadcast {
5    Hstack(usize),
6    Vstack(usize),
7    IsScalar,
8    None,
9    Invalid,
10}
11
12pub(crate) fn calc_broadcast_shape(m1: &Matrix, m2: &Matrix) -> [Broadcast; 2] {
13    if m1.shape() == m2.shape() {
14        [Broadcast::None, Broadcast::None]
15    } else if m1.shape().contains(&1) {
16        if m1.nrows == 1 {
17            assert!(
18                m1.ncols == m2.ncols // single vstack broadcast for m1
19                || m2.ncols == 1 // vstack broadcast m1 and hstack broadcast m2
20                || m1.ncols == 1 // m1 is just a single element matrix, vstack and hstack it.
21            );
22            if m1.ncols == m2.ncols {
23                [Broadcast::Vstack(m2.nrows), Broadcast::None]
24            } else if m2.ncols == 1 {
25                [Broadcast::Vstack(m2.nrows), Broadcast::Hstack(m1.ncols)]
26            } else if m1.ncols == 1 {
27                [Broadcast::IsScalar, Broadcast::None]
28            } else {
29                [Broadcast::Invalid, Broadcast::Invalid]
30            }
31        } else {
32            // m1.ncols == 1
33            assert!(
34                m1.nrows == m2.nrows // single hstack broadcast for m1
35                || m2.nrows == 1 // hstack broadcast m1 and vstack broadcast m2
36                || m1.nrows == 1 // m1 is just a single element matrix
37            );
38            if m1.nrows == m2.nrows {
39                [Broadcast::Hstack(m2.ncols), Broadcast::None]
40            } else if m2.nrows == 1 {
41                [Broadcast::Hstack(m2.ncols), Broadcast::Vstack(m1.nrows)]
42            } else if m1.nrows == 1 {
43                [Broadcast::IsScalar, Broadcast::None]
44            } else {
45                [Broadcast::Invalid, Broadcast::Invalid]
46            }
47        }
48    } else if m2.shape().contains(&1) {
49        let [b1, b2] = calc_broadcast_shape(m2, m1);
50        [b2, b1]
51    } else {
52        [Broadcast::Invalid, Broadcast::Invalid]
53    }
54}
55
56macro_rules! broadcast_op {
57    ($op: tt, $fnname: ident, $matmatfn: ident) => {
58        pub(crate) fn $fnname(m1: &Matrix, m2: &Matrix) -> Matrix {
59            let b = calc_broadcast_shape(m1, m2);
60            match b {
61                [Broadcast::None, Broadcast::None] => {
62                    assert_eq!(m1.shape(), m2.shape());
63                    // easy, do nothing special
64                    $matmatfn(m1, m2)
65                }
66                [Broadcast::Hstack(hstack), Broadcast::None] => {
67                    assert_eq!(hstack, m2.ncols);
68                    let mut new = m2.clone();
69                    // each element in m1 gets added to each row in m2
70                    // .    ....
71                    // .    ....
72                    // .    ....
73                    // .    ....
74                    for i in 0..m1.nrows {
75                        new.apply_along_row(i, |x| m1[i][0] $op x)
76                    }
77                    new
78                }
79                [Broadcast::Vstack(vstack), Broadcast::None] => {
80                    assert_eq!(vstack, m2.nrows);
81                    let mut new = m2.clone();
82                    // each element in m1 gets added to each column in m2
83                    // ....     ....
84                    //          ....
85                    //          ....
86                    //          ....
87                    for i in 0..new.nrows {
88                        new[i].iter_mut().zip(&m1[0]).for_each(|(x, y)| *x = y $op *x);
89                    }
90                    new
91                }
92                [Broadcast::None, Broadcast::Hstack(hstack)] => {
93                    assert_eq!(hstack, m1.ncols);
94                    let mut new = m1.clone();
95                    // each element in m1 gets added to each row in m2
96                    // ....   .
97                    // ....   .
98                    // ....   .
99                    // ....   .
100                    for i in 0..m2.nrows {
101                        new.apply_along_row(i, |x| x $op m2[i][0])
102                    }
103                    new
104                },
105                [Broadcast::None, Broadcast::Vstack(vstack)] => {
106                    assert_eq!(vstack, m1.nrows);
107                    let mut new = m1.clone();
108                    // each element in m1 gets added to each column in m2
109                    // ....   ....
110                    // ....
111                    // ....
112                    // ....
113                    for i in 0..new.nrows {
114                        new[i].iter_mut().zip(&m2[0]).for_each(|(x, y)| *x = *x $op y);
115                    }
116                    new
117                },
118                [Broadcast::Hstack(hstack), Broadcast::Vstack(vstack)] => {
119                    assert_eq!(m2.ncols, hstack);
120                    assert_eq!(m1.nrows, vstack);
121                    assert_eq!(m2.nrows, 1);
122                    assert_eq!(m1.ncols, 1);
123                    // .  .  .  .
124                    // .
125                    // .
126                    // .
127                    let mut new = Matrix::zeros(m1.nrows, m2.ncols);
128                    for i in 0..new.nrows {
129                        for j in 0..new.ncols {
130                            new[i][j] = m1[i][0] $op m2[0][j]
131                        }
132                    }
133                    new
134                }
135                [Broadcast::Vstack(vstack), Broadcast::Hstack(hstack)] => {
136                    assert_eq!(m1.ncols, hstack);
137                    assert_eq!(m2.nrows, vstack);
138                    assert_eq!(m1.nrows, 1);
139                    assert_eq!(m2.ncols, 1);
140                    // .  .  .  .
141                    //          .
142                    //          .
143                    //          .
144                    let mut new = Matrix::zeros(m2.nrows, m1.ncols);
145                    for i in 0..new.nrows {
146                        for j in 0..new.ncols {
147                            new[i][j] = m1[0][j] $op m2[i][0]
148                        }
149                    }
150                    new
151                }
152                [Broadcast::IsScalar, _] => {
153                    // 2nd broadcast should always be Broadcast::None
154                    assert!(m1.nrows == 1 && m1.ncols == 1);
155                    m1[0][0] $op m2
156                }
157                [_, Broadcast::IsScalar] => {
158                    // 1st broadcast should always be Broadcast::None
159                    assert!(m2.nrows == 1 && m2.ncols == 1);
160                    m1 $op m2[0][0]
161                }
162                _ => {
163                    // one of them is invalid, or have [hstack, hstack] or [vstack, vstack], neither
164                    // of which are possible
165                    panic!("invalid broadcast shape")
166                }
167            }
168        }
169    };
170}
171
172broadcast_op!(+, broadcast_add, matmatadd);
173broadcast_op!(-, broadcast_sub, matmatsub);
174broadcast_op!(*, broadcast_mul, matmatmul);
175broadcast_op!(/, broadcast_div, matmatdiv);
176
177#[cfg(test)]
178mod tests {
179    use super::super::super::arange;
180    use super::super::Vector;
181    use super::*;
182
183    #[test]
184    fn test_broadcast_1() {
185        let mut a = Matrix::new([8., 9., 2., 5., 4., 9., 1., 6., 3.], 3, 3);
186        let mut b = Vector::new([1., 2., 3.]).to_matrix(); // 1x3 matrix
187        let c = broadcast_add(&a, &b);
188        assert_eq!(c, Matrix::new([9., 11., 5., 6., 6., 12., 2., 8., 6.], 3, 3));
189        let d = broadcast_mul(&a, &b);
190        assert_eq!(
191            d,
192            Matrix::new([8., 18., 6., 5., 8., 27., 1., 12., 9.], 3, 3)
193        );
194        b.t_mut(); // 3x1
195        let e = broadcast_sub(&b, &a);
196        assert_eq!(
197            e,
198            Matrix::new([-7., -8., -1., -3., -2., -7., 2., -3., 0.], 3, 3)
199        );
200        a.reshape_mut(1, -1); // flatten to 9x1
201        let f = broadcast_div(&a, &b);
202        assert_eq!(
203            f,
204            Matrix::new(
205                vec![
206                    8.,
207                    9.,
208                    2.,
209                    5.,
210                    4.,
211                    9.,
212                    1.,
213                    6.,
214                    3.,
215                    4.,
216                    4.5,
217                    1.,
218                    2.5,
219                    2.,
220                    4.5,
221                    0.5,
222                    3.,
223                    1.5,
224                    2. + 2. / 3.,
225                    3.,
226                    2. / 3.,
227                    1. + 2. / 3.,
228                    1. + 1. / 3.,
229                    3.,
230                    1. / 3.,
231                    2.,
232                    1.
233                ],
234                3,
235                9
236            )
237        );
238    }
239
240    #[test]
241    fn test_broadcast_2() {
242        let a = Matrix::new(
243            [
244                -0.699, -1.031, 1.235, 0.328, 0.026, 0.046, 1.501, 0.438, 1.304, 0.728, 1., -0.417,
245                -0.265, 0.091, 0.422, 0.602,
246            ],
247            4,
248            4,
249        );
250        let b = Matrix::new([0.896, 0.488, 0.577, 0.316], 4, 1);
251        let c = broadcast_sub(&a, &b);
252        assert_eq!(
253            c,
254            Matrix::new(
255                [
256                    -1.595, -1.927, 0.339, -0.568, -0.462, -0.442, 1.013, -0.05, 0.727, 0.151,
257                    0.423, -0.994, -0.581, -0.225, 0.106, 0.286
258                ],
259                4,
260                4
261            )
262        );
263        let d = broadcast_sub(&a, &b.t());
264        assert_eq!(
265            d,
266            Matrix::new(
267                [
268                    -1.595, -1.519, 0.658, 0.012, -0.87, -0.442, 0.924, 0.122, 0.408, 0.24, 0.423,
269                    -0.733, -1.161, -0.397, -0.155, 0.286
270                ],
271                4,
272                4
273            )
274        );
275        let e = broadcast_div(&b, &a);
276        assert_eq!(
277            e,
278            Matrix::new(
279                [
280                    -1.2818311874105868,
281                    -0.86905916585839,
282                    0.7255060728744939,
283                    2.7317073170731705,
284                    18.76923076923077,
285                    10.608695652173912,
286                    0.3251165889407062,
287                    1.1141552511415524,
288                    0.44248466257668706,
289                    0.7925824175824175,
290                    0.577,
291                    -1.3836930455635492,
292                    -1.1924528301886792,
293                    3.4725274725274726,
294                    0.7488151658767773,
295                    0.5249169435215947
296                ],
297                4,
298                4
299            )
300        );
301        let f = broadcast_div(&a.t(), &b);
302        assert_eq!(
303            f,
304            Matrix::new(
305                [
306                    -0.7801339285714285,
307                    0.02901785714285714,
308                    1.4553571428571428,
309                    -0.2957589285714286,
310                    -2.1127049180327866,
311                    0.0942622950819672,
312                    1.4918032786885247,
313                    0.1864754098360656,
314                    2.1403812824956674,
315                    2.601386481802426,
316                    1.733102253032929,
317                    0.7313691507798961,
318                    1.0379746835443038,
319                    1.3860759493670887,
320                    -1.3196202531645569,
321                    1.9050632911392404
322                ],
323                4,
324                4
325            )
326        );
327    }
328
329    #[test]
330    fn test_broadcast_3() {
331        let a = Matrix::new([1., 2., 3., 4.], 2, 2);
332        let b = Matrix::new([3., 4., 1., 1.], 2, 2);
333        let c = broadcast_add(&a, &b);
334        assert_eq!(c, Matrix::new([4., 6., 4., 5.], 2, 2));
335        let d = broadcast_sub(&a, &b.t());
336        assert_eq!(d, Matrix::new([-2., 1., -1., 3.], 2, 2));
337        let e = broadcast_div(&a.t(), &b);
338        assert_eq!(e, Matrix::new([1. / 3., 0.75, 2., 4.], 2, 2));
339    }
340
341    #[test]
342    fn test_broadcast_4() {
343        let a = arange(0., 4., 1.).to_matrix().reshape(1, 4);
344        let b = arange(0., 4., 1.).to_matrix().reshape(4, 1);
345        let c = broadcast_sub(&a, &b);
346        assert_eq!(
347            c,
348            Matrix::new(
349                [0., 1., 2., 3., -1., 0., 1., 2., -2., -1., 0., 1., -3., -2., -1., 0.],
350                4,
351                4
352            )
353        );
354        let d = broadcast_mul(&a, &b.t());
355        assert_eq!(d, Matrix::new([0., 1., 4., 9.], 1, 4));
356        let e = broadcast_add(&b, &a.t());
357        assert_eq!(e, Matrix::new([0., 2., 4., 6.], 4, 1));
358    }
359}