astrai 2.2.0

A pretty bad neural network library
Documentation
use super::*;

mu!(traits);

#[derive(Clone)]
pub struct Dataset {
    pub inputs: Vec<Vec<f64>>,
    pub outputs: Vec<Vec<f64>>,
}

impl Dataset {
    pub fn from_tuples(tuples: Vec<(Vec<f64>, Vec<f64>)>) -> Dataset {
        let mut inputs = Vec::new();
        let mut outputs = Vec::new();
        for (input, output) in tuples {
            inputs.push(input);
            outputs.push(output);
        }
        Dataset { inputs, outputs }
    }
}

#[derive(Clone)]
pub struct Evaluator {
    dataset: Option<Dataset>,
    custom_eval_function: Option<Box<dyn CustomEval>>,
    custom_loss_function: Option<Box<dyn CustomLoss>>,
}

#[profiling::all_functions]
impl Evaluator {
    pub fn new_with_dataset(dataset: Dataset) -> Evaluator {
        Evaluator {
            dataset: Some(dataset),
            custom_eval_function: None,
            custom_loss_function: None,
        }
    }

    pub fn new_with_custom_eval_function(custom_eval_function: Box<dyn CustomEval>) -> Evaluator {
        Evaluator {
            dataset: None,
            custom_eval_function: Some(custom_eval_function),
            custom_loss_function: None,
        }
    }

    pub fn new_with_custom_loss_function(custom_loss_function: Box<dyn CustomLoss>) -> Evaluator {
        Evaluator {
            dataset: None,
            custom_eval_function: None,
            custom_loss_function: Some(custom_loss_function),
        }
    }

    pub fn evaluate(&self, agent: &mut Agent) -> f64 {
        if let Some(dataset) = &self.dataset {
            let mut fitness = 0.0;
            for (inputs, outputs) in dataset.inputs.iter().zip(dataset.outputs.iter()) {
                let result = agent.network.activate(inputs.clone(), true);
                for (output, expected) in result.iter().zip(outputs.iter()) {
                    fitness += 1f64 - (output - expected).abs();
                }
            }
            fitness
        } else if let Some(custom_eval_function) = &self.custom_eval_function {
            // println!("Using custom eval function with matrix evaluate - this does nothing different than using evaluate!");
            custom_eval_function.evaluate(agent)
        } else {
            panic!("No dataset or environment provided");
        }
    }

    pub fn loss(&self, agent: &mut Agent) -> f64 {
        if let Some(dataset) = &self.dataset {
            // MSE
            let mut loss = 0.0;
            for (inputs, outputs) in dataset.inputs.iter().zip(dataset.outputs.iter()) {
                let result = agent.activate(inputs.clone());
                for (output, expected) in result.iter().zip(outputs.iter()) {
                    loss += (output - expected).powi(2);
                }
            }
            loss
        } else if let Some(custom_loss_function) = &self.custom_loss_function {
            custom_loss_function.loss(agent)
        } else {
            panic!("No dataset or loss function provided, but loss was called");
        }
    }

    pub fn loss_on_pair(&self, net: &mut Network, inputs: &[f64], target: &[f64]) -> f64 {
        let result = net.activate(inputs.to_vec(), false);
        let mut loss = 0.0;
        for (output, expected) in result.iter().zip(target.iter()) {
            loss += (output - expected).powi(2);
        }
        loss
    }

    pub fn fitness(&self, network: &mut Network) -> f64 {
        if let Some(custom_eval_function) = &self.custom_eval_function {
            custom_eval_function.evaluate(&mut Agent {
                network: network.clone(),
                fitness: 0.0,
            })
        } else {
            panic!("No custom evaluation function provided, but fitness was called");
        }
    }

    pub fn show_dataset_performance(&self, agent: &mut Network) {
        if let Some(dataset) = &self.dataset {
            for (inputs, outputs) in dataset.inputs.iter().zip(dataset.outputs.iter()) {
                let output = agent.activate(inputs.clone(), false);
                let rounded_output = output.iter().map(|x| x.round()).collect::<Vec<f64>>();
                let text = format!(
                    "{:?} -> {:.5?} rounded to {:?} (expected {:?})",
                    inputs, output, rounded_output, outputs
                );

                if rounded_output == *outputs {
                    println!("{}", text.green())
                } else {
                    println!("{}", text.red())
                }
            }
        } else if self.custom_eval_function.is_some() {
            panic!("Cannot show dataset performance for custom evaluation function");
        } else {
            panic!("No dataset or environment provided");
        }
    }

    pub fn dataset_iter(&self) -> impl Iterator<Item = (&Vec<f64>, &Vec<f64>)> {
        self.dataset
            .as_ref()
            .unwrap()
            .inputs
            .iter()
            .zip(self.dataset.as_ref().unwrap().outputs.iter())
    }

    pub fn debug_dataset(&self) {
        for (inputs, outputs) in self.dataset_iter() {
            println!("{:?} -> {:?}", inputs, outputs);
        }
    }
}