use crate::shape_error_taxonomy::{validate_matmul_shapes, ShapeErrorUtils};
use crate::tensor::TensorStorage;
use crate::ultra_performance_profiler::record_matmul_performance;
use crate::{Result, Shape, Tensor, TensorError};
use scirs2_core::metrics::Timer;
use scirs2_core::ndarray::{Array2, ArrayD, ArrayView2};
use scirs2_core::numeric::{One, Zero};
use std::time::{Duration, Instant};
#[derive(Clone, Copy)]
#[allow(dead_code)]
struct MatrixOffsets {
a_row_start: usize,
a_col_start: usize,
b_row_start: usize,
b_col_start: usize,
c_row_start: usize,
c_col_start: usize,
}
#[derive(Clone, Copy)]
#[allow(dead_code)]
struct MatrixDims {
m: usize,
k: usize,
n: usize,
cutoff: usize,
}
pub fn ultra_matmul<T>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
let timer = Timer::new("ultra_matmul".to_string());
let _timer_guard = timer.start();
let start_time = Instant::now();
if a.device() != b.device() {
return Err(TensorError::device_mismatch(
"ultra_matmul",
&a.device().to_string(),
&b.device().to_string(),
));
}
let a_shape_obj = a.shape();
let b_shape_obj = b.shape();
if a_shape_obj.rank() < 2 {
return Err(ShapeErrorUtils::rank_mismatch(
"ultra_matmul",
2,
a_shape_obj,
));
}
if b_shape_obj.rank() < 2 {
return Err(ShapeErrorUtils::rank_mismatch(
"ultra_matmul",
2,
b_shape_obj,
));
}
let result_shape_obj =
validate_matmul_shapes("ultra_matmul", a_shape_obj, b_shape_obj, false, false)?;
let result_shape = result_shape_obj.dims().to_vec();
let a_shape = a_shape_obj.dims();
let b_shape = b_shape_obj.dims();
let m = a_shape[a_shape.len() - 2];
let k1 = a_shape[a_shape.len() - 1];
let n = b_shape[b_shape.len() - 1];
match (&a.storage, &b.storage) {
(TensorStorage::Cpu(arr_a), TensorStorage::Cpu(arr_b)) => {
let result = if a_shape.len() == 2 && b_shape.len() == 2 {
ultra_matmul_2d_adaptive(arr_a, arr_b, m, n, k1)
} else {
ultra_matmul_batch_adaptive(arr_a, arr_b, &result_shape)
};
let elapsed = start_time.elapsed();
record_matmul_performance("ultra_matmul_cpu", m, n, k1, elapsed);
log_performance_metrics(m, n, k1, elapsed);
result
}
#[cfg(feature = "gpu")]
(TensorStorage::Gpu(_), TensorStorage::Gpu(_)) => {
crate::ops::matmul::matmul(a, b)
}
#[cfg(feature = "gpu")]
_ => Err(TensorError::unsupported_operation_simple(
"Mixed CPU/GPU ultra matrix multiplication not supported".to_string(),
)),
}
}
fn ultra_matmul_2d_adaptive<T>(
a: &ArrayD<T>,
b: &ArrayD<T>,
m: usize,
n: usize,
k: usize,
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static,
{
let a_2d = a
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.map_err(|e| TensorError::invalid_shape_simple(e.to_string()))?;
let b_2d = b
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.map_err(|e| TensorError::invalid_shape_simple(e.to_string()))?;
let result = if k == 1 {
ultra_outer_product(a_2d, b_2d)
} else if m <= 8 || n <= 8 || k <= 8 {
ultra_matmul_micro_simd(a_2d, b_2d)
} else if m <= 64 && n <= 64 && k <= 64 {
ultra_matmul_small_cache_simd(a_2d, b_2d)
} else if m >= 512 || n >= 512 || k >= 512 {
ultra_matmul_large_cache_oblivious(a_2d, b_2d)
} else {
ultra_matmul_medium_adaptive(a_2d, b_2d)
};
let result_dyn = result.into_dyn();
Ok(Tensor::from_array(result_dyn))
}
fn ultra_matmul_micro_simd<T>(a: ArrayView2<T>, b: ArrayView2<T>) -> Array2<T>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ 'static,
{
let (m, k) = a.dim();
let (_, n) = b.dim();
let mut result = Array2::<T>::zeros((m, n));
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
return ultra_matmul_f32_simd(a, b);
}
for i in 0..m {
for j in 0..n {
let mut sum = T::zero();
let mut k_idx = 0;
while k_idx + 4 <= k {
sum = sum + a[[i, k_idx]].clone() * b[[k_idx, j]].clone();
sum = sum + a[[i, k_idx + 1]].clone() * b[[k_idx + 1, j]].clone();
sum = sum + a[[i, k_idx + 2]].clone() * b[[k_idx + 2, j]].clone();
sum = sum + a[[i, k_idx + 3]].clone() * b[[k_idx + 3, j]].clone();
k_idx += 4;
}
while k_idx < k {
sum = sum + a[[i, k_idx]].clone() * b[[k_idx, j]].clone();
k_idx += 1;
}
result[[i, j]] = sum;
}
}
result
}
fn ultra_matmul_f32_simd<T>(a: ArrayView2<T>, b: ArrayView2<T>) -> Array2<T>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ 'static,
{
let (m, k) = a.dim();
let (_, n) = b.dim();
let mut result = Array2::<T>::zeros((m, n));
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let mut a_f32_data = Vec::with_capacity(m * k);
let mut b_f32_data = Vec::with_capacity(k * n);
for i in 0..m {
for j in 0..k {
let val_ptr = &a[[i, j]] as *const T as *const f32;
unsafe {
a_f32_data.push(*val_ptr);
}
}
}
for i in 0..k {
for j in 0..n {
let val_ptr = &b[[i, j]] as *const T as *const f32;
unsafe {
b_f32_data.push(*val_ptr);
}
}
}
let a_f32 = Array2::from_shape_vec((m, k), a_f32_data)
.expect("a_f32_data must match expected shape (m, k)");
let b_f32 = Array2::from_shape_vec((k, n), b_f32_data)
.expect("b_f32_data must match expected shape (k, n)");
let result_f32 = ultra_matmul_f32_optimized(&a_f32.view(), &b_f32.view());
for i in 0..m {
for j in 0..n {
let f32_val = result_f32[[i, j]];
let t_val_ptr = &f32_val as *const f32 as *const T;
unsafe {
result[[i, j]] = (*t_val_ptr).clone();
}
}
}
return result;
}
ultra_matmul_cache_optimized_generic(a, b)
}
fn ultra_matmul_f32_optimized(a: &ArrayView2<f32>, b: &ArrayView2<f32>) -> Array2<f32> {
let (m, k) = a.dim();
let (_, n) = b.dim();
let mut result = Array2::<f32>::zeros((m, n));
if m <= 32 && n <= 32 && k <= 32 {
ultra_matmul_f32_small(&mut result, a, b, m, k, n);
} else if m >= 128 || n >= 128 || k >= 128 {
ultra_matmul_f32_blocked(&mut result, a, b, m, k, n);
} else {
ultra_matmul_f32_cache_optimized(&mut result, a, b, m, k, n);
}
result
}
fn ultra_matmul_cache_optimized_generic<T>(a: ArrayView2<T>, b: ArrayView2<T>) -> Array2<T>
where
T: Clone + Default + Zero + One + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
{
let (m, k) = a.dim();
let (_, n) = b.dim();
let mut result = Array2::<T>::zeros((m, n));
for j in 0..n {
for k_idx in 0..k {
let b_val = &b[[k_idx, j]];
for i in 0..m {
result[[i, j]] = result[[i, j]].clone() + a[[i, k_idx]].clone() * b_val.clone();
}
}
}
result
}
fn ultra_matmul_f32_small(
result: &mut Array2<f32>,
a: &ArrayView2<f32>,
b: &ArrayView2<f32>,
m: usize,
k: usize,
n: usize,
) {
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
let mut k_idx = 0;
while k_idx + 4 <= k {
sum += a[[i, k_idx]] * b[[k_idx, j]];
sum += a[[i, k_idx + 1]] * b[[k_idx + 1, j]];
sum += a[[i, k_idx + 2]] * b[[k_idx + 2, j]];
sum += a[[i, k_idx + 3]] * b[[k_idx + 3, j]];
k_idx += 4;
}
while k_idx < k {
sum += a[[i, k_idx]] * b[[k_idx, j]];
k_idx += 1;
}
result[[i, j]] = sum;
}
}
}
fn ultra_matmul_f32_cache_optimized(
result: &mut Array2<f32>,
a: &ArrayView2<f32>,
b: &ArrayView2<f32>,
m: usize,
k: usize,
n: usize,
) {
for j in 0..n {
for k_idx in 0..k {
let b_val = b[[k_idx, j]];
for i in 0..m {
result[[i, j]] += a[[i, k_idx]] * b_val;
}
}
}
}
fn ultra_matmul_f32_blocked(
result: &mut Array2<f32>,
a: &ArrayView2<f32>,
b: &ArrayView2<f32>,
m: usize,
k: usize,
n: usize,
) {
let block_size = 64;
for j_block in (0..n).step_by(block_size) {
for k_block in (0..k).step_by(block_size) {
for i_block in (0..m).step_by(block_size) {
let i_end = (i_block + block_size).min(m);
let j_end = (j_block + block_size).min(n);
let k_end = (k_block + block_size).min(k);
for j in j_block..j_end {
for k_idx in k_block..k_end {
let b_val = b[[k_idx, j]];
for i in i_block..i_end {
result[[i, j]] += a[[i, k_idx]] * b_val;
}
}
}
}
}
}
}
fn ultra_matmul_large_cache_oblivious<T>(a: ArrayView2<T>, b: ArrayView2<T>) -> Array2<T>
where
T: Clone + Default + Zero + One + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
{
let (m, k) = a.dim();
let (_, n) = b.dim();
let mut result = Array2::<T>::zeros((m, n));
let cutoff = determine_optimal_cutoff();
cache_oblivious_multiply(
&a,
&b,
&mut result.view_mut(),
0,
0,
0,
0,
0,
0,
m,
k,
n,
cutoff,
);
result
}
fn determine_optimal_cutoff() -> usize {
64
}
#[allow(clippy::too_many_arguments)]
fn cache_oblivious_multiply<T>(
a: &ArrayView2<T>,
b: &ArrayView2<T>,
c: &mut scirs2_core::ndarray::ArrayViewMut2<T>,
a_row_start: usize,
a_col_start: usize,
b_row_start: usize,
b_col_start: usize,
c_row_start: usize,
c_col_start: usize,
m: usize,
k: usize,
n: usize,
cutoff: usize,
) where
T: Clone + Default + Zero + One + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
{
if m <= cutoff && k <= cutoff && n <= cutoff {
cache_oblivious_base_case(
a,
b,
c,
a_row_start,
a_col_start,
b_row_start,
b_col_start,
c_row_start,
c_col_start,
m,
k,
n,
);
return;
}
if m >= k && m >= n {
let m1 = m / 2;
let m2 = m - m1;
cache_oblivious_multiply(
a,
b,
c,
a_row_start,
a_col_start,
b_row_start,
b_col_start,
c_row_start,
c_col_start,
m1,
k,
n,
cutoff,
);
cache_oblivious_multiply(
a,
b,
c,
a_row_start + m1,
a_col_start,
b_row_start,
b_col_start,
c_row_start + m1,
c_col_start,
m2,
k,
n,
cutoff,
);
} else if k >= n {
let k1 = k / 2;
let k2 = k - k1;
cache_oblivious_multiply(
a,
b,
c,
a_row_start,
a_col_start,
b_row_start,
b_col_start,
c_row_start,
c_col_start,
m,
k1,
n,
cutoff,
);
cache_oblivious_multiply(
a,
b,
c,
a_row_start,
a_col_start + k1,
b_row_start + k1,
b_col_start,
c_row_start,
c_col_start,
m,
k2,
n,
cutoff,
);
} else {
let n1 = n / 2;
let n2 = n - n1;
cache_oblivious_multiply(
a,
b,
c,
a_row_start,
a_col_start,
b_row_start,
b_col_start,
c_row_start,
c_col_start,
m,
k,
n1,
cutoff,
);
cache_oblivious_multiply(
a,
b,
c,
a_row_start,
a_col_start,
b_row_start,
b_col_start + n1,
c_row_start,
c_col_start + n1,
m,
k,
n2,
cutoff,
);
}
}
#[allow(clippy::too_many_arguments)]
fn cache_oblivious_base_case<T>(
a: &ArrayView2<T>,
b: &ArrayView2<T>,
c: &mut scirs2_core::ndarray::ArrayViewMut2<T>,
a_row_start: usize,
a_col_start: usize,
b_row_start: usize,
b_col_start: usize,
c_row_start: usize,
c_col_start: usize,
m: usize,
k: usize,
n: usize,
) where
T: Clone + Default + Zero + One + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
{
for j in 0..n {
for k_idx in 0..k {
let b_val = &b[[b_row_start + k_idx, b_col_start + j]];
for i in 0..m {
let a_val = &a[[a_row_start + i, a_col_start + k_idx]];
c[[c_row_start + i, c_col_start + j]] =
c[[c_row_start + i, c_col_start + j]].clone() + a_val.clone() * b_val.clone();
}
}
}
}
fn ultra_outer_product<T>(a: ArrayView2<T>, b: ArrayView2<T>) -> Array2<T>
where
T: Clone + Default + Zero + One + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
{
let (m, k) = a.dim();
let (_, n) = b.dim();
assert_eq!(k, 1);
let mut result = Array2::<T>::zeros((m, n));
let a_col = a.column(0);
let b_row = b.row(0);
if m >= 64 && n >= 64 {
let chunk_size = 64;
for i_chunk in (0..m).step_by(chunk_size) {
let i_end = (i_chunk + chunk_size).min(m);
for j_chunk in (0..n).step_by(chunk_size) {
let j_end = (j_chunk + chunk_size).min(n);
for i in i_chunk..i_end {
let a_val = a_col[i].clone();
for j in j_chunk..j_end {
result[[i, j]] = a_val.clone() * b_row[j].clone();
}
}
}
}
} else {
for i in 0..m {
let a_val = a_col[i].clone();
for j in 0..n {
result[[i, j]] = a_val.clone() * b_row[j].clone();
}
}
}
result
}
fn ultra_matmul_small_cache_simd<T>(a: ArrayView2<T>, b: ArrayView2<T>) -> Array2<T>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ 'static,
{
let (_m, _k) = a.dim();
let (_, _n) = b.dim();
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
return ultra_matmul_f32_simd(a, b);
}
ultra_matmul_cache_optimized_generic(a, b)
}
fn ultra_matmul_medium_adaptive<T>(a: ArrayView2<T>, b: ArrayView2<T>) -> Array2<T>
where
T: Clone + Default + Zero + One + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
{
let (m, k) = a.dim();
let (_, n) = b.dim();
let block_size = if m * n * k < 1_000_000 {
32 } else {
64 };
ultra_matmul_blocked_optimized(a, b, block_size)
}
fn ultra_matmul_blocked_optimized<T>(
a: ArrayView2<T>,
b: ArrayView2<T>,
block_size: usize,
) -> Array2<T>
where
T: Clone + Default + Zero + One + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
{
let (m, k) = a.dim();
let (_, n) = b.dim();
let mut result = Array2::<T>::zeros((m, n));
for j_block in (0..n).step_by(block_size) {
for k_block in (0..k).step_by(block_size) {
for i_block in (0..m).step_by(block_size) {
let i_end = (i_block + block_size).min(m);
let j_end = (j_block + block_size).min(n);
let k_end = (k_block + block_size).min(k);
for j in j_block..j_end {
for k_idx in k_block..k_end {
let b_val = &b[[k_idx, j]];
for i in i_block..i_end {
result[[i, j]] =
result[[i, j]].clone() + a[[i, k_idx]].clone() * b_val.clone();
}
}
}
}
}
}
result
}
fn ultra_matmul_batch_adaptive<T>(
a: &ArrayD<T>,
b: &ArrayD<T>,
result_shape: &[usize],
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static,
{
let a_shape = a.shape();
let b_shape = b.shape();
let a_ndim = a_shape.len();
let b_ndim = b_shape.len();
let m = a_shape[a_ndim - 2];
let k = a_shape[a_ndim - 1];
let n = b_shape[b_ndim - 1];
let mut result = ArrayD::zeros(scirs2_core::ndarray::IxDyn(result_shape));
let batch_size: usize = result_shape[..result_shape.len() - 2].iter().product();
for batch_idx in 0..batch_size {
let mut batch_indices = vec![0; result_shape.len() - 2];
let mut idx = batch_idx;
for i in (0..batch_indices.len()).rev() {
batch_indices[i] = idx % result_shape[i];
idx /= result_shape[i];
}
let a_indices = compute_broadcast_indices(&batch_indices, &a_shape[..a_ndim - 2]);
let b_indices = compute_broadcast_indices(&batch_indices, &b_shape[..b_ndim - 2]);
let a_mat = extract_2d_slice(a, &a_indices, m, k);
let b_mat = extract_2d_slice(b, &b_indices, k, n);
let c_mat = ultra_matmul_2d_raw(&a_mat, &b_mat);
store_2d_slice(&mut result, &batch_indices, &c_mat);
}
Ok(Tensor::from_array(result))
}
fn log_performance_metrics(m: usize, n: usize, k: usize, elapsed: Duration) {
let ops = 2 * m * n * k; let gflops = ops as f64 / elapsed.as_secs_f64() / 1e9;
println!(
"Ultra MatMul Performance: {}x{}x{} = {:.2} GFLOPs in {:.2}ms",
m,
n,
k,
gflops,
elapsed.as_secs_f64() * 1000.0
);
}
fn compute_matmul_shape(a_shape: &[usize], b_shape: &[usize]) -> Result<Vec<usize>> {
let a_ndim = a_shape.len();
let b_ndim = b_shape.len();
let a_batch = &a_shape[..a_ndim - 2];
let b_batch = &b_shape[..b_ndim - 2];
let batch_shape = broadcast_shapes(a_batch, b_batch)?;
let mut result_shape = batch_shape;
result_shape.push(a_shape[a_ndim - 2]); result_shape.push(b_shape[b_ndim - 1]);
Ok(result_shape)
}
fn broadcast_shapes(a: &[usize], b: &[usize]) -> Result<Vec<usize>> {
let max_len = a.len().max(b.len());
let mut result = Vec::with_capacity(max_len);
for i in 0..max_len {
let a_dim = if i < a.len() { a[a.len() - 1 - i] } else { 1 };
let b_dim = if i < b.len() { b[b.len() - 1 - i] } else { 1 };
if a_dim == b_dim {
result.push(a_dim);
} else if a_dim == 1 {
result.push(b_dim);
} else if b_dim == 1 {
result.push(a_dim);
} else {
return Err(TensorError::invalid_argument(format!(
"Cannot broadcast shapes {a:?} and {b:?}"
)));
}
}
result.reverse();
Ok(result)
}
fn compute_broadcast_indices(indices: &[usize], shape: &[usize]) -> Vec<usize> {
let mut result = vec![0; shape.len()];
let offset = indices.len() - shape.len();
for i in 0..shape.len() {
if i >= offset {
let idx = indices[i - offset];
result[i] = if shape[i] == 1 { 0 } else { idx };
}
}
result
}
fn extract_2d_slice<T: Clone + Zero>(
arr: &ArrayD<T>,
batch_indices: &[usize],
rows: usize,
cols: usize,
) -> Array2<T> {
let mut result = Array2::zeros((rows, cols));
for i in 0..rows {
for j in 0..cols {
let mut indices = batch_indices.to_vec();
indices.push(i);
indices.push(j);
if let Some(val) = arr.get(indices.as_slice()) {
result[[i, j]] = val.clone();
}
}
}
result
}
fn store_2d_slice<T: Clone>(arr: &mut ArrayD<T>, batch_indices: &[usize], mat: &Array2<T>) {
let (rows, cols) = mat.dim();
for i in 0..rows {
for j in 0..cols {
let mut indices = batch_indices.to_vec();
indices.push(i);
indices.push(j);
if let Some(dst) = arr.get_mut(indices.as_slice()) {
*dst = mat[[i, j]].clone();
}
}
}
}
fn ultra_matmul_2d_raw<T>(a: &Array2<T>, b: &Array2<T>) -> Array2<T>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ 'static,
{
let (m, _k) = a.dim();
let (_, n) = b.dim();
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
return ultra_matmul_f32_simd(a.view(), b.view());
}
if m > 128 && n > 128 {
ultra_matmul_blocked_optimized(a.view(), b.view(), 64)
} else {
ultra_matmul_small_cache_simd(a.view(), b.view())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ultra_matmul_basic() {
let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
.expect("test: from_vec should succeed");
let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2])
.expect("test: from_vec should succeed");
let result = ultra_matmul(&a, &b).expect("test: ultra_matmul should succeed");
assert_eq!(result.shape().dims(), &[2, 2]);
if let Some(data) = result.as_slice() {
assert_eq!(data, &[19.0, 22.0, 43.0, 50.0]);
}
}
#[test]
fn test_ultra_matmul_large() {
let size = 128;
let a_data: Vec<f32> = (0..size * size).map(|i| i as f32).collect();
let b_data: Vec<f32> = (0..size * size).map(|i| (i + 1) as f32).collect();
let a =
Tensor::<f32>::from_vec(a_data, &[size, size]).expect("test: from_vec should succeed");
let b =
Tensor::<f32>::from_vec(b_data, &[size, size]).expect("test: from_vec should succeed");
let start = Instant::now();
let _result = ultra_matmul(&a, &b).expect("test: ultra_matmul should succeed");
let elapsed = start.elapsed();
println!(
"Ultra MatMul {}x{} completed in {:.2}ms",
size,
size,
elapsed.as_secs_f64() * 1000.0
);
}
#[test]
fn test_ultra_matmul_simd_f32() {
let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
.expect("test: from_vec should succeed");
let b = Tensor::<f32>::from_vec(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2])
.expect("test: from_vec should succeed");
let result = ultra_matmul(&a, &b).expect("test: ultra_matmul should succeed");
assert_eq!(result.shape().dims(), &[2, 2]);
if let Some(data) = result.as_slice() {
assert_eq!(data, &[58.0, 64.0, 139.0, 154.0]);
}
}
}