use super::NnResult;
use crate::error::NumRs2Error;
use scirs2_core::ndarray::{
s, Array, Array1, Array2, ArrayView, ArrayView1, ArrayView2, Axis, ScalarOperand, Zip,
};
use scirs2_core::numeric::Float;
use scirs2_core::simd_ops::SimdUnifiedOps;
pub fn relu<T>(x: &ArrayView1<T>) -> NnResult<Array1<T>>
where
T: Float + SimdUnifiedOps,
{
let zero = T::zero();
Ok(x.mapv(|v| if v > zero { v } else { zero }))
}
pub fn relu_2d<T>(x: &ArrayView2<T>) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
let zero = T::zero();
Ok(x.mapv(|v| if v > zero { v } else { zero }))
}
pub fn relu_inplace<T>(x: &mut Array1<T>)
where
T: Float + SimdUnifiedOps,
{
let zero = T::zero();
x.mapv_inplace(|v| if v > zero { v } else { zero });
}
pub fn leaky_relu<T>(x: &ArrayView1<T>, alpha: T) -> NnResult<Array1<T>>
where
T: Float + SimdUnifiedOps,
{
if alpha < T::zero() {
return Err(NumRs2Error::InvalidOperation(
"Leaky ReLU alpha must be non-negative".to_string(),
));
}
let zero = T::zero();
Ok(x.mapv(|v| if v > zero { v } else { alpha * v }))
}
pub fn leaky_relu_2d<T>(x: &ArrayView2<T>, alpha: T) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
if alpha < T::zero() {
return Err(NumRs2Error::InvalidOperation(
"Leaky ReLU alpha must be non-negative".to_string(),
));
}
let zero = T::zero();
Ok(x.mapv(|v| if v > zero { v } else { alpha * v }))
}
pub fn elu<T>(x: &ArrayView1<T>, alpha: T) -> NnResult<Array1<T>>
where
T: Float + SimdUnifiedOps,
{
if alpha <= T::zero() {
return Err(NumRs2Error::InvalidOperation(
"ELU alpha must be positive".to_string(),
));
}
let zero = T::zero();
let one = T::one();
Ok(x.mapv(|v| if v > zero { v } else { alpha * (v.exp() - one) }))
}
pub fn elu_2d<T>(x: &ArrayView2<T>, alpha: T) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
if alpha <= T::zero() {
return Err(NumRs2Error::InvalidOperation(
"ELU alpha must be positive".to_string(),
));
}
let zero = T::zero();
let one = T::one();
Ok(x.mapv(|v| if v > zero { v } else { alpha * (v.exp() - one) }))
}
pub fn selu<T>(x: &ArrayView1<T>) -> NnResult<Array1<T>>
where
T: Float + SimdUnifiedOps,
{
let lambda = T::from(1.0507009873554804934193349852946).ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert SELU lambda constant".to_string())
})?;
let alpha = T::from(1.6732632423543772848170429916717).ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert SELU alpha constant".to_string())
})?;
let zero = T::zero();
let one = T::one();
Ok(x.mapv(|v| {
if v > zero {
lambda * v
} else {
lambda * alpha * (v.exp() - one)
}
}))
}
pub fn selu_2d<T>(x: &ArrayView2<T>) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
let lambda = T::from(1.0507009873554804934193349852946).ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert SELU lambda constant".to_string())
})?;
let alpha = T::from(1.6732632423543772848170429916717).ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert SELU alpha constant".to_string())
})?;
let zero = T::zero();
let one = T::one();
Ok(x.mapv(|v| {
if v > zero {
lambda * v
} else {
lambda * alpha * (v.exp() - one)
}
}))
}
pub fn sigmoid<T>(x: &ArrayView1<T>) -> NnResult<Array1<T>>
where
T: Float + SimdUnifiedOps,
{
let one = T::one();
Ok(x.mapv(|v| one / (one + (-v).exp())))
}
pub fn sigmoid_2d<T>(x: &ArrayView2<T>) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
let one = T::one();
Ok(x.mapv(|v| one / (one + (-v).exp())))
}
pub fn tanh<T>(x: &ArrayView1<T>) -> NnResult<Array1<T>>
where
T: Float + SimdUnifiedOps,
{
Ok(x.mapv(|v| v.tanh()))
}
pub fn tanh_2d<T>(x: &ArrayView2<T>) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
Ok(x.mapv(|v| v.tanh()))
}
pub fn swish<T>(x: &ArrayView1<T>) -> NnResult<Array1<T>>
where
T: Float + SimdUnifiedOps,
{
let one = T::one();
Ok(x.mapv(|v| v / (one + (-v).exp())))
}
pub fn swish_2d<T>(x: &ArrayView2<T>) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
let one = T::one();
Ok(x.mapv(|v| v / (one + (-v).exp())))
}
pub fn silu<T>(x: &ArrayView1<T>) -> NnResult<Array1<T>>
where
T: Float + SimdUnifiedOps,
{
swish(x)
}
pub fn silu_2d<T>(x: &ArrayView2<T>) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
swish_2d(x)
}
pub fn mish<T>(x: &ArrayView1<T>) -> NnResult<Array1<T>>
where
T: Float + SimdUnifiedOps,
{
let one = T::one();
Ok(x.mapv(|v| {
let softplus = (one + v.exp()).ln();
v * softplus.tanh()
}))
}
pub fn mish_2d<T>(x: &ArrayView2<T>) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
let one = T::one();
Ok(x.mapv(|v| {
let softplus = (one + v.exp()).ln();
v * softplus.tanh()
}))
}
pub fn gelu<T>(x: &ArrayView1<T>) -> NnResult<Array1<T>>
where
T: Float + SimdUnifiedOps,
{
let half = T::from(0.5)
.ok_or_else(|| NumRs2Error::InvalidOperation("Failed to convert 0.5".to_string()))?;
let one = T::one();
let coeff = T::from(0.7978845608028654).ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert GELU coefficient".to_string())
})?; let cubic_coeff = T::from(0.044715).ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert GELU cubic coefficient".to_string())
})?;
Ok(x.mapv(|v| {
let three = T::from(3.0).unwrap_or(one + one + one);
let cubic = v.powi(3);
let inner = coeff * (v + cubic_coeff * cubic);
half * v * (one + inner.tanh())
}))
}
pub fn gelu_2d<T>(x: &ArrayView2<T>) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
let half = T::from(0.5)
.ok_or_else(|| NumRs2Error::InvalidOperation("Failed to convert 0.5".to_string()))?;
let one = T::one();
let coeff = T::from(0.7978845608028654).ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert GELU coefficient".to_string())
})?;
let cubic_coeff = T::from(0.044715).ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert GELU cubic coefficient".to_string())
})?;
Ok(x.mapv(|v| {
let three = T::from(3.0).unwrap_or(one + one + one);
let cubic = v.powi(3);
let inner = coeff * (v + cubic_coeff * cubic);
half * v * (one + inner.tanh())
}))
}
pub fn softmax<T>(x: &ArrayView1<T>) -> NnResult<Array1<T>>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
if x.is_empty() {
return Err(NumRs2Error::DimensionMismatch(
"Softmax requires non-empty input".to_string(),
));
}
let max_val = x.fold(T::neg_infinity(), |acc, &v| if v > acc { v } else { acc });
if !max_val.is_finite() {
return Err(NumRs2Error::InvalidOperation(
"Softmax input contains non-finite values".to_string(),
));
}
let shifted = x.mapv(|v| (v - max_val).exp());
let sum = shifted.sum();
if sum == T::zero() || !sum.is_finite() {
return Err(NumRs2Error::InvalidOperation(
"Softmax normalization failed (sum is zero or non-finite)".to_string(),
));
}
Ok(shifted / sum)
}
pub fn softmax_2d<T>(x: &ArrayView2<T>, axis: usize) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
if x.is_empty() {
return Err(NumRs2Error::DimensionMismatch(
"Softmax requires non-empty input".to_string(),
));
}
if axis >= 2 {
return Err(NumRs2Error::InvalidOperation(format!(
"Invalid axis {} for 2D array",
axis
)));
}
let mut result = Array2::zeros(x.raw_dim());
if axis == 1 {
for (i, row) in x.axis_iter(Axis(0)).enumerate() {
let softmax_row = softmax(&row)?;
result.row_mut(i).assign(&softmax_row);
}
} else {
for (j, col) in x.axis_iter(Axis(1)).enumerate() {
let softmax_col = softmax(&col)?;
result.column_mut(j).assign(&softmax_col);
}
}
Ok(result)
}
pub fn log_softmax<T>(x: &ArrayView1<T>) -> NnResult<Array1<T>>
where
T: Float + SimdUnifiedOps,
{
if x.is_empty() {
return Err(NumRs2Error::DimensionMismatch(
"Log-softmax requires non-empty input".to_string(),
));
}
let max_val = x.fold(T::neg_infinity(), |acc, &v| if v > acc { v } else { acc });
if !max_val.is_finite() {
return Err(NumRs2Error::InvalidOperation(
"Log-softmax input contains non-finite values".to_string(),
));
}
let shifted = x.mapv(|v| v - max_val);
let log_sum_exp = shifted.mapv(|v| v.exp()).sum().ln();
if !log_sum_exp.is_finite() {
return Err(NumRs2Error::InvalidOperation(
"Log-softmax computation failed (non-finite log_sum_exp)".to_string(),
));
}
Ok(shifted.mapv(|v| v - log_sum_exp))
}
pub fn log_softmax_2d<T>(x: &ArrayView2<T>, axis: usize) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
if x.is_empty() {
return Err(NumRs2Error::DimensionMismatch(
"Log-softmax requires non-empty input".to_string(),
));
}
if axis >= 2 {
return Err(NumRs2Error::InvalidOperation(format!(
"Invalid axis {} for 2D array",
axis
)));
}
let mut result = Array2::zeros(x.raw_dim());
if axis == 1 {
for (i, row) in x.axis_iter(Axis(0)).enumerate() {
let log_softmax_row = log_softmax(&row)?;
result.row_mut(i).assign(&log_softmax_row);
}
} else {
for (j, col) in x.axis_iter(Axis(1)).enumerate() {
let log_softmax_col = log_softmax(&col)?;
result.column_mut(j).assign(&log_softmax_col);
}
}
Ok(result)
}
pub fn softplus<T>(x: &ArrayView1<T>) -> NnResult<Array1<T>>
where
T: Float + SimdUnifiedOps,
{
let one = T::one();
let threshold = T::from(20.0)
.ok_or_else(|| NumRs2Error::InvalidOperation("Failed to convert threshold".to_string()))?;
Ok(x.mapv(|v| {
if v > threshold {
v
} else {
(one + v.exp()).ln()
}
}))
}
pub fn softplus_2d<T>(x: &ArrayView2<T>) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
let one = T::one();
let threshold = T::from(20.0)
.ok_or_else(|| NumRs2Error::InvalidOperation("Failed to convert threshold".to_string()))?;
Ok(x.mapv(|v| {
if v > threshold {
v
} else {
(one + v.exp()).ln()
}
}))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_relu() {
let x = array![-2.0, -1.0, 0.0, 1.0, 2.0];
let y = relu(&x.view()).expect("test: valid relu input");
assert_abs_diff_eq!(y[0], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(y[1], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(y[2], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(y[3], 1.0, epsilon = 1e-6);
assert_abs_diff_eq!(y[4], 2.0, epsilon = 1e-6);
}
#[test]
fn test_sigmoid() {
let x = array![0.0];
let y = sigmoid(&x.view()).expect("test: valid sigmoid input");
assert_abs_diff_eq!(y[0], 0.5, epsilon = 1e-6);
}
#[test]
fn test_softmax() {
let x = array![1.0, 2.0, 3.0];
let y = softmax(&x.view()).expect("test: valid softmax input");
let sum: f64 = y.sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
assert!(y.iter().all(|&v| v > 0.0));
}
#[test]
fn test_softmax_numerical_stability() {
let x = array![1000.0, 1001.0, 1002.0];
let y = softmax(&x.view()).expect("test: valid softmax input");
assert!(y.iter().all(|&v| v.is_finite()));
let sum: f64 = y.sum();
assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
}
#[test]
fn test_gelu() {
let x = array![0.0];
let y = gelu(&x.view()).expect("test: valid gelu input");
assert_abs_diff_eq!(y[0], 0.0, epsilon = 1e-6);
}
#[test]
fn test_swish() {
let x = array![0.0];
let y = swish(&x.view()).expect("test: valid swish input");
assert_abs_diff_eq!(y[0], 0.0, epsilon = 1e-6);
}
}