burn_core/nn/rnn/gate_controller.rs
1use crate as burn;
2
3use crate::module::Module;
4use crate::nn::{Initializer, Linear, LinearConfig};
5use crate::tensor::{Tensor, backend::Backend};
6
7/// A GateController represents a gate in an LSTM cell. An
8/// LSTM cell generally contains three gates: an input gate,
9/// forget gate, and output gate. Additionally, cell gate
10/// is just used to compute the cell state.
11///
12/// An Lstm gate is modeled as two linear transformations.
13/// The results of these transformations are used to calculate
14/// the gate's output.
15#[derive(Module, Debug)]
16pub struct GateController<B: Backend> {
17 /// Represents the affine transformation applied to input vector
18 pub input_transform: Linear<B>,
19 /// Represents the affine transformation applied to the hidden state
20 pub hidden_transform: Linear<B>,
21}
22
23impl<B: Backend> GateController<B> {
24 /// Initialize a new [gate_controller](GateController) module.
25 pub fn new(
26 d_input: usize,
27 d_output: usize,
28 bias: bool,
29 initializer: Initializer,
30 device: &B::Device,
31 ) -> Self {
32 Self {
33 input_transform: LinearConfig {
34 d_input,
35 d_output,
36 bias,
37 initializer: initializer.clone(),
38 }
39 .init(device),
40 hidden_transform: LinearConfig {
41 d_input: d_output,
42 d_output,
43 bias,
44 initializer,
45 }
46 .init(device),
47 }
48 }
49
50 /// Helper function for performing weighted matrix product for a gate and adds
51 /// bias, if any.
52 ///
53 /// Mathematically, performs `Wx*X + Wh*H + b`, where:
54 /// Wx = weight matrix for the connection to input vector X
55 /// Wh = weight matrix for the connection to hidden state H
56 /// X = input vector
57 /// H = hidden state
58 /// b = bias terms
59 pub fn gate_product(&self, input: Tensor<B, 2>, hidden: Tensor<B, 2>) -> Tensor<B, 2> {
60 self.input_transform.forward(input) + self.hidden_transform.forward(hidden)
61 }
62
63 /// Used to initialize a gate controller with known weight layers,
64 /// allowing for predictable behavior. Used only for testing in
65 /// lstm.
66 #[cfg(test)]
67 pub fn create_with_weights(
68 d_input: usize,
69 d_output: usize,
70 bias: bool,
71 initializer: Initializer,
72 input_record: crate::nn::LinearRecord<B>,
73 hidden_record: crate::nn::LinearRecord<B>,
74 ) -> Self {
75 let l1 = LinearConfig {
76 d_input,
77 d_output,
78 bias,
79 initializer: initializer.clone(),
80 }
81 .init(&input_record.weight.device())
82 .load_record(input_record);
83 let l2 = LinearConfig {
84 d_input,
85 d_output,
86 bias,
87 initializer,
88 }
89 .init(&hidden_record.weight.device())
90 .load_record(hidden_record);
91
92 Self {
93 input_transform: l1,
94 hidden_transform: l2,
95 }
96 }
97}