pub struct GateController<B: Backend> {
pub input_transform: Linear<B>,
pub hidden_transform: Linear<B>,
}
Expand description
A GateController represents a gate in an LSTM cell. An LSTM cell generally contains three gates: an input gate, forget gate, and output gate. Additionally, cell gate is just used to compute the cell state.
An Lstm gate is modeled as two linear transformations. The results of these transformations are used to calculate the gate’s output.
Fields§
§input_transform: Linear<B>
Represents the affine transformation applied to input vector
Represents the affine transformation applied to the hidden state
Implementations§
Source§impl<B: Backend> GateController<B>
impl<B: Backend> GateController<B>
Sourcepub fn new(
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
device: &B::Device,
) -> Self
pub fn new( d_input: usize, d_output: usize, bias: bool, initializer: Initializer, device: &B::Device, ) -> Self
Initialize a new gate_controller module.
Sourcepub fn gate_product(
&self,
input: Tensor<B, 2>,
hidden: Tensor<B, 2>,
) -> Tensor<B, 2>
pub fn gate_product( &self, input: Tensor<B, 2>, hidden: Tensor<B, 2>, ) -> Tensor<B, 2>
Helper function for performing weighted matrix product for a gate and adds bias, if any.
Mathematically, performs Wx*X + Wh*H + b
, where:
Wx = weight matrix for the connection to input vector X
Wh = weight matrix for the connection to hidden state H
X = input vector
H = hidden state
b = bias terms
Trait Implementations§
Source§impl<B> AutodiffModule<B> for GateController<B>
impl<B> AutodiffModule<B> for GateController<B>
Source§type InnerModule = GateController<<B as AutodiffBackend>::InnerBackend>
type InnerModule = GateController<<B as AutodiffBackend>::InnerBackend>
Source§fn valid(&self) -> Self::InnerModule
fn valid(&self) -> Self::InnerModule
Source§impl<B: Backend> Clone for GateController<B>
impl<B: Backend> Clone for GateController<B>
Source§impl<B: Backend> Display for GateController<B>
impl<B: Backend> Display for GateController<B>
Source§impl<B: Backend> Module<B> for GateController<B>
impl<B: Backend> Module<B> for GateController<B>
Source§type Record = GateControllerRecord<B>
type Record = GateControllerRecord<B>
Source§fn load_record(self, record: Self::Record) -> Self
fn load_record(self, record: Self::Record) -> Self
Source§fn into_record(self) -> Self::Record
fn into_record(self) -> Self::Record
Source§fn num_params(&self) -> usize
fn num_params(&self) -> usize
Source§fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor)
fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor)
Source§fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self
fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self
Source§fn collect_devices(&self, devices: Devices<B>) -> Devices<B>
fn collect_devices(&self, devices: Devices<B>) -> Devices<B>
Source§fn to_device(self, device: &B::Device) -> Self
fn to_device(self, device: &B::Device) -> Self
Source§fn fork(self, device: &B::Device) -> Self
fn fork(self, device: &B::Device) -> Self
Source§fn devices(&self) -> Devices<B>
fn devices(&self) -> Devices<B>
Source§fn save_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
) -> Result<(), RecorderError>
fn save_file<FR, PB>( self, file_path: PB, recorder: &FR, ) -> Result<(), RecorderError>
std
only.Source§fn load_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
device: &B::Device,
) -> Result<Self, RecorderError>
fn load_file<FR, PB>( self, file_path: PB, recorder: &FR, device: &B::Device, ) -> Result<Self, RecorderError>
std
only.Source§fn quantize_weights(self, quantizer: &mut Quantizer) -> Self
fn quantize_weights(self, quantizer: &mut Quantizer) -> Self
Source§impl<B: Backend> ModuleDisplay for GateController<B>
impl<B: Backend> ModuleDisplay for GateController<B>
Source§fn format(&self, passed_settings: DisplaySettings) -> String
fn format(&self, passed_settings: DisplaySettings) -> String
Source§fn custom_settings(&self) -> Option<DisplaySettings>
fn custom_settings(&self) -> Option<DisplaySettings>
Auto Trait Implementations§
impl<B> !Freeze for GateController<B>
impl<B> !RefUnwindSafe for GateController<B>
impl<B> Send for GateController<B>
impl<B> !Sync for GateController<B>
impl<B> Unpin for GateController<B>where
<B as Backend>::FloatTensorPrimitive: Unpin,
<B as Backend>::QuantizedTensorPrimitive: Unpin,
<B as Backend>::Device: Unpin,
impl<B> UnwindSafe for GateController<B>where
<B as Backend>::FloatTensorPrimitive: UnwindSafe,
<B as Backend>::QuantizedTensorPrimitive: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more