use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum Activation {
Relu,
Sigmoid,
Softmax,
Sqrt,
Log,
Log10,
Tanh,
Inverse,
}
impl Activation {
pub fn get_by_name(type_name: &str) -> Option<Self> {
let map: HashMap<&str, Activation> = [
("RELU", Activation::Relu),
("SIGMOID", Activation::Sigmoid),
("SOFTMAX", Activation::Softmax),
("SQRT", Activation::Sqrt),
("LOG", Activation::Log),
("LOG10", Activation::Log10),
("TANH", Activation::Tanh),
("INVERSE", Activation::Inverse),
]
.iter()
.cloned()
.collect();
map.get(type_name).copied()
}
pub fn apply_single(self, x: f32) -> f32 {
match self {
Activation::Relu => x.max(0.0),
Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
Activation::Sqrt => {
if x > 0.0 {
x.sqrt()
} else {
0.0
}
}
Activation::Log => {
if x > 0.0 {
(x + 1.0).ln()
} else {
0.0
}
}
Activation::Log10 => {
if x > 0.0 {
(x + 1.0).log10()
} else {
0.0
}
}
Activation::Tanh => x.tanh(),
Activation::Inverse => 1.0 - x,
Activation::Softmax => {
x.exp()
}
}
}
pub fn apply_in_place(self, values: &mut [f32]) {
match self {
Activation::Softmax => {
let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut sum = 0.0f32;
for val in values.iter_mut() {
*val = (*val - max_val).exp();
sum += *val;
}
for val in values.iter_mut() {
*val /= sum;
}
}
_ => {
for val in values.iter_mut() {
*val = self.apply_single(*val);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const DELTA: f32 = 0.00005;
#[test]
fn test_relu() {
assert!((Activation::Relu.apply_single(1.0) - 1.0).abs() < DELTA);
assert!((Activation::Relu.apply_single(-1.0) - 0.0).abs() < DELTA);
assert!((Activation::Relu.apply_single(0.5) - 0.5).abs() < DELTA);
}
#[test]
fn test_sigmoid() {
assert!((Activation::Sigmoid.apply_single(1.0) - 0.7311).abs() < DELTA);
assert!((Activation::Sigmoid.apply_single(0.0) - 0.5).abs() < DELTA);
assert!((Activation::Sigmoid.apply_single(-0.5) - 0.3775).abs() < DELTA);
}
#[test]
fn test_softmax() {
let mut values = [1.0, 2.0, 3.0];
Activation::Softmax.apply_in_place(&mut values);
assert!((values[0] - 0.09003057).abs() < DELTA);
assert!((values[1] - 0.24472847).abs() < DELTA);
assert!((values[2] - 0.66524096).abs() < DELTA);
}
#[test]
fn test_sqrt() {
assert!((Activation::Sqrt.apply_single(4.0) - 2.0).abs() < DELTA);
assert!((Activation::Sqrt.apply_single(-1.0) - 0.0).abs() < DELTA);
assert!((Activation::Sqrt.apply_single(9.0) - 3.0).abs() < DELTA);
}
#[test]
fn test_log() {
assert!((Activation::Log.apply_single(1.0) - 2.0_f32.ln()).abs() < DELTA);
assert!((Activation::Log.apply_single(0.0) - 0.0).abs() < DELTA);
assert!((Activation::Log.apply_single(9.0) - 10.0_f32.ln()).abs() < DELTA);
}
#[test]
fn test_log10() {
assert!((Activation::Log10.apply_single(9.0) - 1.0).abs() < DELTA);
assert!((Activation::Log10.apply_single(0.0) - 0.0).abs() < DELTA);
assert!((Activation::Log10.apply_single(99.0) - 2.0).abs() < DELTA);
}
#[test]
fn test_tanh() {
assert!((Activation::Tanh.apply_single(0.0) - 0.0).abs() < DELTA);
assert!((Activation::Tanh.apply_single(1.0) - 1.0_f32.tanh()).abs() < DELTA);
assert!((Activation::Tanh.apply_single(-1.0) - (-1.0_f32).tanh()).abs() < DELTA);
}
#[test]
fn test_inverse() {
assert!((Activation::Inverse.apply_single(1.0) - 0.0).abs() < DELTA);
assert!((Activation::Inverse.apply_single(0.0) - 1.0).abs() < DELTA);
assert!((Activation::Inverse.apply_single(-1.0) - 2.0).abs() < DELTA);
}
#[test]
fn test_get_by_name() {
assert_eq!(Activation::get_by_name("RELU"), Some(Activation::Relu));
assert_eq!(
Activation::get_by_name("SIGMOID"),
Some(Activation::Sigmoid)
);
assert_eq!(
Activation::get_by_name("SOFTMAX"),
Some(Activation::Softmax)
);
assert_eq!(Activation::get_by_name("INVALID"), None);
}
}