1use burn_core as burn;
2
3use super::gate_controller::GateController;
4use crate::activation::{Activation, ActivationConfig};
5use burn::config::Config;
6use burn::module::Initializer;
7use burn::module::Module;
8use burn::module::{Content, DisplaySettings, ModuleDisplay};
9use burn::tensor::Tensor;
10use burn::tensor::backend::Backend;
11
12#[derive(Config, Debug)]
14pub struct GruConfig {
15 pub d_input: usize,
17 pub d_hidden: usize,
19 pub bias: bool,
21 #[config(default = "true")]
35 pub reset_after: bool,
36 #[config(default = "Initializer::XavierNormal{gain:1.0}")]
38 pub initializer: Initializer,
39 #[config(default = "ActivationConfig::Sigmoid")]
42 pub gate_activation: ActivationConfig,
43 #[config(default = "ActivationConfig::Tanh")]
46 pub hidden_activation: ActivationConfig,
47 pub clip: Option<f64>,
51}
52
53#[derive(Module, Debug)]
59#[module(custom_display)]
60pub struct Gru<B: Backend> {
61 pub update_gate: GateController<B>,
63 pub reset_gate: GateController<B>,
65 pub new_gate: GateController<B>,
67 pub d_hidden: usize,
69 pub reset_after: bool,
71 pub gate_activation: Activation<B>,
73 pub hidden_activation: Activation<B>,
75 pub clip: Option<f64>,
77}
78
79impl<B: Backend> ModuleDisplay for Gru<B> {
80 fn custom_settings(&self) -> Option<DisplaySettings> {
81 DisplaySettings::new()
82 .with_new_line_after_attribute(false)
83 .optional()
84 }
85
86 fn custom_content(&self, content: Content) -> Option<Content> {
87 let [d_input, _] = self.update_gate.input_transform.weight.shape().dims();
88 let bias = self.update_gate.input_transform.bias.is_some();
89
90 content
91 .add("d_input", &d_input)
92 .add("d_hidden", &self.d_hidden)
93 .add("bias", &bias)
94 .add("reset_after", &self.reset_after)
95 .optional()
96 }
97}
98
99impl GruConfig {
100 pub fn init<B: Backend>(&self, device: &B::Device) -> Gru<B> {
102 let d_output = self.d_hidden;
103
104 let update_gate = GateController::new(
105 self.d_input,
106 d_output,
107 self.bias,
108 self.initializer.clone(),
109 device,
110 );
111 let reset_gate = GateController::new(
112 self.d_input,
113 d_output,
114 self.bias,
115 self.initializer.clone(),
116 device,
117 );
118 let new_gate = GateController::new(
119 self.d_input,
120 d_output,
121 self.bias,
122 self.initializer.clone(),
123 device,
124 );
125
126 Gru {
127 update_gate,
128 reset_gate,
129 new_gate,
130 d_hidden: self.d_hidden,
131 reset_after: self.reset_after,
132 gate_activation: self.gate_activation.init(device),
133 hidden_activation: self.hidden_activation.init(device),
134 clip: self.clip,
135 }
136 }
137}
138
139impl<B: Backend> Gru<B> {
140 pub fn forward(
151 &self,
152 batched_input: Tensor<B, 3>,
153 state: Option<Tensor<B, 2>>,
154 ) -> Tensor<B, 3> {
155 let device = batched_input.device();
156 let [batch_size, seq_length, _] = batched_input.shape().dims();
157
158 self.forward_iter(
159 batched_input.iter_dim(1).zip(0..seq_length),
160 state,
161 batch_size,
162 seq_length,
163 &device,
164 )
165 .0
166 }
167
168 pub(crate) fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(
183 &self,
184 input_timestep_iter: I,
185 state: Option<Tensor<B, 2>>,
186 batch_size: usize,
187 seq_length: usize,
188 device: &B::Device,
189 ) -> (Tensor<B, 3>, Tensor<B, 2>) {
190 let mut batched_hidden_state =
191 Tensor::empty([batch_size, seq_length, self.d_hidden], device);
192
193 let mut hidden_t = match state {
194 Some(state) => state,
195 None => Tensor::zeros([batch_size, self.d_hidden], device),
196 };
197
198 for (input_t, t) in input_timestep_iter {
199 let input_t = input_t.squeeze_dim(1);
200
201 let biased_ug_input_sum =
203 self.gate_product(&input_t, &hidden_t, None, &self.update_gate);
204 let update_values = self.gate_activation.forward(biased_ug_input_sum);
205
206 let biased_rg_input_sum =
208 self.gate_product(&input_t, &hidden_t, None, &self.reset_gate);
209 let reset_values = self.gate_activation.forward(biased_rg_input_sum);
210
211 let biased_ng_input_sum = if self.reset_after {
213 self.gate_product(&input_t, &hidden_t, Some(&reset_values), &self.new_gate)
214 } else {
215 let reset_t = hidden_t.clone().mul(reset_values);
216 self.gate_product(&input_t, &reset_t, None, &self.new_gate)
217 };
218 let candidate_state = self.hidden_activation.forward(biased_ng_input_sum);
219
220 let one_minus_z = update_values.clone().neg().add_scalar(1.0);
223 hidden_t = candidate_state.mul(one_minus_z) + update_values.mul(hidden_t);
224
225 if let Some(clip) = self.clip {
227 hidden_t = hidden_t.clamp(-clip, clip);
228 }
229
230 let unsqueezed_hidden_state = hidden_t.clone().unsqueeze_dim(1);
231
232 batched_hidden_state = batched_hidden_state.slice_assign(
233 [0..batch_size, t..(t + 1), 0..self.d_hidden],
234 unsqueezed_hidden_state,
235 );
236 }
237
238 (batched_hidden_state, hidden_t)
239 }
240
241 fn gate_product(
252 &self,
253 input: &Tensor<B, 2>,
254 hidden: &Tensor<B, 2>,
255 reset: Option<&Tensor<B, 2>>,
256 gate: &GateController<B>,
257 ) -> Tensor<B, 2> {
258 let input_product = input.clone().matmul(gate.input_transform.weight.val());
259 let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val());
260
261 let input_part = match &gate.input_transform.bias {
262 Some(bias) => input_product + bias.val().unsqueeze(),
263 None => input_product,
264 };
265
266 let hidden_part = match &gate.hidden_transform.bias {
267 Some(bias) => hidden_product + bias.val().unsqueeze(),
268 None => hidden_product,
269 };
270
271 match reset {
272 Some(r) => input_part + r.clone().mul(hidden_part),
273 None => input_part + hidden_part,
274 }
275 }
276}
277
278#[derive(Config, Debug)]
280pub struct BiGruConfig {
281 pub d_input: usize,
283 pub d_hidden: usize,
285 pub bias: bool,
287 #[config(default = "true")]
289 pub reset_after: bool,
290 #[config(default = "Initializer::XavierNormal{gain:1.0}")]
292 pub initializer: Initializer,
293 #[config(default = true)]
296 pub batch_first: bool,
297 #[config(default = "ActivationConfig::Sigmoid")]
299 pub gate_activation: ActivationConfig,
300 #[config(default = "ActivationConfig::Tanh")]
302 pub hidden_activation: ActivationConfig,
303 pub clip: Option<f64>,
305}
306
307#[derive(Module, Debug)]
313#[module(custom_display)]
314pub struct BiGru<B: Backend> {
315 pub forward: Gru<B>,
317 pub reverse: Gru<B>,
319 pub d_hidden: usize,
321 pub batch_first: bool,
324}
325
326impl<B: Backend> ModuleDisplay for BiGru<B> {
327 fn custom_settings(&self) -> Option<DisplaySettings> {
328 DisplaySettings::new()
329 .with_new_line_after_attribute(false)
330 .optional()
331 }
332
333 fn custom_content(&self, content: Content) -> Option<Content> {
334 let [d_input, _] = self
335 .forward
336 .update_gate
337 .input_transform
338 .weight
339 .shape()
340 .dims();
341 let bias = self.forward.update_gate.input_transform.bias.is_some();
342
343 content
344 .add("d_input", &d_input)
345 .add("d_hidden", &self.d_hidden)
346 .add("bias", &bias)
347 .optional()
348 }
349}
350
351impl BiGruConfig {
352 pub fn init<B: Backend>(&self, device: &B::Device) -> BiGru<B> {
354 let base_config = GruConfig::new(self.d_input, self.d_hidden, self.bias)
356 .with_initializer(self.initializer.clone())
357 .with_reset_after(self.reset_after)
358 .with_gate_activation(self.gate_activation.clone())
359 .with_hidden_activation(self.hidden_activation.clone())
360 .with_clip(self.clip);
361
362 BiGru {
363 forward: base_config.clone().init(device),
364 reverse: base_config.init(device),
365 d_hidden: self.d_hidden,
366 batch_first: self.batch_first,
367 }
368 }
369}
370
371impl<B: Backend> BiGru<B> {
372 pub fn forward(
389 &self,
390 batched_input: Tensor<B, 3>,
391 state: Option<Tensor<B, 3>>,
392 ) -> (Tensor<B, 3>, Tensor<B, 3>) {
393 let batched_input = if self.batch_first {
395 batched_input
396 } else {
397 batched_input.swap_dims(0, 1)
398 };
399
400 let device = batched_input.clone().device();
401 let [batch_size, seq_length, _] = batched_input.shape().dims();
402
403 let [init_state_forward, init_state_reverse] = match state {
404 Some(state) => {
405 let hidden_state_forward = state
406 .clone()
407 .slice([0..1, 0..batch_size, 0..self.d_hidden])
408 .squeeze_dim(0);
409 let hidden_state_reverse = state
410 .slice([1..2, 0..batch_size, 0..self.d_hidden])
411 .squeeze_dim(0);
412
413 [Some(hidden_state_forward), Some(hidden_state_reverse)]
414 }
415 None => [None, None],
416 };
417
418 let (batched_hidden_state_forward, final_state_forward) = self.forward.forward_iter(
420 batched_input.clone().iter_dim(1).zip(0..seq_length),
421 init_state_forward,
422 batch_size,
423 seq_length,
424 &device,
425 );
426
427 let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(
429 batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
430 init_state_reverse,
431 batch_size,
432 seq_length,
433 &device,
434 );
435
436 let output = Tensor::cat(
437 [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),
438 2,
439 );
440
441 let output = if self.batch_first {
443 output
444 } else {
445 output.swap_dims(0, 1)
446 };
447
448 let state = Tensor::stack([final_state_forward, final_state_reverse].to_vec(), 0);
449
450 (output, state)
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use crate::{LinearRecord, TestBackend};
458 use burn::module::Param;
459 use burn::tensor::{Distribution, TensorData};
460 use burn::tensor::{Tolerance, ops::FloatElem};
461
462 type FT = FloatElem<TestBackend>;
463
464 fn init_gru<B: Backend>(reset_after: bool, device: &B::Device) -> Gru<B> {
465 fn create_gate_controller<B: Backend>(
466 weights: f32,
467 biases: f32,
468 d_input: usize,
469 d_output: usize,
470 bias: bool,
471 initializer: Initializer,
472 device: &B::Device,
473 ) -> GateController<B> {
474 let record_1 = LinearRecord {
475 weight: Param::from_data(TensorData::from([[weights]]), device),
476 bias: Some(Param::from_data(TensorData::from([biases]), device)),
477 };
478 let record_2 = LinearRecord {
479 weight: Param::from_data(TensorData::from([[weights]]), device),
480 bias: Some(Param::from_data(TensorData::from([biases]), device)),
481 };
482 GateController::create_with_weights(
483 d_input,
484 d_output,
485 bias,
486 initializer,
487 record_1,
488 record_2,
489 )
490 }
491
492 let config = GruConfig::new(1, 1, false).with_reset_after(reset_after);
493 let mut gru = config.init::<B>(device);
494
495 gru.update_gate = create_gate_controller(
496 0.5,
497 0.0,
498 1,
499 1,
500 false,
501 Initializer::XavierNormal { gain: 1.0 },
502 device,
503 );
504 gru.reset_gate = create_gate_controller(
505 0.6,
506 0.0,
507 1,
508 1,
509 false,
510 Initializer::XavierNormal { gain: 1.0 },
511 device,
512 );
513 gru.new_gate = create_gate_controller(
514 0.7,
515 0.0,
516 1,
517 1,
518 false,
519 Initializer::XavierNormal { gain: 1.0 },
520 device,
521 );
522 gru
523 }
524
525 #[test]
533 fn tests_forward_single_input_single_feature() {
534 let device = Default::default();
535 TestBackend::seed(&device, 0);
536
537 let mut gru = init_gru::<TestBackend>(false, &device);
538
539 let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
540 let expected = TensorData::from([[0.034]]);
541
542 let state = gru.forward(input.clone(), None);
544
545 let output = state
546 .select(0, Tensor::arange(0..1, &device))
547 .squeeze_dim::<2>(0);
548
549 let tolerance = Tolerance::default();
550 output
551 .to_data()
552 .assert_approx_eq::<FT>(&expected, tolerance);
553
554 gru.reset_after = true; let state = gru.forward(input, None);
557
558 let output = state
559 .select(0, Tensor::arange(0..1, &device))
560 .squeeze_dim::<2>(0);
561
562 output
563 .to_data()
564 .assert_approx_eq::<FT>(&expected, tolerance);
565 }
566
567 #[test]
568 fn tests_forward_seq_len_3() {
569 let device = Default::default();
570 TestBackend::seed(&device, 0);
571 let mut gru = init_gru::<TestBackend>(true, &device);
572
573 let input =
574 Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device);
575 let expected = TensorData::from([[0.0341], [0.0894], [0.1575]]);
576
577 let result = gru.forward(input.clone(), None);
578 let output = result
579 .select(0, Tensor::arange(0..1, &device))
580 .squeeze_dim::<2>(0);
581
582 let tolerance = Tolerance::default();
583 output
584 .to_data()
585 .assert_approx_eq::<FT>(&expected, tolerance);
586
587 gru.reset_after = false; let state = gru.forward(input, None);
590
591 let output = state
592 .select(0, Tensor::arange(0..1, &device))
593 .squeeze_dim::<2>(0);
594
595 output
596 .to_data()
597 .assert_approx_eq::<FT>(&expected, tolerance);
598 }
599
600 #[test]
601 fn test_batched_forward_pass() {
602 let device = Default::default();
603 let gru = GruConfig::new(64, 1024, true).init::<TestBackend>(&device);
604 let batched_input =
605 Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
606
607 let hidden_state = gru.forward(batched_input, None);
608
609 assert_eq!(&*hidden_state.shape(), [8, 10, 1024]);
610 }
611
612 #[test]
613 fn display() {
614 let config = GruConfig::new(2, 8, true);
615
616 let layer = config.init::<TestBackend>(&Default::default());
617
618 assert_eq!(
619 alloc::format!("{layer}"),
620 "Gru {d_input: 2, d_hidden: 8, bias: true, reset_after: true, params: 288}"
621 );
622 }
623
624 #[test]
625 fn test_bigru_batched_forward_pass() {
626 let device = Default::default();
627 let bigru = BiGruConfig::new(64, 1024, true).init::<TestBackend>(&device);
628 let batched_input =
629 Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
630
631 let (output, state) = bigru.forward(batched_input, None);
632
633 assert_eq!(&*output.shape(), [8, 10, 2048]);
635 assert_eq!(&*state.shape(), [2, 8, 1024]);
637 }
638
639 #[test]
640 fn test_bigru_with_initial_state() {
641 let device = Default::default();
642 let bigru = BiGruConfig::new(32, 64, true).init::<TestBackend>(&device);
643 let batched_input =
644 Tensor::<TestBackend, 3>::random([4, 5, 32], Distribution::Default, &device);
645 let initial_state =
646 Tensor::<TestBackend, 3>::random([2, 4, 64], Distribution::Default, &device);
647
648 let (output, state) = bigru.forward(batched_input, Some(initial_state));
649
650 assert_eq!(&*output.shape(), [4, 5, 128]);
651 assert_eq!(&*state.shape(), [2, 4, 64]);
652 }
653
654 #[test]
655 fn test_bigru_seq_first() {
656 let device = Default::default();
657 let bigru = BiGruConfig::new(32, 64, true)
658 .with_batch_first(false)
659 .init::<TestBackend>(&device);
660 let batched_input =
662 Tensor::<TestBackend, 3>::random([5, 4, 32], Distribution::Default, &device);
663
664 let (output, state) = bigru.forward(batched_input, None);
665
666 assert_eq!(&*output.shape(), [5, 4, 128]);
668 assert_eq!(&*state.shape(), [2, 4, 64]);
669 }
670
671 #[test]
674 fn test_bigru_against_pytorch() {
675 use burn::tensor::Device;
676
677 let device = Default::default();
678 TestBackend::seed(&device, 0);
679
680 let config = BiGruConfig::new(2, 3, true);
681 let mut bigru = config.init::<TestBackend>(&device);
682
683 fn create_gate_controller<const D1: usize, const D2: usize>(
684 input_weights: [[f32; D1]; D2],
685 input_biases: [f32; D1],
686 hidden_weights: [[f32; D1]; D1],
687 hidden_biases: [f32; D1],
688 device: &Device<TestBackend>,
689 ) -> GateController<TestBackend> {
690 let d_input = input_weights[0].len();
691 let d_output = input_weights.len();
692
693 let input_record = LinearRecord {
694 weight: Param::from_data(TensorData::from(input_weights), device),
695 bias: Some(Param::from_data(TensorData::from(input_biases), device)),
696 };
697 let hidden_record = LinearRecord {
698 weight: Param::from_data(TensorData::from(hidden_weights), device),
699 bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
700 };
701 GateController::create_with_weights(
702 d_input,
703 d_output,
704 true,
705 Initializer::XavierUniform { gain: 1.0 },
706 input_record,
707 hidden_record,
708 )
709 }
710
711 let input = Tensor::<TestBackend, 3>::from_data(
712 TensorData::from([[
713 [0.949, -0.861],
714 [0.892, 0.927],
715 [-0.173, -0.301],
716 [-0.081, 0.992],
717 ]]),
718 &device,
719 );
720 let h0 = Tensor::<TestBackend, 3>::from_data(
721 TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),
722 &device,
723 );
724
725 bigru.forward.update_gate = create_gate_controller(
727 [[-0.2811, 0.5090, 0.5018], [0.3391, -0.4236, 0.1081]],
728 [0.2932, -0.3519, -0.5715],
729 [
730 [-0.3471, 0.5214, 0.0961],
731 [0.0545, -0.4904, -0.1875],
732 [-0.5702, 0.4457, 0.3568],
733 ],
734 [-0.0100, 0.4518, -0.4102],
735 &device,
736 );
737
738 bigru.forward.reset_gate = create_gate_controller(
739 [[0.4414, -0.1353, -0.1265], [0.4792, 0.5304, 0.1165]],
740 [-0.2524, 0.3333, 0.1033],
741 [
742 [-0.2695, -0.0677, -0.4557],
743 [0.1472, -0.2345, -0.2662],
744 [-0.2660, 0.3830, -0.1630],
745 ],
746 [0.1663, 0.2391, 0.1826],
747 &device,
748 );
749
750 bigru.forward.new_gate = create_gate_controller(
751 [[0.4266, 0.2784, 0.4451], [0.0782, -0.0815, 0.0853]],
752 [-0.2231, -0.4428, 0.4737],
753 [
754 [0.0900, -0.1821, 0.2430],
755 [0.4665, 0.1551, 0.5155],
756 [0.0631, -0.1566, 0.3337],
757 ],
758 [0.0364, -0.3941, 0.1780],
759 &device,
760 );
761
762 bigru.reverse.update_gate = create_gate_controller(
764 [[-0.3444, 0.1924, -0.4765], [0.5193, 0.5556, -0.5727]],
765 [0.1090, 0.1779, -0.5385],
766 [
767 [0.1221, 0.3925, 0.5287],
768 [-0.1472, -0.4187, -0.1948],
769 [0.3441, -0.3082, -0.2047],
770 ],
771 [0.0016, -0.2148, -0.0400],
772 &device,
773 );
774
775 bigru.reverse.reset_gate = create_gate_controller(
776 [[-0.1988, -0.1203, -0.3422], [0.1769, 0.4788, -0.3443]],
777 [-0.5053, -0.3676, 0.5771],
778 [
779 [-0.3936, 0.3504, -0.4486],
780 [0.3063, -0.1370, -0.2914],
781 [-0.2334, 0.3303, 0.1760],
782 ],
783 [-0.5080, -0.2488, -0.3456],
784 &device,
785 );
786
787 bigru.reverse.new_gate = create_gate_controller(
788 [[-0.4517, 0.2339, 0.4797], [-0.3884, 0.2067, -0.2982]],
789 [-0.3792, -0.1922, 0.0903],
790 [
791 [-0.5586, -0.0762, -0.3944],
792 [-0.3306, -0.4191, -0.4898],
793 [0.1442, 0.0135, -0.3179],
794 ],
795 [-0.3912, -0.3963, -0.3368],
796 &device,
797 );
798
799 let expected_output_with_init = TensorData::from([[
801 [0.24537, 0.14018, 0.19449, -0.49777, -0.15647, 0.48392],
802 [0.27468, -0.14514, 0.56205, -0.60381, -0.04986, 0.15683],
803 [-0.04062, -0.33486, 0.52330, -0.42244, -0.12644, -0.12034],
804 [-0.11743, -0.53873, 0.54429, -0.64943, 0.30127, -0.41943],
805 ]]);
806
807 let expected_hn_with_init = TensorData::from([
808 [[-0.11743, -0.53873, 0.54429]],
809 [[-0.49777, -0.15647, 0.48392]],
810 ]);
811
812 let expected_output_without_init = TensorData::from([[
813 [0.07452, -0.08247, 0.46677, -0.46770, -0.18086, 0.47519],
814 [0.15843, -0.27144, 0.65781, -0.50286, -0.12806, 0.14884],
815 [-0.10704, -0.41573, 0.53954, -0.24794, -0.24003, -0.10294],
816 [-0.16505, -0.57952, 0.53565, -0.23598, -0.07137, -0.28937],
817 ]]);
818
819 let expected_hn_without_init = TensorData::from([
820 [[-0.16505, -0.57952, 0.53565]],
821 [[-0.46770, -0.18086, 0.47519]],
822 ]);
823
824 let (output_with_init, hn_with_init) = bigru.forward(input.clone(), Some(h0));
825 let (output_without_init, hn_without_init) = bigru.forward(input, None);
826
827 let tolerance = Tolerance::permissive();
828 output_with_init
829 .to_data()
830 .assert_approx_eq::<FT>(&expected_output_with_init, tolerance);
831 output_without_init
832 .to_data()
833 .assert_approx_eq::<FT>(&expected_output_without_init, tolerance);
834 hn_with_init
835 .to_data()
836 .assert_approx_eq::<FT>(&expected_hn_with_init, tolerance);
837 hn_without_init
838 .to_data()
839 .assert_approx_eq::<FT>(&expected_hn_without_init, tolerance);
840 }
841
842 #[test]
843 fn bigru_display() {
844 let config = BiGruConfig::new(2, 8, true);
845
846 let layer = config.init::<TestBackend>(&Default::default());
847
848 assert_eq!(
849 alloc::format!("{layer}"),
850 "BiGru {d_input: 2, d_hidden: 8, bias: true, params: 576}"
851 );
852 }
853
854 #[test]
855 fn test_gru_custom_activations() {
856 let device = Default::default();
857
858 let config = GruConfig::new(4, 8, true)
860 .with_gate_activation(ActivationConfig::Relu)
861 .with_hidden_activation(ActivationConfig::Relu);
862 let gru = config.init::<TestBackend>(&device);
863
864 let input = Tensor::<TestBackend, 3>::random([2, 3, 4], Distribution::Default, &device);
865
866 let output = gru.forward(input, None);
868 assert_eq!(&*output.shape(), [2, 3, 8]);
869 }
870
871 #[test]
872 fn test_bigru_custom_activations() {
873 let device = Default::default();
874
875 let config = BiGruConfig::new(4, 8, true)
877 .with_gate_activation(ActivationConfig::Relu)
878 .with_hidden_activation(ActivationConfig::Relu);
879 let bigru = config.init::<TestBackend>(&device);
880
881 let input = Tensor::<TestBackend, 3>::random([2, 3, 4], Distribution::Default, &device);
882
883 let (output, state) = bigru.forward(input, None);
884 assert_eq!(&*output.shape(), [2, 3, 16]); assert_eq!(&*state.shape(), [2, 2, 8]);
886 }
887
888 #[test]
889 fn test_gru_clipping() {
890 let device = Default::default();
891
892 let clip_value = 0.5;
894 let config = GruConfig::new(4, 8, true).with_clip(Some(clip_value));
895 let gru = config.init::<TestBackend>(&device);
896
897 let input = Tensor::<TestBackend, 3>::random([2, 5, 4], Distribution::Default, &device);
898
899 let output = gru.forward(input, None);
900
901 let output_data: Vec<f32> = output.to_data().to_vec().unwrap();
903 for val in output_data {
904 assert!(
905 val >= -clip_value as f32 && val <= clip_value as f32,
906 "Value {} is outside clip range [-{}, {}]",
907 val,
908 clip_value,
909 clip_value
910 );
911 }
912 }
913
914 #[test]
915 fn test_bigru_clipping() {
916 let device = Default::default();
917
918 let clip_value = 0.3;
920 let config = BiGruConfig::new(4, 8, true).with_clip(Some(clip_value));
921 let bigru = config.init::<TestBackend>(&device);
922
923 let input = Tensor::<TestBackend, 3>::random([2, 5, 4], Distribution::Default, &device);
924
925 let (output, state) = bigru.forward(input, None);
926
927 let output_data: Vec<f32> = output.to_data().to_vec().unwrap();
929 for val in output_data {
930 assert!(
931 val >= -clip_value as f32 && val <= clip_value as f32,
932 "Output value {} is outside clip range [-{}, {}]",
933 val,
934 clip_value,
935 clip_value
936 );
937 }
938
939 let state_data: Vec<f32> = state.to_data().to_vec().unwrap();
941 for val in state_data {
942 assert!(
943 val >= -clip_value as f32 && val <= clip_value as f32,
944 "State value {} is outside clip range [-{}, {}]",
945 val,
946 clip_value,
947 clip_value
948 );
949 }
950 }
951
952 #[test]
955 fn test_gru_against_pytorch() {
956 use burn::tensor::Device;
957
958 let device = Default::default();
959 TestBackend::seed(&device, 0);
960
961 let config = GruConfig::new(2, 3, true);
962 let mut gru = config.init::<TestBackend>(&device);
963
964 fn create_gate_controller<const D1: usize, const D2: usize>(
965 input_weights: [[f32; D1]; D2],
966 input_biases: [f32; D1],
967 hidden_weights: [[f32; D1]; D1],
968 hidden_biases: [f32; D1],
969 device: &Device<TestBackend>,
970 ) -> GateController<TestBackend> {
971 let d_input = input_weights[0].len();
972 let d_output = input_weights.len();
973
974 let input_record = LinearRecord {
975 weight: Param::from_data(TensorData::from(input_weights), device),
976 bias: Some(Param::from_data(TensorData::from(input_biases), device)),
977 };
978 let hidden_record = LinearRecord {
979 weight: Param::from_data(TensorData::from(hidden_weights), device),
980 bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
981 };
982 GateController::create_with_weights(
983 d_input,
984 d_output,
985 true,
986 Initializer::XavierUniform { gain: 1.0 },
987 input_record,
988 hidden_record,
989 )
990 }
991
992 let input = Tensor::<TestBackend, 3>::from_data(
994 TensorData::from([[
995 [-0.11147, 0.12036],
996 [-0.36963, -0.24042],
997 [-1.19692, 0.20927],
998 [-0.97236, -0.75505],
999 ]]),
1000 &device,
1001 );
1002
1003 let h0 = Tensor::<TestBackend, 2>::from_data(
1005 TensorData::from([[0.3239, -0.10852, 0.21033]]),
1006 &device,
1007 );
1008
1009 gru.update_gate = create_gate_controller(
1011 [[-0.2811, 0.5090, 0.5018], [0.3391, -0.4236, 0.1081]],
1012 [0.2932, -0.3519, -0.5715],
1013 [
1014 [-0.3471, 0.5214, 0.0961],
1015 [0.0545, -0.4904, -0.1875],
1016 [-0.5702, 0.4457, 0.3568],
1017 ],
1018 [-0.0100, 0.4518, -0.4102],
1019 &device,
1020 );
1021
1022 gru.reset_gate = create_gate_controller(
1024 [[0.4414, -0.1353, -0.1265], [0.4792, 0.5304, 0.1165]],
1025 [-0.2524, 0.3333, 0.1033],
1026 [
1027 [-0.2695, -0.0677, -0.4557],
1028 [0.1472, -0.2345, -0.2662],
1029 [-0.2660, 0.3830, -0.1630],
1030 ],
1031 [0.1663, 0.2391, 0.1826],
1032 &device,
1033 );
1034
1035 gru.new_gate = create_gate_controller(
1037 [[0.4266, 0.2784, 0.4451], [0.0782, -0.0815, 0.0853]],
1038 [-0.2231, -0.4428, 0.4737],
1039 [
1040 [0.0900, -0.1821, 0.2430],
1041 [0.4665, 0.1551, 0.5155],
1042 [0.0631, -0.1566, 0.3337],
1043 ],
1044 [0.0364, -0.3941, 0.1780],
1045 &device,
1046 );
1047
1048 let expected_output_with_h0 = TensorData::from([[
1050 [0.05665, -0.34932, 0.43267],
1051 [-0.1737, -0.49246, 0.38099],
1052 [-0.35401, -0.68099, 0.05061],
1053 [-0.47854, -0.70427, -0.13648],
1054 ]]);
1055
1056 let expected_output_no_h0 = TensorData::from([[
1057 [-0.0985, -0.31661, 0.36126],
1058 [-0.24563, -0.47784, 0.34609],
1059 [-0.39497, -0.67659, 0.03083],
1060 [-0.50146, -0.70066, -0.14894],
1061 ]]);
1062
1063 let output_with_h0 = gru.forward(input.clone(), Some(h0));
1064 let output_no_h0 = gru.forward(input, None);
1065
1066 let tolerance = Tolerance::permissive();
1067 output_with_h0
1068 .to_data()
1069 .assert_approx_eq::<FT>(&expected_output_with_h0, tolerance);
1070 output_no_h0
1071 .to_data()
1072 .assert_approx_eq::<FT>(&expected_output_no_h0, tolerance);
1073 }
1074}