Skip to main content

burn_ndarray/ops/
matmul.rs

1use crate::UnsafeSharedRef;
2use crate::{NdArrayElement, ShapeOps, SharedArray, iter_range_par, ops::NdArrayOps, run_par};
3
4use alloc::{vec, vec::Vec};
5use burn_backend::ElementConversion;
6use burn_backend::Shape;
7use ndarray::{IxDyn, s};
8
9pub(crate) fn matmul<E: NdArrayElement>(
10    lhs: SharedArray<E>,
11    rhs: SharedArray<E>,
12) -> SharedArray<E> {
13    let shape_lhs = lhs.shape();
14    let shape_rhs = rhs.shape();
15    let ndims = shape_lhs.num_dims();
16    let m = shape_lhs[ndims - 2]; // # of left rows
17    let k = shape_rhs[ndims - 2]; // # of left cols and right rows
18    let n = shape_rhs[ndims - 1]; // # of right cols
19
20    let (out_shape, strides_lhs, strides_rhs, strides_out) = output_shape(shape_lhs, shape_rhs);
21    let l_mat_size = m * k; // size of matrix component of left array
22    let r_mat_size = k * n; // size of matrix component of right array
23    let out_mat_size = m * n; // size of matrix component of output array
24
25    let num_l_batches = shape_lhs.num_elements() / l_mat_size;
26    let num_r_batches = shape_rhs.num_elements() / r_mat_size;
27    let num_out_batches = out_shape.num_elements() / out_mat_size;
28
29    let lhs_array = NdArrayOps::reshape(lhs, Shape::new([num_l_batches, m, k]));
30    let rhs_array = NdArrayOps::reshape(rhs, Shape::new([num_r_batches, k, n]));
31
32    let alpha: E = 1.0.elem();
33    let beta: E = 0.0.elem();
34
35    let out = run_par!(|| {
36        let mut out_array = ndarray::Array3::<E>::zeros((num_out_batches, m, n));
37        let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array);
38
39        iter_range_par!(0, num_out_batches).for_each(|out_batch| {
40            // Here, we:
41            //   1. Un-flatten the output batch into a component-based batch index.
42            //   2. Use the strides for left and right batch indices to convert it to a flattened
43            //      batch for left and right.
44            let out_index = strides_out.unflatten(out_batch);
45            let l_batch = strides_lhs.flatten(&out_index);
46            let r_batch = strides_rhs.flatten(&out_index);
47
48            let lhs_slice = lhs_array.slice(s!(l_batch, .., ..));
49            let rhs_slice = rhs_array.slice(s!(r_batch, .., ..));
50
51            unsafe {
52                let mut out_slice = unsafe_shared_out_array
53                    .get()
54                    .slice_mut(s!(out_batch, .., ..));
55
56                ndarray::linalg::general_mat_mul(
57                    alpha,
58                    &lhs_slice,
59                    &rhs_slice,
60                    beta,
61                    &mut out_slice,
62                )
63            }
64        });
65
66        out_array.into_shared().into_dyn()
67    });
68
69    NdArrayOps::reshape(out, out_shape)
70}
71
72#[derive(Debug, PartialEq)]
73struct Strides {
74    strides: Vec<usize>,
75}
76impl Strides {
77    fn new(strides: Vec<usize>) -> Self {
78        Strides { strides }
79    }
80
81    fn unflatten(&self, linear_index: usize) -> Vec<usize> {
82        let mut coord = Vec::with_capacity(self.strides.len());
83        let mut rem = linear_index;
84        for stride in self.strides.iter() {
85            coord.push(rem / stride);
86            rem %= stride;
87        }
88        coord
89    }
90
91    fn flatten(&self, index: &Vec<usize>) -> usize {
92        assert_eq!(self.strides.len(), index.len());
93        self.strides
94            .iter()
95            .zip(index)
96            .map(|(stride, index)| stride * index)
97            .sum()
98    }
99}
100
101/// Compute the (broadcasted) output shape of matrix multiplication, along with strides for
102/// the non-matrix dimensions of all arrays.
103///
104/// # Arguments
105/// * `lsh`: Shape of the first (left-hand) matrix multiplication argument.
106/// * `rsh`: Shape of the second (right-hand) matrix multiplication argument.
107///
108/// # Panics
109/// * If `D` is not at least 2.
110/// * If the matrix multiplication dimensions (last 2) are incompatible.
111/// * If any other dimension is not the same for both tensors, or equal to 1. (Any dimension where
112///   one dim is equal to 1 is broadcast.)
113fn output_shape(lsh: &[usize], rsh: &[usize]) -> (Shape, Strides, Strides, Strides) {
114    let ndims = lsh.num_dims();
115    if ndims < 2 {
116        panic!("Matrix multiplication requires an array with at least 2 dimensions.");
117    }
118
119    // Fetch matrix dimensions and check compatibility.
120    let l_rows = lsh[ndims - 2];
121    let l_cols = lsh[ndims - 1];
122    let r_rows = rsh[ndims - 2];
123    let r_cols = rsh[ndims - 1];
124    if l_cols != r_rows {
125        panic!("Dimensions are incompatible for matrix multiplication.");
126    }
127    // Set matrix dimensions of the output shape.
128    let mut osh = vec![0; ndims];
129    osh[ndims - 2] = l_rows;
130    osh[ndims - 1] = r_cols;
131
132    // Set other array dimensions, broadcasting as necessary.
133    // Compute the strides inline.
134    let mut cur_l_stride: usize = 1;
135    let mut cur_r_stride: usize = 1;
136    let mut cur_o_stride: usize = 1;
137    let mut l_strides = Vec::with_capacity(ndims - 2);
138    let mut r_strides = Vec::with_capacity(ndims - 2);
139    let mut o_strides = Vec::with_capacity(ndims - 2);
140    for i in (0..ndims - 2).rev() {
141        let l_dim = lsh[i];
142        let r_dim = rsh[i];
143
144        // Compatible dimensions are:
145        //   1. Both dimensions are equal.
146        //   2. One of the dimensions is equal to 1.
147        let o_dim: usize;
148        if l_dim == r_dim {
149            o_dim = l_dim; // both dimensions are equal
150            l_strides.push(cur_l_stride);
151            r_strides.push(cur_r_stride);
152        } else if l_dim == 1 {
153            o_dim = r_dim; // broadcast the left
154            l_strides.push(0);
155            r_strides.push(cur_r_stride);
156        } else if r_dim == 1 {
157            o_dim = l_dim; // broadcast the right
158            l_strides.push(cur_l_stride);
159            r_strides.push(0);
160        } else {
161            panic!("Dimensions differ and cannot be broadcasted.");
162        }
163        osh[i] = o_dim;
164        o_strides.push(cur_o_stride);
165        cur_o_stride *= o_dim;
166
167        cur_l_stride *= l_dim;
168        cur_r_stride *= r_dim;
169    }
170    l_strides.reverse();
171    r_strides.reverse();
172    o_strides.reverse();
173
174    (
175        Shape::from(osh),
176        Strides::new(l_strides),
177        Strides::new(r_strides),
178        Strides::new(o_strides),
179    )
180}
181
182pub(crate) fn cross<E: NdArrayElement>(
183    lhs: SharedArray<E>,
184    rhs: SharedArray<E>,
185    dim: usize,
186) -> SharedArray<E> {
187    let shape_lhs = lhs.shape();
188    let shape_rhs = rhs.shape();
189    let ndims = shape_lhs.num_dims();
190
191    // Broadcast the shapes except along dim
192    let mut broadcast_shape = vec![0; ndims];
193    for i in 0..ndims {
194        if i == dim {
195            broadcast_shape[i] = shape_lhs[i]; // already checked to be 3
196        } else {
197            let l = shape_lhs[i];
198            let r = shape_rhs[i];
199            if l == r {
200                broadcast_shape[i] = l;
201            } else if l == 1 {
202                broadcast_shape[i] = r;
203            } else if r == 1 {
204                broadcast_shape[i] = l;
205            } else {
206                panic!("Tensors are not broadcastable along dimension {}", i);
207            }
208        }
209    }
210
211    // Broadcast lhs and rhs
212    let lhs_broadcast = if shape_lhs == broadcast_shape.as_slice() {
213        lhs
214    } else {
215        NdArrayOps::expand(lhs, Shape::from(broadcast_shape.clone()))
216    };
217    let rhs_broadcast = if shape_rhs == broadcast_shape.as_slice() {
218        rhs
219    } else {
220        NdArrayOps::expand(rhs, Shape::from(broadcast_shape.clone()))
221    };
222
223    // Now, move dim to the last dimension
224    let mut perm = (0..ndims).collect::<Vec<_>>();
225    perm.remove(dim);
226    perm.push(dim);
227
228    let lhs_permuted = NdArrayOps::permute(lhs_broadcast, &perm);
229    let rhs_permuted = NdArrayOps::permute(rhs_broadcast, &perm);
230
231    // Reshape to (*, 3)
232    let total_elements = lhs_permuted.shape().num_elements();
233    let batch_size = total_elements / 3;
234    let lhs_reshaped = NdArrayOps::reshape(lhs_permuted, Shape::new([batch_size, 3]));
235    let rhs_reshaped = NdArrayOps::reshape(rhs_permuted, Shape::new([batch_size, 3]));
236
237    // Compute cross product
238    let mut result = ndarray::ArrayD::<E>::zeros(IxDyn(&[batch_size, 3]));
239    for i in 0..batch_size {
240        let a1 = lhs_reshaped[IxDyn(&[i, 0])];
241        let a2 = lhs_reshaped[IxDyn(&[i, 1])];
242        let a3 = lhs_reshaped[IxDyn(&[i, 2])];
243        let b1 = rhs_reshaped[IxDyn(&[i, 0])];
244        let b2 = rhs_reshaped[IxDyn(&[i, 1])];
245        let b3 = rhs_reshaped[IxDyn(&[i, 2])];
246        result[IxDyn(&[i, 0])] = a2.mul(b3).sub(a3.mul(b2));
247        result[IxDyn(&[i, 1])] = a3.mul(b1).sub(a1.mul(b3));
248        result[IxDyn(&[i, 2])] = a1.mul(b2).sub(a2.mul(b1));
249    }
250
251    let result_shared = result.into_shared();
252
253    // Reshape back to the broadcast shape with dim at the end
254    let mut result_shape = broadcast_shape;
255    result_shape.remove(dim);
256    result_shape.push(3);
257    let result_reshaped = NdArrayOps::reshape(result_shared, Shape::from(result_shape));
258
259    // Permute back
260    let mut inv_perm = vec![0; ndims];
261    for (i, &p) in perm.iter().enumerate() {
262        inv_perm[p] = i;
263    }
264    NdArrayOps::permute(result_reshaped, &inv_perm)
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    impl Strides {
272        fn empty() -> Self {
273            Strides {
274                strides: Vec::with_capacity(0),
275            }
276        }
277    }
278
279    #[test]
280    fn test_output_shape() {
281        // plain matrix multiply
282        assert_eq!(
283            output_shape(&[5, 3], &[3, 7]),
284            (
285                Shape::from([5, 7]),
286                Strides::empty(),
287                Strides::empty(),
288                Strides::empty()
289            )
290        );
291        // matrix multiply with one extra stack dimension
292        assert_eq!(
293            output_shape(&[4, 5, 3], &[4, 3, 7]),
294            (
295                Shape::from([4, 5, 7]),
296                Strides::new(vec![1]),
297                Strides::new(vec![1]),
298                Strides::new(vec![1])
299            )
300        );
301        // rank 3, broadcast left
302        assert_eq!(
303            output_shape(&[1, 5, 3], &[4, 3, 7]),
304            (
305                Shape::from([4, 5, 7]),
306                Strides::new(vec![0]),
307                Strides::new(vec![1]),
308                Strides::new(vec![1])
309            )
310        );
311        // rank 3, broadcast right
312        assert_eq!(
313            output_shape(&[4, 5, 3], &[1, 3, 7]),
314            (
315                Shape::from([4, 5, 7]),
316                Strides::new(vec![1]),
317                Strides::new(vec![0]),
318                Strides::new(vec![1])
319            )
320        );
321        // rank 4, multi broadcast
322        assert_eq!(
323            output_shape(&[1, 4, 5, 3], &[8, 1, 3, 7]),
324            (
325                Shape::from([8, 4, 5, 7]),
326                Strides::new(vec![0, 1]),
327                Strides::new(vec![1, 0]),
328                Strides::new(vec![4, 1])
329            )
330        );
331        // rank 5, multi-broadcast
332        assert_eq!(
333            output_shape(&[1, 3, 4, 5, 3], &[8, 3, 1, 3, 7]),
334            (
335                Shape::from([8, 3, 4, 5, 7]),
336                Strides::new(vec![0, 4, 1]),
337                Strides::new(vec![3, 1, 0]),
338                Strides::new(vec![12, 4, 1])
339            )
340        )
341    }
342
343    #[test]
344    #[should_panic(
345        expected = "Matrix multiplication requires an array with at least 2 dimensions."
346    )]
347    fn test_output_shape_too_small() {
348        output_shape(&[4], &[4]);
349    }
350
351    #[test]
352    #[should_panic(expected = "Dimensions are incompatible for matrix multiplication.")]
353    fn test_output_shape_bad_matrix_dims() {
354        output_shape(&[5, 3], &[4, 7]);
355    }
356
357    #[test]
358    #[should_panic(expected = "Dimensions differ and cannot be broadcasted.")]
359    fn test_output_shape_non_broadcast() {
360        output_shape(&[4, 5, 3], &[2, 3, 7]);
361    }
362}