1use burn_core as burn;
2
3use crate::activation::{
4 Gelu, HardSigmoid, HardSigmoidConfig, HardSwish, LeakyRelu, LeakyReluConfig, PRelu,
5 PReluConfig, Relu, Sigmoid, Softplus, SoftplusConfig, SwiGlu, SwiGluConfig, Tanh,
6};
7use burn::config::Config;
8use burn::module::Module;
9use burn::tensor::Tensor;
10use burn::tensor::backend::Backend;
11
12#[derive(Config, Debug)]
14#[non_exhaustive]
15pub enum ActivationConfig {
16 Gelu,
18
19 PRelu(PReluConfig),
21
22 Relu,
24
25 LeakyRelu(LeakyReluConfig),
27
28 SwiGlu(SwiGluConfig),
30
31 Sigmoid,
33
34 Tanh,
36
37 HardSigmoid(HardSigmoidConfig),
39
40 HardSwish,
42
43 Softplus(SoftplusConfig),
45}
46
47impl From<PReluConfig> for ActivationConfig {
48 fn from(config: PReluConfig) -> Self {
49 Self::PRelu(config)
50 }
51}
52
53impl From<LeakyReluConfig> for ActivationConfig {
54 fn from(config: LeakyReluConfig) -> Self {
55 Self::LeakyRelu(config)
56 }
57}
58
59impl From<SwiGluConfig> for ActivationConfig {
60 fn from(config: SwiGluConfig) -> Self {
61 Self::SwiGlu(config)
62 }
63}
64
65impl From<HardSigmoidConfig> for ActivationConfig {
66 fn from(config: HardSigmoidConfig) -> Self {
67 Self::HardSigmoid(config)
68 }
69}
70
71impl From<SoftplusConfig> for ActivationConfig {
72 fn from(config: SoftplusConfig) -> Self {
73 Self::Softplus(config)
74 }
75}
76
77impl ActivationConfig {
78 pub fn init<B: Backend>(&self, device: &B::Device) -> Activation<B> {
80 match self {
81 ActivationConfig::Relu => Relu.into(),
82 ActivationConfig::LeakyRelu(conf) => conf.init().into(),
83 ActivationConfig::Gelu => Gelu.into(),
84 ActivationConfig::PRelu(conf) => conf.init(device).into(),
85 ActivationConfig::SwiGlu(conf) => conf.init(device).into(),
86 ActivationConfig::HardSigmoid(conf) => conf.init().into(),
87 ActivationConfig::HardSwish => HardSwish.into(),
88 ActivationConfig::Softplus(conf) => conf.init().into(),
89 ActivationConfig::Sigmoid => Sigmoid.into(),
90 ActivationConfig::Tanh => Tanh.into(),
91 }
92 }
93}
94
95#[derive(Module, Debug)]
99#[non_exhaustive]
100#[allow(clippy::large_enum_variant)]
101pub enum Activation<B: Backend> {
102 Gelu(Gelu),
104
105 PRelu(PRelu<B>),
107
108 Relu(Relu),
110
111 LeakyRelu(LeakyRelu),
113
114 SwiGlu(SwiGlu<B>),
116
117 Sigmoid(Sigmoid),
119
120 Tanh(Tanh),
122
123 HardSigmoid(HardSigmoid),
125
126 HardSwish(HardSwish),
128
129 Softplus(Softplus),
131}
132
133impl<B: Backend> From<Gelu> for Activation<B> {
134 fn from(layer: Gelu) -> Self {
135 Self::Gelu(layer)
136 }
137}
138
139impl<B: Backend> From<PRelu<B>> for Activation<B> {
140 fn from(layer: PRelu<B>) -> Self {
141 Self::PRelu(layer)
142 }
143}
144
145impl<B: Backend> From<Relu> for Activation<B> {
146 fn from(layer: Relu) -> Self {
147 Self::Relu(layer)
148 }
149}
150
151impl<B: Backend> From<LeakyRelu> for Activation<B> {
152 fn from(layer: LeakyRelu) -> Self {
153 Self::LeakyRelu(layer)
154 }
155}
156
157impl<B: Backend> From<SwiGlu<B>> for Activation<B> {
158 fn from(layer: SwiGlu<B>) -> Self {
159 Self::SwiGlu(layer)
160 }
161}
162
163impl<B: Backend> From<Sigmoid> for Activation<B> {
164 fn from(layer: Sigmoid) -> Self {
165 Self::Sigmoid(layer)
166 }
167}
168
169impl<B: Backend> From<Tanh> for Activation<B> {
170 fn from(layer: Tanh) -> Self {
171 Self::Tanh(layer)
172 }
173}
174
175impl<B: Backend> From<HardSigmoid> for Activation<B> {
176 fn from(layer: HardSigmoid) -> Self {
177 Self::HardSigmoid(layer)
178 }
179}
180
181impl<B: Backend> From<HardSwish> for Activation<B> {
182 fn from(layer: HardSwish) -> Self {
183 Self::HardSwish(layer)
184 }
185}
186
187impl<B: Backend> From<Softplus> for Activation<B> {
188 fn from(layer: Softplus) -> Self {
189 Self::Softplus(layer)
190 }
191}
192
193impl<B: Backend> Activation<B> {
194 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
196 match self {
197 Activation::Relu(layer) => layer.forward(input),
198 Activation::LeakyRelu(layer) => layer.forward(input),
199 Activation::Gelu(layer) => layer.forward(input),
200 Activation::PRelu(layer) => layer.forward(input),
201 Activation::SwiGlu(layer) => layer.forward(input),
202 Activation::HardSigmoid(layer) => layer.forward(input),
203 Activation::HardSwish(layer) => layer.forward(input),
204 Activation::Softplus(layer) => layer.forward(input),
205 Activation::Sigmoid(layer) => layer.forward(input),
206 Activation::Tanh(layer) => layer.forward(input),
207 }
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::TestBackend;
215 use burn::module::Module;
216
217 fn make_input<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
218 Tensor::from_data([[-1.0, -0.5, 0.0], [1.0, 0.5, 0.0]], device)
219 }
220
221 fn expect_tensor<B: Backend, const D: usize>(actual: Tensor<B, D>, expected: Tensor<B, D>) {
222 actual.to_data().assert_eq(&expected.to_data(), true);
223 }
224
225 fn check_stateless_config_output<B: Backend, const D: usize>(
226 config: ActivationConfig,
227 input: Tensor<B, D>,
228 expected: Tensor<B, D>,
229 device: &B::Device,
230 ) {
231 let act = config.init(device);
232 let output = act.forward(input);
233 expect_tensor(output, expected);
234 }
235
236 #[test]
237 fn test_gelu() {
238 let device = Default::default();
239 let input = make_input::<TestBackend>(&device);
240
241 let expected = Gelu.forward(input.clone());
242
243 check_stateless_config_output(ActivationConfig::Gelu, input, expected, &device)
244 }
245
246 #[test]
247 fn test_prelu() {
248 let device = Default::default();
249 let input = make_input::<TestBackend>(&device);
250
251 let inner_config = PReluConfig::new();
252 let expected = inner_config.init(&device).forward(input.clone());
253
254 check_stateless_config_output(inner_config.into(), input, expected, &device)
255 }
256
257 #[test]
258 fn test_relu() {
259 let device = Default::default();
260 let input = make_input::<TestBackend>(&device);
261
262 let expected = Relu.forward(input.clone());
263
264 check_stateless_config_output(ActivationConfig::Relu, input, expected, &device)
265 }
266
267 #[test]
268 fn test_leaky_relu() {
269 let device = Default::default();
270 let input = make_input::<TestBackend>(&device);
271
272 let inner_config = LeakyReluConfig::new();
273 let expected = inner_config.init().forward(input.clone());
274
275 check_stateless_config_output(inner_config.into(), input, expected, &device)
276 }
277
278 #[test]
279 fn test_swi_glu() {
280 let device = Default::default();
281 let input = make_input::<TestBackend>(&device);
282
283 let d_input = input.shape().dims[1];
284 let d_output = 2 * d_input;
285
286 let inner_config = SwiGluConfig::new(d_input, d_output);
287 let mut reference: SwiGlu<TestBackend> = inner_config.init(&device);
288
289 let config: ActivationConfig = inner_config.into();
290 let layer = config.init(&device);
291
292 match &layer {
293 Activation::SwiGlu(inner) => {
294 let state = inner.clone().into_record();
296 reference = reference.load_record(state);
297 }
298 _ => unreachable!(),
299 };
300
301 expect_tensor(
302 layer.forward(input.clone()),
303 reference.forward(input.clone()),
304 )
305 }
306
307 #[test]
308 fn test_sigmoid() {
309 let device = Default::default();
310 let input = make_input::<TestBackend>(&device);
311
312 let expected = Sigmoid.forward(input.clone());
313
314 check_stateless_config_output(ActivationConfig::Sigmoid, input, expected, &device)
315 }
316
317 #[test]
318 fn test_tanh() {
319 let device = Default::default();
320 let input = make_input::<TestBackend>(&device);
321
322 let expected = Tanh.forward(input.clone());
323
324 check_stateless_config_output(ActivationConfig::Tanh, input, expected, &device)
325 }
326
327 #[test]
328 fn test_hard_sigmoid() {
329 let device = Default::default();
330 let input = make_input::<TestBackend>(&device);
331
332 let inner_config = HardSigmoidConfig::new();
333 let expected = inner_config.init().forward(input.clone());
334
335 check_stateless_config_output(inner_config.into(), input, expected, &device)
336 }
337
338 #[test]
339 fn test_softplus() {
340 let device = Default::default();
341 let input = make_input::<TestBackend>(&device);
342
343 let inner_config = SoftplusConfig::new();
344 let expected = inner_config.init().forward(input.clone());
345
346 check_stateless_config_output(inner_config.into(), input, expected, &device)
347 }
348}