burn_core/nn/rnn/
gru.rs

1use crate as burn;
2
3use crate::config::Config;
4use crate::module::Module;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::nn::Initializer;
7use crate::nn::rnn::gate_controller;
8use crate::tensor::Tensor;
9use crate::tensor::activation;
10use crate::tensor::backend::Backend;
11
12use super::gate_controller::GateController;
13
14/// Configuration to create a [gru](Gru) module using the [init function](GruConfig::init).
15#[derive(Config)]
16pub struct GruConfig {
17    /// The size of the input features.
18    pub d_input: usize,
19    /// The size of the hidden state.
20    pub d_hidden: usize,
21    /// If a bias should be applied during the Gru transformation.
22    pub bias: bool,
23    /// If reset gate should be applied after weight multiplication.
24    ///
25    /// This configuration option controls how the reset gate is applied to the hidden state.
26    /// * `true` - (Default) Match the initial arXiv version of the paper [Learning Phrase Representations using RNN Encoder-Decoder for
27    ///   Statistical Machine Translation (v1)](https://arxiv.org/abs/1406.1078v1) and apply the reset gate after multiplication by
28    ///   the weights. This matches the behavior of [PyTorch GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU).
29    /// * `false` - Match the most recent revision of [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine
30    ///   Translation (v3)](https://arxiv.org/abs/1406.1078) and apply the reset gate before the weight multiplication.
31    ///
32    /// The differing implementations can give slightly different numerical results and have different efficiencies. For more
33    /// motivation for why the `true` can be more efficient see [Optimizing RNNs with Differentiable Graphs](https://svail.github.io/diff_graphs).
34    ///
35    /// To set this field to `false` use [`with_reset_after`](`GruConfig::with_reset_after`).
36    #[config(default = "true")]
37    pub reset_after: bool,
38    /// Gru initializer
39    #[config(default = "Initializer::XavierNormal{gain:1.0}")]
40    pub initializer: Initializer,
41}
42
43/// The Gru (Gated recurrent unit) module. This implementation is for a unidirectional, stateless, Gru.
44///
45/// Introduced in the paper: [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://arxiv.org/abs/1406.1078).
46///
47/// Should be created with [GruConfig].
48#[derive(Module, Debug)]
49#[module(custom_display)]
50pub struct Gru<B: Backend> {
51    /// The update gate controller.
52    pub update_gate: GateController<B>,
53    /// The reset gate controller.
54    pub reset_gate: GateController<B>,
55    /// The new gate controller.
56    pub new_gate: GateController<B>,
57    /// The size of the hidden state.
58    pub d_hidden: usize,
59    /// If reset gate should be applied after weight multiplication.
60    pub reset_after: bool,
61}
62
63impl<B: Backend> ModuleDisplay for Gru<B> {
64    fn custom_settings(&self) -> Option<DisplaySettings> {
65        DisplaySettings::new()
66            .with_new_line_after_attribute(false)
67            .optional()
68    }
69
70    fn custom_content(&self, content: Content) -> Option<Content> {
71        let [d_input, _] = self.update_gate.input_transform.weight.shape().dims();
72        let bias = self.update_gate.input_transform.bias.is_some();
73
74        content
75            .add("d_input", &d_input)
76            .add("d_hidden", &self.d_hidden)
77            .add("bias", &bias)
78            .add("reset_after", &self.reset_after)
79            .optional()
80    }
81}
82
83impl GruConfig {
84    /// Initialize a new [gru](Gru) module.
85    pub fn init<B: Backend>(&self, device: &B::Device) -> Gru<B> {
86        let d_output = self.d_hidden;
87
88        let update_gate = gate_controller::GateController::new(
89            self.d_input,
90            d_output,
91            self.bias,
92            self.initializer.clone(),
93            device,
94        );
95        let reset_gate = gate_controller::GateController::new(
96            self.d_input,
97            d_output,
98            self.bias,
99            self.initializer.clone(),
100            device,
101        );
102        let new_gate = gate_controller::GateController::new(
103            self.d_input,
104            d_output,
105            self.bias,
106            self.initializer.clone(),
107            device,
108        );
109
110        Gru {
111            update_gate,
112            reset_gate,
113            new_gate,
114            d_hidden: self.d_hidden,
115            reset_after: self.reset_after,
116        }
117    }
118}
119
120impl<B: Backend> Gru<B> {
121    /// Applies the forward pass on the input tensor. This GRU implementation
122    /// returns a state tensor with dimensions `[batch_size, sequence_length, hidden_size]`.
123    ///
124    /// # Parameters
125    /// - batched_input: `[batch_size, sequence_length, input_size]`.
126    /// - state: An optional tensor representing an initial cell state with dimensions
127    ///   `[batch_size, hidden_size]`. If none is provided, an empty state will be used.
128    ///
129    /// # Returns
130    /// - output: `[batch_size, sequence_length, hidden_size]`
131    pub fn forward(
132        &self,
133        batched_input: Tensor<B, 3>,
134        state: Option<Tensor<B, 2>>,
135    ) -> Tensor<B, 3> {
136        let device = batched_input.device();
137        let [batch_size, seq_length, _] = batched_input.shape().dims();
138
139        let mut batched_hidden_state =
140            Tensor::empty([batch_size, seq_length, self.d_hidden], &device);
141
142        let mut hidden_t = match state {
143            Some(state) => state,
144            None => Tensor::zeros([batch_size, self.d_hidden], &device),
145        };
146
147        for (t, input_t) in batched_input.iter_dim(1).enumerate() {
148            let input_t = input_t.squeeze(1);
149            // u(pdate)g(ate) tensors
150            let biased_ug_input_sum =
151                self.gate_product(&input_t, &hidden_t, None, &self.update_gate);
152            let update_values = activation::sigmoid(biased_ug_input_sum); // Colloquially referred to as z(t)
153
154            // r(eset)g(ate) tensors
155            let biased_rg_input_sum =
156                self.gate_product(&input_t, &hidden_t, None, &self.reset_gate);
157            let reset_values = activation::sigmoid(biased_rg_input_sum); // Colloquially referred to as r(t)
158
159            // n(ew)g(ate) tensor
160            let biased_ng_input_sum = if self.reset_after {
161                self.gate_product(&input_t, &hidden_t, Some(&reset_values), &self.new_gate)
162            } else {
163                let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate
164                self.gate_product(&input_t, &reset_t, None, &self.new_gate)
165            };
166            let candidate_state = biased_ng_input_sum.tanh(); // Colloquially referred to as g(t)
167
168            // calculate linear interpolation between previous hidden state and candidate state:
169            // g(t) * (1 - z(t)) + z(t) * hidden_t
170            hidden_t = candidate_state
171                .clone()
172                .mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) // (1 - z(t)) = -(z(t) - 1)
173                + update_values.clone().mul(hidden_t);
174
175            let unsqueezed_hidden_state = hidden_t.clone().unsqueeze_dim(1);
176
177            batched_hidden_state = batched_hidden_state.slice_assign(
178                [0..batch_size, t..(t + 1), 0..self.d_hidden],
179                unsqueezed_hidden_state,
180            );
181        }
182
183        batched_hidden_state
184    }
185
186    /// Helper function for performing weighted matrix product for a gate and adds
187    /// bias, if any, and optionally applies reset to hidden state.
188    ///
189    ///  Mathematically, performs `Wx*X + r .* (Wh*H + b)`, where:
190    ///     Wx = weight matrix for the connection to input vector X
191    ///     Wh = weight matrix for the connection to hidden state H
192    ///     X = input vector
193    ///     H = hidden state
194    ///     b = bias terms
195    ///     r = reset state
196    fn gate_product(
197        &self,
198        input: &Tensor<B, 2>,
199        hidden: &Tensor<B, 2>,
200        reset: Option<&Tensor<B, 2>>,
201        gate: &GateController<B>,
202    ) -> Tensor<B, 2> {
203        let input_product = input.clone().matmul(gate.input_transform.weight.val());
204        let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val());
205
206        let input_bias = gate
207            .input_transform
208            .bias
209            .as_ref()
210            .map(|bias_param| bias_param.val());
211        let hidden_bias = gate
212            .hidden_transform
213            .bias
214            .as_ref()
215            .map(|bias_param| bias_param.val());
216
217        match (input_bias, hidden_bias, reset) {
218            (Some(input_bias), Some(hidden_bias), Some(r)) => {
219                input_product
220                    + input_bias.unsqueeze()
221                    + r.clone().mul(hidden_product + hidden_bias.unsqueeze())
222            }
223            (Some(input_bias), Some(hidden_bias), None) => {
224                input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze()
225            }
226            (Some(input_bias), None, Some(r)) => {
227                input_product + input_bias.unsqueeze() + r.clone().mul(hidden_product)
228            }
229            (Some(input_bias), None, None) => {
230                input_product + input_bias.unsqueeze() + hidden_product
231            }
232            (None, Some(hidden_bias), Some(r)) => {
233                input_product + r.clone().mul(hidden_product + hidden_bias.unsqueeze())
234            }
235            (None, Some(hidden_bias), None) => {
236                input_product + hidden_product + hidden_bias.unsqueeze()
237            }
238            (None, None, Some(r)) => input_product + r.clone().mul(hidden_product),
239            (None, None, None) => input_product + hidden_product,
240        }
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::tensor::{Distribution, TensorData};
248    use crate::{TestBackend, module::Param, nn::LinearRecord};
249    use burn_tensor::{Tolerance, ops::FloatElem};
250    type FT = FloatElem<TestBackend>;
251
252    fn init_gru<B: Backend>(reset_after: bool, device: &B::Device) -> Gru<B> {
253        fn create_gate_controller<B: Backend>(
254            weights: f32,
255            biases: f32,
256            d_input: usize,
257            d_output: usize,
258            bias: bool,
259            initializer: Initializer,
260            device: &B::Device,
261        ) -> GateController<B> {
262            let record_1 = LinearRecord {
263                weight: Param::from_data(TensorData::from([[weights]]), device),
264                bias: Some(Param::from_data(TensorData::from([biases]), device)),
265            };
266            let record_2 = LinearRecord {
267                weight: Param::from_data(TensorData::from([[weights]]), device),
268                bias: Some(Param::from_data(TensorData::from([biases]), device)),
269            };
270            gate_controller::GateController::create_with_weights(
271                d_input,
272                d_output,
273                bias,
274                initializer,
275                record_1,
276                record_2,
277            )
278        }
279
280        let config = GruConfig::new(1, 1, false).with_reset_after(reset_after);
281        let mut gru = config.init::<B>(device);
282
283        gru.update_gate = create_gate_controller(
284            0.5,
285            0.0,
286            1,
287            1,
288            false,
289            Initializer::XavierNormal { gain: 1.0 },
290            device,
291        );
292        gru.reset_gate = create_gate_controller(
293            0.6,
294            0.0,
295            1,
296            1,
297            false,
298            Initializer::XavierNormal { gain: 1.0 },
299            device,
300        );
301        gru.new_gate = create_gate_controller(
302            0.7,
303            0.0,
304            1,
305            1,
306            false,
307            Initializer::XavierNormal { gain: 1.0 },
308            device,
309        );
310        gru
311    }
312
313    /// Test forward pass with simple input vector.
314    ///
315    /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125
316    /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150
317    /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699
318    ///
319    /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341
320    #[test]
321    fn tests_forward_single_input_single_feature() {
322        TestBackend::seed(0);
323        let device = Default::default();
324        let mut gru = init_gru::<TestBackend>(false, &device);
325
326        let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
327        let expected = TensorData::from([[0.034]]);
328
329        // Reset gate applied to hidden state before the matrix multiplication
330        let state = gru.forward(input.clone(), None);
331
332        let output = state
333            .select(0, Tensor::arange(0..1, &device))
334            .squeeze::<2>(0);
335
336        let tolerance = Tolerance::default();
337        output
338            .to_data()
339            .assert_approx_eq::<FT>(&expected, tolerance);
340
341        // Reset gate applied to hidden state after the matrix multiplication
342        gru.reset_after = true; // override forward behavior
343        let state = gru.forward(input, None);
344
345        let output = state
346            .select(0, Tensor::arange(0..1, &device))
347            .squeeze::<2>(0);
348
349        output
350            .to_data()
351            .assert_approx_eq::<FT>(&expected, tolerance);
352    }
353
354    #[test]
355    fn tests_forward_seq_len_3() {
356        TestBackend::seed(0);
357        let device = Default::default();
358        let mut gru = init_gru::<TestBackend>(true, &device);
359
360        let input =
361            Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device);
362        let expected = TensorData::from([[0.0341], [0.0894], [0.1575]]);
363
364        let result = gru.forward(input.clone(), None);
365        let output = result
366            .select(0, Tensor::arange(0..1, &device))
367            .squeeze::<2>(0);
368
369        let tolerance = Tolerance::default();
370        output
371            .to_data()
372            .assert_approx_eq::<FT>(&expected, tolerance);
373
374        // Reset gate applied to hidden state before the matrix multiplication
375        gru.reset_after = false; // override forward behavior
376        let state = gru.forward(input, None);
377
378        let output = state
379            .select(0, Tensor::arange(0..1, &device))
380            .squeeze::<2>(0);
381
382        output
383            .to_data()
384            .assert_approx_eq::<FT>(&expected, tolerance);
385    }
386
387    #[test]
388    fn test_batched_forward_pass() {
389        let device = Default::default();
390        let gru = GruConfig::new(64, 1024, true).init::<TestBackend>(&device);
391        let batched_input =
392            Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
393
394        let hidden_state = gru.forward(batched_input, None);
395
396        assert_eq!(hidden_state.shape().dims, [8, 10, 1024]);
397    }
398
399    #[test]
400    fn display() {
401        let config = GruConfig::new(2, 8, true);
402
403        let layer = config.init::<TestBackend>(&Default::default());
404
405        assert_eq!(
406            alloc::format!("{layer}"),
407            "Gru {d_input: 2, d_hidden: 8, bias: true, reset_after: true, params: 288}"
408        );
409    }
410}