burn_core/nn/rnn/
gru.rs

1use crate as burn;
2
3use crate::config::Config;
4use crate::module::Module;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::nn::rnn::gate_controller;
7use crate::nn::Initializer;
8use crate::tensor::activation;
9use crate::tensor::backend::Backend;
10use crate::tensor::Tensor;
11
12use super::gate_controller::GateController;
13
14/// Configuration to create a [gru](Gru) module using the [init function](GruConfig::init).
15#[derive(Config)]
16pub struct GruConfig {
17    /// The size of the input features.
18    pub d_input: usize,
19    /// The size of the hidden state.
20    pub d_hidden: usize,
21    /// If a bias should be applied during the Gru transformation.
22    pub bias: bool,
23    /// Gru initializer
24    #[config(default = "Initializer::XavierNormal{gain:1.0}")]
25    pub initializer: Initializer,
26}
27
28/// The Gru (Gated recurrent unit) module. This implementation is for a unidirectional, stateless, Gru.
29///
30/// Introduced in the paper: [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://arxiv.org/abs/1406.1078).
31///
32/// Should be created with [GruConfig].
33#[derive(Module, Debug)]
34#[module(custom_display)]
35pub struct Gru<B: Backend> {
36    /// The update gate controller.
37    pub update_gate: GateController<B>,
38    /// The reset gate controller.
39    pub reset_gate: GateController<B>,
40    /// The new gate controller.
41    pub new_gate: GateController<B>,
42    /// The size of the hidden state.
43    pub d_hidden: usize,
44}
45
46impl<B: Backend> ModuleDisplay for Gru<B> {
47    fn custom_settings(&self) -> Option<DisplaySettings> {
48        DisplaySettings::new()
49            .with_new_line_after_attribute(false)
50            .optional()
51    }
52
53    fn custom_content(&self, content: Content) -> Option<Content> {
54        let [d_input, _] = self.update_gate.input_transform.weight.shape().dims();
55        let bias = self.update_gate.input_transform.bias.is_some();
56
57        content
58            .add("d_input", &d_input)
59            .add("d_hidden", &self.d_hidden)
60            .add("bias", &bias)
61            .optional()
62    }
63}
64
65impl GruConfig {
66    /// Initialize a new [gru](Gru) module.
67    pub fn init<B: Backend>(&self, device: &B::Device) -> Gru<B> {
68        let d_output = self.d_hidden;
69
70        let update_gate = gate_controller::GateController::new(
71            self.d_input,
72            d_output,
73            self.bias,
74            self.initializer.clone(),
75            device,
76        );
77        let reset_gate = gate_controller::GateController::new(
78            self.d_input,
79            d_output,
80            self.bias,
81            self.initializer.clone(),
82            device,
83        );
84        let new_gate = gate_controller::GateController::new(
85            self.d_input,
86            d_output,
87            self.bias,
88            self.initializer.clone(),
89            device,
90        );
91
92        Gru {
93            update_gate,
94            reset_gate,
95            new_gate,
96            d_hidden: self.d_hidden,
97        }
98    }
99}
100
101impl<B: Backend> Gru<B> {
102    /// Applies the forward pass on the input tensor. This GRU implementation
103    /// returns a single state tensor with dimensions [batch_size, sequence_length, hidden_size].
104    ///
105    /// # Shapes
106    /// - batched_input: `[batch_size, sequence_length, input_size]`.
107    /// - state: An optional tensor representing an initial cell state with the same dimensions
108    ///          as batched_input. If none is provided, one will be generated.
109    /// - output: `[batch_size, sequence_length, hidden_size]`.
110    pub fn forward(
111        &self,
112        batched_input: Tensor<B, 3>,
113        state: Option<Tensor<B, 3>>,
114    ) -> Tensor<B, 3> {
115        let [batch_size, seq_length, _] = batched_input.shape().dims();
116
117        let mut hidden_state = match state {
118            Some(state) => state,
119            None => Tensor::zeros(
120                [batch_size, seq_length, self.d_hidden],
121                &batched_input.device(),
122            ),
123        };
124
125        for (t, (input_t, hidden_t)) in batched_input
126            .iter_dim(1)
127            .zip(hidden_state.clone().iter_dim(1))
128            .enumerate()
129        {
130            let input_t = input_t.squeeze(1);
131            let hidden_t = hidden_t.squeeze(1);
132            // u(pdate)g(ate) tensors
133            let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate);
134            let update_values = activation::sigmoid(biased_ug_input_sum); // Colloquially referred to as z(t)
135
136            // r(eset)g(ate) tensors
137            let biased_rg_input_sum = self.gate_product(&input_t, &hidden_t, &self.reset_gate);
138            let reset_values = activation::sigmoid(biased_rg_input_sum); // Colloquially referred to as r(t)
139            let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate
140
141            // n(ew)g(ate) tensor
142            let biased_ng_input_sum = self.gate_product(&input_t, &reset_t, &self.new_gate);
143            let candidate_state = biased_ng_input_sum.tanh(); // Colloquially referred to as g(t)
144
145            // calculate linear interpolation between previous hidden state and candidate state:
146            // g(t) * (1 - z(t)) + z(t) * hidden_t
147            let state_vector = candidate_state
148                .clone()
149                .mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) // (1 - z(t)) = -(z(t) - 1)
150                + update_values.clone().mul(hidden_t);
151
152            let current_shape = state_vector.shape().dims;
153            let unsqueezed_shape = [current_shape[0], 1, current_shape[1]];
154            let reshaped_state_vector = state_vector.reshape(unsqueezed_shape);
155            hidden_state = hidden_state.slice_assign(
156                [0..batch_size, t..(t + 1), 0..self.d_hidden],
157                reshaped_state_vector,
158            );
159        }
160
161        hidden_state
162    }
163
164    /// Helper function for performing weighted matrix product for a gate and adds
165    /// bias, if any.
166    ///
167    ///  Mathematically, performs `Wx*X + Wh*H + b`, where:
168    ///     Wx = weight matrix for the connection to input vector X
169    ///     Wh = weight matrix for the connection to hidden state H
170    ///     X = input vector
171    ///     H = hidden state
172    ///     b = bias terms
173    fn gate_product(
174        &self,
175        input: &Tensor<B, 2>,
176        hidden: &Tensor<B, 2>,
177        gate: &GateController<B>,
178    ) -> Tensor<B, 2> {
179        let input_product = input.clone().matmul(gate.input_transform.weight.val());
180        let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val());
181
182        let input_bias = gate
183            .input_transform
184            .bias
185            .as_ref()
186            .map(|bias_param| bias_param.val());
187        let hidden_bias = gate
188            .hidden_transform
189            .bias
190            .as_ref()
191            .map(|bias_param| bias_param.val());
192
193        match (input_bias, hidden_bias) {
194            (Some(input_bias), Some(hidden_bias)) => {
195                input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze()
196            }
197            (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product,
198            (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(),
199            (None, None) => input_product + hidden_product,
200        }
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use crate::tensor::{Distribution, TensorData};
208    use crate::{module::Param, nn::LinearRecord, TestBackend};
209
210    /// Test forward pass with simple input vector.
211    ///
212    /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125
213    /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150
214    /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699
215    ///
216    /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341
217    #[test]
218    fn tests_forward_single_input_single_feature() {
219        TestBackend::seed(0);
220        let config = GruConfig::new(1, 1, false);
221        let device = Default::default();
222        let mut gru = config.init::<TestBackend>(&device);
223
224        fn create_gate_controller(
225            weights: f32,
226            biases: f32,
227            d_input: usize,
228            d_output: usize,
229            bias: bool,
230            initializer: Initializer,
231            device: &<TestBackend as Backend>::Device,
232        ) -> GateController<TestBackend> {
233            let record_1 = LinearRecord {
234                weight: Param::from_data(TensorData::from([[weights]]), device),
235                bias: Some(Param::from_data(TensorData::from([biases]), device)),
236            };
237            let record_2 = LinearRecord {
238                weight: Param::from_data(TensorData::from([[weights]]), device),
239                bias: Some(Param::from_data(TensorData::from([biases]), device)),
240            };
241            gate_controller::GateController::create_with_weights(
242                d_input,
243                d_output,
244                bias,
245                initializer,
246                record_1,
247                record_2,
248            )
249        }
250
251        gru.update_gate = create_gate_controller(
252            0.5,
253            0.0,
254            1,
255            1,
256            false,
257            Initializer::XavierNormal { gain: 1.0 },
258            &device,
259        );
260        gru.reset_gate = create_gate_controller(
261            0.6,
262            0.0,
263            1,
264            1,
265            false,
266            Initializer::XavierNormal { gain: 1.0 },
267            &device,
268        );
269        gru.new_gate = create_gate_controller(
270            0.7,
271            0.0,
272            1,
273            1,
274            false,
275            Initializer::XavierNormal { gain: 1.0 },
276            &device,
277        );
278
279        let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
280
281        let state = gru.forward(input, None);
282
283        let output = state
284            .select(0, Tensor::arange(0..1, &device))
285            .squeeze::<2>(0);
286
287        let expected = TensorData::from([[0.034]]);
288        output.to_data().assert_approx_eq(&expected, 3);
289    }
290
291    #[test]
292    fn test_batched_forward_pass() {
293        let device = Default::default();
294        let gru = GruConfig::new(64, 1024, true).init::<TestBackend>(&device);
295        let batched_input =
296            Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
297
298        let hidden_state = gru.forward(batched_input, None);
299
300        assert_eq!(hidden_state.shape().dims, [8, 10, 1024]);
301    }
302
303    #[test]
304    fn display() {
305        let config = GruConfig::new(2, 8, true);
306
307        let layer = config.init::<TestBackend>(&Default::default());
308
309        assert_eq!(
310            alloc::format!("{}", layer),
311            "Gru {d_input: 2, d_hidden: 8, bias: true, params: 288}"
312        );
313    }
314}