1use crate as burn;
2
3use crate::config::Config;
4use crate::module::Module;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::nn::rnn::gate_controller::GateController;
7use crate::nn::Initializer;
8use crate::tensor::activation;
9use crate::tensor::backend::Backend;
10use crate::tensor::Tensor;
11
12pub struct LstmState<B: Backend, const D: usize> {
14 pub cell: Tensor<B, D>,
16 pub hidden: Tensor<B, D>,
18}
19
20impl<B: Backend, const D: usize> LstmState<B, D> {
21 pub fn new(cell: Tensor<B, D>, hidden: Tensor<B, D>) -> Self {
23 Self { cell, hidden }
24 }
25}
26
27#[derive(Config)]
29pub struct LstmConfig {
30 pub d_input: usize,
32 pub d_hidden: usize,
34 pub bias: bool,
36 #[config(default = "Initializer::XavierNormal{gain:1.0}")]
38 pub initializer: Initializer,
39}
40
41#[derive(Module, Debug)]
47#[module(custom_display)]
48pub struct Lstm<B: Backend> {
49 pub input_gate: GateController<B>,
51 pub forget_gate: GateController<B>,
53 pub output_gate: GateController<B>,
55 pub cell_gate: GateController<B>,
57 pub d_hidden: usize,
59}
60
61impl<B: Backend> ModuleDisplay for Lstm<B> {
62 fn custom_settings(&self) -> Option<DisplaySettings> {
63 DisplaySettings::new()
64 .with_new_line_after_attribute(false)
65 .optional()
66 }
67
68 fn custom_content(&self, content: Content) -> Option<Content> {
69 let [d_input, _] = self.input_gate.input_transform.weight.shape().dims();
70 let bias = self.input_gate.input_transform.bias.is_some();
71
72 content
73 .add("d_input", &d_input)
74 .add("d_hidden", &self.d_hidden)
75 .add("bias", &bias)
76 .optional()
77 }
78}
79
80impl LstmConfig {
81 pub fn init<B: Backend>(&self, device: &B::Device) -> Lstm<B> {
83 let d_output = self.d_hidden;
84
85 let new_gate = || {
86 GateController::new(
87 self.d_input,
88 d_output,
89 self.bias,
90 self.initializer.clone(),
91 device,
92 )
93 };
94
95 Lstm {
96 input_gate: new_gate(),
97 forget_gate: new_gate(),
98 output_gate: new_gate(),
99 cell_gate: new_gate(),
100 d_hidden: self.d_hidden,
101 }
102 }
103}
104
105impl<B: Backend> Lstm<B> {
106 pub fn forward(
120 &self,
121 batched_input: Tensor<B, 3>,
122 state: Option<LstmState<B, 2>>,
123 ) -> (Tensor<B, 3>, LstmState<B, 2>) {
124 let device = batched_input.device();
125 let [batch_size, seq_length, _] = batched_input.dims();
126
127 self.forward_iter(
128 batched_input.iter_dim(1).zip(0..seq_length),
129 state,
130 batch_size,
131 seq_length,
132 &device,
133 )
134 }
135
136 fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(
137 &self,
138 input_timestep_iter: I,
139 state: Option<LstmState<B, 2>>,
140 batch_size: usize,
141 seq_length: usize,
142 device: &B::Device,
143 ) -> (Tensor<B, 3>, LstmState<B, 2>) {
144 let mut batched_hidden_state =
145 Tensor::empty([batch_size, seq_length, self.d_hidden], device);
146
147 let (mut cell_state, mut hidden_state) = match state {
148 Some(state) => (state.cell, state.hidden),
149 None => (
150 Tensor::zeros([batch_size, self.d_hidden], device),
151 Tensor::zeros([batch_size, self.d_hidden], device),
152 ),
153 };
154
155 for (input_t, t) in input_timestep_iter {
156 let input_t = input_t.squeeze(1);
157 let biased_fg_input_sum = self
159 .forget_gate
160 .gate_product(input_t.clone(), hidden_state.clone());
161 let forget_values = activation::sigmoid(biased_fg_input_sum); let biased_ig_input_sum = self
165 .input_gate
166 .gate_product(input_t.clone(), hidden_state.clone());
167 let add_values = activation::sigmoid(biased_ig_input_sum);
168
169 let biased_og_input_sum = self
171 .output_gate
172 .gate_product(input_t.clone(), hidden_state.clone());
173 let output_values = activation::sigmoid(biased_og_input_sum);
174
175 let biased_cg_input_sum = self
177 .cell_gate
178 .gate_product(input_t.clone(), hidden_state.clone());
179 let candidate_cell_values = biased_cg_input_sum.tanh();
180
181 cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values;
182 hidden_state = output_values * cell_state.clone().tanh();
183
184 let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1);
185
186 batched_hidden_state = batched_hidden_state.slice_assign(
188 [0..batch_size, t..(t + 1), 0..self.d_hidden],
189 unsqueezed_hidden_state.clone(),
190 );
191 }
192
193 (
194 batched_hidden_state,
195 LstmState::new(cell_state, hidden_state),
196 )
197 }
198}
199
200#[derive(Config)]
202pub struct BiLstmConfig {
203 pub d_input: usize,
205 pub d_hidden: usize,
207 pub bias: bool,
209 #[config(default = "Initializer::XavierNormal{gain:1.0}")]
211 pub initializer: Initializer,
212}
213
214#[derive(Module, Debug)]
220#[module(custom_display)]
221pub struct BiLstm<B: Backend> {
222 pub forward: Lstm<B>,
224 pub reverse: Lstm<B>,
226 pub d_hidden: usize,
228}
229
230impl<B: Backend> ModuleDisplay for BiLstm<B> {
231 fn custom_settings(&self) -> Option<DisplaySettings> {
232 DisplaySettings::new()
233 .with_new_line_after_attribute(false)
234 .optional()
235 }
236
237 fn custom_content(&self, content: Content) -> Option<Content> {
238 let [d_input, _] = self
239 .forward
240 .input_gate
241 .input_transform
242 .weight
243 .shape()
244 .dims();
245 let bias = self.forward.input_gate.input_transform.bias.is_some();
246
247 content
248 .add("d_input", &d_input)
249 .add("d_hidden", &self.d_hidden)
250 .add("bias", &bias)
251 .optional()
252 }
253}
254
255impl BiLstmConfig {
256 pub fn init<B: Backend>(&self, device: &B::Device) -> BiLstm<B> {
258 BiLstm {
259 forward: LstmConfig::new(self.d_input, self.d_hidden, self.bias)
260 .with_initializer(self.initializer.clone())
261 .init(device),
262 reverse: LstmConfig::new(self.d_input, self.d_hidden, self.bias)
263 .with_initializer(self.initializer.clone())
264 .init(device),
265 d_hidden: self.d_hidden,
266 }
267 }
268}
269
270impl<B: Backend> BiLstm<B> {
271 pub fn forward(
285 &self,
286 batched_input: Tensor<B, 3>,
287 state: Option<LstmState<B, 3>>,
288 ) -> (Tensor<B, 3>, LstmState<B, 3>) {
289 let device = batched_input.clone().device();
290 let [batch_size, seq_length, _] = batched_input.shape().dims();
291
292 let [init_state_forward, init_state_reverse] = match state {
293 Some(state) => {
294 let cell_state_forward = state
295 .cell
296 .clone()
297 .slice([0..1, 0..batch_size, 0..self.d_hidden])
298 .squeeze(0);
299 let hidden_state_forward = state
300 .hidden
301 .clone()
302 .slice([0..1, 0..batch_size, 0..self.d_hidden])
303 .squeeze(0);
304 let cell_state_reverse = state
305 .cell
306 .slice([1..2, 0..batch_size, 0..self.d_hidden])
307 .squeeze(0);
308 let hidden_state_reverse = state
309 .hidden
310 .slice([1..2, 0..batch_size, 0..self.d_hidden])
311 .squeeze(0);
312
313 [
314 Some(LstmState::new(cell_state_forward, hidden_state_forward)),
315 Some(LstmState::new(cell_state_reverse, hidden_state_reverse)),
316 ]
317 }
318 None => [None, None],
319 };
320
321 let (batched_hidden_state_forward, final_state_forward) = self
323 .forward
324 .forward(batched_input.clone(), init_state_forward);
325
326 let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(
328 batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
329 init_state_reverse,
330 batch_size,
331 seq_length,
332 &device,
333 );
334
335 let output = Tensor::cat(
336 [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),
337 2,
338 );
339
340 let state = LstmState::new(
341 Tensor::stack(
342 [final_state_forward.cell, final_state_reverse.cell].to_vec(),
343 0,
344 ),
345 Tensor::stack(
346 [final_state_forward.hidden, final_state_reverse.hidden].to_vec(),
347 0,
348 ),
349 );
350
351 (output, state)
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use crate::tensor::{Device, Distribution, TensorData};
359 use crate::{module::Param, nn::LinearRecord, TestBackend};
360
361 #[cfg(feature = "std")]
362 use crate::TestAutodiffBackend;
363
364 #[test]
365 fn test_with_uniform_initializer() {
366 TestBackend::seed(0);
367
368 let config = LstmConfig::new(5, 5, false)
369 .with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 });
370 let lstm = config.init::<TestBackend>(&Default::default());
371
372 let gate_to_data =
373 |gate: GateController<TestBackend>| gate.input_transform.weight.val().to_data();
374
375 gate_to_data(lstm.input_gate).assert_within_range(0..1);
376 gate_to_data(lstm.forget_gate).assert_within_range(0..1);
377 gate_to_data(lstm.output_gate).assert_within_range(0..1);
378 gate_to_data(lstm.cell_gate).assert_within_range(0..1);
379 }
380
381 #[test]
390 fn test_forward_single_input_single_feature() {
391 TestBackend::seed(0);
392 let config = LstmConfig::new(1, 1, false);
393 let device = Default::default();
394 let mut lstm = config.init::<TestBackend>(&device);
395
396 fn create_gate_controller(
397 weights: f32,
398 biases: f32,
399 d_input: usize,
400 d_output: usize,
401 bias: bool,
402 initializer: Initializer,
403 device: &Device<TestBackend>,
404 ) -> GateController<TestBackend> {
405 let record_1 = LinearRecord {
406 weight: Param::from_data(TensorData::from([[weights]]), device),
407 bias: Some(Param::from_data(TensorData::from([biases]), device)),
408 };
409 let record_2 = LinearRecord {
410 weight: Param::from_data(TensorData::from([[weights]]), device),
411 bias: Some(Param::from_data(TensorData::from([biases]), device)),
412 };
413 GateController::create_with_weights(
414 d_input,
415 d_output,
416 bias,
417 initializer,
418 record_1,
419 record_2,
420 )
421 }
422
423 lstm.input_gate = create_gate_controller(
424 0.5,
425 0.0,
426 1,
427 1,
428 false,
429 Initializer::XavierUniform { gain: 1.0 },
430 &device,
431 );
432 lstm.forget_gate = create_gate_controller(
433 0.7,
434 0.0,
435 1,
436 1,
437 false,
438 Initializer::XavierUniform { gain: 1.0 },
439 &device,
440 );
441 lstm.cell_gate = create_gate_controller(
442 0.9,
443 0.0,
444 1,
445 1,
446 false,
447 Initializer::XavierUniform { gain: 1.0 },
448 &device,
449 );
450 lstm.output_gate = create_gate_controller(
451 1.1,
452 0.0,
453 1,
454 1,
455 false,
456 Initializer::XavierUniform { gain: 1.0 },
457 &device,
458 );
459
460 let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
462
463 let (output, state) = lstm.forward(input, None);
464
465 let expected = TensorData::from([[0.046]]);
466 state.cell.to_data().assert_approx_eq(&expected, 3);
467
468 let expected = TensorData::from([[0.024]]);
469 state.hidden.to_data().assert_approx_eq(&expected, 3);
470
471 output
472 .select(0, Tensor::arange(0..1, &device))
473 .squeeze::<2>(0)
474 .to_data()
475 .assert_approx_eq(&state.hidden.to_data(), 3);
476 }
477
478 #[test]
479 fn test_batched_forward_pass() {
480 let device = Default::default();
481 let lstm = LstmConfig::new(64, 1024, true).init(&device);
482 let batched_input =
483 Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
484
485 let (output, state) = lstm.forward(batched_input, None);
486
487 assert_eq!(output.dims(), [8, 10, 1024]);
488 assert_eq!(state.cell.dims(), [8, 1024]);
489 assert_eq!(state.hidden.dims(), [8, 1024]);
490 }
491
492 #[test]
493 fn test_batched_forward_pass_batch_of_one() {
494 let device = Default::default();
495 let lstm = LstmConfig::new(64, 1024, true).init(&device);
496 let batched_input =
497 Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);
498
499 let (output, state) = lstm.forward(batched_input, None);
500
501 assert_eq!(output.dims(), [1, 2, 1024]);
502 assert_eq!(state.cell.dims(), [1, 1024]);
503 assert_eq!(state.hidden.dims(), [1, 1024]);
504 }
505
506 #[test]
507 #[cfg(feature = "std")]
508 fn test_batched_backward_pass() {
509 use crate::tensor::Shape;
510 let device = Default::default();
511 let lstm = LstmConfig::new(64, 32, true).init(&device);
512 let shape: Shape = [8, 10, 64].into();
513 let batched_input =
514 Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);
515
516 let (output, _) = lstm.forward(batched_input.clone(), None);
517 let fake_loss = output;
518 let grads = fake_loss.backward();
519
520 let some_gradient = lstm
521 .output_gate
522 .hidden_transform
523 .weight
524 .grad(&grads)
525 .unwrap();
526
527 assert!(
529 some_gradient
530 .any()
531 .into_data()
532 .iter::<f32>()
533 .next()
534 .unwrap()
535 != 0.0
536 );
537 }
538
539 #[test]
540 fn test_bidirectional() {
541 TestBackend::seed(0);
542 let config = BiLstmConfig::new(2, 3, true);
543 let device = Default::default();
544 let mut lstm = config.init(&device);
545
546 fn create_gate_controller<const D1: usize, const D2: usize>(
547 input_weights: [[f32; D1]; D2],
548 input_biases: [f32; D1],
549 hidden_weights: [[f32; D1]; D1],
550 hidden_biases: [f32; D1],
551 device: &Device<TestBackend>,
552 ) -> GateController<TestBackend> {
553 let d_input = input_weights[0].len();
554 let d_output = input_weights.len();
555
556 let input_record = LinearRecord {
557 weight: Param::from_data(TensorData::from(input_weights), device),
558 bias: Some(Param::from_data(TensorData::from(input_biases), device)),
559 };
560 let hidden_record = LinearRecord {
561 weight: Param::from_data(TensorData::from(hidden_weights), device),
562 bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
563 };
564 GateController::create_with_weights(
565 d_input,
566 d_output,
567 true,
568 Initializer::XavierUniform { gain: 1.0 },
569 input_record,
570 hidden_record,
571 )
572 }
573
574 let input = Tensor::<TestBackend, 3>::from_data(
575 TensorData::from([[
576 [0.949, -0.861],
577 [0.892, 0.927],
578 [-0.173, -0.301],
579 [-0.081, 0.992],
580 ]]),
581 &device,
582 );
583 let h0 = Tensor::<TestBackend, 3>::from_data(
584 TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),
585 &device,
586 );
587 let c0 = Tensor::<TestBackend, 3>::from_data(
588 TensorData::from([[[0.723, 0.397, -0.262]], [[0.471, 0.613, 1.885]]]),
589 &device,
590 );
591
592 lstm.forward.input_gate = create_gate_controller(
593 [[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]],
594 [-0.196, 0.354, 0.209],
595 [
596 [-0.320, 0.232, -0.165],
597 [0.093, -0.572, -0.315],
598 [-0.467, 0.325, 0.046],
599 ],
600 [0.181, -0.190, -0.245],
601 &device,
602 );
603
604 lstm.forward.forget_gate = create_gate_controller(
605 [[-0.342, -0.084, -0.420], [-0.432, 0.119, 0.191]],
606 [0.315, -0.413, -0.041],
607 [
608 [0.453, 0.063, 0.561],
609 [0.211, 0.149, 0.213],
610 [-0.499, -0.158, 0.068],
611 ],
612 [-0.431, -0.535, 0.125],
613 &device,
614 );
615
616 lstm.forward.cell_gate = create_gate_controller(
617 [[-0.046, -0.382, 0.321], [-0.533, 0.558, 0.004]],
618 [-0.358, 0.282, -0.078],
619 [
620 [-0.358, 0.109, 0.139],
621 [-0.345, 0.091, -0.368],
622 [-0.508, 0.221, -0.507],
623 ],
624 [0.502, -0.509, -0.247],
625 &device,
626 );
627
628 lstm.forward.output_gate = create_gate_controller(
629 [[-0.577, -0.359, 0.216], [-0.550, 0.268, 0.243]],
630 [-0.227, -0.274, 0.039],
631 [
632 [-0.383, 0.449, 0.222],
633 [-0.357, -0.093, 0.449],
634 [-0.106, 0.236, 0.360],
635 ],
636 [-0.361, -0.209, -0.454],
637 &device,
638 );
639
640 lstm.reverse.input_gate = create_gate_controller(
641 [[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]],
642 [0.540, -0.164, 0.033],
643 [
644 [0.159, 0.180, -0.037],
645 [-0.443, 0.485, -0.488],
646 [0.098, -0.085, -0.140],
647 ],
648 [-0.510, 0.105, 0.114],
649 &device,
650 );
651
652 lstm.reverse.forget_gate = create_gate_controller(
653 [[-0.154, -0.432, -0.547], [-0.369, -0.310, -0.175]],
654 [0.141, 0.004, 0.055],
655 [
656 [-0.005, -0.277, -0.515],
657 [-0.011, -0.101, -0.365],
658 [0.426, 0.379, 0.337],
659 ],
660 [-0.382, 0.331, -0.176],
661 &device,
662 );
663
664 lstm.reverse.cell_gate = create_gate_controller(
665 [[-0.571, 0.228, -0.287], [-0.331, 0.110, 0.219]],
666 [-0.206, -0.546, 0.462],
667 [
668 [0.449, -0.240, 0.071],
669 [-0.045, 0.131, 0.124],
670 [0.138, -0.201, 0.191],
671 ],
672 [-0.030, 0.211, -0.352],
673 &device,
674 );
675
676 lstm.reverse.output_gate = create_gate_controller(
677 [[0.491, -0.442, 0.333], [0.313, -0.121, -0.070]],
678 [-0.387, -0.250, 0.066],
679 [
680 [-0.030, 0.268, 0.299],
681 [-0.019, -0.280, -0.314],
682 [0.466, -0.365, -0.248],
683 ],
684 [-0.398, -0.199, -0.566],
685 &device,
686 );
687
688 let expected_output_with_init_state = TensorData::from([[
689 [0.23764, -0.03442, 0.04414, -0.15635, -0.03366, -0.05798],
690 [0.00473, -0.02254, 0.02988, -0.16510, -0.00306, 0.08742],
691 [0.06210, -0.06509, -0.05339, -0.01710, 0.02091, 0.16012],
692 [-0.03420, 0.07774, -0.09774, -0.02604, 0.12584, 0.20872],
693 ]]);
694 let expected_output_without_init_state = TensorData::from([[
695 [0.08679, -0.08776, -0.00528, -0.15969, -0.05322, -0.08863],
696 [-0.02577, -0.05057, 0.00033, -0.17558, -0.03679, 0.03142],
697 [0.02942, -0.07411, -0.06044, -0.03601, -0.09998, 0.04846],
698 [-0.04026, 0.07178, -0.10189, -0.07349, -0.04576, 0.05550],
699 ]]);
700 let expected_hn_with_init_state = TensorData::from([
701 [[-0.03420, 0.07774, -0.09774]],
702 [[-0.15635, -0.03366, -0.05798]],
703 ]);
704 let expected_cn_with_init_state = TensorData::from([
705 [[-0.13593, 0.17125, -0.22395]],
706 [[-0.45425, -0.11206, -0.12908]],
707 ]);
708 let expected_hn_without_init_state = TensorData::from([
709 [[-0.04026, 0.07178, -0.10189]],
710 [[-0.15969, -0.05322, -0.08863]],
711 ]);
712 let expected_cn_without_init_state = TensorData::from([
713 [[-0.15839, 0.15923, -0.23569]],
714 [[-0.47407, -0.17493, -0.19643]],
715 ]);
716
717 let (output_with_init_state, state_with_init_state) =
718 lstm.forward(input.clone(), Some(LstmState::new(c0, h0)));
719 let (output_without_init_state, state_without_init_state) = lstm.forward(input, None);
720
721 output_with_init_state
722 .to_data()
723 .assert_approx_eq(&expected_output_with_init_state, 3);
724 output_without_init_state
725 .to_data()
726 .assert_approx_eq(&expected_output_without_init_state, 3);
727 state_with_init_state
728 .hidden
729 .to_data()
730 .assert_approx_eq(&expected_hn_with_init_state, 3);
731 state_with_init_state
732 .cell
733 .to_data()
734 .assert_approx_eq(&expected_cn_with_init_state, 3);
735 state_without_init_state
736 .hidden
737 .to_data()
738 .assert_approx_eq(&expected_hn_without_init_state, 3);
739 state_without_init_state
740 .cell
741 .to_data()
742 .assert_approx_eq(&expected_cn_without_init_state, 3);
743 }
744
745 #[test]
746 fn display_lstm() {
747 let config = LstmConfig::new(2, 3, true);
748
749 let layer = config.init::<TestBackend>(&Default::default());
750
751 assert_eq!(
752 alloc::format!("{}", layer),
753 "Lstm {d_input: 2, d_hidden: 3, bias: true, params: 84}"
754 );
755 }
756
757 #[test]
758 fn display_bilstm() {
759 let config = BiLstmConfig::new(2, 3, true);
760
761 let layer = config.init::<TestBackend>(&Default::default());
762
763 assert_eq!(
764 alloc::format!("{}", layer),
765 "BiLstm {d_input: 2, d_hidden: 3, bias: true, params: 168}"
766 );
767 }
768}