Struct neuralneat::Trainer

source ·
pub struct Trainer {
    pub evaluate_fn: EvaluationFn,
    pub hidden_activation: ActivationFn,
    pub output_activation: ActivationFn,
    /* private fields */
}
Expand description

A Trainer will manage the training cycle for a population of Genomes.

Fields§

§evaluate_fn: EvaluationFn

A function that can score a Genome after each piece of TrainingData is fed to it. This function will be passed a Vec of the Genome’s outputs and the expected value or values from the TrainingData. This function is expected to assess the Genome’s performance by comparing the two, and returning an f32 representing its “score”. The score from each call to evaluate_fn will be summed together to form the final fitness value of each Genome.

§hidden_activation: ActivationFn

The ActivationFn to use for the hidden layers of each Genome’s network.

§output_activation: ActivationFn

The ActivationFn to use for the output layer of each Genome’s network.

Implementations§

source§

impl Trainer

source

pub fn new(training_data: Vec<TrainingData>) -> Self

Create a new Trainer that will use the given training_data to train a Pool’s Genomes when train is called. Other parameters (eg: evaluate_fn may be customized before calling train by directly setting them on the returned Trainer.

Examples found in repository?
examples/adding_managed.rs (line 32)
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
fn main() {
    let args: Vec<String> = env::args().collect();

    if args.len() < 2 {
        println!("Usage: '{} train' to train a new comparer.", args[0]);
        println!("Usage: '{} evaluate serialized_genome.json input1 input2 input3 input4' to evaluate with an existing genome.", args[0]);
        return;
    }

    if args[1] == "train" {
        // One input node for each input in the training data structure
        let input_nodes = 4;
        // One output node with the "prediction"
        let output_nodes = 1;
        // Create a new gene pool with an initial population of genomes
        let mut gene_pool = Pool::with_defaults(input_nodes, output_nodes);

        let training_data = load_training_data(TRAINING_DATA_STRING, 4, 1);

        let mut trainer = Trainer::new(training_data);
        // This function will be called once per Genome per piece of
        // TrainingData in each generation, passing the values of the
        // output nodes of the Genome as well as the expected result
        // from the TrainingData.
        trainer.evaluate_fn = adding_fitness_func;
        trainer.hidden_activation = linear_activation;
        trainer.output_activation = linear_activation;

        // Train over the course of 100 generations
        trainer.train(&mut gene_pool, 100);

        let best_genome = gene_pool.get_best_genome();

        println!("Serializing best genome to winner.json");
        serde_json::to_writer(&File::create("winner.json").unwrap(), &best_genome).unwrap();
    } else {
        if args.len() < 7 {
            println!("Usage: '{} evaluate serialized_genome.json input1 input2 input3 input4' to evaluate with an existing genome.", args[0]);
            return;
        }
        let mut genome: Genome = serde_json::from_reader(File::open(&args[2]).unwrap()).unwrap();
        let input1 = args[3]
            .parse::<f32>()
            .expect("Couldn't parse input1 as f32");
        let input2 = args[4]
            .parse::<f32>()
            .expect("Couldn't parse input2 as f32");
        let input3 = args[5]
            .parse::<f32>()
            .expect("Couldn't parse input1 as f32");
        let input4 = args[6]
            .parse::<f32>()
            .expect("Couldn't parse input2 as f32");
        // Note that this is the exact same function we used in training
        // further up!
        genome.evaluate(
            &vec![input1, input2, input3, input4],
            Some(linear_activation),
            Some(linear_activation),
        );
        println!(
            "Sum of inputs is..........{}",
            genome.get_outputs()[0] as u32
        );
    }
}
source

pub fn train(self, gene_pool: &mut Pool, generations: usize)

Train the given gene_pool over the course of generations generations. The training_data provided when creating the Trainer will be used as part of this process, as well as the current values of evaluate_fn, hidden_activation and output_activation.

Examples found in repository?
examples/adding_managed.rs (line 42)
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
fn main() {
    let args: Vec<String> = env::args().collect();

    if args.len() < 2 {
        println!("Usage: '{} train' to train a new comparer.", args[0]);
        println!("Usage: '{} evaluate serialized_genome.json input1 input2 input3 input4' to evaluate with an existing genome.", args[0]);
        return;
    }

    if args[1] == "train" {
        // One input node for each input in the training data structure
        let input_nodes = 4;
        // One output node with the "prediction"
        let output_nodes = 1;
        // Create a new gene pool with an initial population of genomes
        let mut gene_pool = Pool::with_defaults(input_nodes, output_nodes);

        let training_data = load_training_data(TRAINING_DATA_STRING, 4, 1);

        let mut trainer = Trainer::new(training_data);
        // This function will be called once per Genome per piece of
        // TrainingData in each generation, passing the values of the
        // output nodes of the Genome as well as the expected result
        // from the TrainingData.
        trainer.evaluate_fn = adding_fitness_func;
        trainer.hidden_activation = linear_activation;
        trainer.output_activation = linear_activation;

        // Train over the course of 100 generations
        trainer.train(&mut gene_pool, 100);

        let best_genome = gene_pool.get_best_genome();

        println!("Serializing best genome to winner.json");
        serde_json::to_writer(&File::create("winner.json").unwrap(), &best_genome).unwrap();
    } else {
        if args.len() < 7 {
            println!("Usage: '{} evaluate serialized_genome.json input1 input2 input3 input4' to evaluate with an existing genome.", args[0]);
            return;
        }
        let mut genome: Genome = serde_json::from_reader(File::open(&args[2]).unwrap()).unwrap();
        let input1 = args[3]
            .parse::<f32>()
            .expect("Couldn't parse input1 as f32");
        let input2 = args[4]
            .parse::<f32>()
            .expect("Couldn't parse input2 as f32");
        let input3 = args[5]
            .parse::<f32>()
            .expect("Couldn't parse input1 as f32");
        let input4 = args[6]
            .parse::<f32>()
            .expect("Couldn't parse input2 as f32");
        // Note that this is the exact same function we used in training
        // further up!
        genome.evaluate(
            &vec![input1, input2, input3, input4],
            Some(linear_activation),
            Some(linear_activation),
        );
        println!(
            "Sum of inputs is..........{}",
            genome.get_outputs()[0] as u32
        );
    }
}

Auto Trait Implementations§

Blanket Implementations§

source§

impl<T> Any for Twhere T: 'static + ?Sized,

source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
source§

impl<T> Borrow<T> for Twhere T: ?Sized,

source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
source§

impl<T> BorrowMut<T> for Twhere T: ?Sized,

source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
source§

impl<T> From<T> for T

source§

fn from(t: T) -> T

Returns the argument unchanged.

source§

impl<T, U> Into<U> for Twhere U: From<T>,

source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

source§

impl<T, U> TryFrom<U> for Twhere U: Into<T>,

§

type Error = Infallible

The type returned in the event of a conversion error.
source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
source§

impl<T, U> TryInto<U> for Twhere U: TryFrom<T>,

§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
§

impl<V, T> VZip<V> for Twhere V: MultiLane<T>,

§

fn vzip(self) -> V