Skip to main content

burn_nn/modules/rnn/
gate_controller.rs

1use burn_core as burn;
2
3use crate::{Linear, LinearConfig, LinearLayout};
4use burn::module::{Initializer, Module};
5use burn::tensor::{Tensor, backend::Backend};
6
7/// A GateController represents a gate in an LSTM cell. An
8/// LSTM cell generally contains three gates: an input gate,
9/// forget gate, and output gate. Additionally, cell gate
10/// is just used to compute the cell state.
11///
12/// An Lstm gate is modeled as two linear transformations.
13/// The results of these transformations are used to calculate
14/// the gate's output.
15#[derive(Module, Debug)]
16pub struct GateController<B: Backend> {
17    /// Represents the affine transformation applied to input vector
18    pub input_transform: Linear<B>,
19    /// Represents the affine transformation applied to the hidden state
20    pub hidden_transform: Linear<B>,
21}
22
23impl<B: Backend> GateController<B> {
24    /// Initialize a new [gate_controller](GateController) module.
25    pub fn new(
26        d_input: usize,
27        d_output: usize,
28        bias: bool,
29        initializer: Initializer,
30        device: &B::Device,
31    ) -> Self {
32        Self {
33            input_transform: LinearConfig {
34                d_input,
35                d_output,
36                bias,
37                initializer: initializer.clone(),
38                layout: LinearLayout::Row,
39            }
40            .init(device),
41            hidden_transform: LinearConfig {
42                d_input: d_output,
43                d_output,
44                bias,
45                initializer,
46                layout: LinearLayout::Row,
47            }
48            .init(device),
49        }
50    }
51
52    /// Helper function for performing weighted matrix product for a gate and adds
53    /// bias, if any.
54    ///
55    ///  Mathematically, performs `Wx*X + Wh*H + b`, where:
56    ///     Wx = weight matrix for the connection to input vector X
57    ///     Wh = weight matrix for the connection to hidden state H
58    ///     X = input vector
59    ///     H = hidden state
60    ///     b = bias terms
61    pub fn gate_product(&self, input: Tensor<B, 2>, hidden: Tensor<B, 2>) -> Tensor<B, 2> {
62        self.input_transform.forward(input) + self.hidden_transform.forward(hidden)
63    }
64
65    /// Used to initialize a gate controller with known weight layers,
66    /// allowing for predictable behavior. Used only for testing in
67    /// lstm.
68    #[cfg(test)]
69    pub fn create_with_weights(
70        d_input: usize,
71        d_output: usize,
72        bias: bool,
73        initializer: Initializer,
74        input_record: crate::LinearRecord<B>,
75        hidden_record: crate::LinearRecord<B>,
76    ) -> Self {
77        let l1 = LinearConfig {
78            d_input,
79            d_output,
80            bias,
81            initializer: initializer.clone(),
82            layout: LinearLayout::Row,
83        }
84        .init(&input_record.weight.device())
85        .load_record(input_record);
86        let l2 = LinearConfig {
87            d_input,
88            d_output,
89            bias,
90            initializer,
91            layout: LinearLayout::Row,
92        }
93        .init(&hidden_record.weight.device())
94        .load_record(hidden_record);
95
96        Self {
97            input_transform: l1,
98            hidden_transform: l2,
99        }
100    }
101}