use super::super::super::super::*;
#[test]
fn test_gelu_basic() {
let v = Vector::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let result = v.gelu().unwrap();
assert_eq!(result.data[2], 0.0);
assert!(result.data[0] < 0.0 && result.data[0] > -0.1);
assert!(result.data[1] < 0.0 && result.data[1] > -0.2);
assert!(result.data[3] > 0.8);
assert!(result.data[4] > 1.8);
}
#[test]
fn test_gelu_zero() {
let v = Vector::from_slice(&[0.0, 0.0, 0.0]);
let result = v.gelu().unwrap();
for &val in result.as_slice() {
assert_eq!(val, 0.0, "gelu(0) should be 0");
}
}
#[test]
fn test_gelu_smoothness() {
let v = Vector::from_slice(&[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]);
let result = v.gelu().unwrap();
for &val in result.as_slice() {
assert!(val.is_finite(), "GELU output should be finite");
}
assert!(result.data[0] < result.data[3]); assert!(result.data[3] < result.data[6]); }
#[test]
fn test_gelu_large_positive() {
let v = Vector::from_slice(&[5.0, 10.0, 20.0]);
let result = v.gelu().unwrap();
for i in 0..v.len() {
assert!(
(result.data[i] - v.data[i]).abs() < 0.01,
"gelu({}) = {} should ≈ {} for large positive x",
v.data[i],
result.data[i],
v.data[i]
);
}
}
#[test]
fn test_gelu_large_negative() {
let v = Vector::from_slice(&[-5.0, -10.0, -20.0]);
let result = v.gelu().unwrap();
for &val in result.as_slice() {
assert!(val.abs() < 0.001, "gelu should approach 0 for large negative inputs, got {}", val);
}
}
#[test]
fn test_gelu_empty_vector() {
let v = Vector::from_slice(&[]);
let result = v.gelu();
assert!(matches!(result, Err(TruenoError::EmptyVector)));
}
#[test]
fn test_swish_basic() {
let v = Vector::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let result = v.swish().unwrap();
assert!((result.as_slice()[0] - (-0.238)).abs() < 0.01);
assert!((result.as_slice()[1] - (-0.269)).abs() < 0.01);
assert_eq!(result.as_slice()[2], 0.0);
assert!((result.as_slice()[3] - 0.731).abs() < 0.01);
assert!((result.as_slice()[4] - 1.762).abs() < 0.01);
}
#[test]
fn test_swish_zero() {
let v = Vector::from_slice(&[0.0]);
let result = v.swish().unwrap();
assert_eq!(result.as_slice()[0], 0.0); }
#[test]
fn test_swish_minimum() {
let v = Vector::from_slice(&[-2.0, -1.5, -1.278, -1.0, -0.5]);
let result = v.swish().unwrap();
for &val in result.as_slice() {
assert!(val > -0.3, "Swish value {} below minimum", val);
}
assert!(result.as_slice()[2] < -0.27);
assert!(result.as_slice()[2] > -0.29);
}
#[test]
fn test_swish_large_positive() {
let v = Vector::from_slice(&[10.0, 20.0, 50.0]);
let result = v.swish().unwrap();
assert!((result.as_slice()[0] - 10.0).abs() < 0.01);
assert!((result.as_slice()[1] - 20.0).abs() < 0.01);
assert!((result.as_slice()[2] - 50.0).abs() < 0.01);
}
#[test]
fn test_swish_large_negative() {
let v = Vector::from_slice(&[-10.0, -20.0, -50.0]);
let result = v.swish().unwrap();
assert!(result.as_slice()[0].abs() < 1e-3);
assert!(result.as_slice()[1].abs() < 1e-7);
assert!(result.as_slice()[2].abs() < 1e-15); }
#[test]
fn test_swish_empty_vector() {
let v = Vector::from_slice(&[]);
let result = v.swish();
assert!(matches!(result, Err(TruenoError::EmptyVector)));
}
#[test]
fn test_hardswish_basic() {
let v = Vector::from_slice(&[-4.0, -3.0, -1.5, 0.0, 1.5, 3.0, 4.0]);
let result = v.hardswish().unwrap();
assert_eq!(result.as_slice()[0], 0.0);
assert_eq!(result.as_slice()[1], 0.0);
assert!((result.as_slice()[2] - (-0.375)).abs() < 1e-5);
assert_eq!(result.as_slice()[3], 0.0);
assert!((result.as_slice()[4] - 1.125).abs() < 1e-5);
assert_eq!(result.as_slice()[5], 3.0);
assert_eq!(result.as_slice()[6], 4.0);
}
#[test]
fn test_hardswish_zero() {
let v = Vector::from_slice(&[0.0]);
let result = v.hardswish().unwrap();
assert_eq!(result.as_slice()[0], 0.0);
}
#[test]
fn test_hardswish_boundary_values() {
let v = Vector::from_slice(&[-3.0, 3.0]);
let result = v.hardswish().unwrap();
assert_eq!(result.as_slice()[0], 0.0);
assert_eq!(result.as_slice()[1], 3.0);
}
#[test]
fn test_hardswish_large_values() {
let v = Vector::from_slice(&[-100.0, -10.0, 10.0, 100.0]);
let result = v.hardswish().unwrap();
assert_eq!(result.as_slice()[0], 0.0);
assert_eq!(result.as_slice()[1], 0.0);
assert_eq!(result.as_slice()[2], 10.0);
assert_eq!(result.as_slice()[3], 100.0);
}
#[test]
fn test_hardswish_transition_region() {
let v = Vector::from_slice(&[-2.0, -1.0, 1.0, 2.0]);
let result = v.hardswish().unwrap();
assert!((result.as_slice()[0] - (-1.0 / 3.0)).abs() < 1e-5);
assert!((result.as_slice()[1] - (-1.0 / 3.0)).abs() < 1e-5);
assert!((result.as_slice()[2] - (2.0 / 3.0)).abs() < 1e-5);
assert!((result.as_slice()[3] - (5.0 / 3.0)).abs() < 1e-5);
}
#[test]
fn test_hardswish_empty_vector() {
let v = Vector::from_slice(&[]);
let result = v.hardswish();
assert!(matches!(result, Err(TruenoError::EmptyVector)));
}
#[test]
fn test_mish_basic() {
let v = Vector::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let result = v.mish().unwrap();
assert!(result.as_slice()[0] < 0.0);
assert!(result.as_slice()[1] < 0.0);
assert!(result.as_slice()[2].abs() < 1e-5);
assert!(result.as_slice()[3] > 0.0);
assert!(result.as_slice()[4] > 0.0);
}
#[test]
fn test_mish_zero() {
let v = Vector::from_slice(&[0.0]);
let result = v.mish().unwrap();
assert!(result.as_slice()[0].abs() < 1e-10);
}
#[test]
fn test_mish_large_positive() {
let v = Vector::from_slice(&[10.0, 20.0, 50.0]);
let result = v.mish().unwrap();
assert!((result.as_slice()[0] - 10.0).abs() < 0.001);
assert!((result.as_slice()[1] - 20.0).abs() < 0.001);
assert!((result.as_slice()[2] - 50.0).abs() < 0.001);
}
#[test]
fn test_mish_large_negative() {
let v = Vector::from_slice(&[-10.0, -20.0, -50.0]);
let result = v.mish().unwrap();
assert!(result.as_slice()[0].abs() < 0.001);
assert!(result.as_slice()[1].abs() < 1e-6);
assert!(result.as_slice()[2].abs() < 1e-10);
}
#[test]
fn test_mish_minimum() {
let v = Vector::from_slice(&[-1.19]);
let result = v.mish().unwrap();
assert!(result.as_slice()[0] < -0.2);
assert!(result.as_slice()[0] > -0.4);
}
#[test]
fn test_mish_empty_vector() {
let v = Vector::from_slice(&[]);
let result = v.mish();
assert!(matches!(result, Err(TruenoError::EmptyVector)));
}
#[test]
fn test_selu_basic() {
let v = Vector::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
let result = v.selu().unwrap();
let data = result.as_slice();
const LAMBDA: f32 = 1.0507009873554804934193349852946;
const ALPHA: f32 = 1.6732632423543772848170429916717;
assert!((data[3] - LAMBDA * 1.0).abs() < 1e-5); assert!((data[4] - LAMBDA * 2.0).abs() < 1e-5);
assert!(data[2].abs() < 1e-5);
let expected_neg1 = LAMBDA * ALPHA * ((-1.0_f32).exp() - 1.0);
assert!((data[1] - expected_neg1).abs() < 1e-5);
}
#[test]
fn test_selu_zero() {
let v = Vector::from_slice(&[0.0]);
let result = v.selu().unwrap();
assert!(result.as_slice()[0].abs() < 1e-10);
}
#[test]
fn test_selu_positive_scaling() {
let v = Vector::from_slice(&[1.0, 2.0, 3.0, 10.0]);
let result = v.selu().unwrap();
let data = result.as_slice();
const LAMBDA: f32 = 1.0507009873554804934193349852946;
for (i, &x) in [1.0, 2.0, 3.0, 10.0].iter().enumerate() {
assert!(
(data[i] - LAMBDA * x).abs() < 1e-5,
"selu({}) should be {} but got {}",
x,
LAMBDA * x,
data[i]
);
}
}
#[test]
fn test_selu_negative_asymptote() {
let v = Vector::from_slice(&[-100.0]);
let result = v.selu().unwrap();
const LAMBDA: f32 = 1.0507009873554804934193349852946;
const ALPHA: f32 = 1.6732632423543772848170429916717;
let asymptote = -LAMBDA * ALPHA;
assert!(
(result.as_slice()[0] - asymptote).abs() < 1e-4,
"selu(-100) should approach {} but got {}",
asymptote,
result.as_slice()[0]
);
}
#[test]
fn test_selu_continuity_at_zero() {
let eps = 1e-6;
let v = Vector::from_slice(&[-eps, 0.0, eps]);
let result = v.selu().unwrap();
let data = result.as_slice();
assert!(data[0].abs() < 1e-3);
assert!(data[1].abs() < 1e-10);
assert!(data[2].abs() < 1e-3);
}
#[test]
fn test_selu_empty_vector() {
let v = Vector::from_slice(&[]);
let result = v.selu();
assert!(matches!(result, Err(TruenoError::EmptyVector)));
}