burn_core/nn/rnn/
lstm.rs

1use crate as burn;
2
3use crate::config::Config;
4use crate::module::Module;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::nn::Initializer;
7use crate::nn::rnn::gate_controller::GateController;
8use crate::tensor::Tensor;
9use crate::tensor::activation;
10use crate::tensor::backend::Backend;
11
12/// A LstmState is used to store cell state and hidden state in LSTM.
13pub struct LstmState<B: Backend, const D: usize> {
14    /// The cell state.
15    pub cell: Tensor<B, D>,
16    /// The hidden state.
17    pub hidden: Tensor<B, D>,
18}
19
20impl<B: Backend, const D: usize> LstmState<B, D> {
21    /// Initialize a new [LSTM State](LstmState).
22    pub fn new(cell: Tensor<B, D>, hidden: Tensor<B, D>) -> Self {
23        Self { cell, hidden }
24    }
25}
26
27/// Configuration to create a [Lstm](Lstm) module using the [init function](LstmConfig::init).
28#[derive(Config)]
29pub struct LstmConfig {
30    /// The size of the input features.
31    pub d_input: usize,
32    /// The size of the hidden state.
33    pub d_hidden: usize,
34    /// If a bias should be applied during the Lstm transformation.
35    pub bias: bool,
36    /// Lstm initializer
37    #[config(default = "Initializer::XavierNormal{gain:1.0}")]
38    pub initializer: Initializer,
39}
40
41/// The Lstm module. This implementation is for a unidirectional, stateless, Lstm.
42///
43/// Introduced in the paper: [Long Short-Term Memory](https://www.researchgate.net/publication/13853244).
44///
45/// Should be created with [LstmConfig].
46#[derive(Module, Debug)]
47#[module(custom_display)]
48pub struct Lstm<B: Backend> {
49    /// The input gate regulates which information to update and store in the cell state at each time step.
50    pub input_gate: GateController<B>,
51    /// The forget gate is used to control which information to discard or keep in the memory cell at each time step.
52    pub forget_gate: GateController<B>,
53    /// The output gate determines which information from the cell state to output at each time step.
54    pub output_gate: GateController<B>,
55    /// The cell gate is used to compute the cell state that stores and carries information through time.
56    pub cell_gate: GateController<B>,
57    /// The hidden state of the LSTM.
58    pub d_hidden: usize,
59}
60
61impl<B: Backend> ModuleDisplay for Lstm<B> {
62    fn custom_settings(&self) -> Option<DisplaySettings> {
63        DisplaySettings::new()
64            .with_new_line_after_attribute(false)
65            .optional()
66    }
67
68    fn custom_content(&self, content: Content) -> Option<Content> {
69        let [d_input, _] = self.input_gate.input_transform.weight.shape().dims();
70        let bias = self.input_gate.input_transform.bias.is_some();
71
72        content
73            .add("d_input", &d_input)
74            .add("d_hidden", &self.d_hidden)
75            .add("bias", &bias)
76            .optional()
77    }
78}
79
80impl LstmConfig {
81    /// Initialize a new [lstm](Lstm) module.
82    pub fn init<B: Backend>(&self, device: &B::Device) -> Lstm<B> {
83        let d_output = self.d_hidden;
84
85        let new_gate = || {
86            GateController::new(
87                self.d_input,
88                d_output,
89                self.bias,
90                self.initializer.clone(),
91                device,
92            )
93        };
94
95        Lstm {
96            input_gate: new_gate(),
97            forget_gate: new_gate(),
98            output_gate: new_gate(),
99            cell_gate: new_gate(),
100            d_hidden: self.d_hidden,
101        }
102    }
103}
104
105impl<B: Backend> Lstm<B> {
106    /// Applies the forward pass on the input tensor. This LSTM implementation
107    /// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
108    ///
109    /// ## Parameters:
110    /// - batched_input: The input tensor of shape `[batch_size, sequence_length, input_size]`.
111    /// - state: An optional `LstmState` representing the initial cell state and hidden state.
112    ///   Each state tensor has shape `[batch_size, hidden_size]`.
113    ///   If no initial state is provided, these tensors are initialized to zeros.
114    ///
115    /// ## Returns:
116    /// - output: A tensor represents the output features of LSTM. Shape: `[batch_size, sequence_length, hidden_size]`
117    /// - state: A `LstmState` represents the final states. Both `state.cell` and `state.hidden` have the shape
118    ///   `[batch_size, hidden_size]`.
119    pub fn forward(
120        &self,
121        batched_input: Tensor<B, 3>,
122        state: Option<LstmState<B, 2>>,
123    ) -> (Tensor<B, 3>, LstmState<B, 2>) {
124        let device = batched_input.device();
125        let [batch_size, seq_length, _] = batched_input.dims();
126
127        self.forward_iter(
128            batched_input.iter_dim(1).zip(0..seq_length),
129            state,
130            batch_size,
131            seq_length,
132            &device,
133        )
134    }
135
136    fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(
137        &self,
138        input_timestep_iter: I,
139        state: Option<LstmState<B, 2>>,
140        batch_size: usize,
141        seq_length: usize,
142        device: &B::Device,
143    ) -> (Tensor<B, 3>, LstmState<B, 2>) {
144        let mut batched_hidden_state =
145            Tensor::empty([batch_size, seq_length, self.d_hidden], device);
146
147        let (mut cell_state, mut hidden_state) = match state {
148            Some(state) => (state.cell, state.hidden),
149            None => (
150                Tensor::zeros([batch_size, self.d_hidden], device),
151                Tensor::zeros([batch_size, self.d_hidden], device),
152            ),
153        };
154
155        for (input_t, t) in input_timestep_iter {
156            let input_t = input_t.squeeze(1);
157            // f(orget)g(ate) tensors
158            let biased_fg_input_sum = self
159                .forget_gate
160                .gate_product(input_t.clone(), hidden_state.clone());
161            let forget_values = activation::sigmoid(biased_fg_input_sum); // to multiply with cell state
162
163            // i(nput)g(ate) tensors
164            let biased_ig_input_sum = self
165                .input_gate
166                .gate_product(input_t.clone(), hidden_state.clone());
167            let add_values = activation::sigmoid(biased_ig_input_sum);
168
169            // o(output)g(ate) tensors
170            let biased_og_input_sum = self
171                .output_gate
172                .gate_product(input_t.clone(), hidden_state.clone());
173            let output_values = activation::sigmoid(biased_og_input_sum);
174
175            // c(ell)g(ate) tensors
176            let biased_cg_input_sum = self
177                .cell_gate
178                .gate_product(input_t.clone(), hidden_state.clone());
179            let candidate_cell_values = biased_cg_input_sum.tanh();
180
181            cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values;
182            hidden_state = output_values * cell_state.clone().tanh();
183
184            let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1);
185
186            // store the hidden state for this timestep
187            batched_hidden_state = batched_hidden_state.slice_assign(
188                [0..batch_size, t..(t + 1), 0..self.d_hidden],
189                unsqueezed_hidden_state.clone(),
190            );
191        }
192
193        (
194            batched_hidden_state,
195            LstmState::new(cell_state, hidden_state),
196        )
197    }
198}
199
200/// Configuration to create a [BiLstm](BiLstm) module using the [init function](BiLstmConfig::init).
201#[derive(Config)]
202pub struct BiLstmConfig {
203    /// The size of the input features.
204    pub d_input: usize,
205    /// The size of the hidden state.
206    pub d_hidden: usize,
207    /// If a bias should be applied during the BiLstm transformation.
208    pub bias: bool,
209    /// BiLstm initializer
210    #[config(default = "Initializer::XavierNormal{gain:1.0}")]
211    pub initializer: Initializer,
212}
213
214/// The BiLstm module. This implementation is for Bidirectional LSTM.
215///
216/// Introduced in the paper: [Framewise phoneme classification with bidirectional LSTM and other neural network architectures](https://www.cs.toronto.edu/~graves/ijcnn_2005.pdf).
217///
218/// Should be created with [BiLstmConfig].
219#[derive(Module, Debug)]
220#[module(custom_display)]
221pub struct BiLstm<B: Backend> {
222    /// LSTM for the forward direction.
223    pub forward: Lstm<B>,
224    /// LSTM for the reverse direction.
225    pub reverse: Lstm<B>,
226    /// The size of the hidden state.
227    pub d_hidden: usize,
228}
229
230impl<B: Backend> ModuleDisplay for BiLstm<B> {
231    fn custom_settings(&self) -> Option<DisplaySettings> {
232        DisplaySettings::new()
233            .with_new_line_after_attribute(false)
234            .optional()
235    }
236
237    fn custom_content(&self, content: Content) -> Option<Content> {
238        let [d_input, _] = self
239            .forward
240            .input_gate
241            .input_transform
242            .weight
243            .shape()
244            .dims();
245        let bias = self.forward.input_gate.input_transform.bias.is_some();
246
247        content
248            .add("d_input", &d_input)
249            .add("d_hidden", &self.d_hidden)
250            .add("bias", &bias)
251            .optional()
252    }
253}
254
255impl BiLstmConfig {
256    /// Initialize a new [Bidirectional LSTM](BiLstm) module.
257    pub fn init<B: Backend>(&self, device: &B::Device) -> BiLstm<B> {
258        BiLstm {
259            forward: LstmConfig::new(self.d_input, self.d_hidden, self.bias)
260                .with_initializer(self.initializer.clone())
261                .init(device),
262            reverse: LstmConfig::new(self.d_input, self.d_hidden, self.bias)
263                .with_initializer(self.initializer.clone())
264                .init(device),
265            d_hidden: self.d_hidden,
266        }
267    }
268}
269
270impl<B: Backend> BiLstm<B> {
271    /// Applies the forward pass on the input tensor. This Bidirectional LSTM implementation
272    /// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
273    ///
274    /// ## Parameters:
275    /// - batched_input: The input tensor of shape `[batch_size, sequence_length, input_size]`.
276    /// - state: An optional `LstmState` representing the initial cell state and hidden state.
277    ///   Each state tensor has shape `[2, batch_size, hidden_size]`.
278    ///   If no initial state is provided, these tensors are initialized to zeros.
279    ///
280    /// ## Returns:
281    /// - output: A tensor represents the output features of LSTM. Shape: `[batch_size, sequence_length, hidden_size * 2]`
282    /// - state: A `LstmState` represents the final forward and reverse states. Both `state.cell` and
283    ///   `state.hidden` have the shape `[2, batch_size, hidden_size]`.
284    pub fn forward(
285        &self,
286        batched_input: Tensor<B, 3>,
287        state: Option<LstmState<B, 3>>,
288    ) -> (Tensor<B, 3>, LstmState<B, 3>) {
289        let device = batched_input.clone().device();
290        let [batch_size, seq_length, _] = batched_input.shape().dims();
291
292        let [init_state_forward, init_state_reverse] = match state {
293            Some(state) => {
294                let cell_state_forward = state
295                    .cell
296                    .clone()
297                    .slice([0..1, 0..batch_size, 0..self.d_hidden])
298                    .squeeze(0);
299                let hidden_state_forward = state
300                    .hidden
301                    .clone()
302                    .slice([0..1, 0..batch_size, 0..self.d_hidden])
303                    .squeeze(0);
304                let cell_state_reverse = state
305                    .cell
306                    .slice([1..2, 0..batch_size, 0..self.d_hidden])
307                    .squeeze(0);
308                let hidden_state_reverse = state
309                    .hidden
310                    .slice([1..2, 0..batch_size, 0..self.d_hidden])
311                    .squeeze(0);
312
313                [
314                    Some(LstmState::new(cell_state_forward, hidden_state_forward)),
315                    Some(LstmState::new(cell_state_reverse, hidden_state_reverse)),
316                ]
317            }
318            None => [None, None],
319        };
320
321        // forward direction
322        let (batched_hidden_state_forward, final_state_forward) = self
323            .forward
324            .forward(batched_input.clone(), init_state_forward);
325
326        // reverse direction
327        let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(
328            batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
329            init_state_reverse,
330            batch_size,
331            seq_length,
332            &device,
333        );
334
335        let output = Tensor::cat(
336            [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),
337            2,
338        );
339
340        let state = LstmState::new(
341            Tensor::stack(
342                [final_state_forward.cell, final_state_reverse.cell].to_vec(),
343                0,
344            ),
345            Tensor::stack(
346                [final_state_forward.hidden, final_state_reverse.hidden].to_vec(),
347                0,
348            ),
349        );
350
351        (output, state)
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use crate::tensor::{Device, Distribution, TensorData};
359    use crate::{TestBackend, module::Param, nn::LinearRecord};
360    use burn_tensor::{ElementConversion, Tolerance, ops::FloatElem};
361    type FT = FloatElem<TestBackend>;
362
363    #[cfg(feature = "std")]
364    use crate::TestAutodiffBackend;
365
366    #[test]
367    fn test_with_uniform_initializer() {
368        TestBackend::seed(0);
369
370        let config = LstmConfig::new(5, 5, false)
371            .with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 });
372        let lstm = config.init::<TestBackend>(&Default::default());
373
374        let gate_to_data =
375            |gate: GateController<TestBackend>| gate.input_transform.weight.val().to_data();
376
377        gate_to_data(lstm.input_gate).assert_within_range::<FT>(0.elem()..1.elem());
378        gate_to_data(lstm.forget_gate).assert_within_range::<FT>(0.elem()..1.elem());
379        gate_to_data(lstm.output_gate).assert_within_range::<FT>(0.elem()..1.elem());
380        gate_to_data(lstm.cell_gate).assert_within_range::<FT>(0.elem()..1.elem());
381    }
382
383    /// Test forward pass with simple input vector.
384    ///
385    /// f_t = sigmoid(0.7*0.1 + 0.7*0) = sigmoid(0.07) = 0.5173928
386    /// i_t = sigmoid(0.5*0.1 + 0.5*0) = sigmoid(0.05) = 0.5123725
387    /// o_t = sigmoid(1.1*0.1 + 1.1*0) = sigmoid(0.11) = 0.5274723
388    /// c_t = tanh(0.9*0.1 + 0.9*0) = tanh(0.09) = 0.0892937
389    /// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243
390    /// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648
391    #[test]
392    fn test_forward_single_input_single_feature() {
393        TestBackend::seed(0);
394        let config = LstmConfig::new(1, 1, false);
395        let device = Default::default();
396        let mut lstm = config.init::<TestBackend>(&device);
397
398        fn create_gate_controller(
399            weights: f32,
400            biases: f32,
401            d_input: usize,
402            d_output: usize,
403            bias: bool,
404            initializer: Initializer,
405            device: &Device<TestBackend>,
406        ) -> GateController<TestBackend> {
407            let record_1 = LinearRecord {
408                weight: Param::from_data(TensorData::from([[weights]]), device),
409                bias: Some(Param::from_data(TensorData::from([biases]), device)),
410            };
411            let record_2 = LinearRecord {
412                weight: Param::from_data(TensorData::from([[weights]]), device),
413                bias: Some(Param::from_data(TensorData::from([biases]), device)),
414            };
415            GateController::create_with_weights(
416                d_input,
417                d_output,
418                bias,
419                initializer,
420                record_1,
421                record_2,
422            )
423        }
424
425        lstm.input_gate = create_gate_controller(
426            0.5,
427            0.0,
428            1,
429            1,
430            false,
431            Initializer::XavierUniform { gain: 1.0 },
432            &device,
433        );
434        lstm.forget_gate = create_gate_controller(
435            0.7,
436            0.0,
437            1,
438            1,
439            false,
440            Initializer::XavierUniform { gain: 1.0 },
441            &device,
442        );
443        lstm.cell_gate = create_gate_controller(
444            0.9,
445            0.0,
446            1,
447            1,
448            false,
449            Initializer::XavierUniform { gain: 1.0 },
450            &device,
451        );
452        lstm.output_gate = create_gate_controller(
453            1.1,
454            0.0,
455            1,
456            1,
457            false,
458            Initializer::XavierUniform { gain: 1.0 },
459            &device,
460        );
461
462        // single timestep with single feature
463        let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
464
465        let (output, state) = lstm.forward(input, None);
466
467        let expected = TensorData::from([[0.046]]);
468        let tolerance = Tolerance::default();
469        state
470            .cell
471            .to_data()
472            .assert_approx_eq::<FT>(&expected, tolerance);
473
474        let expected = TensorData::from([[0.0242]]);
475        state
476            .hidden
477            .to_data()
478            .assert_approx_eq::<FT>(&expected, tolerance);
479
480        output
481            .select(0, Tensor::arange(0..1, &device))
482            .squeeze::<2>(0)
483            .to_data()
484            .assert_approx_eq::<FT>(&state.hidden.to_data(), tolerance);
485    }
486
487    #[test]
488    fn test_batched_forward_pass() {
489        let device = Default::default();
490        let lstm = LstmConfig::new(64, 1024, true).init(&device);
491        let batched_input =
492            Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
493
494        let (output, state) = lstm.forward(batched_input, None);
495
496        assert_eq!(output.dims(), [8, 10, 1024]);
497        assert_eq!(state.cell.dims(), [8, 1024]);
498        assert_eq!(state.hidden.dims(), [8, 1024]);
499    }
500
501    #[test]
502    fn test_batched_forward_pass_batch_of_one() {
503        let device = Default::default();
504        let lstm = LstmConfig::new(64, 1024, true).init(&device);
505        let batched_input =
506            Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);
507
508        let (output, state) = lstm.forward(batched_input, None);
509
510        assert_eq!(output.dims(), [1, 2, 1024]);
511        assert_eq!(state.cell.dims(), [1, 1024]);
512        assert_eq!(state.hidden.dims(), [1, 1024]);
513    }
514
515    #[test]
516    #[cfg(feature = "std")]
517    fn test_batched_backward_pass() {
518        use crate::tensor::Shape;
519        let device = Default::default();
520        let lstm = LstmConfig::new(64, 32, true).init(&device);
521        let shape: Shape = [8, 10, 64].into();
522        let batched_input =
523            Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);
524
525        let (output, _) = lstm.forward(batched_input.clone(), None);
526        let fake_loss = output;
527        let grads = fake_loss.backward();
528
529        let some_gradient = lstm
530            .output_gate
531            .hidden_transform
532            .weight
533            .grad(&grads)
534            .unwrap();
535
536        // Asserts that the gradients exist and are non-zero
537        assert!(
538            some_gradient
539                .any()
540                .into_data()
541                .iter::<f32>()
542                .next()
543                .unwrap()
544                != 0.0
545        );
546    }
547
548    #[test]
549    fn test_bidirectional() {
550        TestBackend::seed(0);
551        let config = BiLstmConfig::new(2, 3, true);
552        let device = Default::default();
553        let mut lstm = config.init(&device);
554
555        fn create_gate_controller<const D1: usize, const D2: usize>(
556            input_weights: [[f32; D1]; D2],
557            input_biases: [f32; D1],
558            hidden_weights: [[f32; D1]; D1],
559            hidden_biases: [f32; D1],
560            device: &Device<TestBackend>,
561        ) -> GateController<TestBackend> {
562            let d_input = input_weights[0].len();
563            let d_output = input_weights.len();
564
565            let input_record = LinearRecord {
566                weight: Param::from_data(TensorData::from(input_weights), device),
567                bias: Some(Param::from_data(TensorData::from(input_biases), device)),
568            };
569            let hidden_record = LinearRecord {
570                weight: Param::from_data(TensorData::from(hidden_weights), device),
571                bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
572            };
573            GateController::create_with_weights(
574                d_input,
575                d_output,
576                true,
577                Initializer::XavierUniform { gain: 1.0 },
578                input_record,
579                hidden_record,
580            )
581        }
582
583        let input = Tensor::<TestBackend, 3>::from_data(
584            TensorData::from([[
585                [0.949, -0.861],
586                [0.892, 0.927],
587                [-0.173, -0.301],
588                [-0.081, 0.992],
589            ]]),
590            &device,
591        );
592        let h0 = Tensor::<TestBackend, 3>::from_data(
593            TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),
594            &device,
595        );
596        let c0 = Tensor::<TestBackend, 3>::from_data(
597            TensorData::from([[[0.723, 0.397, -0.262]], [[0.471, 0.613, 1.885]]]),
598            &device,
599        );
600
601        lstm.forward.input_gate = create_gate_controller(
602            [[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]],
603            [-0.196, 0.354, 0.209],
604            [
605                [-0.320, 0.232, -0.165],
606                [0.093, -0.572, -0.315],
607                [-0.467, 0.325, 0.046],
608            ],
609            [0.181, -0.190, -0.245],
610            &device,
611        );
612
613        lstm.forward.forget_gate = create_gate_controller(
614            [[-0.342, -0.084, -0.420], [-0.432, 0.119, 0.191]],
615            [0.315, -0.413, -0.041],
616            [
617                [0.453, 0.063, 0.561],
618                [0.211, 0.149, 0.213],
619                [-0.499, -0.158, 0.068],
620            ],
621            [-0.431, -0.535, 0.125],
622            &device,
623        );
624
625        lstm.forward.cell_gate = create_gate_controller(
626            [[-0.046, -0.382, 0.321], [-0.533, 0.558, 0.004]],
627            [-0.358, 0.282, -0.078],
628            [
629                [-0.358, 0.109, 0.139],
630                [-0.345, 0.091, -0.368],
631                [-0.508, 0.221, -0.507],
632            ],
633            [0.502, -0.509, -0.247],
634            &device,
635        );
636
637        lstm.forward.output_gate = create_gate_controller(
638            [[-0.577, -0.359, 0.216], [-0.550, 0.268, 0.243]],
639            [-0.227, -0.274, 0.039],
640            [
641                [-0.383, 0.449, 0.222],
642                [-0.357, -0.093, 0.449],
643                [-0.106, 0.236, 0.360],
644            ],
645            [-0.361, -0.209, -0.454],
646            &device,
647        );
648
649        lstm.reverse.input_gate = create_gate_controller(
650            [[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]],
651            [0.540, -0.164, 0.033],
652            [
653                [0.159, 0.180, -0.037],
654                [-0.443, 0.485, -0.488],
655                [0.098, -0.085, -0.140],
656            ],
657            [-0.510, 0.105, 0.114],
658            &device,
659        );
660
661        lstm.reverse.forget_gate = create_gate_controller(
662            [[-0.154, -0.432, -0.547], [-0.369, -0.310, -0.175]],
663            [0.141, 0.004, 0.055],
664            [
665                [-0.005, -0.277, -0.515],
666                [-0.011, -0.101, -0.365],
667                [0.426, 0.379, 0.337],
668            ],
669            [-0.382, 0.331, -0.176],
670            &device,
671        );
672
673        lstm.reverse.cell_gate = create_gate_controller(
674            [[-0.571, 0.228, -0.287], [-0.331, 0.110, 0.219]],
675            [-0.206, -0.546, 0.462],
676            [
677                [0.449, -0.240, 0.071],
678                [-0.045, 0.131, 0.124],
679                [0.138, -0.201, 0.191],
680            ],
681            [-0.030, 0.211, -0.352],
682            &device,
683        );
684
685        lstm.reverse.output_gate = create_gate_controller(
686            [[0.491, -0.442, 0.333], [0.313, -0.121, -0.070]],
687            [-0.387, -0.250, 0.066],
688            [
689                [-0.030, 0.268, 0.299],
690                [-0.019, -0.280, -0.314],
691                [0.466, -0.365, -0.248],
692            ],
693            [-0.398, -0.199, -0.566],
694            &device,
695        );
696
697        let expected_output_with_init_state = TensorData::from([[
698            [0.23764, -0.03442, 0.04414, -0.15635, -0.03366, -0.05798],
699            [0.00473, -0.02254, 0.02988, -0.16510, -0.00306, 0.08742],
700            [0.06210, -0.06509, -0.05339, -0.01710, 0.02091, 0.16012],
701            [-0.03420, 0.07774, -0.09774, -0.02604, 0.12584, 0.20872],
702        ]]);
703        let expected_output_without_init_state = TensorData::from([[
704            [0.08679, -0.08776, -0.00528, -0.15969, -0.05322, -0.08863],
705            [-0.02577, -0.05057, 0.00033, -0.17558, -0.03679, 0.03142],
706            [0.02942, -0.07411, -0.06044, -0.03601, -0.09998, 0.04846],
707            [-0.04026, 0.07178, -0.10189, -0.07349, -0.04576, 0.05550],
708        ]]);
709        let expected_hn_with_init_state = TensorData::from([
710            [[-0.03420, 0.07774, -0.09774]],
711            [[-0.15635, -0.03366, -0.05798]],
712        ]);
713        let expected_cn_with_init_state = TensorData::from([
714            [[-0.13593, 0.17125, -0.22395]],
715            [[-0.45425, -0.11206, -0.12908]],
716        ]);
717        let expected_hn_without_init_state = TensorData::from([
718            [[-0.04026, 0.07178, -0.10189]],
719            [[-0.15969, -0.05322, -0.08863]],
720        ]);
721        let expected_cn_without_init_state = TensorData::from([
722            [[-0.15839, 0.15923, -0.23569]],
723            [[-0.47407, -0.17493, -0.19643]],
724        ]);
725
726        let (output_with_init_state, state_with_init_state) =
727            lstm.forward(input.clone(), Some(LstmState::new(c0, h0)));
728        let (output_without_init_state, state_without_init_state) = lstm.forward(input, None);
729
730        let tolerance = Tolerance::permissive();
731        output_with_init_state
732            .to_data()
733            .assert_approx_eq::<FT>(&expected_output_with_init_state, tolerance);
734        output_without_init_state
735            .to_data()
736            .assert_approx_eq::<FT>(&expected_output_without_init_state, tolerance);
737        state_with_init_state
738            .hidden
739            .to_data()
740            .assert_approx_eq::<FT>(&expected_hn_with_init_state, tolerance);
741        state_with_init_state
742            .cell
743            .to_data()
744            .assert_approx_eq::<FT>(&expected_cn_with_init_state, tolerance);
745        state_without_init_state
746            .hidden
747            .to_data()
748            .assert_approx_eq::<FT>(&expected_hn_without_init_state, tolerance);
749        state_without_init_state
750            .cell
751            .to_data()
752            .assert_approx_eq::<FT>(&expected_cn_without_init_state, tolerance);
753    }
754
755    #[test]
756    fn display_lstm() {
757        let config = LstmConfig::new(2, 3, true);
758
759        let layer = config.init::<TestBackend>(&Default::default());
760
761        assert_eq!(
762            alloc::format!("{layer}"),
763            "Lstm {d_input: 2, d_hidden: 3, bias: true, params: 84}"
764        );
765    }
766
767    #[test]
768    fn display_bilstm() {
769        let config = BiLstmConfig::new(2, 3, true);
770
771        let layer = config.init::<TestBackend>(&Default::default());
772
773        assert_eq!(
774            alloc::format!("{layer}"),
775            "BiLstm {d_input: 2, d_hidden: 3, bias: true, params: 168}"
776        );
777    }
778}