use scirs2_core::ndarray::{array, Array1};
use scirs2_core::ndarray_ext::elementwise::{
gelu_simd, mish_simd, sigmoid_simd, softplus_simd, swish_simd,
};
#[cfg(feature = "random")]
use scirs2_core::random::{thread_rng, Distribution, Uniform};
#[test]
fn test_sigmoid_simd_f64_at_zero() {
let x = array![0.0_f64];
let result = sigmoid_simd(&x.view());
assert!(
(result[0] - 0.5).abs() < 1e-15,
"sigmoid(0) should be 0.5, got {}",
result[0]
);
}
#[test]
fn test_sigmoid_simd_f64_basic() {
let x = array![0.0_f64, 1.0, -1.0, 2.0, -2.0];
let result = sigmoid_simd(&x.view());
assert!(
(result[0] - 0.5).abs() < 1e-15,
"sigmoid(0) should be 0.5, got {}",
result[0]
);
let expected_1 = 1.0 / (1.0 + (-1.0_f64).exp());
assert!(
(result[1] - expected_1).abs() < 1e-10,
"sigmoid(1) should be {}, got {}",
expected_1,
result[1]
);
assert!(
(result[2] - (1.0 - expected_1)).abs() < 1e-10,
"sigmoid(-1) should be {}, got {}",
1.0 - expected_1,
result[2]
);
assert!(
(result[1] + result[2] - 1.0).abs() < 1e-10,
"sigmoid(1) + sigmoid(-1) should be 1, got {}",
result[1] + result[2]
);
assert!(
(result[3] + result[4] - 1.0).abs() < 1e-10,
"sigmoid(2) + sigmoid(-2) should be 1, got {}",
result[3] + result[4]
);
}
#[test]
fn test_sigmoid_simd_f64_range() {
let x = array![-10.0_f64, -5.0, -1.0, 0.0, 1.0, 5.0, 10.0];
let result = sigmoid_simd(&x.view());
for i in 0..x.len() {
assert!(
result[i] > 0.0 && result[i] < 1.0,
"sigmoid({}) should be in (0,1), got {}",
x[i],
result[i]
);
}
let x_extreme = array![-100.0_f64, 100.0];
let result_extreme = sigmoid_simd(&x_extreme.view());
for i in 0..x_extreme.len() {
assert!(
result_extreme[i] >= 0.0 && result_extreme[i] <= 1.0,
"sigmoid({}) should be in [0,1], got {}",
x_extreme[i],
result_extreme[i]
);
}
}
#[test]
fn test_sigmoid_simd_f64_symmetry() {
let x_pos = array![0.5_f64, 1.0, 2.0, 5.0, 10.0];
let x_neg = array![-0.5_f64, -1.0, -2.0, -5.0, -10.0];
let result_pos = sigmoid_simd(&x_pos.view());
let result_neg = sigmoid_simd(&x_neg.view());
for i in 0..x_pos.len() {
assert!(
(result_pos[i] + result_neg[i] - 1.0).abs() < 1e-10,
"sigmoid({}) + sigmoid({}) should be 1, got {} + {} = {}",
x_pos[i],
x_neg[i],
result_pos[i],
result_neg[i],
result_pos[i] + result_neg[i]
);
}
}
#[test]
fn test_sigmoid_simd_f64_large_positive() {
let x = array![50.0_f64, 100.0, 500.0, 700.0];
let result = sigmoid_simd(&x.view());
for i in 0..x.len() {
assert!(
result[i] > 0.999,
"sigmoid({}) should be close to 1, got {}",
x[i],
result[i]
);
assert!(
result[i] <= 1.0,
"sigmoid({}) should be <= 1, got {}",
x[i],
result[i]
);
}
}
#[test]
fn test_sigmoid_simd_f64_large_negative() {
let x = array![-50.0_f64, -100.0, -500.0, -700.0];
let result = sigmoid_simd(&x.view());
for i in 0..x.len() {
assert!(
result[i] < 0.001,
"sigmoid({}) should be close to 0, got {}",
x[i],
result[i]
);
assert!(
result[i] >= 0.0,
"sigmoid({}) should be >= 0, got {}",
x[i],
result[i]
);
}
}
#[test]
fn test_sigmoid_simd_f64_infinity() {
let x = array![f64::INFINITY, f64::NEG_INFINITY];
let result = sigmoid_simd(&x.view());
assert!(
(result[0] - 1.0).abs() < 1e-10,
"sigmoid(+∞) should be 1, got {}",
result[0]
);
assert!(
result[1].abs() < 1e-10,
"sigmoid(-∞) should be 0, got {}",
result[1]
);
}
#[test]
fn test_sigmoid_simd_f64_nan() {
let x = array![f64::NAN];
let result = sigmoid_simd(&x.view());
assert!(result[0].is_nan(), "sigmoid(NaN) should be NaN");
}
#[test]
fn test_sigmoid_simd_f32_basic() {
let x = array![0.0_f32, 1.0, -1.0, 5.0, -5.0];
let result = sigmoid_simd(&x.view());
assert!(
(result[0] - 0.5).abs() < 1e-6,
"sigmoid(0) should be 0.5, got {}",
result[0]
);
assert!(
(result[1] + result[2] - 1.0).abs() < 1e-5,
"sigmoid(1) + sigmoid(-1) should be 1, got {}",
result[1] + result[2]
);
assert!(
(result[3] + result[4] - 1.0).abs() < 1e-5,
"sigmoid(5) + sigmoid(-5) should be 1, got {}",
result[3] + result[4]
);
}
#[test]
fn test_sigmoid_simd_empty() {
let x = array![] as Array1<f64>;
let result = sigmoid_simd(&x.view());
assert!(result.is_empty(), "sigmoid of empty array should be empty");
}
#[test]
fn test_sigmoid_simd_large_array() {
let n = 10000;
let x: Array1<f64> = Array1::linspace(-10.0, 10.0, n);
let result = sigmoid_simd(&x.view());
assert_eq!(result.len(), n);
for i in 1..n {
assert!(
result[i] >= result[i - 1],
"sigmoid should be monotonically increasing: sigmoid({}) = {} < sigmoid({}) = {}",
x[i - 1],
result[i - 1],
x[i],
result[i]
);
}
for i in 0..n {
assert!(
result[i] >= 0.0 && result[i] <= 1.0,
"sigmoid({}) out of range: {}",
x[i],
result[i]
);
}
let symmetric_x = array![-5.0_f64, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 5.0];
let symmetric_result = sigmoid_simd(&symmetric_x.view());
assert!((symmetric_result[0] + symmetric_result[8] - 1.0).abs() < 1e-10); assert!((symmetric_result[1] + symmetric_result[7] - 1.0).abs() < 1e-10); assert!((symmetric_result[2] + symmetric_result[6] - 1.0).abs() < 1e-10); assert!((symmetric_result[3] + symmetric_result[5] - 1.0).abs() < 1e-10); assert!((symmetric_result[4] - 0.5).abs() < 1e-10); }
#[test]
fn test_sigmoid_simd_derivative_property() {
let h = 1e-8_f64;
let x_vals = [0.0_f64, 0.5, 1.0, -1.0, 2.0];
for &x in &x_vals {
let x_arr = array![x];
let x_plus_h = array![x + h];
let x_minus_h = array![x - h];
let s = sigmoid_simd(&x_arr.view())[0];
let s_plus = sigmoid_simd(&x_plus_h.view())[0];
let s_minus = sigmoid_simd(&x_minus_h.view())[0];
let numerical_deriv = (s_plus - s_minus) / (2.0 * h);
let analytical_deriv = s * (1.0 - s);
assert!(
(numerical_deriv - analytical_deriv).abs() < 1e-5,
"Derivative at x={}: numerical {} vs analytical {}",
x,
numerical_deriv,
analytical_deriv
);
}
}
#[test]
fn test_sigmoid_simd_logistic_regression() {
let logits = array![0.0_f64, -2.0, 2.0, -5.0, 5.0];
let probs = sigmoid_simd(&logits.view());
assert!((probs[0] - 0.5).abs() < 1e-10);
assert!(probs[1] < 0.5);
assert!(probs[3] < 0.5);
assert!(probs[2] > 0.5);
assert!(probs[4] > 0.5);
}
#[test]
fn test_gelu_simd_f64_at_zero() {
let x = array![0.0_f64];
let result = gelu_simd(&x.view());
assert!(
result[0].abs() < 1e-15,
"GELU(0) should be 0, got {}",
result[0]
);
}
#[test]
fn test_gelu_simd_f64_basic() {
let x = array![0.0_f64, 1.0, -1.0, 2.0, -2.0];
let result = gelu_simd(&x.view());
assert!(
result[0].abs() < 1e-15,
"GELU(0) should be 0, got {}",
result[0]
);
let expected_1 = 0.8413447460685429_f64;
assert!(
(result[1] - expected_1).abs() < 1e-6,
"GELU(1) should be approximately {}, got {}",
expected_1,
result[1]
);
let expected_neg1 = -0.15865525393145707_f64;
assert!(
(result[2] - expected_neg1).abs() < 1e-6,
"GELU(-1) should be approximately {}, got {}",
expected_neg1,
result[2]
);
assert!(
result[3] > result[1],
"GELU should be increasing: GELU(2)={} should be > GELU(1)={}",
result[3],
result[1]
);
}
#[test]
fn test_gelu_simd_f64_large_positive() {
let x = array![5.0_f64, 10.0, 20.0];
let result = gelu_simd(&x.view());
for i in 0..x.len() {
assert!(
(result[i] - x[i]).abs() < 0.01,
"GELU({}) should be approximately {}, got {}",
x[i],
x[i],
result[i]
);
}
}
#[test]
fn test_gelu_simd_f64_large_negative() {
let x = array![-5.0_f64, -10.0, -20.0];
let result = gelu_simd(&x.view());
for i in 0..x.len() {
assert!(
result[i].abs() < 0.01,
"GELU({}) should be approximately 0, got {}",
x[i],
result[i]
);
}
}
#[test]
fn test_gelu_simd_f64_smoothness() {
let eps = 1e-6_f64;
let x = array![-eps, 0.0, eps];
let result = gelu_simd(&x.view());
assert!(
(result[1] - result[0]).abs() < 0.001,
"GELU should be smooth at 0: GELU({})={}, GELU(0)={}",
-eps,
result[0],
result[1]
);
assert!(
(result[2] - result[1]).abs() < 0.001,
"GELU should be smooth at 0: GELU(0)={}, GELU({})={}",
result[1],
eps,
result[2]
);
}
#[test]
fn test_gelu_simd_f64_monotonicity() {
let n = 100;
let x: Array1<f64> = Array1::linspace(0.0, 10.0, n);
let result = gelu_simd(&x.view());
for i in 1..n {
assert!(
result[i] >= result[i - 1],
"GELU should be monotonically increasing for positive x: GELU({})={} < GELU({})={}",
x[i - 1],
result[i - 1],
x[i],
result[i]
);
}
}
#[test]
fn test_gelu_simd_f64_nan() {
let x = array![f64::NAN];
let result = gelu_simd(&x.view());
assert!(result[0].is_nan(), "GELU(NaN) should be NaN");
}
#[test]
fn test_gelu_simd_f32_basic() {
let x = array![0.0_f32, 1.0, -1.0, 2.0, -2.0];
let result = gelu_simd(&x.view());
assert!(
result[0].abs() < 1e-6,
"GELU(0) should be 0, got {}",
result[0]
);
assert!(result[1] > 0.0, "GELU(1) should be positive");
assert!(result[2] < 0.0, "GELU(-1) should be negative");
}
#[test]
fn test_gelu_simd_empty() {
let x = array![] as Array1<f64>;
let result = gelu_simd(&x.view());
assert!(result.is_empty(), "GELU of empty array should be empty");
}
#[test]
fn test_gelu_simd_large_array() {
let n = 10000;
let x: Array1<f64> = Array1::linspace(-5.0, 5.0, n);
let result = gelu_simd(&x.view());
assert_eq!(result.len(), n);
for i in 0..n {
if x[i] < 0.0 {
assert!(
result[i] >= x[i] && result[i] <= 0.0,
"For x={}, GELU={} should be between {} and 0",
x[i],
result[i],
x[i]
);
}
if x[i] > 0.0 {
assert!(
result[i] >= 0.0 && result[i] <= x[i],
"For x={}, GELU={} should be between 0 and {}",
x[i],
result[i],
x[i]
);
}
}
}
#[test]
fn test_gelu_simd_vs_relu() {
let x = array![-3.0_f64, -1.0, 0.0, 1.0, 3.0];
let gelu_result = gelu_simd(&x.view());
assert!(
(gelu_result[4] - x[4]).abs() < 0.01,
"GELU(3) should be close to 3"
);
assert!(gelu_result[0].abs() < 0.01, "GELU(-3) should be close to 0");
assert!(gelu_result[2].abs() < 1e-10, "GELU(0) should be 0");
}
#[test]
fn test_gelu_simd_transformer_use_case() {
let hidden = array![0.5_f64, 1.2, -0.8, 2.0, -1.5];
let activated = gelu_simd(&hidden.view());
for i in 0..hidden.len() {
assert!(
activated[i].is_finite(),
"GELU output should be finite for input {}",
hidden[i]
);
}
assert!(activated[0] > 0.0);
assert!(activated[1] > 0.0);
assert!(activated[3] > 0.0);
}
#[test]
fn test_swish_simd_f64_at_zero() {
let x = array![0.0_f64];
let result = swish_simd(&x.view());
assert!(
result[0].abs() < 1e-15,
"Swish(0) should be 0, got {}",
result[0]
);
}
#[test]
fn test_swish_simd_f64_basic() {
let x = array![0.0_f64, 1.0, -1.0, 2.0, -2.0];
let result = swish_simd(&x.view());
assert!(result[0].abs() < 1e-10, "Swish(0) should be 0");
let expected_swish_1 = 1.0 / (1.0 + (-1.0_f64).exp());
assert!(
(result[1] - expected_swish_1).abs() < 1e-10,
"Swish(1) should be approximately {}, got {}",
expected_swish_1,
result[1]
);
let expected_swish_neg1 = -1.0 / (1.0 + 1.0_f64.exp());
assert!(
(result[2] - expected_swish_neg1).abs() < 1e-10,
"Swish(-1) should be approximately {}, got {}",
expected_swish_neg1,
result[2]
);
let expected_swish_2 = 2.0 / (1.0 + (-2.0_f64).exp());
assert!(
(result[3] - expected_swish_2).abs() < 1e-10,
"Swish(2) should be approximately {}, got {}",
expected_swish_2,
result[3]
);
let expected_swish_neg2 = -2.0 / (1.0 + 2.0_f64.exp());
assert!(
(result[4] - expected_swish_neg2).abs() < 1e-10,
"Swish(-2) should be approximately {}, got {}",
expected_swish_neg2,
result[4]
);
}
#[test]
fn test_swish_simd_f64_large_positive() {
let x = array![5.0_f64, 10.0, 20.0];
let result = swish_simd(&x.view());
for i in 0..x.len() {
let relative_error = (result[i] - x[i]).abs() / x[i];
assert!(
relative_error < 0.01,
"Swish({}) ≈ {}, relative error should be small: {}",
x[i],
result[i],
relative_error
);
}
}
#[test]
fn test_swish_simd_f64_large_negative() {
let x = array![-5.0_f64, -10.0, -20.0];
let result = swish_simd(&x.view());
for i in 0..x.len() {
assert!(
result[i].abs() < 0.1,
"Swish({}) should be approximately 0, got {}",
x[i],
result[i]
);
}
}
#[test]
fn test_swish_simd_f64_smoothness() {
let eps = 1e-6_f64;
let x = array![-eps, 0.0, eps];
let result = swish_simd(&x.view());
let diff_left = (result[0] - result[1]).abs();
let diff_right = (result[2] - result[1]).abs();
assert!(
diff_left < 1e-5,
"Swish should be continuous at 0 from the left"
);
assert!(
diff_right < 1e-5,
"Swish should be continuous at 0 from the right"
);
}
#[test]
fn test_swish_simd_f64_global_minimum() {
let x = array![-1.278_f64, -1.5, -1.0, -2.0];
let result = swish_simd(&x.view());
let min_value = result[0];
assert!(
min_value < -0.27 && min_value > -0.29,
"Swish global minimum should be around -0.278, got {}",
min_value
);
for i in 1..result.len() {
assert!(
result[i] >= min_value - 0.01,
"Swish({}) = {} should be >= minimum {}",
x[i],
result[i],
min_value
);
}
}
#[test]
fn test_swish_simd_f64_nan() {
let x = array![f64::NAN];
let result = swish_simd(&x.view());
assert!(result[0].is_nan(), "Swish(NaN) should be NaN");
}
#[test]
fn test_swish_simd_f32_basic() {
let x = array![0.0_f32, 1.0, -1.0, 2.0, -2.0];
let result = swish_simd(&x.view());
assert!(result[0].abs() < 1e-6, "Swish(0) should be 0");
let expected_swish_1 = 1.0_f32 / (1.0 + (-1.0_f32).exp());
assert!(
(result[1] - expected_swish_1).abs() < 1e-5,
"Swish(1) should be approximately {}",
expected_swish_1
);
let expected_swish_neg1 = -1.0_f32 / (1.0 + 1.0_f32.exp());
assert!(
(result[2] - expected_swish_neg1).abs() < 1e-5,
"Swish(-1) should be approximately {}",
expected_swish_neg1
);
}
#[test]
fn test_swish_simd_empty() {
let x = array![] as Array1<f64>;
let result = swish_simd(&x.view());
assert!(result.is_empty(), "Swish of empty array should be empty");
}
#[test]
fn test_swish_simd_large_array() {
let n = 10000;
let x: Array1<f64> = Array1::linspace(-5.0, 5.0, n);
let result = swish_simd(&x.view());
assert_eq!(result.len(), n, "Result should have same length as input");
let mid = n / 2;
for i in 0..100 {
let pos_idx = mid + i;
let neg_idx = mid - i;
if neg_idx > 0 && pos_idx < n {
let x_pos = x[pos_idx];
let x_neg = x[neg_idx];
let expected_pos = x_pos / (1.0 + (-x_pos).exp());
let expected_neg = x_neg / (1.0 + (-x_neg).exp());
assert!(
(result[pos_idx] - expected_pos).abs() < 1e-10,
"Swish({}) should be {}",
x_pos,
expected_pos
);
assert!(
(result[neg_idx] - expected_neg).abs() < 1e-10,
"Swish({}) should be {}",
x_neg,
expected_neg
);
}
}
}
#[test]
fn test_swish_simd_sigmoid_relation() {
let x = array![-2.0_f64, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0];
let swish_result = swish_simd(&x.view());
let sigmoid_result = sigmoid_simd(&x.view());
for i in 0..x.len() {
let expected = x[i] * sigmoid_result[i];
assert!(
(swish_result[i] - expected).abs() < 1e-10,
"Swish({}) should equal {} * sigmoid({}) = {}",
x[i],
x[i],
x[i],
expected
);
}
}
#[test]
fn test_swish_simd_efficientnet_use_case() {
let features = array![0.5_f64, 1.2, -0.8, 2.0, -1.5, 0.0, 3.0, -3.0];
let activated = swish_simd(&features.view());
for i in 0..features.len() {
assert!(
activated[i].is_finite(),
"Swish output should be finite for input {}",
features[i]
);
}
assert!(activated[0] > 0.0);
assert!(activated[1] > 0.0);
assert!(activated[3] > 0.0);
assert!(activated[6] > 0.0);
assert!(activated[5].abs() < 1e-10);
for i in 0..features.len() {
assert!(
activated[i] > -0.3,
"Swish output should be >= -0.278, got {}",
activated[i]
);
}
}
#[test]
fn test_swish_simd_vs_relu() {
let x = array![-3.0_f64, -1.0, 0.0, 1.0, 3.0];
let swish_result = swish_simd(&x.view());
assert!(swish_result[0] < 0.0 && swish_result[0] > -0.3);
assert!(swish_result[1] < 0.0 && swish_result[1] > -0.3);
assert!(swish_result[2].abs() < 1e-10);
assert!(swish_result[3] > 0.0 && swish_result[3] < 1.0);
assert!(swish_result[4] > 0.0 && swish_result[4] < 3.0);
}
#[test]
fn test_swish_simd_derivative() {
let eps = 1e-6_f64;
let test_points = array![0.5_f64, 1.0, 2.0, -0.5, -1.0];
for i in 0..test_points.len() {
let x = test_points[i];
let x_plus = array![x + eps];
let x_minus = array![x - eps];
let swish_plus = swish_simd(&x_plus.view())[0];
let swish_minus = swish_simd(&x_minus.view())[0];
let numerical_derivative = (swish_plus - swish_minus) / (2.0 * eps);
let sigmoid_x = 1.0 / (1.0 + (-x).exp());
let analytical_derivative = sigmoid_x * (1.0 + x * (1.0 - sigmoid_x));
assert!(
(numerical_derivative - analytical_derivative).abs() < 1e-4,
"Swish derivative at {} should be approximately {}, numerical: {}",
x,
analytical_derivative,
numerical_derivative
);
}
}
#[test]
fn test_softplus_simd_f64_at_zero() {
let x = array![0.0_f64];
let result = softplus_simd(&x.view());
let expected = (2.0_f64).ln(); assert!(
(result[0] - expected).abs() < 1e-15,
"Softplus(0) should be ln(2) ≈ {}, got {}",
expected,
result[0]
);
}
#[test]
fn test_softplus_simd_f64_basic() {
let x = array![0.0_f64, 1.0, -1.0, 2.0, -2.0];
let result = softplus_simd(&x.view());
let expected_0 = (2.0_f64).ln();
assert!(
(result[0] - expected_0).abs() < 1e-10,
"Softplus(0) should be ln(2)"
);
let expected_1 = (1.0_f64 + 1.0_f64.exp()).ln();
assert!(
(result[1] - expected_1).abs() < 1e-10,
"Softplus(1) should be approximately {}, got {}",
expected_1,
result[1]
);
let expected_neg1 = (1.0_f64 + (-1.0_f64).exp()).ln();
assert!(
(result[2] - expected_neg1).abs() < 1e-10,
"Softplus(-1) should be approximately {}, got {}",
expected_neg1,
result[2]
);
let expected_2 = (1.0_f64 + 2.0_f64.exp()).ln();
assert!(
(result[3] - expected_2).abs() < 1e-10,
"Softplus(2) should be approximately {}, got {}",
expected_2,
result[3]
);
let expected_neg2 = (1.0_f64 + (-2.0_f64).exp()).ln();
assert!(
(result[4] - expected_neg2).abs() < 1e-10,
"Softplus(-2) should be approximately {}, got {}",
expected_neg2,
result[4]
);
}
#[test]
fn test_softplus_simd_f64_large_positive() {
let x = array![10.0_f64, 20.0, 50.0];
let result = softplus_simd(&x.view());
for i in 0..x.len() {
let relative_error = (result[i] - x[i]).abs() / x[i];
assert!(
relative_error < 1e-4,
"Softplus({}) ≈ {}, relative error should be small: {}",
x[i],
result[i],
relative_error
);
}
}
#[test]
fn test_softplus_simd_f64_large_negative() {
let x = array![-10.0_f64, -20.0, -50.0];
let result = softplus_simd(&x.view());
for i in 0..x.len() {
let expected = x[i].exp();
let relative_error = if expected > 1e-15 {
(result[i] - expected).abs() / expected
} else {
result[i].abs()
};
assert!(
relative_error < 0.01, "Softplus({}) should be approximately exp({}) = {}, got {}, relative error: {}",
x[i],
x[i],
expected,
result[i],
relative_error
);
}
}
#[test]
fn test_softplus_simd_f64_always_positive() {
let x = array![-100.0_f64, -10.0, -1.0, 0.0, 1.0, 10.0, 100.0];
let result = softplus_simd(&x.view());
for i in 0..x.len() {
assert!(
result[i] > 0.0,
"Softplus({}) should be positive, got {}",
x[i],
result[i]
);
}
}
#[test]
fn test_softplus_simd_f64_monotonicity() {
let n = 100;
let x: Array1<f64> = Array1::linspace(-10.0, 10.0, n);
let result = softplus_simd(&x.view());
for i in 1..n {
assert!(
result[i] > result[i - 1],
"Softplus should be monotonically increasing: softplus({}) = {} <= softplus({}) = {}",
x[i - 1],
result[i - 1],
x[i],
result[i]
);
}
}
#[test]
fn test_softplus_simd_f64_smoothness() {
let eps = 1e-6_f64;
let x = array![-eps, 0.0, eps];
let result = softplus_simd(&x.view());
let diff_left = (result[0] - result[1]).abs();
let diff_right = (result[2] - result[1]).abs();
assert!(
diff_left < 1e-5,
"Softplus should be continuous at 0 from the left"
);
assert!(
diff_right < 1e-5,
"Softplus should be continuous at 0 from the right"
);
}
#[test]
fn test_softplus_simd_f64_nan() {
let x = array![f64::NAN];
let result = softplus_simd(&x.view());
assert!(result[0].is_nan(), "Softplus(NaN) should be NaN");
}
#[test]
fn test_softplus_simd_f32_basic() {
let x = array![0.0_f32, 1.0, -1.0, 2.0, -2.0];
let result = softplus_simd(&x.view());
let expected_0 = (2.0_f32).ln();
assert!(
(result[0] - expected_0).abs() < 1e-6,
"Softplus(0) should be ln(2)"
);
let expected_1 = (1.0_f32 + 1.0_f32.exp()).ln();
assert!(
(result[1] - expected_1).abs() < 1e-5,
"Softplus(1) should be approximately {}",
expected_1
);
let expected_neg1 = (1.0_f32 + (-1.0_f32).exp()).ln();
assert!(
(result[2] - expected_neg1).abs() < 1e-5,
"Softplus(-1) should be approximately {}",
expected_neg1
);
}
#[test]
fn test_softplus_simd_empty() {
let x = array![] as Array1<f64>;
let result = softplus_simd(&x.view());
assert!(result.is_empty(), "Softplus of empty array should be empty");
}
#[test]
fn test_softplus_simd_large_array() {
let n = 10000;
let x: Array1<f64> = Array1::linspace(-10.0, 10.0, n);
let result = softplus_simd(&x.view());
assert_eq!(result.len(), n, "Result should have same length as input");
for i in 0..n {
assert!(
result[i] > 0.0,
"Softplus({}) should be positive, got {}",
x[i],
result[i]
);
}
for i in (0..n).step_by(100) {
if i > 0 {
assert!(
result[i] >= result[i - 100],
"Softplus should be monotonically increasing"
);
}
}
}
#[test]
fn test_softplus_simd_derivative() {
let eps = 1e-6_f64;
let test_points = array![0.0_f64, 0.5, 1.0, 2.0, -0.5, -1.0, -2.0];
for i in 0..test_points.len() {
let x = test_points[i];
let x_plus = array![x + eps];
let x_minus = array![x - eps];
let softplus_plus = softplus_simd(&x_plus.view())[0];
let softplus_minus = softplus_simd(&x_minus.view())[0];
let numerical_derivative = (softplus_plus - softplus_minus) / (2.0 * eps);
let analytical_derivative = 1.0 / (1.0 + (-x).exp());
assert!(
(numerical_derivative - analytical_derivative).abs() < 1e-4,
"Softplus derivative at {} should be sigmoid({}) = {}, numerical: {}",
x,
x,
analytical_derivative,
numerical_derivative
);
}
}
#[test]
fn test_softplus_simd_probabilistic_model() {
let raw_outputs = array![-5.0_f64, -2.0, 0.0, 2.0, 5.0];
let variances = softplus_simd(&raw_outputs.view());
for i in 0..raw_outputs.len() {
assert!(
variances[i] > 0.0,
"Variance should be positive, got {}",
variances[i]
);
assert!(
variances[i].is_finite(),
"Variance should be finite, got {}",
variances[i]
);
}
for i in 1..raw_outputs.len() {
assert!(
variances[i] > variances[i - 1],
"Variance should increase with raw output"
);
}
}
#[test]
fn test_softplus_simd_vs_relu() {
let x = array![-3.0_f64, -1.0, 0.0, 1.0, 3.0];
let softplus_result = softplus_simd(&x.view());
assert!(softplus_result[2] > 0.0, "Softplus(0) = ln(2) > 0");
assert!(softplus_result[3] > x[3], "Softplus(1) > 1");
assert!(softplus_result[4] > x[4], "Softplus(3) > 3");
let diff_1 = softplus_result[3] - x[3]; let diff_3 = softplus_result[4] - x[4]; assert!(
diff_3 < diff_1,
"Softplus - x should decrease as x increases"
);
let large_x = array![10.0_f64];
let large_result = softplus_simd(&large_x.view())[0];
assert!((large_result - 10.0).abs() < 0.001);
}
#[test]
fn test_mish_simd_f64_at_zero() {
let x = array![0.0_f64];
let result = mish_simd(&x.view());
assert!(
result[0].abs() < 1e-15,
"Mish(0) should be 0, got {}",
result[0]
);
}
#[test]
fn test_mish_simd_f64_basic() {
let x = array![0.0_f64, 1.0, -1.0, 2.0, -2.0];
let result = mish_simd(&x.view());
assert!(result[0].abs() < 1e-10, "Mish(0) should be 0");
let softplus_1 = (1.0_f64 + 1.0_f64.exp()).ln();
let expected_mish_1 = 1.0 * softplus_1.tanh();
assert!(
(result[1] - expected_mish_1).abs() < 1e-10,
"Mish(1) should be approximately {}, got {}",
expected_mish_1,
result[1]
);
let softplus_neg1 = (1.0_f64 + (-1.0_f64).exp()).ln();
let expected_mish_neg1 = -softplus_neg1.tanh();
assert!(
(result[2] - expected_mish_neg1).abs() < 1e-10,
"Mish(-1) should be approximately {}, got {}",
expected_mish_neg1,
result[2]
);
let softplus_2 = (1.0_f64 + 2.0_f64.exp()).ln();
let expected_mish_2 = 2.0 * softplus_2.tanh();
assert!(
(result[3] - expected_mish_2).abs() < 1e-10,
"Mish(2) should be approximately {}, got {}",
expected_mish_2,
result[3]
);
}
#[test]
fn test_mish_simd_f64_large_positive() {
let x = array![5.0_f64, 10.0, 20.0];
let result = mish_simd(&x.view());
for i in 0..x.len() {
let relative_error = (result[i] - x[i]).abs() / x[i];
assert!(
relative_error < 0.01,
"Mish({}) ≈ {}, relative error should be small: {}",
x[i],
result[i],
relative_error
);
}
}
#[test]
fn test_mish_simd_f64_large_negative() {
let x = array![-5.0_f64, -10.0, -20.0];
let result = mish_simd(&x.view());
for i in 0..x.len() {
assert!(
result[i].abs() < 0.1,
"Mish({}) should be approximately 0, got {}",
x[i],
result[i]
);
}
}
#[test]
fn test_mish_simd_f64_global_minimum() {
let x = array![-1.2_f64, -1.5, -1.0, -2.0];
let result = mish_simd(&x.view());
let min_value = result[0];
assert!(
min_value < -0.28 && min_value > -0.35,
"Mish global minimum should be around -0.31, got {}",
min_value
);
}
#[test]
fn test_mish_simd_f64_smoothness() {
let eps = 1e-6_f64;
let x = array![-eps, 0.0, eps];
let result = mish_simd(&x.view());
let diff_left = (result[0] - result[1]).abs();
let diff_right = (result[2] - result[1]).abs();
assert!(
diff_left < 1e-5,
"Mish should be continuous at 0 from the left"
);
assert!(
diff_right < 1e-5,
"Mish should be continuous at 0 from the right"
);
}
#[test]
fn test_mish_simd_f64_nan() {
let x = array![f64::NAN];
let result = mish_simd(&x.view());
assert!(result[0].is_nan(), "Mish(NaN) should be NaN");
}
#[test]
fn test_mish_simd_f32_basic() {
let x = array![0.0_f32, 1.0, -1.0, 2.0, -2.0];
let result = mish_simd(&x.view());
assert!(result[0].abs() < 1e-6, "Mish(0) should be 0");
let softplus_1 = (1.0_f32 + 1.0_f32.exp()).ln();
let expected_mish_1 = 1.0_f32 * softplus_1.tanh();
assert!(
(result[1] - expected_mish_1).abs() < 1e-5,
"Mish(1) should be approximately {}",
expected_mish_1
);
}
#[test]
fn test_mish_simd_empty() {
let x = array![] as Array1<f64>;
let result = mish_simd(&x.view());
assert!(result.is_empty(), "Mish of empty array should be empty");
}
#[test]
fn test_mish_simd_large_array() {
let n = 10000;
let x: Array1<f64> = Array1::linspace(-5.0, 5.0, n);
let result = mish_simd(&x.view());
assert_eq!(result.len(), n, "Result should have same length as input");
for i in (0..n).step_by(500) {
let xi = x[i];
let softplus_xi = (1.0_f64 + xi.exp()).ln();
let expected = xi * softplus_xi.tanh();
assert!(
(result[i] - expected).abs() < 1e-8,
"Mish({}) should be {}, got {}",
xi,
expected,
result[i]
);
}
}
#[test]
fn test_mish_simd_vs_swish() {
let x = array![-2.0_f64, -1.0, 0.0, 1.0, 2.0];
let mish_result = mish_simd(&x.view());
let swish_result = swish_simd(&x.view());
assert!(mish_result[2].abs() < 1e-10);
assert!(swish_result[2].abs() < 1e-10);
assert!(mish_result[3] > swish_result[3]);
assert!(mish_result[4] > swish_result[4]);
let large_x = array![10.0_f64];
let mish_large = mish_simd(&large_x.view())[0];
let swish_large = swish_simd(&large_x.view())[0];
assert!((mish_large - 10.0).abs() < 0.01);
assert!((swish_large - 10.0).abs() < 0.01);
}
#[test]
fn test_mish_simd_yolov4_use_case() {
let features = array![0.5_f64, 1.2, -0.8, 2.0, -1.5, 0.0, 3.0, -3.0];
let activated = mish_simd(&features.view());
for i in 0..features.len() {
assert!(
activated[i].is_finite(),
"Mish output should be finite for input {}",
features[i]
);
}
assert!(activated[5].abs() < 1e-10);
assert!(activated[0] > 0.0);
assert!(activated[1] > 0.0);
assert!(activated[3] > 0.0);
assert!(activated[6] > 0.0);
for i in 0..features.len() {
assert!(
activated[i] > -0.35,
"Mish output should be >= -0.31, got {}",
activated[i]
);
}
}