use std::vec::Vec;
use rand::distributions::Range;
use ndarray_rand::RandomExt;
use ndarray::prelude::*;
use ndarray::{Zip, Ix};
use itertools::Itertools;
use traits::{LearnRate, LearnMomentum, Predict, UpdateGradients, UpdateWeights};
use activation::Activation;
use topology::*;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
struct FullyConnectedLayer {
weights : Array2<f32>,
delta_weights: Array2<f32>,
outputs : Array1<f32>,
gradients : Array1<f32>,
activation : Activation,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
pub struct NeuralNet {
layers: Vec<FullyConnectedLayer>,
}
impl FullyConnectedLayer {
fn with_weights(weights: Array2<f32>, activation: Activation) -> Self {
use std::iter;
let (n_outputs, _) = weights.dim();
let biased_outputs = n_outputs + 1;
let biased_gradients = biased_outputs;
let biased_shape = weights.dim();
FullyConnectedLayer{
weights,
delta_weights: Array2::zeros(biased_shape),
outputs: Array1::from_iter(iter::repeat(0.0).take(n_outputs)),
gradients: Array1::zeros(biased_gradients),
activation: activation,
}
}
fn random(n_inputs: Ix, n_outputs: Ix, activation: Activation) -> Self {
assert!(n_inputs >= 1 && n_outputs >= 1);
let biased_inputs = n_inputs + 1;
let biased_shape = (n_outputs, biased_inputs);
FullyConnectedLayer::with_weights(
Array2::random(biased_shape, Range::new(-1.0, 1.0)), activation)
}
#[inline]
fn count_outputs(&self) -> Ix {
self.outputs.dim()
}
#[inline]
fn count_gradients(&self) -> Ix {
self.gradients.dim()
}
#[inline]
fn output_view(&self) -> ArrayView1<f32> {
self.outputs.view()
}
#[inline]
#[cfg(test)]
fn gradients_view(&self) -> ArrayView1<f32> {
self.gradients.view()
}
fn feed_forward(&mut self,
input: ArrayView1<f32>)
-> ArrayView1<f32> {
debug_assert_eq!(self.weights.rows(), self.count_outputs());
debug_assert_eq!(self.weights.cols(), input.len() + 1);
let act = self.activation;
Zip::from(&mut self.outputs).and(self.weights.genrows()).apply(|output, weights| {
let s = weights.len();
*output = act.base(weights.slice(s![..-1]).dot(&input) + weights[s-1]);
});
self.output_view() }
fn calculate_output_gradients(&mut self,
target_values: ArrayView1<f32>)
-> &Self {
debug_assert_eq!(self.count_outputs() , target_values.len());
debug_assert_eq!(self.count_gradients(), target_values.len() + 1);
let act = self.activation;
Zip::from(&mut self.gradients.slice_mut(s![..-1]))
.and(&target_values)
.and(&self.outputs)
.apply(|gradient, &target, &output| {
*gradient = (target - output) * act.derived(output)
});
self
}
#[inline]
fn reset_gradients(&mut self) {
self.gradients.fill(0.0);
debug_assert!(self.gradients.iter().all(|&g| g == 0.0));
}
fn apply_activation(&mut self) {
debug_assert_eq!(self.count_gradients(), self.count_outputs() + 1);
let act = self.activation; use std::iter;
izip!(self.gradients.iter_mut(), self.outputs.iter().chain(iter::once(&1.0)))
.foreach(|(gradient, &output)| *gradient *= act.derived(output));
}
fn propagate_gradients(&mut self,
prev: &FullyConnectedLayer)
-> &Self {
debug_assert_eq!(prev.weights.rows(), prev.count_gradients() - 1);
debug_assert_eq!(prev.weights.cols(), self.count_gradients());
izip!(prev.weights.genrows(), prev.gradients.iter())
.foreach(|(prev_weights_row, prev_gradient)| {
izip!(self.gradients.iter_mut(), prev_weights_row.iter())
.foreach(|(gradient, weight)| *gradient += weight * prev_gradient)
});
self.apply_activation();
self }
fn update_weights(&mut self,
prev_outputs: ArrayView1<f32>,
learn_rate : LearnRate,
learn_mom : LearnMomentum)
-> ArrayView1<f32> {
debug_assert_eq!(prev_outputs.len() + 1, self.weights.cols());
debug_assert_eq!(self.count_gradients(), self.weights.rows() + 1);
use std::iter;
izip!(self.weights.genrows_mut(),
self.delta_weights.genrows_mut(),
self.gradients.iter())
.foreach(|(mut weights_row, mut delta_weights_row, gradient)| {
izip!(prev_outputs.iter().chain(iter::once(&1.0)), delta_weights_row.iter_mut())
.foreach(|(prev_output, delta_weight)| {
*delta_weight =
learn_rate.0 * prev_output * gradient
+ learn_mom.0 * *delta_weight;
});
weights_row += &delta_weights_row;
});
self.reset_gradients();
self.output_view()
}
}
impl NeuralNet {
fn from_vec(layers: Vec<FullyConnectedLayer>) -> Self {
NeuralNet {
layers: layers
}
}
pub fn from_topology(topology: Topology) -> Self {
NeuralNet::from_vec(topology
.iter_layers()
.map(|&layer| {
FullyConnectedLayer::random(
layer.inputs, layer.outputs, layer.activation)
})
.collect()
)
}
}
impl<'b, A> Predict<A> for NeuralNet
where A: Into<ArrayView1<'b, f32>>
{
fn predict(&mut self, input: A) -> ArrayView1<f32> {
let input = input.into();
if let Some((first, tail)) = self.layers.split_first_mut() {
tail.iter_mut()
.fold(first.feed_forward(input),
|prev, layer| layer.feed_forward(prev))
} else {
panic!("A Neural Net is guaranteed to have at least one layer so this situation \
should never happen!");
}
}
}
impl<'a, A> UpdateGradients<A> for NeuralNet
where A: Into<ArrayView1<'a, f32>>
{
fn update_gradients(&mut self, target_values: A) {
if let Some((&mut ref mut last, ref mut tail)) = self.layers.split_last_mut() {
tail.iter_mut()
.rev()
.fold(last.calculate_output_gradients(target_values.into()),
|prev, layer| layer.propagate_gradients(prev));
}
}
}
impl<'b, A> UpdateWeights<A> for NeuralNet
where A: Into<ArrayView1<'b, f32>>
{
fn update_weights(&mut self, input: A, rate: LearnRate, momentum: LearnMomentum) {
let input = input.into();
if let Some((first, tail)) = self.layers.split_first_mut() {
tail.iter_mut()
.fold(first.update_weights(input, rate, momentum),
|prev, layer| layer.update_weights(prev, rate, momentum));
}
}
}
#[cfg(test)]
mod tests {
pub use super::*;
mod fully_connected_layer {
use super::*;
use std::iter;
#[test]
fn construction_invariants() {
use self::Activation::{Identity};
let weights = Array1::linspace(1.0, 12.0, 12).into_shape((3, 4)).unwrap();
let layer = FullyConnectedLayer::with_weights(weights.clone(), Identity);
assert_eq!(layer.weights, weights);
assert_eq!(layer.delta_weights, Array::zeros((3, 4)));
assert_eq!(layer.gradients, Array1::zeros(4));
let expected_outputs = Array1::from_iter(iter::repeat(0.0).take(3));
assert_eq!(layer.outputs, expected_outputs);
}
#[test]
fn feed_forward() {
use self::Activation::{Identity};
let mut layer = FullyConnectedLayer::with_weights(
Array1::linspace(1.0, 12.0, 12).into_shape((3, 4)).unwrap(), Identity);
let applier = Array1::linspace(1.0, 3.0, 3);
let outputs = layer.feed_forward(applier.view()).to_owned();
let targets = Array1::from_vec(vec![18.0, 46.0, 74.0]);
assert_eq!(outputs, targets);
}
#[test]
fn update_output_gradients() {
use self::Activation::{Identity};
let mut layer = FullyConnectedLayer::with_weights(
Array1::linspace(1.0, 12.0, 12).into_shape((3, 4)).unwrap(), Identity);
let expected = Array1::linspace(1.0, 3.0, 3);
let gradients = layer.gradients_view().to_owned();
let outputs = layer.output_view().to_owned();
let expected_gradients = Array1::zeros(4);
let expected_outputs = Array1::from_iter(iter::repeat(0.0).take(3));
assert_eq!(gradients, expected_gradients);
assert_eq!(outputs , expected_outputs);
assert_eq!(gradients, Array1::zeros(4));
layer.calculate_output_gradients(expected.view()).to_owned();
let targets = Array1::from_vec(vec![1.0, 2.0, 3.0, 0.0]);
let gradients = layer.gradients_view().to_owned();
assert_eq!(gradients, targets);
}
#[test]
fn propagate_gradients() {
use self::Activation::{Identity};
let fst_layer = FullyConnectedLayer{
weights : Array1::linspace(1.0, 12.0, 12).into_shape((3, 4)).unwrap(),
delta_weights: Array::zeros((3, 4)),
outputs : Array1::from_iter(iter::repeat(0.0).take(3)),
gradients : Array1::linspace(10.0, 40.0, 4),
activation : Identity
};
let mut snd_layer = FullyConnectedLayer::with_weights(
Array1::linspace(1.0, 12.0, 12).into_shape((3, 4)).unwrap(), Identity);
snd_layer.propagate_gradients(&fst_layer);
let expected_gradients = Array1::from_vec(vec![380.0, 440.0, 500.0, 560.0]);
assert_eq!(snd_layer.gradients_view().to_owned(), expected_gradients);
}
#[test]
fn update_weights() {
use self::Activation::{Identity};
let lr = LearnRate(0.5);
let lm = LearnMomentum(1.0);
let outputs = Array1::from_iter(iter::repeat(0.0).take(3));
let mut layer = FullyConnectedLayer{
weights : Array1::linspace(1.0, 12.0, 12).into_shape((3, 4)).unwrap(),
delta_weights: Array::zeros((3, 4)),
outputs : Array1::from_iter(iter::repeat(0.0).take(3)),
gradients : Array1::linspace(10.0, 40.0, 4),
activation : Identity
};
let result_outputs = layer.update_weights(outputs.view(), lr, lm).to_owned();
let target_outputs = Array::from_vec(vec![0.0, 0.0, 0.0]);
let result_weights = layer.weights.clone();
let target_weights = Array::from_vec(vec![
1.0, 2.0, 3.0, 9.0,
5.0, 6.0, 7.0, 18.0,
9.0, 10.0, 11.0, 27.0]).into_shape((3, 4)).unwrap();
assert_eq!(result_outputs, target_outputs);
assert_eq!(result_weights, target_weights);
}
}
#[test]
#[ignore]
fn equivalence() {
use self::Activation::{Identity, Tanh};
println!("1");
let mut merged = FullyConnectedLayer{
weights: Array::from_vec(vec![
1.0, 2.0, 3.0,
4.0, 5.0, 6.0
]).into_shape((2, 3)).unwrap(),
delta_weights: Array::zeros((2, 3)),
outputs: Array::zeros(2),
gradients: Array::zeros(3),
activation: Tanh
};
println!("2");
let mut weights_part = FullyConnectedLayer{
weights: Array::from_vec(vec![
1.0, 2.0, 3.0,
4.0, 5.0, 6.0
]).into_shape((2, 3)).unwrap(),
delta_weights: Array::zeros((2, 3)),
outputs: Array::zeros(2),
gradients: Array::zeros(3),
activation: Identity
};
println!("3");
let mut activation_part = FullyConnectedLayer{
weights: Array::from_vec(vec![
1.0, 0.0, 0.0,
0.0, 1.0, 0.0,
0.0, 0.0, 1.0
]).into_shape((3, 3)).unwrap(),
delta_weights: Array::zeros((3, 3)),
outputs: Array::zeros(2),
gradients: Array::zeros(3),
activation: Tanh
};
println!("4");
let input = Array::from_vec(vec![10.0, 20.0]);
println!("5");
let result_merged = merged.feed_forward(input.view()).to_owned();
println!("6");
weights_part.feed_forward(input.view());
let result_split_temp = weights_part.output_view().to_owned();
println!("7");
let result_split = activation_part.feed_forward(result_split_temp.view()).to_owned();
println!("8");
println!("result_merged = {:?}", result_merged);
println!("result_split_temp = {:?}", result_split_temp);
println!("result_split = {:?}", result_split);
assert_eq!(result_merged, result_split);
}
}