oxidized_transformers/layers/
feedforward.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
/// Feed-forward layers.
use candle_core::{Module, ModuleT, Tensor};
use candle_nn::{linear, linear_no_bias, Linear, VarBuilder};
use snafu::{ResultExt, Snafu};

use crate::error::BoxedError;
use crate::layers::activation::Activation;
use crate::layers::build_module::BuildModule;
use crate::layers::identity::Identity;

/// Configuration for pointwise feed-forward layers.
#[derive(Debug)]
pub struct PointwiseFeedForwardConfig {
    activation: Box<dyn BuildModule>,
    dropout: Box<dyn BuildModule>,
    hidden_width: usize,
    intermediate_width: usize,
    layer_norm: Box<dyn BuildModule>,
    use_bias: bool,
    use_gate: bool,
}

impl PointwiseFeedForwardConfig {
    /// Activation function.
    ///
    /// Default: `GELU`
    pub fn activation(mut self, activation: Box<dyn BuildModule>) -> Self {
        self.activation = activation;
        self
    }

    /// Dropout to apply after the feed-forward layer.
    ///
    /// Default: `Identity`
    pub fn dropout(mut self, dropout: Box<dyn BuildModule>) -> Self {
        self.dropout = dropout;
        self
    }

    /// Hidden width of the transformer.
    ///
    /// Default: `768`
    pub fn hidden_width(mut self, hidden_width: usize) -> Self {
        self.hidden_width = hidden_width;
        self
    }

    /// Intermediate width in the feed-forward layer.
    ///
    /// Default: `3072`
    pub fn intermediate_width(mut self, intermediate_width: usize) -> Self {
        self.intermediate_width = intermediate_width;
        self
    }

    /// Layer normalization module.
    ///
    /// Default: `Identity`
    pub fn layer_norm(mut self, layer_norm: Box<dyn BuildModule>) -> Self {
        self.layer_norm = layer_norm;
        self
    }

    /// Use bias in linear layers.
    ///
    /// Default: `true`
    pub fn use_bias(mut self, use_bias: bool) -> Self {
        self.use_bias = use_bias;
        self
    }

    /// Use Gated Linear Units.
    ///
    /// Default: `false`
    pub fn use_gate(mut self, use_gate: bool) -> Self {
        self.use_gate = use_gate;
        self
    }
}

impl Default for PointwiseFeedForwardConfig {
    fn default() -> Self {
        Self {
            activation: Box::new(Activation::Gelu),
            dropout: Box::new(Identity),
            hidden_width: 768,
            intermediate_width: 3072,
            layer_norm: Box::new(Identity),
            use_bias: true,
            use_gate: false,
        }
    }
}

impl BuildModule for PointwiseFeedForwardConfig {
    fn build(&self, vb: VarBuilder) -> Result<Box<dyn ModuleT>, BoxedError> {
        let linear_ctor = if self.use_bias {
            linear
        } else {
            linear_no_bias
        };

        let intermediate = linear_ctor(
            self.hidden_width,
            self.intermediate_width,
            vb.push_prefix("intermediate"),
        )
        .context(CreateLinearSnafu)?;

        let gate = if self.use_gate {
            Some(
                linear_ctor(
                    self.hidden_width,
                    self.intermediate_width,
                    vb.push_prefix("gate"),
                )
                .context(CreateLinearSnafu)?,
            )
        } else {
            None
        };

        let output = linear_ctor(
            self.intermediate_width,
            self.hidden_width,
            vb.push_prefix("output"),
        )
        .context(CreateLinearSnafu)?;

        Ok(Box::new(PointwiseFeedForward {
            activation: self
                .activation
                .build(vb.clone())
                .context(BuildActivationSnafu)?,
            dropout: self
                .dropout
                .build(vb.push_prefix("dropout"))
                .context(BuildDropoutSnafu)?,
            gate,
            intermediate,
            layer_norm: self
                .layer_norm
                .build(vb.push_prefix("layer_norm"))
                .context(CreateLayerNormSnafu)?,
            output,
        }))
    }
}

#[derive(Debug, Snafu)]
pub enum PointwiseFeedForwardError {
    #[snafu(display("Cannot build activation"))]
    BuildActivation { source: BoxedError },

    #[snafu(display("Cannot build dropout"))]
    BuildDropout { source: BoxedError },

    #[snafu(display("Cannot create layer norm"))]
    CreateLayerNorm { source: BoxedError },

    #[snafu(display("Cannot create linear layer"))]
    CreateLinear { source: candle_core::Error },
}

/// Point-wise feed-forward layer (_Vaswani et al., 2017_).
///
/// This layer is applied pointwise, meaning that the same
/// transformation is applied to each sequence element. This
/// transformation is:
///
/// `g(xW_1 + b_1)W_2 + b_2`
///
/// `W_1` and `b_1` transform the input to an
/// intermediate width, `g` is a non-linear activation
/// function and `W_2` and `b_2` transform the
/// output of the activation back to the input width.
///
/// Gated Linear Units (_Dauphin et al., 2016_; _Shazeer, 2020_) are also
/// supported. Gating applies the following transformation:
///
/// `(g(xW_g + b_g) * (xW_1 + b_1))W_2 + b_2`
///
/// `W_g` and _b_g_ are the affine transformation for the gate.
///
/// * _Vaswani et al., 2017_: https://arxiv.org/abs/1706.03762
/// * _Dauphin et al., 2016_: https://arxiv.org/abs/1612.08083
/// * _Shazeer, 2020_: https://arxiv.org/abs/2002.05202
pub struct PointwiseFeedForward {
    activation: Box<dyn ModuleT>,
    dropout: Box<dyn ModuleT>,
    gate: Option<Linear>,
    intermediate: Linear,
    layer_norm: Box<dyn ModuleT>,
    output: Linear,
}

impl ModuleT for PointwiseFeedForward {
    fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor, candle_core::Error> {
        let xs = self.layer_norm.forward_t(xs, train)?;

        let output = match &self.gate {
            Some(gate) => self.output.forward(
                &self
                    .activation
                    .forward_t(&gate.forward(&xs)?, train)?
                    .mul(&self.intermediate.forward(&xs)?)?,
            ),
            None => self.output.forward(
                &self
                    .activation
                    .forward_t(&self.intermediate.forward(&xs)?, train)?,
            ),
        }?;

        self.dropout.forward_t(&output, train)
    }
}