use af;
use af::{Dim4, Array};
use af::MatProp;
use activations;
use initializations;
use params::{LSTMIndex, Input, Params};
pub struct LSTM {
pub input_size: usize,
pub output_size: usize,
pub max_seq_size: usize,
pub return_sequences: bool,
}
pub enum ActivationIndex {
Inner,
Outer,
}
impl RTRL for LSTM {
pub fn rtrl(&self, delta: &Array, params: &mut Params) -> Array
{
let inner_activation = params.activation[0];
let outer_activation = params.activation[1];
let i_t = params.recurrences[LSTMIndex::Input];
let f_t = params.recurrences[LSTMIndex::Forget];
let o_t = params.recurrences[LSTMIndex::Output];
let ct_t = params.recurrences[LSTMIndex::CellTilda];
let c_t = params.recurrences[LSTMIndex::Cell];
let h_t = params.recurrences[LSTMIndex::CellOutput];
let inputs = params.inputs.last().unwrap();
let mut derivatives = params.optional.pop().unwrap();
let mut dW_tm1 = derivatives[0];
let mut dU_tm1 = derivatives[1];
let mut db_tm1 = derivatives[2];
let e_t = af::mul(&af::mul(&o_t, &activations::get_derivative(outer_activation, &c_t).unwrap()).unwrap()
, delta).unwrap();
let dz = vec![&activations::get_derivative(inner_activation, &i_t).unwrap()
, &activations::get_derivative(inner_activation, &f_t).unwrap()
, &activations::get_derivative(outer_activation, &ct_t).unwrap()];
let ct_ctm1_it = vec![&ct_t, &c_tm1, &i_t];
let dzprod = af::mul(&af::join_many(0, ct_ctm1_it).unwrap()
, af::join_many(0, dz).unwrap(), false).unwrap();
let w_lhs = af::mul(dW_tm1, &f_t, true).unwrap(); let w_rhs = af::mul(&dzprod, &inputs.data, true).unwrap();
dW_tm1 = af::add(&w_lhs, &w_rhs, false).unwrap();
let u_lhs = af::mul(dU_tm1, &f_t, true).unwrap(); let u_rhs = af::mul(&dzprod, &recurrences.data, true).unwrap();
dU_tm1 = af::add(&u_lhs, &u_rhs, false).unwrap();
let b_lhs = af::mul(db_tm1, &f_t, true).unwrap(); params.optional[2] = af::add(&b_lhs, &dzprod, false).unwrap(); }
}
impl Layer for LSTM {
fn forward(&self, params: &mut Params, inputs: &Input, train: bool) -> Input
{
assert!(inputs.data.dims().unwrap()[2] == 1);
let h_tm1 = params.recurrences[LSTMIndex::CellOutput].last().unwrap(); let c_tm1 = params.recurrences[LSTMIndex::Cell].last().unwrap(); let inner_activation = params.activations[0];
let outer_activation = params.activations[1];
let weights_ref = vec![¶ms.weights[LSTMIndex::Input]
, ¶ms.weights[LSTMIndex::Forget]
, ¶ms.weights[LSTMIndex::Output]
, ¶ms.weights[LSTMIndex::CellTilda]];
let offset = 4; let recurrents_ref = vec![¶ms.weights[LSTMIndex::Input as usize + offset]
, ¶ms.weights[LSTMIndex::Forget as usize + offset]
, ¶ms.weights[LSTMIndex::Output as usize + offset]
, ¶ms.weights[LSTMIndex::CellTilda as usize + offset]];
let bias_ref = vec![¶ms.biases[LSTMIndex::Input]
, ¶ms.biases[LSTMIndex::Forget]
, ¶ms.biases[LSTMIndex::Output]
, ¶ms.biases[LSTMIndex::CellTilda]];
let z_t = af::add(&af::add(&af::matmul(&af::join_many(0, weights_ref).unwrap(), inputs.data).unwrap()
, &af::matmul(&af::join_many(0, recurrents_ref).unwrap(), &h_tm1).unwrap(), false).unwrap()
, &af::join_many(0, bias_ref).unwrap(), true).unwrap();
let i_t = activations::get_activation(inner_activation, &af::rows(&z_t, 0, 0).unwrap());
let f_t = activations::get_activation(inner_activation, &af::rows(&z_t, 1, 1).unwrap());
let o_t = activations::get_activation(inner_activation, &af::rows(&z_t, 2, 2).unwrap());
let ct_t = activations::get_activation(inner_activation, &af::rows(&z_t, 3, 3).unwrap());
let c_t = af::add(&af::mul(&i_t, &ct_t, false).unwrap()
, &af::mul(&f_t, &c_tm1, false).unwrap()
, false).unwrap();
let h_t = af::mul(&o_t, &activations::get_activation(outer_activation, &c_t).unwrap(), false).unwrap();
if train { params.inputs.push(inputs.clone());
params.outputs.push(h_t.clone());
params.recurrences[LSTMIndex::Input].push(i_t.clone());
params.recurrences[LSTMIndex::Forget].push(f_t.clone());
params.recurrences[LSTMIndex::Output].push(o_t.clone());
params.recurrences[LSTMIndex::CellTilda].push(ct_t.clone());
params.recurrences[LSTMIndex::Cell].push(c_t.clone()); params.recurrences[LSTMIndex::CellOutput].push(h_t.clone());
}
if self.return_sequences {
Input { data: af::join_many(1, vec![&h_t, &c_t]).unwrap() , activation: self.outer_activation }
}else {
Input { data: h_t.clone()
, activation: self.outer_activation }
}
}
fn backward(&self, params: &mut Params, delta: &Array) -> Array {
let inner_activation = params.activations[0];
let outer_activation = params.activations[1];
let o_t = params.recurrences[LSTMIndex::Output].last().unwrap();
let c_t = params.recurrences[LSTMIndex::Cell].last().unwrap();
params.deltas = vec![af::mul(delta, &activations::get_derivative(¶ms.activations[0]
, ¶ms.outputs[0].data).unwrap(), false).unwrap()];
let d_h = af::mul(&af::mul(&activations::get_derivative(inner_activation, &o_t).unwrap()
, &activations::get_activation(outer_activation, &c_t).unwrap()).unwrap()
, delta).unwrap();
let d_i = self.rtrl(delta, &d_h, &mut params);
params.recurrences[LSTMIndex::Cell].pop();
params.recurrences[LSTMIndex::CellOutput].pop();
params.recurrences[LSTMIndex::CellTilda].pop();
params.recurrences[LSTMIndex::Forget].pop();
params.recurrences[LSTMIndex::Input].pop();
params.recurrences[LSTMIndex::Output].pop();
let activation_prev = activations::get_activation(self.inputs.activation[0], &self.inputs.data[DataIndex::Input]).unwrap();
let d_activation_prev = activations::get_derivative(self.inputs.activation[0], &activation_prev).unwrap();
let delta_prev = af::mul(&af::matmul(¶ms.weights[0], delta, af::MatProp::TRANS, af::MatProp::NONE).unwrap()
, &d_activation_prev, false).unwrap();
delta_prev
}
}