use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
pub fn matmul<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
match (a.ndim(), b.ndim()) {
(0, _) | (_, 0) => Err(FerrotorchError::InvalidArgument {
message: format!(
"matmul: scalar operands not supported, got shapes {:?} and {:?}",
a.shape(),
b.shape()
),
}),
(1, 1) => dot(a, b),
(2, 1) => mv(a, b),
(1, 2) => vm(a, b),
(2, 2) => mm(a, b),
_ => broadcast_matmul(a, b),
}
}
fn broadcast_batch_shapes(a: &[usize], b: &[usize]) -> FerrotorchResult<Vec<usize>> {
let max_len = a.len().max(b.len());
let mut result = Vec::with_capacity(max_len);
for i in 0..max_len {
let da = if i < max_len - a.len() {
1
} else {
a[i - (max_len - a.len())]
};
let db = if i < max_len - b.len() {
1
} else {
b[i - (max_len - b.len())]
};
if da == db {
result.push(da);
} else if da == 1 {
result.push(db);
} else if db == 1 {
result.push(da);
} else {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"matmul: batch dimensions cannot be broadcast: {:?} vs {:?}",
a, b
),
});
}
}
Ok(result)
}
fn broadcast_matmul<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let device = a.device();
let squeeze_row = a.ndim() == 1;
let squeeze_col = b.ndim() == 1;
let a_shape: Vec<usize> = if squeeze_row {
let mut s = vec![1];
s.extend_from_slice(a.shape());
s
} else {
a.shape().to_vec()
};
let b_shape: Vec<usize> = if squeeze_col {
let mut s = b.shape().to_vec();
s.push(1);
s
} else {
b.shape().to_vec()
};
let a_nd = a_shape.len();
let b_nd = b_shape.len();
let m = a_shape[a_nd - 2];
let k_a = a_shape[a_nd - 1];
let k_b = b_shape[b_nd - 2];
let n = b_shape[b_nd - 1];
if k_a != k_b {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"matmul: inner dimensions mismatch: {:?} @ {:?}",
a.shape(),
b.shape()
),
});
}
let k = k_a;
let a_batch = &a_shape[..a_nd - 2];
let b_batch = &b_shape[..b_nd - 2];
let batch_shape = broadcast_batch_shapes(a_batch, b_batch)?;
let batch_size: usize = batch_shape.iter().product::<usize>().max(1);
let a_batch_strides = broadcast_strides(a_batch, &batch_shape);
let b_batch_strides = broadcast_strides(b_batch, &batch_shape);
let a_mat_size = m * k;
let b_mat_size = k * n;
let c_mat_size = m * n;
let a_data = a.data_vec()?;
let b_data = b.data_vec()?;
let mut result = vec![<T as num_traits::Zero>::zero(); batch_size * c_mat_size];
for bi in 0..batch_size {
let a_off = batch_linear_index(bi, &a_batch_strides, &batch_shape) * a_mat_size;
let b_off = batch_linear_index(bi, &b_batch_strides, &batch_shape) * b_mat_size;
let c_off = bi * c_mat_size;
for i in 0..m {
for j in 0..n {
let mut acc = <T as num_traits::Zero>::zero();
for p in 0..k {
acc += a_data[a_off + i * k + p] * b_data[b_off + p * n + j];
}
result[c_off + i * n + j] = acc;
}
}
}
let mut out_shape = batch_shape;
out_shape.push(m);
out_shape.push(n);
if squeeze_row {
let pos = out_shape.len() - 2;
out_shape.remove(pos);
}
if squeeze_col {
out_shape.pop();
}
let t = Tensor::from_storage(TensorStorage::cpu(result), out_shape, false)?;
Ok(if device.is_cuda() { t.to(device)? } else { t })
}
fn broadcast_strides(src: &[usize], broadcast: &[usize]) -> Vec<usize> {
let offset = broadcast.len() - src.len();
let mut strides = vec![0usize; broadcast.len()];
if !src.is_empty() {
let mut src_strides = vec![1usize; src.len()];
for i in (0..src.len() - 1).rev() {
src_strides[i] = src_strides[i + 1] * src[i + 1];
}
for (i, stride) in strides.iter_mut().enumerate() {
if i < offset {
*stride = 0;
} else {
let si = i - offset;
if src[si] == 1 {
*stride = 0;
} else {
*stride = src_strides[si];
}
}
}
}
strides
}
fn batch_linear_index(flat: usize, strides: &[usize], shape: &[usize]) -> usize {
let mut idx = 0;
let mut remaining = flat;
for i in (0..shape.len()).rev() {
let coord = remaining % shape[i];
remaining /= shape[i];
idx += coord * strides[i];
}
idx
}
pub fn dot<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.ndim() != 1 || b.ndim() != 1 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"dot requires 1-D tensors, got {:?} and {:?}",
a.shape(),
b.shape()
),
});
}
if a.shape()[0] != b.shape()[0] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"dot product dimension mismatch: {} vs {}",
a.shape()[0],
b.shape()[0]
),
});
}
let a_data = a.data()?;
let b_data = b.data()?;
let result = a_data
.iter()
.zip(b_data.iter())
.fold(<T as num_traits::Zero>::zero(), |acc, (&x, &y)| acc + x * y);
Tensor::from_storage(TensorStorage::cpu(vec![result]), vec![], false)
}
const DIRECT_MM_THRESHOLD: usize = 128;
#[inline]
fn faer_par(m: usize, k: usize, n: usize) -> faer::Par {
if m * k * n >= 512 * 512 * 512 {
faer::Par::rayon(0)
} else {
faer::Par::Seq
}
}
pub fn mm_raw<T: Float>(a_data: &[T], b_data: &[T], m: usize, k: usize, n: usize) -> Vec<T> {
let max_dim = m.max(n).max(k);
let zero = <T as num_traits::Zero>::zero();
if max_dim <= DIRECT_MM_THRESHOLD {
let mut result = vec![zero; m * n];
unsafe {
for i in 0..m {
let a_row = i * k;
let r_row = i * n;
for p in 0..k {
let a_ip = *a_data.get_unchecked(a_row + p);
let b_row = p * n;
for j in 0..n {
let r = result.get_unchecked_mut(r_row + j);
*r += a_ip * *b_data.get_unchecked(b_row + j);
}
}
}
}
result
} else {
let mut result = vec![zero; m * n];
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let a_f32 = unsafe { &*(a_data as *const [T] as *const [f32]) };
let b_f32 = unsafe { &*(b_data as *const [T] as *const [f32]) };
let c_f32 = unsafe { &mut *(result.as_mut_slice() as *mut [T] as *mut [f32]) };
let a_mat = faer::mat::MatRef::from_row_major_slice(a_f32, m, k);
let b_mat = faer::mat::MatRef::from_row_major_slice(b_f32, k, n);
let mut c_mat = faer::mat::MatMut::from_row_major_slice_mut(c_f32, m, n);
let par = faer_par(m, k, n);
faer::linalg::matmul::matmul(
&mut c_mat,
faer::Accum::Replace,
&a_mat,
&b_mat,
1.0f32,
par,
);
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let a_f64 = unsafe { &*(a_data as *const [T] as *const [f64]) };
let b_f64 = unsafe { &*(b_data as *const [T] as *const [f64]) };
let c_f64 = unsafe { &mut *(result.as_mut_slice() as *mut [T] as *mut [f64]) };
let a_mat = faer::mat::MatRef::from_row_major_slice(a_f64, m, k);
let b_mat = faer::mat::MatRef::from_row_major_slice(b_f64, k, n);
let mut c_mat = faer::mat::MatMut::from_row_major_slice_mut(c_f64, m, n);
let par = faer_par(m, k, n);
faer::linalg::matmul::matmul(
&mut c_mat,
faer::Accum::Replace,
&a_mat,
&b_mat,
1.0f64,
par,
);
} else {
let a_f64: Vec<f64> = a_data.iter().map(|&v| v.to_f64().unwrap()).collect();
let b_f64: Vec<f64> = b_data.iter().map(|&v| v.to_f64().unwrap()).collect();
let mut r_f64 = vec![0.0f64; m * n];
let a_mat = faer::mat::MatRef::from_row_major_slice(&a_f64, m, k);
let b_mat = faer::mat::MatRef::from_row_major_slice(&b_f64, k, n);
let mut c_mat = faer::mat::MatMut::from_row_major_slice_mut(&mut r_f64, m, n);
let par = faer_par(m, k, n);
faer::linalg::matmul::matmul(
&mut c_mat,
faer::Accum::Replace,
&a_mat,
&b_mat,
1.0f64,
par,
);
for (r, &v) in result.iter_mut().zip(r_f64.iter()) {
*r = T::from(v).unwrap();
}
}
result
}
}
pub fn mm_raw_bt<T: Float>(a_data: &[T], b_data: &[T], m: usize, k: usize, n: usize) -> Vec<T> {
let max_dim = m.max(n).max(k);
let zero = <T as num_traits::Zero>::zero();
if max_dim <= DIRECT_MM_THRESHOLD {
let mut result = vec![zero; m * n];
unsafe {
for i in 0..m {
let a_row = i * k;
let r_row = i * n;
for j in 0..n {
let b_row = j * k;
let mut acc = zero;
for p in 0..k {
acc += *a_data.get_unchecked(a_row + p) * *b_data.get_unchecked(b_row + p);
}
*result.get_unchecked_mut(r_row + j) = acc;
}
}
}
result
} else {
let mut result = vec![zero; m * n];
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let a_f32 = unsafe { &*(a_data as *const [T] as *const [f32]) };
let b_f32 = unsafe { &*(b_data as *const [T] as *const [f32]) };
let c_f32 = unsafe { &mut *(result.as_mut_slice() as *mut [T] as *mut [f32]) };
let a_mat = faer::mat::MatRef::from_row_major_slice(a_f32, m, k);
let b_mat = faer::mat::MatRef::from_row_major_slice(b_f32, n, k).transpose();
let mut c_mat = faer::mat::MatMut::from_row_major_slice_mut(c_f32, m, n);
let par = faer_par(m, k, n);
faer::linalg::matmul::matmul(
&mut c_mat,
faer::Accum::Replace,
&a_mat,
&b_mat,
1.0f32,
par,
);
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let a_f64 = unsafe { &*(a_data as *const [T] as *const [f64]) };
let b_f64 = unsafe { &*(b_data as *const [T] as *const [f64]) };
let c_f64 = unsafe { &mut *(result.as_mut_slice() as *mut [T] as *mut [f64]) };
let a_mat = faer::mat::MatRef::from_row_major_slice(a_f64, m, k);
let b_mat = faer::mat::MatRef::from_row_major_slice(b_f64, n, k).transpose();
let mut c_mat = faer::mat::MatMut::from_row_major_slice_mut(c_f64, m, n);
let par = faer_par(m, k, n);
faer::linalg::matmul::matmul(
&mut c_mat,
faer::Accum::Replace,
&a_mat,
&b_mat,
1.0f64,
par,
);
} else {
let a_f64: Vec<f64> = a_data.iter().map(|&v| v.to_f64().unwrap()).collect();
let b_f64: Vec<f64> = b_data.iter().map(|&v| v.to_f64().unwrap()).collect();
let mut r_f64 = vec![0.0f64; m * n];
let a_mat = faer::mat::MatRef::from_row_major_slice(&a_f64, m, k);
let b_mat = faer::mat::MatRef::from_row_major_slice(&b_f64, n, k).transpose();
let mut c_mat = faer::mat::MatMut::from_row_major_slice_mut(&mut r_f64, m, n);
let par = faer_par(m, k, n);
faer::linalg::matmul::matmul(
&mut c_mat,
faer::Accum::Replace,
&a_mat,
&b_mat,
1.0f64,
par,
);
for (r, &v) in result.iter_mut().zip(r_f64.iter()) {
*r = T::from(v).unwrap();
}
}
result
}
}
pub fn mm_raw_at<T: Float>(a_data: &[T], b_data: &[T], m: usize, k: usize, n: usize) -> Vec<T> {
let max_dim = m.max(n).max(k);
let zero = <T as num_traits::Zero>::zero();
if max_dim <= DIRECT_MM_THRESHOLD {
let mut result = vec![zero; m * n];
unsafe {
for p in 0..k {
let a_row = p * m;
let b_row = p * n;
for i in 0..m {
let a_val = *a_data.get_unchecked(a_row + i);
let r_row = i * n;
for j in 0..n {
let r = result.get_unchecked_mut(r_row + j);
*r += a_val * *b_data.get_unchecked(b_row + j);
}
}
}
}
result
} else {
let mut result = vec![zero; m * n];
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let a_f32 = unsafe { &*(a_data as *const [T] as *const [f32]) };
let b_f32 = unsafe { &*(b_data as *const [T] as *const [f32]) };
let c_f32 = unsafe { &mut *(result.as_mut_slice() as *mut [T] as *mut [f32]) };
let a_mat = faer::mat::MatRef::from_row_major_slice(a_f32, k, m).transpose();
let b_mat = faer::mat::MatRef::from_row_major_slice(b_f32, k, n);
let mut c_mat = faer::mat::MatMut::from_row_major_slice_mut(c_f32, m, n);
let par = faer_par(m, k, n);
faer::linalg::matmul::matmul(
&mut c_mat,
faer::Accum::Replace,
&a_mat,
&b_mat,
1.0f32,
par,
);
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let a_f64 = unsafe { &*(a_data as *const [T] as *const [f64]) };
let b_f64 = unsafe { &*(b_data as *const [T] as *const [f64]) };
let c_f64 = unsafe { &mut *(result.as_mut_slice() as *mut [T] as *mut [f64]) };
let a_mat = faer::mat::MatRef::from_row_major_slice(a_f64, k, m).transpose();
let b_mat = faer::mat::MatRef::from_row_major_slice(b_f64, k, n);
let mut c_mat = faer::mat::MatMut::from_row_major_slice_mut(c_f64, m, n);
let par = faer_par(m, k, n);
faer::linalg::matmul::matmul(
&mut c_mat,
faer::Accum::Replace,
&a_mat,
&b_mat,
1.0f64,
par,
);
} else {
let a_f64: Vec<f64> = a_data.iter().map(|&v| v.to_f64().unwrap()).collect();
let b_f64: Vec<f64> = b_data.iter().map(|&v| v.to_f64().unwrap()).collect();
let mut r_f64 = vec![0.0f64; m * n];
let a_mat = faer::mat::MatRef::from_row_major_slice(&a_f64, k, m).transpose();
let b_mat = faer::mat::MatRef::from_row_major_slice(&b_f64, k, n);
let mut c_mat = faer::mat::MatMut::from_row_major_slice_mut(&mut r_f64, m, n);
let par = faer_par(m, k, n);
faer::linalg::matmul::matmul(
&mut c_mat,
faer::Accum::Replace,
&a_mat,
&b_mat,
1.0f64,
par,
);
for (r, &v) in result.iter_mut().zip(r_f64.iter()) {
*r = T::from(v).unwrap();
}
}
result
}
}
pub fn mm<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.ndim() != 2 || b.ndim() != 2 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"mm requires 2-D tensors, got {:?} and {:?}",
a.shape(),
b.shape()
),
});
}
let a = if a.is_contiguous() { a.clone() } else { a.contiguous()? };
let b = if b.is_contiguous() { b.clone() } else { b.contiguous()? };
let m = a.shape()[0];
let k = a.shape()[1];
let n = b.shape()[1];
if k != b.shape()[0] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"mm: inner dimensions mismatch: ({},{}) @ ({},{})",
m,
k,
b.shape()[0],
n
),
});
}
let a_data = a.data()?;
let b_data = b.data()?;
let result = mm_raw(a_data, b_data, m, k, n);
Tensor::from_storage(TensorStorage::cpu(result), vec![m, n], false)
}
pub fn mv<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.ndim() != 2 || b.ndim() != 1 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"mv requires (2-D, 1-D), got {:?} and {:?}",
a.shape(),
b.shape()
),
});
}
let m = a.shape()[0];
let k = a.shape()[1];
if k != b.shape()[0] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"mv: dimension mismatch: ({},{}) @ ({},)",
m,
k,
b.shape()[0]
),
});
}
let a_data = a.data()?;
let b_data = b.data()?;
let mut result = vec![<T as num_traits::Zero>::zero(); m];
for i in 0..m {
let mut acc = <T as num_traits::Zero>::zero();
for p in 0..k {
acc += a_data[i * k + p] * b_data[p];
}
result[i] = acc;
}
Tensor::from_storage(TensorStorage::cpu(result), vec![m], false)
}
fn vm<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let k = a.shape()[0];
let n = b.shape()[1];
if k != b.shape()[0] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"vm: dimension mismatch: ({},) @ ({},{})",
k,
b.shape()[0],
n
),
});
}
let a_data = a.data()?;
let b_data = b.data()?;
let mut result = vec![<T as num_traits::Zero>::zero(); n];
for j in 0..n {
let mut acc = <T as num_traits::Zero>::zero();
for p in 0..k {
acc += a_data[p] * b_data[p * n + j];
}
result[j] = acc;
}
Tensor::from_storage(TensorStorage::cpu(result), vec![n], false)
}
pub fn bmm<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.ndim() != 3 || b.ndim() != 3 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"bmm requires 3-D tensors, got {:?} and {:?}",
a.shape(),
b.shape()
),
});
}
let batch = a.shape()[0];
let m = a.shape()[1];
let k = a.shape()[2];
let n = b.shape()[2];
if b.shape()[0] != batch {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"bmm: batch dimensions mismatch: {} vs {}",
batch,
b.shape()[0]
),
});
}
if k != b.shape()[1] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"bmm: inner dimensions mismatch: ({},{},{}) @ ({},{},{})",
batch,
m,
k,
b.shape()[0],
b.shape()[1],
n
),
});
}
let a_data = a.data()?;
let b_data = b.data()?;
let slice_a = m * k;
let slice_b = k * n;
let slice_c = m * n;
let mut result = vec![<T as num_traits::Zero>::zero(); batch * slice_c];
for bi in 0..batch {
let a_off = bi * slice_a;
let b_off = bi * slice_b;
let c_off = bi * slice_c;
for i in 0..m {
for j in 0..n {
let mut acc = <T as num_traits::Zero>::zero();
for p in 0..k {
acc += a_data[a_off + i * k + p] * b_data[b_off + p * n + j];
}
result[c_off + i * n + j] = acc;
}
}
}
Tensor::from_storage(TensorStorage::cpu(result), vec![batch, m, n], false)
}
pub fn transpose<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!("transpose requires 2-D tensor, got {:?}", input.shape()),
});
}
let m = input.shape()[0];
let n = input.shape()[1];
let data = input.data()?;
let mut result = vec![<T as num_traits::Zero>::zero(); m * n];
for i in 0..m {
for j in 0..n {
result[j * m + i] = data[i * n + j];
}
}
Tensor::from_storage(TensorStorage::cpu(result), vec![n, m], false)
}
#[cfg(test)]
mod tests {
use super::*;
fn t(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
#[test]
fn test_dot() {
let a = t(&[1.0, 2.0, 3.0], &[3]);
let b = t(&[4.0, 5.0, 6.0], &[3]);
let c = dot(&a, &b).unwrap();
assert!(c.is_scalar());
assert!((c.item().unwrap() - 32.0).abs() < 1e-6);
}
#[test]
fn test_mm() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
let c = mm(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 2]);
let d = c.data().unwrap();
assert!((d[0] - 19.0).abs() < 1e-6);
assert!((d[1] - 22.0).abs() < 1e-6);
assert!((d[2] - 43.0).abs() < 1e-6);
assert!((d[3] - 50.0).abs() < 1e-6);
}
#[test]
fn test_mv() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[5.0, 6.0], &[2]);
let c = mv(&a, &b).unwrap();
assert_eq!(c.shape(), &[2]);
let d = c.data().unwrap();
assert!((d[0] - 17.0).abs() < 1e-6);
assert!((d[1] - 39.0).abs() < 1e-6);
}
#[test]
fn test_matmul_dispatch() {
let a = t(&[1.0, 2.0, 3.0], &[3]);
let b = t(&[4.0, 5.0, 6.0], &[3]);
let c = matmul(&a, &b).unwrap();
assert!(c.is_scalar());
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 2]);
}
#[test]
fn test_matmul_3d_3d_same_batch() {
#[rustfmt::skip]
let a = t(&[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, ], &[2, 2, 3]);
#[rustfmt::skip]
let b = t(&[
1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ], &[2, 3, 2]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 2, 2]);
let d = c.data().unwrap();
assert!((d[0] - 4.0).abs() < 1e-6);
assert!((d[1] - 2.0).abs() < 1e-6);
assert!((d[2] - 10.0).abs() < 1e-6);
assert!((d[3] - 5.0).abs() < 1e-6);
}
#[test]
fn test_matmul_3d_2d_broadcast() {
let a = t(&vec![1.0; 2 * 3 * 4], &[2, 3, 4]);
let b = t(&vec![1.0; 4 * 2], &[4, 2]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 3, 2]);
for &v in c.data().unwrap().iter() {
assert!((v - 4.0).abs() < 1e-6);
}
}
#[test]
fn test_matmul_2d_3d_broadcast() {
let a = t(&vec![1.0; 3 * 4], &[3, 4]);
let b = t(&vec![1.0; 2 * 4 * 2], &[2, 4, 2]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 3, 2]);
}
#[test]
fn test_matmul_batch_broadcast_1_vs_n() {
let a = t(&vec![1.0; 1 * 2 * 3], &[1, 2, 3]);
let b = t(&vec![1.0; 4 * 3 * 2], &[4, 3, 2]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[4, 2, 2]);
}
#[test]
fn test_matmul_4d() {
let a = t(&vec![1.0; 2 * 3 * 2 * 4], &[2, 3, 2, 4]);
let b = t(&vec![1.0; 2 * 3 * 4 * 5], &[2, 3, 4, 5]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 3, 2, 5]);
}
#[test]
fn test_matmul_3d_1d() {
let a = t(&vec![1.0; 2 * 3 * 4], &[2, 3, 4]);
let b = t(&vec![1.0; 4], &[4]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 3]);
for &v in c.data().unwrap().iter() {
assert!((v - 4.0).abs() < 1e-6);
}
}
#[test]
fn test_matmul_1d_3d() {
let a = t(&vec![1.0; 4], &[4]);
let b = t(&vec![1.0; 2 * 4 * 3], &[2, 4, 3]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 3]);
}
#[test]
fn test_matmul_broadcast_mismatch() {
let a = t(&vec![1.0; 2 * 3 * 4], &[2, 3, 4]);
let b = t(&vec![1.0; 3 * 4 * 2], &[3, 4, 2]);
assert!(matmul(&a, &b).is_err());
}
#[test]
fn test_matmul_inner_dim_mismatch() {
let a = t(&vec![1.0; 2 * 3 * 4], &[2, 3, 4]);
let b = t(&vec![1.0; 2 * 5 * 2], &[2, 5, 2]);
assert!(matmul(&a, &b).is_err());
}
#[test]
fn test_mm_shape_mismatch() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let b = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
assert!(mm(&a, &b).is_err());
}
#[test]
fn test_transpose() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let b = transpose(&a).unwrap();
assert_eq!(b.shape(), &[3, 2]);
assert_eq!(b.data().unwrap(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn test_bmm_forward_shape() {
let a = t(&vec![1.0; 2 * 3 * 4], &[2, 3, 4]);
let b = t(&vec![1.0; 2 * 4 * 5], &[2, 4, 5]);
let c = bmm(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 3, 5]);
}
#[test]
fn test_bmm_forward_correctness() {
#[rustfmt::skip]
let a_data: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0,
1.0, 0.0, 0.0, 1.0,
];
#[rustfmt::skip]
let b_data: Vec<f32> = vec![
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
];
let a = t(&a_data, &[2, 2, 2]);
let b = t(&b_data, &[2, 2, 2]);
let c = bmm(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 2, 2]);
let d = c.data().unwrap();
assert!((d[0] - 19.0).abs() < 1e-6);
assert!((d[1] - 22.0).abs() < 1e-6);
assert!((d[2] - 43.0).abs() < 1e-6);
assert!((d[3] - 50.0).abs() < 1e-6);
assert!((d[4] - 9.0).abs() < 1e-6);
assert!((d[5] - 10.0).abs() < 1e-6);
assert!((d[6] - 11.0).abs() < 1e-6);
assert!((d[7] - 12.0).abs() < 1e-6);
}
#[test]
fn test_bmm_batch_size_1() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[1, 2, 2]);
let b = t(&[5.0, 6.0, 7.0, 8.0], &[1, 2, 2]);
let c = bmm(&a, &b).unwrap();
assert_eq!(c.shape(), &[1, 2, 2]);
let d = c.data().unwrap();
assert!((d[0] - 19.0).abs() < 1e-6);
assert!((d[1] - 22.0).abs() < 1e-6);
assert!((d[2] - 43.0).abs() < 1e-6);
assert!((d[3] - 50.0).abs() < 1e-6);
}
#[test]
fn test_bmm_shape_mismatch() {
let a = t(&vec![1.0; 2 * 2 * 2], &[2, 2, 2]);
let b = t(&vec![1.0; 3 * 2 * 2], &[3, 2, 2]);
assert!(bmm(&a, &b).is_err());
let a = t(&vec![1.0; 2 * 2 * 3], &[2, 2, 3]);
let b = t(&vec![1.0; 2 * 4 * 2], &[2, 4, 2]);
assert!(bmm(&a, &b).is_err());
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&vec![1.0; 1 * 2 * 2], &[1, 2, 2]);
assert!(bmm(&a, &b).is_err());
}
}