use super::*;
impl Network {
fn backward_pass(&mut self, targets: &[f64]) {
let output_layer_idx = self.layers.len() - 1;
let activation_derivative = self.layers[output_layer_idx].z.mapv(|z| self.layers[output_layer_idx].activation_function.derivative(z.unwrap()));
self.layers[output_layer_idx].delta = (self.layers[output_layer_idx].activation.mapv(|a| a.unwrap()) - Array1::from(targets.to_vec())) * activation_derivative;
for layer_idx in (1..output_layer_idx).rev() {
for neuron_idx in 0..self.layers[layer_idx].neuron_amt {
let delta_sum = self
.connections
.iter()
.filter(|c| c.enabled && c.in_neuron_id == (layer_idx, neuron_idx))
.map(|c| c.weight * self.layers[c.out_neuron_id.0].delta[c.out_neuron_id.1])
.sum::<f64>();
let activation_derivative = self.layers[layer_idx]
.activation_function
.derivative(self.layers[layer_idx].z[neuron_idx].unwrap());
self.layers[layer_idx].delta[neuron_idx] = delta_sum * activation_derivative;
}
}
}
fn update_connection_weights(&mut self, learning_rate: f64, weight_clamp: f64) {
self.connections.iter_mut().for_each(|connection| {
let grad_weight = self.layers[connection.in_neuron_id.0].activation
[connection.in_neuron_id.1]
.unwrap()
* self.layers[connection.out_neuron_id.0].delta[connection.out_neuron_id.1];
connection.weight -= learning_rate * grad_weight;
connection.weight = connection.weight.clamp(-weight_clamp, weight_clamp);
});
}
fn update_layer_biases(&mut self, learning_rate: f64, weight_clamp: f64) {
self.layers.iter_mut().skip(1).for_each(|layer| {
let learning_rate_array = Array1::from_elem(layer.neuron_amt, learning_rate);
let modded = learning_rate_array * &layer.delta;
layer.bias -= &modded;
layer.bias.mapv(|x| x.clamp(-weight_clamp, weight_clamp));
});
}
fn update_parameters(&mut self, learning_rate: f64, weight_clamp: f64) {
self.update_connection_weights(learning_rate, weight_clamp);
self.update_layer_biases(learning_rate, weight_clamp);
}
pub fn graddec(
&mut self,
evaluator: &Evaluator,
optimizer: &mut Optimizer,
weight_clamp: f64,
print_every: Option<usize>,
) -> f64 {
let mut last_print: usize = 0;
for epoch in 0..optimizer.epochs {
let mut total_loss = 0.0;
evaluator.dataset_iter().for_each(|(inputs, targets)| {
total_loss += evaluator.loss_on_pair(self, inputs, &(targets[..]));
self.backward_pass(targets);
self.update_parameters(optimizer.learning_rate, weight_clamp);
});
if let Some(pe) = print_every {
if pe == 1 || epoch - last_print >= pe {
println!(
"Epoch {}: Loss = {:.6}, learning rate {:.6}",
epoch, total_loss, optimizer.learning_rate
);
last_print = epoch;
}
}
if let Some(stop_on_loss) = optimizer.stop_on_loss {
if total_loss <= stop_on_loss {
println!("Stopping early due to loss threshold");
break;
}
}
optimizer.tick(total_loss);
}
optimizer.last_loss.unwrap()
}
}