Skip to main content

burn_nn/modules/rnn/
lstm.rs

1use burn_core as burn;
2
3use crate::GateController;
4use crate::activation::{Activation, ActivationConfig};
5use burn::config::Config;
6use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9
10/// A LstmState is used to store cell state and hidden state in LSTM.
11pub struct LstmState<B: Backend, const D: usize> {
12    /// The cell state.
13    pub cell: Tensor<B, D>,
14    /// The hidden state.
15    pub hidden: Tensor<B, D>,
16}
17
18impl<B: Backend, const D: usize> LstmState<B, D> {
19    /// Initialize a new [LSTM State](LstmState).
20    pub fn new(cell: Tensor<B, D>, hidden: Tensor<B, D>) -> Self {
21        Self { cell, hidden }
22    }
23}
24
25/// Configuration to create a [Lstm](Lstm) module using the [init function](LstmConfig::init).
26#[derive(Config, Debug)]
27pub struct LstmConfig {
28    /// The size of the input features.
29    pub d_input: usize,
30    /// The size of the hidden state.
31    pub d_hidden: usize,
32    /// If a bias should be applied during the Lstm transformation.
33    pub bias: bool,
34    /// Lstm initializer
35    #[config(default = "Initializer::XavierNormal{gain:1.0}")]
36    pub initializer: Initializer,
37    /// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.
38    /// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.
39    #[config(default = true)]
40    pub batch_first: bool,
41    /// If true, process the sequence in reverse order.
42    /// This is useful for implementing reverse-direction LSTMs (e.g., ONNX reverse direction).
43    #[config(default = false)]
44    pub reverse: bool,
45    /// Optional cell state clip threshold. If provided, cell state values are clipped
46    /// to the range `[-clip, +clip]` after each timestep. This can help prevent
47    /// exploding values during inference.
48    pub clip: Option<f64>,
49    /// If true, couples the input and forget gates: `f_t = 1 - i_t`.
50    /// This reduces the number of parameters and is based on GRU-style simplification.
51    #[config(default = false)]
52    pub input_forget: bool,
53    /// Activation function for the input, forget, and output gates.
54    /// Default is Sigmoid, which is standard for LSTM gates.
55    #[config(default = "ActivationConfig::Sigmoid")]
56    pub gate_activation: ActivationConfig,
57    /// Activation function for the cell gate (candidate cell state).
58    /// Default is Tanh, which is standard for LSTM.
59    #[config(default = "ActivationConfig::Tanh")]
60    pub cell_activation: ActivationConfig,
61    /// Activation function applied to the cell state before computing hidden output.
62    /// Default is Tanh, which is standard for LSTM.
63    #[config(default = "ActivationConfig::Tanh")]
64    pub hidden_activation: ActivationConfig,
65}
66
67/// The Lstm module. This implementation is for a unidirectional, stateless, Lstm.
68///
69/// Introduced in the paper: [Long Short-Term Memory](https://www.researchgate.net/publication/13853244).
70///
71/// Should be created with [LstmConfig].
72#[derive(Module, Debug)]
73#[module(custom_display)]
74pub struct Lstm<B: Backend> {
75    /// The input gate regulates which information to update and store in the cell state at each time step.
76    pub input_gate: GateController<B>,
77    /// The forget gate is used to control which information to discard or keep in the memory cell at each time step.
78    /// Note: When `input_forget` is true, this gate is not used (forget = 1 - input).
79    pub forget_gate: GateController<B>,
80    /// The output gate determines which information from the cell state to output at each time step.
81    pub output_gate: GateController<B>,
82    /// The cell gate is used to compute the cell state that stores and carries information through time.
83    pub cell_gate: GateController<B>,
84    /// The hidden state of the LSTM.
85    pub d_hidden: usize,
86    /// If true, input is `[batch_size, seq_length, input_size]`.
87    /// If false, input is `[seq_length, batch_size, input_size]`.
88    pub batch_first: bool,
89    /// If true, process the sequence in reverse order.
90    pub reverse: bool,
91    /// Optional cell state clip threshold.
92    pub clip: Option<f64>,
93    /// If true, couples input and forget gates: f_t = 1 - i_t.
94    pub input_forget: bool,
95    /// Activation function for gates (input, forget, output).
96    pub gate_activation: Activation<B>,
97    /// Activation function for cell gate (candidate cell state).
98    pub cell_activation: Activation<B>,
99    /// Activation function for hidden output.
100    pub hidden_activation: Activation<B>,
101}
102
103impl<B: Backend> ModuleDisplay for Lstm<B> {
104    fn custom_settings(&self) -> Option<DisplaySettings> {
105        DisplaySettings::new()
106            .with_new_line_after_attribute(false)
107            .optional()
108    }
109
110    fn custom_content(&self, content: Content) -> Option<Content> {
111        let [d_input, _] = self.input_gate.input_transform.weight.shape().dims();
112        let bias = self.input_gate.input_transform.bias.is_some();
113
114        content
115            .add("d_input", &d_input)
116            .add("d_hidden", &self.d_hidden)
117            .add("bias", &bias)
118            .optional()
119    }
120}
121
122impl LstmConfig {
123    /// Initialize a new [lstm](Lstm) module.
124    pub fn init<B: Backend>(&self, device: &B::Device) -> Lstm<B> {
125        let d_output = self.d_hidden;
126
127        let new_gate = || {
128            GateController::new(
129                self.d_input,
130                d_output,
131                self.bias,
132                self.initializer.clone(),
133                device,
134            )
135        };
136
137        Lstm {
138            input_gate: new_gate(),
139            forget_gate: new_gate(),
140            output_gate: new_gate(),
141            cell_gate: new_gate(),
142            d_hidden: self.d_hidden,
143            batch_first: self.batch_first,
144            reverse: self.reverse,
145            clip: self.clip,
146            input_forget: self.input_forget,
147            gate_activation: self.gate_activation.init(device),
148            cell_activation: self.cell_activation.init(device),
149            hidden_activation: self.hidden_activation.init(device),
150        }
151    }
152}
153
154impl<B: Backend> Lstm<B> {
155    /// Applies the forward pass on the input tensor. This LSTM implementation
156    /// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
157    ///
158    /// ## Parameters:
159    /// - batched_input: The input tensor of shape:
160    ///   - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)
161    ///   - `[sequence_length, batch_size, input_size]` if `batch_first` is false
162    /// - state: An optional `LstmState` representing the initial cell state and hidden state.
163    ///   Each state tensor has shape `[batch_size, hidden_size]`.
164    ///   If no initial state is provided, these tensors are initialized to zeros.
165    ///
166    /// ## Returns:
167    /// - output: A tensor represents the output features of LSTM. Shape:
168    ///   - `[batch_size, sequence_length, hidden_size]` if `batch_first` is true
169    ///   - `[sequence_length, batch_size, hidden_size]` if `batch_first` is false
170    /// - state: A `LstmState` represents the final states. Both `state.cell` and `state.hidden` have the shape
171    ///   `[batch_size, hidden_size]`.
172    pub fn forward(
173        &self,
174        batched_input: Tensor<B, 3>,
175        state: Option<LstmState<B, 2>>,
176    ) -> (Tensor<B, 3>, LstmState<B, 2>) {
177        // Convert to batch-first layout internally if needed
178        let batched_input = if self.batch_first {
179            batched_input
180        } else {
181            batched_input.swap_dims(0, 1)
182        };
183
184        let device = batched_input.device();
185        let [batch_size, seq_length, _] = batched_input.dims();
186
187        // Process sequence in forward or reverse order based on config
188        let (output, state) = if self.reverse {
189            self.forward_iter(
190                batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
191                state,
192                batch_size,
193                seq_length,
194                &device,
195            )
196        } else {
197            self.forward_iter(
198                batched_input.iter_dim(1).zip(0..seq_length),
199                state,
200                batch_size,
201                seq_length,
202                &device,
203            )
204        };
205
206        // Convert output back to seq-first layout if needed
207        let output = if self.batch_first {
208            output
209        } else {
210            output.swap_dims(0, 1)
211        };
212
213        (output, state)
214    }
215
216    fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(
217        &self,
218        input_timestep_iter: I,
219        state: Option<LstmState<B, 2>>,
220        batch_size: usize,
221        seq_length: usize,
222        device: &B::Device,
223    ) -> (Tensor<B, 3>, LstmState<B, 2>) {
224        let mut batched_hidden_state =
225            Tensor::empty([batch_size, seq_length, self.d_hidden], device);
226
227        let (mut cell_state, mut hidden_state) = match state {
228            Some(state) => (state.cell, state.hidden),
229            None => (
230                Tensor::zeros([batch_size, self.d_hidden], device),
231                Tensor::zeros([batch_size, self.d_hidden], device),
232            ),
233        };
234
235        for (input_t, t) in input_timestep_iter {
236            let input_t = input_t.squeeze_dim(1);
237
238            // i(nput)g(ate) tensors
239            let biased_ig_input_sum = self
240                .input_gate
241                .gate_product(input_t.clone(), hidden_state.clone());
242            let input_values = self.gate_activation.forward(biased_ig_input_sum);
243
244            // f(orget)g(ate) tensors - either computed or coupled to input gate
245            let forget_values = if self.input_forget {
246                // Coupled mode: f_t = 1 - i_t
247                input_values.clone().neg().add_scalar(1.0)
248            } else {
249                let biased_fg_input_sum = self
250                    .forget_gate
251                    .gate_product(input_t.clone(), hidden_state.clone());
252                self.gate_activation.forward(biased_fg_input_sum)
253            };
254
255            // o(output)g(ate) tensors
256            let biased_og_input_sum = self
257                .output_gate
258                .gate_product(input_t.clone(), hidden_state.clone());
259            let output_values = self.gate_activation.forward(biased_og_input_sum);
260
261            // c(ell)g(ate) tensors
262            let biased_cg_input_sum = self
263                .cell_gate
264                .gate_product(input_t.clone(), hidden_state.clone());
265            let candidate_cell_values = self.cell_activation.forward(biased_cg_input_sum);
266
267            cell_state = forget_values * cell_state.clone() + input_values * candidate_cell_values;
268
269            // Apply cell state clipping if configured
270            if let Some(clip) = self.clip {
271                cell_state = cell_state.clamp(-clip, clip);
272            }
273
274            hidden_state = output_values * self.hidden_activation.forward(cell_state.clone());
275
276            let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1);
277
278            // store the hidden state for this timestep
279            batched_hidden_state = batched_hidden_state.slice_assign(
280                [0..batch_size, t..(t + 1), 0..self.d_hidden],
281                unsqueezed_hidden_state.clone(),
282            );
283        }
284
285        (
286            batched_hidden_state,
287            LstmState::new(cell_state, hidden_state),
288        )
289    }
290}
291
292/// Configuration to create a [BiLstm](BiLstm) module using the [init function](BiLstmConfig::init).
293#[derive(Config, Debug)]
294pub struct BiLstmConfig {
295    /// The size of the input features.
296    pub d_input: usize,
297    /// The size of the hidden state.
298    pub d_hidden: usize,
299    /// If a bias should be applied during the BiLstm transformation.
300    pub bias: bool,
301    /// BiLstm initializer
302    #[config(default = "Initializer::XavierNormal{gain:1.0}")]
303    pub initializer: Initializer,
304    /// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.
305    /// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.
306    #[config(default = true)]
307    pub batch_first: bool,
308    /// Optional cell state clip threshold.
309    pub clip: Option<f64>,
310    /// If true, couples the input and forget gates.
311    #[config(default = false)]
312    pub input_forget: bool,
313    /// Activation function for the input, forget, and output gates.
314    #[config(default = "ActivationConfig::Sigmoid")]
315    pub gate_activation: ActivationConfig,
316    /// Activation function for the cell gate (candidate cell state).
317    #[config(default = "ActivationConfig::Tanh")]
318    pub cell_activation: ActivationConfig,
319    /// Activation function applied to the cell state before computing hidden output.
320    #[config(default = "ActivationConfig::Tanh")]
321    pub hidden_activation: ActivationConfig,
322}
323
324/// The BiLstm module. This implementation is for Bidirectional LSTM.
325///
326/// Introduced in the paper: [Framewise phoneme classification with bidirectional LSTM and other neural network architectures](https://www.cs.toronto.edu/~graves/ijcnn_2005.pdf).
327///
328/// Should be created with [BiLstmConfig].
329#[derive(Module, Debug)]
330#[module(custom_display)]
331pub struct BiLstm<B: Backend> {
332    /// LSTM for the forward direction.
333    pub forward: Lstm<B>,
334    /// LSTM for the reverse direction.
335    pub reverse: Lstm<B>,
336    /// The size of the hidden state.
337    pub d_hidden: usize,
338    /// If true, input is `[batch_size, seq_length, input_size]`.
339    /// If false, input is `[seq_length, batch_size, input_size]`.
340    pub batch_first: bool,
341}
342
343impl<B: Backend> ModuleDisplay for BiLstm<B> {
344    fn custom_settings(&self) -> Option<DisplaySettings> {
345        DisplaySettings::new()
346            .with_new_line_after_attribute(false)
347            .optional()
348    }
349
350    fn custom_content(&self, content: Content) -> Option<Content> {
351        let [d_input, _] = self
352            .forward
353            .input_gate
354            .input_transform
355            .weight
356            .shape()
357            .dims();
358        let bias = self.forward.input_gate.input_transform.bias.is_some();
359
360        content
361            .add("d_input", &d_input)
362            .add("d_hidden", &self.d_hidden)
363            .add("bias", &bias)
364            .optional()
365    }
366}
367
368impl BiLstmConfig {
369    /// Initialize a new [Bidirectional LSTM](BiLstm) module.
370    pub fn init<B: Backend>(&self, device: &B::Device) -> BiLstm<B> {
371        // Internal LSTMs always use batch_first=true; BiLstm handles layout conversion
372        let base_config = LstmConfig::new(self.d_input, self.d_hidden, self.bias)
373            .with_initializer(self.initializer.clone())
374            .with_batch_first(true)
375            .with_clip(self.clip)
376            .with_input_forget(self.input_forget)
377            .with_gate_activation(self.gate_activation.clone())
378            .with_cell_activation(self.cell_activation.clone())
379            .with_hidden_activation(self.hidden_activation.clone());
380
381        BiLstm {
382            forward: base_config.clone().init(device),
383            reverse: base_config.init(device),
384            d_hidden: self.d_hidden,
385            batch_first: self.batch_first,
386        }
387    }
388}
389
390impl<B: Backend> BiLstm<B> {
391    /// Applies the forward pass on the input tensor. This Bidirectional LSTM implementation
392    /// returns the state for each element in a sequence (i.e., across seq_length) and a final state.
393    ///
394    /// ## Parameters:
395    /// - batched_input: The input tensor of shape:
396    ///   - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)
397    ///   - `[sequence_length, batch_size, input_size]` if `batch_first` is false
398    /// - state: An optional `LstmState` representing the initial cell state and hidden state.
399    ///   Each state tensor has shape `[2, batch_size, hidden_size]`.
400    ///   If no initial state is provided, these tensors are initialized to zeros.
401    ///
402    /// ## Returns:
403    /// - output: A tensor represents the output features of LSTM. Shape:
404    ///   - `[batch_size, sequence_length, hidden_size * 2]` if `batch_first` is true
405    ///   - `[sequence_length, batch_size, hidden_size * 2]` if `batch_first` is false
406    /// - state: A `LstmState` represents the final forward and reverse states. Both `state.cell` and
407    ///   `state.hidden` have the shape `[2, batch_size, hidden_size]`.
408    pub fn forward(
409        &self,
410        batched_input: Tensor<B, 3>,
411        state: Option<LstmState<B, 3>>,
412    ) -> (Tensor<B, 3>, LstmState<B, 3>) {
413        // Convert to batch-first layout internally if needed
414        let batched_input = if self.batch_first {
415            batched_input
416        } else {
417            batched_input.swap_dims(0, 1)
418        };
419
420        let device = batched_input.clone().device();
421        let [batch_size, seq_length, _] = batched_input.shape().dims();
422
423        let [init_state_forward, init_state_reverse] = match state {
424            Some(state) => {
425                let cell_state_forward = state
426                    .cell
427                    .clone()
428                    .slice([0..1, 0..batch_size, 0..self.d_hidden])
429                    .squeeze_dim(0);
430                let hidden_state_forward = state
431                    .hidden
432                    .clone()
433                    .slice([0..1, 0..batch_size, 0..self.d_hidden])
434                    .squeeze_dim(0);
435                let cell_state_reverse = state
436                    .cell
437                    .slice([1..2, 0..batch_size, 0..self.d_hidden])
438                    .squeeze_dim(0);
439                let hidden_state_reverse = state
440                    .hidden
441                    .slice([1..2, 0..batch_size, 0..self.d_hidden])
442                    .squeeze_dim(0);
443
444                [
445                    Some(LstmState::new(cell_state_forward, hidden_state_forward)),
446                    Some(LstmState::new(cell_state_reverse, hidden_state_reverse)),
447                ]
448            }
449            None => [None, None],
450        };
451
452        // forward direction
453        let (batched_hidden_state_forward, final_state_forward) = self
454            .forward
455            .forward(batched_input.clone(), init_state_forward);
456
457        // reverse direction
458        let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(
459            batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
460            init_state_reverse,
461            batch_size,
462            seq_length,
463            &device,
464        );
465
466        let output = Tensor::cat(
467            [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),
468            2,
469        );
470
471        // Convert output back to seq-first layout if needed
472        let output = if self.batch_first {
473            output
474        } else {
475            output.swap_dims(0, 1)
476        };
477
478        let state = LstmState::new(
479            Tensor::stack(
480                [final_state_forward.cell, final_state_reverse.cell].to_vec(),
481                0,
482            ),
483            Tensor::stack(
484                [final_state_forward.hidden, final_state_reverse.hidden].to_vec(),
485                0,
486            ),
487        );
488
489        (output, state)
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use crate::{LinearRecord, TestBackend};
497    use burn::module::Param;
498    use burn::tensor::{Device, Distribution, TensorData};
499    use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
500    type FT = FloatElem<TestBackend>;
501
502    #[cfg(feature = "std")]
503    use crate::TestAutodiffBackend;
504
505    #[test]
506    fn test_with_uniform_initializer() {
507        let device = Default::default();
508        TestBackend::seed(&device, 0);
509
510        let config = LstmConfig::new(5, 5, false)
511            .with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 });
512        let lstm = config.init::<TestBackend>(&Default::default());
513
514        let gate_to_data =
515            |gate: GateController<TestBackend>| gate.input_transform.weight.val().to_data();
516
517        gate_to_data(lstm.input_gate).assert_within_range::<FT>(0.elem()..1.elem());
518        gate_to_data(lstm.forget_gate).assert_within_range::<FT>(0.elem()..1.elem());
519        gate_to_data(lstm.output_gate).assert_within_range::<FT>(0.elem()..1.elem());
520        gate_to_data(lstm.cell_gate).assert_within_range::<FT>(0.elem()..1.elem());
521    }
522
523    /// Test forward pass with simple input vector.
524    ///
525    /// f_t = sigmoid(0.7*0.1 + 0.7*0) = sigmoid(0.07) = 0.5173928
526    /// i_t = sigmoid(0.5*0.1 + 0.5*0) = sigmoid(0.05) = 0.5123725
527    /// o_t = sigmoid(1.1*0.1 + 1.1*0) = sigmoid(0.11) = 0.5274723
528    /// c_t = tanh(0.9*0.1 + 0.9*0) = tanh(0.09) = 0.0892937
529    /// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243
530    /// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648
531    #[test]
532    fn test_forward_single_input_single_feature() {
533        let device = Default::default();
534        TestBackend::seed(&device, 0);
535
536        let config = LstmConfig::new(1, 1, false);
537        let device = Default::default();
538        let mut lstm = config.init::<TestBackend>(&device);
539
540        fn create_gate_controller(
541            weights: f32,
542            biases: f32,
543            d_input: usize,
544            d_output: usize,
545            bias: bool,
546            initializer: Initializer,
547            device: &Device<TestBackend>,
548        ) -> GateController<TestBackend> {
549            let record_1 = LinearRecord {
550                weight: Param::from_data(TensorData::from([[weights]]), device),
551                bias: Some(Param::from_data(TensorData::from([biases]), device)),
552            };
553            let record_2 = LinearRecord {
554                weight: Param::from_data(TensorData::from([[weights]]), device),
555                bias: Some(Param::from_data(TensorData::from([biases]), device)),
556            };
557            GateController::create_with_weights(
558                d_input,
559                d_output,
560                bias,
561                initializer,
562                record_1,
563                record_2,
564            )
565        }
566
567        lstm.input_gate = create_gate_controller(
568            0.5,
569            0.0,
570            1,
571            1,
572            false,
573            Initializer::XavierUniform { gain: 1.0 },
574            &device,
575        );
576        lstm.forget_gate = create_gate_controller(
577            0.7,
578            0.0,
579            1,
580            1,
581            false,
582            Initializer::XavierUniform { gain: 1.0 },
583            &device,
584        );
585        lstm.cell_gate = create_gate_controller(
586            0.9,
587            0.0,
588            1,
589            1,
590            false,
591            Initializer::XavierUniform { gain: 1.0 },
592            &device,
593        );
594        lstm.output_gate = create_gate_controller(
595            1.1,
596            0.0,
597            1,
598            1,
599            false,
600            Initializer::XavierUniform { gain: 1.0 },
601            &device,
602        );
603
604        // single timestep with single feature
605        let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
606
607        let (output, state) = lstm.forward(input, None);
608
609        let expected = TensorData::from([[0.046]]);
610        let tolerance = Tolerance::default();
611        state
612            .cell
613            .to_data()
614            .assert_approx_eq::<FT>(&expected, tolerance);
615
616        let expected = TensorData::from([[0.0242]]);
617        state
618            .hidden
619            .to_data()
620            .assert_approx_eq::<FT>(&expected, tolerance);
621
622        output
623            .select(0, Tensor::arange(0..1, &device))
624            .squeeze_dim::<2>(0)
625            .to_data()
626            .assert_approx_eq::<FT>(&state.hidden.to_data(), tolerance);
627    }
628
629    #[test]
630    fn test_batched_forward_pass() {
631        let device = Default::default();
632        let lstm = LstmConfig::new(64, 1024, true).init(&device);
633        let batched_input =
634            Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
635
636        let (output, state) = lstm.forward(batched_input, None);
637
638        assert_eq!(output.dims(), [8, 10, 1024]);
639        assert_eq!(state.cell.dims(), [8, 1024]);
640        assert_eq!(state.hidden.dims(), [8, 1024]);
641    }
642
643    #[test]
644    fn test_batched_forward_pass_batch_of_one() {
645        let device = Default::default();
646        let lstm = LstmConfig::new(64, 1024, true).init(&device);
647        let batched_input =
648            Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);
649
650        let (output, state) = lstm.forward(batched_input, None);
651
652        assert_eq!(output.dims(), [1, 2, 1024]);
653        assert_eq!(state.cell.dims(), [1, 1024]);
654        assert_eq!(state.hidden.dims(), [1, 1024]);
655    }
656
657    #[test]
658    #[cfg(feature = "std")]
659    fn test_batched_backward_pass() {
660        use burn::tensor::Shape;
661        let device = Default::default();
662        let lstm = LstmConfig::new(64, 32, true).init(&device);
663        let shape: Shape = [8, 10, 64].into();
664        let batched_input =
665            Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);
666
667        let (output, _) = lstm.forward(batched_input.clone(), None);
668        let fake_loss = output;
669        let grads = fake_loss.backward();
670
671        let some_gradient = lstm
672            .output_gate
673            .hidden_transform
674            .weight
675            .grad(&grads)
676            .unwrap();
677
678        // Asserts that the gradients exist and are non-zero
679        assert_ne!(
680            some_gradient
681                .any()
682                .into_data()
683                .iter::<f32>()
684                .next()
685                .unwrap(),
686            0.0
687        );
688    }
689
690    #[test]
691    fn test_bidirectional() {
692        let device = Default::default();
693        TestBackend::seed(&device, 0);
694
695        let config = BiLstmConfig::new(2, 3, true);
696        let device = Default::default();
697        let mut lstm = config.init(&device);
698
699        fn create_gate_controller<const D1: usize, const D2: usize>(
700            input_weights: [[f32; D1]; D2],
701            input_biases: [f32; D1],
702            hidden_weights: [[f32; D1]; D1],
703            hidden_biases: [f32; D1],
704            device: &Device<TestBackend>,
705        ) -> GateController<TestBackend> {
706            let d_input = input_weights[0].len();
707            let d_output = input_weights.len();
708
709            let input_record = LinearRecord {
710                weight: Param::from_data(TensorData::from(input_weights), device),
711                bias: Some(Param::from_data(TensorData::from(input_biases), device)),
712            };
713            let hidden_record = LinearRecord {
714                weight: Param::from_data(TensorData::from(hidden_weights), device),
715                bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
716            };
717            GateController::create_with_weights(
718                d_input,
719                d_output,
720                true,
721                Initializer::XavierUniform { gain: 1.0 },
722                input_record,
723                hidden_record,
724            )
725        }
726
727        let input = Tensor::<TestBackend, 3>::from_data(
728            TensorData::from([[
729                [0.949, -0.861],
730                [0.892, 0.927],
731                [-0.173, -0.301],
732                [-0.081, 0.992],
733            ]]),
734            &device,
735        );
736        let h0 = Tensor::<TestBackend, 3>::from_data(
737            TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),
738            &device,
739        );
740        let c0 = Tensor::<TestBackend, 3>::from_data(
741            TensorData::from([[[0.723, 0.397, -0.262]], [[0.471, 0.613, 1.885]]]),
742            &device,
743        );
744
745        lstm.forward.input_gate = create_gate_controller(
746            [[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]],
747            [-0.196, 0.354, 0.209],
748            [
749                [-0.320, 0.232, -0.165],
750                [0.093, -0.572, -0.315],
751                [-0.467, 0.325, 0.046],
752            ],
753            [0.181, -0.190, -0.245],
754            &device,
755        );
756
757        lstm.forward.forget_gate = create_gate_controller(
758            [[-0.342, -0.084, -0.420], [-0.432, 0.119, 0.191]],
759            [0.315, -0.413, -0.041],
760            [
761                [0.453, 0.063, 0.561],
762                [0.211, 0.149, 0.213],
763                [-0.499, -0.158, 0.068],
764            ],
765            [-0.431, -0.535, 0.125],
766            &device,
767        );
768
769        lstm.forward.cell_gate = create_gate_controller(
770            [[-0.046, -0.382, 0.321], [-0.533, 0.558, 0.004]],
771            [-0.358, 0.282, -0.078],
772            [
773                [-0.358, 0.109, 0.139],
774                [-0.345, 0.091, -0.368],
775                [-0.508, 0.221, -0.507],
776            ],
777            [0.502, -0.509, -0.247],
778            &device,
779        );
780
781        lstm.forward.output_gate = create_gate_controller(
782            [[-0.577, -0.359, 0.216], [-0.550, 0.268, 0.243]],
783            [-0.227, -0.274, 0.039],
784            [
785                [-0.383, 0.449, 0.222],
786                [-0.357, -0.093, 0.449],
787                [-0.106, 0.236, 0.360],
788            ],
789            [-0.361, -0.209, -0.454],
790            &device,
791        );
792
793        lstm.reverse.input_gate = create_gate_controller(
794            [[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]],
795            [0.540, -0.164, 0.033],
796            [
797                [0.159, 0.180, -0.037],
798                [-0.443, 0.485, -0.488],
799                [0.098, -0.085, -0.140],
800            ],
801            [-0.510, 0.105, 0.114],
802            &device,
803        );
804
805        lstm.reverse.forget_gate = create_gate_controller(
806            [[-0.154, -0.432, -0.547], [-0.369, -0.310, -0.175]],
807            [0.141, 0.004, 0.055],
808            [
809                [-0.005, -0.277, -0.515],
810                [-0.011, -0.101, -0.365],
811                [0.426, 0.379, 0.337],
812            ],
813            [-0.382, 0.331, -0.176],
814            &device,
815        );
816
817        lstm.reverse.cell_gate = create_gate_controller(
818            [[-0.571, 0.228, -0.287], [-0.331, 0.110, 0.219]],
819            [-0.206, -0.546, 0.462],
820            [
821                [0.449, -0.240, 0.071],
822                [-0.045, 0.131, 0.124],
823                [0.138, -0.201, 0.191],
824            ],
825            [-0.030, 0.211, -0.352],
826            &device,
827        );
828
829        lstm.reverse.output_gate = create_gate_controller(
830            [[0.491, -0.442, 0.333], [0.313, -0.121, -0.070]],
831            [-0.387, -0.250, 0.066],
832            [
833                [-0.030, 0.268, 0.299],
834                [-0.019, -0.280, -0.314],
835                [0.466, -0.365, -0.248],
836            ],
837            [-0.398, -0.199, -0.566],
838            &device,
839        );
840
841        let expected_output_with_init_state = TensorData::from([[
842            [0.23764, -0.03442, 0.04414, -0.15635, -0.03366, -0.05798],
843            [0.00473, -0.02254, 0.02988, -0.16510, -0.00306, 0.08742],
844            [0.06210, -0.06509, -0.05339, -0.01710, 0.02091, 0.16012],
845            [-0.03420, 0.07774, -0.09774, -0.02604, 0.12584, 0.20872],
846        ]]);
847        let expected_output_without_init_state = TensorData::from([[
848            [0.08679, -0.08776, -0.00528, -0.15969, -0.05322, -0.08863],
849            [-0.02577, -0.05057, 0.00033, -0.17558, -0.03679, 0.03142],
850            [0.02942, -0.07411, -0.06044, -0.03601, -0.09998, 0.04846],
851            [-0.04026, 0.07178, -0.10189, -0.07349, -0.04576, 0.05550],
852        ]]);
853        let expected_hn_with_init_state = TensorData::from([
854            [[-0.03420, 0.07774, -0.09774]],
855            [[-0.15635, -0.03366, -0.05798]],
856        ]);
857        let expected_cn_with_init_state = TensorData::from([
858            [[-0.13593, 0.17125, -0.22395]],
859            [[-0.45425, -0.11206, -0.12908]],
860        ]);
861        let expected_hn_without_init_state = TensorData::from([
862            [[-0.04026, 0.07178, -0.10189]],
863            [[-0.15969, -0.05322, -0.08863]],
864        ]);
865        let expected_cn_without_init_state = TensorData::from([
866            [[-0.15839, 0.15923, -0.23569]],
867            [[-0.47407, -0.17493, -0.19643]],
868        ]);
869
870        let (output_with_init_state, state_with_init_state) =
871            lstm.forward(input.clone(), Some(LstmState::new(c0, h0)));
872        let (output_without_init_state, state_without_init_state) = lstm.forward(input, None);
873
874        let tolerance = Tolerance::permissive();
875        output_with_init_state
876            .to_data()
877            .assert_approx_eq::<FT>(&expected_output_with_init_state, tolerance);
878        output_without_init_state
879            .to_data()
880            .assert_approx_eq::<FT>(&expected_output_without_init_state, tolerance);
881        state_with_init_state
882            .hidden
883            .to_data()
884            .assert_approx_eq::<FT>(&expected_hn_with_init_state, tolerance);
885        state_with_init_state
886            .cell
887            .to_data()
888            .assert_approx_eq::<FT>(&expected_cn_with_init_state, tolerance);
889        state_without_init_state
890            .hidden
891            .to_data()
892            .assert_approx_eq::<FT>(&expected_hn_without_init_state, tolerance);
893        state_without_init_state
894            .cell
895            .to_data()
896            .assert_approx_eq::<FT>(&expected_cn_without_init_state, tolerance);
897    }
898
899    #[test]
900    fn display_lstm() {
901        let config = LstmConfig::new(2, 3, true);
902
903        let layer = config.init::<TestBackend>(&Default::default());
904
905        assert_eq!(
906            alloc::format!("{layer}"),
907            "Lstm {d_input: 2, d_hidden: 3, bias: true, params: 84}"
908        );
909    }
910
911    #[test]
912    fn display_bilstm() {
913        let config = BiLstmConfig::new(2, 3, true);
914
915        let layer = config.init::<TestBackend>(&Default::default());
916
917        assert_eq!(
918            alloc::format!("{layer}"),
919            "BiLstm {d_input: 2, d_hidden: 3, bias: true, params: 168}"
920        );
921    }
922}