use crate::autograd::Variable;
use crate::nn::{
recurrent_common::{collect_recurrent_parameters, RecurrentOps},
Module,
};
use crate::tensor::Tensor;
use num_traits::Float;
use std::fmt::Debug;
#[derive(Debug)]
pub struct LSTMCell<T: Float + Send + Sync + ndarray::ScalarOperand + num_traits::FromPrimitive> {
weight_ih: Variable<T>,
weight_hh: Variable<T>,
bias_ih: Option<Variable<T>>,
bias_hh: Option<Variable<T>>,
input_size: usize,
hidden_size: usize,
training: bool,
}
impl<T> LSTMCell<T>
where
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
{
pub fn new(input_size: usize, hidden_size: usize, bias: bool) -> Self {
let (weight_ih, weight_hh) = RecurrentOps::init_weights(input_size, hidden_size, 4);
let (bias_ih, bias_hh) = if bias {
RecurrentOps::init_bias(hidden_size, 4)
} else {
(None, None)
};
LSTMCell {
weight_ih,
weight_hh,
bias_ih,
bias_hh,
input_size,
hidden_size,
training: true,
}
}
pub fn forward(
&self,
input: &Variable<T>,
hidden: Option<(&Variable<T>, &Variable<T>)>,
) -> (Variable<T>, Variable<T>) {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let batch_size = input_data.shape()[0];
let (h_prev, c_prev) = match hidden {
Some((h, c)) => (h.clone(), c.clone()),
None => {
let h = Variable::new(Tensor::zeros(&[batch_size, self.hidden_size]), false);
let c = Variable::new(Tensor::zeros(&[batch_size, self.hidden_size]), false);
(h, c)
}
};
let gi = RecurrentOps::linear_transform(input, &self.weight_ih, self.bias_ih.as_ref());
let gh = RecurrentOps::linear_transform(&h_prev, &self.weight_hh, self.bias_hh.as_ref());
let gates = RecurrentOps::add_variables(&gi, &gh);
let input_gate =
RecurrentOps::sigmoid(&RecurrentOps::slice_gates(&gates, 0, self.hidden_size));
let forget_gate =
RecurrentOps::sigmoid(&RecurrentOps::slice_gates(&gates, 1, self.hidden_size));
let cell_gate = RecurrentOps::tanh(&RecurrentOps::slice_gates(&gates, 2, self.hidden_size));
let output_gate =
RecurrentOps::sigmoid(&RecurrentOps::slice_gates(&gates, 3, self.hidden_size));
let forget_term = RecurrentOps::multiply_variables(&forget_gate, &c_prev);
let input_term = RecurrentOps::multiply_variables(&input_gate, &cell_gate);
let new_cell = RecurrentOps::add_variables(&forget_term, &input_term);
let cell_tanh = RecurrentOps::tanh(&new_cell);
let new_hidden = RecurrentOps::multiply_variables(&output_gate, &cell_tanh);
(new_hidden, new_cell)
}
pub fn input_size(&self) -> usize {
self.input_size
}
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
pub fn set_training(&mut self, training: bool) {
self.training = training;
}
pub fn is_training(&self) -> bool {
self.training
}
}
impl<T> Module<T> for LSTMCell<T>
where
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
let (hidden, _cell) = self.forward(input, None);
hidden
}
fn parameters(&self) -> Vec<Variable<T>> {
collect_recurrent_parameters(
&self.weight_ih,
&self.weight_hh,
&self.bias_ih,
&self.bias_hh,
)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
}