use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum Activation {
Relu,
Sigmoid,
Softmax,
Sqrt,
Log,
Log10,
Tanh,
Inverse,
Gelu,
Softplus,
Exp,
Sign,
}
#[inline(always)]
fn relu(x: f32) -> f32 {
x.max(0.0)
}
#[inline(always)]
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
#[inline(always)]
fn sqrt_activation(x: f32) -> f32 {
if x > 0.0 { x.sqrt() } else { 0.0 }
}
#[inline(always)]
fn log_activation(x: f32) -> f32 {
if x > 0.0 { (x + 1.0).ln() } else { 0.0 }
}
#[inline(always)]
fn log10_activation(x: f32) -> f32 {
if x > 0.0 { (x + 1.0).log10() } else { 0.0 }
}
#[inline(always)]
fn tanh_activation(x: f32) -> f32 {
x.tanh()
}
#[inline(always)]
fn inverse_activation(x: f32) -> f32 {
1.0 - x
}
#[inline(always)]
fn erf(x: f32) -> f32 {
let sign = if x >= 0.0 { 1.0 } else { -1.0 };
let x = x.abs();
const A1: f32 = 0.254_829_6;
const A2: f32 = -0.284_496_72;
const A3: f32 = 1.421_413_8;
const A4: f32 = -1.453_152_1;
const A5: f32 = 1.061_405_4;
const P: f32 = 0.327_591_1;
let t = 1.0 / (1.0 + P * x);
let y = 1.0 - (((((A5 * t + A4) * t) + A3) * t + A2) * t + A1) * t * (-x * x).exp();
sign * y
}
#[inline(always)]
fn gelu_activation(x: f32) -> f32 {
const SQRT_2_INV: f32 = std::f32::consts::FRAC_1_SQRT_2; x * 0.5 * (1.0 + erf(x * SQRT_2_INV))
}
#[inline(always)]
fn softplus_activation(x: f32) -> f32 {
const THRESHOLD: f32 = -13.9424;
if x > -THRESHOLD {
x
} else if x < THRESHOLD {
x.exp()
} else {
x.exp().ln_1p() }
}
#[inline(always)]
fn exp_activation(x: f32) -> f32 {
x.clamp(-88.0, 88.0).exp()
}
#[inline(always)]
fn sign_activation(x: f32) -> f32 {
if x > 0.0 {
1.0
} else if x < 0.0 {
-1.0
} else {
x
}
}
impl Activation {
pub fn get_by_name(type_name: &str) -> Option<Self> {
let map: HashMap<&str, Activation> = [
("RELU", Activation::Relu),
("SIGMOID", Activation::Sigmoid),
("SOFTMAX", Activation::Softmax),
("SQRT", Activation::Sqrt),
("LOG", Activation::Log),
("LOG10", Activation::Log10),
("TANH", Activation::Tanh),
("INVERSE", Activation::Inverse),
("GELU", Activation::Gelu),
("SOFTPLUS", Activation::Softplus),
("EXP", Activation::Exp),
("SIGN", Activation::Sign),
]
.iter()
.cloned()
.collect();
map.get(type_name).copied()
}
pub fn apply_single(self, x: f32) -> f32 {
match self {
Activation::Relu => relu(x),
Activation::Sigmoid => sigmoid(x),
Activation::Sqrt => sqrt_activation(x),
Activation::Log => log_activation(x),
Activation::Log10 => log10_activation(x),
Activation::Tanh => tanh_activation(x),
Activation::Inverse => inverse_activation(x),
Activation::Gelu => gelu_activation(x),
Activation::Softplus => softplus_activation(x),
Activation::Exp => exp_activation(x),
Activation::Sign => sign_activation(x),
Activation::Softmax => {
x.exp()
}
}
}
pub fn apply_in_place(self, values: &mut [f32]) {
match self {
Activation::Softmax => {
let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut sum = 0.0f32;
for val in values.iter_mut() {
*val = (*val - max_val).exp();
sum += *val;
}
for val in values.iter_mut() {
*val /= sum;
}
}
Activation::Relu => {
for val in values.iter_mut() {
*val = relu(*val);
}
}
Activation::Sigmoid => {
for val in values.iter_mut() {
*val = sigmoid(*val);
}
}
Activation::Sqrt => {
for val in values.iter_mut() {
*val = sqrt_activation(*val);
}
}
Activation::Log => {
for val in values.iter_mut() {
*val = log_activation(*val);
}
}
Activation::Log10 => {
for val in values.iter_mut() {
*val = log10_activation(*val);
}
}
Activation::Tanh => {
for val in values.iter_mut() {
*val = tanh_activation(*val);
}
}
Activation::Inverse => {
for val in values.iter_mut() {
*val = inverse_activation(*val);
}
}
Activation::Gelu => {
for val in values.iter_mut() {
*val = gelu_activation(*val);
}
}
Activation::Softplus => {
for val in values.iter_mut() {
*val = softplus_activation(*val);
}
}
Activation::Exp => {
for val in values.iter_mut() {
*val = exp_activation(*val);
}
}
Activation::Sign => {
for val in values.iter_mut() {
*val = sign_activation(*val);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const DELTA: f32 = 0.00005;
#[test]
fn test_relu() {
assert!((Activation::Relu.apply_single(1.0) - 1.0).abs() < DELTA);
assert!((Activation::Relu.apply_single(-1.0) - 0.0).abs() < DELTA);
assert!((Activation::Relu.apply_single(0.5) - 0.5).abs() < DELTA);
}
#[test]
fn test_sigmoid() {
assert!((Activation::Sigmoid.apply_single(1.0) - 0.7311).abs() < DELTA);
assert!((Activation::Sigmoid.apply_single(0.0) - 0.5).abs() < DELTA);
assert!((Activation::Sigmoid.apply_single(-0.5) - 0.3775).abs() < DELTA);
}
#[test]
fn test_softmax() {
let mut values = [1.0, 2.0, 3.0];
Activation::Softmax.apply_in_place(&mut values);
assert!((values[0] - 0.09003057).abs() < DELTA);
assert!((values[1] - 0.24472847).abs() < DELTA);
assert!((values[2] - 0.66524096).abs() < DELTA);
}
#[test]
fn test_sqrt() {
assert!((Activation::Sqrt.apply_single(4.0) - 2.0).abs() < DELTA);
assert!((Activation::Sqrt.apply_single(-1.0) - 0.0).abs() < DELTA);
assert!((Activation::Sqrt.apply_single(9.0) - 3.0).abs() < DELTA);
}
#[test]
fn test_log() {
assert!((Activation::Log.apply_single(1.0) - 2.0_f32.ln()).abs() < DELTA);
assert!((Activation::Log.apply_single(0.0) - 0.0).abs() < DELTA);
assert!((Activation::Log.apply_single(9.0) - 10.0_f32.ln()).abs() < DELTA);
}
#[test]
fn test_log10() {
assert!((Activation::Log10.apply_single(9.0) - 1.0).abs() < DELTA);
assert!((Activation::Log10.apply_single(0.0) - 0.0).abs() < DELTA);
assert!((Activation::Log10.apply_single(99.0) - 2.0).abs() < DELTA);
}
#[test]
fn test_tanh() {
assert!((Activation::Tanh.apply_single(0.0) - 0.0).abs() < DELTA);
assert!((Activation::Tanh.apply_single(1.0) - 1.0_f32.tanh()).abs() < DELTA);
assert!((Activation::Tanh.apply_single(-1.0) - (-1.0_f32).tanh()).abs() < DELTA);
}
#[test]
fn test_inverse() {
assert!((Activation::Inverse.apply_single(1.0) - 0.0).abs() < DELTA);
assert!((Activation::Inverse.apply_single(0.0) - 1.0).abs() < DELTA);
assert!((Activation::Inverse.apply_single(-1.0) - 2.0).abs() < DELTA);
}
#[test]
fn test_gelu() {
const GELU_DELTA: f32 = 0.001;
assert!((Activation::Gelu.apply_single(-2.0) - (-0.0454)).abs() < GELU_DELTA);
assert!((Activation::Gelu.apply_single(-1.0) - (-0.1587)).abs() < GELU_DELTA);
assert!((Activation::Gelu.apply_single(0.0) - 0.0).abs() < DELTA);
assert!((Activation::Gelu.apply_single(1.0) - 0.8413).abs() < GELU_DELTA);
assert!((Activation::Gelu.apply_single(2.0) - 1.9545).abs() < GELU_DELTA);
}
#[test]
fn test_gelu_in_place() {
let mut values = [-2.0, -1.0, 0.0, 1.0, 2.0];
Activation::Gelu.apply_in_place(&mut values);
const GELU_DELTA: f32 = 0.001;
assert!((values[0] - (-0.0454)).abs() < GELU_DELTA);
assert!((values[1] - (-0.1587)).abs() < GELU_DELTA);
assert!((values[2] - 0.0).abs() < DELTA);
assert!((values[3] - 0.8413).abs() < GELU_DELTA);
assert!((values[4] - 1.9545).abs() < GELU_DELTA);
}
#[test]
fn test_get_by_name() {
assert_eq!(Activation::get_by_name("RELU"), Some(Activation::Relu));
assert_eq!(
Activation::get_by_name("SIGMOID"),
Some(Activation::Sigmoid)
);
assert_eq!(
Activation::get_by_name("SOFTMAX"),
Some(Activation::Softmax)
);
assert_eq!(Activation::get_by_name("GELU"), Some(Activation::Gelu));
assert_eq!(
Activation::get_by_name("SOFTPLUS"),
Some(Activation::Softplus)
);
assert_eq!(Activation::get_by_name("EXP"), Some(Activation::Exp));
assert_eq!(Activation::get_by_name("SIGN"), Some(Activation::Sign));
assert_eq!(Activation::get_by_name("INVALID"), None);
}
#[test]
fn test_softplus() {
const SOFTPLUS_DELTA: f32 = 1e-7;
assert!((Activation::Softplus.apply_single(0.0) - 0.6931472).abs() < SOFTPLUS_DELTA);
assert!((Activation::Softplus.apply_single(1.0) - 1.3132616).abs() < SOFTPLUS_DELTA);
assert!((Activation::Softplus.apply_single(-1.0) - 0.3132617).abs() < SOFTPLUS_DELTA);
assert!((Activation::Softplus.apply_single(100.0) - 100.0).abs() < SOFTPLUS_DELTA);
assert!(Activation::Softplus.apply_single(-100.0) < SOFTPLUS_DELTA);
}
#[test]
fn test_softplus_in_place() {
const SOFTPLUS_DELTA: f32 = 1e-7;
let mut values = [-100.0, -1.0, 0.0, 1.0, 100.0];
Activation::Softplus.apply_in_place(&mut values);
assert!(values[0] < SOFTPLUS_DELTA); assert!((values[1] - 0.3132617).abs() < SOFTPLUS_DELTA);
assert!((values[2] - 0.6931472).abs() < SOFTPLUS_DELTA);
assert!((values[3] - 1.3132616).abs() < SOFTPLUS_DELTA);
assert!((values[4] - 100.0).abs() < SOFTPLUS_DELTA);
}
#[test]
fn test_exp() {
assert!((Activation::Exp.apply_single(0.0) - 1.0).abs() < DELTA);
assert!((Activation::Exp.apply_single(1.0) - std::f32::consts::E).abs() < DELTA);
assert!((Activation::Exp.apply_single(-1.0) - (-1.0_f32).exp()).abs() < DELTA);
assert!(Activation::Exp.apply_single(1000.0).is_finite());
assert!((Activation::Exp.apply_single(1000.0) - 88.0_f32.exp()).abs() < 1.0);
assert!(Activation::Exp.apply_single(-1000.0) >= 0.0);
assert!(Activation::Exp.apply_single(-1000.0) < DELTA);
}
#[test]
fn test_exp_in_place() {
let mut values = [0.0, 1.0, -1.0];
Activation::Exp.apply_in_place(&mut values);
assert!((values[0] - 1.0).abs() < DELTA);
assert!((values[1] - std::f32::consts::E).abs() < DELTA);
assert!((values[2] - (-1.0_f32).exp()).abs() < DELTA);
}
#[test]
fn test_sign() {
assert!((Activation::Sign.apply_single(3.5) - 1.0).abs() < DELTA);
assert!((Activation::Sign.apply_single(-2.0) - (-1.0)).abs() < DELTA);
assert!((Activation::Sign.apply_single(0.0) - 0.0).abs() < DELTA);
assert!(Activation::Sign.apply_single(f32::NAN).is_nan());
}
#[test]
fn test_sign_in_place() {
let mut values = [-5.0, -0.0, 0.0, 0.001, 100.0];
Activation::Sign.apply_in_place(&mut values);
assert!((values[0] - (-1.0)).abs() < DELTA);
assert!((values[1] - 0.0).abs() < DELTA);
assert!((values[2] - 0.0).abs() < DELTA);
assert!((values[3] - 1.0).abs() < DELTA);
assert!((values[4] - 1.0).abs() < DELTA);
}
}