#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum Activation {
#[default]
Relu,
Sigmoid,
Tanh,
Identity,
}
impl Activation {
pub fn forward(&self, z: &mut [f64]) {
match self {
Self::Relu => {
for v in z.iter_mut() {
if *v < 0.0 {
*v = 0.0;
}
}
}
Self::Sigmoid => {
for v in z.iter_mut() {
*v = sigmoid(*v);
}
}
Self::Tanh => {
for v in z.iter_mut() {
*v = v.tanh();
}
}
Self::Identity => {}
}
}
pub fn backward_from_activated(&self, z: &[f64], activated: &[f64], grad_out: &mut [f64]) {
match self {
Self::Relu => {
for i in 0..grad_out.len() {
if z[i] <= 0.0 {
grad_out[i] = 0.0;
}
}
}
Self::Sigmoid => {
for i in 0..grad_out.len() {
let a = activated[i];
grad_out[i] *= a * (1.0 - a);
}
}
Self::Tanh => {
for i in 0..grad_out.len() {
let a = activated[i];
grad_out[i] *= 1.0 - a * a;
}
}
Self::Identity => {}
}
}
pub(crate) fn uses_he_init(self) -> bool {
matches!(self, Self::Relu)
}
pub(crate) fn to_gpu(self) -> crate::accel::GpuActivation {
match self {
Self::Relu => crate::accel::GpuActivation::Relu,
Self::Sigmoid => crate::accel::GpuActivation::Sigmoid,
Self::Tanh => crate::accel::GpuActivation::Tanh,
Self::Identity => crate::accel::GpuActivation::Identity,
}
}
#[allow(dead_code)]
pub(crate) fn forward_gpu(
self,
z: crate::accel::GpuTensor,
backend: &dyn crate::accel::ComputeBackend,
) -> crate::accel::GpuTensor {
match self {
Self::Relu => backend.gpu_relu(&z),
Self::Sigmoid => backend.gpu_sigmoid(&z),
Self::Tanh => backend.gpu_tanh(&z),
Self::Identity => z,
}
}
#[allow(dead_code)]
pub(crate) fn backward_gpu(
self,
grad: crate::accel::GpuTensor,
z: &crate::accel::GpuTensor,
activated: &crate::accel::GpuTensor,
backend: &dyn crate::accel::ComputeBackend,
) -> crate::accel::GpuTensor {
match self {
Self::Relu => backend.gpu_relu_backward(&grad, z),
Self::Sigmoid => backend.gpu_sigmoid_backward(&grad, activated),
Self::Tanh => backend.gpu_tanh_backward(&grad, activated),
Self::Identity => grad,
}
}
}
#[inline]
fn sigmoid(x: f64) -> f64 {
if x >= 0.0 {
let ex = (-x).exp();
1.0 / (1.0 + ex)
} else {
let ex = x.exp();
ex / (1.0 + ex)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn relu_forward() {
let mut z = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
Activation::Relu.forward(&mut z);
assert_eq!(z, vec![0.0, 0.0, 0.0, 1.0, 2.0]);
}
#[test]
fn relu_backward() {
let z = vec![-2.0, 0.0, 1.0, 3.0];
let activated = vec![0.0, 0.0, 1.0, 3.0];
let mut grad = vec![1.0, 1.0, 1.0, 1.0];
Activation::Relu.backward_from_activated(&z, &activated, &mut grad);
assert_eq!(grad, vec![0.0, 0.0, 1.0, 1.0]);
}
#[test]
fn sigmoid_forward() {
let mut z = vec![0.0];
Activation::Sigmoid.forward(&mut z);
assert!((z[0] - 0.5).abs() < 1e-10);
let mut z = vec![100.0];
Activation::Sigmoid.forward(&mut z);
assert!((z[0] - 1.0).abs() < 1e-10);
let mut z = vec![-100.0];
Activation::Sigmoid.forward(&mut z);
assert!(z[0].abs() < 1e-10);
}
#[test]
fn sigmoid_backward() {
let z = vec![0.0];
let activated = vec![0.5];
let mut grad = vec![1.0];
Activation::Sigmoid.backward_from_activated(&z, &activated, &mut grad);
assert!((grad[0] - 0.25).abs() < 1e-10);
}
#[test]
fn tanh_forward() {
let mut z = vec![0.0];
Activation::Tanh.forward(&mut z);
assert!(z[0].abs() < 1e-10);
}
#[test]
fn tanh_backward() {
let z = vec![0.0];
let activated = vec![0.0];
let mut grad = vec![1.0];
Activation::Tanh.backward_from_activated(&z, &activated, &mut grad);
assert!((grad[0] - 1.0).abs() < 1e-10);
}
#[test]
fn identity_is_noop() {
let mut z = vec![1.0, -2.0, 3.0];
let original = z.clone();
Activation::Identity.forward(&mut z);
assert_eq!(z, original);
let mut grad = vec![1.0, 2.0, 3.0];
let original_grad = grad.clone();
Activation::Identity.backward_from_activated(&z, &z, &mut grad);
assert_eq!(grad, original_grad);
}
#[test]
fn sigmoid_numerical_stability() {
let mut z = vec![-750.0];
Activation::Sigmoid.forward(&mut z);
assert!(z[0].is_finite());
assert!(z[0] >= 0.0);
}
}