1use burn_core as burn;
2
3use crate::GateController;
4use crate::activation::{Activation, ActivationConfig};
5use burn::config::Config;
6use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9
10pub struct RnnState<B: Backend, const D: usize> {
12 pub hidden: Tensor<B, D>,
14}
15
16impl<B: Backend, const D: usize> RnnState<B, D> {
17 pub fn new(hidden: Tensor<B, D>) -> Self {
19 Self { hidden }
20 }
21}
22
23#[derive(Config, Debug)]
25pub struct RnnConfig {
26 pub d_input: usize,
28 pub d_hidden: usize,
30 pub bias: bool,
32 #[config(default = "Initializer::XavierNormal{gain:1.0}")]
34 pub initializer: Initializer,
35 #[config(default = true)]
38 pub batch_first: bool,
39 #[config(default = false)]
42 pub reverse: bool,
43 pub clip: Option<f64>,
47 #[config(default = "ActivationConfig::Tanh")]
50 pub hidden_activation: ActivationConfig,
51}
52
53#[derive(Module, Debug)]
56#[module(custom_display)]
57pub struct Rnn<B: Backend> {
58 pub gate: GateController<B>,
60 pub d_hidden: usize,
62 pub batch_first: bool,
65 pub reverse: bool,
67 pub clip: Option<f64>,
69 pub hidden_activation: Activation<B>,
71}
72
73impl<B: Backend> ModuleDisplay for Rnn<B> {
74 fn custom_settings(&self) -> Option<DisplaySettings> {
75 DisplaySettings::new()
76 .with_new_line_after_attribute(false)
77 .optional()
78 }
79
80 fn custom_content(&self, content: Content) -> Option<Content> {
81 let [d_input, _] = self.gate.input_transform.weight.shape().dims();
82 let bias = self.gate.input_transform.bias.is_some();
83
84 content
85 .add("d_input", &d_input)
86 .add("d_hidden", &self.d_hidden)
87 .add("bias", &bias)
88 .optional()
89 }
90}
91
92impl RnnConfig {
93 pub fn init<B: Backend>(&self, device: &B::Device) -> Rnn<B> {
95 let d_output = self.d_hidden;
96
97 let new_gate = || {
98 GateController::new(
99 self.d_input,
100 d_output,
101 self.bias,
102 self.initializer.clone(),
103 device,
104 )
105 };
106
107 Rnn {
108 gate: new_gate(),
109 d_hidden: self.d_hidden,
110 batch_first: self.batch_first,
111 reverse: self.reverse,
112 clip: self.clip,
113 hidden_activation: self.hidden_activation.init(device),
114 }
115 }
116}
117
118impl<B: Backend> Rnn<B> {
119 pub fn forward(
137 &self,
138 batched_input: Tensor<B, 3>,
139 state: Option<RnnState<B, 2>>,
140 ) -> (Tensor<B, 3>, RnnState<B, 2>) {
141 let batched_input = if self.batch_first {
143 batched_input
144 } else {
145 batched_input.swap_dims(0, 1)
146 };
147
148 let device = batched_input.device();
149 let [batch_size, seq_length, _] = batched_input.dims();
150
151 let (output, state) = if self.reverse {
153 self.forward_iter(
154 batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
155 state,
156 batch_size,
157 seq_length,
158 &device,
159 )
160 } else {
161 self.forward_iter(
162 batched_input.iter_dim(1).zip(0..seq_length),
163 state,
164 batch_size,
165 seq_length,
166 &device,
167 )
168 };
169
170 let output = if self.batch_first {
172 output
173 } else {
174 output.swap_dims(0, 1)
175 };
176
177 (output, state)
178 }
179
180 fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(
181 &self,
182 input_timestep_iter: I,
183 state: Option<RnnState<B, 2>>,
184 batch_size: usize,
185 seq_length: usize,
186 device: &B::Device,
187 ) -> (Tensor<B, 3>, RnnState<B, 2>) {
188 let mut batched_hidden_state =
189 Tensor::empty([batch_size, seq_length, self.d_hidden], device);
190
191 let mut hidden_state = match state {
192 Some(state) => state.hidden,
193 None => Tensor::zeros([batch_size, self.d_hidden], device),
194 };
195
196 for (input_t, t) in input_timestep_iter {
197 let input_t = input_t.squeeze_dim(1);
198
199 let biased_gate_sum = self
201 .gate
202 .gate_product(input_t.clone(), hidden_state.clone());
203
204 let output_values = self.hidden_activation.forward(biased_gate_sum);
205
206 hidden_state = output_values;
208
209 if let Some(clip) = self.clip {
211 hidden_state = hidden_state.clamp(-clip, clip);
212 }
213
214 let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1);
215
216 batched_hidden_state = batched_hidden_state.slice_assign(
218 [0..batch_size, t..(t + 1), 0..self.d_hidden],
219 unsqueezed_hidden_state.clone(),
220 );
221 }
222
223 (batched_hidden_state, RnnState::new(hidden_state))
224 }
225}
226
227#[derive(Config, Debug)]
229pub struct BiRnnConfig {
230 pub d_input: usize,
232 pub d_hidden: usize,
234 pub bias: bool,
236 #[config(default = "Initializer::XavierNormal{gain:1.0}")]
238 pub initializer: Initializer,
239 #[config(default = true)]
242 pub batch_first: bool,
243 pub clip: Option<f64>,
245 #[config(default = "ActivationConfig::Tanh")]
247 pub hidden_activation: ActivationConfig,
248}
249
250#[derive(Module, Debug)]
253#[module(custom_display)]
254pub struct BiRnn<B: Backend> {
255 pub forward: Rnn<B>,
257 pub reverse: Rnn<B>,
259 pub d_hidden: usize,
261 pub batch_first: bool,
264}
265
266impl<B: Backend> ModuleDisplay for BiRnn<B> {
267 fn custom_settings(&self) -> Option<DisplaySettings> {
268 DisplaySettings::new()
269 .with_new_line_after_attribute(false)
270 .optional()
271 }
272
273 fn custom_content(&self, content: Content) -> Option<Content> {
274 let [d_input, _] = self.forward.gate.input_transform.weight.shape().dims();
275 let bias = self.forward.gate.input_transform.bias.is_some();
276
277 content
278 .add("d_input", &d_input)
279 .add("d_hidden", &self.d_hidden)
280 .add("bias", &bias)
281 .optional()
282 }
283}
284
285impl BiRnnConfig {
286 pub fn init<B: Backend>(&self, device: &B::Device) -> BiRnn<B> {
288 let base_config = RnnConfig::new(self.d_input, self.d_hidden, self.bias)
290 .with_initializer(self.initializer.clone())
291 .with_batch_first(true)
292 .with_clip(self.clip)
293 .with_hidden_activation(self.hidden_activation.clone());
294
295 BiRnn {
296 forward: base_config.clone().init(device),
297 reverse: base_config.init(device),
298 d_hidden: self.d_hidden,
299 batch_first: self.batch_first,
300 }
301 }
302}
303
304impl<B: Backend> BiRnn<B> {
305 pub fn forward(
323 &self,
324 batched_input: Tensor<B, 3>,
325 state: Option<RnnState<B, 3>>,
326 ) -> (Tensor<B, 3>, RnnState<B, 3>) {
327 let batched_input = if self.batch_first {
329 batched_input
330 } else {
331 batched_input.swap_dims(0, 1)
332 };
333
334 let device = batched_input.clone().device();
335 let [batch_size, seq_length, _] = batched_input.shape().dims();
336
337 let [init_state_forward, init_state_reverse] = match state {
338 Some(state) => {
339 let hidden_state_forward = state
340 .hidden
341 .clone()
342 .slice([0..1, 0..batch_size, 0..self.d_hidden])
343 .squeeze_dim(0);
344 let hidden_state_reverse = state
345 .hidden
346 .slice([1..2, 0..batch_size, 0..self.d_hidden])
347 .squeeze_dim(0);
348
349 [
350 Some(RnnState::new(hidden_state_forward)),
351 Some(RnnState::new(hidden_state_reverse)),
352 ]
353 }
354 None => [None, None],
355 };
356
357 let (batched_hidden_state_forward, final_state_forward) = self
359 .forward
360 .forward(batched_input.clone(), init_state_forward);
361
362 let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(
364 batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
365 init_state_reverse,
366 batch_size,
367 seq_length,
368 &device,
369 );
370
371 let output = Tensor::cat(
372 [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),
373 2,
374 );
375
376 let output = if self.batch_first {
378 output
379 } else {
380 output.swap_dims(0, 1)
381 };
382
383 let state = RnnState::new(Tensor::stack(
384 [final_state_forward.hidden, final_state_reverse.hidden].to_vec(),
385 0,
386 ));
387
388 (output, state)
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use crate::{LinearRecord, TestBackend};
396 use burn::module::Param;
397 use burn::tensor::{Device, Distribution, TensorData};
398 use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};
399 type FT = FloatElem<TestBackend>;
400
401 #[cfg(feature = "std")]
402 use crate::TestAutodiffBackend;
403
404 fn create_single_feature_gate_controller(
405 weights: f32,
406 biases: f32,
407 d_input: usize,
408 d_output: usize,
409 bias: bool,
410 initializer: Initializer,
411 device: &Device<TestBackend>,
412 ) -> GateController<TestBackend> {
413 let record_1 = LinearRecord {
414 weight: Param::from_data(TensorData::from([[weights]]), device),
415 bias: Some(Param::from_data(TensorData::from([biases]), device)),
416 };
417 let record_2 = LinearRecord {
418 weight: Param::from_data(TensorData::from([[weights]]), device),
419 bias: Some(Param::from_data(TensorData::from([biases]), device)),
420 };
421 GateController::create_with_weights(
422 d_input,
423 d_output,
424 bias,
425 initializer,
426 record_1,
427 record_2,
428 )
429 }
430
431 #[test]
432 fn test_with_uniform_initializer() {
433 let device = Default::default();
434 TestBackend::seed(&device, 0);
435
436 let config = RnnConfig::new(5, 5, false)
437 .with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 });
438 let rnn = config.init::<TestBackend>(&Default::default());
439
440 let gate_to_data =
441 |gate: GateController<TestBackend>| gate.input_transform.weight.val().to_data();
442
443 gate_to_data(rnn.gate).assert_within_range::<FT>(0.elem()..1.elem());
444 }
445
446 #[test]
452 fn test_forward_single_input_single_feature() {
453 let device = Default::default();
454 TestBackend::seed(&device, 0);
455
456 let config = RnnConfig::new(1, 1, false);
457 let device = Default::default();
458 let mut rnn = config.init::<TestBackend>(&device);
459
460 rnn.gate = create_single_feature_gate_controller(
461 0.5,
462 0.0,
463 1,
464 1,
465 false,
466 Initializer::XavierUniform { gain: 1.0 },
467 &device,
468 );
469
470 let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
472
473 let (output, state) = rnn.forward(input, None);
474
475 let tolerance = Tolerance::default();
476 let expected = TensorData::from([[0.04995]]);
477 state
478 .hidden
479 .to_data()
480 .assert_approx_eq::<FT>(&expected, tolerance);
481
482 output
483 .select(0, Tensor::arange(0..1, &device))
484 .squeeze_dim::<2>(0)
485 .to_data()
486 .assert_approx_eq::<FT>(&state.hidden.to_data(), tolerance);
487 }
488
489 #[test]
490 fn test_batched_forward_pass_batch_of_one() {
491 let device = Default::default();
492 let rnn = RnnConfig::new(64, 1024, true).init(&device);
493 let batched_input =
494 Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);
495
496 let (output, state) = rnn.forward(batched_input, None);
497 assert_eq!(output.dims(), [1, 2, 1024]);
498 assert_eq!(state.hidden.dims(), [1, 1024]);
499 }
500
501 #[test]
502 #[cfg(feature = "std")]
503 fn test_batched_backward_pass() {
504 use burn::tensor::Shape;
505 let device = Default::default();
506 let rnn = RnnConfig::new(64, 32, true).init(&device);
507 let shape: Shape = [8, 10, 64].into();
508 let batched_input =
509 Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);
510
511 let (output, _) = rnn.forward(batched_input.clone(), None);
512 let fake_loss = output;
513 let grads = fake_loss.backward();
514
515 let some_gradient = rnn.gate.hidden_transform.weight.grad(&grads).unwrap();
516
517 assert_ne!(
519 some_gradient
520 .any()
521 .into_data()
522 .iter::<f32>()
523 .next()
524 .unwrap(),
525 0.0
526 );
527 }
528
529 #[test]
530 fn test_bidirectional() {
531 let device = Default::default();
532 TestBackend::seed(&device, 0);
533
534 let config = BiRnnConfig::new(2, 3, true);
535 let mut rnn = config.init(&device);
536
537 fn create_gate_controller<const D1: usize, const D2: usize>(
538 input_weights: [[f32; D1]; D2],
539 input_biases: [f32; D1],
540 hidden_weights: [[f32; D1]; D1],
541 hidden_biases: [f32; D1],
542 device: &Device<TestBackend>,
543 ) -> GateController<TestBackend> {
544 let d_input = input_weights[0].len();
545 let d_output = input_weights.len();
546
547 let input_record = LinearRecord {
548 weight: Param::from_data(TensorData::from(input_weights), device),
549 bias: Some(Param::from_data(TensorData::from(input_biases), device)),
550 };
551 let hidden_record = LinearRecord {
552 weight: Param::from_data(TensorData::from(hidden_weights), device),
553 bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
554 };
555 GateController::create_with_weights(
556 d_input,
557 d_output,
558 true,
559 Initializer::XavierUniform { gain: 1.0 },
560 input_record,
561 hidden_record,
562 )
563 }
564
565 let input = Tensor::<TestBackend, 3>::from_data(
567 TensorData::from([[
568 [0.949, -0.861],
569 [0.892, 0.927],
570 [-0.173, -0.301],
571 [-0.081, 0.992],
572 ]]),
573 &device,
574 );
575
576 let h0 = Tensor::<TestBackend, 3>::from_data(
578 TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),
579 &device,
580 );
581
582 rnn.forward.gate = create_gate_controller(
583 [[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]],
585 [-0.196, 0.354, 0.209],
587 [
589 [-0.320, 0.232, -0.165],
590 [0.093, -0.572, -0.315],
591 [-0.467, 0.325, 0.046],
592 ],
593 [0.181, -0.190, -0.245],
595 &device,
596 );
597
598 rnn.reverse.gate = create_gate_controller(
599 [[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]],
600 [0.540, -0.164, 0.033],
601 [
602 [0.159, 0.180, -0.037],
603 [-0.443, 0.485, -0.488],
604 [0.098, -0.085, -0.140],
605 ],
606 [-0.510, 0.105, 0.114],
607 &device,
608 );
609
610 let expected_output_with_init_state = TensorData::from([[
613 [0.5226, -0.6370, 0.0210, 0.0685, 0.3867, 0.3602],
614 [0.3580, 0.8431, 0.4129, -0.3175, 0.4374, 0.1766],
615 [-0.3837, -0.2703, -0.3957, -0.1542, -0.1122, 0.0725],
616 [0.5059, 0.5527, 0.1244, -0.6779, 0.3725, -0.3387],
617 ]]);
618 let expected_output_without_init_state = TensorData::from([[
619 [0.0560, -0.2056, 0.2334, 0.0892, 0.3912, 0.3607],
620 [0.4340, 0.7378, 0.3714, -0.2394, 0.4235, 0.2002],
621 [-0.3962, -0.2097, -0.3798, 0.0532, -0.2067, 0.1727],
622 [0.5075, 0.5298, 0.1083, -0.3200, 0.0764, -0.1282],
623 ]]);
624
625 let expected_hn_with_init_state =
627 TensorData::from([[[0.5059, 0.5527, 0.1244]], [[0.0685, 0.3867, 0.3602]]]);
628 let expected_hn_without_init_state =
629 TensorData::from([[[0.5075, 0.5298, 0.1083]], [[0.0892, 0.3912, 0.3607]]]);
630
631 let (output_with_init_state, state_with_init_state) =
632 rnn.forward(input.clone(), Some(RnnState::new(h0)));
633 let (output_without_init_state, state_without_init_state) = rnn.forward(input, None);
634
635 let tolerance = Tolerance::permissive();
636 output_with_init_state
637 .to_data()
638 .assert_approx_eq::<FT>(&expected_output_with_init_state, tolerance);
639 output_without_init_state
640 .to_data()
641 .assert_approx_eq::<FT>(&expected_output_without_init_state, tolerance);
642 state_with_init_state
643 .hidden
644 .to_data()
645 .assert_approx_eq::<FT>(&expected_hn_with_init_state, tolerance);
646 state_without_init_state
647 .hidden
648 .to_data()
649 .assert_approx_eq::<FT>(&expected_hn_without_init_state, tolerance);
650 }
651
652 #[test]
653 fn display_rnn() {
654 let config = RnnConfig::new(2, 3, true);
655
656 let layer = config.init::<TestBackend>(&Default::default());
657
658 assert_eq!(
659 alloc::format!("{layer}"),
660 "Rnn {d_input: 2, d_hidden: 3, bias: true, params: 21}"
661 );
662 }
663
664 #[test]
665 fn display_birnn() {
666 let config = BiRnnConfig::new(2, 3, true);
667
668 let layer = config.init::<TestBackend>(&Default::default());
669
670 assert_eq!(
671 alloc::format!("{layer}"),
672 "BiRnn {d_input: 2, d_hidden: 3, bias: true, params: 42}"
673 );
674 }
675
676 #[test]
677 fn test_rnn_clipping() {
678 let device = Default::default();
679
680 let clip_value = 0.3;
682 let config = RnnConfig::new(4, 8, true).with_clip(Some(clip_value));
683 let rnn = config.init::<TestBackend>(&device);
684
685 let input = Tensor::<TestBackend, 3>::random([2, 5, 4], Distribution::Default, &device);
686 let (_, state) = rnn.forward(input, None);
687
688 let hidden_state: Vec<f32> = state.hidden.to_data().to_vec().unwrap();
690 for val in hidden_state {
691 assert!(
692 val >= -clip_value as f32 && val <= clip_value as f32,
693 "Value {} is outside clip range [-{}, {}]",
694 val,
695 clip_value,
696 clip_value
697 );
698 }
699 }
700
701 #[test]
702 fn test_forward_reverse_sequence() {
703 let device = Default::default();
704 TestBackend::seed(&device, 0);
705
706 let config = RnnConfig::new(1, 1, false).with_reverse(true);
708 let mut rnn = config.init::<TestBackend>(&device);
709
710 rnn.gate = create_single_feature_gate_controller(
711 0.5,
712 0.0,
713 1,
714 1,
715 false,
716 Initializer::XavierUniform { gain: 1.0 },
717 &device,
718 );
719
720 let input =
723 Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device);
724
725 let (output, state) = rnn.forward(input, None);
726
727 let expected_final_hidden = TensorData::from([[0.135508]]);
732
733 let tolerance = Tolerance::default();
734 state
735 .hidden
736 .to_data()
737 .assert_approx_eq::<FT>(&expected_final_hidden, tolerance);
738
739 assert_eq!(output.dims(), [1, 3, 1]);
741 }
742}