#[cfg(feature = "simd")]
use crate::error::{LinalgError, LinalgResult};
#[cfg(feature = "simd")]
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
#[cfg(feature = "simd")]
use scirs2_core::simd_ops::SimdUnifiedOps;
#[cfg(feature = "simd")]
pub struct GemmBlockSizes {
pub mc: usize,
pub kc: usize,
pub nc: usize,
pub mr: usize,
pub nr: usize,
}
#[cfg(feature = "simd")]
impl Default for GemmBlockSizes {
fn default() -> Self {
Self {
mc: 64, kc: 256, nc: 512, mr: 8, nr: 8, }
}
}
#[cfg(feature = "simd")]
#[allow(dead_code)]
pub fn simd_gemm_f32(
alpha: f32,
a: &ArrayView2<f32>,
b: &ArrayView2<f32>,
beta: f32,
c: &mut Array2<f32>,
_blocksizes: Option<GemmBlockSizes>,
) -> LinalgResult<()> {
let (m, k1) = a.dim();
let (k2, n) = b.dim();
let (cm, cn) = c.dim();
if k1 != k2 {
return Err(LinalgError::ShapeError(format!(
"Matrix inner dimensions must match: A({m}, {k1}) * B({k2}, {n})"
)));
}
if cm != m || cn != n {
return Err(LinalgError::ShapeError(format!(
"Result matrix dimensions must match: C({cm}, {cn}) for A({m}, {k1}) * B({k2}, {n})"
)));
}
f32::simd_gemm(alpha, a, b, beta, c);
Ok(())
}
#[cfg(feature = "simd")]
#[allow(dead_code)]
pub fn simd_gemm_f64(
alpha: f64,
a: &ArrayView2<f64>,
b: &ArrayView2<f64>,
beta: f64,
c: &mut Array2<f64>,
_blocksizes: Option<GemmBlockSizes>,
) -> LinalgResult<()> {
let (m, k1) = a.dim();
let (k2, n) = b.dim();
let (cm, cn) = c.dim();
if k1 != k2 {
return Err(LinalgError::ShapeError(format!(
"Matrix inner dimensions must match: A({m}, {k1}) * B({k2}, {n})"
)));
}
if cm != m || cn != n {
return Err(LinalgError::ShapeError(format!(
"Result matrix dimensions must match: C({cm}, {cn}) for A({m}, {k1}) * B({k2}, {n})"
)));
}
f64::simd_gemm(alpha, a, b, beta, c);
Ok(())
}
#[cfg(feature = "simd")]
#[allow(dead_code)]
pub fn simd_matmul_optimized_f32(
a: &ArrayView2<f32>,
b: &ArrayView2<f32>,
) -> LinalgResult<Array2<f32>> {
let (m, _) = a.dim();
let (_, n) = b.dim();
let mut c = Array2::zeros((m, n));
simd_gemm_f32(1.0, a, b, 0.0, &mut c, None)?;
Ok(c)
}
#[cfg(feature = "simd")]
#[allow(dead_code)]
pub fn simd_matmul_optimized_f64(
a: &ArrayView2<f64>,
b: &ArrayView2<f64>,
) -> LinalgResult<Array2<f64>> {
let (m, _) = a.dim();
let (_, n) = b.dim();
let mut c = Array2::zeros((m, n));
simd_gemm_f64(1.0, a, b, 0.0, &mut c, None)?;
Ok(c)
}
#[cfg(feature = "simd")]
#[allow(dead_code)]
pub fn simd_gemv_f32(
alpha: f32,
a: &ArrayView2<f32>,
x: &ArrayView1<f32>,
beta: f32,
y: &mut Array1<f32>,
) -> LinalgResult<()> {
let (m, n) = a.dim();
if x.len() != n {
return Err(LinalgError::ShapeError(format!(
"Vector x length ({}) must match matrix columns ({})",
x.len(),
n
)));
}
if y.len() != m {
return Err(LinalgError::ShapeError(format!(
"Vector y length ({}) must match matrix rows ({})",
y.len(),
m
)));
}
if beta == 0.0 {
f32::simd_gemv(a, x, 0.0, y);
if alpha != 1.0 {
y.mapv_inplace(|v| v * alpha);
}
} else {
let y_original = y.clone();
f32::simd_gemv(a, x, 0.0, y);
for i in 0..y.len() {
y[i] = alpha * y[i] + beta * y_original[i];
}
}
Ok(())
}
#[cfg(feature = "simd")]
#[allow(dead_code)]
pub fn simd_gemv_f64(
alpha: f64,
a: &ArrayView2<f64>,
x: &ArrayView1<f64>,
beta: f64,
y: &mut Array1<f64>,
) -> LinalgResult<()> {
let (m, n) = a.dim();
if x.len() != n {
return Err(LinalgError::ShapeError(format!(
"Vector x length ({}) must match matrix columns ({})",
x.len(),
n
)));
}
if y.len() != m {
return Err(LinalgError::ShapeError(format!(
"Vector y length ({}) must match matrix rows ({})",
y.len(),
m
)));
}
if beta == 0.0 {
f64::simd_gemv(a, x, 0.0, y);
if alpha != 1.0 {
y.mapv_inplace(|v| v * alpha);
}
} else {
let y_original = y.clone();
f64::simd_gemv(a, x, 0.0, y);
for i in 0..y.len() {
y[i] = alpha * y[i] + beta * y_original[i];
}
}
Ok(())
}
#[cfg(all(test, feature = "simd"))]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
#[cfg(feature = "simd")]
fn test_simd_gemm_f32_basic() {
let a = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
let b = array![[7.0f32, 8.0], [9.0, 10.0], [11.0, 12.0]];
let mut c = Array2::zeros((2, 2));
simd_gemm_f32(1.0, &a.view(), &b.view(), 0.0, &mut c, None).expect("Operation failed");
assert_relative_eq!(c[[0, 0]], 58.0, epsilon = 1e-6);
assert_relative_eq!(c[[0, 1]], 64.0, epsilon = 1e-6);
assert_relative_eq!(c[[1, 0]], 139.0, epsilon = 1e-6);
assert_relative_eq!(c[[1, 1]], 154.0, epsilon = 1e-6);
}
#[test]
#[cfg(feature = "simd")]
#[ignore = "Panics in simd/dot.rs:1167 - Option::unwrap() on None value"]
fn test_simd_gemm_f64_basic() {
let a = array![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
let b = array![[7.0f64, 8.0], [9.0, 10.0], [11.0, 12.0]];
let mut c = Array2::zeros((2, 2));
simd_gemm_f64(1.0, &a.view(), &b.view(), 0.0, &mut c, None).expect("Operation failed");
assert_relative_eq!(c[[0, 0]], 58.0, epsilon = 1e-12);
assert_relative_eq!(c[[0, 1]], 64.0, epsilon = 1e-12);
assert_relative_eq!(c[[1, 0]], 139.0, epsilon = 1e-12);
assert_relative_eq!(c[[1, 1]], 154.0, epsilon = 1e-12);
}
#[test]
#[cfg(feature = "simd")]
fn test_simd_gemm_alpha_beta() {
let a = array![[1.0f32, 2.0], [3.0, 4.0]];
let b = array![[5.0f32, 6.0], [7.0, 8.0]];
let mut c = array![[1.0f32, 2.0], [3.0, 4.0]];
let alpha = 2.0;
let beta = 3.0;
simd_gemm_f32(alpha, &a.view(), &b.view(), beta, &mut c, None).expect("Operation failed");
assert_relative_eq!(c[[0, 0]], 41.0, epsilon = 1e-6);
assert_relative_eq!(c[[0, 1]], 50.0, epsilon = 1e-6);
assert_relative_eq!(c[[1, 0]], 95.0, epsilon = 1e-6);
assert_relative_eq!(c[[1, 1]], 112.0, epsilon = 1e-6);
}
#[test]
#[cfg(feature = "simd")]
fn test_simd_matmul_optimized() {
let a = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
let b = array![[7.0f32, 8.0], [9.0, 10.0], [11.0, 12.0]];
let c = simd_matmul_optimized_f32(&a.view(), &b.view()).expect("Operation failed");
assert_relative_eq!(c[[0, 0]], 58.0, epsilon = 1e-6);
assert_relative_eq!(c[[0, 1]], 64.0, epsilon = 1e-6);
assert_relative_eq!(c[[1, 0]], 139.0, epsilon = 1e-6);
assert_relative_eq!(c[[1, 1]], 154.0, epsilon = 1e-6);
}
#[test]
#[cfg(feature = "simd")]
fn test_simd_gemv() {
let a = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
let x = array![7.0f32, 8.0, 9.0];
let mut y = array![1.0f32, 2.0];
let alpha = 2.0;
let beta = 3.0;
simd_gemv_f32(alpha, &a.view(), &x.view(), beta, &mut y).expect("Operation failed");
assert_relative_eq!(y[0], 103.0, epsilon = 1e-6);
assert_relative_eq!(y[1], 250.0, epsilon = 1e-6);
}
#[test]
#[cfg(feature = "simd")]
fn test_simd_gemm_largematrix() {
let m = 100;
let k = 80;
let n = 60;
let a = Array2::from_shape_fn((m, k), |(i, j)| (i + j) as f32 * 0.01);
let b = Array2::from_shape_fn((k, n), |(i, j)| (i * 2 + j) as f32 * 0.01);
let mut c = Array2::zeros((m, n));
let blocksizes = GemmBlockSizes {
mc: 32,
kc: 64,
nc: 48,
mr: 8,
nr: 8,
};
simd_gemm_f32(1.0, &a.view(), &b.view(), 0.0, &mut c, Some(blocksizes))
.expect("Operation failed");
let c_ref = a.dot(&b);
for ((i, j), &val) in c.indexed_iter() {
assert_relative_eq!(val, c_ref[[i, j]], epsilon = 1e-4);
}
}
#[test]
#[cfg(feature = "simd")]
fn test_gemm_error_handling() {
let a = array![[1.0f32, 2.0], [3.0, 4.0]];
let b = array![[5.0f32, 6.0, 7.0], [8.0, 9.0, 10.0], [11.0, 12.0, 13.0]]; let mut c = Array2::zeros((2, 3));
let result = simd_gemm_f32(1.0, &a.view(), &b.view(), 0.0, &mut c, None);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), LinalgError::ShapeError(_)));
}
}