1use csv::ReaderBuilder;
2use rand::seq::SliceRandom;
3use serde::{Deserialize, Serialize};
4use strum::EnumCount;
5
6use crate::{
7 core::{
8 engines::{
9 breed_engine::BreedEngine,
10 core_engine::Core,
11 fitness_engine::FitnessEngine,
12 freeze_engine::FreezeEngine,
13 generate_engine::{Generate, GenerateEngine},
14 mutate_engine::MutateEngine,
15 reset_engine::{Reset, ResetEngine},
16 status_engine::StatusEngine,
17 },
18 environment::State,
19 program::{Program, ProgramGeneratorParameters},
20 },
21 utils::random::generator,
22};
23
24const IRIS_CSV: &str = include_str!("iris.csv");
25
26#[derive(
27 Debug,
28 Clone,
29 Copy,
30 Eq,
31 PartialEq,
32 EnumCount,
33 PartialOrd,
34 Ord,
35 strum::Display,
36 Serialize,
37 Deserialize,
38 Hash,
39)]
40pub enum IrisClass {
41 #[serde(rename = "Iris-setosa")]
42 Setosa = 0,
43 #[serde(rename = "Iris-versicolor")]
44 Versicolour = 1,
45 #[serde(rename = "Iris-virginica")]
46 Virginica = 2,
47}
48
49pub struct IrisLgp;
50
51#[derive(Deserialize, Debug, Clone, PartialEq, PartialOrd, Serialize)]
52pub struct IrisInput {
53 sepal_length: f64,
54 sepal_width: f64,
55 petal_length: f64,
56 petal_width: f64,
57 class: IrisClass,
58}
59
60#[derive(Clone)]
61pub struct IrisState {
62 data: Vec<IrisInput>,
63 idx: usize,
64}
65
66impl State for IrisState {
67 fn get_value(&self, idx: usize) -> f64 {
68 let item = &self.data[self.idx];
69
70 match idx {
71 0 => item.sepal_length,
72 1 => item.sepal_width,
73 2 => item.petal_length,
74 3 => item.petal_width,
75 _ => unreachable!(),
76 }
77 }
78
79 fn execute_action(&mut self, action: usize) -> f64 {
80 let item = &self.data[self.idx];
81 self.idx += 1;
82 let correct_class = item.class as usize;
83 let is_correct = correct_class == action;
84 is_correct as usize as f64
85 }
86
87 fn get(&mut self) -> Option<&mut Self> {
88 if self.idx >= self.data.len() {
89 return None;
90 }
91
92 Some(self)
93 }
94}
95
96impl Reset<IrisState> for ResetEngine {
97 fn reset(item: &mut IrisState) {
98 item.idx = 0;
99 }
100}
101
102impl Generate<(), IrisState> for GenerateEngine {
103 fn generate(_using: ()) -> IrisState {
104 let mut csv_reader = ReaderBuilder::new()
105 .has_headers(false)
106 .from_reader(IRIS_CSV.as_bytes());
107
108 let mut data: Vec<IrisInput> = csv_reader
109 .deserialize()
110 .collect::<Result<_, _>>()
111 .expect("Failed to parse iris dataset");
112
113 data.shuffle(&mut generator());
114
115 IrisState { data, idx: 0 }
116 }
117}
118
119#[derive(Clone)]
120pub struct IrisEngine;
121
122impl Core for IrisEngine {
123 type State = IrisState;
124 type Individual = Program;
125 type ProgramParameters = ProgramGeneratorParameters;
126 type FitnessMarker = ();
127 type Generate = GenerateEngine;
128 type Fitness = FitnessEngine;
129 type Reset = ResetEngine;
130 type Breed = BreedEngine;
131 type Mutate = MutateEngine;
132 type Status = StatusEngine;
133 type Freeze = FreezeEngine;
134}