use std::fmt::{self, Debug};
use clap::Args;
use derivative::Derivative;
use derive_builder::Builder;
use rand::Rng;
use serde::{Deserialize, Serialize};
use tracing::{debug, trace};
use crate::{
core::{
engines::{
breed_engine::{Breed, BreedEngine},
fitness_engine::{Fitness, FitnessEngine},
freeze_engine::{Freeze, FreezeEngine},
generate_engine::{Generate, GenerateEngine},
mutate_engine::{Mutate, MutateEngine},
reset_engine::{Reset, ResetEngine},
status_engine::{Status, StatusEngine},
},
environment::{RlState, State},
instruction::InstructionGeneratorParameters,
program::{Program, ProgramGeneratorParameters},
registers::{ActionRegister, ArgmaxInput, Registers},
},
utils::{float_ops, random::generator},
};
#[derive(Clone, Serialize, Deserialize)]
pub struct QTable {
table: Vec<Vec<f64>>,
q_consts: QConsts,
freeze: bool,
}
impl Freeze<QTable> for FreezeEngine {
fn freeze(item: &mut QTable) {
item.freeze = true;
}
}
impl Generate<(InstructionGeneratorParameters, QConsts), QTable> for GenerateEngine {
fn generate(using: (InstructionGeneratorParameters, QConsts)) -> QTable {
let mut table = QTable {
table: vec![vec![0.; using.0.n_actions]; using.0.n_registers()],
q_consts: using.1,
freeze: false,
};
ResetEngine::reset(&mut table);
table
}
}
impl Debug for QTable {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.table.iter()).finish()
}
}
#[derive(Debug, Clone, Copy)]
pub struct ActionRegisterPair {
action: usize,
register: usize,
}
impl Reset<QTable> for ResetEngine {
fn reset(item: &mut QTable) {
ResetEngine::reset(&mut item.q_consts);
}
}
impl QTable {
pub fn action_random(&self) -> usize {
let n_actions = self.table[0].len();
generator().gen_range(0..n_actions)
}
pub fn action_argmax(&self, register_number: usize) -> usize {
let available_actions = self
.table
.get(register_number)
.expect("Register number to be less than length of QTable.");
let iter = available_actions.iter().copied();
let max = float_ops::argmax(iter);
max.expect("Available action to yield an index.")
}
pub fn get_action_register(&self, registers: &Registers) -> Option<ActionRegisterPair> {
let winning_register = match registers.argmax(ArgmaxInput::All).any() {
ActionRegister::Value(register) => register,
_ => {
return None;
}
};
let prob = generator().gen_range(0.0..1.0);
let winning_action = if prob <= self.q_consts.epsilon_active {
self.action_random()
} else {
self.action_argmax(winning_register)
};
Some(ActionRegisterPair {
action: winning_action,
register: winning_register,
})
}
pub fn update(
&mut self,
current_action_state: ActionRegisterPair,
current_reward: f64,
next_action_state: ActionRegisterPair,
) {
let current_q_value =
self.table[current_action_state.register][current_action_state.action];
let next_q_value = self.action_argmax(next_action_state.register) as f64;
let new_q_value = self.q_consts.alpha_active
* (current_reward + (self.q_consts.gamma * next_q_value) - current_q_value);
self.table[current_action_state.register][current_action_state.action] += new_q_value;
trace!(
register = current_action_state.register,
action = current_action_state.action,
reward = current_reward,
old_q = current_q_value,
delta_q = new_q_value,
alpha = self.q_consts.alpha_active,
gamma = self.q_consts.gamma,
"Q-table update"
);
if !self.freeze {
self.q_consts.decay();
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Derivative)]
#[derivative(PartialEq, PartialOrd, Ord, Eq)]
pub struct QProgram {
#[derivative(PartialEq = "ignore", PartialOrd = "ignore", Ord = "ignore")]
pub q_table: QTable,
pub program: Program,
}
impl Freeze<QProgram> for FreezeEngine {
fn freeze(item: &mut QProgram) {
FreezeEngine::freeze(&mut item.q_table);
}
}
impl Reset<QProgram> for ResetEngine {
fn reset(item: &mut QProgram) {
ResetEngine::reset(&mut item.program);
}
}
fn get_action_state<T>(environment: &mut T, q_program: &mut QProgram) -> Option<ActionRegisterPair>
where
T: State,
{
q_program.program.run(environment);
q_program
.q_table
.get_action_register(&q_program.program.registers)
}
impl<T: RlState> Fitness<QProgram, T, ()> for FitnessEngine {
fn eval_fitness(program: &mut QProgram, states: &mut T) -> f64 {
let mut score = 0.;
let mut current_action_state = match get_action_state(states, program) {
Some(action_state) => action_state,
None => {
return f64::NEG_INFINITY;
}
};
while let Some(state) = states.get() {
let reward = state.execute_action(current_action_state.action);
score += reward;
if state.is_terminal() {
break;
}
let next_action_state = match get_action_state(state, program) {
Some(action_state) => action_state,
None => {
return f64::NEG_INFINITY;
}
};
if current_action_state.register != next_action_state.register {
program
.q_table
.update(current_action_state, reward, next_action_state)
}
current_action_state = next_action_state;
}
debug!(
program_id = %program.program.id,
score = score,
"Q-Learning fitness evaluation complete"
);
trace!(
program_id = %program.program.id,
q_table = serde_json::to_string(&program.q_table).ok(),
initial_state = serde_json::to_string(&states.get_initial_state()).ok(),
"Full Q-Learning evaluation details"
);
score
}
}
impl Breed<QProgram> for BreedEngine {
fn two_point_crossover(mate_1: &QProgram, mate_2: &QProgram) -> (QProgram, QProgram) {
let (child_1_program, child_2_program) =
BreedEngine::two_point_crossover(&mate_1.program, &mate_2.program);
let mut child_1 = mate_1.clone();
let mut child_2 = mate_2.clone();
child_1.program = child_1_program;
child_2.program = child_2_program;
ResetEngine::reset(&mut child_1.program.id);
ResetEngine::reset(&mut child_2.program.id);
ResetEngine::reset(&mut child_1.program);
ResetEngine::reset(&mut child_2.program);
ResetEngine::reset(&mut child_1.q_table);
ResetEngine::reset(&mut child_2.q_table);
(child_1, child_2)
}
}
impl Status<QProgram> for StatusEngine {
fn valid(item: &QProgram) -> bool {
StatusEngine::valid(&item.program)
}
fn set_fitness(program: &mut QProgram, fitness: f64) {
program.program.fitness = fitness;
}
fn get_fitness(program: &QProgram) -> f64 {
program.program.fitness
}
fn evaluated(item: &QProgram) -> bool {
StatusEngine::evaluated(&item.program)
}
}
impl Mutate<QProgramGeneratorParameters, QProgram> for MutateEngine {
fn mutate(item: &mut QProgram, using: QProgramGeneratorParameters) {
MutateEngine::mutate(&mut item.program, using.program_parameters);
ResetEngine::reset(&mut item.program);
ResetEngine::reset(&mut item.program.id);
ResetEngine::reset(&mut item.q_table);
}
}
impl Generate<QProgramGeneratorParameters, QProgram> for GenerateEngine {
fn generate(using: QProgramGeneratorParameters) -> QProgram {
let program = GenerateEngine::generate(using.program_parameters);
let q_table = GenerateEngine::generate((
using.program_parameters.instruction_generator_parameters,
using.consts,
));
QProgram { q_table, program }
}
}
#[derive(Debug, Clone, Args, Deserialize, Serialize, Copy, Builder)]
pub struct QProgramGeneratorParameters {
#[command(flatten)]
pub program_parameters: ProgramGeneratorParameters,
#[builder(default)]
#[command(flatten)]
pub consts: QConsts,
}
#[derive(Debug, Clone, Copy, Args, Serialize, Deserialize, Builder)]
pub struct QConsts {
#[arg(long, default_value = "0.1")]
#[builder(default = "0.1")]
alpha: f64,
#[arg(long, default_value = "0.9")]
#[builder(default = "0.9")]
gamma: f64,
#[arg(long, default_value = "0.05")]
#[builder(default = "0.05")]
epsilon: f64,
#[arg(long, default_value = "0.01")]
#[builder(default = "0.01")]
alpha_decay: f64,
#[arg(long, default_value = "0.001")]
#[builder(default = "0.001")]
epsilon_decay: f64,
#[arg(skip)]
#[serde(skip)]
#[builder(setter(skip), default)]
alpha_active: f64,
#[serde(skip)]
#[arg(skip)]
#[builder(setter(skip), default)]
epsilon_active: f64,
}
impl Reset<QConsts> for ResetEngine {
fn reset(item: &mut QConsts) {
item.alpha_active = item.alpha;
item.epsilon_active = item.epsilon;
}
}
impl QConsts {
pub fn new(alpha: f64, gamma: f64, epsilon: f64, alpha_decay: f64, epsilon_decay: f64) -> Self {
Self {
alpha_active: alpha,
epsilon_active: epsilon,
alpha,
gamma,
epsilon,
alpha_decay,
epsilon_decay,
}
}
pub fn decay(&mut self) {
self.alpha_active *= 1. - self.alpha_decay;
self.epsilon_active *= 1. - self.epsilon_decay
}
}
impl Default for QConsts {
fn default() -> Self {
let alpha = generator().gen_range(0.0..1.);
let gamma = generator().gen_range(0.0..1.);
let epsilon = generator().gen_range(0.0..1.);
let alpha_decay = generator().gen_range(0.0..1.);
let epsilon_decay = generator().gen_range(0.0..1.);
Self {
alpha,
gamma,
epsilon,
alpha_decay,
epsilon_decay,
alpha_active: alpha,
epsilon_active: epsilon_decay,
}
}
}