burn_nn/activation/
gelu.rs1use burn_core as burn;
2
3use burn::module::Module;
4use burn::tensor::Tensor;
5use burn::tensor::backend::Backend;
6
7#[derive(Module, Clone, Debug, Default)]
14pub struct Gelu {
15 pub approximate: bool,
17}
18
19impl Gelu {
20 pub fn new() -> Self {
22 Self::default()
23 }
24
25 pub fn new_approximate() -> Self {
27 Self { approximate: true }
28 }
29
30 pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
37 if self.approximate {
38 burn::tensor::activation::gelu_approximate(input)
39 } else {
40 burn::tensor::activation::gelu(input)
41 }
42 }
43}
44
45#[cfg(test)]
46mod tests {
47 use super::*;
48 use crate::TestBackend;
49 use burn::tensor::Tolerance;
50 use burn::tensor::ops::FloatElem;
51
52 type FT = FloatElem<TestBackend>;
53
54 #[test]
55 fn display() {
56 let layer = Gelu::new();
57
58 assert_eq!(alloc::format!("{layer}"), "Gelu {\n approximate: false\n}");
59 }
60
61 #[test]
62 fn forward_approximate() {
63 let device = Default::default();
64 let input =
65 Tensor::<TestBackend, 2>::from_data([[-1.0, 0.0, 1.0], [0.5, -0.5, 2.0]], &device);
66
67 let output = Gelu::new_approximate().forward(input);
68
69 let expected = Tensor::<TestBackend, 2>::from_data(
71 [
72 [-0.1588079929, 0.0000000000, 0.8411920071],
73 [0.3457140028, -0.1542859972, 1.9545977116],
74 ],
75 &device,
76 );
77
78 output
79 .into_data()
80 .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::rel_abs(1e-5, 1e-5));
81 }
82}