Skip to main content

burn_nn/activation/
activation_wrapper.rs

1use burn_core as burn;
2
3use crate::activation::{
4    Celu, CeluConfig, Elu, EluConfig, Gelu, HardShrink, HardShrinkConfig, HardSigmoid,
5    HardSigmoidConfig, HardSwish, LeakyRelu, LeakyReluConfig, PRelu, PReluConfig, Relu, Selu,
6    Shrink, ShrinkConfig, Sigmoid, SoftShrink, SoftShrinkConfig, Softplus, SoftplusConfig,
7    Softsign, SwiGlu, SwiGluConfig, Tanh, ThresholdedRelu, ThresholdedReluConfig,
8};
9use burn::config::Config;
10use burn::module::Module;
11use burn::tensor::Tensor;
12use burn::tensor::backend::Backend;
13
14/// [`Activation`] Configuration.
15#[derive(Config, Debug)]
16#[non_exhaustive]
17pub enum ActivationConfig {
18    /// [`Gelu`] activation layer.
19    Gelu,
20
21    /// [`Gelu`] activation layer with tanh approximation.
22    GeluApproximate,
23
24    /// [`PRelu`] activation layer.
25    PRelu(PReluConfig),
26
27    /// [`Relu`] activation layer.
28    Relu,
29
30    /// [`LeakyRelu`] activation layer.
31    LeakyRelu(LeakyReluConfig),
32
33    /// [`SwiGlu`] activation layer.
34    SwiGlu(SwiGluConfig),
35
36    /// [`Selu`] activation layer.
37    Selu,
38
39    /// [`Sigmoid`] activation layer.
40    Sigmoid,
41
42    /// [`Tanh`] activation layer.
43    Tanh,
44
45    /// [`HardSigmoid`] activation layer.
46    HardSigmoid(HardSigmoidConfig),
47
48    /// [`HardSwish`] activation layer.
49    HardSwish,
50
51    /// [`Softplus`] activation layer.
52    Softplus(SoftplusConfig),
53
54    /// [`Softsign`] activation layer.
55    Softsign,
56
57    /// [`Elu`] activation layer.
58    Elu(EluConfig),
59
60    /// [`Celu`] activation layer.
61    Celu(CeluConfig),
62
63    /// [`ThresholdedRelu`] activation layer.
64    ThresholdedRelu(ThresholdedReluConfig),
65
66    /// [`HardShrink`] activation layer.
67    HardShrink(HardShrinkConfig),
68
69    /// [`SoftShrink`] activation layer.
70    SoftShrink(SoftShrinkConfig),
71
72    /// [`Shrink`] activation layer.
73    Shrink(ShrinkConfig),
74}
75
76impl From<PReluConfig> for ActivationConfig {
77    fn from(config: PReluConfig) -> Self {
78        Self::PRelu(config)
79    }
80}
81
82impl From<LeakyReluConfig> for ActivationConfig {
83    fn from(config: LeakyReluConfig) -> Self {
84        Self::LeakyRelu(config)
85    }
86}
87
88impl From<SwiGluConfig> for ActivationConfig {
89    fn from(config: SwiGluConfig) -> Self {
90        Self::SwiGlu(config)
91    }
92}
93
94impl From<HardSigmoidConfig> for ActivationConfig {
95    fn from(config: HardSigmoidConfig) -> Self {
96        Self::HardSigmoid(config)
97    }
98}
99
100impl From<SoftplusConfig> for ActivationConfig {
101    fn from(config: SoftplusConfig) -> Self {
102        Self::Softplus(config)
103    }
104}
105
106impl From<EluConfig> for ActivationConfig {
107    fn from(config: EluConfig) -> Self {
108        Self::Elu(config)
109    }
110}
111
112impl From<CeluConfig> for ActivationConfig {
113    fn from(config: CeluConfig) -> Self {
114        Self::Celu(config)
115    }
116}
117
118impl From<ThresholdedReluConfig> for ActivationConfig {
119    fn from(config: ThresholdedReluConfig) -> Self {
120        Self::ThresholdedRelu(config)
121    }
122}
123
124impl From<HardShrinkConfig> for ActivationConfig {
125    fn from(config: HardShrinkConfig) -> Self {
126        Self::HardShrink(config)
127    }
128}
129
130impl From<SoftShrinkConfig> for ActivationConfig {
131    fn from(config: SoftShrinkConfig) -> Self {
132        Self::SoftShrink(config)
133    }
134}
135
136impl From<ShrinkConfig> for ActivationConfig {
137    fn from(config: ShrinkConfig) -> Self {
138        Self::Shrink(config)
139    }
140}
141
142impl ActivationConfig {
143    /// Initialize a wrapped activation layer.
144    pub fn init<B: Backend>(&self, device: &B::Device) -> Activation<B> {
145        match self {
146            ActivationConfig::Relu => Relu.into(),
147            ActivationConfig::LeakyRelu(conf) => conf.init().into(),
148            ActivationConfig::Gelu => Gelu::new().into(),
149            ActivationConfig::GeluApproximate => Gelu::new_approximate().into(),
150            ActivationConfig::PRelu(conf) => conf.init(device).into(),
151            ActivationConfig::SwiGlu(conf) => conf.init(device).into(),
152            ActivationConfig::HardSigmoid(conf) => conf.init().into(),
153            ActivationConfig::HardSwish => HardSwish.into(),
154            ActivationConfig::Softplus(conf) => conf.init().into(),
155            ActivationConfig::Selu => Selu.into(),
156            ActivationConfig::Sigmoid => Sigmoid.into(),
157            ActivationConfig::Tanh => Tanh.into(),
158            ActivationConfig::Softsign => Softsign.into(),
159            ActivationConfig::Elu(conf) => conf.init().into(),
160            ActivationConfig::Celu(conf) => conf.init().into(),
161            ActivationConfig::HardShrink(conf) => conf.init().into(),
162            ActivationConfig::SoftShrink(conf) => conf.init().into(),
163            ActivationConfig::Shrink(conf) => conf.init().into(),
164            ActivationConfig::ThresholdedRelu(conf) => conf.init().into(),
165        }
166    }
167}
168
169/// Activation Layer Wrapper.
170///
171/// Provides support for many in-built `burn::nn` activations.
172#[derive(Module, Debug)]
173#[non_exhaustive]
174#[allow(clippy::large_enum_variant)]
175pub enum Activation<B: Backend> {
176    /// [`Gelu`] activation layer.
177    Gelu(Gelu),
178
179    /// [`PRelu`] activation layer.
180    PRelu(PRelu<B>),
181
182    /// [`Relu`] activation layer.
183    Relu(Relu),
184
185    /// [`LeakyRelu`] activation layer.
186    LeakyRelu(LeakyRelu),
187
188    /// [`SwiGlu`] activation layer.
189    SwiGlu(SwiGlu<B>),
190
191    /// [`Selu`] activation layer.
192    Selu(Selu),
193
194    /// [`Sigmoid`] activation layer.
195    Sigmoid(Sigmoid),
196
197    /// [`Tanh`] activation layer.
198    Tanh(Tanh),
199
200    /// [`HardSigmoid`] activation layer.
201    HardSigmoid(HardSigmoid),
202
203    /// [`HardSwish`] activation layer.
204    HardSwish(HardSwish),
205
206    /// [`Softplus`] activation layer.
207    Softplus(Softplus),
208
209    /// [`Softsign`] activation layer.
210    Softsign(Softsign),
211
212    /// [`Elu`] activation layer.
213    Elu(Elu),
214
215    /// [`Celu`] activation layer.
216    Celu(Celu),
217
218    /// [`ThresholdedRelu`] activation layer.
219    ThresholdedRelu(ThresholdedRelu),
220
221    /// [`HardShrink`] activation layer.
222    HardShrink(HardShrink),
223
224    /// [`SoftShrink`] activation layer.
225    SoftShrink(SoftShrink),
226
227    /// [`Shrink`] activation layer.
228    Shrink(Shrink),
229}
230
231impl<B: Backend> From<Gelu> for Activation<B> {
232    fn from(layer: Gelu) -> Self {
233        Self::Gelu(layer)
234    }
235}
236
237impl<B: Backend> From<PRelu<B>> for Activation<B> {
238    fn from(layer: PRelu<B>) -> Self {
239        Self::PRelu(layer)
240    }
241}
242
243impl<B: Backend> From<Relu> for Activation<B> {
244    fn from(layer: Relu) -> Self {
245        Self::Relu(layer)
246    }
247}
248
249impl<B: Backend> From<LeakyRelu> for Activation<B> {
250    fn from(layer: LeakyRelu) -> Self {
251        Self::LeakyRelu(layer)
252    }
253}
254
255impl<B: Backend> From<SwiGlu<B>> for Activation<B> {
256    fn from(layer: SwiGlu<B>) -> Self {
257        Self::SwiGlu(layer)
258    }
259}
260
261impl<B: Backend> From<Selu> for Activation<B> {
262    fn from(layer: Selu) -> Self {
263        Self::Selu(layer)
264    }
265}
266
267impl<B: Backend> From<Sigmoid> for Activation<B> {
268    fn from(layer: Sigmoid) -> Self {
269        Self::Sigmoid(layer)
270    }
271}
272
273impl<B: Backend> From<Tanh> for Activation<B> {
274    fn from(layer: Tanh) -> Self {
275        Self::Tanh(layer)
276    }
277}
278
279impl<B: Backend> From<HardSigmoid> for Activation<B> {
280    fn from(layer: HardSigmoid) -> Self {
281        Self::HardSigmoid(layer)
282    }
283}
284
285impl<B: Backend> From<HardSwish> for Activation<B> {
286    fn from(layer: HardSwish) -> Self {
287        Self::HardSwish(layer)
288    }
289}
290
291impl<B: Backend> From<Softplus> for Activation<B> {
292    fn from(layer: Softplus) -> Self {
293        Self::Softplus(layer)
294    }
295}
296
297impl<B: Backend> From<Softsign> for Activation<B> {
298    fn from(layer: Softsign) -> Self {
299        Self::Softsign(layer)
300    }
301}
302
303impl<B: Backend> From<Elu> for Activation<B> {
304    fn from(layer: Elu) -> Self {
305        Self::Elu(layer)
306    }
307}
308
309impl<B: Backend> From<Celu> for Activation<B> {
310    fn from(layer: Celu) -> Self {
311        Self::Celu(layer)
312    }
313}
314
315impl<B: Backend> From<ThresholdedRelu> for Activation<B> {
316    fn from(layer: ThresholdedRelu) -> Self {
317        Self::ThresholdedRelu(layer)
318    }
319}
320
321impl<B: Backend> From<HardShrink> for Activation<B> {
322    fn from(layer: HardShrink) -> Self {
323        Self::HardShrink(layer)
324    }
325}
326
327impl<B: Backend> From<SoftShrink> for Activation<B> {
328    fn from(layer: SoftShrink) -> Self {
329        Self::SoftShrink(layer)
330    }
331}
332
333impl<B: Backend> From<Shrink> for Activation<B> {
334    fn from(layer: Shrink) -> Self {
335        Self::Shrink(layer)
336    }
337}
338
339impl<B: Backend> Activation<B> {
340    /// Forward pass.
341    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
342        match self {
343            Activation::Relu(layer) => layer.forward(input),
344            Activation::LeakyRelu(layer) => layer.forward(input),
345            Activation::Gelu(layer) => layer.forward(input),
346            Activation::PRelu(layer) => layer.forward(input),
347            Activation::SwiGlu(layer) => layer.forward(input),
348            Activation::HardSigmoid(layer) => layer.forward(input),
349            Activation::HardSwish(layer) => layer.forward(input),
350            Activation::Softplus(layer) => layer.forward(input),
351            Activation::Selu(layer) => layer.forward(input),
352            Activation::Sigmoid(layer) => layer.forward(input),
353            Activation::Tanh(layer) => layer.forward(input),
354            Activation::Softsign(layer) => layer.forward(input),
355            Activation::Elu(layer) => layer.forward(input),
356            Activation::Celu(layer) => layer.forward(input),
357            Activation::ThresholdedRelu(layer) => layer.forward(input),
358            Activation::HardShrink(layer) => layer.forward(input),
359            Activation::SoftShrink(layer) => layer.forward(input),
360            Activation::Shrink(layer) => layer.forward(input),
361        }
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use crate::TestBackend;
369    use burn::module::Module;
370
371    fn make_input<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
372        Tensor::from_data([[-1.0, -0.5, 0.0], [1.0, 0.5, 0.0]], device)
373    }
374
375    fn expect_tensor<B: Backend, const D: usize>(actual: Tensor<B, D>, expected: Tensor<B, D>) {
376        actual.to_data().assert_eq(&expected.to_data(), true);
377    }
378
379    fn check_stateless_config_output<B: Backend, const D: usize>(
380        config: ActivationConfig,
381        input: Tensor<B, D>,
382        expected: Tensor<B, D>,
383        device: &B::Device,
384    ) {
385        let act = config.init(device);
386        let output = act.forward(input);
387        expect_tensor(output, expected);
388    }
389
390    #[test]
391    fn test_gelu() {
392        let device = Default::default();
393        let input = make_input::<TestBackend>(&device);
394
395        let expected = Gelu::new().forward(input.clone());
396
397        check_stateless_config_output(ActivationConfig::Gelu, input, expected, &device)
398    }
399
400    #[test]
401    fn test_gelu_approximate() {
402        let device = Default::default();
403        let input = make_input::<TestBackend>(&device);
404
405        let expected = Gelu::new_approximate().forward(input.clone());
406
407        check_stateless_config_output(ActivationConfig::GeluApproximate, input, expected, &device)
408    }
409
410    #[test]
411    fn test_prelu() {
412        let device = Default::default();
413        let input = make_input::<TestBackend>(&device);
414
415        let inner_config = PReluConfig::new();
416        let expected = inner_config.init(&device).forward(input.clone());
417
418        check_stateless_config_output(inner_config.into(), input, expected, &device)
419    }
420
421    #[test]
422    fn test_relu() {
423        let device = Default::default();
424        let input = make_input::<TestBackend>(&device);
425
426        let expected = Relu.forward(input.clone());
427
428        check_stateless_config_output(ActivationConfig::Relu, input, expected, &device)
429    }
430
431    #[test]
432    fn test_leaky_relu() {
433        let device = Default::default();
434        let input = make_input::<TestBackend>(&device);
435
436        let inner_config = LeakyReluConfig::new();
437        let expected = inner_config.init().forward(input.clone());
438
439        check_stateless_config_output(inner_config.into(), input, expected, &device)
440    }
441
442    #[test]
443    fn test_swi_glu() {
444        let device = Default::default();
445        let input = make_input::<TestBackend>(&device);
446
447        let d_input = input.shape()[1];
448        let d_output = 2 * d_input;
449
450        let inner_config = SwiGluConfig::new(d_input, d_output);
451        let mut reference: SwiGlu<TestBackend> = inner_config.init(&device);
452
453        let config: ActivationConfig = inner_config.into();
454        let layer = config.init(&device);
455
456        // Access tensors via forward pass to trigger lazy initialization, then clone weights.
457        let layer_output = layer.forward(input.clone());
458
459        match &layer {
460            Activation::SwiGlu(inner) => {
461                let state = inner.clone().into_record();
462                reference = reference.load_record(state);
463            }
464            _ => unreachable!(),
465        };
466
467        expect_tensor(layer_output, reference.forward(input.clone()))
468    }
469
470    #[test]
471    fn test_selu() {
472        let device = Default::default();
473        let input = make_input::<TestBackend>(&device);
474
475        let expected = Selu.forward(input.clone());
476
477        check_stateless_config_output(ActivationConfig::Selu, input, expected, &device)
478    }
479
480    #[test]
481    fn test_sigmoid() {
482        let device = Default::default();
483        let input = make_input::<TestBackend>(&device);
484
485        let expected = Sigmoid.forward(input.clone());
486
487        check_stateless_config_output(ActivationConfig::Sigmoid, input, expected, &device)
488    }
489
490    #[test]
491    fn test_tanh() {
492        let device = Default::default();
493        let input = make_input::<TestBackend>(&device);
494
495        let expected = Tanh.forward(input.clone());
496
497        check_stateless_config_output(ActivationConfig::Tanh, input, expected, &device)
498    }
499
500    #[test]
501    fn test_hard_sigmoid() {
502        let device = Default::default();
503        let input = make_input::<TestBackend>(&device);
504
505        let inner_config = HardSigmoidConfig::new();
506        let expected = inner_config.init().forward(input.clone());
507
508        check_stateless_config_output(inner_config.into(), input, expected, &device)
509    }
510
511    #[test]
512    fn test_softsign() {
513        let device = Default::default();
514        let input = make_input::<TestBackend>(&device);
515
516        let expected = Softsign.forward(input.clone());
517
518        check_stateless_config_output(ActivationConfig::Softsign, input, expected, &device)
519    }
520
521    #[test]
522    fn test_elu() {
523        let device = Default::default();
524        let input = make_input::<TestBackend>(&device);
525
526        let inner_config = EluConfig::new();
527        let expected = inner_config.init().forward(input.clone());
528
529        check_stateless_config_output(inner_config.into(), input, expected, &device)
530    }
531
532    #[test]
533    fn test_softplus() {
534        let device = Default::default();
535        let input = make_input::<TestBackend>(&device);
536
537        let inner_config = SoftplusConfig::new();
538        let expected = inner_config.init().forward(input.clone());
539
540        check_stateless_config_output(inner_config.into(), input, expected, &device)
541    }
542
543    #[test]
544    fn test_celu() {
545        let device = Default::default();
546        let input = make_input::<TestBackend>(&device);
547
548        let inner_config = CeluConfig::new();
549        let expected = inner_config.init().forward(input.clone());
550
551        check_stateless_config_output(inner_config.into(), input, expected, &device)
552    }
553
554    #[test]
555    fn test_thresholded_relu() {
556        let device = Default::default();
557        let input = make_input::<TestBackend>(&device);
558
559        let inner_config = ThresholdedReluConfig::new();
560        let expected = inner_config.init().forward(input.clone());
561
562        check_stateless_config_output(inner_config.into(), input, expected, &device)
563    }
564
565    #[test]
566    fn test_hard_shrink() {
567        let device = Default::default();
568        let input = make_input::<TestBackend>(&device);
569
570        let inner_config = HardShrinkConfig::new();
571        let expected = inner_config.init().forward(input.clone());
572
573        check_stateless_config_output(inner_config.into(), input, expected, &device)
574    }
575
576    #[test]
577    fn test_soft_shrink() {
578        let device = Default::default();
579        let input = make_input::<TestBackend>(&device);
580
581        let inner_config = SoftShrinkConfig::new();
582        let expected = inner_config.init().forward(input.clone());
583
584        check_stateless_config_output(inner_config.into(), input, expected, &device)
585    }
586
587    #[test]
588    fn test_shrink() {
589        let device = Default::default();
590        let input = make_input::<TestBackend>(&device);
591
592        let inner_config = ShrinkConfig::new();
593        let expected = inner_config.init().forward(input.clone());
594
595        check_stateless_config_output(inner_config.into(), input, expected, &device)
596    }
597}