Struct Trainer

Source
pub struct Trainer {
    pub evaluate_fn: EvaluationFn,
    pub hidden_activation: ActivationFn,
    pub output_activation: ActivationFn,
    /* private fields */
}
Expand description

A Trainer will manage the training cycle for a population of Genomes.

Fields§

§evaluate_fn: EvaluationFn

A function that can score a Genome after each piece of TrainingData is fed to it. This function will be passed a Vec of the Genome’s outputs and the expected value or values from the TrainingData. This function is expected to assess the Genome’s performance by comparing the two, and returning an f32 representing its “score”. The score from each call to evaluate_fn will be summed together to form the final fitness value of each Genome.

§hidden_activation: ActivationFn

The ActivationFn to use for the hidden layers of each Genome’s network.

§output_activation: ActivationFn

The ActivationFn to use for the output layer of each Genome’s network.

Implementations§

Source§

impl Trainer

Source

pub fn new(training_data: Vec<TrainingData>) -> Self

Create a new Trainer that will use the given training_data to train a Pool’s Genomes when train is called. Other parameters (eg: evaluate_fn may be customized before calling train by directly setting them on the returned Trainer.

Examples found in repository?
examples/adding_managed.rs (line 32)
13fn main() {
14    let args: Vec<String> = env::args().collect();
15
16    if args.len() < 2 {
17        println!("Usage: '{} train' to train a new comparer.", args[0]);
18        println!("Usage: '{} evaluate serialized_genome.json input1 input2 input3 input4' to evaluate with an existing genome.", args[0]);
19        return;
20    }
21
22    if args[1] == "train" {
23        // One input node for each input in the training data structure
24        let input_nodes = 4;
25        // One output node with the "prediction"
26        let output_nodes = 1;
27        // Create a new gene pool with an initial population of genomes
28        let mut gene_pool = Pool::with_defaults(input_nodes, output_nodes);
29
30        let training_data = load_training_data(TRAINING_DATA_STRING, 4, 1);
31
32        let mut trainer = Trainer::new(training_data);
33        // This function will be called once per Genome per piece of
34        // TrainingData in each generation, passing the values of the
35        // output nodes of the Genome as well as the expected result
36        // from the TrainingData.
37        trainer.evaluate_fn = adding_fitness_func;
38        trainer.hidden_activation = linear_activation;
39        trainer.output_activation = linear_activation;
40
41        // Train over the course of 100 generations
42        trainer.train(&mut gene_pool, 100);
43
44        let best_genome = gene_pool.get_best_genome();
45
46        println!("Serializing best genome to winner.json");
47        serde_json::to_writer(&File::create("winner.json").unwrap(), &best_genome).unwrap();
48    } else {
49        if args.len() < 7 {
50            println!("Usage: '{} evaluate serialized_genome.json input1 input2 input3 input4' to evaluate with an existing genome.", args[0]);
51            return;
52        }
53        let mut genome: Genome = serde_json::from_reader(File::open(&args[2]).unwrap()).unwrap();
54        let input1 = args[3]
55            .parse::<f32>()
56            .expect("Couldn't parse input1 as f32");
57        let input2 = args[4]
58            .parse::<f32>()
59            .expect("Couldn't parse input2 as f32");
60        let input3 = args[5]
61            .parse::<f32>()
62            .expect("Couldn't parse input1 as f32");
63        let input4 = args[6]
64            .parse::<f32>()
65            .expect("Couldn't parse input2 as f32");
66        // Note that this is the exact same function we used in training
67        // further up!
68        genome.evaluate(
69            &vec![input1, input2, input3, input4],
70            Some(linear_activation),
71            Some(linear_activation),
72        );
73        println!(
74            "Sum of inputs is..........{}",
75            genome.get_outputs()[0] as u32
76        );
77    }
78}
Source

pub fn train(self, gene_pool: &mut Pool, generations: usize)

Train the given gene_pool over the course of generations generations. The training_data provided when creating the Trainer will be used as part of this process, as well as the current values of evaluate_fn, hidden_activation and output_activation.

Examples found in repository?
examples/adding_managed.rs (line 42)
13fn main() {
14    let args: Vec<String> = env::args().collect();
15
16    if args.len() < 2 {
17        println!("Usage: '{} train' to train a new comparer.", args[0]);
18        println!("Usage: '{} evaluate serialized_genome.json input1 input2 input3 input4' to evaluate with an existing genome.", args[0]);
19        return;
20    }
21
22    if args[1] == "train" {
23        // One input node for each input in the training data structure
24        let input_nodes = 4;
25        // One output node with the "prediction"
26        let output_nodes = 1;
27        // Create a new gene pool with an initial population of genomes
28        let mut gene_pool = Pool::with_defaults(input_nodes, output_nodes);
29
30        let training_data = load_training_data(TRAINING_DATA_STRING, 4, 1);
31
32        let mut trainer = Trainer::new(training_data);
33        // This function will be called once per Genome per piece of
34        // TrainingData in each generation, passing the values of the
35        // output nodes of the Genome as well as the expected result
36        // from the TrainingData.
37        trainer.evaluate_fn = adding_fitness_func;
38        trainer.hidden_activation = linear_activation;
39        trainer.output_activation = linear_activation;
40
41        // Train over the course of 100 generations
42        trainer.train(&mut gene_pool, 100);
43
44        let best_genome = gene_pool.get_best_genome();
45
46        println!("Serializing best genome to winner.json");
47        serde_json::to_writer(&File::create("winner.json").unwrap(), &best_genome).unwrap();
48    } else {
49        if args.len() < 7 {
50            println!("Usage: '{} evaluate serialized_genome.json input1 input2 input3 input4' to evaluate with an existing genome.", args[0]);
51            return;
52        }
53        let mut genome: Genome = serde_json::from_reader(File::open(&args[2]).unwrap()).unwrap();
54        let input1 = args[3]
55            .parse::<f32>()
56            .expect("Couldn't parse input1 as f32");
57        let input2 = args[4]
58            .parse::<f32>()
59            .expect("Couldn't parse input2 as f32");
60        let input3 = args[5]
61            .parse::<f32>()
62            .expect("Couldn't parse input1 as f32");
63        let input4 = args[6]
64            .parse::<f32>()
65            .expect("Couldn't parse input2 as f32");
66        // Note that this is the exact same function we used in training
67        // further up!
68        genome.evaluate(
69            &vec![input1, input2, input3, input4],
70            Some(linear_activation),
71            Some(linear_activation),
72        );
73        println!(
74            "Sum of inputs is..........{}",
75            genome.get_outputs()[0] as u32
76        );
77    }
78}

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V