Skip to main content

burn_nn/modules/rnn/
basic.rs

1use burn_core as burn;
2
3use crate::GateController;
4use crate::activation::{Activation, ActivationConfig};
5use burn::config::Config;
6use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9
10/// A RnnState is used to store hidden state in RNN.
11pub struct RnnState<B: Backend, const D: usize> {
12    /// The hidden state.
13    pub hidden: Tensor<B, D>,
14}
15
16impl<B: Backend, const D: usize> RnnState<B, D> {
17    /// Initialize a new [RNN State](RnnState).
18    pub fn new(hidden: Tensor<B, D>) -> Self {
19        Self { hidden }
20    }
21}
22
23/// Configuration to create a [Rnn](Rnn) module using the [init function](RnnConfig::init).
24#[derive(Config, Debug)]
25pub struct RnnConfig {
26    /// The size of the input features.
27    pub d_input: usize,
28    /// The size of the hidden state.
29    pub d_hidden: usize,
30    /// If a bias should be applied during the Rnn transformation.
31    pub bias: bool,
32    /// Rnn initializer
33    #[config(default = "Initializer::XavierNormal{gain:1.0}")]
34    pub initializer: Initializer,
35    /// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.
36    /// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.
37    #[config(default = true)]
38    pub batch_first: bool,
39    /// If true, process the sequence in reverse order.
40    /// This is useful for implementing reverse-direction RNNs (e.g., ONNX reverse direction).
41    #[config(default = false)]
42    pub reverse: bool,
43    /// Optional hidden state clip threshold. If provided, hidden state values are clipped
44    /// to the range `[-clip, +clip]` after each timestep. This can help prevent
45    /// exploding values during inference.
46    pub clip: Option<f64>,
47    /// Activation function applied to the hidden state before computing hidden output.
48    /// Default is Tanh, which is standard for Rnn.
49    #[config(default = "ActivationConfig::Tanh")]
50    pub hidden_activation: ActivationConfig,
51}
52
53/// The Rnn module. This implementation is for a unidirectional, stateless, Rnn.
54/// Should be created with [RnnConfig].
55#[derive(Module, Debug)]
56#[module(custom_display)]
57pub struct Rnn<B: Backend> {
58    /// gate controller for Rnn (has single gate).
59    pub gate: GateController<B>,
60    /// The hidden state of the Rnn.
61    pub d_hidden: usize,
62    /// If true, input is `[batch_size, seq_length, input_size]`.
63    /// If false, input is `[seq_length, batch_size, input_size]`.
64    pub batch_first: bool,
65    /// If true, process the sequence in reverse order.
66    pub reverse: bool,
67    /// Optional hidden state clip threshold.
68    pub clip: Option<f64>,
69    /// Activation function for hidden output.
70    pub hidden_activation: Activation<B>,
71}
72
73impl<B: Backend> ModuleDisplay for Rnn<B> {
74    fn custom_settings(&self) -> Option<DisplaySettings> {
75        DisplaySettings::new()
76            .with_new_line_after_attribute(false)
77            .optional()
78    }
79
80    fn custom_content(&self, content: Content) -> Option<Content> {
81        let [d_input, _] = self.gate.input_transform.weight.shape().dims();
82        let bias = self.gate.input_transform.bias.is_some();
83
84        content
85            .add("d_input", &d_input)
86            .add("d_hidden", &self.d_hidden)
87            .add("bias", &bias)
88            .optional()
89    }
90}
91
92impl RnnConfig {
93    /// Initialize a new [Rnn](Rnn) module.
94    pub fn init<B: Backend>(&self, device: &B::Device) -> Rnn<B> {
95        let d_output = self.d_hidden;
96
97        let new_gate = || {
98            GateController::new(
99                self.d_input,
100                d_output,
101                self.bias,
102                self.initializer.clone(),
103                device,
104            )
105        };
106
107        Rnn {
108            gate: new_gate(),
109            d_hidden: self.d_hidden,
110            batch_first: self.batch_first,
111            reverse: self.reverse,
112            clip: self.clip,
113            hidden_activation: self.hidden_activation.init(device),
114        }
115    }
116}
117
118impl<B: Backend> Rnn<B> {
119    /// Applies the forward pass on the input tensor. This RNN implementation
120    /// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
121    ///
122    /// ## Parameters:
123    /// - batched_input: The input tensor of shape:
124    ///   - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)
125    ///   - `[sequence_length, batch_size, input_size]` if `batch_first` is false
126    /// - state: An optional `RnnState` representing the initial hidden state.
127    ///   The state tensor has shape `[batch_size, hidden_size]`.
128    ///   If no initial state is provided, these tensors are initialized to zeros.
129    ///
130    /// ## Returns:
131    /// - output: A tensor represents the output features of Rnn. Shape:
132    ///   - `[batch_size, sequence_length, hidden_size]` if `batch_first` is true
133    ///   - `[sequence_length, batch_size, hidden_size]` if `batch_first` is false
134    /// - state: A `RnnState` represents the final hidden state. The hidden state tensor has the shape
135    ///   `[batch_size, hidden_size]`.
136    pub fn forward(
137        &self,
138        batched_input: Tensor<B, 3>,
139        state: Option<RnnState<B, 2>>,
140    ) -> (Tensor<B, 3>, RnnState<B, 2>) {
141        // Convert to batch-first layout internally if needed
142        let batched_input = if self.batch_first {
143            batched_input
144        } else {
145            batched_input.swap_dims(0, 1)
146        };
147
148        let device = batched_input.device();
149        let [batch_size, seq_length, _] = batched_input.dims();
150
151        // Process sequence in forward or reverse order based on config
152        let (output, state) = if self.reverse {
153            self.forward_iter(
154                batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
155                state,
156                batch_size,
157                seq_length,
158                &device,
159            )
160        } else {
161            self.forward_iter(
162                batched_input.iter_dim(1).zip(0..seq_length),
163                state,
164                batch_size,
165                seq_length,
166                &device,
167            )
168        };
169
170        // Convert output back to seq-first layout if needed
171        let output = if self.batch_first {
172            output
173        } else {
174            output.swap_dims(0, 1)
175        };
176
177        (output, state)
178    }
179
180    fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(
181        &self,
182        input_timestep_iter: I,
183        state: Option<RnnState<B, 2>>,
184        batch_size: usize,
185        seq_length: usize,
186        device: &B::Device,
187    ) -> (Tensor<B, 3>, RnnState<B, 2>) {
188        let mut batched_hidden_state =
189            Tensor::empty([batch_size, seq_length, self.d_hidden], device);
190
191        let mut hidden_state = match state {
192            Some(state) => state.hidden,
193            None => Tensor::zeros([batch_size, self.d_hidden], device),
194        };
195
196        for (input_t, t) in input_timestep_iter {
197            let input_t = input_t.squeeze_dim(1);
198
199            // Compute gate output: h_t = activation(W_i @ x_t + W_h @ h_{t-1} + b)
200            let biased_gate_sum = self
201                .gate
202                .gate_product(input_t.clone(), hidden_state.clone());
203
204            let output_values = self.hidden_activation.forward(biased_gate_sum);
205
206            // Update hidden state
207            hidden_state = output_values;
208
209            // Apply hidden state clipping if configured
210            if let Some(clip) = self.clip {
211                hidden_state = hidden_state.clamp(-clip, clip);
212            }
213
214            let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1);
215
216            // store the hidden state for this timestep
217            batched_hidden_state = batched_hidden_state.slice_assign(
218                [0..batch_size, t..(t + 1), 0..self.d_hidden],
219                unsqueezed_hidden_state.clone(),
220            );
221        }
222
223        (batched_hidden_state, RnnState::new(hidden_state))
224    }
225}
226
227/// Configuration to create a [BiRnn](BiRnn) module using the [init function](BiRnnConfig::init).
228#[derive(Config, Debug)]
229pub struct BiRnnConfig {
230    /// The size of the input features.
231    pub d_input: usize,
232    /// The size of the hidden state.
233    pub d_hidden: usize,
234    /// If a bias should be applied during the BiRnn transformation.
235    pub bias: bool,
236    /// BiRnn initializer
237    #[config(default = "Initializer::XavierNormal{gain:1.0}")]
238    pub initializer: Initializer,
239    /// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.
240    /// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.
241    #[config(default = true)]
242    pub batch_first: bool,
243    /// Optional hidden state clip threshold.
244    pub clip: Option<f64>,
245    /// Activation function applied to the hidden state before computing hidden output.
246    #[config(default = "ActivationConfig::Tanh")]
247    pub hidden_activation: ActivationConfig,
248}
249
250/// The BiRnn module. This implementation is for Bidirectional RNN.
251/// Should be created with [BiRnnConfig].
252#[derive(Module, Debug)]
253#[module(custom_display)]
254pub struct BiRnn<B: Backend> {
255    /// RNN for the forward direction.
256    pub forward: Rnn<B>,
257    /// RNN for the reverse direction.
258    pub reverse: Rnn<B>,
259    /// The size of the hidden state.
260    pub d_hidden: usize,
261    /// If true, input is `[batch_size, seq_length, input_size]`.
262    /// If false, input is `[seq_length, batch_size, input_size]`.
263    pub batch_first: bool,
264}
265
266impl<B: Backend> ModuleDisplay for BiRnn<B> {
267    fn custom_settings(&self) -> Option<DisplaySettings> {
268        DisplaySettings::new()
269            .with_new_line_after_attribute(false)
270            .optional()
271    }
272
273    fn custom_content(&self, content: Content) -> Option<Content> {
274        let [d_input, _] = self.forward.gate.input_transform.weight.shape().dims();
275        let bias = self.forward.gate.input_transform.bias.is_some();
276
277        content
278            .add("d_input", &d_input)
279            .add("d_hidden", &self.d_hidden)
280            .add("bias", &bias)
281            .optional()
282    }
283}
284
285impl BiRnnConfig {
286    /// Initialize a new [Bidirectional RNN](BiRnn) module.
287    pub fn init<B: Backend>(&self, device: &B::Device) -> BiRnn<B> {
288        // Internal RNNs always use batch_first=true; BiRnn handles layout conversion
289        let base_config = RnnConfig::new(self.d_input, self.d_hidden, self.bias)
290            .with_initializer(self.initializer.clone())
291            .with_batch_first(true)
292            .with_clip(self.clip)
293            .with_hidden_activation(self.hidden_activation.clone());
294
295        BiRnn {
296            forward: base_config.clone().init(device),
297            reverse: base_config.init(device),
298            d_hidden: self.d_hidden,
299            batch_first: self.batch_first,
300        }
301    }
302}
303
304impl<B: Backend> BiRnn<B> {
305    /// Applies the forward pass on the input tensor. This Bidirectional RNN implementation
306    /// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
307    ///
308    /// ## Parameters:
309    /// - batched_input: The input tensor of shape:
310    ///   - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)
311    ///   - `[sequence_length, batch_size, input_size]` if `batch_first` is false
312    /// - state: An optional `RnnState` representing the hidden state.
313    ///   Each state tensor has shape `[2, batch_size, hidden_size]`.
314    ///   If no initial state is provided, these tensors are initialized to zeros.
315    ///
316    /// ## Returns:
317    /// - output: A tensor represents the output features of RNN. Shape:
318    ///   - `[batch_size, sequence_length, hidden_size * 2]` if `batch_first` is true
319    ///   - `[sequence_length, batch_size, hidden_size * 2]` if `batch_first` is false
320    /// - state: A `RnnState` represents the final forward and reverse states.
321    ///   The `state.hidden` have the shape `[2, batch_size, hidden_size]`.
322    pub fn forward(
323        &self,
324        batched_input: Tensor<B, 3>,
325        state: Option<RnnState<B, 3>>,
326    ) -> (Tensor<B, 3>, RnnState<B, 3>) {
327        // Convert to batch-first layout internally if needed
328        let batched_input = if self.batch_first {
329            batched_input
330        } else {
331            batched_input.swap_dims(0, 1)
332        };
333
334        let device = batched_input.clone().device();
335        let [batch_size, seq_length, _] = batched_input.shape().dims();
336
337        let [init_state_forward, init_state_reverse] = match state {
338            Some(state) => {
339                let hidden_state_forward = state
340                    .hidden
341                    .clone()
342                    .slice([0..1, 0..batch_size, 0..self.d_hidden])
343                    .squeeze_dim(0);
344                let hidden_state_reverse = state
345                    .hidden
346                    .slice([1..2, 0..batch_size, 0..self.d_hidden])
347                    .squeeze_dim(0);
348
349                [
350                    Some(RnnState::new(hidden_state_forward)),
351                    Some(RnnState::new(hidden_state_reverse)),
352                ]
353            }
354            None => [None, None],
355        };
356
357        // forward direction
358        let (batched_hidden_state_forward, final_state_forward) = self
359            .forward
360            .forward(batched_input.clone(), init_state_forward);
361
362        // reverse direction
363        let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(
364            batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
365            init_state_reverse,
366            batch_size,
367            seq_length,
368            &device,
369        );
370
371        let output = Tensor::cat(
372            [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),
373            2,
374        );
375
376        // Convert output back to seq-first layout if needed
377        let output = if self.batch_first {
378            output
379        } else {
380            output.swap_dims(0, 1)
381        };
382
383        let state = RnnState::new(Tensor::stack(
384            [final_state_forward.hidden, final_state_reverse.hidden].to_vec(),
385            0,
386        ));
387
388        (output, state)
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use crate::{LinearRecord, TestBackend};
396    use burn::module::Param;
397    use burn::tensor::{Device, Distribution, TensorData};
398    use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
399    type FT = FloatElem<TestBackend>;
400
401    #[cfg(feature = "std")]
402    use crate::TestAutodiffBackend;
403
404    fn create_single_feature_gate_controller(
405        weights: f32,
406        biases: f32,
407        d_input: usize,
408        d_output: usize,
409        bias: bool,
410        initializer: Initializer,
411        device: &Device<TestBackend>,
412    ) -> GateController<TestBackend> {
413        let record_1 = LinearRecord {
414            weight: Param::from_data(TensorData::from([[weights]]), device),
415            bias: Some(Param::from_data(TensorData::from([biases]), device)),
416        };
417        let record_2 = LinearRecord {
418            weight: Param::from_data(TensorData::from([[weights]]), device),
419            bias: Some(Param::from_data(TensorData::from([biases]), device)),
420        };
421        GateController::create_with_weights(
422            d_input,
423            d_output,
424            bias,
425            initializer,
426            record_1,
427            record_2,
428        )
429    }
430
431    #[test]
432    fn test_with_uniform_initializer() {
433        let device = Default::default();
434        TestBackend::seed(&device, 0);
435
436        let config = RnnConfig::new(5, 5, false)
437            .with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 });
438        let rnn = config.init::<TestBackend>(&Default::default());
439
440        let gate_to_data =
441            |gate: GateController<TestBackend>| gate.input_transform.weight.val().to_data();
442
443        gate_to_data(rnn.gate).assert_within_range::<FT>(0.elem()..1.elem());
444    }
445
446    /// Test forward pass with simple input vector.
447    ///
448    /// Simple RNN: h_t = tanh(W_input @ x_t + W_hidden @ h_{t-1} + b)
449    /// With input=0.1, weight_input=0.5, bias=0.0, h_0=0.0, weight_hidden=0.5
450    /// h_t = tanh(0.5*0.1 + 0.5*0) = tanh(0.05) = 0.04995
451    #[test]
452    fn test_forward_single_input_single_feature() {
453        let device = Default::default();
454        TestBackend::seed(&device, 0);
455
456        let config = RnnConfig::new(1, 1, false);
457        let device = Default::default();
458        let mut rnn = config.init::<TestBackend>(&device);
459
460        rnn.gate = create_single_feature_gate_controller(
461            0.5,
462            0.0,
463            1,
464            1,
465            false,
466            Initializer::XavierUniform { gain: 1.0 },
467            &device,
468        );
469
470        // single timestep with single feature
471        let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
472
473        let (output, state) = rnn.forward(input, None);
474
475        let tolerance = Tolerance::default();
476        let expected = TensorData::from([[0.04995]]);
477        state
478            .hidden
479            .to_data()
480            .assert_approx_eq::<FT>(&expected, tolerance);
481
482        output
483            .select(0, Tensor::arange(0..1, &device))
484            .squeeze_dim::<2>(0)
485            .to_data()
486            .assert_approx_eq::<FT>(&state.hidden.to_data(), tolerance);
487    }
488
489    #[test]
490    fn test_batched_forward_pass_batch_of_one() {
491        let device = Default::default();
492        let rnn = RnnConfig::new(64, 1024, true).init(&device);
493        let batched_input =
494            Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);
495
496        let (output, state) = rnn.forward(batched_input, None);
497        assert_eq!(output.dims(), [1, 2, 1024]);
498        assert_eq!(state.hidden.dims(), [1, 1024]);
499    }
500
501    #[test]
502    #[cfg(feature = "std")]
503    fn test_batched_backward_pass() {
504        use burn::tensor::Shape;
505        let device = Default::default();
506        let rnn = RnnConfig::new(64, 32, true).init(&device);
507        let shape: Shape = [8, 10, 64].into();
508        let batched_input =
509            Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);
510
511        let (output, _) = rnn.forward(batched_input.clone(), None);
512        let fake_loss = output;
513        let grads = fake_loss.backward();
514
515        let some_gradient = rnn.gate.hidden_transform.weight.grad(&grads).unwrap();
516
517        // Asserts that the gradients exist and are non-zero
518        assert_ne!(
519            some_gradient
520                .any()
521                .into_data()
522                .iter::<f32>()
523                .next()
524                .unwrap(),
525            0.0
526        );
527    }
528
529    #[test]
530    fn test_bidirectional() {
531        let device = Default::default();
532        TestBackend::seed(&device, 0);
533
534        let config = BiRnnConfig::new(2, 3, true);
535        let mut rnn = config.init(&device);
536
537        fn create_gate_controller<const D1: usize, const D2: usize>(
538            input_weights: [[f32; D1]; D2],
539            input_biases: [f32; D1],
540            hidden_weights: [[f32; D1]; D1],
541            hidden_biases: [f32; D1],
542            device: &Device<TestBackend>,
543        ) -> GateController<TestBackend> {
544            let d_input = input_weights[0].len();
545            let d_output = input_weights.len();
546
547            let input_record = LinearRecord {
548                weight: Param::from_data(TensorData::from(input_weights), device),
549                bias: Some(Param::from_data(TensorData::from(input_biases), device)),
550            };
551            let hidden_record = LinearRecord {
552                weight: Param::from_data(TensorData::from(hidden_weights), device),
553                bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
554            };
555            GateController::create_with_weights(
556                d_input,
557                d_output,
558                true,
559                Initializer::XavierUniform { gain: 1.0 },
560                input_record,
561                hidden_record,
562            )
563        }
564
565        // [batch_size=1, seq_length=4, input_size=2]
566        let input = Tensor::<TestBackend, 3>::from_data(
567            TensorData::from([[
568                [0.949, -0.861],
569                [0.892, 0.927],
570                [-0.173, -0.301],
571                [-0.081, 0.992],
572            ]]),
573            &device,
574        );
575
576        // [2, batch_size=1, hidden_size=3]
577        let h0 = Tensor::<TestBackend, 3>::from_data(
578            TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),
579            &device,
580        );
581
582        rnn.forward.gate = create_gate_controller(
583            // input_weights: [input_size=2, hidden_size=3]
584            [[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]],
585            // input_biases: [hidden_size=3]
586            [-0.196, 0.354, 0.209],
587            // hidden_weights: [hidden_size=3, hidden_size=3]
588            [
589                [-0.320, 0.232, -0.165],
590                [0.093, -0.572, -0.315],
591                [-0.467, 0.325, 0.046],
592            ],
593            // hidden_biases: [hidden_size=3]
594            [0.181, -0.190, -0.245],
595            &device,
596        );
597
598        rnn.reverse.gate = create_gate_controller(
599            [[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]],
600            [0.540, -0.164, 0.033],
601            [
602                [0.159, 0.180, -0.037],
603                [-0.443, 0.485, -0.488],
604                [0.098, -0.085, -0.140],
605            ],
606            [-0.510, 0.105, 0.114],
607            &device,
608        );
609
610        // [batch_size=1, sequence_length=4, hidden_size * 2 = 6]
611        // The expected output values were computed from PyTorch
612        let expected_output_with_init_state = TensorData::from([[
613            [0.5226, -0.6370, 0.0210, 0.0685, 0.3867, 0.3602],
614            [0.3580, 0.8431, 0.4129, -0.3175, 0.4374, 0.1766],
615            [-0.3837, -0.2703, -0.3957, -0.1542, -0.1122, 0.0725],
616            [0.5059, 0.5527, 0.1244, -0.6779, 0.3725, -0.3387],
617        ]]);
618        let expected_output_without_init_state = TensorData::from([[
619            [0.0560, -0.2056, 0.2334, 0.0892, 0.3912, 0.3607],
620            [0.4340, 0.7378, 0.3714, -0.2394, 0.4235, 0.2002],
621            [-0.3962, -0.2097, -0.3798, 0.0532, -0.2067, 0.1727],
622            [0.5075, 0.5298, 0.1083, -0.3200, 0.0764, -0.1282],
623        ]]);
624
625        //`[2, batch_size=1, hidden_size=3]`
626        let expected_hn_with_init_state =
627            TensorData::from([[[0.5059, 0.5527, 0.1244]], [[0.0685, 0.3867, 0.3602]]]);
628        let expected_hn_without_init_state =
629            TensorData::from([[[0.5075, 0.5298, 0.1083]], [[0.0892, 0.3912, 0.3607]]]);
630
631        let (output_with_init_state, state_with_init_state) =
632            rnn.forward(input.clone(), Some(RnnState::new(h0)));
633        let (output_without_init_state, state_without_init_state) = rnn.forward(input, None);
634
635        let tolerance = Tolerance::permissive();
636        output_with_init_state
637            .to_data()
638            .assert_approx_eq::<FT>(&expected_output_with_init_state, tolerance);
639        output_without_init_state
640            .to_data()
641            .assert_approx_eq::<FT>(&expected_output_without_init_state, tolerance);
642        state_with_init_state
643            .hidden
644            .to_data()
645            .assert_approx_eq::<FT>(&expected_hn_with_init_state, tolerance);
646        state_without_init_state
647            .hidden
648            .to_data()
649            .assert_approx_eq::<FT>(&expected_hn_without_init_state, tolerance);
650    }
651
652    #[test]
653    fn display_rnn() {
654        let config = RnnConfig::new(2, 3, true);
655
656        let layer = config.init::<TestBackend>(&Default::default());
657
658        assert_eq!(
659            alloc::format!("{layer}"),
660            "Rnn {d_input: 2, d_hidden: 3, bias: true, params: 21}"
661        );
662    }
663
664    #[test]
665    fn display_birnn() {
666        let config = BiRnnConfig::new(2, 3, true);
667
668        let layer = config.init::<TestBackend>(&Default::default());
669
670        assert_eq!(
671            alloc::format!("{layer}"),
672            "BiRnn {d_input: 2, d_hidden: 3, bias: true, params: 42}"
673        );
674    }
675
676    #[test]
677    fn test_rnn_clipping() {
678        let device = Default::default();
679
680        // Create Rnn with clipping enabled
681        let clip_value = 0.3;
682        let config = RnnConfig::new(4, 8, true).with_clip(Some(clip_value));
683        let rnn = config.init::<TestBackend>(&device);
684
685        let input = Tensor::<TestBackend, 3>::random([2, 5, 4], Distribution::Default, &device);
686        let (_, state) = rnn.forward(input, None);
687
688        // Verify output values are within the clip range
689        let hidden_state: Vec<f32> = state.hidden.to_data().to_vec().unwrap();
690        for val in hidden_state {
691            assert!(
692                val >= -clip_value as f32 && val <= clip_value as f32,
693                "Value {} is outside clip range [-{}, {}]",
694                val,
695                clip_value,
696                clip_value
697            );
698        }
699    }
700
701    #[test]
702    fn test_forward_reverse_sequence() {
703        let device = Default::default();
704        TestBackend::seed(&device, 0);
705
706        // Create RNN with reverse=true to process sequence in reverse order
707        let config = RnnConfig::new(1, 1, false).with_reverse(true);
708        let mut rnn = config.init::<TestBackend>(&device);
709
710        rnn.gate = create_single_feature_gate_controller(
711            0.5,
712            0.0,
713            1,
714            1,
715            false,
716            Initializer::XavierUniform { gain: 1.0 },
717            &device,
718        );
719
720        // Create input with 3 timesteps: [0.1, 0.2, 0.3]
721        // Shape: [batch_size=1, seq_length=3, input_features=1]
722        let input =
723            Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device);
724
725        let (output, state) = rnn.forward(input, None);
726
727        // With reverse=true and weight=0.5, sequence is processed in reverse:
728        // t=2 (last): h = tanh(0.5*0.3 + 0.5*0) = tanh(0.15) ≈ 0.1488850
729        // t=1 (mid):  h = tanh(0.5*0.2 + 0.5*0.1488850) ≈ 0.17269433
730        // t=0 (first): h = tanh(0.5*0.1 + 0.5*0.17269433) ≈ 0.135508
731        let expected_final_hidden = TensorData::from([[0.135508]]);
732
733        let tolerance = Tolerance::default();
734        state
735            .hidden
736            .to_data()
737            .assert_approx_eq::<FT>(&expected_final_hidden, tolerance);
738
739        // Verify output tensor has correct shape and matches state at final timestep
740        assert_eq!(output.dims(), [1, 3, 1]);
741    }
742}