1use crate as burn;
2
3use crate::config::Config;
4use crate::module::Module;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::nn::Initializer;
7use crate::nn::rnn::gate_controller::GateController;
8use crate::tensor::Tensor;
9use crate::tensor::activation;
10use crate::tensor::backend::Backend;
11
12pub struct LstmState<B: Backend, const D: usize> {
14 pub cell: Tensor<B, D>,
16 pub hidden: Tensor<B, D>,
18}
19
20impl<B: Backend, const D: usize> LstmState<B, D> {
21 pub fn new(cell: Tensor<B, D>, hidden: Tensor<B, D>) -> Self {
23 Self { cell, hidden }
24 }
25}
26
27#[derive(Config)]
29pub struct LstmConfig {
30 pub d_input: usize,
32 pub d_hidden: usize,
34 pub bias: bool,
36 #[config(default = "Initializer::XavierNormal{gain:1.0}")]
38 pub initializer: Initializer,
39}
40
41#[derive(Module, Debug)]
47#[module(custom_display)]
48pub struct Lstm<B: Backend> {
49 pub input_gate: GateController<B>,
51 pub forget_gate: GateController<B>,
53 pub output_gate: GateController<B>,
55 pub cell_gate: GateController<B>,
57 pub d_hidden: usize,
59}
60
61impl<B: Backend> ModuleDisplay for Lstm<B> {
62 fn custom_settings(&self) -> Option<DisplaySettings> {
63 DisplaySettings::new()
64 .with_new_line_after_attribute(false)
65 .optional()
66 }
67
68 fn custom_content(&self, content: Content) -> Option<Content> {
69 let [d_input, _] = self.input_gate.input_transform.weight.shape().dims();
70 let bias = self.input_gate.input_transform.bias.is_some();
71
72 content
73 .add("d_input", &d_input)
74 .add("d_hidden", &self.d_hidden)
75 .add("bias", &bias)
76 .optional()
77 }
78}
79
80impl LstmConfig {
81 pub fn init<B: Backend>(&self, device: &B::Device) -> Lstm<B> {
83 let d_output = self.d_hidden;
84
85 let new_gate = || {
86 GateController::new(
87 self.d_input,
88 d_output,
89 self.bias,
90 self.initializer.clone(),
91 device,
92 )
93 };
94
95 Lstm {
96 input_gate: new_gate(),
97 forget_gate: new_gate(),
98 output_gate: new_gate(),
99 cell_gate: new_gate(),
100 d_hidden: self.d_hidden,
101 }
102 }
103}
104
105impl<B: Backend> Lstm<B> {
106 pub fn forward(
120 &self,
121 batched_input: Tensor<B, 3>,
122 state: Option<LstmState<B, 2>>,
123 ) -> (Tensor<B, 3>, LstmState<B, 2>) {
124 let device = batched_input.device();
125 let [batch_size, seq_length, _] = batched_input.dims();
126
127 self.forward_iter(
128 batched_input.iter_dim(1).zip(0..seq_length),
129 state,
130 batch_size,
131 seq_length,
132 &device,
133 )
134 }
135
136 fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(
137 &self,
138 input_timestep_iter: I,
139 state: Option<LstmState<B, 2>>,
140 batch_size: usize,
141 seq_length: usize,
142 device: &B::Device,
143 ) -> (Tensor<B, 3>, LstmState<B, 2>) {
144 let mut batched_hidden_state =
145 Tensor::empty([batch_size, seq_length, self.d_hidden], device);
146
147 let (mut cell_state, mut hidden_state) = match state {
148 Some(state) => (state.cell, state.hidden),
149 None => (
150 Tensor::zeros([batch_size, self.d_hidden], device),
151 Tensor::zeros([batch_size, self.d_hidden], device),
152 ),
153 };
154
155 for (input_t, t) in input_timestep_iter {
156 let input_t = input_t.squeeze(1);
157 let biased_fg_input_sum = self
159 .forget_gate
160 .gate_product(input_t.clone(), hidden_state.clone());
161 let forget_values = activation::sigmoid(biased_fg_input_sum); let biased_ig_input_sum = self
165 .input_gate
166 .gate_product(input_t.clone(), hidden_state.clone());
167 let add_values = activation::sigmoid(biased_ig_input_sum);
168
169 let biased_og_input_sum = self
171 .output_gate
172 .gate_product(input_t.clone(), hidden_state.clone());
173 let output_values = activation::sigmoid(biased_og_input_sum);
174
175 let biased_cg_input_sum = self
177 .cell_gate
178 .gate_product(input_t.clone(), hidden_state.clone());
179 let candidate_cell_values = biased_cg_input_sum.tanh();
180
181 cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values;
182 hidden_state = output_values * cell_state.clone().tanh();
183
184 let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1);
185
186 batched_hidden_state = batched_hidden_state.slice_assign(
188 [0..batch_size, t..(t + 1), 0..self.d_hidden],
189 unsqueezed_hidden_state.clone(),
190 );
191 }
192
193 (
194 batched_hidden_state,
195 LstmState::new(cell_state, hidden_state),
196 )
197 }
198}
199
200#[derive(Config)]
202pub struct BiLstmConfig {
203 pub d_input: usize,
205 pub d_hidden: usize,
207 pub bias: bool,
209 #[config(default = "Initializer::XavierNormal{gain:1.0}")]
211 pub initializer: Initializer,
212}
213
214#[derive(Module, Debug)]
220#[module(custom_display)]
221pub struct BiLstm<B: Backend> {
222 pub forward: Lstm<B>,
224 pub reverse: Lstm<B>,
226 pub d_hidden: usize,
228}
229
230impl<B: Backend> ModuleDisplay for BiLstm<B> {
231 fn custom_settings(&self) -> Option<DisplaySettings> {
232 DisplaySettings::new()
233 .with_new_line_after_attribute(false)
234 .optional()
235 }
236
237 fn custom_content(&self, content: Content) -> Option<Content> {
238 let [d_input, _] = self
239 .forward
240 .input_gate
241 .input_transform
242 .weight
243 .shape()
244 .dims();
245 let bias = self.forward.input_gate.input_transform.bias.is_some();
246
247 content
248 .add("d_input", &d_input)
249 .add("d_hidden", &self.d_hidden)
250 .add("bias", &bias)
251 .optional()
252 }
253}
254
255impl BiLstmConfig {
256 pub fn init<B: Backend>(&self, device: &B::Device) -> BiLstm<B> {
258 BiLstm {
259 forward: LstmConfig::new(self.d_input, self.d_hidden, self.bias)
260 .with_initializer(self.initializer.clone())
261 .init(device),
262 reverse: LstmConfig::new(self.d_input, self.d_hidden, self.bias)
263 .with_initializer(self.initializer.clone())
264 .init(device),
265 d_hidden: self.d_hidden,
266 }
267 }
268}
269
270impl<B: Backend> BiLstm<B> {
271 pub fn forward(
285 &self,
286 batched_input: Tensor<B, 3>,
287 state: Option<LstmState<B, 3>>,
288 ) -> (Tensor<B, 3>, LstmState<B, 3>) {
289 let device = batched_input.clone().device();
290 let [batch_size, seq_length, _] = batched_input.shape().dims();
291
292 let [init_state_forward, init_state_reverse] = match state {
293 Some(state) => {
294 let cell_state_forward = state
295 .cell
296 .clone()
297 .slice([0..1, 0..batch_size, 0..self.d_hidden])
298 .squeeze(0);
299 let hidden_state_forward = state
300 .hidden
301 .clone()
302 .slice([0..1, 0..batch_size, 0..self.d_hidden])
303 .squeeze(0);
304 let cell_state_reverse = state
305 .cell
306 .slice([1..2, 0..batch_size, 0..self.d_hidden])
307 .squeeze(0);
308 let hidden_state_reverse = state
309 .hidden
310 .slice([1..2, 0..batch_size, 0..self.d_hidden])
311 .squeeze(0);
312
313 [
314 Some(LstmState::new(cell_state_forward, hidden_state_forward)),
315 Some(LstmState::new(cell_state_reverse, hidden_state_reverse)),
316 ]
317 }
318 None => [None, None],
319 };
320
321 let (batched_hidden_state_forward, final_state_forward) = self
323 .forward
324 .forward(batched_input.clone(), init_state_forward);
325
326 let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(
328 batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
329 init_state_reverse,
330 batch_size,
331 seq_length,
332 &device,
333 );
334
335 let output = Tensor::cat(
336 [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),
337 2,
338 );
339
340 let state = LstmState::new(
341 Tensor::stack(
342 [final_state_forward.cell, final_state_reverse.cell].to_vec(),
343 0,
344 ),
345 Tensor::stack(
346 [final_state_forward.hidden, final_state_reverse.hidden].to_vec(),
347 0,
348 ),
349 );
350
351 (output, state)
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use crate::tensor::{Device, Distribution, TensorData};
359 use crate::{TestBackend, module::Param, nn::LinearRecord};
360 use burn_tensor::{ElementConversion, Tolerance, ops::FloatElem};
361 type FT = FloatElem<TestBackend>;
362
363 #[cfg(feature = "std")]
364 use crate::TestAutodiffBackend;
365
366 #[test]
367 fn test_with_uniform_initializer() {
368 TestBackend::seed(0);
369
370 let config = LstmConfig::new(5, 5, false)
371 .with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 });
372 let lstm = config.init::<TestBackend>(&Default::default());
373
374 let gate_to_data =
375 |gate: GateController<TestBackend>| gate.input_transform.weight.val().to_data();
376
377 gate_to_data(lstm.input_gate).assert_within_range::<FT>(0.elem()..1.elem());
378 gate_to_data(lstm.forget_gate).assert_within_range::<FT>(0.elem()..1.elem());
379 gate_to_data(lstm.output_gate).assert_within_range::<FT>(0.elem()..1.elem());
380 gate_to_data(lstm.cell_gate).assert_within_range::<FT>(0.elem()..1.elem());
381 }
382
383 #[test]
392 fn test_forward_single_input_single_feature() {
393 TestBackend::seed(0);
394 let config = LstmConfig::new(1, 1, false);
395 let device = Default::default();
396 let mut lstm = config.init::<TestBackend>(&device);
397
398 fn create_gate_controller(
399 weights: f32,
400 biases: f32,
401 d_input: usize,
402 d_output: usize,
403 bias: bool,
404 initializer: Initializer,
405 device: &Device<TestBackend>,
406 ) -> GateController<TestBackend> {
407 let record_1 = LinearRecord {
408 weight: Param::from_data(TensorData::from([[weights]]), device),
409 bias: Some(Param::from_data(TensorData::from([biases]), device)),
410 };
411 let record_2 = LinearRecord {
412 weight: Param::from_data(TensorData::from([[weights]]), device),
413 bias: Some(Param::from_data(TensorData::from([biases]), device)),
414 };
415 GateController::create_with_weights(
416 d_input,
417 d_output,
418 bias,
419 initializer,
420 record_1,
421 record_2,
422 )
423 }
424
425 lstm.input_gate = create_gate_controller(
426 0.5,
427 0.0,
428 1,
429 1,
430 false,
431 Initializer::XavierUniform { gain: 1.0 },
432 &device,
433 );
434 lstm.forget_gate = create_gate_controller(
435 0.7,
436 0.0,
437 1,
438 1,
439 false,
440 Initializer::XavierUniform { gain: 1.0 },
441 &device,
442 );
443 lstm.cell_gate = create_gate_controller(
444 0.9,
445 0.0,
446 1,
447 1,
448 false,
449 Initializer::XavierUniform { gain: 1.0 },
450 &device,
451 );
452 lstm.output_gate = create_gate_controller(
453 1.1,
454 0.0,
455 1,
456 1,
457 false,
458 Initializer::XavierUniform { gain: 1.0 },
459 &device,
460 );
461
462 let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
464
465 let (output, state) = lstm.forward(input, None);
466
467 let expected = TensorData::from([[0.046]]);
468 let tolerance = Tolerance::default();
469 state
470 .cell
471 .to_data()
472 .assert_approx_eq::<FT>(&expected, tolerance);
473
474 let expected = TensorData::from([[0.0242]]);
475 state
476 .hidden
477 .to_data()
478 .assert_approx_eq::<FT>(&expected, tolerance);
479
480 output
481 .select(0, Tensor::arange(0..1, &device))
482 .squeeze::<2>(0)
483 .to_data()
484 .assert_approx_eq::<FT>(&state.hidden.to_data(), tolerance);
485 }
486
487 #[test]
488 fn test_batched_forward_pass() {
489 let device = Default::default();
490 let lstm = LstmConfig::new(64, 1024, true).init(&device);
491 let batched_input =
492 Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
493
494 let (output, state) = lstm.forward(batched_input, None);
495
496 assert_eq!(output.dims(), [8, 10, 1024]);
497 assert_eq!(state.cell.dims(), [8, 1024]);
498 assert_eq!(state.hidden.dims(), [8, 1024]);
499 }
500
501 #[test]
502 fn test_batched_forward_pass_batch_of_one() {
503 let device = Default::default();
504 let lstm = LstmConfig::new(64, 1024, true).init(&device);
505 let batched_input =
506 Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);
507
508 let (output, state) = lstm.forward(batched_input, None);
509
510 assert_eq!(output.dims(), [1, 2, 1024]);
511 assert_eq!(state.cell.dims(), [1, 1024]);
512 assert_eq!(state.hidden.dims(), [1, 1024]);
513 }
514
515 #[test]
516 #[cfg(feature = "std")]
517 fn test_batched_backward_pass() {
518 use crate::tensor::Shape;
519 let device = Default::default();
520 let lstm = LstmConfig::new(64, 32, true).init(&device);
521 let shape: Shape = [8, 10, 64].into();
522 let batched_input =
523 Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);
524
525 let (output, _) = lstm.forward(batched_input.clone(), None);
526 let fake_loss = output;
527 let grads = fake_loss.backward();
528
529 let some_gradient = lstm
530 .output_gate
531 .hidden_transform
532 .weight
533 .grad(&grads)
534 .unwrap();
535
536 assert!(
538 some_gradient
539 .any()
540 .into_data()
541 .iter::<f32>()
542 .next()
543 .unwrap()
544 != 0.0
545 );
546 }
547
548 #[test]
549 fn test_bidirectional() {
550 TestBackend::seed(0);
551 let config = BiLstmConfig::new(2, 3, true);
552 let device = Default::default();
553 let mut lstm = config.init(&device);
554
555 fn create_gate_controller<const D1: usize, const D2: usize>(
556 input_weights: [[f32; D1]; D2],
557 input_biases: [f32; D1],
558 hidden_weights: [[f32; D1]; D1],
559 hidden_biases: [f32; D1],
560 device: &Device<TestBackend>,
561 ) -> GateController<TestBackend> {
562 let d_input = input_weights[0].len();
563 let d_output = input_weights.len();
564
565 let input_record = LinearRecord {
566 weight: Param::from_data(TensorData::from(input_weights), device),
567 bias: Some(Param::from_data(TensorData::from(input_biases), device)),
568 };
569 let hidden_record = LinearRecord {
570 weight: Param::from_data(TensorData::from(hidden_weights), device),
571 bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
572 };
573 GateController::create_with_weights(
574 d_input,
575 d_output,
576 true,
577 Initializer::XavierUniform { gain: 1.0 },
578 input_record,
579 hidden_record,
580 )
581 }
582
583 let input = Tensor::<TestBackend, 3>::from_data(
584 TensorData::from([[
585 [0.949, -0.861],
586 [0.892, 0.927],
587 [-0.173, -0.301],
588 [-0.081, 0.992],
589 ]]),
590 &device,
591 );
592 let h0 = Tensor::<TestBackend, 3>::from_data(
593 TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),
594 &device,
595 );
596 let c0 = Tensor::<TestBackend, 3>::from_data(
597 TensorData::from([[[0.723, 0.397, -0.262]], [[0.471, 0.613, 1.885]]]),
598 &device,
599 );
600
601 lstm.forward.input_gate = create_gate_controller(
602 [[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]],
603 [-0.196, 0.354, 0.209],
604 [
605 [-0.320, 0.232, -0.165],
606 [0.093, -0.572, -0.315],
607 [-0.467, 0.325, 0.046],
608 ],
609 [0.181, -0.190, -0.245],
610 &device,
611 );
612
613 lstm.forward.forget_gate = create_gate_controller(
614 [[-0.342, -0.084, -0.420], [-0.432, 0.119, 0.191]],
615 [0.315, -0.413, -0.041],
616 [
617 [0.453, 0.063, 0.561],
618 [0.211, 0.149, 0.213],
619 [-0.499, -0.158, 0.068],
620 ],
621 [-0.431, -0.535, 0.125],
622 &device,
623 );
624
625 lstm.forward.cell_gate = create_gate_controller(
626 [[-0.046, -0.382, 0.321], [-0.533, 0.558, 0.004]],
627 [-0.358, 0.282, -0.078],
628 [
629 [-0.358, 0.109, 0.139],
630 [-0.345, 0.091, -0.368],
631 [-0.508, 0.221, -0.507],
632 ],
633 [0.502, -0.509, -0.247],
634 &device,
635 );
636
637 lstm.forward.output_gate = create_gate_controller(
638 [[-0.577, -0.359, 0.216], [-0.550, 0.268, 0.243]],
639 [-0.227, -0.274, 0.039],
640 [
641 [-0.383, 0.449, 0.222],
642 [-0.357, -0.093, 0.449],
643 [-0.106, 0.236, 0.360],
644 ],
645 [-0.361, -0.209, -0.454],
646 &device,
647 );
648
649 lstm.reverse.input_gate = create_gate_controller(
650 [[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]],
651 [0.540, -0.164, 0.033],
652 [
653 [0.159, 0.180, -0.037],
654 [-0.443, 0.485, -0.488],
655 [0.098, -0.085, -0.140],
656 ],
657 [-0.510, 0.105, 0.114],
658 &device,
659 );
660
661 lstm.reverse.forget_gate = create_gate_controller(
662 [[-0.154, -0.432, -0.547], [-0.369, -0.310, -0.175]],
663 [0.141, 0.004, 0.055],
664 [
665 [-0.005, -0.277, -0.515],
666 [-0.011, -0.101, -0.365],
667 [0.426, 0.379, 0.337],
668 ],
669 [-0.382, 0.331, -0.176],
670 &device,
671 );
672
673 lstm.reverse.cell_gate = create_gate_controller(
674 [[-0.571, 0.228, -0.287], [-0.331, 0.110, 0.219]],
675 [-0.206, -0.546, 0.462],
676 [
677 [0.449, -0.240, 0.071],
678 [-0.045, 0.131, 0.124],
679 [0.138, -0.201, 0.191],
680 ],
681 [-0.030, 0.211, -0.352],
682 &device,
683 );
684
685 lstm.reverse.output_gate = create_gate_controller(
686 [[0.491, -0.442, 0.333], [0.313, -0.121, -0.070]],
687 [-0.387, -0.250, 0.066],
688 [
689 [-0.030, 0.268, 0.299],
690 [-0.019, -0.280, -0.314],
691 [0.466, -0.365, -0.248],
692 ],
693 [-0.398, -0.199, -0.566],
694 &device,
695 );
696
697 let expected_output_with_init_state = TensorData::from([[
698 [0.23764, -0.03442, 0.04414, -0.15635, -0.03366, -0.05798],
699 [0.00473, -0.02254, 0.02988, -0.16510, -0.00306, 0.08742],
700 [0.06210, -0.06509, -0.05339, -0.01710, 0.02091, 0.16012],
701 [-0.03420, 0.07774, -0.09774, -0.02604, 0.12584, 0.20872],
702 ]]);
703 let expected_output_without_init_state = TensorData::from([[
704 [0.08679, -0.08776, -0.00528, -0.15969, -0.05322, -0.08863],
705 [-0.02577, -0.05057, 0.00033, -0.17558, -0.03679, 0.03142],
706 [0.02942, -0.07411, -0.06044, -0.03601, -0.09998, 0.04846],
707 [-0.04026, 0.07178, -0.10189, -0.07349, -0.04576, 0.05550],
708 ]]);
709 let expected_hn_with_init_state = TensorData::from([
710 [[-0.03420, 0.07774, -0.09774]],
711 [[-0.15635, -0.03366, -0.05798]],
712 ]);
713 let expected_cn_with_init_state = TensorData::from([
714 [[-0.13593, 0.17125, -0.22395]],
715 [[-0.45425, -0.11206, -0.12908]],
716 ]);
717 let expected_hn_without_init_state = TensorData::from([
718 [[-0.04026, 0.07178, -0.10189]],
719 [[-0.15969, -0.05322, -0.08863]],
720 ]);
721 let expected_cn_without_init_state = TensorData::from([
722 [[-0.15839, 0.15923, -0.23569]],
723 [[-0.47407, -0.17493, -0.19643]],
724 ]);
725
726 let (output_with_init_state, state_with_init_state) =
727 lstm.forward(input.clone(), Some(LstmState::new(c0, h0)));
728 let (output_without_init_state, state_without_init_state) = lstm.forward(input, None);
729
730 let tolerance = Tolerance::permissive();
731 output_with_init_state
732 .to_data()
733 .assert_approx_eq::<FT>(&expected_output_with_init_state, tolerance);
734 output_without_init_state
735 .to_data()
736 .assert_approx_eq::<FT>(&expected_output_without_init_state, tolerance);
737 state_with_init_state
738 .hidden
739 .to_data()
740 .assert_approx_eq::<FT>(&expected_hn_with_init_state, tolerance);
741 state_with_init_state
742 .cell
743 .to_data()
744 .assert_approx_eq::<FT>(&expected_cn_with_init_state, tolerance);
745 state_without_init_state
746 .hidden
747 .to_data()
748 .assert_approx_eq::<FT>(&expected_hn_without_init_state, tolerance);
749 state_without_init_state
750 .cell
751 .to_data()
752 .assert_approx_eq::<FT>(&expected_cn_without_init_state, tolerance);
753 }
754
755 #[test]
756 fn display_lstm() {
757 let config = LstmConfig::new(2, 3, true);
758
759 let layer = config.init::<TestBackend>(&Default::default());
760
761 assert_eq!(
762 alloc::format!("{layer}"),
763 "Lstm {d_input: 2, d_hidden: 3, bias: true, params: 84}"
764 );
765 }
766
767 #[test]
768 fn display_bilstm() {
769 let config = BiLstmConfig::new(2, 3, true);
770
771 let layer = config.init::<TestBackend>(&Default::default());
772
773 assert_eq!(
774 alloc::format!("{layer}"),
775 "BiLstm {d_input: 2, d_hidden: 3, bias: true, params: 168}"
776 );
777 }
778}