extern crate rand;
use std::fmt;
use std::mem;
use std::any::Any;
use std::sync::{Arc, RwLock};
use super::{
layertype::LayerType,
layer::Layer,
dense::Dense,
vectorops
};
use super::super::{
activation::Activation,
neatenv::NeatEnvironment,
};
use crate::Genome;
#[derive(Debug)]
pub struct LSTMState {
pub f_gate_output: Vec<Vec<f32>>,
pub i_gate_output: Vec<Vec<f32>>,
pub s_gate_output: Vec<Vec<f32>>,
pub o_gate_output: Vec<Vec<f32>>,
pub memory_states: Vec<Vec<f32>>,
pub d_prev_memory: Option<Vec<f32>>,
pub d_prev_hidden: Option<Vec<f32>>
}
impl LSTMState {
pub fn new() -> Self {
LSTMState {
f_gate_output: Vec::new(),
i_gate_output: Vec::new(),
s_gate_output: Vec::new(),
o_gate_output: Vec::new(),
memory_states: Vec::new(),
d_prev_memory: None,
d_prev_hidden: None
}
}
pub fn update_forward(&mut self, fg: Vec<f32>, ig: Vec<f32>, sg: Vec<f32>, og: Vec<f32>, mem_state: Vec<f32>) {
self.f_gate_output.push(fg);
self.i_gate_output.push(ig);
self.s_gate_output.push(sg);
self.o_gate_output.push(og);
self.memory_states.push(mem_state);
}
}
#[derive(Debug)]
pub struct LSTM {
pub input_size: u32,
pub memory_size: u32,
pub output_size: u32,
pub memory: Vec<f32>,
pub hidden: Vec<f32>,
pub states: LSTMState,
pub g_gate: Dense,
pub i_gate: Dense,
pub f_gate: Dense,
pub o_gate: Dense,
pub v_gate: Dense
}
impl LSTM {
pub fn new(input_size: u32, memory_size: u32, output_size: u32) -> Self {
let cell_input = input_size + memory_size;
LSTM {
input_size,
memory_size,
output_size,
memory: vec![0.0; memory_size as usize],
hidden: vec![0.0; memory_size as usize],
states: LSTMState::new(),
g_gate: Dense::new(cell_input, memory_size, LayerType::DensePool, Activation::Tahn),
i_gate: Dense::new(cell_input, memory_size, LayerType::DensePool, Activation::Sigmoid),
f_gate: Dense::new(cell_input, memory_size, LayerType::DensePool, Activation::Sigmoid),
o_gate: Dense::new(cell_input, memory_size, LayerType::DensePool, Activation::Sigmoid),
v_gate: Dense::new(memory_size, output_size, LayerType::DensePool, Activation::Sigmoid)
}
}
#[inline]
pub fn step_back(&mut self, errors: &Vec<f32>, l_rate: f32) -> Option<Vec<f32>> {
let dh_next = self.states.d_prev_hidden.clone()?;
let dc_next = self.states.d_prev_memory.clone()?;
let c_old = self.states.memory_states.pop()?;
let g_curr = self.states.s_gate_output.pop()?;
let i_curr = self.states.i_gate_output.pop()?;
let f_curr = self.states.f_gate_output.pop()?;
let o_curr = self.states.o_gate_output.pop()?;
let mut dh = self.v_gate.backward(errors, l_rate)?;
vectorops::element_multiply(&mut dh, &dh_next);
let mut dho = vectorops::element_activate(&c_old, Activation::Tahn);
vectorops::element_multiply(&mut dho, &dh);
vectorops::element_multiply(&mut dho, &vectorops::element_deactivate(&o_curr, self.o_gate.activation));
let mut dc = vectorops::product(&o_curr, &dh);
vectorops::element_multiply(&mut dc, &vectorops::element_deactivate(&c_old, Activation::Tahn));
vectorops::element_add(&mut dc, &dc_next);
let mut dhf = vectorops::product(&c_old, &dc);
vectorops::element_multiply(&mut dhf, &vectorops::element_deactivate(&f_curr, self.f_gate.activation));
let mut dhi = vectorops::product(&g_curr, &dc);
vectorops::element_multiply(&mut dhi, &vectorops::element_deactivate(&i_curr, self.i_gate.activation));
let mut dhc = vectorops::product(&i_curr, &dc);
vectorops::element_multiply(&mut dhc, &vectorops::element_deactivate(&g_curr, self.g_gate.activation));
let f_error = self.f_gate.backward(&dhf, l_rate)?;
let i_error = self.i_gate.backward(&dhi, l_rate)?;
let g_error = self.g_gate.backward(&dhc, l_rate)?;
let o_error = self.o_gate.backward(&dho, l_rate)?;
let mut dx = vec![0.0; (self.input_size + self.memory_size) as usize];
vectorops::element_add(&mut dx, &f_error);
vectorops::element_add(&mut dx, &i_error);
vectorops::element_add(&mut dx, &g_error);
vectorops::element_add(&mut dx, &o_error);
let dh_next = dx[..self.memory_size as usize].to_vec();
let dc_next = vectorops::product(&f_curr, &dc);
self.states.d_prev_hidden = Some(dh_next);
self.states.d_prev_memory = Some(dc_next);
Some(dx[self.memory_size as usize..].to_vec())
}
}
impl Layer for LSTM {
#[inline]
fn forward(&mut self, inputs: &Vec<f32>) -> Option<Vec<f32>> {
let mut hidden_input = self.hidden.clone();
hidden_input.extend(inputs);
let f_output = self.f_gate.forward(&hidden_input)?;
let i_output = self.i_gate.forward(&hidden_input)?;
let o_output = self.o_gate.forward(&hidden_input)?;
let g_output = self.g_gate.forward(&hidden_input)?;
let mut current_state = g_output.clone();
let mut current_output = o_output.clone();
vectorops::element_multiply(&mut self.memory, &f_output);
vectorops::element_multiply(&mut current_state, &i_output);
vectorops::element_add(&mut self.memory, ¤t_state);
vectorops::element_multiply(&mut current_output, &vectorops::element_activate(&self.memory, Activation::Tahn));
if let Some(_) = &self.f_gate.trace_states {
self.states.update_forward(f_output, i_output, g_output, o_output, self.memory.clone());
}
self.hidden = current_output;
self.v_gate.forward(&self.hidden)
}
#[inline]
fn backward(&mut self, errors: &Vec<f32>, learning_rate: f32) -> Option<Vec<f32>> {
if self.states.d_prev_hidden.is_none() && self.states.d_prev_memory.is_none() {
self.states.d_prev_memory = Some(vec![0.0; self.memory_size as usize]);
self.states.d_prev_hidden = Some(vec![0.0; self.memory_size as usize]);
}
self.step_back(errors, learning_rate)
}
fn reset(&mut self) {
self.g_gate.reset();
self.i_gate.reset();
self.f_gate.reset();
self.o_gate.reset();
self.v_gate.reset();
self.states = LSTMState::new();
self.memory = vec![0.0; self.memory_size as usize];
self.hidden = vec![0.0; self.memory_size as usize];
}
fn add_tracer(&mut self) {
self.g_gate.add_tracer();
self.i_gate.add_tracer();
self.f_gate.add_tracer();
self.o_gate.add_tracer();
self.v_gate.add_tracer();
}
fn remove_tracer(&mut self) {
self.g_gate.remove_tracer();
self.i_gate.remove_tracer();
self.f_gate.remove_tracer();
self.o_gate.remove_tracer();
self.v_gate.remove_tracer();
}
fn as_ref_any(&self) -> &dyn Any
where Self: Sized + 'static
{
self
}
fn as_mut_any(&mut self) -> &mut dyn Any
where Self: Sized + 'static
{
self
}
fn shape(&self) -> (usize, usize) {
(self.input_size as usize, self.output_size as usize)
}
}
impl Clone for LSTM {
#[inline]
fn clone(&self) -> Self {
LSTM {
input_size: self.input_size,
memory_size: self.memory_size,
output_size: self.output_size,
memory: vec![0.0; self.memory_size as usize],
hidden: vec![0.0; self.memory_size as usize],
states: LSTMState::new(),
g_gate: self.g_gate.clone(),
i_gate: self.i_gate.clone(),
f_gate: self.f_gate.clone(),
o_gate: self.o_gate.clone(),
v_gate: self.v_gate.clone()
}
}
}
impl Genome<LSTM, NeatEnvironment> for LSTM
where LSTM: Layer
{
#[inline]
fn crossover(child: &LSTM, parent_two: &LSTM, env: &Arc<RwLock<NeatEnvironment>>, crossover_rate: f32) -> Option<LSTM> {
let child = LSTM {
input_size: child.input_size,
memory_size: child.memory_size,
output_size: child.output_size,
memory: vec![0.0; child.memory_size as usize],
hidden: vec![0.0; child.memory_size as usize],
states: LSTMState::new(),
g_gate: Dense::crossover(&child.g_gate, &parent_two.g_gate, env, crossover_rate)?,
i_gate: Dense::crossover(&child.i_gate, &parent_two.i_gate, env, crossover_rate)?,
f_gate: Dense::crossover(&child.f_gate, &parent_two.f_gate, env, crossover_rate)?,
o_gate: Dense::crossover(&child.o_gate, &parent_two.o_gate, env, crossover_rate)?,
v_gate: Dense::crossover(&child.v_gate, &parent_two.v_gate, env, crossover_rate)?
};
Some(child)
}
#[inline]
fn distance(one: &LSTM, two: &LSTM, env: &Arc<RwLock<NeatEnvironment>>) -> f32 {
let mut result = 0.0;
result += Dense::distance(&one.g_gate, &two.g_gate, env);
result += Dense::distance(&one.i_gate, &two.i_gate, env);
result += Dense::distance(&one.f_gate, &two.f_gate, env);
result += Dense::distance(&one.o_gate, &two.o_gate, env);
result += Dense::distance(&one.v_gate, &two.v_gate, env);
result
}
}
unsafe impl Send for LSTM {}
unsafe impl Sync for LSTM {}
impl fmt::Display for LSTM {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
unsafe {
let address: u64 = mem::transmute(self);
write!(f, "LSTM=[{}]", address)
}
}
}