evo_rl/
agent_wrapper.rs

1//! A wrapper class for neural network that creates an interface to other environments such as
2//! Python
3use crate::graph::NeuralNetwork;
4use crate::enecode::EneCode;
5use pyo3::prelude::*;
6use std::path::PathBuf;
7use std::fs::File;
8use std::io::prelude::*;
9use std::io::BufWriter;
10use std::io::Result as FileResult;
11
12pub trait NnInputVector: Send {
13    fn into_vec_f32(&self) -> Vec<f32>;
14}
15
16impl NnInputVector for Vec<f32> {
17    fn into_vec_f32(&self) -> Vec<f32> { 
18        self.clone()
19    }
20}
21
22#[pyclass]
23pub struct Agent {
24    pub nn: Box<NeuralNetwork>,
25    pub fitness: f32,
26}
27
28#[pymethods]
29impl Agent {
30    #[new]
31    pub fn new(genome_base: EneCode) -> Self {
32        let agent = NeuralNetwork::new(genome_base.clone());
33
34        Agent {
35            nn: Box::new(agent),
36            fitness: 0.
37        }
38    }
39
40    pub fn fwd(&mut self, input: Vec<f32>) {
41        self.nn.fwd(input);
42    }
43
44    pub fn output(&self) -> Vec<f32> {
45        self.nn.fetch_network_output()
46    }
47
48    pub fn mutate(&mut self, mutation_rate: f32, mutation_sd: f32, topology_mutation_rate: f32) {
49        self.nn.mutate(mutation_rate, mutation_sd, topology_mutation_rate);
50    }
51
52    pub fn update_fitness(&mut self, new_fitness: f32) {
53        self.fitness = new_fitness;
54    }
55
56    pub fn write_genome(&self, file_path: PathBuf) -> FileResult<()> {
57        let serialized_genome = self.nn.serialize_genome();
58        let file = File::create(file_path)?;
59        let mut writer = BufWriter::new(file);
60        writer.write_all(serialized_genome.as_bytes())?;
61        writer.flush()?;
62        Ok(())
63    }
64}
65