1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
use crate as burn;
use crate::module::Module;
use crate::nn::{Initializer, Linear, LinearConfig};
use crate::tensor::{backend::Backend, Tensor};
/// A GateController represents a gate in an LSTM cell. An
/// LSTM cell generally contains three gates: an input gate,
/// forget gate, and output gate. Additionally, cell gate
/// is just used to compute the cell state.
///
/// An Lstm gate is modeled as two linear transformations.
/// The results of these transformations are used to calculate
/// the gate's output.
#[derive(Module, Debug)]
pub struct GateController<B: Backend> {
/// Represents the affine transformation applied to input vector
pub input_transform: Linear<B>,
/// Represents the affine transformation applied to the hidden state
pub hidden_transform: Linear<B>,
}
impl<B: Backend> GateController<B> {
/// Initialize a new [gate_controller](GateController) module.
pub fn new(
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
device: &B::Device,
) -> Self {
Self {
input_transform: LinearConfig {
d_input,
d_output,
bias,
initializer: initializer.clone(),
}
.init(device),
hidden_transform: LinearConfig {
d_input: d_output,
d_output,
bias,
initializer,
}
.init(device),
}
}
/// Helper function for performing weighted matrix product for a gate and adds
/// bias, if any.
///
/// Mathematically, performs `Wx*X + Wh*H + b`, where:
/// Wx = weight matrix for the connection to input vector X
/// Wh = weight matrix for the connection to hidden state H
/// X = input vector
/// H = hidden state
/// b = bias terms
pub fn gate_product(&self, input: Tensor<B, 2>, hidden: Tensor<B, 2>) -> Tensor<B, 2> {
self.input_transform.forward(input) + self.hidden_transform.forward(hidden)
}
/// Used to initialize a gate controller with known weight layers,
/// allowing for predictable behavior. Used only for testing in
/// lstm.
#[cfg(test)]
pub fn create_with_weights(
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
input_record: crate::nn::LinearRecord<B>,
hidden_record: crate::nn::LinearRecord<B>,
) -> Self {
let l1 = LinearConfig {
d_input,
d_output,
bias,
initializer: initializer.clone(),
}
.init(&input_record.weight.device())
.load_record(input_record);
let l2 = LinearConfig {
d_input,
d_output,
bias,
initializer,
}
.init(&hidden_record.weight.device())
.load_record(hidden_record);
Self {
input_transform: l1,
hidden_transform: l2,
}
}
}