Skip to main content

burn_nn/activation/
activation_wrapper.rs

1use burn_core as burn;
2
3use crate::activation::{
4    Gelu, HardSigmoid, HardSigmoidConfig, HardSwish, LeakyRelu, LeakyReluConfig, PRelu,
5    PReluConfig, Relu, Sigmoid, Softplus, SoftplusConfig, SwiGlu, SwiGluConfig, Tanh,
6};
7use burn::config::Config;
8use burn::module::Module;
9use burn::tensor::Tensor;
10use burn::tensor::backend::Backend;
11
12/// [`Activation`] Configuration.
13#[derive(Config, Debug)]
14#[non_exhaustive]
15pub enum ActivationConfig {
16    /// [`Gelu`] activation layer.
17    Gelu,
18
19    /// [`PRelu`] activation layer.
20    PRelu(PReluConfig),
21
22    /// [`Relu`] activation layer.
23    Relu,
24
25    /// [`LeakyRelu`] activation layer.
26    LeakyRelu(LeakyReluConfig),
27
28    /// [`SwiGlu`] activation layer.
29    SwiGlu(SwiGluConfig),
30
31    /// [`Sigmoid`] activation layer.
32    Sigmoid,
33
34    /// [`Tanh`] activation layer.
35    Tanh,
36
37    /// [`HardSigmoid`] activation layer.
38    HardSigmoid(HardSigmoidConfig),
39
40    /// [`HardSwish`] activation layer.
41    HardSwish,
42
43    /// [`Softplus`] activation layer.
44    Softplus(SoftplusConfig),
45}
46
47impl From<PReluConfig> for ActivationConfig {
48    fn from(config: PReluConfig) -> Self {
49        Self::PRelu(config)
50    }
51}
52
53impl From<LeakyReluConfig> for ActivationConfig {
54    fn from(config: LeakyReluConfig) -> Self {
55        Self::LeakyRelu(config)
56    }
57}
58
59impl From<SwiGluConfig> for ActivationConfig {
60    fn from(config: SwiGluConfig) -> Self {
61        Self::SwiGlu(config)
62    }
63}
64
65impl From<HardSigmoidConfig> for ActivationConfig {
66    fn from(config: HardSigmoidConfig) -> Self {
67        Self::HardSigmoid(config)
68    }
69}
70
71impl From<SoftplusConfig> for ActivationConfig {
72    fn from(config: SoftplusConfig) -> Self {
73        Self::Softplus(config)
74    }
75}
76
77impl ActivationConfig {
78    /// Initialize a wrapped activation layer.
79    pub fn init<B: Backend>(&self, device: &B::Device) -> Activation<B> {
80        match self {
81            ActivationConfig::Relu => Relu.into(),
82            ActivationConfig::LeakyRelu(conf) => conf.init().into(),
83            ActivationConfig::Gelu => Gelu.into(),
84            ActivationConfig::PRelu(conf) => conf.init(device).into(),
85            ActivationConfig::SwiGlu(conf) => conf.init(device).into(),
86            ActivationConfig::HardSigmoid(conf) => conf.init().into(),
87            ActivationConfig::HardSwish => HardSwish.into(),
88            ActivationConfig::Softplus(conf) => conf.init().into(),
89            ActivationConfig::Sigmoid => Sigmoid.into(),
90            ActivationConfig::Tanh => Tanh.into(),
91        }
92    }
93}
94
95/// Activation Layer Wrapper.
96///
97/// Provides support for many in-built `burn::nn` activations.
98#[derive(Module, Debug)]
99#[non_exhaustive]
100#[allow(clippy::large_enum_variant)]
101pub enum Activation<B: Backend> {
102    /// [`Gelu`] activation layer.
103    Gelu(Gelu),
104
105    /// [`PRelu`] activation layer.
106    PRelu(PRelu<B>),
107
108    /// [`Relu`] activation layer.
109    Relu(Relu),
110
111    /// [`LeakyRelu`] activation layer.
112    LeakyRelu(LeakyRelu),
113
114    /// [`SwiGlu`] activation layer.
115    SwiGlu(SwiGlu<B>),
116
117    /// [`Sigmoid`] activation layer.
118    Sigmoid(Sigmoid),
119
120    /// [`Tanh`] activation layer.
121    Tanh(Tanh),
122
123    /// [`HardSigmoid`] activation layer.
124    HardSigmoid(HardSigmoid),
125
126    /// [`HardSwish`] activation layer.
127    HardSwish(HardSwish),
128
129    /// [`Softplus`] activation layer.
130    Softplus(Softplus),
131}
132
133impl<B: Backend> From<Gelu> for Activation<B> {
134    fn from(layer: Gelu) -> Self {
135        Self::Gelu(layer)
136    }
137}
138
139impl<B: Backend> From<PRelu<B>> for Activation<B> {
140    fn from(layer: PRelu<B>) -> Self {
141        Self::PRelu(layer)
142    }
143}
144
145impl<B: Backend> From<Relu> for Activation<B> {
146    fn from(layer: Relu) -> Self {
147        Self::Relu(layer)
148    }
149}
150
151impl<B: Backend> From<LeakyRelu> for Activation<B> {
152    fn from(layer: LeakyRelu) -> Self {
153        Self::LeakyRelu(layer)
154    }
155}
156
157impl<B: Backend> From<SwiGlu<B>> for Activation<B> {
158    fn from(layer: SwiGlu<B>) -> Self {
159        Self::SwiGlu(layer)
160    }
161}
162
163impl<B: Backend> From<Sigmoid> for Activation<B> {
164    fn from(layer: Sigmoid) -> Self {
165        Self::Sigmoid(layer)
166    }
167}
168
169impl<B: Backend> From<Tanh> for Activation<B> {
170    fn from(layer: Tanh) -> Self {
171        Self::Tanh(layer)
172    }
173}
174
175impl<B: Backend> From<HardSigmoid> for Activation<B> {
176    fn from(layer: HardSigmoid) -> Self {
177        Self::HardSigmoid(layer)
178    }
179}
180
181impl<B: Backend> From<HardSwish> for Activation<B> {
182    fn from(layer: HardSwish) -> Self {
183        Self::HardSwish(layer)
184    }
185}
186
187impl<B: Backend> From<Softplus> for Activation<B> {
188    fn from(layer: Softplus) -> Self {
189        Self::Softplus(layer)
190    }
191}
192
193impl<B: Backend> Activation<B> {
194    /// Forward pass.
195    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
196        match self {
197            Activation::Relu(layer) => layer.forward(input),
198            Activation::LeakyRelu(layer) => layer.forward(input),
199            Activation::Gelu(layer) => layer.forward(input),
200            Activation::PRelu(layer) => layer.forward(input),
201            Activation::SwiGlu(layer) => layer.forward(input),
202            Activation::HardSigmoid(layer) => layer.forward(input),
203            Activation::HardSwish(layer) => layer.forward(input),
204            Activation::Softplus(layer) => layer.forward(input),
205            Activation::Sigmoid(layer) => layer.forward(input),
206            Activation::Tanh(layer) => layer.forward(input),
207        }
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::TestBackend;
215    use burn::module::Module;
216
217    fn make_input<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
218        Tensor::from_data([[-1.0, -0.5, 0.0], [1.0, 0.5, 0.0]], device)
219    }
220
221    fn expect_tensor<B: Backend, const D: usize>(actual: Tensor<B, D>, expected: Tensor<B, D>) {
222        actual.to_data().assert_eq(&expected.to_data(), true);
223    }
224
225    fn check_stateless_config_output<B: Backend, const D: usize>(
226        config: ActivationConfig,
227        input: Tensor<B, D>,
228        expected: Tensor<B, D>,
229        device: &B::Device,
230    ) {
231        let act = config.init(device);
232        let output = act.forward(input);
233        expect_tensor(output, expected);
234    }
235
236    #[test]
237    fn test_gelu() {
238        let device = Default::default();
239        let input = make_input::<TestBackend>(&device);
240
241        let expected = Gelu.forward(input.clone());
242
243        check_stateless_config_output(ActivationConfig::Gelu, input, expected, &device)
244    }
245
246    #[test]
247    fn test_prelu() {
248        let device = Default::default();
249        let input = make_input::<TestBackend>(&device);
250
251        let inner_config = PReluConfig::new();
252        let expected = inner_config.init(&device).forward(input.clone());
253
254        check_stateless_config_output(inner_config.into(), input, expected, &device)
255    }
256
257    #[test]
258    fn test_relu() {
259        let device = Default::default();
260        let input = make_input::<TestBackend>(&device);
261
262        let expected = Relu.forward(input.clone());
263
264        check_stateless_config_output(ActivationConfig::Relu, input, expected, &device)
265    }
266
267    #[test]
268    fn test_leaky_relu() {
269        let device = Default::default();
270        let input = make_input::<TestBackend>(&device);
271
272        let inner_config = LeakyReluConfig::new();
273        let expected = inner_config.init().forward(input.clone());
274
275        check_stateless_config_output(inner_config.into(), input, expected, &device)
276    }
277
278    #[test]
279    fn test_swi_glu() {
280        let device = Default::default();
281        let input = make_input::<TestBackend>(&device);
282
283        let d_input = input.shape().dims[1];
284        let d_output = 2 * d_input;
285
286        let inner_config = SwiGluConfig::new(d_input, d_output);
287        let mut reference: SwiGlu<TestBackend> = inner_config.init(&device);
288
289        let config: ActivationConfig = inner_config.into();
290        let layer = config.init(&device);
291
292        match &layer {
293            Activation::SwiGlu(inner) => {
294                // Clone the initialized weights.
295                let state = inner.clone().into_record();
296                reference = reference.load_record(state);
297            }
298            _ => unreachable!(),
299        };
300
301        expect_tensor(
302            layer.forward(input.clone()),
303            reference.forward(input.clone()),
304        )
305    }
306
307    #[test]
308    fn test_sigmoid() {
309        let device = Default::default();
310        let input = make_input::<TestBackend>(&device);
311
312        let expected = Sigmoid.forward(input.clone());
313
314        check_stateless_config_output(ActivationConfig::Sigmoid, input, expected, &device)
315    }
316
317    #[test]
318    fn test_tanh() {
319        let device = Default::default();
320        let input = make_input::<TestBackend>(&device);
321
322        let expected = Tanh.forward(input.clone());
323
324        check_stateless_config_output(ActivationConfig::Tanh, input, expected, &device)
325    }
326
327    #[test]
328    fn test_hard_sigmoid() {
329        let device = Default::default();
330        let input = make_input::<TestBackend>(&device);
331
332        let inner_config = HardSigmoidConfig::new();
333        let expected = inner_config.init().forward(input.clone());
334
335        check_stateless_config_output(inner_config.into(), input, expected, &device)
336    }
337
338    #[test]
339    fn test_softplus() {
340        let device = Default::default();
341        let input = make_input::<TestBackend>(&device);
342
343        let inner_config = SoftplusConfig::new();
344        let expected = inner_config.init().forward(input.clone());
345
346        check_stateless_config_output(inner_config.into(), input, expected, &device)
347    }
348}