1use burn_core as burn;
2
3use crate::activation::{
4 Celu, CeluConfig, Elu, EluConfig, Gelu, HardShrink, HardShrinkConfig, HardSigmoid,
5 HardSigmoidConfig, HardSwish, LeakyRelu, LeakyReluConfig, PRelu, PReluConfig, Relu, Selu,
6 Shrink, ShrinkConfig, Sigmoid, SoftShrink, SoftShrinkConfig, Softplus, SoftplusConfig,
7 Softsign, SwiGlu, SwiGluConfig, Tanh, ThresholdedRelu, ThresholdedReluConfig,
8};
9use burn::config::Config;
10use burn::module::Module;
11use burn::tensor::Tensor;
12use burn::tensor::backend::Backend;
13
14#[derive(Config, Debug)]
16#[non_exhaustive]
17pub enum ActivationConfig {
18 Gelu,
20
21 GeluApproximate,
23
24 PRelu(PReluConfig),
26
27 Relu,
29
30 LeakyRelu(LeakyReluConfig),
32
33 SwiGlu(SwiGluConfig),
35
36 Selu,
38
39 Sigmoid,
41
42 Tanh,
44
45 HardSigmoid(HardSigmoidConfig),
47
48 HardSwish,
50
51 Softplus(SoftplusConfig),
53
54 Softsign,
56
57 Elu(EluConfig),
59
60 Celu(CeluConfig),
62
63 ThresholdedRelu(ThresholdedReluConfig),
65
66 HardShrink(HardShrinkConfig),
68
69 SoftShrink(SoftShrinkConfig),
71
72 Shrink(ShrinkConfig),
74}
75
76impl From<PReluConfig> for ActivationConfig {
77 fn from(config: PReluConfig) -> Self {
78 Self::PRelu(config)
79 }
80}
81
82impl From<LeakyReluConfig> for ActivationConfig {
83 fn from(config: LeakyReluConfig) -> Self {
84 Self::LeakyRelu(config)
85 }
86}
87
88impl From<SwiGluConfig> for ActivationConfig {
89 fn from(config: SwiGluConfig) -> Self {
90 Self::SwiGlu(config)
91 }
92}
93
94impl From<HardSigmoidConfig> for ActivationConfig {
95 fn from(config: HardSigmoidConfig) -> Self {
96 Self::HardSigmoid(config)
97 }
98}
99
100impl From<SoftplusConfig> for ActivationConfig {
101 fn from(config: SoftplusConfig) -> Self {
102 Self::Softplus(config)
103 }
104}
105
106impl From<EluConfig> for ActivationConfig {
107 fn from(config: EluConfig) -> Self {
108 Self::Elu(config)
109 }
110}
111
112impl From<CeluConfig> for ActivationConfig {
113 fn from(config: CeluConfig) -> Self {
114 Self::Celu(config)
115 }
116}
117
118impl From<ThresholdedReluConfig> for ActivationConfig {
119 fn from(config: ThresholdedReluConfig) -> Self {
120 Self::ThresholdedRelu(config)
121 }
122}
123
124impl From<HardShrinkConfig> for ActivationConfig {
125 fn from(config: HardShrinkConfig) -> Self {
126 Self::HardShrink(config)
127 }
128}
129
130impl From<SoftShrinkConfig> for ActivationConfig {
131 fn from(config: SoftShrinkConfig) -> Self {
132 Self::SoftShrink(config)
133 }
134}
135
136impl From<ShrinkConfig> for ActivationConfig {
137 fn from(config: ShrinkConfig) -> Self {
138 Self::Shrink(config)
139 }
140}
141
142impl ActivationConfig {
143 pub fn init<B: Backend>(&self, device: &B::Device) -> Activation<B> {
145 match self {
146 ActivationConfig::Relu => Relu.into(),
147 ActivationConfig::LeakyRelu(conf) => conf.init().into(),
148 ActivationConfig::Gelu => Gelu::new().into(),
149 ActivationConfig::GeluApproximate => Gelu::new_approximate().into(),
150 ActivationConfig::PRelu(conf) => conf.init(device).into(),
151 ActivationConfig::SwiGlu(conf) => conf.init(device).into(),
152 ActivationConfig::HardSigmoid(conf) => conf.init().into(),
153 ActivationConfig::HardSwish => HardSwish.into(),
154 ActivationConfig::Softplus(conf) => conf.init().into(),
155 ActivationConfig::Selu => Selu.into(),
156 ActivationConfig::Sigmoid => Sigmoid.into(),
157 ActivationConfig::Tanh => Tanh.into(),
158 ActivationConfig::Softsign => Softsign.into(),
159 ActivationConfig::Elu(conf) => conf.init().into(),
160 ActivationConfig::Celu(conf) => conf.init().into(),
161 ActivationConfig::HardShrink(conf) => conf.init().into(),
162 ActivationConfig::SoftShrink(conf) => conf.init().into(),
163 ActivationConfig::Shrink(conf) => conf.init().into(),
164 ActivationConfig::ThresholdedRelu(conf) => conf.init().into(),
165 }
166 }
167}
168
169#[derive(Module, Debug)]
173#[non_exhaustive]
174#[allow(clippy::large_enum_variant)]
175pub enum Activation<B: Backend> {
176 Gelu(Gelu),
178
179 PRelu(PRelu<B>),
181
182 Relu(Relu),
184
185 LeakyRelu(LeakyRelu),
187
188 SwiGlu(SwiGlu<B>),
190
191 Selu(Selu),
193
194 Sigmoid(Sigmoid),
196
197 Tanh(Tanh),
199
200 HardSigmoid(HardSigmoid),
202
203 HardSwish(HardSwish),
205
206 Softplus(Softplus),
208
209 Softsign(Softsign),
211
212 Elu(Elu),
214
215 Celu(Celu),
217
218 ThresholdedRelu(ThresholdedRelu),
220
221 HardShrink(HardShrink),
223
224 SoftShrink(SoftShrink),
226
227 Shrink(Shrink),
229}
230
231impl<B: Backend> From<Gelu> for Activation<B> {
232 fn from(layer: Gelu) -> Self {
233 Self::Gelu(layer)
234 }
235}
236
237impl<B: Backend> From<PRelu<B>> for Activation<B> {
238 fn from(layer: PRelu<B>) -> Self {
239 Self::PRelu(layer)
240 }
241}
242
243impl<B: Backend> From<Relu> for Activation<B> {
244 fn from(layer: Relu) -> Self {
245 Self::Relu(layer)
246 }
247}
248
249impl<B: Backend> From<LeakyRelu> for Activation<B> {
250 fn from(layer: LeakyRelu) -> Self {
251 Self::LeakyRelu(layer)
252 }
253}
254
255impl<B: Backend> From<SwiGlu<B>> for Activation<B> {
256 fn from(layer: SwiGlu<B>) -> Self {
257 Self::SwiGlu(layer)
258 }
259}
260
261impl<B: Backend> From<Selu> for Activation<B> {
262 fn from(layer: Selu) -> Self {
263 Self::Selu(layer)
264 }
265}
266
267impl<B: Backend> From<Sigmoid> for Activation<B> {
268 fn from(layer: Sigmoid) -> Self {
269 Self::Sigmoid(layer)
270 }
271}
272
273impl<B: Backend> From<Tanh> for Activation<B> {
274 fn from(layer: Tanh) -> Self {
275 Self::Tanh(layer)
276 }
277}
278
279impl<B: Backend> From<HardSigmoid> for Activation<B> {
280 fn from(layer: HardSigmoid) -> Self {
281 Self::HardSigmoid(layer)
282 }
283}
284
285impl<B: Backend> From<HardSwish> for Activation<B> {
286 fn from(layer: HardSwish) -> Self {
287 Self::HardSwish(layer)
288 }
289}
290
291impl<B: Backend> From<Softplus> for Activation<B> {
292 fn from(layer: Softplus) -> Self {
293 Self::Softplus(layer)
294 }
295}
296
297impl<B: Backend> From<Softsign> for Activation<B> {
298 fn from(layer: Softsign) -> Self {
299 Self::Softsign(layer)
300 }
301}
302
303impl<B: Backend> From<Elu> for Activation<B> {
304 fn from(layer: Elu) -> Self {
305 Self::Elu(layer)
306 }
307}
308
309impl<B: Backend> From<Celu> for Activation<B> {
310 fn from(layer: Celu) -> Self {
311 Self::Celu(layer)
312 }
313}
314
315impl<B: Backend> From<ThresholdedRelu> for Activation<B> {
316 fn from(layer: ThresholdedRelu) -> Self {
317 Self::ThresholdedRelu(layer)
318 }
319}
320
321impl<B: Backend> From<HardShrink> for Activation<B> {
322 fn from(layer: HardShrink) -> Self {
323 Self::HardShrink(layer)
324 }
325}
326
327impl<B: Backend> From<SoftShrink> for Activation<B> {
328 fn from(layer: SoftShrink) -> Self {
329 Self::SoftShrink(layer)
330 }
331}
332
333impl<B: Backend> From<Shrink> for Activation<B> {
334 fn from(layer: Shrink) -> Self {
335 Self::Shrink(layer)
336 }
337}
338
339impl<B: Backend> Activation<B> {
340 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
342 match self {
343 Activation::Relu(layer) => layer.forward(input),
344 Activation::LeakyRelu(layer) => layer.forward(input),
345 Activation::Gelu(layer) => layer.forward(input),
346 Activation::PRelu(layer) => layer.forward(input),
347 Activation::SwiGlu(layer) => layer.forward(input),
348 Activation::HardSigmoid(layer) => layer.forward(input),
349 Activation::HardSwish(layer) => layer.forward(input),
350 Activation::Softplus(layer) => layer.forward(input),
351 Activation::Selu(layer) => layer.forward(input),
352 Activation::Sigmoid(layer) => layer.forward(input),
353 Activation::Tanh(layer) => layer.forward(input),
354 Activation::Softsign(layer) => layer.forward(input),
355 Activation::Elu(layer) => layer.forward(input),
356 Activation::Celu(layer) => layer.forward(input),
357 Activation::ThresholdedRelu(layer) => layer.forward(input),
358 Activation::HardShrink(layer) => layer.forward(input),
359 Activation::SoftShrink(layer) => layer.forward(input),
360 Activation::Shrink(layer) => layer.forward(input),
361 }
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use crate::TestBackend;
369 use burn::module::Module;
370
371 fn make_input<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
372 Tensor::from_data([[-1.0, -0.5, 0.0], [1.0, 0.5, 0.0]], device)
373 }
374
375 fn expect_tensor<B: Backend, const D: usize>(actual: Tensor<B, D>, expected: Tensor<B, D>) {
376 actual.to_data().assert_eq(&expected.to_data(), true);
377 }
378
379 fn check_stateless_config_output<B: Backend, const D: usize>(
380 config: ActivationConfig,
381 input: Tensor<B, D>,
382 expected: Tensor<B, D>,
383 device: &B::Device,
384 ) {
385 let act = config.init(device);
386 let output = act.forward(input);
387 expect_tensor(output, expected);
388 }
389
390 #[test]
391 fn test_gelu() {
392 let device = Default::default();
393 let input = make_input::<TestBackend>(&device);
394
395 let expected = Gelu::new().forward(input.clone());
396
397 check_stateless_config_output(ActivationConfig::Gelu, input, expected, &device)
398 }
399
400 #[test]
401 fn test_gelu_approximate() {
402 let device = Default::default();
403 let input = make_input::<TestBackend>(&device);
404
405 let expected = Gelu::new_approximate().forward(input.clone());
406
407 check_stateless_config_output(ActivationConfig::GeluApproximate, input, expected, &device)
408 }
409
410 #[test]
411 fn test_prelu() {
412 let device = Default::default();
413 let input = make_input::<TestBackend>(&device);
414
415 let inner_config = PReluConfig::new();
416 let expected = inner_config.init(&device).forward(input.clone());
417
418 check_stateless_config_output(inner_config.into(), input, expected, &device)
419 }
420
421 #[test]
422 fn test_relu() {
423 let device = Default::default();
424 let input = make_input::<TestBackend>(&device);
425
426 let expected = Relu.forward(input.clone());
427
428 check_stateless_config_output(ActivationConfig::Relu, input, expected, &device)
429 }
430
431 #[test]
432 fn test_leaky_relu() {
433 let device = Default::default();
434 let input = make_input::<TestBackend>(&device);
435
436 let inner_config = LeakyReluConfig::new();
437 let expected = inner_config.init().forward(input.clone());
438
439 check_stateless_config_output(inner_config.into(), input, expected, &device)
440 }
441
442 #[test]
443 fn test_swi_glu() {
444 let device = Default::default();
445 let input = make_input::<TestBackend>(&device);
446
447 let d_input = input.shape()[1];
448 let d_output = 2 * d_input;
449
450 let inner_config = SwiGluConfig::new(d_input, d_output);
451 let mut reference: SwiGlu<TestBackend> = inner_config.init(&device);
452
453 let config: ActivationConfig = inner_config.into();
454 let layer = config.init(&device);
455
456 let layer_output = layer.forward(input.clone());
458
459 match &layer {
460 Activation::SwiGlu(inner) => {
461 let state = inner.clone().into_record();
462 reference = reference.load_record(state);
463 }
464 _ => unreachable!(),
465 };
466
467 expect_tensor(layer_output, reference.forward(input.clone()))
468 }
469
470 #[test]
471 fn test_selu() {
472 let device = Default::default();
473 let input = make_input::<TestBackend>(&device);
474
475 let expected = Selu.forward(input.clone());
476
477 check_stateless_config_output(ActivationConfig::Selu, input, expected, &device)
478 }
479
480 #[test]
481 fn test_sigmoid() {
482 let device = Default::default();
483 let input = make_input::<TestBackend>(&device);
484
485 let expected = Sigmoid.forward(input.clone());
486
487 check_stateless_config_output(ActivationConfig::Sigmoid, input, expected, &device)
488 }
489
490 #[test]
491 fn test_tanh() {
492 let device = Default::default();
493 let input = make_input::<TestBackend>(&device);
494
495 let expected = Tanh.forward(input.clone());
496
497 check_stateless_config_output(ActivationConfig::Tanh, input, expected, &device)
498 }
499
500 #[test]
501 fn test_hard_sigmoid() {
502 let device = Default::default();
503 let input = make_input::<TestBackend>(&device);
504
505 let inner_config = HardSigmoidConfig::new();
506 let expected = inner_config.init().forward(input.clone());
507
508 check_stateless_config_output(inner_config.into(), input, expected, &device)
509 }
510
511 #[test]
512 fn test_softsign() {
513 let device = Default::default();
514 let input = make_input::<TestBackend>(&device);
515
516 let expected = Softsign.forward(input.clone());
517
518 check_stateless_config_output(ActivationConfig::Softsign, input, expected, &device)
519 }
520
521 #[test]
522 fn test_elu() {
523 let device = Default::default();
524 let input = make_input::<TestBackend>(&device);
525
526 let inner_config = EluConfig::new();
527 let expected = inner_config.init().forward(input.clone());
528
529 check_stateless_config_output(inner_config.into(), input, expected, &device)
530 }
531
532 #[test]
533 fn test_softplus() {
534 let device = Default::default();
535 let input = make_input::<TestBackend>(&device);
536
537 let inner_config = SoftplusConfig::new();
538 let expected = inner_config.init().forward(input.clone());
539
540 check_stateless_config_output(inner_config.into(), input, expected, &device)
541 }
542
543 #[test]
544 fn test_celu() {
545 let device = Default::default();
546 let input = make_input::<TestBackend>(&device);
547
548 let inner_config = CeluConfig::new();
549 let expected = inner_config.init().forward(input.clone());
550
551 check_stateless_config_output(inner_config.into(), input, expected, &device)
552 }
553
554 #[test]
555 fn test_thresholded_relu() {
556 let device = Default::default();
557 let input = make_input::<TestBackend>(&device);
558
559 let inner_config = ThresholdedReluConfig::new();
560 let expected = inner_config.init().forward(input.clone());
561
562 check_stateless_config_output(inner_config.into(), input, expected, &device)
563 }
564
565 #[test]
566 fn test_hard_shrink() {
567 let device = Default::default();
568 let input = make_input::<TestBackend>(&device);
569
570 let inner_config = HardShrinkConfig::new();
571 let expected = inner_config.init().forward(input.clone());
572
573 check_stateless_config_output(inner_config.into(), input, expected, &device)
574 }
575
576 #[test]
577 fn test_soft_shrink() {
578 let device = Default::default();
579 let input = make_input::<TestBackend>(&device);
580
581 let inner_config = SoftShrinkConfig::new();
582 let expected = inner_config.init().forward(input.clone());
583
584 check_stateless_config_output(inner_config.into(), input, expected, &device)
585 }
586
587 #[test]
588 fn test_shrink() {
589 let device = Default::default();
590 let input = make_input::<TestBackend>(&device);
591
592 let inner_config = ShrinkConfig::new();
593 let expected = inner_config.init().forward(input.clone());
594
595 check_stateless_config_output(inner_config.into(), input, expected, &device)
596 }
597}