use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
pub fn matmul<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
match (a.ndim(), b.ndim()) {
(0, _) | (_, 0) => Err(FerrotorchError::InvalidArgument {
message: format!(
"matmul: scalar operands not supported, got shapes {:?} and {:?}",
a.shape(),
b.shape()
),
}),
(1, 1) => dot(a, b),
(2, 1) => mv(a, b),
(1, 2) => vm(a, b),
(2, 2) => mm(a, b),
_ => broadcast_matmul(a, b),
}
}
fn broadcast_batch_shapes(a: &[usize], b: &[usize]) -> FerrotorchResult<Vec<usize>> {
let max_len = a.len().max(b.len());
let mut result = Vec::with_capacity(max_len);
for i in 0..max_len {
let da = if i < max_len - a.len() { 1 } else { a[i - (max_len - a.len())] };
let db = if i < max_len - b.len() { 1 } else { b[i - (max_len - b.len())] };
if da == db {
result.push(da);
} else if da == 1 {
result.push(db);
} else if db == 1 {
result.push(da);
} else {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"matmul: batch dimensions cannot be broadcast: {:?} vs {:?}",
a, b
),
});
}
}
Ok(result)
}
fn broadcast_matmul<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let squeeze_row = a.ndim() == 1;
let squeeze_col = b.ndim() == 1;
let a_shape: Vec<usize> = if squeeze_row {
let mut s = vec![1];
s.extend_from_slice(a.shape());
s
} else {
a.shape().to_vec()
};
let b_shape: Vec<usize> = if squeeze_col {
let mut s = b.shape().to_vec();
s.push(1);
s
} else {
b.shape().to_vec()
};
let a_nd = a_shape.len();
let b_nd = b_shape.len();
let m = a_shape[a_nd - 2];
let k_a = a_shape[a_nd - 1];
let k_b = b_shape[b_nd - 2];
let n = b_shape[b_nd - 1];
if k_a != k_b {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"matmul: inner dimensions mismatch: {:?} @ {:?}",
a.shape(),
b.shape()
),
});
}
let k = k_a;
let a_batch = &a_shape[..a_nd - 2];
let b_batch = &b_shape[..b_nd - 2];
let batch_shape = broadcast_batch_shapes(a_batch, b_batch)?;
let batch_size: usize = batch_shape.iter().product::<usize>().max(1);
let a_batch_strides = broadcast_strides(a_batch, &batch_shape);
let b_batch_strides = broadcast_strides(b_batch, &batch_shape);
let a_mat_size = m * k;
let b_mat_size = k * n;
let c_mat_size = m * n;
let a_data = a.data()?;
let b_data = b.data()?;
let mut result = vec![<T as num_traits::Zero>::zero(); batch_size * c_mat_size];
for bi in 0..batch_size {
let a_off = batch_linear_index(bi, &a_batch_strides, &batch_shape) * a_mat_size;
let b_off = batch_linear_index(bi, &b_batch_strides, &batch_shape) * b_mat_size;
let c_off = bi * c_mat_size;
for i in 0..m {
for j in 0..n {
let mut acc = <T as num_traits::Zero>::zero();
for p in 0..k {
acc = acc + a_data[a_off + i * k + p] * b_data[b_off + p * n + j];
}
result[c_off + i * n + j] = acc;
}
}
}
let mut out_shape = batch_shape;
out_shape.push(m);
out_shape.push(n);
if squeeze_row {
let pos = out_shape.len() - 2;
out_shape.remove(pos);
}
if squeeze_col {
out_shape.pop();
}
Tensor::from_storage(TensorStorage::cpu(result), out_shape, false)
}
fn broadcast_strides(src: &[usize], broadcast: &[usize]) -> Vec<usize> {
let offset = broadcast.len() - src.len();
let mut strides = vec![0usize; broadcast.len()];
if !src.is_empty() {
let mut src_strides = vec![1usize; src.len()];
for i in (0..src.len() - 1).rev() {
src_strides[i] = src_strides[i + 1] * src[i + 1];
}
for i in 0..broadcast.len() {
if i < offset {
strides[i] = 0;
} else {
let si = i - offset;
if src[si] == 1 {
strides[i] = 0;
} else {
strides[i] = src_strides[si];
}
}
}
}
strides
}
fn batch_linear_index(flat: usize, strides: &[usize], shape: &[usize]) -> usize {
let mut idx = 0;
let mut remaining = flat;
for i in (0..shape.len()).rev() {
let coord = remaining % shape[i];
remaining /= shape[i];
idx += coord * strides[i];
}
idx
}
pub fn dot<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.ndim() != 1 || b.ndim() != 1 {
return Err(FerrotorchError::ShapeMismatch {
message: format!("dot requires 1-D tensors, got {:?} and {:?}", a.shape(), b.shape()),
});
}
if a.shape()[0] != b.shape()[0] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"dot product dimension mismatch: {} vs {}",
a.shape()[0],
b.shape()[0]
),
});
}
let a_data = a.data()?;
let b_data = b.data()?;
let result = a_data
.iter()
.zip(b_data.iter())
.fold(<T as num_traits::Zero>::zero(), |acc, (&x, &y)| acc + x * y);
Tensor::from_storage(TensorStorage::cpu(vec![result]), vec![], false)
}
const DIRECT_MM_THRESHOLD: usize = 128;
pub fn mm_raw<T: Float>(a_data: &[T], b_data: &[T], m: usize, k: usize, n: usize) -> Vec<T> {
let max_dim = m.max(n).max(k);
let zero = <T as num_traits::Zero>::zero();
if max_dim <= DIRECT_MM_THRESHOLD {
let mut result = vec![zero; m * n];
unsafe {
for i in 0..m {
let a_row = i * k;
let r_row = i * n;
for p in 0..k {
let a_ip = *a_data.get_unchecked(a_row + p);
let b_row = p * n;
for j in 0..n {
let r = result.get_unchecked_mut(r_row + j);
*r = *r + a_ip * *b_data.get_unchecked(b_row + j);
}
}
}
}
result
} else {
let mut result = vec![zero; m * n];
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let a_ptr = a_data.as_ptr() as *const f32;
let b_ptr = b_data.as_ptr() as *const f32;
let c_ptr = result.as_mut_ptr() as *mut f32;
unsafe {
matrixmultiply::sgemm(
m, k, n,
1.0, a_ptr, k as isize, 1, b_ptr, n as isize, 1, 0.0, c_ptr, n as isize, 1, );
}
} else {
let a_f64: Vec<f64> = a_data.iter().map(|&v| v.to_f64().unwrap()).collect();
let b_f64: Vec<f64> = b_data.iter().map(|&v| v.to_f64().unwrap()).collect();
let mut r_f64 = vec![0.0f64; m * n];
unsafe {
matrixmultiply::dgemm(
m, k, n,
1.0,
a_f64.as_ptr(), k as isize, 1,
b_f64.as_ptr(), n as isize, 1,
0.0,
r_f64.as_mut_ptr(), n as isize, 1,
);
}
for (r, &v) in result.iter_mut().zip(r_f64.iter()) {
*r = T::from(v).unwrap();
}
}
result
}
}
pub fn mm_raw_bt<T: Float>(a_data: &[T], b_data: &[T], m: usize, k: usize, n: usize) -> Vec<T> {
let max_dim = m.max(n).max(k);
let zero = <T as num_traits::Zero>::zero();
if max_dim <= DIRECT_MM_THRESHOLD {
let mut result = vec![zero; m * n];
unsafe {
for i in 0..m {
let a_row = i * k;
let r_row = i * n;
for j in 0..n {
let b_row = j * k;
let mut acc = zero;
for p in 0..k {
acc = acc + *a_data.get_unchecked(a_row + p) * *b_data.get_unchecked(b_row + p);
}
*result.get_unchecked_mut(r_row + j) = acc;
}
}
}
result
} else {
let mut result = vec![zero; m * n];
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let a_ptr = a_data.as_ptr() as *const f32;
let b_ptr = b_data.as_ptr() as *const f32;
let c_ptr = result.as_mut_ptr() as *mut f32;
unsafe {
matrixmultiply::sgemm(
m, k, n,
1.0,
a_ptr, k as isize, 1, b_ptr, 1, k as isize, 0.0,
c_ptr, n as isize, 1, );
}
} else {
let a_f64: Vec<f64> = a_data.iter().map(|&v| v.to_f64().unwrap()).collect();
let b_f64: Vec<f64> = b_data.iter().map(|&v| v.to_f64().unwrap()).collect();
let mut r_f64 = vec![0.0f64; m * n];
unsafe {
matrixmultiply::dgemm(
m, k, n,
1.0,
a_f64.as_ptr(), k as isize, 1,
b_f64.as_ptr(), 1, k as isize,
0.0,
r_f64.as_mut_ptr(), n as isize, 1,
);
}
for (r, &v) in result.iter_mut().zip(r_f64.iter()) {
*r = T::from(v).unwrap();
}
}
result
}
}
pub fn mm_raw_at<T: Float>(a_data: &[T], b_data: &[T], m: usize, k: usize, n: usize) -> Vec<T> {
let max_dim = m.max(n).max(k);
let zero = <T as num_traits::Zero>::zero();
if max_dim <= DIRECT_MM_THRESHOLD {
let mut result = vec![zero; m * n];
unsafe {
for p in 0..k {
let a_row = p * m;
let b_row = p * n;
for i in 0..m {
let a_val = *a_data.get_unchecked(a_row + i);
let r_row = i * n;
for j in 0..n {
let r = result.get_unchecked_mut(r_row + j);
*r = *r + a_val * *b_data.get_unchecked(b_row + j);
}
}
}
}
result
} else {
let mut result = vec![zero; m * n];
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let a_ptr = a_data.as_ptr() as *const f32;
let b_ptr = b_data.as_ptr() as *const f32;
let c_ptr = result.as_mut_ptr() as *mut f32;
unsafe {
matrixmultiply::sgemm(
m, k, n,
1.0,
a_ptr, 1, m as isize, b_ptr, n as isize, 1, 0.0,
c_ptr, n as isize, 1, );
}
} else {
let a_f64: Vec<f64> = a_data.iter().map(|&v| v.to_f64().unwrap()).collect();
let b_f64: Vec<f64> = b_data.iter().map(|&v| v.to_f64().unwrap()).collect();
let mut r_f64 = vec![0.0f64; m * n];
unsafe {
matrixmultiply::dgemm(
m, k, n,
1.0,
a_f64.as_ptr(), 1, m as isize,
b_f64.as_ptr(), n as isize, 1,
0.0,
r_f64.as_mut_ptr(), n as isize, 1,
);
}
for (r, &v) in result.iter_mut().zip(r_f64.iter()) {
*r = T::from(v).unwrap();
}
}
result
}
}
pub fn mm<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.ndim() != 2 || b.ndim() != 2 {
return Err(FerrotorchError::ShapeMismatch {
message: format!("mm requires 2-D tensors, got {:?} and {:?}", a.shape(), b.shape()),
});
}
let m = a.shape()[0];
let k = a.shape()[1];
let n = b.shape()[1];
if k != b.shape()[0] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"mm: inner dimensions mismatch: ({},{}) @ ({},{})",
m, k, b.shape()[0], n
),
});
}
let a_data = a.data()?;
let b_data = b.data()?;
let result = mm_raw(a_data, b_data, m, k, n);
Tensor::from_storage(TensorStorage::cpu(result), vec![m, n], false)
}
pub fn mv<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.ndim() != 2 || b.ndim() != 1 {
return Err(FerrotorchError::ShapeMismatch {
message: format!("mv requires (2-D, 1-D), got {:?} and {:?}", a.shape(), b.shape()),
});
}
let m = a.shape()[0];
let k = a.shape()[1];
if k != b.shape()[0] {
return Err(FerrotorchError::ShapeMismatch {
message: format!("mv: dimension mismatch: ({},{}) @ ({},)", m, k, b.shape()[0]),
});
}
let a_data = a.data()?;
let b_data = b.data()?;
let mut result = vec![<T as num_traits::Zero>::zero(); m];
for i in 0..m {
let mut acc = <T as num_traits::Zero>::zero();
for p in 0..k {
acc = acc + a_data[i * k + p] * b_data[p];
}
result[i] = acc;
}
Tensor::from_storage(TensorStorage::cpu(result), vec![m], false)
}
fn vm<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let k = a.shape()[0];
let n = b.shape()[1];
if k != b.shape()[0] {
return Err(FerrotorchError::ShapeMismatch {
message: format!("vm: dimension mismatch: ({},) @ ({},{})", k, b.shape()[0], n),
});
}
let a_data = a.data()?;
let b_data = b.data()?;
let mut result = vec![<T as num_traits::Zero>::zero(); n];
for j in 0..n {
let mut acc = <T as num_traits::Zero>::zero();
for p in 0..k {
acc = acc + a_data[p] * b_data[p * n + j];
}
result[j] = acc;
}
Tensor::from_storage(TensorStorage::cpu(result), vec![n], false)
}
pub fn bmm<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.ndim() != 3 || b.ndim() != 3 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"bmm requires 3-D tensors, got {:?} and {:?}",
a.shape(),
b.shape()
),
});
}
let batch = a.shape()[0];
let m = a.shape()[1];
let k = a.shape()[2];
let n = b.shape()[2];
if b.shape()[0] != batch {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"bmm: batch dimensions mismatch: {} vs {}",
batch,
b.shape()[0]
),
});
}
if k != b.shape()[1] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"bmm: inner dimensions mismatch: ({},{},{}) @ ({},{},{})",
batch, m, k, b.shape()[0], b.shape()[1], n
),
});
}
let a_data = a.data()?;
let b_data = b.data()?;
let slice_a = m * k;
let slice_b = k * n;
let slice_c = m * n;
let mut result = vec![<T as num_traits::Zero>::zero(); batch * slice_c];
for bi in 0..batch {
let a_off = bi * slice_a;
let b_off = bi * slice_b;
let c_off = bi * slice_c;
for i in 0..m {
for j in 0..n {
let mut acc = <T as num_traits::Zero>::zero();
for p in 0..k {
acc = acc + a_data[a_off + i * k + p] * b_data[b_off + p * n + j];
}
result[c_off + i * n + j] = acc;
}
}
}
Tensor::from_storage(TensorStorage::cpu(result), vec![batch, m, n], false)
}
pub fn transpose<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!("transpose requires 2-D tensor, got {:?}", input.shape()),
});
}
let m = input.shape()[0];
let n = input.shape()[1];
let data = input.data()?;
let mut result = vec![<T as num_traits::Zero>::zero(); m * n];
for i in 0..m {
for j in 0..n {
result[j * m + i] = data[i * n + j];
}
}
Tensor::from_storage(TensorStorage::cpu(result), vec![n, m], false)
}
#[cfg(test)]
mod tests {
use super::*;
fn t(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
#[test]
fn test_dot() {
let a = t(&[1.0, 2.0, 3.0], &[3]);
let b = t(&[4.0, 5.0, 6.0], &[3]);
let c = dot(&a, &b).unwrap();
assert!(c.is_scalar());
assert!((c.item().unwrap() - 32.0).abs() < 1e-6);
}
#[test]
fn test_mm() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
let c = mm(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 2]);
let d = c.data().unwrap();
assert!((d[0] - 19.0).abs() < 1e-6);
assert!((d[1] - 22.0).abs() < 1e-6);
assert!((d[2] - 43.0).abs() < 1e-6);
assert!((d[3] - 50.0).abs() < 1e-6);
}
#[test]
fn test_mv() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[5.0, 6.0], &[2]);
let c = mv(&a, &b).unwrap();
assert_eq!(c.shape(), &[2]);
let d = c.data().unwrap();
assert!((d[0] - 17.0).abs() < 1e-6);
assert!((d[1] - 39.0).abs() < 1e-6);
}
#[test]
fn test_matmul_dispatch() {
let a = t(&[1.0, 2.0, 3.0], &[3]);
let b = t(&[4.0, 5.0, 6.0], &[3]);
let c = matmul(&a, &b).unwrap();
assert!(c.is_scalar());
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 2]);
}
#[test]
fn test_matmul_3d_3d_same_batch() {
#[rustfmt::skip]
let a = t(&[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, ], &[2, 2, 3]);
#[rustfmt::skip]
let b = t(&[
1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ], &[2, 3, 2]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 2, 2]);
let d = c.data().unwrap();
assert!((d[0] - 4.0).abs() < 1e-6);
assert!((d[1] - 2.0).abs() < 1e-6);
assert!((d[2] - 10.0).abs() < 1e-6);
assert!((d[3] - 5.0).abs() < 1e-6);
}
#[test]
fn test_matmul_3d_2d_broadcast() {
let a = t(&vec![1.0; 2 * 3 * 4], &[2, 3, 4]);
let b = t(&vec![1.0; 4 * 2], &[4, 2]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 3, 2]);
for &v in c.data().unwrap().iter() {
assert!((v - 4.0).abs() < 1e-6);
}
}
#[test]
fn test_matmul_2d_3d_broadcast() {
let a = t(&vec![1.0; 3 * 4], &[3, 4]);
let b = t(&vec![1.0; 2 * 4 * 2], &[2, 4, 2]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 3, 2]);
}
#[test]
fn test_matmul_batch_broadcast_1_vs_n() {
let a = t(&vec![1.0; 1 * 2 * 3], &[1, 2, 3]);
let b = t(&vec![1.0; 4 * 3 * 2], &[4, 3, 2]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[4, 2, 2]);
}
#[test]
fn test_matmul_4d() {
let a = t(&vec![1.0; 2 * 3 * 2 * 4], &[2, 3, 2, 4]);
let b = t(&vec![1.0; 2 * 3 * 4 * 5], &[2, 3, 4, 5]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 3, 2, 5]);
}
#[test]
fn test_matmul_3d_1d() {
let a = t(&vec![1.0; 2 * 3 * 4], &[2, 3, 4]);
let b = t(&vec![1.0; 4], &[4]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 3]);
for &v in c.data().unwrap().iter() {
assert!((v - 4.0).abs() < 1e-6);
}
}
#[test]
fn test_matmul_1d_3d() {
let a = t(&vec![1.0; 4], &[4]);
let b = t(&vec![1.0; 2 * 4 * 3], &[2, 4, 3]);
let c = matmul(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 3]);
}
#[test]
fn test_matmul_broadcast_mismatch() {
let a = t(&vec![1.0; 2 * 3 * 4], &[2, 3, 4]);
let b = t(&vec![1.0; 3 * 4 * 2], &[3, 4, 2]);
assert!(matmul(&a, &b).is_err());
}
#[test]
fn test_matmul_inner_dim_mismatch() {
let a = t(&vec![1.0; 2 * 3 * 4], &[2, 3, 4]);
let b = t(&vec![1.0; 2 * 5 * 2], &[2, 5, 2]);
assert!(matmul(&a, &b).is_err());
}
#[test]
fn test_mm_shape_mismatch() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let b = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
assert!(mm(&a, &b).is_err());
}
#[test]
fn test_transpose() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let b = transpose(&a).unwrap();
assert_eq!(b.shape(), &[3, 2]);
assert_eq!(b.data().unwrap(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn test_bmm_forward_shape() {
let a = t(&vec![1.0; 2 * 3 * 4], &[2, 3, 4]);
let b = t(&vec![1.0; 2 * 4 * 5], &[2, 4, 5]);
let c = bmm(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 3, 5]);
}
#[test]
fn test_bmm_forward_correctness() {
#[rustfmt::skip]
let a_data: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0,
1.0, 0.0, 0.0, 1.0,
];
#[rustfmt::skip]
let b_data: Vec<f32> = vec![
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
];
let a = t(&a_data, &[2, 2, 2]);
let b = t(&b_data, &[2, 2, 2]);
let c = bmm(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 2, 2]);
let d = c.data().unwrap();
assert!((d[0] - 19.0).abs() < 1e-6);
assert!((d[1] - 22.0).abs() < 1e-6);
assert!((d[2] - 43.0).abs() < 1e-6);
assert!((d[3] - 50.0).abs() < 1e-6);
assert!((d[4] - 9.0).abs() < 1e-6);
assert!((d[5] - 10.0).abs() < 1e-6);
assert!((d[6] - 11.0).abs() < 1e-6);
assert!((d[7] - 12.0).abs() < 1e-6);
}
#[test]
fn test_bmm_batch_size_1() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[1, 2, 2]);
let b = t(&[5.0, 6.0, 7.0, 8.0], &[1, 2, 2]);
let c = bmm(&a, &b).unwrap();
assert_eq!(c.shape(), &[1, 2, 2]);
let d = c.data().unwrap();
assert!((d[0] - 19.0).abs() < 1e-6);
assert!((d[1] - 22.0).abs() < 1e-6);
assert!((d[2] - 43.0).abs() < 1e-6);
assert!((d[3] - 50.0).abs() < 1e-6);
}
#[test]
fn test_bmm_shape_mismatch() {
let a = t(&vec![1.0; 2 * 2 * 2], &[2, 2, 2]);
let b = t(&vec![1.0; 3 * 2 * 2], &[3, 2, 2]);
assert!(bmm(&a, &b).is_err());
let a = t(&vec![1.0; 2 * 2 * 3], &[2, 2, 3]);
let b = t(&vec![1.0; 2 * 4 * 2], &[2, 4, 2]);
assert!(bmm(&a, &b).is_err());
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&vec![1.0; 1 * 2 * 2], &[1, 2, 2]);
assert!(bmm(&a, &b).is_err());
}
}