Skip to main content

burn_nn/activation/
glu.rs

1use burn_core as burn;
2
3use burn::module::Module;
4use burn::tensor::Tensor;
5use burn::tensor::backend::Backend;
6
7/// Applies the gated linear unit function.
8///
9/// See also [glu](burn::tensor::activation::glu)
10#[derive(Module, Clone, Debug, Default)]
11pub struct GLU {
12    dim: usize,
13}
14
15impl GLU {
16    /// Create the module.
17    ///
18    /// # Arguments
19    /// * `dim` - The dimension on which to split the input.
20    pub fn new(dim: usize) -> Self {
21        Self { dim }
22    }
23
24    /// Applies the gated linear unit function.
25    ///
26    /// GLU(a,b)=a⊗σ(b) where `a` is the first half of the input matrices and `b` is the second half.
27    ///
28    /// **Note**:
29    /// * The size of the input tensor along `dim` must be divisible by 2.
30    ///
31    /// ### Arguments
32    /// * `tensor` - The input tensor.
33    ///
34    /// ### Returns
35    /// * A tensor with the same shape as the input, except the size along `dim` is halved.
36    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
37        burn::tensor::activation::glu(input, self.dim)
38    }
39}
40
41#[cfg(test)]
42mod tests {
43    use super::*;
44
45    #[test]
46    fn display() {
47        let layer = GLU::new(1);
48
49        assert_eq!(alloc::format!("{layer}"), "GLU {\n  dim: 1\n}");
50    }
51}