use std::rc::Rc;
use std::sync::Arc;
use ndarray;
use rand;
use nodes;
use nodes::{HogwildParameter, Node, ParameterNode};
use nn::uniform;
use {Arr, DataInput, Variable};
#[derive(Debug, Serialize, Deserialize)]
pub struct Parameters {
input_dim: usize,
hidden_dim: usize,
forget_weights: Arc<nodes::HogwildParameter>,
forget_biases: Arc<nodes::HogwildParameter>,
update_gate_weights: Arc<nodes::HogwildParameter>,
update_gate_biases: Arc<nodes::HogwildParameter>,
update_value_weights: Arc<nodes::HogwildParameter>,
update_value_biases: Arc<nodes::HogwildParameter>,
output_gate_weights: Arc<nodes::HogwildParameter>,
output_gate_biases: Arc<nodes::HogwildParameter>,
}
impl Clone for Parameters {
fn clone(&self) -> Self {
Parameters {
input_dim: self.input_dim,
hidden_dim: self.hidden_dim,
forget_weights: Arc::new(self.forget_weights.as_ref().clone()),
forget_biases: Arc::new(self.forget_biases.as_ref().clone()),
update_gate_weights: Arc::new(self.update_gate_weights.as_ref().clone()),
update_gate_biases: Arc::new(self.update_gate_biases.as_ref().clone()),
update_value_weights: Arc::new(self.update_gate_weights.as_ref().clone()),
update_value_biases: Arc::new(self.update_value_biases.as_ref().clone()),
output_gate_weights: Arc::new(self.output_gate_weights.as_ref().clone()),
output_gate_biases: Arc::new(self.output_gate_biases.as_ref().clone()),
}
}
}
impl Parameters {
pub fn new<R: rand::Rng>(input_dim: usize, hidden_dim: usize, rng: &mut R) -> Self {
let max = 1.0 / (hidden_dim as f32).sqrt();
let min = -max;
Self {
input_dim: input_dim,
hidden_dim: hidden_dim,
forget_weights: Arc::new(HogwildParameter::new(uniform(
input_dim + hidden_dim,
hidden_dim,
min,
max,
rng,
))),
forget_biases: Arc::new(HogwildParameter::new(uniform(1, hidden_dim, min, max, rng))),
update_gate_weights: Arc::new(HogwildParameter::new(uniform(
input_dim + hidden_dim,
hidden_dim,
min,
max,
rng,
))),
update_gate_biases: Arc::new(HogwildParameter::new(uniform(
1, hidden_dim, min, max, rng,
))),
update_value_weights: Arc::new(HogwildParameter::new(uniform(
input_dim + hidden_dim,
hidden_dim,
min,
max,
rng,
))),
update_value_biases: Arc::new(HogwildParameter::new(uniform(
1, hidden_dim, min, max, rng,
))),
output_gate_weights: Arc::new(HogwildParameter::new(uniform(
input_dim + hidden_dim,
hidden_dim,
min,
max,
rng,
))),
output_gate_biases: Arc::new(HogwildParameter::new(uniform(
1, hidden_dim, min, max, rng,
))),
}
}
fn inner_build_cell(&self, coupled: bool) -> Cell {
Cell {
input_dim: self.input_dim,
hidden_dim: self.hidden_dim,
coupled_input: coupled,
forget_weights: ParameterNode::shared(self.forget_weights.clone()),
forget_biases: ParameterNode::shared(self.forget_biases.clone()),
update_gate_weights: ParameterNode::shared(self.update_gate_weights.clone()),
update_gate_biases: ParameterNode::shared(self.update_gate_biases.clone()),
update_value_weights: ParameterNode::shared(self.update_value_weights.clone()),
update_value_biases: ParameterNode::shared(self.update_value_biases.clone()),
output_gate_weights: ParameterNode::shared(self.output_gate_weights.clone()),
output_gate_biases: ParameterNode::shared(self.output_gate_biases.clone()),
}
}
pub fn build(&self) -> Layer {
Layer::new(self.build_cell())
}
pub fn build_coupled(&self) -> Layer {
Layer::new(self.build_coupled_cell())
}
pub fn build_cell(&self) -> Cell {
self.inner_build_cell(false)
}
pub fn build_coupled_cell(&self) -> Cell {
self.inner_build_cell(true)
}
}
#[derive(Debug)]
pub struct Cell {
input_dim: usize,
hidden_dim: usize,
coupled_input: bool,
forget_weights: Variable<ParameterNode>,
forget_biases: Variable<ParameterNode>,
update_gate_weights: Variable<ParameterNode>,
update_gate_biases: Variable<ParameterNode>,
update_value_weights: Variable<ParameterNode>,
update_value_biases: Variable<ParameterNode>,
output_gate_weights: Variable<ParameterNode>,
output_gate_biases: Variable<ParameterNode>,
}
impl Cell {
#[cfg_attr(feature = "cargo-clippy", allow(needless_pass_by_value, type_complexity))]
pub fn forward<C, H, I>(
&self,
state: (Variable<C>, Variable<H>),
input: Variable<I>,
) -> (
Variable<Rc<Node<Value = Arr, InputGradient = Arr>>>,
Variable<Rc<Node<Value = Arr, InputGradient = Arr>>>,
)
where
C: Node<Value = Arr, InputGradient = Arr>,
H: Node<Value = Arr, InputGradient = Arr>,
I: Node<Value = Arr, InputGradient = Arr>,
{
let (cell, hidden) = state;
let stacked_input = hidden.stack(&input, ndarray::Axis(1));
let forget_gate =
(stacked_input.dot(&self.forget_weights) + self.forget_biases.clone()).sigmoid();
let cell = forget_gate.clone() * cell;
let update_gate = if self.coupled_input {
(1.0 - forget_gate).boxed()
} else {
(stacked_input.dot(&self.update_gate_weights) + self.update_gate_biases.clone())
.sigmoid()
.boxed()
};
let update_value = (stacked_input.dot(&self.update_value_weights)
+ self.update_value_biases.clone())
.tanh();
let update = update_gate * update_value;
let cell = cell + update;
let output_value = cell.tanh();
let output_gate = (stacked_input.dot(&self.output_gate_weights)
+ self.output_gate_biases.clone())
.sigmoid();
let hidden = output_gate * output_value;
(cell.boxed(), hidden.boxed())
}
}
#[derive(Debug)]
pub struct Layer {
cell: Cell,
state: Variable<nodes::InputNode>,
hidden: Variable<nodes::InputNode>,
}
impl Layer {
fn new(cell: Cell) -> Self {
let hidden_dim = cell.hidden_dim;
Layer {
cell: cell,
state: nodes::InputNode::new(Arr::zeros((1, hidden_dim))),
hidden: nodes::InputNode::new(Arr::zeros((1, hidden_dim))),
}
}
pub fn forward<T>(
&self,
inputs: &[Variable<T>],
) -> Vec<Variable<Rc<Node<Value = Arr, InputGradient = Arr>>>>
where
T: Node<Value = Arr, InputGradient = Arr>,
{
let mut state = (self.state.clone().boxed(), self.hidden.clone().boxed());
let outputs: Vec<_> = inputs
.iter()
.map(|input| {
state = self.cell.forward(state.clone(), input.clone());
state.1.clone()
})
.collect();
outputs
}
pub fn reset_state(&self) {
self.state.set_value(0.0);
self.hidden.set_value(0.0);
}
}
#[cfg(test)]
mod tests {
use std::ops::Deref;
use super::*;
use finite_difference;
use nn::losses::sparse_categorical_crossentropy;
use nn::xavier_normal;
use nodes::InputNode;
use optim::{Adam, Optimizer};
use DataInput;
const TOLERANCE: f32 = 0.2;
fn assert_close(x: &Arr, y: &Arr, tol: f32) {
assert!(
x.all_close(y, tol),
"{:#?} not within {} of {:#?}",
x,
tol,
y
);
}
fn pi_digits(num: usize) -> Vec<usize> {
let pi_str = include_str!("pi.txt");
pi_str
.chars()
.filter_map(|x| x.to_digit(10))
.map(|x| x as usize)
.take(num)
.collect()
}
#[test]
fn lstm_finite_difference() {
let num_steps = 10;
let dim = 10;
let mut xs: Vec<_> = (0..num_steps)
.map(|_| ParameterNode::new(xavier_normal(1, dim)))
.collect();
let lstm_params = Parameters::new(dim, dim, &mut rand::thread_rng());
let lstm = lstm_params.build();
let mut hidden_states = lstm.forward(&xs);
let mut hidden = hidden_states.last_mut().unwrap();
for x in &mut xs {
let (difference, gradient) = finite_difference(x, &mut hidden);
assert_close(&difference, &gradient, TOLERANCE);
}
let mut params = hidden.parameters().to_owned();
for x in params.iter_mut() {
let (difference, gradient) = finite_difference(x, hidden);
assert_close(&difference, &gradient, TOLERANCE);
}
}
#[test]
fn test_basic_lstm() {
let input_dim = 10;
let hidden_dim = 5;
let lstm_params = Parameters::new(input_dim, hidden_dim, &mut rand::thread_rng());
let lstm = lstm_params.build_cell();
let state = InputNode::new(Arr::zeros((1, hidden_dim)));
let hidden = InputNode::new(Arr::zeros((1, hidden_dim)));
let input = InputNode::new(xavier_normal(1, input_dim));
let mut state = lstm.forward((state, hidden), input.clone());
for _ in 0..200 {
state = lstm.forward(state.clone(), input.clone());
}
let (_, mut hidden) = state;
hidden.forward();
hidden.backward(1.0);
hidden.zero_gradient();
}
fn predicted_label(softmax_output: &Arr) -> usize {
softmax_output
.iter()
.enumerate()
.max_by(|&(_, x), &(_, y)| x.partial_cmp(y).unwrap())
.unwrap()
.0
}
#[test]
fn test_pi_digits() {
let num_epochs = 50;
let sequence_length = 4;
let num_digits = 10;
let input_dim = 16;
let hidden_dim = 32;
let lstm_params = Parameters::new(input_dim, hidden_dim, &mut rand::thread_rng());
let lstm = lstm_params.build();
let final_layer = ParameterNode::new(xavier_normal(hidden_dim, num_digits));
let embeddings = ParameterNode::new(xavier_normal(num_digits, input_dim));
let y = nodes::IndexInputNode::new(&vec![0]);
let inputs: Vec<_> = (0..sequence_length)
.map(|_| nodes::IndexInputNode::new(&vec![0]))
.collect();
let embeddings: Vec<_> = inputs
.iter()
.map(|input| embeddings.index(&input))
.collect();
let hidden_states = lstm.forward(&embeddings);
let hidden = hidden_states.last().unwrap();
let prediction = hidden.dot(&final_layer);
let mut loss = sparse_categorical_crossentropy(&prediction, &y);
let optimizer = Adam::new().learning_rate(0.01);
let digits = pi_digits(100);
let mut correct = 0;
let mut total = 0;
for _ in 0..num_epochs {
let mut loss_val = 0.0;
correct = 0;
total = 0;
for i in 0..(digits.len() - sequence_length - 1) {
let digit_chunk = &digits[i..(i + sequence_length + 1)];
if digit_chunk.len() < sequence_length + 1 {
break;
}
for (&digit, input) in digit_chunk[..digit_chunk.len() - 1].iter().zip(&inputs) {
input.set_value(digit);
}
let target_digit = *digit_chunk.last().unwrap();
y.set_value(target_digit);
loss.forward();
loss.backward(1.0);
loss_val += loss.value().scalar_sum();
optimizer.step(loss.parameters());
loss.zero_gradient();
if target_digit == predicted_label(prediction.value().deref()) {
correct += 1;
}
total += 1;
}
println!(
"Loss {}, accuracy {}",
loss_val,
correct as f32 / total as f32
);
}
assert!((correct as f32 / total as f32) > 0.75);
}
}