1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
use crate as burn;
use crate::module::Module;
use crate::nn::Initializer;
use crate::nn::Linear;
use crate::nn::LinearConfig;
use burn_tensor::backend::Backend;
/// A GateController represents a gate in an LSTM cell. An
/// LSTM cell generally contains three gates: an input gate,
/// forget gate, and cell gate.
///
/// An Lstm gate is modeled as two linear transformations.
/// The results of these transformations are used to calculate
/// the gate's output.
#[derive(Module, Debug)]
pub struct GateController<B: Backend> {
/// Represents the affine transformation applied to input vector
pub(crate) input_transform: Linear<B>,
/// Represents the affine transformation applied to the hidden state
pub(crate) hidden_transform: Linear<B>,
}
impl<B: Backend> GateController<B> {
/// Initialize a new [gate_controller](GateController) module.
pub fn new(d_input: usize, d_output: usize, bias: bool, initializer: Initializer) -> Self {
Self {
input_transform: LinearConfig {
d_input,
d_output,
bias,
initializer: initializer.clone(),
}
.init(),
hidden_transform: LinearConfig {
d_input: d_output,
d_output,
bias,
initializer,
}
.init(),
}
}
/// Initialize a new [gate_controller](GateController) module with a [record](GateControllerRecord).
pub fn new_with(linear_config: &LinearConfig, record: GateControllerRecord<B>) -> Self {
let l1 = LinearConfig::init_with(linear_config, record.input_transform);
let l2 = LinearConfig::init_with(linear_config, record.hidden_transform);
Self {
input_transform: l1,
hidden_transform: l2,
}
}
/// Used to initialize a gate controller with known weight layers,
/// allowing for predictable behavior. Used only for testing in
/// lstm.
#[cfg(test)]
pub fn create_with_weights(
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
input_record: crate::nn::LinearRecord<B>,
hidden_record: crate::nn::LinearRecord<B>,
) -> Self {
let l1 = LinearConfig {
d_input,
d_output,
bias,
initializer: initializer.clone(),
}
.init_with(input_record);
let l2 = LinearConfig {
d_input,
d_output,
bias,
initializer,
}
.init_with(hidden_record);
Self {
input_transform: l1,
hidden_transform: l2,
}
}
}