1use crate::{Error, Result};
11
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum Activation {
15 Tanh,
16 ReLU,
17 LeakyReLU { alpha: f32 },
18 Sigmoid,
19 Identity,
20}
21
22impl Activation {
23 pub fn validate(self) -> Result<()> {
25 match self {
26 Activation::LeakyReLU { alpha } => {
27 if !(alpha.is_finite() && alpha >= 0.0) {
28 return Err(Error::InvalidConfig(format!(
29 "leaky ReLU alpha must be finite and >= 0, got {alpha}"
30 )));
31 }
32 }
33 Activation::Tanh | Activation::ReLU | Activation::Sigmoid | Activation::Identity => {}
34 }
35
36 Ok(())
37 }
38
39 #[inline]
40 pub(crate) fn forward(self, x: f32) -> f32 {
41 match self {
42 Activation::Tanh => x.tanh(),
43 Activation::ReLU => x.max(0.0),
44 Activation::LeakyReLU { alpha } => {
45 if x > 0.0 {
46 x
47 } else {
48 alpha * x
49 }
50 }
51 Activation::Sigmoid => sigmoid(x),
52 Activation::Identity => x,
53 }
54 }
55
56 #[inline]
59 pub(crate) fn grad_from_output(self, y: f32) -> f32 {
60 match self {
61 Activation::Tanh => 1.0 - y * y,
62 Activation::ReLU => {
63 if y > 0.0 {
64 1.0
65 } else {
66 0.0
67 }
68 }
69 Activation::LeakyReLU { alpha } => {
70 if y > 0.0 {
71 1.0
72 } else {
73 alpha
74 }
75 }
76 Activation::Sigmoid => y * (1.0 - y),
77 Activation::Identity => 1.0,
78 }
79 }
80}
81
82#[inline]
83fn sigmoid(x: f32) -> f32 {
84 if x >= 0.0 {
86 let z = (-x).exp();
87 1.0 / (1.0 + z)
88 } else {
89 let z = x.exp();
90 z / (1.0 + z)
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[test]
99 fn leaky_relu_alpha_must_be_finite_and_non_negative() {
100 assert!(
101 Activation::LeakyReLU { alpha: f32::NAN }
102 .validate()
103 .is_err()
104 );
105 assert!(Activation::LeakyReLU { alpha: -0.1 }.validate().is_err());
106 assert!(Activation::LeakyReLU { alpha: 0.1 }.validate().is_ok());
107 }
108
109 #[test]
110 fn sigmoid_basic_values() {
111 let y0 = Activation::Sigmoid.forward(0.0);
112 assert!((y0 - 0.5).abs() < 1e-6);
113
114 let y_pos = Activation::Sigmoid.forward(10.0);
115 let y_neg = Activation::Sigmoid.forward(-10.0);
116 assert!(y_pos > 0.999);
117 assert!(y_neg < 0.001);
118 }
119
120 #[test]
121 fn relu_and_leaky_relu_shapes() {
122 assert_eq!(Activation::ReLU.forward(-2.0), 0.0);
123 assert_eq!(Activation::ReLU.forward(3.0), 3.0);
124
125 let act = Activation::LeakyReLU { alpha: 0.1 };
126 assert_eq!(act.forward(-2.0), -0.2);
127 assert_eq!(act.forward(3.0), 3.0);
128
129 assert_eq!(Activation::ReLU.grad_from_output(0.0), 0.0);
131 assert_eq!(Activation::ReLU.grad_from_output(1.0), 1.0);
132 assert_eq!(act.grad_from_output(-0.2), 0.1);
133 assert_eq!(act.grad_from_output(3.0), 1.0);
134 }
135
136 #[test]
137 fn tanh_and_sigmoid_gradients_from_output() {
138 let y_tanh = Activation::Tanh.forward(0.3);
139 let g_tanh = Activation::Tanh.grad_from_output(y_tanh);
140 assert!((g_tanh - (1.0 - y_tanh * y_tanh)).abs() < 1e-6);
141
142 let y_sig = Activation::Sigmoid.forward(0.0);
143 let g_sig = Activation::Sigmoid.grad_from_output(y_sig);
144 assert!((g_sig - 0.25).abs() < 1e-6);
145 }
146}