use scirs2_core::ndarray::{Array2, ArrayView2};
use scirs2_core::numeric::Zero;
pub fn matmul_2d_optimized<T>(a: ArrayView2<T>, b: ArrayView2<T>) -> Array2<T>
where
T: Clone + Zero + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + Default,
{
let (m, k) = a.dim();
let (_, n) = b.dim();
if m <= 64 && n <= 64 && k <= 64 {
matmul_simple_optimized(a, b)
} else {
matmul_blocked(a, b)
}
}
#[cfg(feature = "blas")]
pub fn matmul_blas_f32<T>(a: ArrayView2<T>, b: ArrayView2<T>) -> Array2<T>
where
T: Clone + Default + Into<f32> + From<f32>,
{
use ndarray_linalg::*;
let a_f32: Array2<f32> = a.mapv(|x| x.clone().into());
let b_f32: Array2<f32> = b.mapv(|x| x.clone().into());
let result_f32 = a_f32.dot(&b_f32);
result_f32.mapv(|x| x.into())
}
#[cfg(feature = "blas")]
pub fn matmul_blas_f64<T>(a: ArrayView2<T>, b: ArrayView2<T>) -> Array2<T>
where
T: Clone + Default + Into<f64> + From<f64>,
{
use ndarray_linalg::*;
let a_f64: Array2<f64> = a.mapv(|x| x.clone().into());
let b_f64: Array2<f64> = b.mapv(|x| x.clone().into());
let result_f64 = a_f64.dot(&b_f64);
result_f64.mapv(|x| x.into())
}
pub fn matmul_blocked<T>(a: ArrayView2<T>, b: ArrayView2<T>) -> Array2<T>
where
T: Clone + Zero + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + Default,
{
let (m, k) = a.dim();
let (_, n) = b.dim();
const BLOCK_SIZE: usize = 64;
let mut result = Array2::<T>::zeros((m, n));
for i_block in (0..m).step_by(BLOCK_SIZE) {
for j_block in (0..n).step_by(BLOCK_SIZE) {
for k_block in (0..k).step_by(BLOCK_SIZE) {
let i_end = (i_block + BLOCK_SIZE).min(m);
let j_end = (j_block + BLOCK_SIZE).min(n);
let k_end = (k_block + BLOCK_SIZE).min(k);
for i in i_block..i_end {
for j in j_block..j_end {
let mut sum = result[[i, j]].clone();
for k_idx in k_block..k_end {
sum = sum + (a[[i, k_idx]].clone() * b[[k_idx, j]].clone());
}
result[[i, j]] = sum;
}
}
}
}
}
result
}
pub fn matmul_simple_optimized<T>(a: ArrayView2<T>, b: ArrayView2<T>) -> Array2<T>
where
T: Clone + Zero + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + Default,
{
let (m, k) = a.dim();
let (_, n) = b.dim();
let mut result = Array2::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut sum = T::zero();
for k_idx in 0..k {
sum = sum + (a[[i, k_idx]].clone() * b[[k_idx, j]].clone());
}
result[[i, j]] = sum;
}
}
result
}
pub fn matmul_cache_optimized<T>(a: &Array2<T>, b: &Array2<T>) -> Array2<T>
where
T: Clone + Zero + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + Default,
{
let (m, k) = a.dim();
let (_, n) = b.dim();
if m > 128 || n > 128 || k > 128 {
matmul_blocked_optimized(a, b, 64)
} else {
let mut result = Array2::<T>::zeros((m, n));
for i in 0..m {
for k_idx in 0..k {
let a_ik = a[[i, k_idx]].clone();
for j in 0..n {
result[[i, j]] =
result[[i, j]].clone() + (a_ik.clone() * b[[k_idx, j]].clone());
}
}
}
result
}
}
pub fn matmul_blocked_optimized<T>(a: &Array2<T>, b: &Array2<T>, block_size: usize) -> Array2<T>
where
T: Clone + Zero + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + Default,
{
let (m, k) = a.dim();
let (_, n) = b.dim();
let mut result = Array2::<T>::zeros((m, n));
for i in (0..m).step_by(block_size) {
for j in (0..n).step_by(block_size) {
for k_idx in (0..k).step_by(block_size) {
let i_end = (i + block_size).min(m);
let j_end = (j + block_size).min(n);
let k_end = (k_idx + block_size).min(k);
for ii in i..i_end {
for jj in j..j_end {
let mut sum = result[[ii, jj]].clone();
for kk in k_idx..k_end {
sum = sum + (a[[ii, kk]].clone() * b[[kk, jj]].clone());
}
result[[ii, jj]] = sum;
}
}
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_matmul_simple_optimized() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[5.0, 6.0], [7.0, 8.0]];
let result = matmul_simple_optimized(a.view(), b.view());
let expected = array![[19.0, 22.0], [43.0, 50.0]];
assert_eq!(result, expected);
}
#[test]
fn test_matmul_blocked() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let b = array![[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
let result = matmul_blocked(a.view(), b.view());
let expected = array![[58.0, 64.0], [139.0, 154.0]];
assert_eq!(result, expected);
}
}