1use crate::UnsafeSharedRef;
2use crate::{NdArrayElement, ShapeOps, SharedArray, iter_range_par, ops::NdArrayOps, run_par};
3
4use alloc::{vec, vec::Vec};
5use burn_backend::ElementConversion;
6use burn_backend::Shape;
7use ndarray::{IxDyn, s};
8
9pub(crate) fn matmul<E: NdArrayElement>(
10 lhs: SharedArray<E>,
11 rhs: SharedArray<E>,
12) -> SharedArray<E> {
13 let shape_lhs = lhs.shape();
14 let shape_rhs = rhs.shape();
15 let ndims = shape_lhs.num_dims();
16 let m = shape_lhs[ndims - 2]; let k = shape_rhs[ndims - 2]; let n = shape_rhs[ndims - 1]; let (out_shape, strides_lhs, strides_rhs, strides_out) = output_shape(shape_lhs, shape_rhs);
21 let l_mat_size = m * k; let r_mat_size = k * n; let out_mat_size = m * n; let num_l_batches = shape_lhs.num_elements() / l_mat_size;
26 let num_r_batches = shape_rhs.num_elements() / r_mat_size;
27 let num_out_batches = out_shape.num_elements() / out_mat_size;
28
29 let lhs_array = NdArrayOps::reshape(lhs, Shape::new([num_l_batches, m, k]));
30 let rhs_array = NdArrayOps::reshape(rhs, Shape::new([num_r_batches, k, n]));
31
32 let alpha: E = 1.0.elem();
33 let beta: E = 0.0.elem();
34
35 let out = run_par!(|| {
36 let mut out_array = ndarray::Array3::<E>::zeros((num_out_batches, m, n));
37 let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array);
38
39 iter_range_par!(0, num_out_batches).for_each(|out_batch| {
40 let out_index = strides_out.unflatten(out_batch);
45 let l_batch = strides_lhs.flatten(&out_index);
46 let r_batch = strides_rhs.flatten(&out_index);
47
48 let lhs_slice = lhs_array.slice(s!(l_batch, .., ..));
49 let rhs_slice = rhs_array.slice(s!(r_batch, .., ..));
50
51 unsafe {
52 let mut out_slice = unsafe_shared_out_array
53 .get()
54 .slice_mut(s!(out_batch, .., ..));
55
56 ndarray::linalg::general_mat_mul(
57 alpha,
58 &lhs_slice,
59 &rhs_slice,
60 beta,
61 &mut out_slice,
62 )
63 }
64 });
65
66 out_array.into_shared().into_dyn()
67 });
68
69 NdArrayOps::reshape(out, out_shape)
70}
71
72#[derive(Debug, PartialEq)]
73struct Strides {
74 strides: Vec<usize>,
75}
76impl Strides {
77 fn new(strides: Vec<usize>) -> Self {
78 Strides { strides }
79 }
80
81 fn unflatten(&self, linear_index: usize) -> Vec<usize> {
82 let mut coord = Vec::with_capacity(self.strides.len());
83 let mut rem = linear_index;
84 for stride in self.strides.iter() {
85 coord.push(rem / stride);
86 rem %= stride;
87 }
88 coord
89 }
90
91 fn flatten(&self, index: &Vec<usize>) -> usize {
92 assert_eq!(self.strides.len(), index.len());
93 self.strides
94 .iter()
95 .zip(index)
96 .map(|(stride, index)| stride * index)
97 .sum()
98 }
99}
100
101fn output_shape(lsh: &[usize], rsh: &[usize]) -> (Shape, Strides, Strides, Strides) {
114 let ndims = lsh.num_dims();
115 if ndims < 2 {
116 panic!("Matrix multiplication requires an array with at least 2 dimensions.");
117 }
118
119 let l_rows = lsh[ndims - 2];
121 let l_cols = lsh[ndims - 1];
122 let r_rows = rsh[ndims - 2];
123 let r_cols = rsh[ndims - 1];
124 if l_cols != r_rows {
125 panic!("Dimensions are incompatible for matrix multiplication.");
126 }
127 let mut osh = vec![0; ndims];
129 osh[ndims - 2] = l_rows;
130 osh[ndims - 1] = r_cols;
131
132 let mut cur_l_stride: usize = 1;
135 let mut cur_r_stride: usize = 1;
136 let mut cur_o_stride: usize = 1;
137 let mut l_strides = Vec::with_capacity(ndims - 2);
138 let mut r_strides = Vec::with_capacity(ndims - 2);
139 let mut o_strides = Vec::with_capacity(ndims - 2);
140 for i in (0..ndims - 2).rev() {
141 let l_dim = lsh[i];
142 let r_dim = rsh[i];
143
144 let o_dim: usize;
148 if l_dim == r_dim {
149 o_dim = l_dim; l_strides.push(cur_l_stride);
151 r_strides.push(cur_r_stride);
152 } else if l_dim == 1 {
153 o_dim = r_dim; l_strides.push(0);
155 r_strides.push(cur_r_stride);
156 } else if r_dim == 1 {
157 o_dim = l_dim; l_strides.push(cur_l_stride);
159 r_strides.push(0);
160 } else {
161 panic!("Dimensions differ and cannot be broadcasted.");
162 }
163 osh[i] = o_dim;
164 o_strides.push(cur_o_stride);
165 cur_o_stride *= o_dim;
166
167 cur_l_stride *= l_dim;
168 cur_r_stride *= r_dim;
169 }
170 l_strides.reverse();
171 r_strides.reverse();
172 o_strides.reverse();
173
174 (
175 Shape::from(osh),
176 Strides::new(l_strides),
177 Strides::new(r_strides),
178 Strides::new(o_strides),
179 )
180}
181
182pub(crate) fn cross<E: NdArrayElement>(
183 lhs: SharedArray<E>,
184 rhs: SharedArray<E>,
185 dim: usize,
186) -> SharedArray<E> {
187 let shape_lhs = lhs.shape();
188 let shape_rhs = rhs.shape();
189 let ndims = shape_lhs.num_dims();
190
191 let mut broadcast_shape = vec![0; ndims];
193 for i in 0..ndims {
194 if i == dim {
195 broadcast_shape[i] = shape_lhs[i]; } else {
197 let l = shape_lhs[i];
198 let r = shape_rhs[i];
199 if l == r {
200 broadcast_shape[i] = l;
201 } else if l == 1 {
202 broadcast_shape[i] = r;
203 } else if r == 1 {
204 broadcast_shape[i] = l;
205 } else {
206 panic!("Tensors are not broadcastable along dimension {}", i);
207 }
208 }
209 }
210
211 let lhs_broadcast = if shape_lhs == broadcast_shape.as_slice() {
213 lhs
214 } else {
215 NdArrayOps::expand(lhs, Shape::from(broadcast_shape.clone()))
216 };
217 let rhs_broadcast = if shape_rhs == broadcast_shape.as_slice() {
218 rhs
219 } else {
220 NdArrayOps::expand(rhs, Shape::from(broadcast_shape.clone()))
221 };
222
223 let mut perm = (0..ndims).collect::<Vec<_>>();
225 perm.remove(dim);
226 perm.push(dim);
227
228 let lhs_permuted = NdArrayOps::permute(lhs_broadcast, &perm);
229 let rhs_permuted = NdArrayOps::permute(rhs_broadcast, &perm);
230
231 let total_elements = lhs_permuted.shape().num_elements();
233 let batch_size = total_elements / 3;
234 let lhs_reshaped = NdArrayOps::reshape(lhs_permuted, Shape::new([batch_size, 3]));
235 let rhs_reshaped = NdArrayOps::reshape(rhs_permuted, Shape::new([batch_size, 3]));
236
237 let mut result = ndarray::ArrayD::<E>::zeros(IxDyn(&[batch_size, 3]));
239 for i in 0..batch_size {
240 let a1 = lhs_reshaped[IxDyn(&[i, 0])];
241 let a2 = lhs_reshaped[IxDyn(&[i, 1])];
242 let a3 = lhs_reshaped[IxDyn(&[i, 2])];
243 let b1 = rhs_reshaped[IxDyn(&[i, 0])];
244 let b2 = rhs_reshaped[IxDyn(&[i, 1])];
245 let b3 = rhs_reshaped[IxDyn(&[i, 2])];
246 result[IxDyn(&[i, 0])] = a2.mul(b3).sub(a3.mul(b2));
247 result[IxDyn(&[i, 1])] = a3.mul(b1).sub(a1.mul(b3));
248 result[IxDyn(&[i, 2])] = a1.mul(b2).sub(a2.mul(b1));
249 }
250
251 let result_shared = result.into_shared();
252
253 let mut result_shape = broadcast_shape;
255 result_shape.remove(dim);
256 result_shape.push(3);
257 let result_reshaped = NdArrayOps::reshape(result_shared, Shape::from(result_shape));
258
259 let mut inv_perm = vec![0; ndims];
261 for (i, &p) in perm.iter().enumerate() {
262 inv_perm[p] = i;
263 }
264 NdArrayOps::permute(result_reshaped, &inv_perm)
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 impl Strides {
272 fn empty() -> Self {
273 Strides {
274 strides: Vec::with_capacity(0),
275 }
276 }
277 }
278
279 #[test]
280 fn test_output_shape() {
281 assert_eq!(
283 output_shape(&[5, 3], &[3, 7]),
284 (
285 Shape::from([5, 7]),
286 Strides::empty(),
287 Strides::empty(),
288 Strides::empty()
289 )
290 );
291 assert_eq!(
293 output_shape(&[4, 5, 3], &[4, 3, 7]),
294 (
295 Shape::from([4, 5, 7]),
296 Strides::new(vec![1]),
297 Strides::new(vec![1]),
298 Strides::new(vec![1])
299 )
300 );
301 assert_eq!(
303 output_shape(&[1, 5, 3], &[4, 3, 7]),
304 (
305 Shape::from([4, 5, 7]),
306 Strides::new(vec![0]),
307 Strides::new(vec![1]),
308 Strides::new(vec![1])
309 )
310 );
311 assert_eq!(
313 output_shape(&[4, 5, 3], &[1, 3, 7]),
314 (
315 Shape::from([4, 5, 7]),
316 Strides::new(vec![1]),
317 Strides::new(vec![0]),
318 Strides::new(vec![1])
319 )
320 );
321 assert_eq!(
323 output_shape(&[1, 4, 5, 3], &[8, 1, 3, 7]),
324 (
325 Shape::from([8, 4, 5, 7]),
326 Strides::new(vec![0, 1]),
327 Strides::new(vec![1, 0]),
328 Strides::new(vec![4, 1])
329 )
330 );
331 assert_eq!(
333 output_shape(&[1, 3, 4, 5, 3], &[8, 3, 1, 3, 7]),
334 (
335 Shape::from([8, 3, 4, 5, 7]),
336 Strides::new(vec![0, 4, 1]),
337 Strides::new(vec![3, 1, 0]),
338 Strides::new(vec![12, 4, 1])
339 )
340 )
341 }
342
343 #[test]
344 #[should_panic(
345 expected = "Matrix multiplication requires an array with at least 2 dimensions."
346 )]
347 fn test_output_shape_too_small() {
348 output_shape(&[4], &[4]);
349 }
350
351 #[test]
352 #[should_panic(expected = "Dimensions are incompatible for matrix multiplication.")]
353 fn test_output_shape_bad_matrix_dims() {
354 output_shape(&[5, 3], &[4, 7]);
355 }
356
357 #[test]
358 #[should_panic(expected = "Dimensions differ and cannot be broadcasted.")]
359 fn test_output_shape_non_broadcast() {
360 output_shape(&[4, 5, 3], &[2, 3, 7]);
361 }
362}