1use crate::graph::NeuralNetwork;
4use crate::enecode::EneCode;
5use pyo3::prelude::*;
6use std::path::PathBuf;
7use std::fs::File;
8use std::io::prelude::*;
9use std::io::BufWriter;
10use std::io::Result as FileResult;
11
12pub trait NnInputVector: Send {
13 fn into_vec_f32(&self) -> Vec<f32>;
14}
15
16impl NnInputVector for Vec<f32> {
17 fn into_vec_f32(&self) -> Vec<f32> {
18 self.clone()
19 }
20}
21
22#[pyclass]
23pub struct Agent {
24 pub nn: Box<NeuralNetwork>,
25 pub fitness: f32,
26}
27
28#[pymethods]
29impl Agent {
30 #[new]
31 pub fn new(genome_base: EneCode) -> Self {
32 let agent = NeuralNetwork::new(genome_base.clone());
33
34 Agent {
35 nn: Box::new(agent),
36 fitness: 0.
37 }
38 }
39
40 pub fn fwd(&mut self, input: Vec<f32>) {
41 self.nn.fwd(input);
42 }
43
44 pub fn output(&self) -> Vec<f32> {
45 self.nn.fetch_network_output()
46 }
47
48 pub fn mutate(&mut self, mutation_rate: f32, mutation_sd: f32, topology_mutation_rate: f32) {
49 self.nn.mutate(mutation_rate, mutation_sd, topology_mutation_rate);
50 }
51
52 pub fn update_fitness(&mut self, new_fitness: f32) {
53 self.fitness = new_fitness;
54 }
55
56 pub fn write_genome(&self, file_path: PathBuf) -> FileResult<()> {
57 let serialized_genome = self.nn.serialize_genome();
58 let file = File::create(file_path)?;
59 let mut writer = BufWriter::new(file);
60 writer.write_all(serialized_genome.as_bytes())?;
61 writer.flush()?;
62 Ok(())
63 }
64}
65