#[cfg(feature = "serde_flexbuffers")]
#[macro_use]
extern crate serde_derive;
use rand_distr::{Distribution, Normal};
use std::iter::{Enumerate, Zip};
use std::slice;
use std::time::{Duration, Instant};
use HaltCondition::{Epochs, Timer, MSE};
use LearningMode::Incremental;
const DEFAULT_LEARNING_RATE: f64 = 0.3;
const DEFAULT_LAMBDA: f64 = 0.0;
const DEFAULT_MOMENTUM: f64 = 0.0;
const DEFAULT_EPOCHS: u32 = 1000;
const SELU_FACTOR_A: f64 = 1.06071;
const SELU_FACTOR_B: f64 = 1.97126;
const PELU_FACTOR_A: f64 = 1.5;
const PELU_FACTOR_B: f64 = 2.0;
const LRELU_FACTOR: f64 = 0.33;
#[cfg_attr(feature = "serde_flexbuffers", derive(Deserialize, Serialize))]
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum Activation {
Sigmoid,
SELU,
PELU,
LRELU,
Linear,
Tanh,
}
#[derive(Debug, Copy, Clone)]
pub enum HaltCondition {
Epochs(u32),
MSE(f64),
Timer(Duration),
}
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum LearningMode {
Incremental,
}
#[derive(Debug)]
pub struct Trainer<'a, 'b> {
examples: &'b [(Vec<f64>, Vec<f64>)],
rate: f64,
momentum: f64,
lambda: f64,
log_interval: Option<u32>,
halt_condition: HaltCondition,
learning_mode: LearningMode,
nn: &'a mut NN,
}
impl<'a, 'b> Trainer<'a, 'b> {
pub fn rate(&mut self, rate: f64) -> &mut Trainer<'a, 'b> {
if rate <= 0f64 {
panic!("the learning rate must be a positive number");
}
self.rate = rate;
self
}
pub fn lambda(&mut self, lambda: f64) -> &mut Trainer<'a, 'b> {
if lambda < 0f64 {
panic!("the lambda value must be a positive number");
}
self.lambda = lambda;
self
}
pub fn momentum(&mut self, momentum: f64) -> &mut Trainer<'a, 'b> {
if momentum < 0f64 {
panic!("momentum must be positive");
}
self.momentum = momentum;
self
}
pub fn log_interval(&mut self, log_interval: Option<u32>) -> &mut Trainer<'a, 'b> {
match log_interval {
Some(interval) if interval < 1 => {
panic!("log interval must be Some positive number or None")
}
_ => (),
}
self.log_interval = log_interval;
self
}
pub fn halt_condition(&mut self, halt_condition: HaltCondition) -> &mut Trainer<'a, 'b> {
match halt_condition {
Epochs(epochs) if epochs < 1 => {
panic!("must train for at least one epoch")
}
MSE(mse) if mse <= 0f64 => {
panic!("MSE must be greater than 0")
}
_ => (),
}
self.halt_condition = halt_condition;
self
}
pub fn learning_mode(&mut self, learning_mode: LearningMode) -> &mut Trainer<'a, 'b> {
self.learning_mode = learning_mode;
self
}
pub fn go(&mut self) -> f64 {
self.nn.train_details(
self.examples,
self.rate,
self.lambda,
self.momentum,
self.log_interval,
self.halt_condition,
)
}
}
#[cfg_attr(feature = "serde_flexbuffers", derive(Deserialize, Serialize))]
#[derive(Debug, Clone)]
pub struct NN {
layers: Vec<Vec<Vec<f64>>>,
num_inputs: u32,
hid_act: Activation,
out_act: Activation,
}
impl NN {
pub fn new(
layers_sizes: &[u32],
hidden_activation: Activation,
output_activation: Activation,
) -> NN {
let mut rng = rand::thread_rng();
if layers_sizes.len() < 2 {
panic!("must have at least two layers");
}
for &layer_size in layers_sizes.iter() {
if layer_size < 1 {
panic!("can't have any empty layers");
}
}
let mut layers = Vec::new();
let mut it = layers_sizes.iter();
let first_layer_size = *it.next().unwrap();
let mut prev_layer_size = first_layer_size;
for &layer_size in it {
let mut layer: Vec<Vec<f64>> = Vec::new();
let mut init_std_scale = 2.0;
if hidden_activation == Activation::SELU {
init_std_scale = 1.0;
}
let normal = Normal::new(0.0, (init_std_scale / prev_layer_size as f64).sqrt())
.expect("can not init the normal distribution");
for _ in 0..layer_size {
let mut node: Vec<f64> = Vec::new();
for i in 0..prev_layer_size + 1 {
if i == 0
{
node.push(0.0);
} else {
let random_weight: f64 = normal.sample(&mut rng);
node.push(random_weight);
}
}
node.shrink_to_fit();
layer.push(node)
}
layer.shrink_to_fit();
layers.push(layer);
prev_layer_size = layer_size;
}
layers.shrink_to_fit();
NN {
layers,
num_inputs: first_layer_size,
hid_act: hidden_activation,
out_act: output_activation,
}
}
pub fn run(&self, inputs: &[f64]) -> Vec<f64> {
if inputs.len() as u32 != self.num_inputs {
panic!("input has a different length than the network's input layer");
}
self.do_run(inputs).pop().unwrap()
}
pub fn train<'b>(&'b mut self, examples: &'b [(Vec<f64>, Vec<f64>)]) -> Trainer {
Trainer {
examples,
rate: DEFAULT_LEARNING_RATE,
momentum: DEFAULT_MOMENTUM,
lambda: DEFAULT_LAMBDA,
log_interval: None,
halt_condition: Epochs(DEFAULT_EPOCHS),
learning_mode: Incremental,
nn: self,
}
}
#[cfg(feature = "serde_flexbuffers")]
pub fn to_flexbuffers(&self) -> Vec<u8> {
flexbuffers::to_vec(self).expect("encoding Flexbuffers failed")
}
#[cfg(feature = "serde_flexbuffers")]
pub fn from_flexbuffers(encoded: &[u8]) -> NN {
let network: NN = flexbuffers::from_slice(encoded).expect("decoding Flexbuffers failed");
network
}
fn train_details(
&mut self,
examples: &[(Vec<f64>, Vec<f64>)],
rate: f64,
lambda: f64,
momentum: f64,
log_interval: Option<u32>,
halt_condition: HaltCondition,
) -> f64 {
let input_layer_size = self.num_inputs;
let output_layer_size = self.layers[self.layers.len() - 1].len();
for &(ref inputs, ref outputs) in examples.iter() {
if inputs.len() as u32 != input_layer_size {
panic!("input has a different length than the network's input layer");
}
if outputs.len() != output_layer_size {
panic!("output has a different length than the network's output layer");
}
}
self.train_incremental(
examples,
rate,
lambda,
momentum,
log_interval,
halt_condition,
)
}
fn train_incremental(
&mut self,
examples: &[(Vec<f64>, Vec<f64>)],
rate: f64,
lambda: f64,
momentum: f64,
log_interval: Option<u32>,
halt_condition: HaltCondition,
) -> f64 {
let mut prev_deltas = self.make_weights_tracker(0.0f64);
let mut epochs = 0u32;
let mut training_error_rate = 0f64;
let start_time = Instant::now();
loop {
if epochs > 0 {
match log_interval {
Some(interval) if epochs % interval == 0 => {
println!("error rate: {}", training_error_rate);
}
_ => (),
}
match halt_condition {
Epochs(epochs_halt) => {
if epochs == epochs_halt {
break;
}
}
MSE(target_error) => {
if training_error_rate <= target_error {
break;
}
}
Timer(duration) => {
if start_time.elapsed() >= duration {
break;
}
}
}
}
training_error_rate = 0f64;
for &(ref inputs, ref targets) in examples.iter() {
let results = self.do_run(&inputs);
let weight_updates = self.calculate_weight_updates(&results, &targets);
training_error_rate += calculate_error(&results, &targets);
self.update_weights(&weight_updates, &mut prev_deltas, rate, lambda, momentum)
}
epochs += 1;
}
training_error_rate
}
fn do_run(&self, inputs: &[f64]) -> Vec<Vec<f64>> {
let mut results = Vec::new();
results.push(inputs.to_vec());
for (layer_index, layer) in self.layers.iter().enumerate() {
let mut layer_results = Vec::new();
for node in layer.iter() {
let activation;
if layer_index == self.layers.len() - 1
{
activation = self.out_act;
} else {
activation = self.hid_act;
}
match activation {
Activation::Sigmoid => {
layer_results.push(sigmoid(modified_dotprod(&node, &results[layer_index])))
}
Activation::SELU => {
layer_results.push(selu(modified_dotprod(&node, &results[layer_index])))
}
Activation::PELU => {
layer_results.push(pelu(modified_dotprod(&node, &results[layer_index])))
}
Activation::LRELU => {
layer_results.push(lrelu(modified_dotprod(&node, &results[layer_index])))
}
Activation::Linear => {
layer_results.push(linear(modified_dotprod(&node, &results[layer_index])))
}
Activation::Tanh => {
layer_results.push(tanh(modified_dotprod(&node, &results[layer_index])))
}
}
}
results.push(layer_results);
}
results
}
fn update_weights(
&mut self,
network_weight_updates: &[Vec<Vec<f64>>],
prev_deltas: &mut Vec<Vec<Vec<f64>>>,
rate: f64,
lambda: f64,
momentum: f64,
) {
for layer_index in 0..self.layers.len() {
let layer = &mut self.layers[layer_index];
let layer_weight_updates = &network_weight_updates[layer_index];
for node_index in 0..layer.len() {
let node = &mut layer[node_index];
let node_weight_updates = &layer_weight_updates[node_index];
for weight_index in 0..node.len() {
let weight_update = node_weight_updates[weight_index];
let prev_delta = prev_deltas[layer_index][node_index][weight_index];
let delta = (rate * weight_update) + (momentum * prev_delta);
node[weight_index] = (1.0 - rate * lambda) * node[weight_index] + delta;
prev_deltas[layer_index][node_index][weight_index] = delta;
}
}
}
}
fn calculate_weight_updates(
&self,
results: &[Vec<f64>],
targets: &[f64],
) -> Vec<Vec<Vec<f64>>> {
let mut network_errors: Vec<Vec<f64>> = Vec::new();
let mut network_weight_updates = Vec::new();
let layers = &self.layers;
let network_results = &results[1..];
let mut next_layer_nodes: Option<&Vec<Vec<f64>>> = None;
for (layer_index, (layer_nodes, layer_results)) in
iter_zip_enum(layers, network_results).rev()
{
let prev_layer_results = &results[layer_index];
let mut layer_errors = Vec::new();
let mut layer_weight_updates = Vec::new();
for (node_index, (node, &result)) in iter_zip_enum(layer_nodes, layer_results) {
let mut node_weight_updates = Vec::new();
let node_error;
if layer_index == layers.len() - 1 {
let act_deriv = match self.out_act {
Activation::Sigmoid => result * (1.0 - result),
Activation::SELU => {
if result >= 0.0f64 {
SELU_FACTOR_A
} else {
result + SELU_FACTOR_A * SELU_FACTOR_B
}
}
Activation::PELU => {
if result >= 0.0f64 {
PELU_FACTOR_A / PELU_FACTOR_B
} else {
(result + PELU_FACTOR_A) / PELU_FACTOR_B
}
}
Activation::LRELU => {
if result >= 0.0f64 {
1.0
} else {
LRELU_FACTOR
}
}
Activation::Linear => 1.0,
Activation::Tanh => 1.0 - result * result,
};
node_error = act_deriv * (targets[node_index] - result);
} else {
let mut sum = 0f64;
let next_layer_errors = &network_errors[network_errors.len() - 1];
for (next_node, &next_node_error_data) in next_layer_nodes
.unwrap()
.iter()
.zip((next_layer_errors).iter())
{
sum += next_node[node_index + 1] * next_node_error_data;
}
let act_deriv = match self.hid_act {
Activation::Sigmoid => result * (1.0 - result),
Activation::SELU => {
if result >= 0.0f64 {
SELU_FACTOR_A
} else {
result + SELU_FACTOR_A * SELU_FACTOR_B
}
}
Activation::PELU => {
if result >= 0.0f64 {
PELU_FACTOR_A / PELU_FACTOR_B
} else {
(result + PELU_FACTOR_A) / PELU_FACTOR_B
}
}
Activation::LRELU => {
if result >= 0.0f64 {
1.0
} else {
LRELU_FACTOR
}
}
Activation::Linear => 1.0,
Activation::Tanh => 1.0 - result * result,
};
node_error = act_deriv * sum;
}
for weight_index in 0..node.len() {
let prev_layer_result;
if weight_index == 0 {
prev_layer_result = 1f64;
} else {
prev_layer_result = prev_layer_results[weight_index - 1];
}
let weight_update = node_error * prev_layer_result;
node_weight_updates.push(weight_update);
}
layer_errors.push(node_error);
layer_weight_updates.push(node_weight_updates);
}
network_errors.push(layer_errors);
network_weight_updates.push(layer_weight_updates);
next_layer_nodes = Some(&layer_nodes);
}
network_weight_updates.reverse();
network_weight_updates
}
fn make_weights_tracker<T: Clone>(&self, place_holder: T) -> Vec<Vec<Vec<T>>> {
let mut network_level = Vec::new();
for layer in self.layers.iter() {
let mut layer_level = Vec::new();
for node in layer.iter() {
let mut node_level = Vec::new();
for _ in node.iter() {
node_level.push(place_holder.clone());
}
layer_level.push(node_level);
}
network_level.push(layer_level);
}
network_level
}
}
fn modified_dotprod(node: &[f64], values: &[f64]) -> f64 {
let mut it = node.iter();
let mut total = *it.next().unwrap();
for (weight, value) in it.zip(values.iter()) {
total += weight * value;
}
total
}
fn sigmoid(y: f64) -> f64 {
1f64 / (1f64 + (-y).exp())
}
fn selu(y: f64) -> f64 {
SELU_FACTOR_A
* if y < 0.0 {
SELU_FACTOR_B * y.exp() - SELU_FACTOR_B
} else {
y
}
}
fn pelu(y: f64) -> f64 {
if y < 0.0 {
PELU_FACTOR_A * (y / PELU_FACTOR_B).exp() - PELU_FACTOR_A
} else {
(PELU_FACTOR_A / PELU_FACTOR_B) * y
}
}
fn lrelu(y: f64) -> f64 {
if y < 0.0 {
LRELU_FACTOR * y
} else {
y
}
}
fn linear(y: f64) -> f64 {
y
}
fn tanh(y: f64) -> f64 {
y.tanh()
}
fn iter_zip_enum<'s, 't, S: 's, T: 't>(
s: &'s [S],
t: &'t [T],
) -> Enumerate<Zip<slice::Iter<'s, S>, slice::Iter<'t, T>>> {
s.iter().zip(t.iter()).enumerate()
}
fn calculate_error(results: &[Vec<f64>], targets: &[f64]) -> f64 {
let last_results = &results[results.len() - 1];
let mut total: f64 = 0f64;
for (&result, &target) in last_results.iter().zip(targets.iter()) {
total += (target - result).powi(2);
}
total / (last_results.len() as f64)
}