burn_core/nn/rnn/
gate_controller.rs

1use crate as burn;
2
3use crate::module::Module;
4use crate::nn::{Initializer, Linear, LinearConfig};
5use crate::tensor::{Tensor, backend::Backend};
6
7/// A GateController represents a gate in an LSTM cell. An
8/// LSTM cell generally contains three gates: an input gate,
9/// forget gate, and output gate. Additionally, cell gate
10/// is just used to compute the cell state.
11///
12/// An Lstm gate is modeled as two linear transformations.
13/// The results of these transformations are used to calculate
14/// the gate's output.
15#[derive(Module, Debug)]
16pub struct GateController<B: Backend> {
17    /// Represents the affine transformation applied to input vector
18    pub input_transform: Linear<B>,
19    /// Represents the affine transformation applied to the hidden state
20    pub hidden_transform: Linear<B>,
21}
22
23impl<B: Backend> GateController<B> {
24    /// Initialize a new [gate_controller](GateController) module.
25    pub fn new(
26        d_input: usize,
27        d_output: usize,
28        bias: bool,
29        initializer: Initializer,
30        device: &B::Device,
31    ) -> Self {
32        Self {
33            input_transform: LinearConfig {
34                d_input,
35                d_output,
36                bias,
37                initializer: initializer.clone(),
38            }
39            .init(device),
40            hidden_transform: LinearConfig {
41                d_input: d_output,
42                d_output,
43                bias,
44                initializer,
45            }
46            .init(device),
47        }
48    }
49
50    /// Helper function for performing weighted matrix product for a gate and adds
51    /// bias, if any.
52    ///
53    ///  Mathematically, performs `Wx*X + Wh*H + b`, where:
54    ///     Wx = weight matrix for the connection to input vector X
55    ///     Wh = weight matrix for the connection to hidden state H
56    ///     X = input vector
57    ///     H = hidden state
58    ///     b = bias terms
59    pub fn gate_product(&self, input: Tensor<B, 2>, hidden: Tensor<B, 2>) -> Tensor<B, 2> {
60        self.input_transform.forward(input) + self.hidden_transform.forward(hidden)
61    }
62
63    /// Used to initialize a gate controller with known weight layers,
64    /// allowing for predictable behavior. Used only for testing in
65    /// lstm.
66    #[cfg(test)]
67    pub fn create_with_weights(
68        d_input: usize,
69        d_output: usize,
70        bias: bool,
71        initializer: Initializer,
72        input_record: crate::nn::LinearRecord<B>,
73        hidden_record: crate::nn::LinearRecord<B>,
74    ) -> Self {
75        let l1 = LinearConfig {
76            d_input,
77            d_output,
78            bias,
79            initializer: initializer.clone(),
80        }
81        .init(&input_record.weight.device())
82        .load_record(input_record);
83        let l2 = LinearConfig {
84            d_input,
85            d_output,
86            bias,
87            initializer,
88        }
89        .init(&hidden_record.weight.device())
90        .load_record(hidden_record);
91
92        Self {
93            input_transform: l1,
94            hidden_transform: l2,
95        }
96    }
97}