burn_nn/activation/glu.rs
1use burn_core as burn;
2
3use burn::module::Module;
4use burn::tensor::Tensor;
5use burn::tensor::backend::Backend;
6
7/// Applies the gated linear unit function.
8///
9/// See also [glu](burn::tensor::activation::glu)
10#[derive(Module, Clone, Debug, Default)]
11pub struct GLU {
12 dim: usize,
13}
14
15impl GLU {
16 /// Create the module.
17 ///
18 /// # Arguments
19 /// * `dim` - The dimension on which to split the input.
20 pub fn new(dim: usize) -> Self {
21 Self { dim }
22 }
23
24 /// Applies the gated linear unit function.
25 ///
26 /// GLU(a,b)=a⊗σ(b) where `a` is the first half of the input matrices and `b` is the second half.
27 ///
28 /// **Note**:
29 /// * The size of the input tensor along `dim` must be divisible by 2.
30 ///
31 /// ### Arguments
32 /// * `tensor` - The input tensor.
33 ///
34 /// ### Returns
35 /// * A tensor with the same shape as the input, except the size along `dim` is halved.
36 pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
37 burn::tensor::activation::glu(input, self.dim)
38 }
39}
40
41#[cfg(test)]
42mod tests {
43 use super::*;
44
45 #[test]
46 fn display() {
47 let layer = GLU::new(1);
48
49 assert_eq!(alloc::format!("{layer}"), "GLU {\n dim: 1\n}");
50 }
51}