burn_nn/modules/rnn/gate_controller.rs
1use burn_core as burn;
2
3use crate::{Linear, LinearConfig, LinearLayout};
4use burn::module::{Initializer, Module};
5use burn::tensor::{Tensor, backend::Backend};
6
7/// A GateController represents a gate in an LSTM cell. An
8/// LSTM cell generally contains three gates: an input gate,
9/// forget gate, and output gate. Additionally, cell gate
10/// is just used to compute the cell state.
11///
12/// An Lstm gate is modeled as two linear transformations.
13/// The results of these transformations are used to calculate
14/// the gate's output.
15#[derive(Module, Debug)]
16pub struct GateController<B: Backend> {
17 /// Represents the affine transformation applied to input vector
18 pub input_transform: Linear<B>,
19 /// Represents the affine transformation applied to the hidden state
20 pub hidden_transform: Linear<B>,
21}
22
23impl<B: Backend> GateController<B> {
24 /// Initialize a new [gate_controller](GateController) module.
25 pub fn new(
26 d_input: usize,
27 d_output: usize,
28 bias: bool,
29 initializer: Initializer,
30 device: &B::Device,
31 ) -> Self {
32 Self {
33 input_transform: LinearConfig {
34 d_input,
35 d_output,
36 bias,
37 initializer: initializer.clone(),
38 layout: LinearLayout::Row,
39 }
40 .init(device),
41 hidden_transform: LinearConfig {
42 d_input: d_output,
43 d_output,
44 bias,
45 initializer,
46 layout: LinearLayout::Row,
47 }
48 .init(device),
49 }
50 }
51
52 /// Helper function for performing weighted matrix product for a gate and adds
53 /// bias, if any.
54 ///
55 /// Mathematically, performs `Wx*X + Wh*H + b`, where:
56 /// Wx = weight matrix for the connection to input vector X
57 /// Wh = weight matrix for the connection to hidden state H
58 /// X = input vector
59 /// H = hidden state
60 /// b = bias terms
61 pub fn gate_product(&self, input: Tensor<B, 2>, hidden: Tensor<B, 2>) -> Tensor<B, 2> {
62 self.input_transform.forward(input) + self.hidden_transform.forward(hidden)
63 }
64
65 /// Used to initialize a gate controller with known weight layers,
66 /// allowing for predictable behavior. Used only for testing in
67 /// lstm.
68 #[cfg(test)]
69 pub fn create_with_weights(
70 d_input: usize,
71 d_output: usize,
72 bias: bool,
73 initializer: Initializer,
74 input_record: crate::LinearRecord<B>,
75 hidden_record: crate::LinearRecord<B>,
76 ) -> Self {
77 let l1 = LinearConfig {
78 d_input,
79 d_output,
80 bias,
81 initializer: initializer.clone(),
82 layout: LinearLayout::Row,
83 }
84 .init(&input_record.weight.device())
85 .load_record(input_record);
86 let l2 = LinearConfig {
87 d_input,
88 d_output,
89 bias,
90 initializer,
91 layout: LinearLayout::Row,
92 }
93 .init(&hidden_record.weight.device())
94 .load_record(hidden_record);
95
96 Self {
97 input_transform: l1,
98 hidden_transform: l2,
99 }
100 }
101}