use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::Float;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
pub trait SimdOps<T> {
fn simd_add(&self, other: &Array<T>) -> Result<Array<T>>;
fn simd_sub(&self, other: &Array<T>) -> Result<Array<T>>;
fn simd_mul(&self, other: &Array<T>) -> Result<Array<T>>;
fn simd_div(&self, other: &Array<T>) -> Result<Array<T>>;
fn simd_dot(&self, other: &Array<T>) -> Result<T>;
fn simd_sum(&self) -> T;
fn simd_mean(&self) -> T;
fn simd_fma(&self, mul: &Array<T>, add: &Array<T>) -> Result<Array<T>>;
fn simd_add_scalar(&self, scalar: T) -> Array<T>;
fn simd_mul_scalar(&self, scalar: T) -> Array<T>;
fn simd_sub_scalar(&self, scalar: T) -> Array<T>;
fn simd_div_scalar(&self, scalar: T) -> Result<Array<T>>;
}
fn to_ndarray_1d<T: Clone>(arr: &Array<T>) -> Result<Array1<T>> {
let data = arr.to_vec();
Ok(Array1::from_vec(data))
}
impl SimdOps<f32> for Array<f32> {
fn simd_add(&self, other: &Array<f32>) -> Result<Array<f32>> {
if self.shape() != other.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: other.shape(),
});
}
let a = to_ndarray_1d(self)?;
let b = to_ndarray_1d(other)?;
let result = f32::simd_add(&a.view(), &b.view());
Ok(Array::from_vec(result.to_vec()).reshape(&self.shape()))
}
fn simd_sub(&self, other: &Array<f32>) -> Result<Array<f32>> {
if self.shape() != other.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: other.shape(),
});
}
let a = to_ndarray_1d(self)?;
let b = to_ndarray_1d(other)?;
let result = f32::simd_sub(&a.view(), &b.view());
Ok(Array::from_vec(result.to_vec()).reshape(&self.shape()))
}
fn simd_mul(&self, other: &Array<f32>) -> Result<Array<f32>> {
if self.shape() != other.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: other.shape(),
});
}
let a = to_ndarray_1d(self)?;
let b = to_ndarray_1d(other)?;
let result = f32::simd_mul(&a.view(), &b.view());
Ok(Array::from_vec(result.to_vec()).reshape(&self.shape()))
}
fn simd_div(&self, other: &Array<f32>) -> Result<Array<f32>> {
if self.shape() != other.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: other.shape(),
});
}
let a = to_ndarray_1d(self)?;
let b = to_ndarray_1d(other)?;
let result = f32::simd_div(&a.view(), &b.view());
Ok(Array::from_vec(result.to_vec()).reshape(&self.shape()))
}
fn simd_dot(&self, other: &Array<f32>) -> Result<f32> {
if self.shape() != other.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: other.shape(),
});
}
let a = to_ndarray_1d(self)?;
let b = to_ndarray_1d(other)?;
Ok(f32::simd_dot(&a.view(), &b.view()))
}
fn simd_sum(&self) -> f32 {
let a = to_ndarray_1d(self).expect("Array conversion to ndarray should succeed");
f32::simd_sum(&a.view())
}
fn simd_mean(&self) -> f32 {
let a = to_ndarray_1d(self).expect("Array conversion to ndarray should succeed");
f32::simd_mean(&a.view())
}
fn simd_fma(&self, mul: &Array<f32>, add: &Array<f32>) -> Result<Array<f32>> {
if self.shape() != mul.shape() || self.shape() != add.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: if self.shape() != mul.shape() {
mul.shape()
} else {
add.shape()
},
});
}
let a = to_ndarray_1d(self)?;
let m = to_ndarray_1d(mul)?;
let b = to_ndarray_1d(add)?;
let result = f32::simd_fma(&a.view(), &m.view(), &b.view());
Ok(Array::from_vec(result.to_vec()).reshape(&self.shape()))
}
fn simd_add_scalar(&self, scalar: f32) -> Array<f32> {
let a = to_ndarray_1d(self).expect("Array conversion to ndarray should succeed");
let result = f32::simd_add(&a.view(), &ArrayView1::from(&vec![scalar; a.len()]));
Array::from_vec(result.to_vec()).reshape(&self.shape())
}
fn simd_mul_scalar(&self, scalar: f32) -> Array<f32> {
let a = to_ndarray_1d(self).expect("Array conversion to ndarray should succeed");
let result = f32::simd_scalar_mul(&a.view(), scalar);
Array::from_vec(result.to_vec()).reshape(&self.shape())
}
fn simd_sub_scalar(&self, scalar: f32) -> Array<f32> {
self.simd_add_scalar(-scalar)
}
fn simd_div_scalar(&self, scalar: f32) -> Result<Array<f32>> {
if scalar == 0.0 {
return Err(NumRs2Error::InvalidOperation(
"Division by zero".to_string(),
));
}
Ok(self.simd_mul_scalar(1.0 / scalar))
}
}
impl SimdOps<f64> for Array<f64> {
fn simd_add(&self, other: &Array<f64>) -> Result<Array<f64>> {
if self.shape() != other.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: other.shape(),
});
}
let a = to_ndarray_1d(self)?;
let b = to_ndarray_1d(other)?;
let result = f64::simd_add(&a.view(), &b.view());
Ok(Array::from_vec(result.to_vec()).reshape(&self.shape()))
}
fn simd_sub(&self, other: &Array<f64>) -> Result<Array<f64>> {
if self.shape() != other.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: other.shape(),
});
}
let a = to_ndarray_1d(self)?;
let b = to_ndarray_1d(other)?;
let result = f64::simd_sub(&a.view(), &b.view());
Ok(Array::from_vec(result.to_vec()).reshape(&self.shape()))
}
fn simd_mul(&self, other: &Array<f64>) -> Result<Array<f64>> {
if self.shape() != other.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: other.shape(),
});
}
let a = to_ndarray_1d(self)?;
let b = to_ndarray_1d(other)?;
let result = f64::simd_mul(&a.view(), &b.view());
Ok(Array::from_vec(result.to_vec()).reshape(&self.shape()))
}
fn simd_div(&self, other: &Array<f64>) -> Result<Array<f64>> {
if self.shape() != other.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: other.shape(),
});
}
let a = to_ndarray_1d(self)?;
let b = to_ndarray_1d(other)?;
let result = f64::simd_div(&a.view(), &b.view());
Ok(Array::from_vec(result.to_vec()).reshape(&self.shape()))
}
fn simd_dot(&self, other: &Array<f64>) -> Result<f64> {
if self.shape() != other.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: other.shape(),
});
}
let a = to_ndarray_1d(self)?;
let b = to_ndarray_1d(other)?;
Ok(f64::simd_dot(&a.view(), &b.view()))
}
fn simd_sum(&self) -> f64 {
let a = to_ndarray_1d(self).expect("Array conversion to ndarray should succeed");
f64::simd_sum(&a.view())
}
fn simd_mean(&self) -> f64 {
let a = to_ndarray_1d(self).expect("Array conversion to ndarray should succeed");
f64::simd_mean(&a.view())
}
fn simd_fma(&self, mul: &Array<f64>, add: &Array<f64>) -> Result<Array<f64>> {
if self.shape() != mul.shape() || self.shape() != add.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: if self.shape() != mul.shape() {
mul.shape()
} else {
add.shape()
},
});
}
let a = to_ndarray_1d(self)?;
let m = to_ndarray_1d(mul)?;
let b = to_ndarray_1d(add)?;
let result = f64::simd_fma(&a.view(), &m.view(), &b.view());
Ok(Array::from_vec(result.to_vec()).reshape(&self.shape()))
}
fn simd_add_scalar(&self, scalar: f64) -> Array<f64> {
let a = to_ndarray_1d(self).expect("Array conversion to ndarray should succeed");
let result = f64::simd_add(&a.view(), &ArrayView1::from(&vec![scalar; a.len()]));
Array::from_vec(result.to_vec()).reshape(&self.shape())
}
fn simd_mul_scalar(&self, scalar: f64) -> Array<f64> {
let a = to_ndarray_1d(self).expect("Array conversion to ndarray should succeed");
let result = f64::simd_scalar_mul(&a.view(), scalar);
Array::from_vec(result.to_vec()).reshape(&self.shape())
}
fn simd_sub_scalar(&self, scalar: f64) -> Array<f64> {
self.simd_add_scalar(-scalar)
}
fn simd_div_scalar(&self, scalar: f64) -> Result<Array<f64>> {
if scalar == 0.0 {
return Err(NumRs2Error::InvalidOperation(
"Division by zero".to_string(),
));
}
Ok(self.simd_mul_scalar(1.0 / scalar))
}
}
pub fn simd_add<T: Float + 'static>(a: &Array<T>, b: &Array<T>) -> Result<Array<T>>
where
Array<T>: SimdOps<T>,
{
a.simd_add(b)
}
pub fn simd_mul<T: Float + 'static>(a: &Array<T>, b: &Array<T>) -> Result<Array<T>>
where
Array<T>: SimdOps<T>,
{
a.simd_mul(b)
}
pub fn simd_div<T: Float + 'static>(a: &Array<T>, b: &Array<T>) -> Result<Array<T>>
where
Array<T>: SimdOps<T>,
{
a.simd_div(b)
}
pub fn simd_sum<T: Float + 'static>(a: &Array<T>) -> T
where
Array<T>: SimdOps<T>,
{
a.simd_sum()
}
pub fn simd_mean<T: Float + 'static>(a: &Array<T>) -> T
where
Array<T>: SimdOps<T>,
{
a.simd_mean()
}
pub fn simd_prod<T: Float + 'static>(a: &Array<T>) -> T {
let data = a.to_vec();
data.iter().copied().fold(T::one(), |acc, x| acc * x)
}
pub fn simd_exp<T: Float + 'static>(a: &Array<T>) -> Array<T> {
let shape = a.shape();
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let data = a.to_vec();
let data_f64: Vec<f64> = data
.iter()
.map(|&x| x.to_f64().expect("f64 conversion should succeed"))
.collect();
let nd_arr = Array1::from_vec(data_f64);
let result = f64::simd_exp(&nd_arr.view());
let result_vec: Vec<T> = result
.iter()
.map(|&x| T::from(x).expect("conversion from f64 should succeed"))
.collect();
return Array::from_vec(result_vec).reshape(&shape);
}
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let data = a.to_vec();
let data_f32: Vec<f32> = data
.iter()
.map(|&x| x.to_f32().expect("f32 conversion should succeed"))
.collect();
let nd_arr = Array1::from_vec(data_f32);
let result = f32::simd_exp(&nd_arr.view());
let result_vec: Vec<T> = result
.iter()
.map(|&x| T::from(x).expect("conversion from f32 should succeed"))
.collect();
return Array::from_vec(result_vec).reshape(&shape);
}
let data = a.to_vec();
let result: Vec<T> = data.iter().map(|&x| x.exp()).collect();
Array::from_vec(result).reshape(&shape)
}
pub fn simd_log<T: Float + 'static>(a: &Array<T>) -> Array<T> {
let shape = a.shape();
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let data = a.to_vec();
let data_f64: Vec<f64> = data
.iter()
.map(|&x| x.to_f64().expect("f64 conversion should succeed"))
.collect();
let nd_arr = Array1::from_vec(data_f64);
let result = f64::simd_ln(&nd_arr.view());
let result_vec: Vec<T> = result
.iter()
.map(|&x| T::from(x).expect("conversion from f64 should succeed"))
.collect();
return Array::from_vec(result_vec).reshape(&shape);
}
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let data = a.to_vec();
let data_f32: Vec<f32> = data
.iter()
.map(|&x| x.to_f32().expect("f32 conversion should succeed"))
.collect();
let nd_arr = Array1::from_vec(data_f32);
let result = f32::simd_ln(&nd_arr.view());
let result_vec: Vec<T> = result
.iter()
.map(|&x| T::from(x).expect("conversion from f32 should succeed"))
.collect();
return Array::from_vec(result_vec).reshape(&shape);
}
let data = a.to_vec();
let result: Vec<T> = data.iter().map(|&x| x.ln()).collect();
Array::from_vec(result).reshape(&shape)
}
pub fn simd_sqrt<T: Float + 'static>(a: &Array<T>) -> Array<T> {
let shape = a.shape();
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let data = a.to_vec();
let data_f64: Vec<f64> = data
.iter()
.map(|&x| x.to_f64().expect("f64 conversion should succeed"))
.collect();
let nd_arr = Array1::from_vec(data_f64);
let result = f64::simd_sqrt(&nd_arr.view());
let result_vec: Vec<T> = result
.iter()
.map(|&x| T::from(x).expect("conversion from f64 should succeed"))
.collect();
return Array::from_vec(result_vec).reshape(&shape);
}
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let data = a.to_vec();
let data_f32: Vec<f32> = data
.iter()
.map(|&x| x.to_f32().expect("f32 conversion should succeed"))
.collect();
let nd_arr = Array1::from_vec(data_f32);
let result = f32::simd_sqrt(&nd_arr.view());
let result_vec: Vec<T> = result
.iter()
.map(|&x| T::from(x).expect("conversion from f32 should succeed"))
.collect();
return Array::from_vec(result_vec).reshape(&shape);
}
let data = a.to_vec();
let result: Vec<T> = data.iter().map(|&x| x.sqrt()).collect();
Array::from_vec(result).reshape(&shape)
}
pub fn get_simd_implementation_name() -> String {
let caps = PlatformCapabilities::detect();
format!(
"NumRS2 SIMD via scirs2-core: AVX512={}, AVX2={}, NEON={}, SIMD={}",
caps.avx512_available, caps.avx2_available, caps.neon_available, caps.simd_available
)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_simd_add() {
let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]);
let b = Array::from_vec(vec![5.0f64, 6.0, 7.0, 8.0]);
let c = simd_add(&a, &b).expect("simd_add should succeed");
assert_eq!(c.to_vec(), vec![6.0, 8.0, 10.0, 12.0]);
}
#[test]
fn test_simd_mul() {
let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]);
let b = Array::from_vec(vec![5.0f64, 6.0, 7.0, 8.0]);
let c = simd_mul(&a, &b).expect("simd_mul should succeed");
assert_eq!(c.to_vec(), vec![5.0, 12.0, 21.0, 32.0]);
}
#[test]
fn test_simd_div() {
let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]);
let b = Array::from_vec(vec![5.0f64, 6.0, 7.0, 8.0]);
let c = simd_div(&a, &b).expect("simd_div should succeed");
assert_relative_eq!(c.to_vec()[0], 0.2, epsilon = 1e-10);
assert_relative_eq!(c.to_vec()[1], 2.0 / 6.0, epsilon = 1e-10);
assert_relative_eq!(c.to_vec()[2], 3.0 / 7.0, epsilon = 1e-10);
assert_relative_eq!(c.to_vec()[3], 4.0 / 8.0, epsilon = 1e-10);
}
#[test]
fn test_simd_sqrt() {
let a = Array::from_vec(vec![1.0f64, 4.0, 9.0, 16.0]);
let b = simd_sqrt(&a);
assert_relative_eq!(b.to_vec()[0], 1.0, epsilon = 1e-10);
assert_relative_eq!(b.to_vec()[1], 2.0, epsilon = 1e-10);
assert_relative_eq!(b.to_vec()[2], 3.0, epsilon = 1e-10);
assert_relative_eq!(b.to_vec()[3], 4.0, epsilon = 1e-10);
}
#[test]
fn test_simd_sum() {
let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]);
let sum = simd_sum(&a);
assert_relative_eq!(sum, 10.0, epsilon = 1e-10);
}
#[test]
fn test_simd_prod() {
let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]);
let prod = simd_prod(&a);
assert_relative_eq!(prod, 24.0, epsilon = 1e-10);
}
#[test]
fn test_simd_mean() {
let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]);
let mean = a.simd_mean();
assert_relative_eq!(mean, 2.5, epsilon = 1e-10);
}
#[test]
fn test_simd_add_scalar() {
let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]);
let result = a.simd_add_scalar(10.0);
assert_eq!(result.to_vec(), vec![11.0, 12.0, 13.0, 14.0]);
}
#[test]
fn test_simd_mul_scalar() {
let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]);
let result = a.simd_mul_scalar(2.0);
assert_eq!(result.to_vec(), vec![2.0, 4.0, 6.0, 8.0]);
}
#[test]
fn test_simd_sub_scalar() {
let a = Array::from_vec(vec![10.0f64, 20.0, 30.0, 40.0]);
let result = a.simd_sub_scalar(5.0);
assert_eq!(result.to_vec(), vec![5.0, 15.0, 25.0, 35.0]);
}
#[test]
fn test_simd_div_scalar() {
let a = Array::from_vec(vec![10.0f64, 20.0, 30.0, 40.0]);
let result = a
.simd_div_scalar(2.0)
.expect("simd_div_scalar should succeed");
assert_eq!(result.to_vec(), vec![5.0, 10.0, 15.0, 20.0]);
}
#[test]
fn test_simd_div_scalar_zero() {
let a = Array::from_vec(vec![10.0f64, 20.0, 30.0, 40.0]);
let result = a.simd_div_scalar(0.0);
assert!(result.is_err());
}
#[test]
fn test_simd_fma() {
let a = Array::from_vec(vec![1.0f64, 2.0, 3.0]);
let b = Array::from_vec(vec![2.0f64, 3.0, 4.0]);
let c = Array::from_vec(vec![1.0f64, 1.0, 1.0]);
let result = a.simd_fma(&b, &c).expect("simd_fma should succeed");
assert_eq!(result.to_vec(), vec![3.0, 7.0, 13.0]);
}
#[test]
fn test_cpu_feature_detection() {
let info = get_simd_implementation_name();
println!("Detected SIMD capabilities: {}", info);
assert!(info.contains("NumRS2 SIMD via scirs2-core"));
}
#[test]
fn test_simd_large_array() {
let data: Vec<f64> = (0..1024).map(|i| i as f64).collect();
let a = Array::from_vec(data.clone());
let b = Array::from_vec(data);
let result = simd_add(&a, &b).expect("simd_add should succeed");
assert_relative_eq!(
result.get(&[0]).expect("get element should succeed"),
0.0,
epsilon = 1e-10
);
assert_relative_eq!(
result.get(&[512]).expect("get element should succeed"),
1024.0,
epsilon = 1e-10
);
assert_relative_eq!(
result.get(&[1023]).expect("get element should succeed"),
2046.0,
epsilon = 1e-10
);
}
}