use crate::tensor::TensorStorage;
use crate::ultra_performance_profiler::record_matmul_performance;
use crate::{Result, Tensor, TensorError};
use scirs2_core::metrics::Timer;
use scirs2_core::ndarray::{Array2, ArrayD, ArrayView2};
use scirs2_core::numeric::Num;
use std::time::Instant;
pub fn ultra_matmul_v2<T>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone + Default + Num + Send + Sync + 'static + bytemuck::Pod + bytemuck::Zeroable,
{
let timer = Timer::new("ultra_matmul_v2".to_string());
let _timer_guard = timer.start();
let start_time = Instant::now();
if a.device() != b.device() {
return Err(TensorError::device_mismatch(
"ultra_matmul_v2",
&a.device().to_string(),
&b.device().to_string(),
));
}
let a_shape = a.shape().dims();
let b_shape = b.shape().dims();
if a_shape.len() < 2 || b_shape.len() < 2 {
return Err(TensorError::invalid_shape_simple(
"Ultra matrix multiplication V2 requires at least 2D tensors".to_string(),
));
}
let m = a_shape[a_shape.len() - 2];
let k1 = a_shape[a_shape.len() - 1];
let k2 = b_shape[b_shape.len() - 2];
let n = b_shape[b_shape.len() - 1];
if k1 != k2 {
return Err(TensorError::shape_mismatch(
"ultra_matmul_v2",
"inner dimensions to match",
&format!("{k1} vs {k2}"),
));
}
let result_shape = compute_matmul_shape_v2(a_shape, b_shape)?;
match (&a.storage, &b.storage) {
(TensorStorage::Cpu(arr_a), TensorStorage::Cpu(arr_b)) => {
let result = if a_shape.len() == 2 && b_shape.len() == 2 {
ultra_matmul_2d_optimized_v2(arr_a, arr_b, m, n, k1)
} else {
ultra_matmul_batch_optimized_v2(arr_a, arr_b, &result_shape)
};
let elapsed = start_time.elapsed();
record_matmul_performance("ultra_matmul_v2_cpu", m, n, k1, elapsed);
result
}
#[cfg(feature = "gpu")]
(TensorStorage::Gpu(_), TensorStorage::Gpu(_)) => {
crate::ops::matmul::matmul(a, b)
}
#[cfg(feature = "gpu")]
_ => Err(TensorError::unsupported_operation_simple(
"Mixed CPU/GPU ultra matrix multiplication V2 not supported".to_string(),
)),
}
}
fn ultra_matmul_2d_optimized_v2<T>(
a: &ArrayD<T>,
b: &ArrayD<T>,
m: usize,
n: usize,
k: usize,
) -> Result<Tensor<T>>
where
T: Clone + Default + Num + Send + Sync + 'static,
{
let a_2d = a
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.map_err(|e| TensorError::invalid_shape_simple(e.to_string()))?;
let b_2d = b
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.map_err(|e| TensorError::invalid_shape_simple(e.to_string()))?;
let result = match (m, n, k) {
(1..=8, 1..=8, 1..=8) => ultra_small_matrix_multiply(a_2d, b_2d),
(_, _, 1) => ultra_outer_product_v2(a_2d, b_2d),
_ => ultra_cache_optimized_multiply(a_2d, b_2d),
};
let result_dyn = result.into_dyn();
Ok(Tensor::from_array(result_dyn))
}
fn ultra_small_matrix_multiply<T>(a: ArrayView2<T>, b: ArrayView2<T>) -> Array2<T>
where
T: Clone + Default + Num,
{
let (m, k) = a.dim();
let (_, n) = b.dim();
let mut result = Array2::<T>::zeros((m, n));
match (m, n, k) {
(2, 2, 2) => {
result[[0, 0]] =
a[[0, 0]].clone() * b[[0, 0]].clone() + a[[0, 1]].clone() * b[[1, 0]].clone();
result[[0, 1]] =
a[[0, 0]].clone() * b[[0, 1]].clone() + a[[0, 1]].clone() * b[[1, 1]].clone();
result[[1, 0]] =
a[[1, 0]].clone() * b[[0, 0]].clone() + a[[1, 1]].clone() * b[[1, 0]].clone();
result[[1, 1]] =
a[[1, 0]].clone() * b[[0, 1]].clone() + a[[1, 1]].clone() * b[[1, 1]].clone();
}
_ => {
for j in 0..n {
for k_idx in 0..k {
let b_val = &b[[k_idx, j]];
for i in 0..m {
result[[i, j]] =
result[[i, j]].clone() + a[[i, k_idx]].clone() * b_val.clone();
}
}
}
}
}
result
}
fn ultra_outer_product_v2<T>(a: ArrayView2<T>, b: ArrayView2<T>) -> Array2<T>
where
T: Clone + Default + Num,
{
let (m, _) = a.dim();
let (_, n) = b.dim();
let a_col = a.column(0);
let b_row = b.row(0);
let mut result = Array2::<T>::zeros((m, n));
for i in 0..m {
let a_val = &a_col[i];
for j in 0..n {
let b_val = &b_row[j];
result[[i, j]] = a_val.clone() * b_val.clone();
}
}
result
}
fn ultra_cache_optimized_multiply<T>(a: ArrayView2<T>, b: ArrayView2<T>) -> Array2<T>
where
T: Clone + Default + Num,
{
let (m, k) = a.dim();
let (_, n) = b.dim();
let mut result = Array2::<T>::zeros((m, n));
for j in 0..n {
for k_idx in 0..k {
let b_val = &b[[k_idx, j]];
for i in 0..m {
result[[i, j]] = result[[i, j]].clone() + a[[i, k_idx]].clone() * b_val.clone();
}
}
}
result
}
fn ultra_matmul_batch_optimized_v2<T>(
a: &ArrayD<T>,
b: &ArrayD<T>,
result_shape: &[usize],
) -> Result<Tensor<T>>
where
T: Clone + Default + Num + Send + Sync + 'static,
{
let a_shape = a.shape();
let b_shape = b.shape();
let a_ndim = a_shape.len();
let b_ndim = b_shape.len();
let m = a_shape[a_ndim - 2];
let k = a_shape[a_ndim - 1];
let n = b_shape[b_ndim - 1];
let mut result = ArrayD::zeros(scirs2_core::ndarray::IxDyn(result_shape));
let batch_size: usize = result_shape[..result_shape.len() - 2].iter().product();
for batch_idx in 0..batch_size {
let mut batch_indices = vec![0; result_shape.len() - 2];
let mut idx = batch_idx;
for i in (0..batch_indices.len()).rev() {
batch_indices[i] = idx % result_shape[i];
idx /= result_shape[i];
}
let a_indices = compute_broadcast_indices_v2(&batch_indices, &a_shape[..a_ndim - 2]);
let b_indices = compute_broadcast_indices_v2(&batch_indices, &b_shape[..b_ndim - 2]);
let a_mat = extract_2d_slice_v2(a, &a_indices, m, k);
let b_mat = extract_2d_slice_v2(b, &b_indices, k, n);
let c_mat = ultra_cache_optimized_multiply(a_mat.view(), b_mat.view());
store_2d_slice_v2(&mut result, &batch_indices, &c_mat);
}
Ok(Tensor::from_array(result))
}
fn compute_broadcast_indices_v2(indices: &[usize], shape: &[usize]) -> Vec<usize> {
let mut result = vec![0; shape.len()];
let offset = if indices.len() >= shape.len() {
indices.len() - shape.len()
} else {
0
};
for i in 0..shape.len() {
if i + offset < indices.len() {
let idx = indices[i + offset];
result[i] = if shape[i] == 1 { 0 } else { idx };
}
}
result
}
fn extract_2d_slice_v2<T: Clone + Default + Num>(
arr: &ArrayD<T>,
batch_indices: &[usize],
rows: usize,
cols: usize,
) -> Array2<T> {
let mut result = Array2::zeros((rows, cols));
for i in 0..rows {
for j in 0..cols {
let mut indices = batch_indices.to_vec();
indices.push(i);
indices.push(j);
if let Some(val) = arr.get(indices.as_slice()) {
result[[i, j]] = val.clone();
}
}
}
result
}
fn store_2d_slice_v2<T: Clone>(arr: &mut ArrayD<T>, batch_indices: &[usize], mat: &Array2<T>) {
let (rows, cols) = mat.dim();
for i in 0..rows {
for j in 0..cols {
let mut indices = batch_indices.to_vec();
indices.push(i);
indices.push(j);
if let Some(dst) = arr.get_mut(indices.as_slice()) {
*dst = mat[[i, j]].clone();
}
}
}
}
fn compute_matmul_shape_v2(a_shape: &[usize], b_shape: &[usize]) -> Result<Vec<usize>> {
let a_ndim = a_shape.len();
let b_ndim = b_shape.len();
let a_batch = &a_shape[..a_ndim - 2];
let b_batch = &b_shape[..b_ndim - 2];
let batch_shape = broadcast_shapes_v2(a_batch, b_batch)?;
let mut result_shape = batch_shape;
result_shape.push(a_shape[a_ndim - 2]); result_shape.push(b_shape[b_ndim - 1]);
Ok(result_shape)
}
fn broadcast_shapes_v2(a: &[usize], b: &[usize]) -> Result<Vec<usize>> {
let max_len = a.len().max(b.len());
let mut result = Vec::with_capacity(max_len);
for i in 0..max_len {
let a_dim = if i < a.len() { a[a.len() - 1 - i] } else { 1 };
let b_dim = if i < b.len() { b[b.len() - 1 - i] } else { 1 };
if a_dim == b_dim {
result.push(a_dim);
} else if a_dim == 1 {
result.push(b_dim);
} else if b_dim == 1 {
result.push(a_dim);
} else {
return Err(TensorError::invalid_argument(format!(
"Cannot broadcast shapes {a:?} and {b:?}"
)));
}
}
result.reverse();
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ultra_matmul_v2_basic() {
let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
.expect("test: from_vec should succeed");
let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2])
.expect("test: from_vec should succeed");
let result = ultra_matmul_v2(&a, &b).expect("test: ultra_matmul_v2 should succeed");
assert_eq!(result.shape().dims(), &[2, 2]);
if let Some(data) = result.as_slice() {
assert_eq!(data, &[19.0, 22.0, 43.0, 50.0]);
}
}
#[test]
fn test_ultra_matmul_v2_small() {
let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
.expect("test: from_vec should succeed");
let b = Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2])
.expect("test: from_vec should succeed");
let result = ultra_matmul_v2(&a, &b).expect("test: ultra_matmul_v2 should succeed");
let _expected = [[19.0, 22.0], [43.0, 50.0]];
if let Some(data) = result.as_slice() {
assert_eq!(data, &[19.0, 22.0, 43.0, 50.0]);
}
}
#[test]
fn test_ultra_matmul_v2_outer_product() {
let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3, 1])
.expect("test: from_vec should succeed");
let b = Tensor::<f32>::from_vec(vec![4.0, 5.0], &[1, 2])
.expect("test: from_vec should succeed");
let result = ultra_matmul_v2(&a, &b).expect("test: ultra_matmul_v2 should succeed");
assert_eq!(result.shape().dims(), &[3, 2]);
if let Some(data) = result.as_slice() {
assert_eq!(data, &[4.0, 5.0, 8.0, 10.0, 12.0, 15.0]);
}
}
#[test]
fn test_ultra_matmul_v2_performance() {
let size = 64;
let a_data: Vec<f32> = (0..size * size).map(|i| i as f32).collect();
let b_data: Vec<f32> = (0..size * size).map(|i| (i + 1) as f32).collect();
let a =
Tensor::<f32>::from_vec(a_data, &[size, size]).expect("test: from_vec should succeed");
let b =
Tensor::<f32>::from_vec(b_data, &[size, size]).expect("test: from_vec should succeed");
let start = Instant::now();
let _result = ultra_matmul_v2(&a, &b).expect("test: ultra_matmul_v2 should succeed");
let elapsed = start.elapsed();
println!(
"Ultra MatMul V2 {}x{} completed in {:.2}ms",
size,
size,
elapsed.as_secs_f64() * 1000.0
);
assert!(elapsed.as_millis() < 3000); }
}