burn_core/nn/
gelu.rs

1use crate as burn;
2
3use crate::module::Module;
4use crate::tensor::Tensor;
5use crate::tensor::backend::Backend;
6
7/// Applies the Gaussian Error Linear Units function element-wise.
8/// See also [gelu](burn::tensor::activation::gelu)
9#[derive(Module, Clone, Debug, Default)]
10pub struct Gelu;
11
12impl Gelu {
13    /// Create the module.
14    pub fn new() -> Self {
15        Self {}
16    }
17
18    /// Applies the forward pass on the input tensor.
19    ///
20    /// # Shapes
21    ///
22    /// - input: `[..., any]`
23    /// - output: `[..., any]`
24    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
25        crate::tensor::activation::gelu(input)
26    }
27}
28
29#[cfg(test)]
30mod tests {
31    use super::*;
32
33    #[test]
34    fn display() {
35        let layer = Gelu::new();
36
37        assert_eq!(alloc::format!("{layer}"), "Gelu");
38    }
39}