Skip to main content

burn_nn/activation/
gelu.rs

1use burn_core as burn;
2
3use burn::module::Module;
4use burn::tensor::Tensor;
5use burn::tensor::backend::Backend;
6
7/// Applies the Gaussian Error Linear Units function element-wise.
8///
9/// See also [gelu](burn::tensor::activation::gelu)
10///
11/// When `approximate` is true, uses the tanh approximation:
12/// `0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
13#[derive(Module, Clone, Debug, Default)]
14pub struct Gelu {
15    /// Whether to use tanh approximation.
16    pub approximate: bool,
17}
18
19impl Gelu {
20    /// Create the module with exact GELU.
21    pub fn new() -> Self {
22        Self::default()
23    }
24
25    /// Create the module with tanh approximation.
26    pub fn new_approximate() -> Self {
27        Self { approximate: true }
28    }
29
30    /// Applies the forward pass on the input tensor.
31    ///
32    /// # Shapes
33    ///
34    /// - input: `[..., any]`
35    /// - output: `[..., any]`
36    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
37        if self.approximate {
38            burn::tensor::activation::gelu_approximate(input)
39        } else {
40            burn::tensor::activation::gelu(input)
41        }
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48    use crate::TestBackend;
49    use burn::tensor::Tolerance;
50    use burn::tensor::ops::FloatElem;
51
52    type FT = FloatElem<TestBackend>;
53
54    #[test]
55    fn display() {
56        let layer = Gelu::new();
57
58        assert_eq!(alloc::format!("{layer}"), "Gelu {\n  approximate: false\n}");
59    }
60
61    #[test]
62    fn forward_approximate() {
63        let device = Default::default();
64        let input =
65            Tensor::<TestBackend, 2>::from_data([[-1.0, 0.0, 1.0], [0.5, -0.5, 2.0]], &device);
66
67        let output = Gelu::new_approximate().forward(input);
68
69        // PyTorch: torch.nn.functional.gelu(x, approximate="tanh")
70        let expected = Tensor::<TestBackend, 2>::from_data(
71            [
72                [-0.1588079929, 0.0000000000, 0.8411920071],
73                [0.3457140028, -0.1542859972, 1.9545977116],
74            ],
75            &device,
76        );
77
78        output
79            .into_data()
80            .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::rel_abs(1e-5, 1e-5));
81    }
82}