1use burn_core as burn;
2
3use crate::GateController;
4use crate::activation::{Activation, ActivationConfig};
5use burn::config::Config;
6use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9
10pub struct LstmState<B: Backend, const D: usize> {
12 pub cell: Tensor<B, D>,
14 pub hidden: Tensor<B, D>,
16}
17
18impl<B: Backend, const D: usize> LstmState<B, D> {
19 pub fn new(cell: Tensor<B, D>, hidden: Tensor<B, D>) -> Self {
21 Self { cell, hidden }
22 }
23}
24
25#[derive(Config, Debug)]
27pub struct LstmConfig {
28 pub d_input: usize,
30 pub d_hidden: usize,
32 pub bias: bool,
34 #[config(default = "Initializer::XavierNormal{gain:1.0}")]
36 pub initializer: Initializer,
37 #[config(default = true)]
40 pub batch_first: bool,
41 #[config(default = false)]
44 pub reverse: bool,
45 pub clip: Option<f64>,
49 #[config(default = false)]
52 pub input_forget: bool,
53 #[config(default = "ActivationConfig::Sigmoid")]
56 pub gate_activation: ActivationConfig,
57 #[config(default = "ActivationConfig::Tanh")]
60 pub cell_activation: ActivationConfig,
61 #[config(default = "ActivationConfig::Tanh")]
64 pub hidden_activation: ActivationConfig,
65}
66
67#[derive(Module, Debug)]
73#[module(custom_display)]
74pub struct Lstm<B: Backend> {
75 pub input_gate: GateController<B>,
77 pub forget_gate: GateController<B>,
80 pub output_gate: GateController<B>,
82 pub cell_gate: GateController<B>,
84 pub d_hidden: usize,
86 pub batch_first: bool,
89 pub reverse: bool,
91 pub clip: Option<f64>,
93 pub input_forget: bool,
95 pub gate_activation: Activation<B>,
97 pub cell_activation: Activation<B>,
99 pub hidden_activation: Activation<B>,
101}
102
103impl<B: Backend> ModuleDisplay for Lstm<B> {
104 fn custom_settings(&self) -> Option<DisplaySettings> {
105 DisplaySettings::new()
106 .with_new_line_after_attribute(false)
107 .optional()
108 }
109
110 fn custom_content(&self, content: Content) -> Option<Content> {
111 let [d_input, _] = self.input_gate.input_transform.weight.shape().dims();
112 let bias = self.input_gate.input_transform.bias.is_some();
113
114 content
115 .add("d_input", &d_input)
116 .add("d_hidden", &self.d_hidden)
117 .add("bias", &bias)
118 .optional()
119 }
120}
121
122impl LstmConfig {
123 pub fn init<B: Backend>(&self, device: &B::Device) -> Lstm<B> {
125 let d_output = self.d_hidden;
126
127 let new_gate = || {
128 GateController::new(
129 self.d_input,
130 d_output,
131 self.bias,
132 self.initializer.clone(),
133 device,
134 )
135 };
136
137 Lstm {
138 input_gate: new_gate(),
139 forget_gate: new_gate(),
140 output_gate: new_gate(),
141 cell_gate: new_gate(),
142 d_hidden: self.d_hidden,
143 batch_first: self.batch_first,
144 reverse: self.reverse,
145 clip: self.clip,
146 input_forget: self.input_forget,
147 gate_activation: self.gate_activation.init(device),
148 cell_activation: self.cell_activation.init(device),
149 hidden_activation: self.hidden_activation.init(device),
150 }
151 }
152}
153
154impl<B: Backend> Lstm<B> {
155 pub fn forward(
173 &self,
174 batched_input: Tensor<B, 3>,
175 state: Option<LstmState<B, 2>>,
176 ) -> (Tensor<B, 3>, LstmState<B, 2>) {
177 let batched_input = if self.batch_first {
179 batched_input
180 } else {
181 batched_input.swap_dims(0, 1)
182 };
183
184 let device = batched_input.device();
185 let [batch_size, seq_length, _] = batched_input.dims();
186
187 let (output, state) = if self.reverse {
189 self.forward_iter(
190 batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
191 state,
192 batch_size,
193 seq_length,
194 &device,
195 )
196 } else {
197 self.forward_iter(
198 batched_input.iter_dim(1).zip(0..seq_length),
199 state,
200 batch_size,
201 seq_length,
202 &device,
203 )
204 };
205
206 let output = if self.batch_first {
208 output
209 } else {
210 output.swap_dims(0, 1)
211 };
212
213 (output, state)
214 }
215
216 fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(
217 &self,
218 input_timestep_iter: I,
219 state: Option<LstmState<B, 2>>,
220 batch_size: usize,
221 seq_length: usize,
222 device: &B::Device,
223 ) -> (Tensor<B, 3>, LstmState<B, 2>) {
224 let mut batched_hidden_state =
225 Tensor::empty([batch_size, seq_length, self.d_hidden], device);
226
227 let (mut cell_state, mut hidden_state) = match state {
228 Some(state) => (state.cell, state.hidden),
229 None => (
230 Tensor::zeros([batch_size, self.d_hidden], device),
231 Tensor::zeros([batch_size, self.d_hidden], device),
232 ),
233 };
234
235 for (input_t, t) in input_timestep_iter {
236 let input_t = input_t.squeeze_dim(1);
237
238 let biased_ig_input_sum = self
240 .input_gate
241 .gate_product(input_t.clone(), hidden_state.clone());
242 let input_values = self.gate_activation.forward(biased_ig_input_sum);
243
244 let forget_values = if self.input_forget {
246 input_values.clone().neg().add_scalar(1.0)
248 } else {
249 let biased_fg_input_sum = self
250 .forget_gate
251 .gate_product(input_t.clone(), hidden_state.clone());
252 self.gate_activation.forward(biased_fg_input_sum)
253 };
254
255 let biased_og_input_sum = self
257 .output_gate
258 .gate_product(input_t.clone(), hidden_state.clone());
259 let output_values = self.gate_activation.forward(biased_og_input_sum);
260
261 let biased_cg_input_sum = self
263 .cell_gate
264 .gate_product(input_t.clone(), hidden_state.clone());
265 let candidate_cell_values = self.cell_activation.forward(biased_cg_input_sum);
266
267 cell_state = forget_values * cell_state.clone() + input_values * candidate_cell_values;
268
269 if let Some(clip) = self.clip {
271 cell_state = cell_state.clamp(-clip, clip);
272 }
273
274 hidden_state = output_values * self.hidden_activation.forward(cell_state.clone());
275
276 let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1);
277
278 batched_hidden_state = batched_hidden_state.slice_assign(
280 [0..batch_size, t..(t + 1), 0..self.d_hidden],
281 unsqueezed_hidden_state.clone(),
282 );
283 }
284
285 (
286 batched_hidden_state,
287 LstmState::new(cell_state, hidden_state),
288 )
289 }
290}
291
292#[derive(Config, Debug)]
294pub struct BiLstmConfig {
295 pub d_input: usize,
297 pub d_hidden: usize,
299 pub bias: bool,
301 #[config(default = "Initializer::XavierNormal{gain:1.0}")]
303 pub initializer: Initializer,
304 #[config(default = true)]
307 pub batch_first: bool,
308 pub clip: Option<f64>,
310 #[config(default = false)]
312 pub input_forget: bool,
313 #[config(default = "ActivationConfig::Sigmoid")]
315 pub gate_activation: ActivationConfig,
316 #[config(default = "ActivationConfig::Tanh")]
318 pub cell_activation: ActivationConfig,
319 #[config(default = "ActivationConfig::Tanh")]
321 pub hidden_activation: ActivationConfig,
322}
323
324#[derive(Module, Debug)]
330#[module(custom_display)]
331pub struct BiLstm<B: Backend> {
332 pub forward: Lstm<B>,
334 pub reverse: Lstm<B>,
336 pub d_hidden: usize,
338 pub batch_first: bool,
341}
342
343impl<B: Backend> ModuleDisplay for BiLstm<B> {
344 fn custom_settings(&self) -> Option<DisplaySettings> {
345 DisplaySettings::new()
346 .with_new_line_after_attribute(false)
347 .optional()
348 }
349
350 fn custom_content(&self, content: Content) -> Option<Content> {
351 let [d_input, _] = self
352 .forward
353 .input_gate
354 .input_transform
355 .weight
356 .shape()
357 .dims();
358 let bias = self.forward.input_gate.input_transform.bias.is_some();
359
360 content
361 .add("d_input", &d_input)
362 .add("d_hidden", &self.d_hidden)
363 .add("bias", &bias)
364 .optional()
365 }
366}
367
368impl BiLstmConfig {
369 pub fn init<B: Backend>(&self, device: &B::Device) -> BiLstm<B> {
371 let base_config = LstmConfig::new(self.d_input, self.d_hidden, self.bias)
373 .with_initializer(self.initializer.clone())
374 .with_batch_first(true)
375 .with_clip(self.clip)
376 .with_input_forget(self.input_forget)
377 .with_gate_activation(self.gate_activation.clone())
378 .with_cell_activation(self.cell_activation.clone())
379 .with_hidden_activation(self.hidden_activation.clone());
380
381 BiLstm {
382 forward: base_config.clone().init(device),
383 reverse: base_config.init(device),
384 d_hidden: self.d_hidden,
385 batch_first: self.batch_first,
386 }
387 }
388}
389
390impl<B: Backend> BiLstm<B> {
391 pub fn forward(
409 &self,
410 batched_input: Tensor<B, 3>,
411 state: Option<LstmState<B, 3>>,
412 ) -> (Tensor<B, 3>, LstmState<B, 3>) {
413 let batched_input = if self.batch_first {
415 batched_input
416 } else {
417 batched_input.swap_dims(0, 1)
418 };
419
420 let device = batched_input.clone().device();
421 let [batch_size, seq_length, _] = batched_input.shape().dims();
422
423 let [init_state_forward, init_state_reverse] = match state {
424 Some(state) => {
425 let cell_state_forward = state
426 .cell
427 .clone()
428 .slice([0..1, 0..batch_size, 0..self.d_hidden])
429 .squeeze_dim(0);
430 let hidden_state_forward = state
431 .hidden
432 .clone()
433 .slice([0..1, 0..batch_size, 0..self.d_hidden])
434 .squeeze_dim(0);
435 let cell_state_reverse = state
436 .cell
437 .slice([1..2, 0..batch_size, 0..self.d_hidden])
438 .squeeze_dim(0);
439 let hidden_state_reverse = state
440 .hidden
441 .slice([1..2, 0..batch_size, 0..self.d_hidden])
442 .squeeze_dim(0);
443
444 [
445 Some(LstmState::new(cell_state_forward, hidden_state_forward)),
446 Some(LstmState::new(cell_state_reverse, hidden_state_reverse)),
447 ]
448 }
449 None => [None, None],
450 };
451
452 let (batched_hidden_state_forward, final_state_forward) = self
454 .forward
455 .forward(batched_input.clone(), init_state_forward);
456
457 let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(
459 batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
460 init_state_reverse,
461 batch_size,
462 seq_length,
463 &device,
464 );
465
466 let output = Tensor::cat(
467 [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),
468 2,
469 );
470
471 let output = if self.batch_first {
473 output
474 } else {
475 output.swap_dims(0, 1)
476 };
477
478 let state = LstmState::new(
479 Tensor::stack(
480 [final_state_forward.cell, final_state_reverse.cell].to_vec(),
481 0,
482 ),
483 Tensor::stack(
484 [final_state_forward.hidden, final_state_reverse.hidden].to_vec(),
485 0,
486 ),
487 );
488
489 (output, state)
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496 use crate::{LinearRecord, TestBackend};
497 use burn::module::Param;
498 use burn::tensor::{Device, Distribution, TensorData};
499 use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
500 type FT = FloatElem<TestBackend>;
501
502 #[cfg(feature = "std")]
503 use crate::TestAutodiffBackend;
504
505 #[test]
506 fn test_with_uniform_initializer() {
507 let device = Default::default();
508 TestBackend::seed(&device, 0);
509
510 let config = LstmConfig::new(5, 5, false)
511 .with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 });
512 let lstm = config.init::<TestBackend>(&Default::default());
513
514 let gate_to_data =
515 |gate: GateController<TestBackend>| gate.input_transform.weight.val().to_data();
516
517 gate_to_data(lstm.input_gate).assert_within_range::<FT>(0.elem()..1.elem());
518 gate_to_data(lstm.forget_gate).assert_within_range::<FT>(0.elem()..1.elem());
519 gate_to_data(lstm.output_gate).assert_within_range::<FT>(0.elem()..1.elem());
520 gate_to_data(lstm.cell_gate).assert_within_range::<FT>(0.elem()..1.elem());
521 }
522
523 #[test]
532 fn test_forward_single_input_single_feature() {
533 let device = Default::default();
534 TestBackend::seed(&device, 0);
535
536 let config = LstmConfig::new(1, 1, false);
537 let device = Default::default();
538 let mut lstm = config.init::<TestBackend>(&device);
539
540 fn create_gate_controller(
541 weights: f32,
542 biases: f32,
543 d_input: usize,
544 d_output: usize,
545 bias: bool,
546 initializer: Initializer,
547 device: &Device<TestBackend>,
548 ) -> GateController<TestBackend> {
549 let record_1 = LinearRecord {
550 weight: Param::from_data(TensorData::from([[weights]]), device),
551 bias: Some(Param::from_data(TensorData::from([biases]), device)),
552 };
553 let record_2 = LinearRecord {
554 weight: Param::from_data(TensorData::from([[weights]]), device),
555 bias: Some(Param::from_data(TensorData::from([biases]), device)),
556 };
557 GateController::create_with_weights(
558 d_input,
559 d_output,
560 bias,
561 initializer,
562 record_1,
563 record_2,
564 )
565 }
566
567 lstm.input_gate = create_gate_controller(
568 0.5,
569 0.0,
570 1,
571 1,
572 false,
573 Initializer::XavierUniform { gain: 1.0 },
574 &device,
575 );
576 lstm.forget_gate = create_gate_controller(
577 0.7,
578 0.0,
579 1,
580 1,
581 false,
582 Initializer::XavierUniform { gain: 1.0 },
583 &device,
584 );
585 lstm.cell_gate = create_gate_controller(
586 0.9,
587 0.0,
588 1,
589 1,
590 false,
591 Initializer::XavierUniform { gain: 1.0 },
592 &device,
593 );
594 lstm.output_gate = create_gate_controller(
595 1.1,
596 0.0,
597 1,
598 1,
599 false,
600 Initializer::XavierUniform { gain: 1.0 },
601 &device,
602 );
603
604 let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
606
607 let (output, state) = lstm.forward(input, None);
608
609 let expected = TensorData::from([[0.046]]);
610 let tolerance = Tolerance::default();
611 state
612 .cell
613 .to_data()
614 .assert_approx_eq::<FT>(&expected, tolerance);
615
616 let expected = TensorData::from([[0.0242]]);
617 state
618 .hidden
619 .to_data()
620 .assert_approx_eq::<FT>(&expected, tolerance);
621
622 output
623 .select(0, Tensor::arange(0..1, &device))
624 .squeeze_dim::<2>(0)
625 .to_data()
626 .assert_approx_eq::<FT>(&state.hidden.to_data(), tolerance);
627 }
628
629 #[test]
630 fn test_batched_forward_pass() {
631 let device = Default::default();
632 let lstm = LstmConfig::new(64, 1024, true).init(&device);
633 let batched_input =
634 Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
635
636 let (output, state) = lstm.forward(batched_input, None);
637
638 assert_eq!(output.dims(), [8, 10, 1024]);
639 assert_eq!(state.cell.dims(), [8, 1024]);
640 assert_eq!(state.hidden.dims(), [8, 1024]);
641 }
642
643 #[test]
644 fn test_batched_forward_pass_batch_of_one() {
645 let device = Default::default();
646 let lstm = LstmConfig::new(64, 1024, true).init(&device);
647 let batched_input =
648 Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);
649
650 let (output, state) = lstm.forward(batched_input, None);
651
652 assert_eq!(output.dims(), [1, 2, 1024]);
653 assert_eq!(state.cell.dims(), [1, 1024]);
654 assert_eq!(state.hidden.dims(), [1, 1024]);
655 }
656
657 #[test]
658 #[cfg(feature = "std")]
659 fn test_batched_backward_pass() {
660 use burn::tensor::Shape;
661 let device = Default::default();
662 let lstm = LstmConfig::new(64, 32, true).init(&device);
663 let shape: Shape = [8, 10, 64].into();
664 let batched_input =
665 Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);
666
667 let (output, _) = lstm.forward(batched_input.clone(), None);
668 let fake_loss = output;
669 let grads = fake_loss.backward();
670
671 let some_gradient = lstm
672 .output_gate
673 .hidden_transform
674 .weight
675 .grad(&grads)
676 .unwrap();
677
678 assert_ne!(
680 some_gradient
681 .any()
682 .into_data()
683 .iter::<f32>()
684 .next()
685 .unwrap(),
686 0.0
687 );
688 }
689
690 #[test]
691 fn test_bidirectional() {
692 let device = Default::default();
693 TestBackend::seed(&device, 0);
694
695 let config = BiLstmConfig::new(2, 3, true);
696 let device = Default::default();
697 let mut lstm = config.init(&device);
698
699 fn create_gate_controller<const D1: usize, const D2: usize>(
700 input_weights: [[f32; D1]; D2],
701 input_biases: [f32; D1],
702 hidden_weights: [[f32; D1]; D1],
703 hidden_biases: [f32; D1],
704 device: &Device<TestBackend>,
705 ) -> GateController<TestBackend> {
706 let d_input = input_weights[0].len();
707 let d_output = input_weights.len();
708
709 let input_record = LinearRecord {
710 weight: Param::from_data(TensorData::from(input_weights), device),
711 bias: Some(Param::from_data(TensorData::from(input_biases), device)),
712 };
713 let hidden_record = LinearRecord {
714 weight: Param::from_data(TensorData::from(hidden_weights), device),
715 bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
716 };
717 GateController::create_with_weights(
718 d_input,
719 d_output,
720 true,
721 Initializer::XavierUniform { gain: 1.0 },
722 input_record,
723 hidden_record,
724 )
725 }
726
727 let input = Tensor::<TestBackend, 3>::from_data(
728 TensorData::from([[
729 [0.949, -0.861],
730 [0.892, 0.927],
731 [-0.173, -0.301],
732 [-0.081, 0.992],
733 ]]),
734 &device,
735 );
736 let h0 = Tensor::<TestBackend, 3>::from_data(
737 TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),
738 &device,
739 );
740 let c0 = Tensor::<TestBackend, 3>::from_data(
741 TensorData::from([[[0.723, 0.397, -0.262]], [[0.471, 0.613, 1.885]]]),
742 &device,
743 );
744
745 lstm.forward.input_gate = create_gate_controller(
746 [[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]],
747 [-0.196, 0.354, 0.209],
748 [
749 [-0.320, 0.232, -0.165],
750 [0.093, -0.572, -0.315],
751 [-0.467, 0.325, 0.046],
752 ],
753 [0.181, -0.190, -0.245],
754 &device,
755 );
756
757 lstm.forward.forget_gate = create_gate_controller(
758 [[-0.342, -0.084, -0.420], [-0.432, 0.119, 0.191]],
759 [0.315, -0.413, -0.041],
760 [
761 [0.453, 0.063, 0.561],
762 [0.211, 0.149, 0.213],
763 [-0.499, -0.158, 0.068],
764 ],
765 [-0.431, -0.535, 0.125],
766 &device,
767 );
768
769 lstm.forward.cell_gate = create_gate_controller(
770 [[-0.046, -0.382, 0.321], [-0.533, 0.558, 0.004]],
771 [-0.358, 0.282, -0.078],
772 [
773 [-0.358, 0.109, 0.139],
774 [-0.345, 0.091, -0.368],
775 [-0.508, 0.221, -0.507],
776 ],
777 [0.502, -0.509, -0.247],
778 &device,
779 );
780
781 lstm.forward.output_gate = create_gate_controller(
782 [[-0.577, -0.359, 0.216], [-0.550, 0.268, 0.243]],
783 [-0.227, -0.274, 0.039],
784 [
785 [-0.383, 0.449, 0.222],
786 [-0.357, -0.093, 0.449],
787 [-0.106, 0.236, 0.360],
788 ],
789 [-0.361, -0.209, -0.454],
790 &device,
791 );
792
793 lstm.reverse.input_gate = create_gate_controller(
794 [[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]],
795 [0.540, -0.164, 0.033],
796 [
797 [0.159, 0.180, -0.037],
798 [-0.443, 0.485, -0.488],
799 [0.098, -0.085, -0.140],
800 ],
801 [-0.510, 0.105, 0.114],
802 &device,
803 );
804
805 lstm.reverse.forget_gate = create_gate_controller(
806 [[-0.154, -0.432, -0.547], [-0.369, -0.310, -0.175]],
807 [0.141, 0.004, 0.055],
808 [
809 [-0.005, -0.277, -0.515],
810 [-0.011, -0.101, -0.365],
811 [0.426, 0.379, 0.337],
812 ],
813 [-0.382, 0.331, -0.176],
814 &device,
815 );
816
817 lstm.reverse.cell_gate = create_gate_controller(
818 [[-0.571, 0.228, -0.287], [-0.331, 0.110, 0.219]],
819 [-0.206, -0.546, 0.462],
820 [
821 [0.449, -0.240, 0.071],
822 [-0.045, 0.131, 0.124],
823 [0.138, -0.201, 0.191],
824 ],
825 [-0.030, 0.211, -0.352],
826 &device,
827 );
828
829 lstm.reverse.output_gate = create_gate_controller(
830 [[0.491, -0.442, 0.333], [0.313, -0.121, -0.070]],
831 [-0.387, -0.250, 0.066],
832 [
833 [-0.030, 0.268, 0.299],
834 [-0.019, -0.280, -0.314],
835 [0.466, -0.365, -0.248],
836 ],
837 [-0.398, -0.199, -0.566],
838 &device,
839 );
840
841 let expected_output_with_init_state = TensorData::from([[
842 [0.23764, -0.03442, 0.04414, -0.15635, -0.03366, -0.05798],
843 [0.00473, -0.02254, 0.02988, -0.16510, -0.00306, 0.08742],
844 [0.06210, -0.06509, -0.05339, -0.01710, 0.02091, 0.16012],
845 [-0.03420, 0.07774, -0.09774, -0.02604, 0.12584, 0.20872],
846 ]]);
847 let expected_output_without_init_state = TensorData::from([[
848 [0.08679, -0.08776, -0.00528, -0.15969, -0.05322, -0.08863],
849 [-0.02577, -0.05057, 0.00033, -0.17558, -0.03679, 0.03142],
850 [0.02942, -0.07411, -0.06044, -0.03601, -0.09998, 0.04846],
851 [-0.04026, 0.07178, -0.10189, -0.07349, -0.04576, 0.05550],
852 ]]);
853 let expected_hn_with_init_state = TensorData::from([
854 [[-0.03420, 0.07774, -0.09774]],
855 [[-0.15635, -0.03366, -0.05798]],
856 ]);
857 let expected_cn_with_init_state = TensorData::from([
858 [[-0.13593, 0.17125, -0.22395]],
859 [[-0.45425, -0.11206, -0.12908]],
860 ]);
861 let expected_hn_without_init_state = TensorData::from([
862 [[-0.04026, 0.07178, -0.10189]],
863 [[-0.15969, -0.05322, -0.08863]],
864 ]);
865 let expected_cn_without_init_state = TensorData::from([
866 [[-0.15839, 0.15923, -0.23569]],
867 [[-0.47407, -0.17493, -0.19643]],
868 ]);
869
870 let (output_with_init_state, state_with_init_state) =
871 lstm.forward(input.clone(), Some(LstmState::new(c0, h0)));
872 let (output_without_init_state, state_without_init_state) = lstm.forward(input, None);
873
874 let tolerance = Tolerance::permissive();
875 output_with_init_state
876 .to_data()
877 .assert_approx_eq::<FT>(&expected_output_with_init_state, tolerance);
878 output_without_init_state
879 .to_data()
880 .assert_approx_eq::<FT>(&expected_output_without_init_state, tolerance);
881 state_with_init_state
882 .hidden
883 .to_data()
884 .assert_approx_eq::<FT>(&expected_hn_with_init_state, tolerance);
885 state_with_init_state
886 .cell
887 .to_data()
888 .assert_approx_eq::<FT>(&expected_cn_with_init_state, tolerance);
889 state_without_init_state
890 .hidden
891 .to_data()
892 .assert_approx_eq::<FT>(&expected_hn_without_init_state, tolerance);
893 state_without_init_state
894 .cell
895 .to_data()
896 .assert_approx_eq::<FT>(&expected_cn_without_init_state, tolerance);
897 }
898
899 #[test]
900 fn display_lstm() {
901 let config = LstmConfig::new(2, 3, true);
902
903 let layer = config.init::<TestBackend>(&Default::default());
904
905 assert_eq!(
906 alloc::format!("{layer}"),
907 "Lstm {d_input: 2, d_hidden: 3, bias: true, params: 84}"
908 );
909 }
910
911 #[test]
912 fn display_bilstm() {
913 let config = BiLstmConfig::new(2, 3, true);
914
915 let layer = config.init::<TestBackend>(&Default::default());
916
917 assert_eq!(
918 alloc::format!("{layer}"),
919 "BiLstm {d_input: 2, d_hidden: 3, bias: true, params: 168}"
920 );
921 }
922}