use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use scirs2_core::numeric::Float;
use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
use super::NnResult;
use crate::error::NumRs2Error;
use crate::simd::SimdOps;
pub fn detect_simd_capabilities() -> PlatformCapabilities {
PlatformCapabilities::detect()
}
pub fn simd_relu_f32(x: &ArrayView1<f32>) -> Array1<f32> {
let zero = Array1::zeros(x.len());
f32::simd_max(x, &zero.view())
}
pub fn simd_relu_2d_f32(x: &ArrayView2<f32>) -> Array2<f32> {
let mut result = Array2::zeros(x.raw_dim());
for (i, row) in x.axis_iter(Axis(0)).enumerate() {
result.row_mut(i).assign(&simd_relu_f32(&row));
}
result
}
pub fn simd_leaky_relu_f32(x: &ArrayView1<f32>, alpha: f32) -> Array1<f32> {
let alpha_x = x.mapv(|v| v * alpha);
f32::simd_max(x, &alpha_x.view())
}
pub fn simd_sigmoid_f32(x: &ArrayView1<f32>) -> Array1<f32> {
let neg_x = x.mapv(|v| -v);
let exp_neg_x = simd_exp_f32(&neg_x.view());
let one = Array1::from_elem(x.len(), 1.0);
let denominator = f32::simd_add(&one.view(), &exp_neg_x.view());
f32::simd_div(&one.view(), &denominator.view())
}
pub fn simd_tanh_f32(x: &ArrayView1<f32>) -> Array1<f32> {
let two_x = x.mapv(|v| v * 2.0);
let sigmoid_2x = simd_sigmoid_f32(&two_x.view());
let two_sigmoid = sigmoid_2x.mapv(|v| v * 2.0);
let one = Array1::from_elem(x.len(), 1.0);
f32::simd_sub(&two_sigmoid.view(), &one.view())
}
pub fn simd_exp_f32(x: &ArrayView1<f32>) -> Array1<f32> {
x.mapv(|v| v.exp())
}
pub fn simd_gelu_f32(x: &ArrayView1<f32>) -> Array1<f32> {
const SQRT_2_OVER_PI: f32 = 0.7978845608; const COEFF: f32 = 0.044715;
let x_squared = f32::simd_mul(x, x);
let x_cubed = f32::simd_mul(&x_squared.view(), x);
let coeff_x_cubed = x_cubed.mapv(|v| v * COEFF);
let inner = f32::simd_add(x, &coeff_x_cubed.view());
let scaled = inner.mapv(|v| v * SQRT_2_OVER_PI);
let tanh_val = simd_tanh_f32(&scaled.view());
let one = Array1::from_elem(x.len(), 1.0);
let one_plus_tanh = f32::simd_add(&one.view(), &tanh_val.view());
let x_times = f32::simd_mul(x, &one_plus_tanh.view());
x_times.mapv(|v| v * 0.5)
}
pub fn simd_swish_f32(x: &ArrayView1<f32>) -> Array1<f32> {
let sigmoid_x = simd_sigmoid_f32(x);
f32::simd_mul(x, &sigmoid_x.view())
}
pub fn simd_mish_f32(x: &ArrayView1<f32>) -> Array1<f32> {
let exp_x = simd_exp_f32(x);
let one = Array1::from_elem(x.len(), 1.0);
let one_plus_exp = f32::simd_add(&one.view(), &exp_x.view());
let softplus = one_plus_exp.mapv(|v| v.ln());
let tanh_softplus = simd_tanh_f32(&softplus.view());
f32::simd_mul(x, &tanh_softplus.view())
}
pub fn simd_elu_f32(x: &ArrayView1<f32>, alpha: f32) -> Array1<f32> {
let mut result = Array1::zeros(x.len());
let zero = 0.0f32;
for (i, &val) in x.iter().enumerate() {
result[i] = if val > zero {
val
} else {
alpha * (val.exp() - 1.0)
};
}
result
}
pub fn simd_selu_f32(x: &ArrayView1<f32>) -> Array1<f32> {
const LAMBDA: f32 = 1.0507009873554804934193349852946;
const ALPHA: f32 = 1.6732632423543772848170429916717;
let elu = simd_elu_f32(x, ALPHA);
elu.mapv(|v| v * LAMBDA)
}
pub fn simd_relu_f64(x: &ArrayView1<f64>) -> Array1<f64> {
let zero = Array1::zeros(x.len());
f64::simd_max(x, &zero.view())
}
pub fn simd_relu_2d_f64(x: &ArrayView2<f64>) -> Array2<f64> {
let mut result = Array2::zeros(x.raw_dim());
for (i, row) in x.axis_iter(Axis(0)).enumerate() {
result.row_mut(i).assign(&simd_relu_f64(&row));
}
result
}
pub fn simd_sigmoid_f64(x: &ArrayView1<f64>) -> Array1<f64> {
let exp_neg_x = x.mapv(|v| (-v).exp());
let one = Array1::from_elem(x.len(), 1.0);
let denominator = f64::simd_add(&one.view(), &exp_neg_x.view());
f64::simd_div(&one.view(), &denominator.view())
}
pub fn simd_tanh_f64(x: &ArrayView1<f64>) -> Array1<f64> {
let two_x = x.mapv(|v| v * 2.0);
let sigmoid_2x = simd_sigmoid_f64(&two_x.view());
let two_sigmoid = sigmoid_2x.mapv(|v| v * 2.0);
let one = Array1::from_elem(x.len(), 1.0);
f64::simd_sub(&two_sigmoid.view(), &one.view())
}
pub fn simd_gelu_f64(x: &ArrayView1<f64>) -> Array1<f64> {
const SQRT_2_OVER_PI: f64 = 0.7978845608028654;
const COEFF: f64 = 0.044715;
let x_squared = f64::simd_mul(x, x);
let x_cubed = f64::simd_mul(&x_squared.view(), x);
let coeff_x_cubed = x_cubed.mapv(|v| v * COEFF);
let inner = f64::simd_add(x, &coeff_x_cubed.view());
let scaled = inner.mapv(|v| v * SQRT_2_OVER_PI);
let tanh_val = simd_tanh_f64(&scaled.view());
let one = Array1::from_elem(x.len(), 1.0);
let one_plus_tanh = f64::simd_add(&one.view(), &tanh_val.view());
let x_times = f64::simd_mul(x, &one_plus_tanh.view());
x_times.mapv(|v| v * 0.5)
}
pub fn simd_swish_f64(x: &ArrayView1<f64>) -> Array1<f64> {
let sigmoid_x = simd_sigmoid_f64(x);
f64::simd_mul(x, &sigmoid_x.view())
}
pub fn simd_matmul_f32(a: &ArrayView2<f32>, b: &ArrayView2<f32>) -> NnResult<Array2<f32>> {
if a.ncols() != b.nrows() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Matrix dimensions incompatible: ({}, {}) x ({}, {})",
a.nrows(),
a.ncols(),
b.nrows(),
b.ncols()
)));
}
let mut result = Array2::zeros((a.nrows(), b.ncols()));
f32::simd_gemm(1.0, a, b, 0.0, &mut result);
Ok(result)
}
pub fn simd_matmul_f64(a: &ArrayView2<f64>, b: &ArrayView2<f64>) -> NnResult<Array2<f64>> {
if a.ncols() != b.nrows() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Matrix dimensions incompatible: ({}, {}) x ({}, {})",
a.nrows(),
a.ncols(),
b.nrows(),
b.ncols()
)));
}
let mut result = Array2::zeros((a.nrows(), b.ncols()));
f64::simd_gemm(1.0, a, b, 0.0, &mut result);
Ok(result)
}
pub fn simd_add_f32(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> NnResult<Array1<f32>> {
if a.len() != b.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Array lengths must match: {} != {}",
a.len(),
b.len()
)));
}
Ok(f32::simd_add(a, b))
}
pub fn simd_mul_f32(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> NnResult<Array1<f32>> {
if a.len() != b.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Array lengths must match: {} != {}",
a.len(),
b.len()
)));
}
Ok(f32::simd_mul(a, b))
}
pub fn simd_sub_f32(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> NnResult<Array1<f32>> {
if a.len() != b.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Array lengths must match: {} != {}",
a.len(),
b.len()
)));
}
Ok(f32::simd_sub(a, b))
}
pub fn simd_div_f32(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> NnResult<Array1<f32>> {
if a.len() != b.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Array lengths must match: {} != {}",
a.len(),
b.len()
)));
}
Ok(f32::simd_div(a, b))
}
pub fn simd_dot_f32(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> NnResult<f32> {
if a.len() != b.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Array lengths must match: {} != {}",
a.len(),
b.len()
)));
}
Ok(f32::simd_dot(a, b))
}
pub fn simd_sum_f32(x: &ArrayView1<f32>) -> f32 {
f32::simd_sum(x)
}
pub fn simd_mean_f32(x: &ArrayView1<f32>) -> f32 {
f32::simd_mean(x)
}
pub fn simd_norm_f32(x: &ArrayView1<f32>) -> f32 {
f32::simd_norm(x)
}
pub fn simd_min_f32(x: &ArrayView1<f32>) -> f32 {
f32::simd_min_element(x)
}
pub fn simd_max_f32(x: &ArrayView1<f32>) -> f32 {
f32::simd_max_element(x)
}
pub fn get_simd_info() -> String {
let caps = detect_simd_capabilities();
format!(
"NumRS2 Neural Network SIMD Capabilities:\n\
- SIMD Available: {}\n\
- AVX2: {}\n\
- AVX512: {}\n\
- NEON: {}\n\
- Vector Width (f32): {} elements\n\
- Vector Width (f64): {} elements",
caps.simd_available,
caps.avx2_available,
caps.avx512_available,
caps.neon_available,
if caps.avx512_available {
16
} else if caps.avx2_available {
8
} else if caps.neon_available {
4
} else {
1
},
if caps.avx512_available {
8
} else if caps.avx2_available {
4
} else if caps.neon_available {
2
} else {
1
}
)
}
pub fn is_simd_available() -> bool {
detect_simd_capabilities().simd_available
}
pub fn recommended_batch_size() -> usize {
let caps = detect_simd_capabilities();
if caps.avx512_available {
512 } else if caps.avx2_available {
256 } else if caps.neon_available {
128 } else {
64 }
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_simd_relu_f32() {
let x = array![-2.0f32, -1.0, 0.0, 1.0, 2.0];
let y = simd_relu_f32(&x.view());
let expected = array![0.0f32, 0.0, 0.0, 1.0, 2.0];
for (actual, expected) in y.iter().zip(expected.iter()) {
assert!((actual - expected).abs() < 1e-6);
}
}
#[test]
fn test_simd_sigmoid_f32() {
let x = array![0.0f32, 1.0, -1.0];
let y = simd_sigmoid_f32(&x.view());
assert!((y[0] - 0.5).abs() < 1e-6);
assert!((y[1] - 0.7310586).abs() < 1e-5);
assert!((y[2] - 0.26894143).abs() < 1e-5);
}
#[test]
fn test_simd_matmul_f32() {
let a = array![[1.0f32, 2.0], [3.0, 4.0]];
let b = array![[5.0f32, 6.0], [7.0, 8.0]];
let c = simd_matmul_f32(&a.view(), &b.view()).expect("matmul failed");
assert!((c[[0, 0]] - 19.0).abs() < 1e-5);
assert!((c[[0, 1]] - 22.0).abs() < 1e-5);
assert!((c[[1, 0]] - 43.0).abs() < 1e-5);
assert!((c[[1, 1]] - 50.0).abs() < 1e-5);
}
#[test]
fn test_simd_dot_f32() {
let a = array![1.0f32, 2.0, 3.0];
let b = array![4.0f32, 5.0, 6.0];
let dot = simd_dot_f32(&a.view(), &b.view()).expect("dot failed");
assert!((dot - 32.0).abs() < 1e-5);
}
#[test]
fn test_simd_capabilities() {
let caps = detect_simd_capabilities();
println!("{}", get_simd_info());
assert!(recommended_batch_size() > 0);
}
}