#[cfg(feature = "simd")]
use crate::error::{LinalgError, LinalgResult};
#[cfg(feature = "simd")]
use scirs2_core::ndarray::{Array1, Array2, ArrayView2, ArrayViewMut2};
#[cfg(feature = "simd")]
use scirs2_core::simd_ops::SimdUnifiedOps;
#[cfg(feature = "simd")]
#[allow(dead_code)]
pub fn simdmatrix_add_f32(a: &ArrayView2<f32>, b: &ArrayView2<f32>) -> LinalgResult<Array2<f32>> {
if a.dim() != b.dim() {
return Err(LinalgError::ShapeError(format!(
"Matrix dimensions must match: {:?} vs {:?}",
a.dim(),
b.dim()
)));
}
let (rows, cols) = a.dim();
let mut result = Array2::zeros((rows, cols));
for i in 0..rows {
let a_row = a.row(i);
let b_row = b.row(i);
let sum_row = f32::simd_add(&a_row, &b_row);
for (j, &val) in sum_row.iter().enumerate() {
result[[i, j]] = val;
}
}
Ok(result)
}
#[cfg(feature = "simd")]
#[allow(dead_code)]
pub fn simdmatrix_add_f64(a: &ArrayView2<f64>, b: &ArrayView2<f64>) -> LinalgResult<Array2<f64>> {
if a.dim() != b.dim() {
return Err(LinalgError::ShapeError(format!(
"Matrix dimensions must match: {:?} vs {:?}",
a.dim(),
b.dim()
)));
}
let (rows, cols) = a.dim();
let mut result = Array2::zeros((rows, cols));
for i in 0..rows {
let a_row = a.row(i);
let b_row = b.row(i);
let sum_row = f64::simd_add(&a_row, &b_row);
for (j, &val) in sum_row.iter().enumerate() {
result[[i, j]] = val;
}
}
Ok(result)
}
#[cfg(feature = "simd")]
#[allow(dead_code)]
pub fn simdmatrix_add_inplace_f32(
a: &mut ArrayViewMut2<f32>,
b: &ArrayView2<f32>,
) -> LinalgResult<()> {
if a.dim() != b.dim() {
return Err(LinalgError::ShapeError(format!(
"Matrix dimensions must match: {:?} vs {:?}",
a.dim(),
b.dim()
)));
}
for (mut a_row, b_row) in a.rows_mut().into_iter().zip(b.rows()) {
let a_copy = Array1::from(a_row.to_vec());
let sum_row = f32::simd_add(&a_copy.view(), &b_row);
for (a_elem, &sum_elem) in a_row.iter_mut().zip(sum_row.iter()) {
*a_elem = sum_elem;
}
}
Ok(())
}
#[cfg(feature = "simd")]
#[allow(dead_code)]
pub fn simdmatrix_mul_elementwise_f32(
a: &ArrayView2<f32>,
b: &ArrayView2<f32>,
) -> LinalgResult<Array2<f32>> {
if a.dim() != b.dim() {
return Err(LinalgError::ShapeError(format!(
"Matrix dimensions must match: {:?} vs {:?}",
a.dim(),
b.dim()
)));
}
let (rows, cols) = a.dim();
let mut result = Array2::zeros((rows, cols));
for i in 0..rows {
let a_row = a.row(i);
let b_row = b.row(i);
let mul_row = f32::simd_mul(&a_row, &b_row);
for (j, &val) in mul_row.iter().enumerate() {
result[[i, j]] = val;
}
}
Ok(result)
}
#[cfg(feature = "simd")]
#[allow(dead_code)]
pub fn simdmatrix_scale_f32(a: &ArrayView2<f32>, scalar: f32) -> LinalgResult<Array2<f32>> {
let (rows, cols) = a.dim();
let mut result = Array2::zeros((rows, cols));
for i in 0..rows {
let a_row = a.row(i);
let scaled_row = f32::simd_scalar_mul(&a_row, scalar);
for (j, &val) in scaled_row.iter().enumerate() {
result[[i, j]] = val;
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
#[cfg(feature = "simd")]
fn test_simdmatrix_add_f32() {
let a = array![[1.0f32, 2.0], [3.0, 4.0]];
let b = array![[5.0f32, 6.0], [7.0, 8.0]];
let result = simdmatrix_add_f32(&a.view(), &b.view()).expect("Operation failed");
assert_relative_eq!(result[[0, 0]], 6.0);
assert_relative_eq!(result[[0, 1]], 8.0);
assert_relative_eq!(result[[1, 0]], 10.0);
assert_relative_eq!(result[[1, 1]], 12.0);
}
#[test]
#[cfg(feature = "simd")]
fn test_simdmatrix_scale_f32() {
let a = array![[1.0f32, 2.0], [3.0, 4.0]];
let scalar = 2.5f32;
let result = simdmatrix_scale_f32(&a.view(), scalar).expect("Operation failed");
assert_relative_eq!(result[[0, 0]], 2.5);
assert_relative_eq!(result[[0, 1]], 5.0);
assert_relative_eq!(result[[1, 0]], 7.5);
assert_relative_eq!(result[[1, 1]], 10.0);
}
#[test]
#[cfg(feature = "simd")]
fn test_simdmatrix_mul_elementwise_f32() {
let a = array![[1.0f32, 2.0], [3.0, 4.0]];
let b = array![[2.0f32, 3.0], [4.0, 5.0]];
let result =
simdmatrix_mul_elementwise_f32(&a.view(), &b.view()).expect("Operation failed");
assert_relative_eq!(result[[0, 0]], 2.0);
assert_relative_eq!(result[[0, 1]], 6.0);
assert_relative_eq!(result[[1, 0]], 12.0);
assert_relative_eq!(result[[1, 1]], 20.0);
}
}