1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
use crate as burn;

use crate::config::Config;
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::tensor::activation::silu;
use crate::tensor::{backend::Backend, Tensor};

use super::{Initializer, Linear, LinearConfig};

/// Configuration to create a [SwiGlu](SwiGlu) activation layer using the [init function](SwiGluConfig::init).
#[derive(Config, Debug)]
pub struct SwiGluConfig {
    /// The size of the input features.
    pub d_input: usize,
    /// The size of the output features.
    pub d_output: usize,
    /// If a bias should be applied during the linear transformation. Default behaviour is False
    /// for SwiGLU activation implementations.
    #[config(default = false)]
    pub bias: bool,
    /// The type of function used to initialize the linear layer parameters
    #[config(
        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
    )]
    pub initializer: Initializer,
}

/// Applies the SwiGLU or Swish Gated Linear Unit to the input tensor.
/// The SwiGLU activation function is defined as:
/// `SwiGLU(x) = Swish(W_inner * x + b_inner) * (W_outer * x + b_outer)`
///
/// Should be created with [SwiGluConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct SwiGlu<B: Backend> {
    /// The inner linear layer for Swish activation function
    /// with `d_input` input features and `d_output` output features.
    pub linear_inner: Linear<B>,
    /// The outer linear layer for element wise multiplication
    /// with `d_input` input features and `d_output` output features.
    pub linear_outer: Linear<B>,
}

impl<B: Backend> ModuleDisplay for SwiGlu<B> {
    fn custom_settings(&self) -> Option<DisplaySettings> {
        DisplaySettings::new()
            .with_new_line_after_attribute(false)
            .optional()
    }

    fn custom_content(&self, content: Content) -> Option<Content> {
        let [d_input, d_output] = self.linear_inner.weight.shape().dims;
        content
            .add("d_input", &d_input)
            .add("d_output", &d_output)
            .add("bias", &self.linear_inner.bias.is_some())
            .optional()
    }
}

impl SwiGluConfig {
    /// Initialize a new [SwiGLU](SwiGlu) activation layer.
    pub fn init<B: Backend>(&self, device: &B::Device) -> SwiGlu<B> {
        SwiGlu {
            linear_inner: LinearConfig::new(self.d_input, self.d_output)
                .with_bias(self.bias)
                .with_initializer(self.initializer.clone())
                .init(device),
            linear_outer: LinearConfig::new(self.d_input, self.d_output)
                .with_bias(self.bias)
                .with_initializer(self.initializer.clone())
                .init(device),
        }
    }
}

impl<B: Backend> SwiGlu<B> {
    /// Applies the Swish Gated Linear Unit to the input tensor.
    ///
    /// # Shapes
    ///
    /// - input: `[batch_size, seq_length, d_input]`
    /// - output: `[batch_size, seq_length, d_output]`
    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
        let x = self.linear_inner.forward(input.clone());
        let x = silu(x);
        x.mul(self.linear_outer.forward(input))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::TestBackend;

    #[test]
    fn test_swiglu_forward_no_bias() {
        TestBackend::seed(0);
        let device = Default::default();
        let config = SwiGluConfig::new(3, 3).with_initializer(Initializer::Constant { value: 0.5 });
        let swiglu = config.init(&device);
        let input =
            Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
        let output = swiglu.forward(input);
        let expected_output = Tensor::<TestBackend, 2>::from_data(
            [[8.5732, 8.5732, 8.5732], [56.2189, 56.2189, 56.2189]],
            &device,
        );
        output
            .to_data()
            .assert_approx_eq(&expected_output.to_data(), 4);
    }

    #[test]
    fn test_swiglu_forward_with_bias() {
        TestBackend::seed(0);
        let device = Default::default();
        let config = SwiGluConfig::new(3, 3)
            .with_bias(true)
            .with_initializer(Initializer::Constant { value: 0.5 });
        let swiglu = config.init(&device);
        let input =
            Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
        let output = swiglu.forward(input);
        let expected_output = Tensor::<TestBackend, 2>::from_data(
            [[11.8909, 11.8909, 11.8909], [63.9785, 63.9785, 63.9785]],
            &device,
        );
        output
            .to_data()
            .assert_approx_eq(&expected_output.to_data(), 4);
    }

    #[test]
    fn display() {
        let config = SwiGluConfig::new(3, 5);
        let swiglu = config.init::<TestBackend>(&Default::default());

        assert_eq!(
            alloc::format!("{}", swiglu),
            "SwiGlu {d_input: 3, d_output: 5, bias: false, params: 30}"
        );
    }
}