1use crate as burn;
2
3use crate::config::Config;
4use crate::module::Module;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::nn::Initializer;
7use crate::nn::rnn::gate_controller;
8use crate::tensor::Tensor;
9use crate::tensor::activation;
10use crate::tensor::backend::Backend;
11
12use super::gate_controller::GateController;
13
14#[derive(Config)]
16pub struct GruConfig {
17 pub d_input: usize,
19 pub d_hidden: usize,
21 pub bias: bool,
23 #[config(default = "true")]
37 pub reset_after: bool,
38 #[config(default = "Initializer::XavierNormal{gain:1.0}")]
40 pub initializer: Initializer,
41}
42
43#[derive(Module, Debug)]
49#[module(custom_display)]
50pub struct Gru<B: Backend> {
51 pub update_gate: GateController<B>,
53 pub reset_gate: GateController<B>,
55 pub new_gate: GateController<B>,
57 pub d_hidden: usize,
59 pub reset_after: bool,
61}
62
63impl<B: Backend> ModuleDisplay for Gru<B> {
64 fn custom_settings(&self) -> Option<DisplaySettings> {
65 DisplaySettings::new()
66 .with_new_line_after_attribute(false)
67 .optional()
68 }
69
70 fn custom_content(&self, content: Content) -> Option<Content> {
71 let [d_input, _] = self.update_gate.input_transform.weight.shape().dims();
72 let bias = self.update_gate.input_transform.bias.is_some();
73
74 content
75 .add("d_input", &d_input)
76 .add("d_hidden", &self.d_hidden)
77 .add("bias", &bias)
78 .add("reset_after", &self.reset_after)
79 .optional()
80 }
81}
82
83impl GruConfig {
84 pub fn init<B: Backend>(&self, device: &B::Device) -> Gru<B> {
86 let d_output = self.d_hidden;
87
88 let update_gate = gate_controller::GateController::new(
89 self.d_input,
90 d_output,
91 self.bias,
92 self.initializer.clone(),
93 device,
94 );
95 let reset_gate = gate_controller::GateController::new(
96 self.d_input,
97 d_output,
98 self.bias,
99 self.initializer.clone(),
100 device,
101 );
102 let new_gate = gate_controller::GateController::new(
103 self.d_input,
104 d_output,
105 self.bias,
106 self.initializer.clone(),
107 device,
108 );
109
110 Gru {
111 update_gate,
112 reset_gate,
113 new_gate,
114 d_hidden: self.d_hidden,
115 reset_after: self.reset_after,
116 }
117 }
118}
119
120impl<B: Backend> Gru<B> {
121 pub fn forward(
132 &self,
133 batched_input: Tensor<B, 3>,
134 state: Option<Tensor<B, 2>>,
135 ) -> Tensor<B, 3> {
136 let device = batched_input.device();
137 let [batch_size, seq_length, _] = batched_input.shape().dims();
138
139 let mut batched_hidden_state =
140 Tensor::empty([batch_size, seq_length, self.d_hidden], &device);
141
142 let mut hidden_t = match state {
143 Some(state) => state,
144 None => Tensor::zeros([batch_size, self.d_hidden], &device),
145 };
146
147 for (t, input_t) in batched_input.iter_dim(1).enumerate() {
148 let input_t = input_t.squeeze(1);
149 let biased_ug_input_sum =
151 self.gate_product(&input_t, &hidden_t, None, &self.update_gate);
152 let update_values = activation::sigmoid(biased_ug_input_sum); let biased_rg_input_sum =
156 self.gate_product(&input_t, &hidden_t, None, &self.reset_gate);
157 let reset_values = activation::sigmoid(biased_rg_input_sum); let biased_ng_input_sum = if self.reset_after {
161 self.gate_product(&input_t, &hidden_t, Some(&reset_values), &self.new_gate)
162 } else {
163 let reset_t = hidden_t.clone().mul(reset_values); self.gate_product(&input_t, &reset_t, None, &self.new_gate)
165 };
166 let candidate_state = biased_ng_input_sum.tanh(); hidden_t = candidate_state
171 .clone()
172 .mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) + update_values.clone().mul(hidden_t);
174
175 let unsqueezed_hidden_state = hidden_t.clone().unsqueeze_dim(1);
176
177 batched_hidden_state = batched_hidden_state.slice_assign(
178 [0..batch_size, t..(t + 1), 0..self.d_hidden],
179 unsqueezed_hidden_state,
180 );
181 }
182
183 batched_hidden_state
184 }
185
186 fn gate_product(
197 &self,
198 input: &Tensor<B, 2>,
199 hidden: &Tensor<B, 2>,
200 reset: Option<&Tensor<B, 2>>,
201 gate: &GateController<B>,
202 ) -> Tensor<B, 2> {
203 let input_product = input.clone().matmul(gate.input_transform.weight.val());
204 let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val());
205
206 let input_bias = gate
207 .input_transform
208 .bias
209 .as_ref()
210 .map(|bias_param| bias_param.val());
211 let hidden_bias = gate
212 .hidden_transform
213 .bias
214 .as_ref()
215 .map(|bias_param| bias_param.val());
216
217 match (input_bias, hidden_bias, reset) {
218 (Some(input_bias), Some(hidden_bias), Some(r)) => {
219 input_product
220 + input_bias.unsqueeze()
221 + r.clone().mul(hidden_product + hidden_bias.unsqueeze())
222 }
223 (Some(input_bias), Some(hidden_bias), None) => {
224 input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze()
225 }
226 (Some(input_bias), None, Some(r)) => {
227 input_product + input_bias.unsqueeze() + r.clone().mul(hidden_product)
228 }
229 (Some(input_bias), None, None) => {
230 input_product + input_bias.unsqueeze() + hidden_product
231 }
232 (None, Some(hidden_bias), Some(r)) => {
233 input_product + r.clone().mul(hidden_product + hidden_bias.unsqueeze())
234 }
235 (None, Some(hidden_bias), None) => {
236 input_product + hidden_product + hidden_bias.unsqueeze()
237 }
238 (None, None, Some(r)) => input_product + r.clone().mul(hidden_product),
239 (None, None, None) => input_product + hidden_product,
240 }
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::tensor::{Distribution, TensorData};
248 use crate::{TestBackend, module::Param, nn::LinearRecord};
249 use burn_tensor::{Tolerance, ops::FloatElem};
250 type FT = FloatElem<TestBackend>;
251
252 fn init_gru<B: Backend>(reset_after: bool, device: &B::Device) -> Gru<B> {
253 fn create_gate_controller<B: Backend>(
254 weights: f32,
255 biases: f32,
256 d_input: usize,
257 d_output: usize,
258 bias: bool,
259 initializer: Initializer,
260 device: &B::Device,
261 ) -> GateController<B> {
262 let record_1 = LinearRecord {
263 weight: Param::from_data(TensorData::from([[weights]]), device),
264 bias: Some(Param::from_data(TensorData::from([biases]), device)),
265 };
266 let record_2 = LinearRecord {
267 weight: Param::from_data(TensorData::from([[weights]]), device),
268 bias: Some(Param::from_data(TensorData::from([biases]), device)),
269 };
270 gate_controller::GateController::create_with_weights(
271 d_input,
272 d_output,
273 bias,
274 initializer,
275 record_1,
276 record_2,
277 )
278 }
279
280 let config = GruConfig::new(1, 1, false).with_reset_after(reset_after);
281 let mut gru = config.init::<B>(device);
282
283 gru.update_gate = create_gate_controller(
284 0.5,
285 0.0,
286 1,
287 1,
288 false,
289 Initializer::XavierNormal { gain: 1.0 },
290 device,
291 );
292 gru.reset_gate = create_gate_controller(
293 0.6,
294 0.0,
295 1,
296 1,
297 false,
298 Initializer::XavierNormal { gain: 1.0 },
299 device,
300 );
301 gru.new_gate = create_gate_controller(
302 0.7,
303 0.0,
304 1,
305 1,
306 false,
307 Initializer::XavierNormal { gain: 1.0 },
308 device,
309 );
310 gru
311 }
312
313 #[test]
321 fn tests_forward_single_input_single_feature() {
322 TestBackend::seed(0);
323 let device = Default::default();
324 let mut gru = init_gru::<TestBackend>(false, &device);
325
326 let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
327 let expected = TensorData::from([[0.034]]);
328
329 let state = gru.forward(input.clone(), None);
331
332 let output = state
333 .select(0, Tensor::arange(0..1, &device))
334 .squeeze::<2>(0);
335
336 let tolerance = Tolerance::default();
337 output
338 .to_data()
339 .assert_approx_eq::<FT>(&expected, tolerance);
340
341 gru.reset_after = true; let state = gru.forward(input, None);
344
345 let output = state
346 .select(0, Tensor::arange(0..1, &device))
347 .squeeze::<2>(0);
348
349 output
350 .to_data()
351 .assert_approx_eq::<FT>(&expected, tolerance);
352 }
353
354 #[test]
355 fn tests_forward_seq_len_3() {
356 TestBackend::seed(0);
357 let device = Default::default();
358 let mut gru = init_gru::<TestBackend>(true, &device);
359
360 let input =
361 Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device);
362 let expected = TensorData::from([[0.0341], [0.0894], [0.1575]]);
363
364 let result = gru.forward(input.clone(), None);
365 let output = result
366 .select(0, Tensor::arange(0..1, &device))
367 .squeeze::<2>(0);
368
369 let tolerance = Tolerance::default();
370 output
371 .to_data()
372 .assert_approx_eq::<FT>(&expected, tolerance);
373
374 gru.reset_after = false; let state = gru.forward(input, None);
377
378 let output = state
379 .select(0, Tensor::arange(0..1, &device))
380 .squeeze::<2>(0);
381
382 output
383 .to_data()
384 .assert_approx_eq::<FT>(&expected, tolerance);
385 }
386
387 #[test]
388 fn test_batched_forward_pass() {
389 let device = Default::default();
390 let gru = GruConfig::new(64, 1024, true).init::<TestBackend>(&device);
391 let batched_input =
392 Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
393
394 let hidden_state = gru.forward(batched_input, None);
395
396 assert_eq!(hidden_state.shape().dims, [8, 10, 1024]);
397 }
398
399 #[test]
400 fn display() {
401 let config = GruConfig::new(2, 8, true);
402
403 let layer = config.init::<TestBackend>(&Default::default());
404
405 assert_eq!(
406 alloc::format!("{layer}"),
407 "Gru {d_input: 2, d_hidden: 8, bias: true, reset_after: true, params: 288}"
408 );
409 }
410}