use crate::UnsafeSharedRef;
use crate::{NdArrayElement, ShapeOps, SharedArray, iter_range_par, ops::NdArrayOps, run_par};
use alloc::{vec, vec::Vec};
use burn_backend::ElementConversion;
use burn_backend::Shape;
use ndarray::{IxDyn, s};
pub(crate) fn matmul<E: NdArrayElement>(
lhs: SharedArray<E>,
rhs: SharedArray<E>,
) -> SharedArray<E> {
let shape_lhs = lhs.shape();
let shape_rhs = rhs.shape();
let ndims = shape_lhs.num_dims();
let m = shape_lhs[ndims - 2]; let k = shape_rhs[ndims - 2]; let n = shape_rhs[ndims - 1];
let (out_shape, strides_lhs, strides_rhs, strides_out) = output_shape(shape_lhs, shape_rhs);
let l_mat_size = m * k; let r_mat_size = k * n; let out_mat_size = m * n;
let num_l_batches = shape_lhs.num_elements() / l_mat_size;
let num_r_batches = shape_rhs.num_elements() / r_mat_size;
let num_out_batches = out_shape.num_elements() / out_mat_size;
let lhs_array = NdArrayOps::reshape(lhs, Shape::new([num_l_batches, m, k]));
let rhs_array = NdArrayOps::reshape(rhs, Shape::new([num_r_batches, k, n]));
let alpha: E = 1.0.elem();
let beta: E = 0.0.elem();
let out = run_par!(|| {
let mut out_array = ndarray::Array3::<E>::zeros((num_out_batches, m, n));
let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array);
iter_range_par!(0, num_out_batches).for_each(|out_batch| {
let out_index = strides_out.unflatten(out_batch);
let l_batch = strides_lhs.flatten(&out_index);
let r_batch = strides_rhs.flatten(&out_index);
let lhs_slice = lhs_array.slice(s!(l_batch, .., ..));
let rhs_slice = rhs_array.slice(s!(r_batch, .., ..));
unsafe {
let mut out_slice = unsafe_shared_out_array
.get()
.slice_mut(s!(out_batch, .., ..));
ndarray::linalg::general_mat_mul(
alpha,
&lhs_slice,
&rhs_slice,
beta,
&mut out_slice,
)
}
});
out_array.into_shared().into_dyn()
});
NdArrayOps::reshape(out, out_shape)
}
#[derive(Debug, PartialEq)]
struct Strides {
strides: Vec<usize>,
}
impl Strides {
fn new(strides: Vec<usize>) -> Self {
Strides { strides }
}
fn unflatten(&self, linear_index: usize) -> Vec<usize> {
let mut coord = Vec::with_capacity(self.strides.len());
let mut rem = linear_index;
for stride in self.strides.iter() {
coord.push(rem / stride);
rem %= stride;
}
coord
}
fn flatten(&self, index: &Vec<usize>) -> usize {
assert_eq!(self.strides.len(), index.len());
self.strides
.iter()
.zip(index)
.map(|(stride, index)| stride * index)
.sum()
}
}
fn output_shape(lsh: &[usize], rsh: &[usize]) -> (Shape, Strides, Strides, Strides) {
let ndims = lsh.num_dims();
if ndims < 2 {
panic!("Matrix multiplication requires an array with at least 2 dimensions.");
}
let l_rows = lsh[ndims - 2];
let l_cols = lsh[ndims - 1];
let r_rows = rsh[ndims - 2];
let r_cols = rsh[ndims - 1];
if l_cols != r_rows {
panic!("Dimensions are incompatible for matrix multiplication.");
}
let mut osh = vec![0; ndims];
osh[ndims - 2] = l_rows;
osh[ndims - 1] = r_cols;
let mut cur_l_stride: usize = 1;
let mut cur_r_stride: usize = 1;
let mut cur_o_stride: usize = 1;
let mut l_strides = Vec::with_capacity(ndims - 2);
let mut r_strides = Vec::with_capacity(ndims - 2);
let mut o_strides = Vec::with_capacity(ndims - 2);
for i in (0..ndims - 2).rev() {
let l_dim = lsh[i];
let r_dim = rsh[i];
let o_dim: usize;
if l_dim == r_dim {
o_dim = l_dim; l_strides.push(cur_l_stride);
r_strides.push(cur_r_stride);
} else if l_dim == 1 {
o_dim = r_dim; l_strides.push(0);
r_strides.push(cur_r_stride);
} else if r_dim == 1 {
o_dim = l_dim; l_strides.push(cur_l_stride);
r_strides.push(0);
} else {
panic!("Dimensions differ and cannot be broadcasted.");
}
osh[i] = o_dim;
o_strides.push(cur_o_stride);
cur_o_stride *= o_dim;
cur_l_stride *= l_dim;
cur_r_stride *= r_dim;
}
l_strides.reverse();
r_strides.reverse();
o_strides.reverse();
(
Shape::from(osh),
Strides::new(l_strides),
Strides::new(r_strides),
Strides::new(o_strides),
)
}
pub(crate) fn cross<E: NdArrayElement>(
lhs: SharedArray<E>,
rhs: SharedArray<E>,
dim: usize,
) -> SharedArray<E> {
let shape_lhs = lhs.shape();
let shape_rhs = rhs.shape();
let ndims = shape_lhs.num_dims();
let mut broadcast_shape = vec![0; ndims];
for i in 0..ndims {
if i == dim {
broadcast_shape[i] = shape_lhs[i]; } else {
let l = shape_lhs[i];
let r = shape_rhs[i];
if l == r {
broadcast_shape[i] = l;
} else if l == 1 {
broadcast_shape[i] = r;
} else if r == 1 {
broadcast_shape[i] = l;
} else {
panic!("Tensors are not broadcastable along dimension {}", i);
}
}
}
let lhs_broadcast = if shape_lhs == broadcast_shape.as_slice() {
lhs
} else {
NdArrayOps::expand(lhs, Shape::from(broadcast_shape.clone()))
};
let rhs_broadcast = if shape_rhs == broadcast_shape.as_slice() {
rhs
} else {
NdArrayOps::expand(rhs, Shape::from(broadcast_shape.clone()))
};
let mut perm = (0..ndims).collect::<Vec<_>>();
perm.remove(dim);
perm.push(dim);
let lhs_permuted = NdArrayOps::permute(lhs_broadcast, &perm);
let rhs_permuted = NdArrayOps::permute(rhs_broadcast, &perm);
let total_elements = lhs_permuted.shape().num_elements();
let batch_size = total_elements / 3;
let lhs_reshaped = NdArrayOps::reshape(lhs_permuted, Shape::new([batch_size, 3]));
let rhs_reshaped = NdArrayOps::reshape(rhs_permuted, Shape::new([batch_size, 3]));
let mut result = ndarray::ArrayD::<E>::zeros(IxDyn(&[batch_size, 3]));
for i in 0..batch_size {
let a1 = lhs_reshaped[IxDyn(&[i, 0])];
let a2 = lhs_reshaped[IxDyn(&[i, 1])];
let a3 = lhs_reshaped[IxDyn(&[i, 2])];
let b1 = rhs_reshaped[IxDyn(&[i, 0])];
let b2 = rhs_reshaped[IxDyn(&[i, 1])];
let b3 = rhs_reshaped[IxDyn(&[i, 2])];
result[IxDyn(&[i, 0])] = a2.mul(b3).sub(a3.mul(b2));
result[IxDyn(&[i, 1])] = a3.mul(b1).sub(a1.mul(b3));
result[IxDyn(&[i, 2])] = a1.mul(b2).sub(a2.mul(b1));
}
let result_shared = result.into_shared();
let mut result_shape = broadcast_shape;
result_shape.remove(dim);
result_shape.push(3);
let result_reshaped = NdArrayOps::reshape(result_shared, Shape::from(result_shape));
let mut inv_perm = vec![0; ndims];
for (i, &p) in perm.iter().enumerate() {
inv_perm[p] = i;
}
NdArrayOps::permute(result_reshaped, &inv_perm)
}
#[cfg(test)]
mod tests {
use super::*;
impl Strides {
fn empty() -> Self {
Strides {
strides: Vec::with_capacity(0),
}
}
}
#[test]
fn test_output_shape() {
assert_eq!(
output_shape(&[5, 3], &[3, 7]),
(
Shape::from([5, 7]),
Strides::empty(),
Strides::empty(),
Strides::empty()
)
);
assert_eq!(
output_shape(&[4, 5, 3], &[4, 3, 7]),
(
Shape::from([4, 5, 7]),
Strides::new(vec![1]),
Strides::new(vec![1]),
Strides::new(vec![1])
)
);
assert_eq!(
output_shape(&[1, 5, 3], &[4, 3, 7]),
(
Shape::from([4, 5, 7]),
Strides::new(vec![0]),
Strides::new(vec![1]),
Strides::new(vec![1])
)
);
assert_eq!(
output_shape(&[4, 5, 3], &[1, 3, 7]),
(
Shape::from([4, 5, 7]),
Strides::new(vec![1]),
Strides::new(vec![0]),
Strides::new(vec![1])
)
);
assert_eq!(
output_shape(&[1, 4, 5, 3], &[8, 1, 3, 7]),
(
Shape::from([8, 4, 5, 7]),
Strides::new(vec![0, 1]),
Strides::new(vec![1, 0]),
Strides::new(vec![4, 1])
)
);
assert_eq!(
output_shape(&[1, 3, 4, 5, 3], &[8, 3, 1, 3, 7]),
(
Shape::from([8, 3, 4, 5, 7]),
Strides::new(vec![0, 4, 1]),
Strides::new(vec![3, 1, 0]),
Strides::new(vec![12, 4, 1])
)
)
}
#[test]
#[should_panic(
expected = "Matrix multiplication requires an array with at least 2 dimensions."
)]
fn test_output_shape_too_small() {
output_shape(&[4], &[4]);
}
#[test]
#[should_panic(expected = "Dimensions are incompatible for matrix multiplication.")]
fn test_output_shape_bad_matrix_dims() {
output_shape(&[5, 3], &[4, 7]);
}
#[test]
#[should_panic(expected = "Dimensions differ and cannot be broadcasted.")]
fn test_output_shape_non_broadcast() {
output_shape(&[4, 5, 3], &[2, 3, 7]);
}
}