use crate::error::Result;
use scirs2_core::ndarray::{Array, IxDyn, Zip};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
pub fn relu<F: Float + Debug + NumAssign>(input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut output = input.clone();
let zero = F::zero();
Zip::from(&mut output).for_each(|x| {
if *x < zero {
*x = zero;
}
});
Ok(output)
}
pub fn sigmoid<F: Float + Debug + NumAssign>(input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut output = input.clone();
let one = F::one();
Zip::from(&mut output).for_each(|x| {
*x = one / (one + (-*x).exp());
});
Ok(output)
}
pub fn tanh<F: Float + Debug + NumAssign>(input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut output = input.clone();
Zip::from(&mut output).for_each(|x| {
*x = x.tanh();
});
Ok(output)
}
pub fn gelu<F: Float + Debug + NumAssign>(input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut output = input.clone();
let half = F::from(0.5).ok_or_else(|| {
crate::error::NeuralError::ComputationError("Failed to convert constant".to_string())
})?;
let sqrt_2_over_pi = F::from(0.7978845608028654).ok_or_else(|| {
crate::error::NeuralError::ComputationError("Failed to convert constant".to_string())
})?;
let coeff = F::from(0.044715).ok_or_else(|| {
crate::error::NeuralError::ComputationError("Failed to convert constant".to_string())
})?;
let one = F::one();
Zip::from(&mut output).for_each(|x| {
let x3 = *x * *x * *x;
let inner = sqrt_2_over_pi * (*x + coeff * x3);
*x = half * *x * (one + inner.tanh());
});
Ok(output)
}
pub fn leaky_relu<F: Float + Debug + NumAssign>(
input: &Array<F, IxDyn>,
negative_slope: F,
) -> Result<Array<F, IxDyn>> {
let mut output = input.clone();
let zero = F::zero();
Zip::from(&mut output).for_each(|x| {
if *x < zero {
*x = negative_slope * *x;
}
});
Ok(output)
}
pub fn swish<F: Float + Debug + NumAssign>(input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let sigmoid_output = sigmoid(input)?;
let mut output = input.clone();
Zip::from(&mut output)
.and(&sigmoid_output)
.for_each(|x, &sig| {
*x = *x * sig;
});
Ok(output)
}
pub fn mish<F: Float + Debug + NumAssign>(input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut output = input.clone();
let one = F::one();
Zip::from(&mut output).for_each(|x| {
let softplus = (one + x.exp()).ln();
*x = *x * softplus.tanh();
});
Ok(output)
}
pub fn softmax<F: Float + Debug + NumAssign>(input: &Array<F, IxDyn>, axis: isize) -> Result<Array<F, IxDyn>> {
let mut output = input.clone();
let actual_axis = if axis < 0 {
(input.ndim() as isize + axis) as usize
} else {
axis as usize
};
if actual_axis >= input.ndim() {
return Err(crate::error::NeuralError::InvalidArgument(format!(
"Axis {} out of bounds for array with {} dimensions",
axis,
input.ndim()
)));
}
if actual_axis == input.ndim() - 1 {
let max_val = input.fold(F::neg_infinity(), |acc, &x| if x > acc { x } else { acc });
Zip::from(&mut output).for_each(|x| {
*x = (*x - max_val).exp();
});
let sum = output.sum();
Zip::from(&mut output).for_each(|x| {
*x = *x / sum;
});
} else {
return Err(crate::error::NeuralError::InvalidArgument(
"Softmax along non-last axis not yet implemented".to_string(),
));
}
Ok(output)
}
pub fn elu<F: Float + Debug + NumAssign>(input: &Array<F, IxDyn>, alpha: F) -> Result<Array<F, IxDyn>> {
let mut output = input.clone();
let zero = F::zero();
let one = F::one();
Zip::from(&mut output).for_each(|x| {
if *x > zero {
} else {
*x = alpha * (x.exp() - one);
}
});
Ok(output)
}
pub fn selu<F: Float + Debug + NumAssign>(input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let scale = F::from(1.0507009873554804934193349852946).ok_or_else(|| {
crate::error::NeuralError::ComputationError("Failed to convert constant".to_string())
})?;
let alpha = F::from(1.6732632423543772848170429916717).ok_or_else(|| {
crate::error::NeuralError::ComputationError("Failed to convert constant".to_string())
})?;
let mut output = input.clone();
let zero = F::zero();
let one = F::one();
Zip::from(&mut output).for_each(|x| {
if *x > zero {
*x = scale * *x;
} else {
*x = scale * alpha * (x.exp() - one);
}
});
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array;
#[test]
fn test_relu() {
let input = Array::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0]).into_dyn();
let output = relu(&input).expect("ReLU failed");
let expected = Array::from_vec(vec![0.0, 0.0, 0.0, 1.0, 2.0]).into_dyn();
assert_eq!(output, expected);
}
#[test]
fn test_sigmoid() {
let input = Array::from_vec(vec![0.0]).into_dyn();
let output = sigmoid(&input).expect("Sigmoid failed");
assert!((output[[0]] - 0.5).abs() < 1e-6);
}
#[test]
fn test_tanh() {
let input = Array::from_vec(vec![0.0]).into_dyn();
let output = tanh(&input).expect("Tanh failed");
assert_eq!(output[[0]], 0.0);
}
#[test]
fn test_gelu() {
let input = Array::from_vec(vec![0.0]).into_dyn();
let output = gelu(&input).expect("GELU failed");
assert_eq!(output[[0]], 0.0);
}
#[test]
fn test_leaky_relu() {
let input = Array::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0]).into_dyn();
let output = leaky_relu(&input, 0.01).expect("Leaky ReLU failed");
assert!((output[[0]] - (-0.02)).abs() < 1e-6);
assert!((output[[1]] - (-0.01)).abs() < 1e-6);
assert_eq!(output[[2]], 0.0);
assert_eq!(output[[3]], 1.0);
assert_eq!(output[[4]], 2.0);
}
#[test]
fn test_swish() {
let input = Array::from_vec(vec![0.0]).into_dyn();
let output = swish(&input).expect("Swish failed");
assert_eq!(output[[0]], 0.0);
}
#[test]
fn test_mish() {
let input = Array::from_vec(vec![0.0]).into_dyn();
let output = mish(&input).expect("Mish failed");
assert_eq!(output[[0]], 0.0);
}
#[test]
fn test_softmax_1d() {
let input = Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
let output = softmax(&input, -1).expect("Softmax failed");
let sum: f64 = output.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
assert!(output[[0]] < output[[1]]);
assert!(output[[1]] < output[[2]]);
}
#[test]
fn test_elu() {
let input = Array::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0]).into_dyn();
let output = elu(&input, 1.0).expect("ELU failed");
assert_eq!(output[[3]], 1.0);
assert_eq!(output[[4]], 2.0);
assert!(output[[0]] < 0.0 && output[[0]] > -2.0);
}
#[test]
fn test_selu() {
let input = Array::from_vec(vec![-1.0, 0.0, 1.0]).into_dyn();
let output = selu(&input).expect("SELU failed");
assert_eq!(output[[1]], 0.0);
assert!(output[[2]] > 1.0);
}
}