use super::*;
mu!(traits);
#[derive(Clone)]
pub struct Dataset {
pub inputs: Vec<Vec<f64>>,
pub outputs: Vec<Vec<f64>>,
}
impl Dataset {
pub fn from_tuples(tuples: Vec<(Vec<f64>, Vec<f64>)>) -> Dataset {
let mut inputs = Vec::new();
let mut outputs = Vec::new();
for (input, output) in tuples {
inputs.push(input);
outputs.push(output);
}
Dataset { inputs, outputs }
}
}
#[derive(Clone)]
pub struct Evaluator {
dataset: Option<Dataset>,
custom_eval_function: Option<Box<dyn CustomEval>>,
custom_loss_function: Option<Box<dyn CustomLoss>>,
}
#[profiling::all_functions]
impl Evaluator {
pub fn new_with_dataset(dataset: Dataset) -> Evaluator {
Evaluator {
dataset: Some(dataset),
custom_eval_function: None,
custom_loss_function: None,
}
}
pub fn new_with_custom_eval_function(custom_eval_function: Box<dyn CustomEval>) -> Evaluator {
Evaluator {
dataset: None,
custom_eval_function: Some(custom_eval_function),
custom_loss_function: None,
}
}
pub fn new_with_custom_loss_function(custom_loss_function: Box<dyn CustomLoss>) -> Evaluator {
Evaluator {
dataset: None,
custom_eval_function: None,
custom_loss_function: Some(custom_loss_function),
}
}
pub fn evaluate(&self, agent: &mut Agent) -> f64 {
if let Some(dataset) = &self.dataset {
let mut fitness = 0.0;
for (inputs, outputs) in dataset.inputs.iter().zip(dataset.outputs.iter()) {
let result = agent.network.activate(inputs.clone(), true);
for (output, expected) in result.iter().zip(outputs.iter()) {
fitness += 1f64 - (output - expected).abs();
}
}
fitness
} else if let Some(custom_eval_function) = &self.custom_eval_function {
custom_eval_function.evaluate(agent)
} else {
panic!("No dataset or environment provided");
}
}
pub fn loss(&self, agent: &mut Agent) -> f64 {
if let Some(dataset) = &self.dataset {
let mut loss = 0.0;
for (inputs, outputs) in dataset.inputs.iter().zip(dataset.outputs.iter()) {
let result = agent.activate(inputs.clone());
for (output, expected) in result.iter().zip(outputs.iter()) {
loss += (output - expected).powi(2);
}
}
loss
} else if let Some(custom_loss_function) = &self.custom_loss_function {
custom_loss_function.loss(agent)
} else {
panic!("No dataset or loss function provided, but loss was called");
}
}
pub fn loss_on_pair(&self, net: &mut Network, inputs: &[f64], target: &[f64]) -> f64 {
let result = net.activate(inputs.to_vec(), false);
let mut loss = 0.0;
for (output, expected) in result.iter().zip(target.iter()) {
loss += (output - expected).powi(2);
}
loss
}
pub fn fitness(&self, network: &mut Network) -> f64 {
if let Some(custom_eval_function) = &self.custom_eval_function {
custom_eval_function.evaluate(&mut Agent {
network: network.clone(),
fitness: 0.0,
})
} else {
panic!("No custom evaluation function provided, but fitness was called");
}
}
pub fn show_dataset_performance(&self, agent: &mut Network) {
if let Some(dataset) = &self.dataset {
for (inputs, outputs) in dataset.inputs.iter().zip(dataset.outputs.iter()) {
let output = agent.activate(inputs.clone(), false);
let rounded_output = output.iter().map(|x| x.round()).collect::<Vec<f64>>();
let text = format!(
"{:?} -> {:.5?} rounded to {:?} (expected {:?})",
inputs, output, rounded_output, outputs
);
if rounded_output == *outputs {
println!("{}", text.green())
} else {
println!("{}", text.red())
}
}
} else if self.custom_eval_function.is_some() {
panic!("Cannot show dataset performance for custom evaluation function");
} else {
panic!("No dataset or environment provided");
}
}
pub fn dataset_iter(&self) -> impl Iterator<Item = (&Vec<f64>, &Vec<f64>)> {
self.dataset
.as_ref()
.unwrap()
.inputs
.iter()
.zip(self.dataset.as_ref().unwrap().outputs.iter())
}
pub fn debug_dataset(&self) {
for (inputs, outputs) in self.dataset_iter() {
println!("{:?} -> {:?}", inputs, outputs);
}
}
}