1use std::fmt::{self, Debug};
2
3use clap::Args;
4use derivative::Derivative;
5use derive_builder::Builder;
6use rand::Rng;
7use serde::{Deserialize, Serialize};
8use tracing::{debug, trace};
9
10use crate::{
11 core::{
12 engines::{
13 breed_engine::{Breed, BreedEngine},
14 fitness_engine::{Fitness, FitnessEngine},
15 freeze_engine::{Freeze, FreezeEngine},
16 generate_engine::{Generate, GenerateEngine},
17 mutate_engine::{Mutate, MutateEngine},
18 reset_engine::{Reset, ResetEngine},
19 status_engine::{Status, StatusEngine},
20 },
21 environment::{RlState, State},
22 instruction::InstructionGeneratorParameters,
23 program::{Program, ProgramGeneratorParameters},
24 registers::{ActionRegister, ArgmaxInput, Registers},
25 },
26 utils::{float_ops, random::generator},
27};
28
29#[derive(Clone, Serialize, Deserialize)]
30pub struct QTable {
31 table: Vec<Vec<f64>>,
32 q_consts: QConsts,
33 freeze: bool,
34}
35
36impl Freeze<QTable> for FreezeEngine {
37 fn freeze(item: &mut QTable) {
38 item.freeze = true;
39 }
40}
41
42impl Generate<(InstructionGeneratorParameters, QConsts), QTable> for GenerateEngine {
43 fn generate(using: (InstructionGeneratorParameters, QConsts)) -> QTable {
44 let mut table = QTable {
45 table: vec![vec![0.; using.0.n_actions]; using.0.n_registers()],
46 q_consts: using.1,
47 freeze: false,
48 };
49
50 ResetEngine::reset(&mut table);
51 table
52 }
53}
54
55impl Debug for QTable {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 f.debug_list().entries(self.table.iter()).finish()
58 }
59}
60
61#[derive(Debug, Clone, Copy)]
62pub struct ActionRegisterPair {
63 action: usize,
64 register: usize,
65}
66
67impl Reset<QTable> for ResetEngine {
68 fn reset(item: &mut QTable) {
69 ResetEngine::reset(&mut item.q_consts);
70 }
71}
72
73impl QTable {
74 pub fn action_random(&self) -> usize {
75 let n_actions = self.table[0].len();
76 generator().gen_range(0..n_actions)
77 }
78
79 pub fn action_argmax(&self, register_number: usize) -> usize {
80 let available_actions = self
81 .table
82 .get(register_number)
83 .expect("Register number to be less than length of QTable.");
84
85 let iter = available_actions.iter().copied();
86 let max = float_ops::argmax(iter);
87
88 max.expect("Available action to yield an index.")
89 }
90
91 pub fn get_action_register(&self, registers: &Registers) -> Option<ActionRegisterPair> {
92 let winning_register = match registers.argmax(ArgmaxInput::All).any() {
93 ActionRegister::Value(register) => register,
94 _ => {
95 return None;
96 }
97 };
98
99 let prob = generator().gen_range(0.0..1.0);
100
101 let winning_action = if prob <= self.q_consts.epsilon_active {
102 self.action_random()
103 } else {
104 self.action_argmax(winning_register)
105 };
106
107 Some(ActionRegisterPair {
108 action: winning_action,
109 register: winning_register,
110 })
111 }
112
113 pub fn update(
114 &mut self,
115 current_action_state: ActionRegisterPair,
116 current_reward: f64,
117 next_action_state: ActionRegisterPair,
118 ) {
119 let current_q_value =
120 self.table[current_action_state.register][current_action_state.action];
121 let next_q_value = self.action_argmax(next_action_state.register) as f64;
122
123 let new_q_value = self.q_consts.alpha_active
124 * (current_reward + (self.q_consts.gamma * next_q_value) - current_q_value);
125
126 self.table[current_action_state.register][current_action_state.action] += new_q_value;
127
128 trace!(
129 register = current_action_state.register,
130 action = current_action_state.action,
131 reward = current_reward,
132 old_q = current_q_value,
133 delta_q = new_q_value,
134 alpha = self.q_consts.alpha_active,
135 gamma = self.q_consts.gamma,
136 "Q-table update"
137 );
138
139 if !self.freeze {
140 self.q_consts.decay();
141 }
142 }
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize, Derivative)]
146#[derivative(PartialEq, PartialOrd, Ord, Eq)]
147pub struct QProgram {
148 #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Ord = "ignore")]
149 pub q_table: QTable,
150 pub program: Program,
151}
152
153impl Freeze<QProgram> for FreezeEngine {
154 fn freeze(item: &mut QProgram) {
155 FreezeEngine::freeze(&mut item.q_table);
156 }
157}
158
159impl Reset<QProgram> for ResetEngine {
160 fn reset(item: &mut QProgram) {
161 ResetEngine::reset(&mut item.program);
162 }
163}
164
165fn get_action_state<T>(environment: &mut T, q_program: &mut QProgram) -> Option<ActionRegisterPair>
166where
167 T: State,
168{
169 q_program.program.run(environment);
171
172 q_program
174 .q_table
175 .get_action_register(&q_program.program.registers)
176}
177
178impl<T: RlState> Fitness<QProgram, T, ()> for FitnessEngine {
179 fn eval_fitness(program: &mut QProgram, states: &mut T) -> f64 {
180 let mut score = 0.;
181
182 let mut current_action_state = match get_action_state(states, program) {
184 Some(action_state) => action_state,
185 None => {
186 return f64::NEG_INFINITY;
187 }
188 };
189
190 while let Some(state) = states.get() {
192 let reward = state.execute_action(current_action_state.action);
194 score += reward;
195
196 if state.is_terminal() {
197 break;
198 }
199
200 let next_action_state = match get_action_state(state, program) {
201 Some(action_state) => action_state,
202 None => {
203 return f64::NEG_INFINITY;
204 }
205 };
206
207 if current_action_state.register != next_action_state.register {
210 program
211 .q_table
212 .update(current_action_state, reward, next_action_state)
213 }
214
215 current_action_state = next_action_state;
216 }
217
218 debug!(
219 program_id = %program.program.id,
220 score = score,
221 "Q-Learning fitness evaluation complete"
222 );
223
224 trace!(
225 program_id = %program.program.id,
226 q_table = serde_json::to_string(&program.q_table).ok(),
227 initial_state = serde_json::to_string(&states.get_initial_state()).ok(),
228 "Full Q-Learning evaluation details"
229 );
230
231 score
232 }
233}
234
235impl Breed<QProgram> for BreedEngine {
236 fn two_point_crossover(mate_1: &QProgram, mate_2: &QProgram) -> (QProgram, QProgram) {
237 let (child_1_program, child_2_program) =
238 BreedEngine::two_point_crossover(&mate_1.program, &mate_2.program);
239
240 let mut child_1 = mate_1.clone();
241 let mut child_2 = mate_2.clone();
242
243 child_1.program = child_1_program;
244 child_2.program = child_2_program;
245
246 ResetEngine::reset(&mut child_1.program.id);
247 ResetEngine::reset(&mut child_2.program.id);
248
249 ResetEngine::reset(&mut child_1.program);
250 ResetEngine::reset(&mut child_2.program);
251
252 ResetEngine::reset(&mut child_1.q_table);
253 ResetEngine::reset(&mut child_2.q_table);
254
255 (child_1, child_2)
256 }
257}
258
259impl Status<QProgram> for StatusEngine {
260 fn valid(item: &QProgram) -> bool {
261 StatusEngine::valid(&item.program)
262 }
263
264 fn set_fitness(program: &mut QProgram, fitness: f64) {
265 program.program.fitness = fitness;
266 }
267
268 fn get_fitness(program: &QProgram) -> f64 {
269 program.program.fitness
270 }
271
272 fn evaluated(item: &QProgram) -> bool {
273 StatusEngine::evaluated(&item.program)
274 }
275}
276
277impl Mutate<QProgramGeneratorParameters, QProgram> for MutateEngine {
278 fn mutate(item: &mut QProgram, using: QProgramGeneratorParameters) {
279 MutateEngine::mutate(&mut item.program, using.program_parameters);
280 ResetEngine::reset(&mut item.program);
281 ResetEngine::reset(&mut item.program.id);
282 ResetEngine::reset(&mut item.q_table);
283 }
284}
285
286impl Generate<QProgramGeneratorParameters, QProgram> for GenerateEngine {
287 fn generate(using: QProgramGeneratorParameters) -> QProgram {
288 let program = GenerateEngine::generate(using.program_parameters);
289 let q_table = GenerateEngine::generate((
290 using.program_parameters.instruction_generator_parameters,
291 using.consts,
292 ));
293
294 QProgram { q_table, program }
295 }
296}
297
298#[derive(Debug, Clone, Args, Deserialize, Serialize, Copy, Builder)]
299pub struct QProgramGeneratorParameters {
300 #[command(flatten)]
301 pub program_parameters: ProgramGeneratorParameters,
302 #[builder(default)]
303 #[command(flatten)]
304 pub consts: QConsts,
305}
306
307#[derive(Debug, Clone, Copy, Args, Serialize, Deserialize, Builder)]
308pub struct QConsts {
309 #[arg(long, default_value = "0.1")]
311 #[builder(default = "0.1")]
312 alpha: f64,
313 #[arg(long, default_value = "0.9")]
315 #[builder(default = "0.9")]
316 gamma: f64,
317 #[arg(long, default_value = "0.05")]
319 #[builder(default = "0.05")]
320 epsilon: f64,
321 #[arg(long, default_value = "0.01")]
323 #[builder(default = "0.01")]
324 alpha_decay: f64,
325 #[arg(long, default_value = "0.001")]
327 #[builder(default = "0.001")]
328 epsilon_decay: f64,
329
330 #[arg(skip)]
333 #[serde(skip)]
334 #[builder(setter(skip), default)]
335 alpha_active: f64,
336
337 #[serde(skip)]
338 #[arg(skip)]
339 #[builder(setter(skip), default)]
340 epsilon_active: f64,
341}
342
343impl Reset<QConsts> for ResetEngine {
344 fn reset(item: &mut QConsts) {
345 item.alpha_active = item.alpha;
346 item.epsilon_active = item.epsilon;
347 }
348}
349
350impl QConsts {
351 pub fn new(alpha: f64, gamma: f64, epsilon: f64, alpha_decay: f64, epsilon_decay: f64) -> Self {
352 Self {
353 alpha_active: alpha,
354 epsilon_active: epsilon,
355 alpha,
356 gamma,
357 epsilon,
358 alpha_decay,
359 epsilon_decay,
360 }
361 }
362
363 pub fn decay(&mut self) {
364 self.alpha_active *= 1. - self.alpha_decay;
365 self.epsilon_active *= 1. - self.epsilon_decay
366 }
367}
368
369impl Default for QConsts {
370 fn default() -> Self {
371 let alpha = generator().gen_range(0.0..1.);
372 let gamma = generator().gen_range(0.0..1.);
373 let epsilon = generator().gen_range(0.0..1.);
374 let alpha_decay = generator().gen_range(0.0..1.);
375 let epsilon_decay = generator().gen_range(0.0..1.);
376 Self {
377 alpha,
378 gamma,
379 epsilon,
380 alpha_decay,
381 epsilon_decay,
382 alpha_active: alpha,
383 epsilon_active: epsilon_decay,
384 }
385 }
386}