use scirs2_core::ndarray::{s, Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::iter::Sum;
use crate::error::{LinalgError, LinalgResult};
#[allow(dead_code)]
pub fn block_matmul<F>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
blocksize: Option<usize>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static,
{
let (m, k1) = a.dim();
let (k2, n) = b.dim();
if k1 != k2 {
return Err(LinalgError::ShapeError(format!(
"Inner dimensions mismatch for matmul: {k1} vs {k2}"
)));
}
let blocksize = blocksize.unwrap_or(64);
let mut result = Array2::zeros((m, n));
let m_blocks = m.div_ceil(blocksize);
let n_blocks = n.div_ceil(blocksize);
let k_blocks = k1.div_ceil(blocksize);
for mb in 0..m_blocks {
let m_start = mb * blocksize;
let m_end = (m_start + blocksize).min(m);
for nb in 0..n_blocks {
let n_start = nb * blocksize;
let n_end = (n_start + blocksize).min(n);
for i in m_start..m_end {
for j in n_start..n_end {
result[[i, j]] = F::zero();
}
}
for kb in 0..k_blocks {
let k_start = kb * blocksize;
let k_end = (k_start + blocksize).min(k1);
for i in m_start..m_end {
for j in n_start..n_end {
for k in k_start..k_end {
result[[i, j]] += a[[i, k]] * b[[k, j]];
}
}
}
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn strassen_matmul<F>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
cutoff: Option<usize>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static,
{
let (m, k1) = a.dim();
let (k2, n) = b.dim();
if k1 != k2 {
return Err(LinalgError::ShapeError(format!(
"Inner dimensions mismatch for matmul: {k1} vs {k2}"
)));
}
let cutoff = cutoff.unwrap_or(128);
if m <= cutoff || k1 <= cutoff || n <= cutoff {
return standard_matmul(a, b);
}
let newsize = 1 << ((m.max(k1).max(n) - 1).ilog2() + 1);
let a_padded = padmatrix(a, newsize, newsize);
let b_padded = padmatrix(b, newsize, newsize);
let result_padded = strassen_recursive(&a_padded.view(), &b_padded.view(), cutoff);
let result = result_padded.slice(s![0..m, 0..n]).to_owned();
Ok(result)
}
#[allow(dead_code)]
fn padmatrix<F>(a: &ArrayView2<F>, new_rows: usize, newcols: usize) -> Array2<F>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let (rows, cols) = a.dim();
let mut result = Array2::zeros((new_rows, newcols));
for i in 0..rows {
for j in 0..cols {
result[[i, j]] = a[[i, j]];
}
}
result
}
#[allow(dead_code)]
fn strassen_recursive<F>(a: &ArrayView2<F>, b: &ArrayView2<F>, cutoff: usize) -> Array2<F>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand,
{
let n = a.dim().0;
if n <= cutoff {
return standard_matmul(a, b).expect("Operation failed");
}
let half_n = n / 2;
let a11 = a.slice(s![0..half_n, 0..half_n]);
let a12 = a.slice(s![0..half_n, half_n..n]);
let a21 = a.slice(s![half_n..n, 0..half_n]);
let a22 = a.slice(s![half_n..n, half_n..n]);
let b11 = b.slice(s![0..half_n, 0..half_n]);
let b12 = b.slice(s![0..half_n, half_n..n]);
let b21 = b.slice(s![half_n..n, 0..half_n]);
let b22 = b.slice(s![half_n..n, half_n..n]);
let p1 = strassen_recursive(&(&a11 + &a22).view(), &(&b11 + &b22).view(), cutoff);
let p2 = strassen_recursive(&(&a21 + &a22).view(), &b11.view(), cutoff);
let p3 = strassen_recursive(&a11.view(), &(&b12 - &b22).view(), cutoff);
let p4 = strassen_recursive(&a22.view(), &(&b21 - &b11).view(), cutoff);
let p5 = strassen_recursive(&(&a11 + &a12).view(), &b22.view(), cutoff);
let p6 = strassen_recursive(&(&a21 - &a11).view(), &(&b11 + &b12).view(), cutoff);
let p7 = strassen_recursive(&(&a12 - &a22).view(), &(&b21 + &b22).view(), cutoff);
let c11 = &p1 + &p4 - &p5 + &p7;
let c12 = &p3 + &p5;
let c21 = &p2 + &p4;
let c22 = &p1 - &p2 + &p3 + &p6;
let mut result = Array2::zeros((n, n));
for i in 0..half_n {
for j in 0..half_n {
result[[i, j]] = c11[[i, j]];
result[[i, j + half_n]] = c12[[i, j]];
result[[i + half_n, j]] = c21[[i, j]];
result[[i + half_n, j + half_n]] = c22[[i, j]];
}
}
result
}
#[allow(dead_code)]
fn standard_matmul<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static,
{
let (m, k1) = a.dim();
let (k2, n) = b.dim();
if k1 != k2 {
return Err(LinalgError::ShapeError(format!(
"Inner dimensions mismatch for matmul: {k1} vs {k2}"
)));
}
let mut result = Array2::zeros((m, n));
for i in 0..m {
for j in 0..n {
for k in 0..k1 {
result[[i, j]] += a[[i, k]] * b[[k, j]];
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn tiled_matmul<F>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
tilesize: Option<usize>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static,
{
let (m, k1) = a.dim();
let (k2, n) = b.dim();
if k1 != k2 {
return Err(LinalgError::ShapeError(format!(
"Inner dimensions mismatch for matmul: {k1} vs {k2}"
)));
}
let tilesize = tilesize.unwrap_or(32);
let mut result = Array2::zeros((m, n));
let m_tiles = m.div_ceil(tilesize);
let n_tiles = n.div_ceil(tilesize);
let k_tiles = k1.div_ceil(tilesize);
for i_tile in 0..m_tiles {
let i_start = i_tile * tilesize;
let i_end = (i_start + tilesize).min(m);
for k_tile in 0..k_tiles {
let k_start = k_tile * tilesize;
let k_end = (k_start + tilesize).min(k1);
for j_tile in 0..n_tiles {
let j_start = j_tile * tilesize;
let j_end = (j_start + tilesize).min(n);
for i in i_start..i_end {
for k in k_start..k_end {
let a_ik = a[[i, k]];
for j in j_start..j_end {
result[[i, j]] += a_ik * b[[k, j]];
}
}
}
}
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_block_matmul_2x2() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[5.0, 6.0], [7.0, 8.0]];
let result = block_matmul(&a.view(), &b.view(), Some(1)).expect("Operation failed");
assert_eq!(result.shape(), &[2, 2]);
assert_relative_eq!(result[[0, 0]], 19.0);
assert_relative_eq!(result[[0, 1]], 22.0);
assert_relative_eq!(result[[1, 0]], 43.0);
assert_relative_eq!(result[[1, 1]], 50.0);
}
#[test]
fn test_block_matmul_2x3_3x2() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let b = array![[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
let result = block_matmul(&a.view(), &b.view(), Some(2)).expect("Operation failed");
assert_eq!(result.shape(), &[2, 2]);
assert_relative_eq!(result[[0, 0]], 58.0);
assert_relative_eq!(result[[0, 1]], 64.0);
assert_relative_eq!(result[[1, 0]], 139.0);
assert_relative_eq!(result[[1, 1]], 154.0);
}
#[test]
fn test_strassen_matmul_2x2() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[5.0, 6.0], [7.0, 8.0]];
let result = strassen_matmul(&a.view(), &b.view(), Some(1)).expect("Operation failed");
assert_eq!(result.shape(), &[2, 2]);
assert_relative_eq!(result[[0, 0]], 19.0);
assert_relative_eq!(result[[0, 1]], 22.0);
assert_relative_eq!(result[[1, 0]], 43.0);
assert_relative_eq!(result[[1, 1]], 50.0);
}
#[test]
fn test_strassen_matmul_3x3() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let b = array![[9.0, 8.0, 7.0], [6.0, 5.0, 4.0], [3.0, 2.0, 1.0]];
let result = strassen_matmul(&a.view(), &b.view(), Some(2)).expect("Operation failed");
let expected = array![[30.0, 24.0, 18.0], [84.0, 69.0, 54.0], [138.0, 114.0, 90.0]];
assert_eq!(result.shape(), &[3, 3]);
for i in 0..3 {
for j in 0..3 {
assert_relative_eq!(result[[i, j]], expected[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_tiled_matmul_2x2() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[5.0, 6.0], [7.0, 8.0]];
let result = tiled_matmul(&a.view(), &b.view(), Some(1)).expect("Operation failed");
assert_eq!(result.shape(), &[2, 2]);
assert_relative_eq!(result[[0, 0]], 19.0);
assert_relative_eq!(result[[0, 1]], 22.0);
assert_relative_eq!(result[[1, 0]], 43.0);
assert_relative_eq!(result[[1, 1]], 50.0);
}
#[test]
fn test_tiled_matmul_2x3_3x2() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let b = array![[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
let result = tiled_matmul(&a.view(), &b.view(), Some(2)).expect("Operation failed");
assert_eq!(result.shape(), &[2, 2]);
assert_relative_eq!(result[[0, 0]], 58.0);
assert_relative_eq!(result[[0, 1]], 64.0);
assert_relative_eq!(result[[1, 0]], 139.0);
assert_relative_eq!(result[[1, 1]], 154.0);
}
#[test]
fn test_largematrix_equivalence() {
let size = 20;
let mut a = Array2::<f64>::zeros((size, size));
let mut b = Array2::<f64>::zeros((size, size));
for i in 0..size {
for j in 0..size {
a[[i, j]] = (i * size + j) as f64;
b[[i, j]] = ((size - i) * size + (size - j)) as f64;
}
}
let result_standard = standard_matmul(&a.view(), &b.view()).expect("Operation failed");
let result_block = block_matmul(&a.view(), &b.view(), Some(4)).expect("Operation failed");
let result_strassen =
strassen_matmul(&a.view(), &b.view(), Some(8)).expect("Operation failed");
let result_tiled = tiled_matmul(&a.view(), &b.view(), Some(4)).expect("Operation failed");
for i in 0..size {
for j in 0..size {
assert_relative_eq!(
result_block[[i, j]],
result_standard[[i, j]],
epsilon = 1e-10
);
assert_relative_eq!(
result_strassen[[i, j]],
result_standard[[i, j]],
epsilon = 1e-10
);
assert_relative_eq!(
result_tiled[[i, j]],
result_standard[[i, j]],
epsilon = 1e-10
);
}
}
}
#[test]
fn test_dimension_mismatch() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[5.0, 6.0, 7.0], [8.0, 9.0, 10.0], [11.0, 12.0, 13.0]];
let result_block = block_matmul(&a.view(), &b.view(), None);
let result_strassen = strassen_matmul(&a.view(), &b.view(), None);
let result_tiled = tiled_matmul(&a.view(), &b.view(), None);
assert!(result_block.is_err());
assert!(result_strassen.is_err());
assert!(result_tiled.is_err());
}
}