use af;
use af::{Array, Dim4, MatProp};
use std::sync::{Arc, Mutex};
use utils;
use activations;
use params::Params;
use layer::Layer;
pub struct RNN {
pub input_size: usize,
pub output_size: usize,
}
impl Layer for RNN
{
fn forward(&self, params: Arc<Mutex<Params>>, inputs: &Array) -> Array
{
let mut ltex = params.lock().unwrap();
let current_unroll = ltex.current_unroll;
let wx = af::matmul(&inputs , <ex.weights[0]
, MatProp::NONE
, MatProp::NONE);
let uhtm1 = match ltex.recurrences.len(){
0 => {
let output_size = ltex.weights[1].dims()[0]; let init_h_dims = Dim4::new(&[inputs.dims()[0], output_size, 1, 1]);
ltex.recurrences.push(utils::constant(init_h_dims, inputs.get_type(), 0f32));
af::matmul(&utils::constant(init_h_dims, inputs.get_type(), 0f32)
, <ex.weights[1]
, MatProp::NONE
, MatProp::NONE)
},
_ => {
af::matmul(<ex.recurrences[ltex.current_unroll]
, <ex.weights[1]
, MatProp::NONE
, MatProp::NONE)
}
};
let h_t = af::transpose(&af::add(&af::transpose(&uhtm1, false),
&af::add(&af::transpose(&wx, false)
, <ex.biases[0], true), false)
, false);
let a_t = activations::get_activation(<ex.activations[0], &h_t).unwrap();
if ltex.inputs.len() > current_unroll { ltex.inputs[current_unroll] = inputs.clone();
ltex.outputs[current_unroll] = a_t.clone();
ltex.recurrences[current_unroll] = h_t.clone();
}else{ ltex.inputs.push(inputs.clone());
ltex.outputs.push(a_t.clone());
ltex.recurrences.push(h_t.clone());
}
ltex.current_unroll += 1;
a_t.clone() }
fn backward(&self, params: Arc<Mutex<Params>>, delta: &Array) -> Array {
let mut ltex = params.lock().unwrap();
let current_unroll = ltex.current_unroll;
assert!(current_unroll > 0
, "Cannot call backward pass without at least 1 forward pass");
let dz = activations::get_derivative(<ex.activations[0]
, <ex.outputs[current_unroll - 1]).unwrap();
let delta_t = af::mul(delta, &dz, false);
let dw = af::matmul(<ex.inputs[current_unroll - 1], &delta_t , af::MatProp::TRANS
, af::MatProp::NONE);
let du = af::matmul(<ex.recurrences[current_unroll - 1], &delta_t , af::MatProp::TRANS
, af::MatProp::NONE);
let db = af::transpose(&af::sum(&delta_t, 0), false); ltex.deltas[0] = af::add(<ex.deltas[0], &dw, false);
ltex.deltas[1] = af::add(<ex.deltas[1], &du, false);
ltex.deltas[2] = af::add(<ex.deltas[2], &db, false);
ltex.current_unroll -= 1;
af::matmul(&delta_t, <ex.weights[0], af::MatProp::NONE, af::MatProp::TRANS)
}
}