1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
use crate::{iter_par, run_par, UnsafeSharedRef};
use burn_tensor::ElementConversion;
use burn_tensor::{ops::TensorOps, Shape};
use ndarray::s;
pub(crate) fn matmul<E, const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D>
where
E: FloatNdArrayElement,
{
let shape_ori_lhs = lhs.shape();
let shape_ori_rhs = rhs.shape();
let lhs = reshape(lhs);
let rhs = reshape(rhs);
let [batch_size_lhs, m, _] = lhs.shape().dims;
let [batch_size_rhs, _, n] = rhs.shape().dims;
let mut shape_out = match batch_size_lhs > batch_size_rhs {
true => shape_ori_lhs,
false => shape_ori_rhs,
};
shape_out.dims[D - 2] = m;
shape_out.dims[D - 1] = n;
let out = general_matmul(lhs, rhs);
NdArrayBackend::<E>::reshape(out, shape_out)
}
fn general_matmul<E: FloatNdArrayElement>(
lhs: NdArrayTensor<E, 3>,
rhs: NdArrayTensor<E, 3>,
) -> NdArrayTensor<E, 3> {
run_par!(|| {
let [batch_size_lhs, m, _] = lhs.shape().dims;
let [batch_size_rhs, k, n] = rhs.shape().dims;
let batch_size = usize::max(batch_size_rhs, batch_size_lhs);
if batch_size_lhs > batch_size && batch_size_lhs != 1 {
panic!("Broadcast on multiple dimensions is not yet supported");
}
if batch_size_rhs > batch_size && batch_size_rhs != 1 {
panic!("Broadcast on multiple dimensions is not yet supported");
}
let alpha: E = 1.0.elem();
let beta: E = 0.0.elem();
let mut out_array = ndarray::Array3::<E>::zeros((batch_size, m, n));
let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array);
let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap();
let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap();
iter_par!(0, batch_size).for_each(|b| {
let lhs_slice = match batch_size_lhs == 1 {
true => lhs_array.slice(s!(0, .., ..)),
false => lhs_array.slice(s!(b, .., ..)),
};
let rhs_slice = match batch_size_rhs == 1 {
true => rhs_array.slice(s!(0, .., ..)),
false => rhs_array.slice(s!(b, .., ..)),
};
unsafe {
let mut out_slice = unsafe_shared_out_array.get().slice_mut(s!(b, .., ..));
ndarray::linalg::general_mat_mul(
alpha,
&lhs_slice,
&rhs_slice,
beta,
&mut out_slice,
);
}
});
NdArrayTensor::new(out_array.into_shared().into_dyn())
})
}
fn reshape<E: FloatNdArrayElement, const D: usize>(
tensor: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, 3> {
let shape = tensor.shape();
if D < 2 {
NdArrayBackend::<E>::reshape(tensor, Shape::new([1, 1, shape.dims[0]]))
} else {
let batch_size = batch_size(&shape);
let size0 = shape.dims[D - 2];
let size1 = shape.dims[D - 1];
NdArrayBackend::<E>::reshape(tensor, Shape::new([batch_size, size0, size1]))
}
}
fn batch_size<const D: usize>(shape: &Shape<D>) -> usize {
let mut num_batch = 1;
for i in 0..D - 2 {
num_batch *= shape.dims[i];
}
num_batch
}