Skip to main content

burn_nn/modules/rnn/
gru.rs

1use burn_core as burn;
2
3use super::gate_controller::GateController;
4use crate::activation::{Activation, ActivationConfig};
5use burn::config::Config;
6use burn::module::Initializer;
7use burn::module::Module;
8use burn::module::{Content, DisplaySettings, ModuleDisplay};
9use burn::tensor::Tensor;
10use burn::tensor::backend::Backend;
11
12/// Configuration to create a [gru](Gru) module using the [init function](GruConfig::init).
13#[derive(Config, Debug)]
14pub struct GruConfig {
15    /// The size of the input features.
16    pub d_input: usize,
17    /// The size of the hidden state.
18    pub d_hidden: usize,
19    /// If a bias should be applied during the Gru transformation.
20    pub bias: bool,
21    /// If reset gate should be applied after weight multiplication.
22    ///
23    /// This configuration option controls how the reset gate is applied to the hidden state.
24    /// * `true` - (Default) Match the initial arXiv version of the paper [Learning Phrase Representations using RNN Encoder-Decoder for
25    ///   Statistical Machine Translation (v1)](https://arxiv.org/abs/1406.1078v1) and apply the reset gate after multiplication by
26    ///   the weights. This matches the behavior of [PyTorch GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU).
27    /// * `false` - Match the most recent revision of [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine
28    ///   Translation (v3)](https://arxiv.org/abs/1406.1078) and apply the reset gate before the weight multiplication.
29    ///
30    /// The differing implementations can give slightly different numerical results and have different efficiencies. For more
31    /// motivation for why the `true` can be more efficient see [Optimizing RNNs with Differentiable Graphs](https://svail.github.io/diff_graphs).
32    ///
33    /// To set this field to `false` use [`with_reset_after`](`GruConfig::with_reset_after`).
34    #[config(default = "true")]
35    pub reset_after: bool,
36    /// Gru initializer
37    #[config(default = "Initializer::XavierNormal{gain:1.0}")]
38    pub initializer: Initializer,
39    /// Activation function for the update and reset gates.
40    /// Default is Sigmoid, which is standard for GRU gates.
41    #[config(default = "ActivationConfig::Sigmoid")]
42    pub gate_activation: ActivationConfig,
43    /// Activation function for the new/candidate gate.
44    /// Default is Tanh, which is standard for GRU.
45    #[config(default = "ActivationConfig::Tanh")]
46    pub hidden_activation: ActivationConfig,
47    /// Optional hidden state clip threshold. If provided, hidden state values are clipped
48    /// to the range `[-clip, +clip]` after each timestep. This can help prevent
49    /// exploding values during inference.
50    pub clip: Option<f64>,
51}
52
53/// The Gru (Gated recurrent unit) module. This implementation is for a unidirectional, stateless, Gru.
54///
55/// Introduced in the paper: [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://arxiv.org/abs/1406.1078).
56///
57/// Should be created with [GruConfig].
58#[derive(Module, Debug)]
59#[module(custom_display)]
60pub struct Gru<B: Backend> {
61    /// The update gate controller.
62    pub update_gate: GateController<B>,
63    /// The reset gate controller.
64    pub reset_gate: GateController<B>,
65    /// The new gate controller.
66    pub new_gate: GateController<B>,
67    /// The size of the hidden state.
68    pub d_hidden: usize,
69    /// If reset gate should be applied after weight multiplication.
70    pub reset_after: bool,
71    /// Activation function for gates (update, reset).
72    pub gate_activation: Activation<B>,
73    /// Activation function for new/candidate gate.
74    pub hidden_activation: Activation<B>,
75    /// Optional hidden state clip threshold.
76    pub clip: Option<f64>,
77}
78
79impl<B: Backend> ModuleDisplay for Gru<B> {
80    fn custom_settings(&self) -> Option<DisplaySettings> {
81        DisplaySettings::new()
82            .with_new_line_after_attribute(false)
83            .optional()
84    }
85
86    fn custom_content(&self, content: Content) -> Option<Content> {
87        let [d_input, _] = self.update_gate.input_transform.weight.shape().dims();
88        let bias = self.update_gate.input_transform.bias.is_some();
89
90        content
91            .add("d_input", &d_input)
92            .add("d_hidden", &self.d_hidden)
93            .add("bias", &bias)
94            .add("reset_after", &self.reset_after)
95            .optional()
96    }
97}
98
99impl GruConfig {
100    /// Initialize a new [gru](Gru) module.
101    pub fn init<B: Backend>(&self, device: &B::Device) -> Gru<B> {
102        let d_output = self.d_hidden;
103
104        let update_gate = GateController::new(
105            self.d_input,
106            d_output,
107            self.bias,
108            self.initializer.clone(),
109            device,
110        );
111        let reset_gate = GateController::new(
112            self.d_input,
113            d_output,
114            self.bias,
115            self.initializer.clone(),
116            device,
117        );
118        let new_gate = GateController::new(
119            self.d_input,
120            d_output,
121            self.bias,
122            self.initializer.clone(),
123            device,
124        );
125
126        Gru {
127            update_gate,
128            reset_gate,
129            new_gate,
130            d_hidden: self.d_hidden,
131            reset_after: self.reset_after,
132            gate_activation: self.gate_activation.init(device),
133            hidden_activation: self.hidden_activation.init(device),
134            clip: self.clip,
135        }
136    }
137}
138
139impl<B: Backend> Gru<B> {
140    /// Applies the forward pass on the input tensor. This GRU implementation
141    /// returns a state tensor with dimensions `[batch_size, sequence_length, hidden_size]`.
142    ///
143    /// # Parameters
144    /// - batched_input: `[batch_size, sequence_length, input_size]`.
145    /// - state: An optional tensor representing an initial cell state with dimensions
146    ///   `[batch_size, hidden_size]`. If none is provided, an empty state will be used.
147    ///
148    /// # Returns
149    /// - output: `[batch_size, sequence_length, hidden_size]`
150    pub fn forward(
151        &self,
152        batched_input: Tensor<B, 3>,
153        state: Option<Tensor<B, 2>>,
154    ) -> Tensor<B, 3> {
155        let device = batched_input.device();
156        let [batch_size, seq_length, _] = batched_input.shape().dims();
157
158        self.forward_iter(
159            batched_input.iter_dim(1).zip(0..seq_length),
160            state,
161            batch_size,
162            seq_length,
163            &device,
164        )
165        .0
166    }
167
168    /// Forward pass variant that accepts an iterator over timesteps.
169    /// Used by BiGru to process sequences in either direction.
170    ///
171    /// # Parameters
172    /// - input_timestep_iter: Iterator yielding (input_tensor, timestep_index) pairs.
173    ///   The timestep_index determines where in the output tensor to store results.
174    /// - state: Optional initial hidden state with shape `[batch_size, hidden_size]`.
175    /// - batch_size: Batch size of the input.
176    /// - seq_length: Sequence length of the input.
177    /// - device: Device to create tensors on.
178    ///
179    /// # Returns
180    /// - output: `[batch_size, sequence_length, hidden_size]`
181    /// - final_hidden: Final hidden state `[batch_size, hidden_size]`
182    pub(crate) fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(
183        &self,
184        input_timestep_iter: I,
185        state: Option<Tensor<B, 2>>,
186        batch_size: usize,
187        seq_length: usize,
188        device: &B::Device,
189    ) -> (Tensor<B, 3>, Tensor<B, 2>) {
190        let mut batched_hidden_state =
191            Tensor::empty([batch_size, seq_length, self.d_hidden], device);
192
193        let mut hidden_t = match state {
194            Some(state) => state,
195            None => Tensor::zeros([batch_size, self.d_hidden], device),
196        };
197
198        for (input_t, t) in input_timestep_iter {
199            let input_t = input_t.squeeze_dim(1);
200
201            // u(pdate)g(ate) tensors
202            let biased_ug_input_sum =
203                self.gate_product(&input_t, &hidden_t, None, &self.update_gate);
204            let update_values = self.gate_activation.forward(biased_ug_input_sum);
205
206            // r(eset)g(ate) tensors
207            let biased_rg_input_sum =
208                self.gate_product(&input_t, &hidden_t, None, &self.reset_gate);
209            let reset_values = self.gate_activation.forward(biased_rg_input_sum);
210
211            // n(ew)g(ate) tensor
212            let biased_ng_input_sum = if self.reset_after {
213                self.gate_product(&input_t, &hidden_t, Some(&reset_values), &self.new_gate)
214            } else {
215                let reset_t = hidden_t.clone().mul(reset_values);
216                self.gate_product(&input_t, &reset_t, None, &self.new_gate)
217            };
218            let candidate_state = self.hidden_activation.forward(biased_ng_input_sum);
219
220            // calculate linear interpolation between previous hidden state and candidate state:
221            // h_t = (1 - z_t) * g_t + z_t * h_{t-1}
222            let one_minus_z = update_values.clone().neg().add_scalar(1.0);
223            hidden_t = candidate_state.mul(one_minus_z) + update_values.mul(hidden_t);
224
225            // Apply hidden state clipping if configured
226            if let Some(clip) = self.clip {
227                hidden_t = hidden_t.clamp(-clip, clip);
228            }
229
230            let unsqueezed_hidden_state = hidden_t.clone().unsqueeze_dim(1);
231
232            batched_hidden_state = batched_hidden_state.slice_assign(
233                [0..batch_size, t..(t + 1), 0..self.d_hidden],
234                unsqueezed_hidden_state,
235            );
236        }
237
238        (batched_hidden_state, hidden_t)
239    }
240
241    /// Helper function for performing weighted matrix product for a gate and adds
242    /// bias, if any, and optionally applies reset to hidden state.
243    ///
244    ///  Mathematically, performs `Wx*X + r .* (Wh*H + b)`, where:
245    ///     Wx = weight matrix for the connection to input vector X
246    ///     Wh = weight matrix for the connection to hidden state H
247    ///     X = input vector
248    ///     H = hidden state
249    ///     b = bias terms
250    ///     r = reset state
251    fn gate_product(
252        &self,
253        input: &Tensor<B, 2>,
254        hidden: &Tensor<B, 2>,
255        reset: Option<&Tensor<B, 2>>,
256        gate: &GateController<B>,
257    ) -> Tensor<B, 2> {
258        let input_product = input.clone().matmul(gate.input_transform.weight.val());
259        let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val());
260
261        let input_part = match &gate.input_transform.bias {
262            Some(bias) => input_product + bias.val().unsqueeze(),
263            None => input_product,
264        };
265
266        let hidden_part = match &gate.hidden_transform.bias {
267            Some(bias) => hidden_product + bias.val().unsqueeze(),
268            None => hidden_product,
269        };
270
271        match reset {
272            Some(r) => input_part + r.clone().mul(hidden_part),
273            None => input_part + hidden_part,
274        }
275    }
276}
277
278/// Configuration to create a [BiGru](BiGru) module using the [init function](BiGruConfig::init).
279#[derive(Config, Debug)]
280pub struct BiGruConfig {
281    /// The size of the input features.
282    pub d_input: usize,
283    /// The size of the hidden state.
284    pub d_hidden: usize,
285    /// If a bias should be applied during the BiGru transformation.
286    pub bias: bool,
287    /// If reset gate should be applied after weight multiplication.
288    #[config(default = "true")]
289    pub reset_after: bool,
290    /// BiGru initializer
291    #[config(default = "Initializer::XavierNormal{gain:1.0}")]
292    pub initializer: Initializer,
293    /// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.
294    /// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.
295    #[config(default = true)]
296    pub batch_first: bool,
297    /// Activation function for the update and reset gates.
298    #[config(default = "ActivationConfig::Sigmoid")]
299    pub gate_activation: ActivationConfig,
300    /// Activation function for the new/candidate gate.
301    #[config(default = "ActivationConfig::Tanh")]
302    pub hidden_activation: ActivationConfig,
303    /// Optional hidden state clip threshold.
304    pub clip: Option<f64>,
305}
306
307/// The BiGru module. This implementation is for Bidirectional GRU.
308///
309/// Based on the paper: [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://arxiv.org/abs/1406.1078).
310///
311/// Should be created with [BiGruConfig].
312#[derive(Module, Debug)]
313#[module(custom_display)]
314pub struct BiGru<B: Backend> {
315    /// GRU for the forward direction.
316    pub forward: Gru<B>,
317    /// GRU for the reverse direction.
318    pub reverse: Gru<B>,
319    /// The size of the hidden state.
320    pub d_hidden: usize,
321    /// If true, input is `[batch_size, seq_length, input_size]`.
322    /// If false, input is `[seq_length, batch_size, input_size]`.
323    pub batch_first: bool,
324}
325
326impl<B: Backend> ModuleDisplay for BiGru<B> {
327    fn custom_settings(&self) -> Option<DisplaySettings> {
328        DisplaySettings::new()
329            .with_new_line_after_attribute(false)
330            .optional()
331    }
332
333    fn custom_content(&self, content: Content) -> Option<Content> {
334        let [d_input, _] = self
335            .forward
336            .update_gate
337            .input_transform
338            .weight
339            .shape()
340            .dims();
341        let bias = self.forward.update_gate.input_transform.bias.is_some();
342
343        content
344            .add("d_input", &d_input)
345            .add("d_hidden", &self.d_hidden)
346            .add("bias", &bias)
347            .optional()
348    }
349}
350
351impl BiGruConfig {
352    /// Initialize a new [Bidirectional GRU](BiGru) module.
353    pub fn init<B: Backend>(&self, device: &B::Device) -> BiGru<B> {
354        // Internal GRUs always use batch_first=true; BiGru handles layout conversion
355        let base_config = GruConfig::new(self.d_input, self.d_hidden, self.bias)
356            .with_initializer(self.initializer.clone())
357            .with_reset_after(self.reset_after)
358            .with_gate_activation(self.gate_activation.clone())
359            .with_hidden_activation(self.hidden_activation.clone())
360            .with_clip(self.clip);
361
362        BiGru {
363            forward: base_config.clone().init(device),
364            reverse: base_config.init(device),
365            d_hidden: self.d_hidden,
366            batch_first: self.batch_first,
367        }
368    }
369}
370
371impl<B: Backend> BiGru<B> {
372    /// Applies the forward pass on the input tensor. This Bidirectional GRU implementation
373    /// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
374    ///
375    /// ## Parameters:
376    /// - batched_input: The input tensor of shape:
377    ///   - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)
378    ///   - `[sequence_length, batch_size, input_size]` if `batch_first` is false
379    /// - state: An optional tensor representing the initial hidden state with shape
380    ///   `[2, batch_size, hidden_size]`. If no initial state is provided, it is initialized to zeros.
381    ///
382    /// ## Returns:
383    /// - output: A tensor representing the output features. Shape:
384    ///   - `[batch_size, sequence_length, hidden_size * 2]` if `batch_first` is true
385    ///   - `[sequence_length, batch_size, hidden_size * 2]` if `batch_first` is false
386    /// - state: The final forward and reverse hidden states stacked along dimension 0
387    ///   with shape `[2, batch_size, hidden_size]`.
388    pub fn forward(
389        &self,
390        batched_input: Tensor<B, 3>,
391        state: Option<Tensor<B, 3>>,
392    ) -> (Tensor<B, 3>, Tensor<B, 3>) {
393        // Convert to batch-first layout internally if needed
394        let batched_input = if self.batch_first {
395            batched_input
396        } else {
397            batched_input.swap_dims(0, 1)
398        };
399
400        let device = batched_input.clone().device();
401        let [batch_size, seq_length, _] = batched_input.shape().dims();
402
403        let [init_state_forward, init_state_reverse] = match state {
404            Some(state) => {
405                let hidden_state_forward = state
406                    .clone()
407                    .slice([0..1, 0..batch_size, 0..self.d_hidden])
408                    .squeeze_dim(0);
409                let hidden_state_reverse = state
410                    .slice([1..2, 0..batch_size, 0..self.d_hidden])
411                    .squeeze_dim(0);
412
413                [Some(hidden_state_forward), Some(hidden_state_reverse)]
414            }
415            None => [None, None],
416        };
417
418        // forward direction
419        let (batched_hidden_state_forward, final_state_forward) = self.forward.forward_iter(
420            batched_input.clone().iter_dim(1).zip(0..seq_length),
421            init_state_forward,
422            batch_size,
423            seq_length,
424            &device,
425        );
426
427        // reverse direction
428        let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(
429            batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
430            init_state_reverse,
431            batch_size,
432            seq_length,
433            &device,
434        );
435
436        let output = Tensor::cat(
437            [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),
438            2,
439        );
440
441        // Convert output back to seq-first layout if needed
442        let output = if self.batch_first {
443            output
444        } else {
445            output.swap_dims(0, 1)
446        };
447
448        let state = Tensor::stack([final_state_forward, final_state_reverse].to_vec(), 0);
449
450        (output, state)
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use crate::{LinearRecord, TestBackend};
458    use burn::module::Param;
459    use burn::tensor::{Distribution, TensorData};
460    use burn::tensor::{Tolerance, ops::FloatElem};
461
462    type FT = FloatElem<TestBackend>;
463
464    fn init_gru<B: Backend>(reset_after: bool, device: &B::Device) -> Gru<B> {
465        fn create_gate_controller<B: Backend>(
466            weights: f32,
467            biases: f32,
468            d_input: usize,
469            d_output: usize,
470            bias: bool,
471            initializer: Initializer,
472            device: &B::Device,
473        ) -> GateController<B> {
474            let record_1 = LinearRecord {
475                weight: Param::from_data(TensorData::from([[weights]]), device),
476                bias: Some(Param::from_data(TensorData::from([biases]), device)),
477            };
478            let record_2 = LinearRecord {
479                weight: Param::from_data(TensorData::from([[weights]]), device),
480                bias: Some(Param::from_data(TensorData::from([biases]), device)),
481            };
482            GateController::create_with_weights(
483                d_input,
484                d_output,
485                bias,
486                initializer,
487                record_1,
488                record_2,
489            )
490        }
491
492        let config = GruConfig::new(1, 1, false).with_reset_after(reset_after);
493        let mut gru = config.init::<B>(device);
494
495        gru.update_gate = create_gate_controller(
496            0.5,
497            0.0,
498            1,
499            1,
500            false,
501            Initializer::XavierNormal { gain: 1.0 },
502            device,
503        );
504        gru.reset_gate = create_gate_controller(
505            0.6,
506            0.0,
507            1,
508            1,
509            false,
510            Initializer::XavierNormal { gain: 1.0 },
511            device,
512        );
513        gru.new_gate = create_gate_controller(
514            0.7,
515            0.0,
516            1,
517            1,
518            false,
519            Initializer::XavierNormal { gain: 1.0 },
520            device,
521        );
522        gru
523    }
524
525    /// Test forward pass with simple input vector.
526    ///
527    /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125
528    /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150
529    /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699
530    ///
531    /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341
532    #[test]
533    fn tests_forward_single_input_single_feature() {
534        let device = Default::default();
535        TestBackend::seed(&device, 0);
536
537        let mut gru = init_gru::<TestBackend>(false, &device);
538
539        let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
540        let expected = TensorData::from([[0.034]]);
541
542        // Reset gate applied to hidden state before the matrix multiplication
543        let state = gru.forward(input.clone(), None);
544
545        let output = state
546            .select(0, Tensor::arange(0..1, &device))
547            .squeeze_dim::<2>(0);
548
549        let tolerance = Tolerance::default();
550        output
551            .to_data()
552            .assert_approx_eq::<FT>(&expected, tolerance);
553
554        // Reset gate applied to hidden state after the matrix multiplication
555        gru.reset_after = true; // override forward behavior
556        let state = gru.forward(input, None);
557
558        let output = state
559            .select(0, Tensor::arange(0..1, &device))
560            .squeeze_dim::<2>(0);
561
562        output
563            .to_data()
564            .assert_approx_eq::<FT>(&expected, tolerance);
565    }
566
567    #[test]
568    fn tests_forward_seq_len_3() {
569        let device = Default::default();
570        TestBackend::seed(&device, 0);
571        let mut gru = init_gru::<TestBackend>(true, &device);
572
573        let input =
574            Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device);
575        let expected = TensorData::from([[0.0341], [0.0894], [0.1575]]);
576
577        let result = gru.forward(input.clone(), None);
578        let output = result
579            .select(0, Tensor::arange(0..1, &device))
580            .squeeze_dim::<2>(0);
581
582        let tolerance = Tolerance::default();
583        output
584            .to_data()
585            .assert_approx_eq::<FT>(&expected, tolerance);
586
587        // Reset gate applied to hidden state before the matrix multiplication
588        gru.reset_after = false; // override forward behavior
589        let state = gru.forward(input, None);
590
591        let output = state
592            .select(0, Tensor::arange(0..1, &device))
593            .squeeze_dim::<2>(0);
594
595        output
596            .to_data()
597            .assert_approx_eq::<FT>(&expected, tolerance);
598    }
599
600    #[test]
601    fn test_batched_forward_pass() {
602        let device = Default::default();
603        let gru = GruConfig::new(64, 1024, true).init::<TestBackend>(&device);
604        let batched_input =
605            Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
606
607        let hidden_state = gru.forward(batched_input, None);
608
609        assert_eq!(&*hidden_state.shape(), [8, 10, 1024]);
610    }
611
612    #[test]
613    fn display() {
614        let config = GruConfig::new(2, 8, true);
615
616        let layer = config.init::<TestBackend>(&Default::default());
617
618        assert_eq!(
619            alloc::format!("{layer}"),
620            "Gru {d_input: 2, d_hidden: 8, bias: true, reset_after: true, params: 288}"
621        );
622    }
623
624    #[test]
625    fn test_bigru_batched_forward_pass() {
626        let device = Default::default();
627        let bigru = BiGruConfig::new(64, 1024, true).init::<TestBackend>(&device);
628        let batched_input =
629            Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
630
631        let (output, state) = bigru.forward(batched_input, None);
632
633        // Output should have hidden_size * 2 features (forward + reverse concatenated)
634        assert_eq!(&*output.shape(), [8, 10, 2048]);
635        // State should have shape [2, batch_size, hidden_size]
636        assert_eq!(&*state.shape(), [2, 8, 1024]);
637    }
638
639    #[test]
640    fn test_bigru_with_initial_state() {
641        let device = Default::default();
642        let bigru = BiGruConfig::new(32, 64, true).init::<TestBackend>(&device);
643        let batched_input =
644            Tensor::<TestBackend, 3>::random([4, 5, 32], Distribution::Default, &device);
645        let initial_state =
646            Tensor::<TestBackend, 3>::random([2, 4, 64], Distribution::Default, &device);
647
648        let (output, state) = bigru.forward(batched_input, Some(initial_state));
649
650        assert_eq!(&*output.shape(), [4, 5, 128]);
651        assert_eq!(&*state.shape(), [2, 4, 64]);
652    }
653
654    #[test]
655    fn test_bigru_seq_first() {
656        let device = Default::default();
657        let bigru = BiGruConfig::new(32, 64, true)
658            .with_batch_first(false)
659            .init::<TestBackend>(&device);
660        // Input shape: [seq_length, batch_size, input_size] when batch_first=false
661        let batched_input =
662            Tensor::<TestBackend, 3>::random([5, 4, 32], Distribution::Default, &device);
663
664        let (output, state) = bigru.forward(batched_input, None);
665
666        // Output shape: [seq_length, batch_size, hidden_size * 2]
667        assert_eq!(&*output.shape(), [5, 4, 128]);
668        assert_eq!(&*state.shape(), [2, 4, 64]);
669    }
670
671    /// Test BiGru against PyTorch reference implementation.
672    /// Expected values computed with PyTorch nn.GRU(bidirectional=True).
673    #[test]
674    fn test_bigru_against_pytorch() {
675        use burn::tensor::Device;
676
677        let device = Default::default();
678        TestBackend::seed(&device, 0);
679
680        let config = BiGruConfig::new(2, 3, true);
681        let mut bigru = config.init::<TestBackend>(&device);
682
683        fn create_gate_controller<const D1: usize, const D2: usize>(
684            input_weights: [[f32; D1]; D2],
685            input_biases: [f32; D1],
686            hidden_weights: [[f32; D1]; D1],
687            hidden_biases: [f32; D1],
688            device: &Device<TestBackend>,
689        ) -> GateController<TestBackend> {
690            let d_input = input_weights[0].len();
691            let d_output = input_weights.len();
692
693            let input_record = LinearRecord {
694                weight: Param::from_data(TensorData::from(input_weights), device),
695                bias: Some(Param::from_data(TensorData::from(input_biases), device)),
696            };
697            let hidden_record = LinearRecord {
698                weight: Param::from_data(TensorData::from(hidden_weights), device),
699                bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
700            };
701            GateController::create_with_weights(
702                d_input,
703                d_output,
704                true,
705                Initializer::XavierUniform { gain: 1.0 },
706                input_record,
707                hidden_record,
708            )
709        }
710
711        let input = Tensor::<TestBackend, 3>::from_data(
712            TensorData::from([[
713                [0.949, -0.861],
714                [0.892, 0.927],
715                [-0.173, -0.301],
716                [-0.081, 0.992],
717            ]]),
718            &device,
719        );
720        let h0 = Tensor::<TestBackend, 3>::from_data(
721            TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),
722            &device,
723        );
724
725        // Forward GRU gates (weights from PyTorch with seed 42, transposed for burn)
726        bigru.forward.update_gate = create_gate_controller(
727            [[-0.2811, 0.5090, 0.5018], [0.3391, -0.4236, 0.1081]],
728            [0.2932, -0.3519, -0.5715],
729            [
730                [-0.3471, 0.5214, 0.0961],
731                [0.0545, -0.4904, -0.1875],
732                [-0.5702, 0.4457, 0.3568],
733            ],
734            [-0.0100, 0.4518, -0.4102],
735            &device,
736        );
737
738        bigru.forward.reset_gate = create_gate_controller(
739            [[0.4414, -0.1353, -0.1265], [0.4792, 0.5304, 0.1165]],
740            [-0.2524, 0.3333, 0.1033],
741            [
742                [-0.2695, -0.0677, -0.4557],
743                [0.1472, -0.2345, -0.2662],
744                [-0.2660, 0.3830, -0.1630],
745            ],
746            [0.1663, 0.2391, 0.1826],
747            &device,
748        );
749
750        bigru.forward.new_gate = create_gate_controller(
751            [[0.4266, 0.2784, 0.4451], [0.0782, -0.0815, 0.0853]],
752            [-0.2231, -0.4428, 0.4737],
753            [
754                [0.0900, -0.1821, 0.2430],
755                [0.4665, 0.1551, 0.5155],
756                [0.0631, -0.1566, 0.3337],
757            ],
758            [0.0364, -0.3941, 0.1780],
759            &device,
760        );
761
762        // Reverse GRU gates
763        bigru.reverse.update_gate = create_gate_controller(
764            [[-0.3444, 0.1924, -0.4765], [0.5193, 0.5556, -0.5727]],
765            [0.1090, 0.1779, -0.5385],
766            [
767                [0.1221, 0.3925, 0.5287],
768                [-0.1472, -0.4187, -0.1948],
769                [0.3441, -0.3082, -0.2047],
770            ],
771            [0.0016, -0.2148, -0.0400],
772            &device,
773        );
774
775        bigru.reverse.reset_gate = create_gate_controller(
776            [[-0.1988, -0.1203, -0.3422], [0.1769, 0.4788, -0.3443]],
777            [-0.5053, -0.3676, 0.5771],
778            [
779                [-0.3936, 0.3504, -0.4486],
780                [0.3063, -0.1370, -0.2914],
781                [-0.2334, 0.3303, 0.1760],
782            ],
783            [-0.5080, -0.2488, -0.3456],
784            &device,
785        );
786
787        bigru.reverse.new_gate = create_gate_controller(
788            [[-0.4517, 0.2339, 0.4797], [-0.3884, 0.2067, -0.2982]],
789            [-0.3792, -0.1922, 0.0903],
790            [
791                [-0.5586, -0.0762, -0.3944],
792                [-0.3306, -0.4191, -0.4898],
793                [0.1442, 0.0135, -0.3179],
794            ],
795            [-0.3912, -0.3963, -0.3368],
796            &device,
797        );
798
799        // Expected values from PyTorch
800        let expected_output_with_init = TensorData::from([[
801            [0.24537, 0.14018, 0.19449, -0.49777, -0.15647, 0.48392],
802            [0.27468, -0.14514, 0.56205, -0.60381, -0.04986, 0.15683],
803            [-0.04062, -0.33486, 0.52330, -0.42244, -0.12644, -0.12034],
804            [-0.11743, -0.53873, 0.54429, -0.64943, 0.30127, -0.41943],
805        ]]);
806
807        let expected_hn_with_init = TensorData::from([
808            [[-0.11743, -0.53873, 0.54429]],
809            [[-0.49777, -0.15647, 0.48392]],
810        ]);
811
812        let expected_output_without_init = TensorData::from([[
813            [0.07452, -0.08247, 0.46677, -0.46770, -0.18086, 0.47519],
814            [0.15843, -0.27144, 0.65781, -0.50286, -0.12806, 0.14884],
815            [-0.10704, -0.41573, 0.53954, -0.24794, -0.24003, -0.10294],
816            [-0.16505, -0.57952, 0.53565, -0.23598, -0.07137, -0.28937],
817        ]]);
818
819        let expected_hn_without_init = TensorData::from([
820            [[-0.16505, -0.57952, 0.53565]],
821            [[-0.46770, -0.18086, 0.47519]],
822        ]);
823
824        let (output_with_init, hn_with_init) = bigru.forward(input.clone(), Some(h0));
825        let (output_without_init, hn_without_init) = bigru.forward(input, None);
826
827        let tolerance = Tolerance::permissive();
828        output_with_init
829            .to_data()
830            .assert_approx_eq::<FT>(&expected_output_with_init, tolerance);
831        output_without_init
832            .to_data()
833            .assert_approx_eq::<FT>(&expected_output_without_init, tolerance);
834        hn_with_init
835            .to_data()
836            .assert_approx_eq::<FT>(&expected_hn_with_init, tolerance);
837        hn_without_init
838            .to_data()
839            .assert_approx_eq::<FT>(&expected_hn_without_init, tolerance);
840    }
841
842    #[test]
843    fn bigru_display() {
844        let config = BiGruConfig::new(2, 8, true);
845
846        let layer = config.init::<TestBackend>(&Default::default());
847
848        assert_eq!(
849            alloc::format!("{layer}"),
850            "BiGru {d_input: 2, d_hidden: 8, bias: true, params: 576}"
851        );
852    }
853
854    #[test]
855    fn test_gru_custom_activations() {
856        let device = Default::default();
857
858        // Create GRU with custom activations (ReLU instead of Sigmoid/Tanh)
859        let config = GruConfig::new(4, 8, true)
860            .with_gate_activation(ActivationConfig::Relu)
861            .with_hidden_activation(ActivationConfig::Relu);
862        let gru = config.init::<TestBackend>(&device);
863
864        let input = Tensor::<TestBackend, 3>::random([2, 3, 4], Distribution::Default, &device);
865
866        // Should run without panicking and produce valid output
867        let output = gru.forward(input, None);
868        assert_eq!(&*output.shape(), [2, 3, 8]);
869    }
870
871    #[test]
872    fn test_bigru_custom_activations() {
873        let device = Default::default();
874
875        // Create BiGRU with custom activations
876        let config = BiGruConfig::new(4, 8, true)
877            .with_gate_activation(ActivationConfig::Relu)
878            .with_hidden_activation(ActivationConfig::Relu);
879        let bigru = config.init::<TestBackend>(&device);
880
881        let input = Tensor::<TestBackend, 3>::random([2, 3, 4], Distribution::Default, &device);
882
883        let (output, state) = bigru.forward(input, None);
884        assert_eq!(&*output.shape(), [2, 3, 16]); // hidden_size * 2
885        assert_eq!(&*state.shape(), [2, 2, 8]);
886    }
887
888    #[test]
889    fn test_gru_clipping() {
890        let device = Default::default();
891
892        // Create GRU with clipping enabled
893        let clip_value = 0.5;
894        let config = GruConfig::new(4, 8, true).with_clip(Some(clip_value));
895        let gru = config.init::<TestBackend>(&device);
896
897        let input = Tensor::<TestBackend, 3>::random([2, 5, 4], Distribution::Default, &device);
898
899        let output = gru.forward(input, None);
900
901        // Verify output values are within the clip range
902        let output_data: Vec<f32> = output.to_data().to_vec().unwrap();
903        for val in output_data {
904            assert!(
905                val >= -clip_value as f32 && val <= clip_value as f32,
906                "Value {} is outside clip range [-{}, {}]",
907                val,
908                clip_value,
909                clip_value
910            );
911        }
912    }
913
914    #[test]
915    fn test_bigru_clipping() {
916        let device = Default::default();
917
918        // Create BiGRU with clipping enabled
919        let clip_value = 0.3;
920        let config = BiGruConfig::new(4, 8, true).with_clip(Some(clip_value));
921        let bigru = config.init::<TestBackend>(&device);
922
923        let input = Tensor::<TestBackend, 3>::random([2, 5, 4], Distribution::Default, &device);
924
925        let (output, state) = bigru.forward(input, None);
926
927        // Verify output values are within the clip range
928        let output_data: Vec<f32> = output.to_data().to_vec().unwrap();
929        for val in output_data {
930            assert!(
931                val >= -clip_value as f32 && val <= clip_value as f32,
932                "Output value {} is outside clip range [-{}, {}]",
933                val,
934                clip_value,
935                clip_value
936            );
937        }
938
939        // Verify state values are within the clip range
940        let state_data: Vec<f32> = state.to_data().to_vec().unwrap();
941        for val in state_data {
942            assert!(
943                val >= -clip_value as f32 && val <= clip_value as f32,
944                "State value {} is outside clip range [-{}, {}]",
945                val,
946                clip_value,
947                clip_value
948            );
949        }
950    }
951
952    /// Test Gru against PyTorch reference implementation.
953    /// Expected values computed with PyTorch nn.GRU (seed=42 for weights, seed=123 for input).
954    #[test]
955    fn test_gru_against_pytorch() {
956        use burn::tensor::Device;
957
958        let device = Default::default();
959        TestBackend::seed(&device, 0);
960
961        let config = GruConfig::new(2, 3, true);
962        let mut gru = config.init::<TestBackend>(&device);
963
964        fn create_gate_controller<const D1: usize, const D2: usize>(
965            input_weights: [[f32; D1]; D2],
966            input_biases: [f32; D1],
967            hidden_weights: [[f32; D1]; D1],
968            hidden_biases: [f32; D1],
969            device: &Device<TestBackend>,
970        ) -> GateController<TestBackend> {
971            let d_input = input_weights[0].len();
972            let d_output = input_weights.len();
973
974            let input_record = LinearRecord {
975                weight: Param::from_data(TensorData::from(input_weights), device),
976                bias: Some(Param::from_data(TensorData::from(input_biases), device)),
977            };
978            let hidden_record = LinearRecord {
979                weight: Param::from_data(TensorData::from(hidden_weights), device),
980                bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
981            };
982            GateController::create_with_weights(
983                d_input,
984                d_output,
985                true,
986                Initializer::XavierUniform { gain: 1.0 },
987                input_record,
988                hidden_record,
989            )
990        }
991
992        // Input: [batch=1, seq=4, input=2]
993        let input = Tensor::<TestBackend, 3>::from_data(
994            TensorData::from([[
995                [-0.11147, 0.12036],
996                [-0.36963, -0.24042],
997                [-1.19692, 0.20927],
998                [-0.97236, -0.75505],
999            ]]),
1000            &device,
1001        );
1002
1003        // Initial hidden state: [batch=1, hidden=3]
1004        let h0 = Tensor::<TestBackend, 2>::from_data(
1005            TensorData::from([[0.3239, -0.10852, 0.21033]]),
1006            &device,
1007        );
1008
1009        // Update gate (z) - weights from PyTorch, transposed for Burn's Row layout
1010        gru.update_gate = create_gate_controller(
1011            [[-0.2811, 0.5090, 0.5018], [0.3391, -0.4236, 0.1081]],
1012            [0.2932, -0.3519, -0.5715],
1013            [
1014                [-0.3471, 0.5214, 0.0961],
1015                [0.0545, -0.4904, -0.1875],
1016                [-0.5702, 0.4457, 0.3568],
1017            ],
1018            [-0.0100, 0.4518, -0.4102],
1019            &device,
1020        );
1021
1022        // Reset gate (r)
1023        gru.reset_gate = create_gate_controller(
1024            [[0.4414, -0.1353, -0.1265], [0.4792, 0.5304, 0.1165]],
1025            [-0.2524, 0.3333, 0.1033],
1026            [
1027                [-0.2695, -0.0677, -0.4557],
1028                [0.1472, -0.2345, -0.2662],
1029                [-0.2660, 0.3830, -0.1630],
1030            ],
1031            [0.1663, 0.2391, 0.1826],
1032            &device,
1033        );
1034
1035        // New gate (n)
1036        gru.new_gate = create_gate_controller(
1037            [[0.4266, 0.2784, 0.4451], [0.0782, -0.0815, 0.0853]],
1038            [-0.2231, -0.4428, 0.4737],
1039            [
1040                [0.0900, -0.1821, 0.2430],
1041                [0.4665, 0.1551, 0.5155],
1042                [0.0631, -0.1566, 0.3337],
1043            ],
1044            [0.0364, -0.3941, 0.1780],
1045            &device,
1046        );
1047
1048        // Expected values from PyTorch
1049        let expected_output_with_h0 = TensorData::from([[
1050            [0.05665, -0.34932, 0.43267],
1051            [-0.1737, -0.49246, 0.38099],
1052            [-0.35401, -0.68099, 0.05061],
1053            [-0.47854, -0.70427, -0.13648],
1054        ]]);
1055
1056        let expected_output_no_h0 = TensorData::from([[
1057            [-0.0985, -0.31661, 0.36126],
1058            [-0.24563, -0.47784, 0.34609],
1059            [-0.39497, -0.67659, 0.03083],
1060            [-0.50146, -0.70066, -0.14894],
1061        ]]);
1062
1063        let output_with_h0 = gru.forward(input.clone(), Some(h0));
1064        let output_no_h0 = gru.forward(input, None);
1065
1066        let tolerance = Tolerance::permissive();
1067        output_with_h0
1068            .to_data()
1069            .assert_approx_eq::<FT>(&expected_output_with_h0, tolerance);
1070        output_no_h0
1071            .to_data()
1072            .assert_approx_eq::<FT>(&expected_output_no_h0, tolerance);
1073    }
1074}