extern crate rand;
use std::fmt;
use std::any::Any;
use std::sync::{Arc, RwLock};
use std::thread;
use super::{
layertype::LayerType,
layer::Layer,
dense::Dense,
vectorops
};
use super::super::{
activation::Activation,
neatenv::NeatEnvironment,
};
use crate::Genome;
#[derive(Debug, Serialize, Deserialize)]
pub struct LSTMState {
pub f_gate_output: Vec<Vec<f32>>,
pub i_gate_output: Vec<Vec<f32>>,
pub s_gate_output: Vec<Vec<f32>>,
pub o_gate_output: Vec<Vec<f32>>,
pub memory_states: Vec<Vec<f32>>,
pub d_prev_memory: Option<Vec<f32>>,
pub d_prev_hidden: Option<Vec<f32>>
}
impl LSTMState {
pub fn new() -> Self {
LSTMState {
f_gate_output: Vec::new(),
i_gate_output: Vec::new(),
s_gate_output: Vec::new(),
o_gate_output: Vec::new(),
memory_states: Vec::new(),
d_prev_memory: None,
d_prev_hidden: None
}
}
pub fn update_forward(&mut self, fg: Vec<f32>, ig: Vec<f32>, sg: Vec<f32>, og: Vec<f32>, mem_state: Vec<f32>) {
self.f_gate_output.push(fg);
self.i_gate_output.push(ig);
self.s_gate_output.push(sg);
self.o_gate_output.push(og);
self.memory_states.push(mem_state);
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LSTM {
pub input_size: u32,
pub memory_size: u32,
pub output_size: u32,
pub activation: Activation,
pub memory: Vec<f32>,
pub hidden: Vec<f32>,
pub states: LSTMState,
pub g_gate: Arc<RwLock<Dense>>,
pub i_gate: Arc<RwLock<Dense>>,
pub f_gate: Arc<RwLock<Dense>>,
pub o_gate: Arc<RwLock<Dense>>,
pub v_gate: Arc<RwLock<Dense>>
}
impl LSTM {
pub fn new(input_size: u32, memory_size: u32, output_size: u32, activation: Activation) -> Self {
let cell_input = input_size + memory_size;
LSTM {
input_size,
memory_size,
output_size,
activation,
memory: vec![0.0; memory_size as usize],
hidden: vec![0.0; memory_size as usize],
states: LSTMState::new(),
g_gate: Arc::new(RwLock::new(Dense::new(cell_input, memory_size, LayerType::DensePool, Activation::Tanh))),
i_gate: Arc::new(RwLock::new(Dense::new(cell_input, memory_size, LayerType::DensePool, Activation::Sigmoid))),
f_gate: Arc::new(RwLock::new(Dense::new(cell_input, memory_size, LayerType::DensePool, Activation::Sigmoid))),
o_gate: Arc::new(RwLock::new(Dense::new(cell_input, memory_size, LayerType::DensePool, Activation::Sigmoid))),
v_gate: Arc::new(RwLock::new(Dense::new(memory_size, output_size, LayerType::DensePool, activation)))
}
}
#[inline]
pub fn step_forward_async(&mut self, inputs: &[f32]) -> Option<Vec<f32>> {
let mut hidden_input = self.hidden.clone();
hidden_input.extend(inputs);
let g_gate_clone = Arc::clone(&self.g_gate);
let o_gate_clone = Arc::clone(&self.o_gate);
let f_gate_clone = Arc::clone(&self.f_gate);
let i_gate_clone = Arc::clone(&self.i_gate);
let hidden_async = Arc::new(hidden_input);
let g_input = Arc::clone(&hidden_async);
let o_input = Arc::clone(&hidden_async);
let f_input = Arc::clone(&hidden_async);
let i_input = Arc::clone(&hidden_async);
let g_output = thread::spawn(move || { return g_gate_clone.write().unwrap().forward(&*g_input).unwrap(); });
let o_output = thread::spawn(move || { return o_gate_clone.write().unwrap().forward(&*o_input).unwrap(); });
let f_output = thread::spawn(move || { return f_gate_clone.write().unwrap().forward(&*f_input).unwrap(); });
let i_output = thread::spawn(move || { return i_gate_clone.write().unwrap().forward(&*i_input).unwrap(); });
let mut curr_state = g_output.join().ok()?;
let mut curr_output = o_output.join().ok()?;
let f_curr = f_output.join().ok()?;
let i_curr = i_output.join().ok()?;
let g_out = curr_state.clone();
let o_out = curr_output.clone();
vectorops::element_multiply(&mut self.memory, &f_curr);
vectorops::element_multiply(&mut curr_state, &i_curr);
vectorops::element_add(&mut self.memory, &curr_state);
vectorops::element_multiply(&mut curr_output, &vectorops::element_activate(&self.memory, Activation::Tanh));
self.states.update_forward(f_curr, i_curr, g_out, o_out, self.memory.clone());
self.hidden = curr_output;
self.v_gate.write().unwrap().forward(&self.hidden)
}
#[inline]
pub fn step_forward(&mut self, inputs: &[f32]) -> Option<Vec<f32>> {
let mut hidden_input = self.hidden.clone();
hidden_input.extend(inputs);
let f_output = self.f_gate.write().unwrap().forward(&hidden_input)?;
let i_output = self.i_gate.write().unwrap().forward(&hidden_input)?;
let o_output = self.o_gate.write().unwrap().forward(&hidden_input)?;
let g_output = self.g_gate.write().unwrap().forward(&hidden_input)?;
let mut current_state = g_output.clone();
let mut current_output = o_output.clone();
vectorops::element_multiply(&mut self.memory, &f_output);
vectorops::element_multiply(&mut current_state, &i_output);
vectorops::element_add(&mut self.memory, ¤t_state);
vectorops::element_multiply(&mut current_output, &vectorops::element_activate(&self.memory, Activation::Tanh));
self.hidden = current_output;
self.v_gate.write().unwrap().forward(&self.hidden)
}
#[inline]
pub fn step_back(&mut self, errors: &Vec<f32>, l_rate: f32) -> Option<Vec<f32>> {
let dh_next = self.states.d_prev_hidden.clone()?;
let dc_next = self.states.d_prev_memory.clone()?;
let c_old = self.states.memory_states.pop()?;
let g_curr = self.states.s_gate_output.pop()?;
let i_curr = self.states.i_gate_output.pop()?;
let f_curr = self.states.f_gate_output.pop()?;
let o_curr = self.states.o_gate_output.pop()?;
let mut dh = self.v_gate.write().unwrap().backward(errors, l_rate)?;
vectorops::element_add(&mut dh, &dh_next);
let mut dho = vectorops::element_activate(&c_old, Activation::Tanh);
vectorops::element_multiply(&mut dho, &dh);
vectorops::element_multiply(&mut dho, &vectorops::element_deactivate(&o_curr, self.o_gate.read().unwrap().activation));
let o_gate_clone = Arc::clone(&self.o_gate);
let o_handle = thread::spawn(move || {
return o_gate_clone.write().unwrap().backward(&dho, l_rate).unwrap();
});
let mut dc = vectorops::product(&o_curr, &dh);
vectorops::element_multiply(&mut dc, &vectorops::element_deactivate(&c_old, Activation::Tanh));
vectorops::element_add(&mut dc, &dc_next);
let mut dhf = vectorops::product(&c_old, &dc);
vectorops::element_multiply(&mut dhf, &vectorops::element_deactivate(&f_curr, self.f_gate.read().unwrap().activation));
let f_gate_clone = Arc::clone(&self.f_gate);
let f_handle = thread::spawn(move || {
return f_gate_clone.write().unwrap().backward(&dhf, l_rate).unwrap();
});
let mut dhi = vectorops::product(&g_curr, &dc);
vectorops::element_multiply(&mut dhi, &vectorops::element_deactivate(&i_curr, self.i_gate.read().unwrap().activation));
let i_gate_clone = Arc::clone(&self.i_gate);
let i_handle = thread::spawn(move || {
return i_gate_clone.write().unwrap().backward(&dhi, l_rate).unwrap();
});
let mut dhc = vectorops::product(&i_curr, &dc);
vectorops::element_multiply(&mut dhc, &vectorops::element_deactivate(&g_curr, self.g_gate.read().unwrap().activation));
let g_gate_clone = Arc::clone(&self.g_gate);
let g_handle = thread::spawn(move || {
return g_gate_clone.write().unwrap().backward(&dhc, l_rate).unwrap();
});
let mut dx = vec![0.0; (self.input_size + self.memory_size) as usize];
vectorops::element_add(&mut dx, &o_handle.join().ok()?);
vectorops::element_add(&mut dx, &f_handle.join().ok()?);
vectorops::element_add(&mut dx, &i_handle.join().ok()?);
vectorops::element_add(&mut dx, &g_handle.join().ok()?);
let dh_next = dx[..self.memory_size as usize].to_vec();
let dc_next = vectorops::product(&f_curr, &dc);
self.states.d_prev_hidden = Some(dh_next);
self.states.d_prev_memory = Some(dc_next);
Some(dx[..self.input_size as usize].to_vec())
}
}
#[typetag::serde]
impl Layer for LSTM {
#[inline]
fn forward(&mut self, inputs: &Vec<f32>) -> Option<Vec<f32>> {
if self.f_gate.read().map(|x| x.trace_states.is_some()).ok()? {
return self.step_forward_async(inputs);
}
self.step_forward(inputs)
}
#[inline]
fn backward(&mut self, errors: &Vec<f32>, learning_rate: f32) -> Option<Vec<f32>> {
if self.states.d_prev_hidden.is_none() && self.states.d_prev_memory.is_none() {
self.states.d_prev_memory = Some(vec![0.0; self.memory_size as usize]);
self.states.d_prev_hidden = Some(vec![0.0; self.memory_size as usize]);
}
self.step_back(errors, learning_rate)
}
fn reset(&mut self) {
self.g_gate.write().unwrap().reset();
self.i_gate.write().unwrap().reset();
self.f_gate.write().unwrap().reset();
self.o_gate.write().unwrap().reset();
self.v_gate.write().unwrap().reset();
self.states = LSTMState::new();
self.memory = vec![0.0; self.memory_size as usize];
self.hidden = vec![0.0; self.memory_size as usize];
}
fn add_tracer(&mut self) {
self.g_gate.write().unwrap().add_tracer();
self.i_gate.write().unwrap().add_tracer();
self.f_gate.write().unwrap().add_tracer();
self.o_gate.write().unwrap().add_tracer();
self.v_gate.write().unwrap().add_tracer();
}
fn remove_tracer(&mut self) {
self.g_gate.write().unwrap().remove_tracer();
self.i_gate.write().unwrap().remove_tracer();
self.f_gate.write().unwrap().remove_tracer();
self.o_gate.write().unwrap().remove_tracer();
self.v_gate.write().unwrap().remove_tracer();
}
fn as_ref_any(&self) -> &dyn Any
where Self: Sized + 'static
{
self
}
fn as_mut_any(&mut self) -> &mut dyn Any
where Self: Sized + 'static
{
self
}
fn shape(&self) -> (usize, usize) {
(self.input_size as usize, self.output_size as usize)
}
}
impl Clone for LSTM {
#[inline]
fn clone(&self) -> Self {
LSTM {
input_size: self.input_size,
memory_size: self.memory_size,
output_size: self.output_size,
activation: self.activation.clone(),
memory: vec![0.0; self.memory_size as usize],
hidden: vec![0.0; self.memory_size as usize],
states: LSTMState::new(),
g_gate: Arc::new(RwLock::new((*self.g_gate.read().unwrap()).clone())),
i_gate: Arc::new(RwLock::new((*self.i_gate.read().unwrap()).clone())),
f_gate: Arc::new(RwLock::new((*self.f_gate.read().unwrap()).clone())),
o_gate: Arc::new(RwLock::new((*self.o_gate.read().unwrap()).clone())),
v_gate: Arc::new(RwLock::new((*self.v_gate.read().unwrap()).clone()))
}
}
}
impl Genome<LSTM, NeatEnvironment> for LSTM
where LSTM: Layer
{
#[inline]
fn crossover(child: &LSTM, parent_two: &LSTM, env: Arc<RwLock<NeatEnvironment>>, crossover_rate: f32) -> Option<LSTM> {
let child = LSTM {
input_size: child.input_size,
memory_size: child.memory_size,
output_size: child.output_size,
activation: child.activation,
memory: vec![0.0; child.memory_size as usize],
hidden: vec![0.0; child.memory_size as usize],
states: LSTMState::new(),
g_gate: Arc::new(RwLock::new(Dense::crossover(&child.g_gate.read().unwrap(), &parent_two.g_gate.read().unwrap(), Arc::clone(&env), crossover_rate)?)),
i_gate: Arc::new(RwLock::new(Dense::crossover(&child.i_gate.read().unwrap(), &parent_two.i_gate.read().unwrap(), Arc::clone(&env), crossover_rate)?)),
f_gate: Arc::new(RwLock::new(Dense::crossover(&child.f_gate.read().unwrap(), &parent_two.f_gate.read().unwrap(), Arc::clone(&env), crossover_rate)?)),
o_gate: Arc::new(RwLock::new(Dense::crossover(&child.o_gate.read().unwrap(), &parent_two.o_gate.read().unwrap(), Arc::clone(&env), crossover_rate)?)),
v_gate: Arc::new(RwLock::new(Dense::crossover(&child.v_gate.read().unwrap(), &parent_two.v_gate.read().unwrap(), Arc::clone(&env), crossover_rate)?)),
};
Some(child)
}
#[inline]
fn distance(one: &LSTM, two: &LSTM, env: Arc<RwLock<NeatEnvironment>>) -> f32 {
let mut result = 0.0;
result += Dense::distance(&one.g_gate.read().unwrap(), &two.g_gate.read().unwrap(), Arc::clone(&env));
result += Dense::distance(&one.i_gate.read().unwrap(), &two.i_gate.read().unwrap(), Arc::clone(&env));
result += Dense::distance(&one.f_gate.read().unwrap(), &two.f_gate.read().unwrap(), Arc::clone(&env));
result += Dense::distance(&one.o_gate.read().unwrap(), &two.o_gate.read().unwrap(), Arc::clone(&env));
result += Dense::distance(&one.v_gate.read().unwrap(), &two.v_gate.read().unwrap(), Arc::clone(&env));
result
}
}
impl fmt::Display for LSTM {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "LSTM=[input={}, memory={}, output={}]",
self.input_size, self.memory_size, self.output_size)
}
}