burn_core/nn/rnn/
lstm.rs

1use crate as burn;
2
3use crate::config::Config;
4use crate::module::Module;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::nn::rnn::gate_controller::GateController;
7use crate::nn::Initializer;
8use crate::tensor::activation;
9use crate::tensor::backend::Backend;
10use crate::tensor::Tensor;
11
12/// A LstmState is used to store cell state and hidden state in LSTM.
13pub struct LstmState<B: Backend, const D: usize> {
14    /// The cell state.
15    pub cell: Tensor<B, D>,
16    /// The hidden state.
17    pub hidden: Tensor<B, D>,
18}
19
20impl<B: Backend, const D: usize> LstmState<B, D> {
21    /// Initialize a new [LSTM State](LstmState).
22    pub fn new(cell: Tensor<B, D>, hidden: Tensor<B, D>) -> Self {
23        Self { cell, hidden }
24    }
25}
26
27/// Configuration to create a [Lstm](Lstm) module using the [init function](LstmConfig::init).
28#[derive(Config)]
29pub struct LstmConfig {
30    /// The size of the input features.
31    pub d_input: usize,
32    /// The size of the hidden state.
33    pub d_hidden: usize,
34    /// If a bias should be applied during the Lstm transformation.
35    pub bias: bool,
36    /// Lstm initializer
37    #[config(default = "Initializer::XavierNormal{gain:1.0}")]
38    pub initializer: Initializer,
39}
40
41/// The Lstm module. This implementation is for a unidirectional, stateless, Lstm.
42///
43/// Introduced in the paper: [Long Short-Term Memory](https://www.researchgate.net/publication/13853244).
44///
45/// Should be created with [LstmConfig].
46#[derive(Module, Debug)]
47#[module(custom_display)]
48pub struct Lstm<B: Backend> {
49    /// The input gate regulates which information to update and store in the cell state at each time step.
50    pub input_gate: GateController<B>,
51    /// The forget gate is used to control which information to discard or keep in the memory cell at each time step.
52    pub forget_gate: GateController<B>,
53    /// The output gate determines which information from the cell state to output at each time step.
54    pub output_gate: GateController<B>,
55    /// The cell gate is used to compute the cell state that stores and carries information through time.
56    pub cell_gate: GateController<B>,
57    /// The hidden state of the LSTM.
58    pub d_hidden: usize,
59}
60
61impl<B: Backend> ModuleDisplay for Lstm<B> {
62    fn custom_settings(&self) -> Option<DisplaySettings> {
63        DisplaySettings::new()
64            .with_new_line_after_attribute(false)
65            .optional()
66    }
67
68    fn custom_content(&self, content: Content) -> Option<Content> {
69        let [d_input, _] = self.input_gate.input_transform.weight.shape().dims();
70        let bias = self.input_gate.input_transform.bias.is_some();
71
72        content
73            .add("d_input", &d_input)
74            .add("d_hidden", &self.d_hidden)
75            .add("bias", &bias)
76            .optional()
77    }
78}
79
80impl LstmConfig {
81    /// Initialize a new [lstm](Lstm) module.
82    pub fn init<B: Backend>(&self, device: &B::Device) -> Lstm<B> {
83        let d_output = self.d_hidden;
84
85        let new_gate = || {
86            GateController::new(
87                self.d_input,
88                d_output,
89                self.bias,
90                self.initializer.clone(),
91                device,
92            )
93        };
94
95        Lstm {
96            input_gate: new_gate(),
97            forget_gate: new_gate(),
98            output_gate: new_gate(),
99            cell_gate: new_gate(),
100            d_hidden: self.d_hidden,
101        }
102    }
103}
104
105impl<B: Backend> Lstm<B> {
106    /// Applies the forward pass on the input tensor. This LSTM implementation
107    /// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
108    ///
109    /// ## Parameters:
110    /// - batched_input: The input tensor of shape `[batch_size, sequence_length, input_size]`.
111    /// - state: An optional `LstmState` representing the initial cell state and hidden state.
112    ///          Each state tensor has shape `[batch_size, hidden_size]`.
113    ///          If no initial state is provided, these tensors are initialized to zeros.
114    ///
115    /// ## Returns:
116    /// - output: A tensor represents the output features of LSTM. Shape: `[batch_size, sequence_length, hidden_size]`
117    /// - state: A `LstmState` represents the final states. Both `state.cell` and `state.hidden` have the shape
118    ///          `[batch_size, hidden_size]`.
119    pub fn forward(
120        &self,
121        batched_input: Tensor<B, 3>,
122        state: Option<LstmState<B, 2>>,
123    ) -> (Tensor<B, 3>, LstmState<B, 2>) {
124        let device = batched_input.device();
125        let [batch_size, seq_length, _] = batched_input.dims();
126
127        self.forward_iter(
128            batched_input.iter_dim(1).zip(0..seq_length),
129            state,
130            batch_size,
131            seq_length,
132            &device,
133        )
134    }
135
136    fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(
137        &self,
138        input_timestep_iter: I,
139        state: Option<LstmState<B, 2>>,
140        batch_size: usize,
141        seq_length: usize,
142        device: &B::Device,
143    ) -> (Tensor<B, 3>, LstmState<B, 2>) {
144        let mut batched_hidden_state =
145            Tensor::empty([batch_size, seq_length, self.d_hidden], device);
146
147        let (mut cell_state, mut hidden_state) = match state {
148            Some(state) => (state.cell, state.hidden),
149            None => (
150                Tensor::zeros([batch_size, self.d_hidden], device),
151                Tensor::zeros([batch_size, self.d_hidden], device),
152            ),
153        };
154
155        for (input_t, t) in input_timestep_iter {
156            let input_t = input_t.squeeze(1);
157            // f(orget)g(ate) tensors
158            let biased_fg_input_sum = self
159                .forget_gate
160                .gate_product(input_t.clone(), hidden_state.clone());
161            let forget_values = activation::sigmoid(biased_fg_input_sum); // to multiply with cell state
162
163            // i(nput)g(ate) tensors
164            let biased_ig_input_sum = self
165                .input_gate
166                .gate_product(input_t.clone(), hidden_state.clone());
167            let add_values = activation::sigmoid(biased_ig_input_sum);
168
169            // o(output)g(ate) tensors
170            let biased_og_input_sum = self
171                .output_gate
172                .gate_product(input_t.clone(), hidden_state.clone());
173            let output_values = activation::sigmoid(biased_og_input_sum);
174
175            // c(ell)g(ate) tensors
176            let biased_cg_input_sum = self
177                .cell_gate
178                .gate_product(input_t.clone(), hidden_state.clone());
179            let candidate_cell_values = biased_cg_input_sum.tanh();
180
181            cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values;
182            hidden_state = output_values * cell_state.clone().tanh();
183
184            let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1);
185
186            // store the hidden state for this timestep
187            batched_hidden_state = batched_hidden_state.slice_assign(
188                [0..batch_size, t..(t + 1), 0..self.d_hidden],
189                unsqueezed_hidden_state.clone(),
190            );
191        }
192
193        (
194            batched_hidden_state,
195            LstmState::new(cell_state, hidden_state),
196        )
197    }
198}
199
200/// Configuration to create a [BiLstm](BiLstm) module using the [init function](BiLstmConfig::init).
201#[derive(Config)]
202pub struct BiLstmConfig {
203    /// The size of the input features.
204    pub d_input: usize,
205    /// The size of the hidden state.
206    pub d_hidden: usize,
207    /// If a bias should be applied during the BiLstm transformation.
208    pub bias: bool,
209    /// BiLstm initializer
210    #[config(default = "Initializer::XavierNormal{gain:1.0}")]
211    pub initializer: Initializer,
212}
213
214/// The BiLstm module. This implementation is for Bidirectional LSTM.
215///
216/// Introduced in the paper: [Framewise phoneme classification with bidirectional LSTM and other neural network architectures](https://www.cs.toronto.edu/~graves/ijcnn_2005.pdf).
217///
218/// Should be created with [BiLstmConfig].
219#[derive(Module, Debug)]
220#[module(custom_display)]
221pub struct BiLstm<B: Backend> {
222    /// LSTM for the forward direction.
223    pub forward: Lstm<B>,
224    /// LSTM for the reverse direction.
225    pub reverse: Lstm<B>,
226    /// The size of the hidden state.
227    pub d_hidden: usize,
228}
229
230impl<B: Backend> ModuleDisplay for BiLstm<B> {
231    fn custom_settings(&self) -> Option<DisplaySettings> {
232        DisplaySettings::new()
233            .with_new_line_after_attribute(false)
234            .optional()
235    }
236
237    fn custom_content(&self, content: Content) -> Option<Content> {
238        let [d_input, _] = self
239            .forward
240            .input_gate
241            .input_transform
242            .weight
243            .shape()
244            .dims();
245        let bias = self.forward.input_gate.input_transform.bias.is_some();
246
247        content
248            .add("d_input", &d_input)
249            .add("d_hidden", &self.d_hidden)
250            .add("bias", &bias)
251            .optional()
252    }
253}
254
255impl BiLstmConfig {
256    /// Initialize a new [Bidirectional LSTM](BiLstm) module.
257    pub fn init<B: Backend>(&self, device: &B::Device) -> BiLstm<B> {
258        BiLstm {
259            forward: LstmConfig::new(self.d_input, self.d_hidden, self.bias)
260                .with_initializer(self.initializer.clone())
261                .init(device),
262            reverse: LstmConfig::new(self.d_input, self.d_hidden, self.bias)
263                .with_initializer(self.initializer.clone())
264                .init(device),
265            d_hidden: self.d_hidden,
266        }
267    }
268}
269
270impl<B: Backend> BiLstm<B> {
271    /// Applies the forward pass on the input tensor. This Bidirectional LSTM implementation
272    /// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
273    ///
274    /// ## Parameters:
275    /// - batched_input: The input tensor of shape `[batch_size, sequence_length, input_size]`.
276    /// - state: An optional `LstmState` representing the initial cell state and hidden state.
277    ///          Each state tensor has shape `[2, batch_size, hidden_size]`.
278    ///          If no initial state is provided, these tensors are initialized to zeros.
279    ///
280    /// ## Returns:
281    /// - output: A tensor represents the output features of LSTM. Shape: `[batch_size, sequence_length, hidden_size * 2]`
282    /// - state: A `LstmState` represents the final forward and reverse states. Both `state.cell` and
283    ///          `state.hidden` have the shape `[2, batch_size, hidden_size]`.
284    pub fn forward(
285        &self,
286        batched_input: Tensor<B, 3>,
287        state: Option<LstmState<B, 3>>,
288    ) -> (Tensor<B, 3>, LstmState<B, 3>) {
289        let device = batched_input.clone().device();
290        let [batch_size, seq_length, _] = batched_input.shape().dims();
291
292        let [init_state_forward, init_state_reverse] = match state {
293            Some(state) => {
294                let cell_state_forward = state
295                    .cell
296                    .clone()
297                    .slice([0..1, 0..batch_size, 0..self.d_hidden])
298                    .squeeze(0);
299                let hidden_state_forward = state
300                    .hidden
301                    .clone()
302                    .slice([0..1, 0..batch_size, 0..self.d_hidden])
303                    .squeeze(0);
304                let cell_state_reverse = state
305                    .cell
306                    .slice([1..2, 0..batch_size, 0..self.d_hidden])
307                    .squeeze(0);
308                let hidden_state_reverse = state
309                    .hidden
310                    .slice([1..2, 0..batch_size, 0..self.d_hidden])
311                    .squeeze(0);
312
313                [
314                    Some(LstmState::new(cell_state_forward, hidden_state_forward)),
315                    Some(LstmState::new(cell_state_reverse, hidden_state_reverse)),
316                ]
317            }
318            None => [None, None],
319        };
320
321        // forward direction
322        let (batched_hidden_state_forward, final_state_forward) = self
323            .forward
324            .forward(batched_input.clone(), init_state_forward);
325
326        // reverse direction
327        let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(
328            batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
329            init_state_reverse,
330            batch_size,
331            seq_length,
332            &device,
333        );
334
335        let output = Tensor::cat(
336            [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),
337            2,
338        );
339
340        let state = LstmState::new(
341            Tensor::stack(
342                [final_state_forward.cell, final_state_reverse.cell].to_vec(),
343                0,
344            ),
345            Tensor::stack(
346                [final_state_forward.hidden, final_state_reverse.hidden].to_vec(),
347                0,
348            ),
349        );
350
351        (output, state)
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use crate::tensor::{Device, Distribution, TensorData};
359    use crate::{module::Param, nn::LinearRecord, TestBackend};
360
361    #[cfg(feature = "std")]
362    use crate::TestAutodiffBackend;
363
364    #[test]
365    fn test_with_uniform_initializer() {
366        TestBackend::seed(0);
367
368        let config = LstmConfig::new(5, 5, false)
369            .with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 });
370        let lstm = config.init::<TestBackend>(&Default::default());
371
372        let gate_to_data =
373            |gate: GateController<TestBackend>| gate.input_transform.weight.val().to_data();
374
375        gate_to_data(lstm.input_gate).assert_within_range(0..1);
376        gate_to_data(lstm.forget_gate).assert_within_range(0..1);
377        gate_to_data(lstm.output_gate).assert_within_range(0..1);
378        gate_to_data(lstm.cell_gate).assert_within_range(0..1);
379    }
380
381    /// Test forward pass with simple input vector.
382    ///
383    /// f_t = sigmoid(0.7*0.1 + 0.7*0) = sigmoid(0.07) = 0.5173928
384    /// i_t = sigmoid(0.5*0.1 + 0.5*0) = sigmoid(0.05) = 0.5123725
385    /// o_t = sigmoid(1.1*0.1 + 1.1*0) = sigmoid(0.11) = 0.5274723
386    /// c_t = tanh(0.9*0.1 + 0.9*0) = tanh(0.09) = 0.0892937
387    /// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243
388    /// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648
389    #[test]
390    fn test_forward_single_input_single_feature() {
391        TestBackend::seed(0);
392        let config = LstmConfig::new(1, 1, false);
393        let device = Default::default();
394        let mut lstm = config.init::<TestBackend>(&device);
395
396        fn create_gate_controller(
397            weights: f32,
398            biases: f32,
399            d_input: usize,
400            d_output: usize,
401            bias: bool,
402            initializer: Initializer,
403            device: &Device<TestBackend>,
404        ) -> GateController<TestBackend> {
405            let record_1 = LinearRecord {
406                weight: Param::from_data(TensorData::from([[weights]]), device),
407                bias: Some(Param::from_data(TensorData::from([biases]), device)),
408            };
409            let record_2 = LinearRecord {
410                weight: Param::from_data(TensorData::from([[weights]]), device),
411                bias: Some(Param::from_data(TensorData::from([biases]), device)),
412            };
413            GateController::create_with_weights(
414                d_input,
415                d_output,
416                bias,
417                initializer,
418                record_1,
419                record_2,
420            )
421        }
422
423        lstm.input_gate = create_gate_controller(
424            0.5,
425            0.0,
426            1,
427            1,
428            false,
429            Initializer::XavierUniform { gain: 1.0 },
430            &device,
431        );
432        lstm.forget_gate = create_gate_controller(
433            0.7,
434            0.0,
435            1,
436            1,
437            false,
438            Initializer::XavierUniform { gain: 1.0 },
439            &device,
440        );
441        lstm.cell_gate = create_gate_controller(
442            0.9,
443            0.0,
444            1,
445            1,
446            false,
447            Initializer::XavierUniform { gain: 1.0 },
448            &device,
449        );
450        lstm.output_gate = create_gate_controller(
451            1.1,
452            0.0,
453            1,
454            1,
455            false,
456            Initializer::XavierUniform { gain: 1.0 },
457            &device,
458        );
459
460        // single timestep with single feature
461        let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
462
463        let (output, state) = lstm.forward(input, None);
464
465        let expected = TensorData::from([[0.046]]);
466        state.cell.to_data().assert_approx_eq(&expected, 3);
467
468        let expected = TensorData::from([[0.024]]);
469        state.hidden.to_data().assert_approx_eq(&expected, 3);
470
471        output
472            .select(0, Tensor::arange(0..1, &device))
473            .squeeze::<2>(0)
474            .to_data()
475            .assert_approx_eq(&state.hidden.to_data(), 3);
476    }
477
478    #[test]
479    fn test_batched_forward_pass() {
480        let device = Default::default();
481        let lstm = LstmConfig::new(64, 1024, true).init(&device);
482        let batched_input =
483            Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
484
485        let (output, state) = lstm.forward(batched_input, None);
486
487        assert_eq!(output.dims(), [8, 10, 1024]);
488        assert_eq!(state.cell.dims(), [8, 1024]);
489        assert_eq!(state.hidden.dims(), [8, 1024]);
490    }
491
492    #[test]
493    fn test_batched_forward_pass_batch_of_one() {
494        let device = Default::default();
495        let lstm = LstmConfig::new(64, 1024, true).init(&device);
496        let batched_input =
497            Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);
498
499        let (output, state) = lstm.forward(batched_input, None);
500
501        assert_eq!(output.dims(), [1, 2, 1024]);
502        assert_eq!(state.cell.dims(), [1, 1024]);
503        assert_eq!(state.hidden.dims(), [1, 1024]);
504    }
505
506    #[test]
507    #[cfg(feature = "std")]
508    fn test_batched_backward_pass() {
509        use crate::tensor::Shape;
510        let device = Default::default();
511        let lstm = LstmConfig::new(64, 32, true).init(&device);
512        let shape: Shape = [8, 10, 64].into();
513        let batched_input =
514            Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);
515
516        let (output, _) = lstm.forward(batched_input.clone(), None);
517        let fake_loss = output;
518        let grads = fake_loss.backward();
519
520        let some_gradient = lstm
521            .output_gate
522            .hidden_transform
523            .weight
524            .grad(&grads)
525            .unwrap();
526
527        // Asserts that the gradients exist and are non-zero
528        assert!(
529            some_gradient
530                .any()
531                .into_data()
532                .iter::<f32>()
533                .next()
534                .unwrap()
535                != 0.0
536        );
537    }
538
539    #[test]
540    fn test_bidirectional() {
541        TestBackend::seed(0);
542        let config = BiLstmConfig::new(2, 3, true);
543        let device = Default::default();
544        let mut lstm = config.init(&device);
545
546        fn create_gate_controller<const D1: usize, const D2: usize>(
547            input_weights: [[f32; D1]; D2],
548            input_biases: [f32; D1],
549            hidden_weights: [[f32; D1]; D1],
550            hidden_biases: [f32; D1],
551            device: &Device<TestBackend>,
552        ) -> GateController<TestBackend> {
553            let d_input = input_weights[0].len();
554            let d_output = input_weights.len();
555
556            let input_record = LinearRecord {
557                weight: Param::from_data(TensorData::from(input_weights), device),
558                bias: Some(Param::from_data(TensorData::from(input_biases), device)),
559            };
560            let hidden_record = LinearRecord {
561                weight: Param::from_data(TensorData::from(hidden_weights), device),
562                bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
563            };
564            GateController::create_with_weights(
565                d_input,
566                d_output,
567                true,
568                Initializer::XavierUniform { gain: 1.0 },
569                input_record,
570                hidden_record,
571            )
572        }
573
574        let input = Tensor::<TestBackend, 3>::from_data(
575            TensorData::from([[
576                [0.949, -0.861],
577                [0.892, 0.927],
578                [-0.173, -0.301],
579                [-0.081, 0.992],
580            ]]),
581            &device,
582        );
583        let h0 = Tensor::<TestBackend, 3>::from_data(
584            TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),
585            &device,
586        );
587        let c0 = Tensor::<TestBackend, 3>::from_data(
588            TensorData::from([[[0.723, 0.397, -0.262]], [[0.471, 0.613, 1.885]]]),
589            &device,
590        );
591
592        lstm.forward.input_gate = create_gate_controller(
593            [[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]],
594            [-0.196, 0.354, 0.209],
595            [
596                [-0.320, 0.232, -0.165],
597                [0.093, -0.572, -0.315],
598                [-0.467, 0.325, 0.046],
599            ],
600            [0.181, -0.190, -0.245],
601            &device,
602        );
603
604        lstm.forward.forget_gate = create_gate_controller(
605            [[-0.342, -0.084, -0.420], [-0.432, 0.119, 0.191]],
606            [0.315, -0.413, -0.041],
607            [
608                [0.453, 0.063, 0.561],
609                [0.211, 0.149, 0.213],
610                [-0.499, -0.158, 0.068],
611            ],
612            [-0.431, -0.535, 0.125],
613            &device,
614        );
615
616        lstm.forward.cell_gate = create_gate_controller(
617            [[-0.046, -0.382, 0.321], [-0.533, 0.558, 0.004]],
618            [-0.358, 0.282, -0.078],
619            [
620                [-0.358, 0.109, 0.139],
621                [-0.345, 0.091, -0.368],
622                [-0.508, 0.221, -0.507],
623            ],
624            [0.502, -0.509, -0.247],
625            &device,
626        );
627
628        lstm.forward.output_gate = create_gate_controller(
629            [[-0.577, -0.359, 0.216], [-0.550, 0.268, 0.243]],
630            [-0.227, -0.274, 0.039],
631            [
632                [-0.383, 0.449, 0.222],
633                [-0.357, -0.093, 0.449],
634                [-0.106, 0.236, 0.360],
635            ],
636            [-0.361, -0.209, -0.454],
637            &device,
638        );
639
640        lstm.reverse.input_gate = create_gate_controller(
641            [[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]],
642            [0.540, -0.164, 0.033],
643            [
644                [0.159, 0.180, -0.037],
645                [-0.443, 0.485, -0.488],
646                [0.098, -0.085, -0.140],
647            ],
648            [-0.510, 0.105, 0.114],
649            &device,
650        );
651
652        lstm.reverse.forget_gate = create_gate_controller(
653            [[-0.154, -0.432, -0.547], [-0.369, -0.310, -0.175]],
654            [0.141, 0.004, 0.055],
655            [
656                [-0.005, -0.277, -0.515],
657                [-0.011, -0.101, -0.365],
658                [0.426, 0.379, 0.337],
659            ],
660            [-0.382, 0.331, -0.176],
661            &device,
662        );
663
664        lstm.reverse.cell_gate = create_gate_controller(
665            [[-0.571, 0.228, -0.287], [-0.331, 0.110, 0.219]],
666            [-0.206, -0.546, 0.462],
667            [
668                [0.449, -0.240, 0.071],
669                [-0.045, 0.131, 0.124],
670                [0.138, -0.201, 0.191],
671            ],
672            [-0.030, 0.211, -0.352],
673            &device,
674        );
675
676        lstm.reverse.output_gate = create_gate_controller(
677            [[0.491, -0.442, 0.333], [0.313, -0.121, -0.070]],
678            [-0.387, -0.250, 0.066],
679            [
680                [-0.030, 0.268, 0.299],
681                [-0.019, -0.280, -0.314],
682                [0.466, -0.365, -0.248],
683            ],
684            [-0.398, -0.199, -0.566],
685            &device,
686        );
687
688        let expected_output_with_init_state = TensorData::from([[
689            [0.23764, -0.03442, 0.04414, -0.15635, -0.03366, -0.05798],
690            [0.00473, -0.02254, 0.02988, -0.16510, -0.00306, 0.08742],
691            [0.06210, -0.06509, -0.05339, -0.01710, 0.02091, 0.16012],
692            [-0.03420, 0.07774, -0.09774, -0.02604, 0.12584, 0.20872],
693        ]]);
694        let expected_output_without_init_state = TensorData::from([[
695            [0.08679, -0.08776, -0.00528, -0.15969, -0.05322, -0.08863],
696            [-0.02577, -0.05057, 0.00033, -0.17558, -0.03679, 0.03142],
697            [0.02942, -0.07411, -0.06044, -0.03601, -0.09998, 0.04846],
698            [-0.04026, 0.07178, -0.10189, -0.07349, -0.04576, 0.05550],
699        ]]);
700        let expected_hn_with_init_state = TensorData::from([
701            [[-0.03420, 0.07774, -0.09774]],
702            [[-0.15635, -0.03366, -0.05798]],
703        ]);
704        let expected_cn_with_init_state = TensorData::from([
705            [[-0.13593, 0.17125, -0.22395]],
706            [[-0.45425, -0.11206, -0.12908]],
707        ]);
708        let expected_hn_without_init_state = TensorData::from([
709            [[-0.04026, 0.07178, -0.10189]],
710            [[-0.15969, -0.05322, -0.08863]],
711        ]);
712        let expected_cn_without_init_state = TensorData::from([
713            [[-0.15839, 0.15923, -0.23569]],
714            [[-0.47407, -0.17493, -0.19643]],
715        ]);
716
717        let (output_with_init_state, state_with_init_state) =
718            lstm.forward(input.clone(), Some(LstmState::new(c0, h0)));
719        let (output_without_init_state, state_without_init_state) = lstm.forward(input, None);
720
721        output_with_init_state
722            .to_data()
723            .assert_approx_eq(&expected_output_with_init_state, 3);
724        output_without_init_state
725            .to_data()
726            .assert_approx_eq(&expected_output_without_init_state, 3);
727        state_with_init_state
728            .hidden
729            .to_data()
730            .assert_approx_eq(&expected_hn_with_init_state, 3);
731        state_with_init_state
732            .cell
733            .to_data()
734            .assert_approx_eq(&expected_cn_with_init_state, 3);
735        state_without_init_state
736            .hidden
737            .to_data()
738            .assert_approx_eq(&expected_hn_without_init_state, 3);
739        state_without_init_state
740            .cell
741            .to_data()
742            .assert_approx_eq(&expected_cn_without_init_state, 3);
743    }
744
745    #[test]
746    fn display_lstm() {
747        let config = LstmConfig::new(2, 3, true);
748
749        let layer = config.init::<TestBackend>(&Default::default());
750
751        assert_eq!(
752            alloc::format!("{}", layer),
753            "Lstm {d_input: 2, d_hidden: 3, bias: true, params: 84}"
754        );
755    }
756
757    #[test]
758    fn display_bilstm() {
759        let config = BiLstmConfig::new(2, 3, true);
760
761        let layer = config.init::<TestBackend>(&Default::default());
762
763        assert_eq!(
764            alloc::format!("{}", layer),
765            "BiLstm {d_input: 2, d_hidden: 3, bias: true, params: 168}"
766        );
767    }
768}