use crate::history::Features;
use crate::history::History;
use crate::network::Network;
use crate::node::Node;
use crate::settings::Settings;
use crate::training_sample_generator::TrainingSampleGenerator;
#[derive(Debug)]
pub struct Trainer<'a> {
settings: &'a Settings,
history: &'a mut History,
}
impl<'a> Trainer<'a> {
pub fn new(settings: &'a Settings, history: &'a mut History) -> Trainer<'a> {
Trainer { settings, history }
}
pub fn train(&mut self) {
let lr = 0.000_005;
let momentum = 0.0;
let batch_size = 1000;
let plateau_threshold = 3000;
let generator = TrainingSampleGenerator::new(self.settings, self.history);
println!(
"Evaluating error rate on current {:#?}",
self.history.network
);
let mut best_overall_network = self.history.network;
let mut best_overall_error = self
.history
.network
.average_error(&generator, batch_size * 10);
println!("Current network error rate is {}", best_overall_error);
loop {
let mut best_restart_network = Network::random();
let mut best_restart_error = 10000.0;
let mut cycles_since_best_restart_error = 0;
let mut network = Network::random();
let mut node_increments = [Node::default(), Node::default(), Node::default()];
let mut output_increments = [0.0, 0.0, 0.0, 0.0];
loop {
let mut batch_error = 0.0;
let mut batch_samples = 0.0;
generator.generate(Some(batch_size), |features: &Features, correct: bool| {
let target = if correct { 1.0 } else { -1.0 };
network.compute(features);
let error = 0.5 * (target - network.final_output).powi(2);
batch_error += error;
batch_samples += 1.0;
let d_e_d_o_3 = -(target - network.final_output);
let d_o_3_d_s_3 = 1.0 - network.final_sum.tanh().powi(2);
output_increments[0] =
momentum * output_increments[0] + lr * d_e_d_o_3 * d_o_3_d_s_3 * 1.0;
output_increments[1] = momentum * output_increments[1]
+ lr * d_e_d_o_3 * d_o_3_d_s_3 * network.hidden_node_outputs[0];
output_increments[2] = momentum * output_increments[2]
+ lr * d_e_d_o_3 * d_o_3_d_s_3 * network.hidden_node_outputs[1];
output_increments[3] = momentum * output_increments[3]
+ lr * d_e_d_o_3 * d_o_3_d_s_3 * network.hidden_node_outputs[2];
let d_s_3_d_o_0 = network.final_weights[0];
let d_s_3_d_o_1 = network.final_weights[1];
let d_s_3_d_o_2 = network.final_weights[2];
let d_o_0_d_s_0 = 1.0 - network.hidden_node_sums[0].tanh().powi(2);
let d_o_1_d_s_1 = 1.0 - network.hidden_node_sums[1].tanh().powi(2);
let d_o_2_d_s_2 = 1.0 - network.hidden_node_sums[2].tanh().powi(2);
let d_e_d_s_0 = d_e_d_o_3 * d_o_3_d_s_3 * d_s_3_d_o_0 * d_o_0_d_s_0;
let d_e_d_s_1 = d_e_d_o_3 * d_o_3_d_s_3 * d_s_3_d_o_1 * d_o_1_d_s_1;
let d_e_d_s_2 = d_e_d_o_3 * d_o_3_d_s_3 * d_s_3_d_o_2 * d_o_2_d_s_2;
node_increments[0].offset =
momentum * node_increments[0].offset + lr * d_e_d_s_0 * 1.0;
node_increments[0].age =
momentum * node_increments[0].age + lr * d_e_d_s_0 * features.age_factor;
node_increments[0].length = momentum * node_increments[0].length
+ lr * d_e_d_s_0 * features.length_factor;
node_increments[0].exit =
momentum * node_increments[0].exit + lr * d_e_d_s_0 * features.exit_factor;
node_increments[0].recent_failure = momentum
* node_increments[0].recent_failure
+ lr * d_e_d_s_0 * features.recent_failure_factor;
node_increments[0].selected_dir = momentum * node_increments[0].selected_dir
+ lr * d_e_d_s_0 * features.selected_dir_factor;
node_increments[0].dir =
momentum * node_increments[0].dir + lr * d_e_d_s_0 * features.dir_factor;
node_increments[0].overlap = momentum * node_increments[0].overlap
+ lr * d_e_d_s_0 * features.overlap_factor;
node_increments[0].immediate_overlap = momentum
* node_increments[0].immediate_overlap
+ lr * d_e_d_s_0 * features.immediate_overlap_factor;
node_increments[0].selected_occurrences = momentum
* node_increments[0].selected_occurrences
+ lr * d_e_d_s_0 * features.selected_occurrences_factor;
node_increments[0].occurrences = momentum * node_increments[0].occurrences
+ lr * d_e_d_s_0 * features.occurrences_factor;
node_increments[1].offset =
momentum * node_increments[1].offset + lr * d_e_d_s_1 * 1.0;
node_increments[1].age =
momentum * node_increments[1].age + lr * d_e_d_s_1 * features.age_factor;
node_increments[1].length = momentum * node_increments[1].length
+ lr * d_e_d_s_1 * features.length_factor;
node_increments[1].exit =
momentum * node_increments[1].exit + lr * d_e_d_s_1 * features.exit_factor;
node_increments[1].recent_failure = momentum
* node_increments[1].recent_failure
+ lr * d_e_d_s_1 * features.recent_failure_factor;
node_increments[1].selected_dir = momentum * node_increments[1].selected_dir
+ lr * d_e_d_s_1 * features.selected_dir_factor;
node_increments[1].dir =
momentum * node_increments[1].dir + lr * d_e_d_s_1 * features.dir_factor;
node_increments[1].overlap = momentum * node_increments[1].overlap
+ lr * d_e_d_s_1 * features.overlap_factor;
node_increments[1].immediate_overlap = momentum
* node_increments[1].immediate_overlap
+ lr * d_e_d_s_1 * features.immediate_overlap_factor;
node_increments[1].selected_occurrences = momentum
* node_increments[1].selected_occurrences
+ lr * d_e_d_s_1 * features.selected_occurrences_factor;
node_increments[1].occurrences = momentum * node_increments[1].occurrences
+ lr * d_e_d_s_1 * features.occurrences_factor;
node_increments[2].offset =
momentum * node_increments[2].offset + lr * d_e_d_s_2 * 1.0;
node_increments[2].age =
momentum * node_increments[2].age + lr * d_e_d_s_2 * features.age_factor;
node_increments[2].length = momentum * node_increments[2].length
+ lr * d_e_d_s_2 * features.length_factor;
node_increments[2].exit =
momentum * node_increments[2].exit + lr * d_e_d_s_2 * features.exit_factor;
node_increments[2].recent_failure = momentum
* node_increments[2].recent_failure
+ lr * d_e_d_s_2 * features.recent_failure_factor;
node_increments[2].selected_dir = momentum * node_increments[2].selected_dir
+ lr * d_e_d_s_2 * features.selected_dir_factor;
node_increments[2].dir =
momentum * node_increments[2].dir + lr * d_e_d_s_2 * features.dir_factor;
node_increments[2].overlap = momentum * node_increments[2].overlap
+ lr * d_e_d_s_2 * features.overlap_factor;
node_increments[2].immediate_overlap = momentum
* node_increments[2].immediate_overlap
+ lr * d_e_d_s_2 * features.immediate_overlap_factor;
node_increments[2].selected_occurrences = momentum
* node_increments[2].selected_occurrences
+ lr * d_e_d_s_2 * features.selected_occurrences_factor;
node_increments[2].occurrences = momentum * node_increments[2].occurrences
+ lr * d_e_d_s_2 * features.occurrences_factor;
let node0 = network.hidden_nodes[0];
let node1 = network.hidden_nodes[1];
let node2 = network.hidden_nodes[2];
network = Network {
hidden_nodes: [
Node {
offset: node0.offset - node_increments[0].offset,
age: node0.age - node_increments[0].age,
length: node0.length - node_increments[0].length,
exit: node0.exit - node_increments[0].exit,
recent_failure: node0.recent_failure
- node_increments[0].recent_failure,
selected_dir: node0.selected_dir - node_increments[0].selected_dir,
dir: node0.dir - node_increments[0].dir,
overlap: node0.overlap - node_increments[0].overlap,
immediate_overlap: node0.immediate_overlap
- node_increments[0].immediate_overlap,
selected_occurrences: node0.selected_occurrences
- node_increments[0].selected_occurrences,
occurrences: node0.occurrences - node_increments[0].occurrences,
},
Node {
offset: node1.offset - node_increments[1].offset,
age: node1.age - node_increments[1].age,
length: node1.length - node_increments[1].length,
exit: node1.exit - node_increments[1].exit,
recent_failure: node1.recent_failure
- node_increments[1].recent_failure,
selected_dir: node1.selected_dir - node_increments[1].selected_dir,
dir: node1.dir - node_increments[1].dir,
overlap: node1.overlap - node_increments[1].overlap,
immediate_overlap: node1.immediate_overlap
- node_increments[1].immediate_overlap,
selected_occurrences: node1.selected_occurrences
- node_increments[1].selected_occurrences,
occurrences: node1.occurrences - node_increments[1].occurrences,
},
Node {
offset: node2.offset - node_increments[2].offset,
age: node2.age - node_increments[2].age,
length: node2.length - node_increments[2].length,
exit: node2.exit - node_increments[2].exit,
recent_failure: node2.recent_failure
- node_increments[2].recent_failure,
selected_dir: node2.selected_dir - node_increments[2].selected_dir,
dir: node2.dir - node_increments[2].dir,
overlap: node2.overlap - node_increments[2].overlap,
immediate_overlap: node2.immediate_overlap
- node_increments[2].immediate_overlap,
selected_occurrences: node2.selected_occurrences
- node_increments[2].selected_occurrences,
occurrences: node2.occurrences - node_increments[2].occurrences,
},
],
hidden_node_sums: [0.0, 0.0, 0.0],
hidden_node_outputs: [0.0, 0.0, 0.0],
final_bias: network.final_bias - output_increments[0],
final_weights: [
network.final_weights[0] - output_increments[1],
network.final_weights[1] - output_increments[2],
network.final_weights[2] - output_increments[3],
],
final_sum: 0.0,
final_output: 0.0,
};
});
if batch_error / batch_samples < best_restart_error {
best_restart_error = batch_error / batch_samples;
best_restart_network = network;
cycles_since_best_restart_error = 0;
} else {
cycles_since_best_restart_error += 1;
if cycles_since_best_restart_error > plateau_threshold {
println!("Plateaued at {}.", batch_error / batch_samples);
if best_restart_error < best_overall_error {
best_overall_error = best_restart_error;
best_overall_network = best_restart_network;
println!(
"New best overall for {:#?} with error {} (new best)",
best_overall_network, best_overall_error
);
} else {
println!(
"Best overall remains {:#?} with error {} (old)",
best_overall_network, best_overall_error
);
}
break;
}
}
}
}
}
}