use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::layer_weight::{LSTMGateWeight, LSTMLayerWeight, LayerWeight};
use crate::neural_network::layer::recurrent_layer::apply_sigmoid;
use crate::neural_network::layer::recurrent_layer::gate::{
Gate, compute_gate_value, store_gate_gradients, take_cache, update_gate_ada_grad,
update_gate_adam, update_gate_rmsprop, update_gate_sgd,
};
use crate::neural_network::layer::recurrent_layer::input_validation_function::{
validate_input_3d, validate_recurrent_dimensions,
};
use crate::neural_network::neural_network_trait::{ActivationLayer, Layer};
use ndarray::{Array2, Array3, Axis, Ix2, Ix3};
const LSTM_PARALLEL_THRESHOLD: usize = 1024;
pub struct LSTM<T: ActivationLayer> {
input_dim: usize,
units: usize,
input_gate: Gate,
forget_gate: Gate,
cell_gate: Gate,
output_gate: Gate,
input_cache: Option<Array3<f32>>,
hidden_cache: Option<Vec<Array2<f32>>>, cell_cache: Option<Vec<Array2<f32>>>, cell_activated_cache: Option<Vec<Array2<f32>>>,
i_cache: Option<Vec<Array2<f32>>>, f_cache: Option<Vec<Array2<f32>>>, g_cache: Option<Vec<Array2<f32>>>, o_cache: Option<Vec<Array2<f32>>>,
activation: T,
}
impl<T: ActivationLayer> LSTM<T> {
pub fn new(input_dim: usize, units: usize, activation: T) -> Result<Self, ModelError> {
validate_recurrent_dimensions(input_dim, units)?;
Ok(Self {
input_dim,
units,
input_gate: Gate::new(input_dim, units, 0.0)?,
forget_gate: Gate::new(input_dim, units, 1.0)?, cell_gate: Gate::new(input_dim, units, 0.0)?,
output_gate: Gate::new(input_dim, units, 0.0)?,
input_cache: None,
hidden_cache: None,
cell_cache: None,
cell_activated_cache: None,
i_cache: None,
f_cache: None,
g_cache: None,
o_cache: None,
activation,
})
}
pub fn set_weights(
&mut self,
input_kernel: Array2<f32>,
input_recurrent_kernel: Array2<f32>,
input_bias: Array2<f32>,
forget_kernel: Array2<f32>,
forget_recurrent_kernel: Array2<f32>,
forget_bias: Array2<f32>,
cell_kernel: Array2<f32>,
cell_recurrent_kernel: Array2<f32>,
cell_bias: Array2<f32>,
output_kernel: Array2<f32>,
output_recurrent_kernel: Array2<f32>,
output_bias: Array2<f32>,
) {
self.input_gate.kernel = input_kernel;
self.input_gate.recurrent_kernel = input_recurrent_kernel;
self.input_gate.bias = input_bias;
self.forget_gate.kernel = forget_kernel;
self.forget_gate.recurrent_kernel = forget_recurrent_kernel;
self.forget_gate.bias = forget_bias;
self.cell_gate.kernel = cell_kernel;
self.cell_gate.recurrent_kernel = cell_recurrent_kernel;
self.cell_gate.bias = cell_bias;
self.output_gate.kernel = output_kernel;
self.output_gate.recurrent_kernel = output_recurrent_kernel;
self.output_gate.bias = output_bias;
}
}
impl<T: ActivationLayer> Layer for LSTM<T> {
fn forward(&mut self, input: &Tensor) -> Result<Tensor, ModelError> {
validate_input_3d(input)?;
let x3 = input.view().into_dimensionality::<Ix3>().unwrap();
let (batch, timesteps, _) = (x3.shape()[0], x3.shape()[1], x3.shape()[2]);
self.input_cache = Some(x3.to_owned());
let mut h_prev = Array2::<f32>::zeros((batch, self.units));
let mut c_prev = Array2::<f32>::zeros((batch, self.units));
let mut hs = Vec::with_capacity(timesteps + 1);
let mut cs = Vec::with_capacity(timesteps + 1);
let mut cs_activated = Vec::with_capacity(timesteps);
let mut i_vals = Vec::with_capacity(timesteps);
let mut f_vals = Vec::with_capacity(timesteps);
let mut g_vals = Vec::with_capacity(timesteps);
let mut o_vals = Vec::with_capacity(timesteps);
hs.push(h_prev.clone());
cs.push(c_prev.clone());
let use_parallel = batch * self.units >= LSTM_PARALLEL_THRESHOLD;
for t in 0..timesteps {
let x_t = x3.index_axis(Axis(1), t).to_owned();
let (i_raw, f_raw, g_raw, o_raw) = if use_parallel {
let ((i_raw, f_raw), (g_raw, o_raw)) = rayon::join(
|| {
rayon::join(
|| compute_gate_value(&self.input_gate, &x_t, &h_prev),
|| compute_gate_value(&self.forget_gate, &x_t, &h_prev),
)
},
|| {
rayon::join(
|| compute_gate_value(&self.cell_gate, &x_t, &h_prev),
|| compute_gate_value(&self.output_gate, &x_t, &h_prev),
)
},
);
(i_raw, f_raw, g_raw, o_raw)
} else {
(
compute_gate_value(&self.input_gate, &x_t, &h_prev),
compute_gate_value(&self.forget_gate, &x_t, &h_prev),
compute_gate_value(&self.cell_gate, &x_t, &h_prev),
compute_gate_value(&self.output_gate, &x_t, &h_prev),
)
};
let (i_t, f_t, g_t, o_t) = if use_parallel {
let ((i_t, f_t), (g_t, o_t)) = rayon::join(
|| rayon::join(|| apply_sigmoid(i_raw), || apply_sigmoid(f_raw)),
|| {
rayon::join(
|| {
g_raw.mapv(|x| {
let clipped_x = x.clamp(-500.0, 500.0);
clipped_x.tanh()
})
},
|| apply_sigmoid(o_raw),
)
},
);
(i_t, f_t, g_t, o_t)
} else {
(
apply_sigmoid(i_raw),
apply_sigmoid(f_raw),
g_raw.mapv(|x| {
let clipped_x = x.clamp(-500.0, 500.0);
clipped_x.tanh()
}),
apply_sigmoid(o_raw),
)
};
let c_t = &f_t * &c_prev + &i_t * &g_t;
let c_t_activated = c_t.mapv(|x| {
let clipped_x = x.clamp(-500.0, 500.0);
clipped_x.tanh()
});
let h_t = &o_t * &c_t_activated;
i_vals.push(i_t);
f_vals.push(f_t);
g_vals.push(g_t);
o_vals.push(o_t);
cs.push(c_t.clone());
cs_activated.push(c_t_activated);
hs.push(h_t.clone());
h_prev = h_t;
c_prev = c_t;
}
self.hidden_cache = Some(hs);
self.cell_cache = Some(cs);
self.cell_activated_cache = Some(cs_activated);
self.i_cache = Some(i_vals);
self.f_cache = Some(f_vals);
self.g_cache = Some(g_vals);
self.o_cache = Some(o_vals);
self.activation.forward(&h_prev.into_dyn())
}
fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
let grad_upstream = self.activation.backward(grad_output)?;
let grad_h_t = grad_upstream.into_dimensionality::<Ix2>().unwrap();
let error_msg = "Forward pass has not been run";
let x3 = take_cache(&mut self.input_cache, error_msg)?;
let hs = take_cache(&mut self.hidden_cache, error_msg)?;
let cs = take_cache(&mut self.cell_cache, error_msg)?;
let cs_activated = take_cache(&mut self.cell_activated_cache, error_msg)?;
let i_vals = take_cache(&mut self.i_cache, error_msg)?;
let f_vals = take_cache(&mut self.f_cache, error_msg)?;
let g_vals = take_cache(&mut self.g_cache, error_msg)?;
let o_vals = take_cache(&mut self.o_cache, error_msg)?;
let batch = x3.shape()[0];
let timesteps = x3.shape()[1];
let feat = x3.shape()[2];
let mut grad_i_kernel = Array2::<f32>::zeros((self.input_dim, self.units));
let mut grad_i_recurrent = Array2::<f32>::zeros((self.units, self.units));
let mut grad_i_bias = Array2::<f32>::zeros((1, self.units));
let mut grad_f_kernel = Array2::<f32>::zeros((self.input_dim, self.units));
let mut grad_f_recurrent = Array2::<f32>::zeros((self.units, self.units));
let mut grad_f_bias = Array2::<f32>::zeros((1, self.units));
let mut grad_g_kernel = Array2::<f32>::zeros((self.input_dim, self.units));
let mut grad_g_recurrent = Array2::<f32>::zeros((self.units, self.units));
let mut grad_g_bias = Array2::<f32>::zeros((1, self.units));
let mut grad_o_kernel = Array2::<f32>::zeros((self.input_dim, self.units));
let mut grad_o_recurrent = Array2::<f32>::zeros((self.units, self.units));
let mut grad_o_bias = Array2::<f32>::zeros((1, self.units));
let mut grad_x3 = Array3::<f32>::zeros((batch, timesteps, feat));
let mut grad_h = grad_h_t;
let mut grad_c = Array2::<f32>::zeros((batch, self.units));
let use_parallel = batch * self.units >= LSTM_PARALLEL_THRESHOLD;
for t in (0..timesteps).rev() {
let h_prev = &hs[t];
let c_prev = &cs[t];
let c_t_activated = &cs_activated[t];
let i_t = &i_vals[t];
let f_t = &f_vals[t];
let g_t = &g_vals[t];
let o_t = &o_vals[t];
let grad_o_t = &grad_h * c_t_activated;
grad_c = grad_c + &(&grad_h * o_t * &(1.0 - c_t_activated * c_t_activated));
let grad_f_t = &grad_c * c_prev;
let grad_i_t = &grad_c * g_t;
let grad_g_t = &grad_c * i_t;
let grad_c_prev = &grad_c * f_t;
let (grad_o_raw, grad_f_raw, grad_i_raw, grad_g_raw) = if use_parallel {
let ((grad_o_raw, grad_f_raw), (grad_i_raw, grad_g_raw)) = rayon::join(
|| {
rayon::join(
|| &grad_o_t * o_t * &(1.0 - o_t), || &grad_f_t * f_t * &(1.0 - f_t), )
},
|| {
rayon::join(
|| &grad_i_t * i_t * &(1.0 - i_t), || &grad_g_t * &(1.0 - g_t * g_t), )
},
);
(grad_o_raw, grad_f_raw, grad_i_raw, grad_g_raw)
} else {
(
&grad_o_t * o_t * &(1.0 - o_t), &grad_f_t * f_t * &(1.0 - f_t), &grad_i_t * i_t * &(1.0 - i_t), &grad_g_t * &(1.0 - g_t * g_t), )
};
let x_t = x3.index_axis(Axis(1), t).to_owned();
let x_t_t = x_t.t();
let h_prev_t = h_prev.t();
let compute_gate_gradients = |grad_raw: &Array2<f32>| {
let kernel_update = x_t_t.dot(grad_raw);
let recurrent_update = h_prev_t.dot(grad_raw);
let bias_update = grad_raw.sum_axis(Axis(0)).insert_axis(Axis(0));
(kernel_update, recurrent_update, bias_update)
};
let (o_updates, f_updates, i_updates, g_updates) = if use_parallel {
let ((o_updates, f_updates), (i_updates, g_updates)) = rayon::join(
|| {
rayon::join(
|| compute_gate_gradients(&grad_o_raw),
|| compute_gate_gradients(&grad_f_raw),
)
},
|| {
rayon::join(
|| compute_gate_gradients(&grad_i_raw),
|| compute_gate_gradients(&grad_g_raw),
)
},
);
(o_updates, f_updates, i_updates, g_updates)
} else {
(
compute_gate_gradients(&grad_o_raw),
compute_gate_gradients(&grad_f_raw),
compute_gate_gradients(&grad_i_raw),
compute_gate_gradients(&grad_g_raw),
)
};
grad_o_kernel = grad_o_kernel + &o_updates.0;
grad_o_recurrent = grad_o_recurrent + &o_updates.1;
grad_o_bias = grad_o_bias + &o_updates.2;
grad_f_kernel = grad_f_kernel + &f_updates.0;
grad_f_recurrent = grad_f_recurrent + &f_updates.1;
grad_f_bias = grad_f_bias + &f_updates.2;
grad_i_kernel = grad_i_kernel + &i_updates.0;
grad_i_recurrent = grad_i_recurrent + &i_updates.1;
grad_i_bias = grad_i_bias + &i_updates.2;
grad_g_kernel = grad_g_kernel + &g_updates.0;
grad_g_recurrent = grad_g_recurrent + &g_updates.1;
grad_g_bias = grad_g_bias + &g_updates.2;
let (dx, grad_h_next) = if use_parallel {
rayon::join(
|| {
grad_o_raw.dot(&self.output_gate.kernel.t())
+ grad_f_raw.dot(&self.forget_gate.kernel.t())
+ grad_i_raw.dot(&self.input_gate.kernel.t())
+ grad_g_raw.dot(&self.cell_gate.kernel.t())
},
|| {
grad_o_raw.dot(&self.output_gate.recurrent_kernel.t())
+ grad_f_raw.dot(&self.forget_gate.recurrent_kernel.t())
+ grad_i_raw.dot(&self.input_gate.recurrent_kernel.t())
+ grad_g_raw.dot(&self.cell_gate.recurrent_kernel.t())
},
)
} else {
(
grad_o_raw.dot(&self.output_gate.kernel.t())
+ grad_f_raw.dot(&self.forget_gate.kernel.t())
+ grad_i_raw.dot(&self.input_gate.kernel.t())
+ grad_g_raw.dot(&self.cell_gate.kernel.t()),
grad_o_raw.dot(&self.output_gate.recurrent_kernel.t())
+ grad_f_raw.dot(&self.forget_gate.recurrent_kernel.t())
+ grad_i_raw.dot(&self.input_gate.recurrent_kernel.t())
+ grad_g_raw.dot(&self.cell_gate.recurrent_kernel.t()),
)
};
grad_x3.index_axis_mut(Axis(1), t).assign(&dx);
grad_h = grad_h_next;
grad_c = grad_c_prev;
}
store_gate_gradients(
&mut self.input_gate,
grad_i_kernel,
grad_i_recurrent,
grad_i_bias,
);
store_gate_gradients(
&mut self.forget_gate,
grad_f_kernel,
grad_f_recurrent,
grad_f_bias,
);
store_gate_gradients(
&mut self.cell_gate,
grad_g_kernel,
grad_g_recurrent,
grad_g_bias,
);
store_gate_gradients(
&mut self.output_gate,
grad_o_kernel,
grad_o_recurrent,
grad_o_bias,
);
Ok(grad_x3.into_dyn())
}
fn layer_type(&self) -> &str {
"LSTM"
}
fn output_shape(&self) -> String {
format!("(None, {})", self.units)
}
fn param_count(&self) -> TrainingParameters {
TrainingParameters::Trainable(
4 * (self.input_dim * self.units + self.units * self.units + self.units),
)
}
fn update_parameters_sgd(&mut self, lr: f32) {
update_gate_sgd(&mut self.input_gate, lr);
update_gate_sgd(&mut self.forget_gate, lr);
update_gate_sgd(&mut self.cell_gate, lr);
update_gate_sgd(&mut self.output_gate, lr);
}
fn update_parameters_adam(&mut self, lr: f32, beta1: f32, beta2: f32, epsilon: f32, t: u64) {
update_gate_adam(
&mut self.input_gate,
self.input_dim,
self.units,
lr,
beta1,
beta2,
epsilon,
t,
);
update_gate_adam(
&mut self.forget_gate,
self.input_dim,
self.units,
lr,
beta1,
beta2,
epsilon,
t,
);
update_gate_adam(
&mut self.cell_gate,
self.input_dim,
self.units,
lr,
beta1,
beta2,
epsilon,
t,
);
update_gate_adam(
&mut self.output_gate,
self.input_dim,
self.units,
lr,
beta1,
beta2,
epsilon,
t,
);
}
fn update_parameters_rmsprop(&mut self, lr: f32, rho: f32, epsilon: f32) {
update_gate_rmsprop(
&mut self.input_gate,
self.input_dim,
self.units,
lr,
rho,
epsilon,
);
update_gate_rmsprop(
&mut self.forget_gate,
self.input_dim,
self.units,
lr,
rho,
epsilon,
);
update_gate_rmsprop(
&mut self.cell_gate,
self.input_dim,
self.units,
lr,
rho,
epsilon,
);
update_gate_rmsprop(
&mut self.output_gate,
self.input_dim,
self.units,
lr,
rho,
epsilon,
);
}
fn update_parameters_ada_grad(&mut self, lr: f32, epsilon: f32) {
update_gate_ada_grad(
&mut self.input_gate,
self.input_dim,
self.units,
lr,
epsilon,
);
update_gate_ada_grad(
&mut self.forget_gate,
self.input_dim,
self.units,
lr,
epsilon,
);
update_gate_ada_grad(&mut self.cell_gate, self.input_dim, self.units, lr, epsilon);
update_gate_ada_grad(
&mut self.output_gate,
self.input_dim,
self.units,
lr,
epsilon,
);
}
fn get_weights(&self) -> LayerWeight<'_> {
LayerWeight::LSTM(LSTMLayerWeight {
input: LSTMGateWeight {
kernel: &self.input_gate.kernel,
recurrent_kernel: &self.input_gate.recurrent_kernel,
bias: &self.input_gate.bias,
},
forget: LSTMGateWeight {
kernel: &self.forget_gate.kernel,
recurrent_kernel: &self.forget_gate.recurrent_kernel,
bias: &self.forget_gate.bias,
},
cell: LSTMGateWeight {
kernel: &self.cell_gate.kernel,
recurrent_kernel: &self.cell_gate.recurrent_kernel,
bias: &self.cell_gate.bias,
},
output: LSTMGateWeight {
kernel: &self.output_gate.kernel,
recurrent_kernel: &self.output_gate.recurrent_kernel,
bias: &self.output_gate.bias,
},
})
}
}