use num_traits::Float;
use std::collections::HashMap;
use std::ops::{Index, Range};
use super::utils;
use crate::activation::Activation;
use crate::gene::{Gene, InputId, NeuronId};
use crate::network::NeuronInfo;
use crate::stack::Stack;
#[derive(Clone, Copy)]
pub struct Inputs<'a, T>(pub &'a [T]);
impl<'a, T> Index<InputId> for Inputs<'a, T> {
type Output = T;
fn index(&self, index: InputId) -> &Self::Output {
&self.0[index.as_usize()]
}
}
pub fn evaluate_slice<'s, T: Float>(
genome: &mut Vec<Gene<T>>,
range: Range<usize>,
inputs: Inputs<T>,
stack: &'s mut Stack<T>,
ignore_final_neuron_weight: bool,
neuron_info: &HashMap<NeuronId, NeuronInfo>,
activation: Activation,
) {
for (i, gene_index) in range.enumerate().rev() {
let weight;
let value;
if genome[gene_index].is_bias() {
let bias = genome[gene_index].as_bias().unwrap();
weight = bias.value();
value = T::one();
} else if genome[gene_index].is_input() {
let input = genome[gene_index].as_input().unwrap();
weight = input.weight();
value = inputs[input.id()];
} else if genome[gene_index].is_neuron() {
let neuron = genome[gene_index].as_mut_neuron().unwrap();
let sum_inputs = stack
.pop_sum(neuron.num_inputs())
.expect("A neuron did not receive enough inputs");
value = activation.apply(sum_inputs);
neuron.set_current_value(Some(value));
if i == 0 && ignore_final_neuron_weight {
weight = T::one();
} else {
weight = neuron.weight();
}
} else if genome[gene_index].is_forward_jumper() {
let forward = genome[gene_index].as_forward_jumper().unwrap();
let source_subgenome_range = neuron_info[&forward.source_id()].subgenome_range();
let source = genome[source_subgenome_range.start].as_neuron().unwrap();
weight = forward.weight();
let subgenome_output = if let Some(cached) = source.current_value() {
cached
} else {
evaluate_slice(
genome,
source_subgenome_range,
inputs,
stack,
true,
neuron_info,
activation,
);
stack.pop().unwrap()
};
value = subgenome_output;
} else if genome[gene_index].is_recurrent_jumper() {
let recurrent = genome[gene_index].as_recurrent_jumper().unwrap();
let source = utils::get_neuron(recurrent.source_id(), neuron_info, genome).unwrap();
weight = recurrent.weight();
value = source.previous_value();
} else {
unreachable!();
}
stack.push(weight * value);
}
}
#[cfg(test)]
mod tests {
use assert_approx_eq::assert_approx_eq;
use super::*;
use crate::network::NotEnoughInputsError;
use crate::{Network, WithRecurrentState};
fn get_file_path(file_name: &str) -> String {
format!("{}/test_data/{}", env!("CARGO_MANIFEST_DIR"), file_name)
}
#[test]
fn test_evaluate_full() {
let (mut net, _, _) = Network::load_file::<(), _>(
get_file_path("test_network_v1.cge"),
WithRecurrentState(false),
)
.unwrap();
let output = net.evaluate(&[1.0, 1.0]).unwrap();
assert_eq!(output.len(), 1);
assert_eq!(output[0], 1.014);
let output2 = net.evaluate(&[0.0, 0.0]).unwrap();
assert_eq!(output2.len(), 1);
assert_eq!(output2[0], 0.40056);
}
#[test]
fn test_inputs() {
let (mut net, _, _) = Network::load_file::<(), _>(
get_file_path("test_network_v1.cge"),
WithRecurrentState(false),
)
.unwrap();
let output = net.evaluate(&[1.0, 1.0, 2.0, 3.0]).unwrap();
assert_eq!(output.len(), 1);
assert_eq!(output[0], 1.014);
let output2 = net.evaluate(&[0.0, 0.0, 2.0, 3.0]).unwrap();
assert_eq!(output2.len(), 1);
assert_eq!(output2[0], 0.40056);
assert_eq!(Err(NotEnoughInputsError::new(2, 1)), net.evaluate(&[1.0]));
}
#[test]
fn test_activation() {
let (mut net, _, _) = Network::load_file::<(), _>(
get_file_path("test_network_v1.cge"),
WithRecurrentState(false),
)
.unwrap();
net.set_activation(Activation::Tanh);
let output = net.evaluate(&[1.0, 1.0]).unwrap();
assert_eq!(output.len(), 1);
assert_approx_eq!(output[0], 0.3913229613565932);
let output2 = net.evaluate(&[0.0, 0.0]).unwrap();
assert_eq!(output2.len(), 1);
assert_approx_eq!(output2[0], 0.11798552468976746);
}
#[test]
fn test_multiple_outputs() {
let (mut net, _, _) = Network::load_file::<(), _>(
get_file_path("test_network_multi_output.cge"),
WithRecurrentState(false),
)
.unwrap();
let inputs = [2.0, 3.0];
let output = net.evaluate(&inputs).unwrap().to_vec();
let expected = [3.541362029170628, 3.2752704637145316, 1.1087918551621792];
assert_eq!(expected.len(), output.len());
for i in 0..3 {
assert_approx_eq!(expected[i], output[i]);
}
let output2 = net.evaluate(&inputs).unwrap();
assert_eq!(output, output2);
}
#[test]
fn test_forward_jumper_cached() {
let (mut net, _, _) = Network::load_file::<(), _>(
get_file_path("test_network_v1.cge"),
WithRecurrentState(false),
)
.unwrap();
for gene in &mut net.genome {
if let Gene::Neuron(neuron) = gene {
neuron.set_current_value(Some(100.0));
}
}
let output = net.evaluate(&[0.0, 0.0]).unwrap();
assert_eq!(output.len(), 1);
assert_eq!(output[0], 3.96);
}
#[test]
fn test_recurrent_previous_value() {
let (mut net, _, _) = Network::<f64>::load_file::<(), _>(
get_file_path("test_network_recurrent.cge"),
WithRecurrentState(false),
)
.unwrap();
let output = net.evaluate(&[]).unwrap().to_vec();
assert_eq!(output.len(), 1);
assert_eq!(output[0], 1.0);
let output2 = net.evaluate(&[]).unwrap();
assert_eq!(output2.len(), 1);
assert_eq!(output2[0], 4.0);
}
}