use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array2, Array3, ArrayView2, Axis};
const NB_L1: usize = 64;
const MB_L1: usize = 64;
const KB_L2: usize = 256;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum GemmBackend {
CpuNaive,
CpuBlas,
GpuOxiblas,
}
#[derive(Clone, Debug)]
pub struct GemmConfig {
pub backend: GemmBackend,
pub block_size: usize,
pub transpose_a: bool,
pub transpose_b: bool,
pub alpha: f64,
pub beta: f64,
}
impl Default for GemmConfig {
fn default() -> Self {
Self {
backend: GemmBackend::CpuNaive,
block_size: MB_L1,
transpose_a: false,
transpose_b: false,
alpha: 1.0,
beta: 0.0,
}
}
}
pub fn gemm(
a: &Array2<f64>,
b: &Array2<f64>,
c: Option<&Array2<f64>>,
config: &GemmConfig,
) -> LinalgResult<Array2<f64>> {
let a_eff: Array2<f64>;
let b_eff: Array2<f64>;
let a_ref: &Array2<f64> = if config.transpose_a {
let rows = a.ncols();
let cols = a.nrows();
a_eff = Array2::from_shape_fn((rows, cols), |(i, j)| a[[j, i]]);
&a_eff
} else {
a
};
let b_ref: &Array2<f64> = if config.transpose_b {
let rows = b.ncols();
let cols = b.nrows();
b_eff = Array2::from_shape_fn((rows, cols), |(i, j)| b[[j, i]]);
&b_eff
} else {
b
};
let (m, k_a) = (a_ref.nrows(), a_ref.ncols());
let (k_b, n) = (b_ref.nrows(), b_ref.ncols());
if k_a != k_b {
return Err(LinalgError::DimensionError(format!(
"GEMM: inner dimensions must match: A has k={k_a}, B has k={k_b}"
)));
}
let k = k_a;
match config.backend {
GemmBackend::GpuOxiblas => {
#[cfg(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal"
))]
{
gemm_cpu_blocked(a_ref, b_ref, c, m, n, k, config)
}
#[cfg(not(any(
feature = "cuda",
feature = "opencl",
feature = "rocm",
feature = "metal"
)))]
{
gemm_cpu_blocked(a_ref, b_ref, c, m, n, k, config)
}
}
GemmBackend::CpuBlas | GemmBackend::CpuNaive => {
gemm_cpu_blocked(a_ref, b_ref, c, m, n, k, config)
}
}
}
pub fn batched_gemm(
a: &Array3<f64>,
b: &Array3<f64>,
config: &GemmConfig,
) -> LinalgResult<Array3<f64>> {
let (batch_a, m, k_a) = (a.shape()[0], a.shape()[1], a.shape()[2]);
let (batch_b, k_b, n) = (b.shape()[0], b.shape()[1], b.shape()[2]);
if batch_a != batch_b {
return Err(LinalgError::DimensionError(format!(
"Batched GEMM: batch sizes must match: got {batch_a} and {batch_b}"
)));
}
if k_a != k_b {
return Err(LinalgError::DimensionError(format!(
"Batched GEMM: inner dimensions must match: A has k={k_a}, B has k={k_b}"
)));
}
let batch = batch_a;
let mut result = Array3::<f64>::zeros((batch, m, n));
for i in 0..batch {
let a_slice: Array2<f64> = a.index_axis(Axis(0), i).to_owned();
let b_slice: Array2<f64> = b.index_axis(Axis(0), i).to_owned();
let c_slice = gemm(&a_slice, &b_slice, None, config)?;
result.index_axis_mut(Axis(0), i).assign(&c_slice);
}
Ok(result)
}
pub fn symm_gemm(a: &Array2<f64>, config: &GemmConfig) -> LinalgResult<Array2<f64>> {
let mut cfg = config.clone();
cfg.transpose_b = true; gemm(a, a, None, &cfg)
}
fn gemm_cpu_blocked(
a: &Array2<f64>,
b: &Array2<f64>,
c_init: Option<&Array2<f64>>,
m: usize,
n: usize,
k: usize,
config: &GemmConfig,
) -> LinalgResult<Array2<f64>> {
let mb = config.block_size.max(1);
let nb = mb; let kb = KB_L2;
let mut c: Array2<f64> = match c_init {
Some(c0) => {
if c0.nrows() != m || c0.ncols() != n {
return Err(LinalgError::DimensionError(format!(
"GEMM: initial C has shape [{}, {}], expected [{m}, {n}]",
c0.nrows(),
c0.ncols()
)));
}
if config.beta == 0.0 {
Array2::<f64>::zeros((m, n))
} else {
c0.mapv(|v| v * config.beta)
}
}
None => Array2::<f64>::zeros((m, n)),
};
let a_owned: Array2<f64>;
let a_c: &Array2<f64> = if a.is_standard_layout() {
a
} else {
a_owned = Array2::from_shape_fn((a.nrows(), a.ncols()), |(i, j)| a[[i, j]]);
&a_owned
};
let b_owned: Array2<f64>;
let b_c: &Array2<f64> = if b.is_standard_layout() {
b
} else {
b_owned = Array2::from_shape_fn((b.nrows(), b.ncols()), |(i, j)| b[[i, j]]);
&b_owned
};
let a_slice = a_c.as_slice().ok_or_else(|| {
LinalgError::ComputationError(
"A matrix could not be converted to a contiguous slice".to_string(),
)
})?;
let b_slice = b_c.as_slice().ok_or_else(|| {
LinalgError::ComputationError(
"B matrix could not be converted to a contiguous slice".to_string(),
)
})?;
let c_slice = c.as_slice_mut().ok_or_else(|| {
LinalgError::ComputationError("C matrix is not contiguous in memory".to_string())
})?;
let alpha = config.alpha;
let mut kb_start = 0;
while kb_start < k {
let kb_end = (kb_start + kb).min(k);
let kb_size = kb_end - kb_start;
let mut mb_start = 0;
while mb_start < m {
let mb_end = (mb_start + mb).min(m);
let mb_size = mb_end - mb_start;
let mut a_pack = vec![0.0_f64; mb_size * kb_size];
for i in 0..mb_size {
for p in 0..kb_size {
a_pack[i * kb_size + p] = a_slice[(mb_start + i) * k + (kb_start + p)];
}
}
let mut nb_start = 0;
while nb_start < n {
let nb_end = (nb_start + nb).min(n);
let nb_size = nb_end - nb_start;
let mut b_pack = vec![0.0_f64; kb_size * nb_size];
for p in 0..kb_size {
for j in 0..nb_size {
b_pack[p * nb_size + j] = b_slice[(kb_start + p) * n + (nb_start + j)];
}
}
micro_kernel(
&a_pack, &b_pack, c_slice, mb_size, nb_size, kb_size, mb_start, nb_start, n,
alpha,
);
nb_start += nb;
}
mb_start += mb;
}
kb_start += kb;
}
Ok(c)
}
#[inline(always)]
fn micro_kernel(
a_pack: &[f64],
b_pack: &[f64],
c: &mut [f64],
mb: usize,
nb: usize,
kb: usize,
c_row_offset: usize,
c_col_offset: usize,
c_stride: usize,
alpha: f64,
) {
for i in 0..mb {
for p in 0..kb {
let a_ip = a_pack[i * kb + p] * alpha;
if a_ip == 0.0 {
continue;
}
for j in 0..nb {
let c_idx = (c_row_offset + i) * c_stride + (c_col_offset + j);
c[c_idx] += a_ip * b_pack[p * nb + j];
}
}
}
}
pub fn gemm_view(
a: &ArrayView2<f64>,
b: &ArrayView2<f64>,
c: Option<&ArrayView2<f64>>,
config: &GemmConfig,
) -> LinalgResult<Array2<f64>> {
let a_owned = a.to_owned();
let b_owned = b.to_owned();
let c_owned = c.map(|v| v.to_owned());
gemm(&a_owned, &b_owned, c_owned.as_ref(), config)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::{array, Array2, Array3};
fn naive_gemm(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
let (m, k) = (a.nrows(), a.ncols());
let n = b.ncols();
let mut c = Array2::<f64>::zeros((m, n));
for i in 0..m {
for p in 0..k {
for j in 0..n {
c[[i, j]] += a[[i, p]] * b[[p, j]];
}
}
}
c
}
#[test]
fn test_gemm_identity() {
let eye = Array2::<f64>::eye(3);
let b = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let c = gemm(&eye, &b, None, &GemmConfig::default()).unwrap();
for i in 0..3 {
for j in 0..3 {
assert_abs_diff_eq!(c[[i, j]], b[[i, j]], epsilon = 1e-12);
}
}
}
#[test]
fn test_gemm_transpose_a() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]]; let b = array![[1.0_f64, 0.0], [0.0, 1.0], [1.0, 1.0]]; let config = GemmConfig {
transpose_a: true,
..Default::default()
};
let c = gemm(&a, &b, None, &config).unwrap();
assert_eq!(c.shape(), &[2, 2]);
assert_abs_diff_eq!(c[[0, 0]], 6.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[0, 1]], 8.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[1, 0]], 8.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[1, 1]], 10.0, epsilon = 1e-12);
}
#[test]
fn test_gemm_transpose_b() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
let b = array![[5.0_f64, 7.0], [6.0, 8.0]];
let config = GemmConfig {
transpose_b: true,
..Default::default()
};
let c = gemm(&a, &b, None, &config).unwrap();
assert_abs_diff_eq!(c[[0, 0]], 19.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[0, 1]], 22.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[1, 0]], 43.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[1, 1]], 50.0, epsilon = 1e-12);
}
#[test]
fn test_gemm_alpha_beta() {
let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
let b = array![[3.0_f64, 0.0], [0.0, 3.0]];
let c_init = array![[1.0_f64, 2.0], [3.0, 4.0]];
let config = GemmConfig {
alpha: 2.0,
beta: 0.5,
..GemmConfig::default()
};
let c = gemm(&a, &b, Some(&c_init), &config).unwrap();
assert_abs_diff_eq!(c[[0, 0]], 6.0 + 0.5, epsilon = 1e-12);
assert_abs_diff_eq!(c[[0, 1]], 0.0 + 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[1, 0]], 0.0 + 1.5, epsilon = 1e-12);
assert_abs_diff_eq!(c[[1, 1]], 6.0 + 2.0, epsilon = 1e-12);
}
#[test]
fn test_gemm_blocked_vs_naive() {
use scirs2_core::ndarray::Array2;
let m = 70;
let k = 90;
let n = 80;
let a: Array2<f64> =
Array2::from_shape_fn((m, k), |(i, j)| ((i * k + j) as f64) / (m * k) as f64);
let b: Array2<f64> =
Array2::from_shape_fn((k, n), |(i, j)| ((i * n + j) as f64) / (k * n) as f64);
let expected = naive_gemm(&a, &b);
let got = gemm(&a, &b, None, &GemmConfig::default()).unwrap();
for i in 0..m {
for j in 0..n {
assert_abs_diff_eq!(got[[i, j]], expected[[i, j]], epsilon = 1e-9);
}
}
}
#[test]
fn test_gemm_non_square() {
let a = Array2::<f64>::from_shape_fn((3, 5), |(i, j)| (i + j) as f64);
let b = Array2::<f64>::from_shape_fn((5, 4), |(i, j)| (i * j) as f64);
let got = gemm(&a, &b, None, &GemmConfig::default()).unwrap();
let expected = naive_gemm(&a, &b);
assert_eq!(got.shape(), &[3, 4]);
for i in 0..3 {
for j in 0..4 {
assert_abs_diff_eq!(got[[i, j]], expected[[i, j]], epsilon = 1e-9);
}
}
}
#[test]
fn test_gemm_dimension_mismatch() {
let a = Array2::<f64>::zeros((3, 4));
let b = Array2::<f64>::zeros((5, 2)); assert!(gemm(&a, &b, None, &GemmConfig::default()).is_err());
}
#[test]
fn test_batched_gemm_shape() {
let a = Array3::<f64>::from_shape_fn((4, 3, 5), |(b, i, j)| (b + i + j) as f64);
let b = Array3::<f64>::from_shape_fn((4, 5, 2), |(b, i, j)| (b + i * j) as f64);
let c = batched_gemm(&a, &b, &GemmConfig::default()).unwrap();
assert_eq!(c.shape(), &[4, 3, 2]);
}
#[test]
fn test_batched_gemm_result() {
let batch = 3;
let a =
Array3::<f64>::from_shape_fn((batch, 2, 2), |(b, i, j)| (b * 4 + i * 2 + j + 1) as f64);
let b =
Array3::<f64>::from_shape_fn((batch, 2, 2), |(b, i, j)| (b * 4 + i * 2 + j + 1) as f64);
let c_batched = batched_gemm(&a, &b, &GemmConfig::default()).unwrap();
for i in 0..batch {
let a_slice = a.index_axis(Axis(0), i).to_owned();
let b_slice = b.index_axis(Axis(0), i).to_owned();
let c_single = gemm(&a_slice, &b_slice, None, &GemmConfig::default()).unwrap();
let c_slice = c_batched.index_axis(Axis(0), i).to_owned();
for r in 0..2 {
for col in 0..2 {
assert_abs_diff_eq!(c_slice[[r, col]], c_single[[r, col]], epsilon = 1e-12);
}
}
}
}
#[test]
fn test_batched_gemm_batch_mismatch() {
let a = Array3::<f64>::zeros((3, 2, 2));
let b = Array3::<f64>::zeros((4, 2, 2));
assert!(batched_gemm(&a, &b, &GemmConfig::default()).is_err());
}
#[test]
fn test_symm_gemm_symmetry() {
let a = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]]; let c = symm_gemm(&a, &GemmConfig::default()).unwrap();
assert_eq!(c.shape(), &[2, 2]);
assert_abs_diff_eq!(c[[0, 1]], c[[1, 0]], epsilon = 1e-12);
}
#[test]
fn test_symm_gemm_psd() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
let c = symm_gemm(&a, &GemmConfig::default()).unwrap();
for i in 0..c.nrows() {
assert!(c[[i, i]] >= 0.0);
}
}
#[test]
fn test_symm_gemm_values() {
let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
let c = symm_gemm(&a, &GemmConfig::default()).unwrap();
assert_abs_diff_eq!(c[[0, 0]], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[1, 1]], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[0, 1]], 0.0, epsilon = 1e-12);
}
#[test]
fn test_gemm_large_block_boundary() {
let m = 130;
let k = 270;
let n = 140;
let a: Array2<f64> = Array2::from_shape_fn((m, k), |(i, j)| ((i + j) as f64) * 0.001);
let b: Array2<f64> = Array2::from_shape_fn((k, n), |(i, j)| ((i * 2 + j) as f64) * 0.001);
let got = gemm(&a, &b, None, &GemmConfig::default()).unwrap();
let expected = naive_gemm(&a, &b);
for i in 0..m {
for j in 0..n {
assert_abs_diff_eq!(got[[i, j]], expected[[i, j]], epsilon = 1e-6);
}
}
}
}