Skip to main content

lgp/extensions/
q_learning.rs

1use std::fmt::{self, Debug};
2
3use clap::Args;
4use derivative::Derivative;
5use derive_builder::Builder;
6use rand::Rng;
7use serde::{Deserialize, Serialize};
8use tracing::{debug, trace};
9
10use crate::{
11    core::{
12        engines::{
13            breed_engine::{Breed, BreedEngine},
14            fitness_engine::{Fitness, FitnessEngine},
15            freeze_engine::{Freeze, FreezeEngine},
16            generate_engine::{Generate, GenerateEngine},
17            mutate_engine::{Mutate, MutateEngine},
18            reset_engine::{Reset, ResetEngine},
19            status_engine::{Status, StatusEngine},
20        },
21        environment::{RlState, State},
22        instruction::InstructionGeneratorParameters,
23        program::{Program, ProgramGeneratorParameters},
24        registers::{ActionRegister, ArgmaxInput, Registers},
25    },
26    utils::{float_ops, random::generator},
27};
28
29#[derive(Clone, Serialize, Deserialize)]
30pub struct QTable {
31    table: Vec<Vec<f64>>,
32    q_consts: QConsts,
33    freeze: bool,
34}
35
36impl Freeze<QTable> for FreezeEngine {
37    fn freeze(item: &mut QTable) {
38        item.freeze = true;
39    }
40}
41
42impl Generate<(InstructionGeneratorParameters, QConsts), QTable> for GenerateEngine {
43    fn generate(using: (InstructionGeneratorParameters, QConsts)) -> QTable {
44        let mut table = QTable {
45            table: vec![vec![0.; using.0.n_actions]; using.0.n_registers()],
46            q_consts: using.1,
47            freeze: false,
48        };
49
50        ResetEngine::reset(&mut table);
51        table
52    }
53}
54
55impl Debug for QTable {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        f.debug_list().entries(self.table.iter()).finish()
58    }
59}
60
61#[derive(Debug, Clone, Copy)]
62pub struct ActionRegisterPair {
63    action: usize,
64    register: usize,
65}
66
67impl Reset<QTable> for ResetEngine {
68    fn reset(item: &mut QTable) {
69        ResetEngine::reset(&mut item.q_consts);
70    }
71}
72
73impl QTable {
74    pub fn action_random(&self) -> usize {
75        let n_actions = self.table[0].len();
76        generator().gen_range(0..n_actions)
77    }
78
79    pub fn action_argmax(&self, register_number: usize) -> usize {
80        let available_actions = self
81            .table
82            .get(register_number)
83            .expect("Register number to be less than length of QTable.");
84
85        let iter = available_actions.iter().copied();
86        let max = float_ops::argmax(iter);
87
88        max.expect("Available action to yield an index.")
89    }
90
91    pub fn get_action_register(&self, registers: &Registers) -> Option<ActionRegisterPair> {
92        let winning_register = match registers.argmax(ArgmaxInput::All).any() {
93            ActionRegister::Value(register) => register,
94            _ => {
95                return None;
96            }
97        };
98
99        let prob = generator().gen_range(0.0..1.0);
100
101        let winning_action = if prob <= self.q_consts.epsilon_active {
102            self.action_random()
103        } else {
104            self.action_argmax(winning_register)
105        };
106
107        Some(ActionRegisterPair {
108            action: winning_action,
109            register: winning_register,
110        })
111    }
112
113    pub fn update(
114        &mut self,
115        current_action_state: ActionRegisterPair,
116        current_reward: f64,
117        next_action_state: ActionRegisterPair,
118    ) {
119        let current_q_value =
120            self.table[current_action_state.register][current_action_state.action];
121        let next_q_value = self.action_argmax(next_action_state.register) as f64;
122
123        let new_q_value = self.q_consts.alpha_active
124            * (current_reward + (self.q_consts.gamma * next_q_value) - current_q_value);
125
126        self.table[current_action_state.register][current_action_state.action] += new_q_value;
127
128        trace!(
129            register = current_action_state.register,
130            action = current_action_state.action,
131            reward = current_reward,
132            old_q = current_q_value,
133            delta_q = new_q_value,
134            alpha = self.q_consts.alpha_active,
135            gamma = self.q_consts.gamma,
136            "Q-table update"
137        );
138
139        if !self.freeze {
140            self.q_consts.decay();
141        }
142    }
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize, Derivative)]
146#[derivative(PartialEq, PartialOrd, Ord, Eq)]
147pub struct QProgram {
148    #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Ord = "ignore")]
149    pub q_table: QTable,
150    pub program: Program,
151}
152
153impl Freeze<QProgram> for FreezeEngine {
154    fn freeze(item: &mut QProgram) {
155        FreezeEngine::freeze(&mut item.q_table);
156    }
157}
158
159impl Reset<QProgram> for ResetEngine {
160    fn reset(item: &mut QProgram) {
161        ResetEngine::reset(&mut item.program);
162    }
163}
164
165fn get_action_state<T>(environment: &mut T, q_program: &mut QProgram) -> Option<ActionRegisterPair>
166where
167    T: State,
168{
169    // Run the program on the current state.
170    q_program.program.run(environment);
171
172    // Get the winning action-register pair.
173    q_program
174        .q_table
175        .get_action_register(&q_program.program.registers)
176}
177
178impl<T: RlState> Fitness<QProgram, T, ()> for FitnessEngine {
179    fn eval_fitness(program: &mut QProgram, states: &mut T) -> f64 {
180        let mut score = 0.;
181
182        // We run the program and determine what action to take at the step = 0.
183        let mut current_action_state = match get_action_state(states, program) {
184            Some(action_state) => action_state,
185            None => {
186                return f64::NEG_INFINITY;
187            }
188        };
189
190        // We execute the selected action and continue to repeat the cycle until termination.
191        while let Some(state) = states.get() {
192            // Act.
193            let reward = state.execute_action(current_action_state.action);
194            score += reward;
195
196            if state.is_terminal() {
197                break;
198            }
199
200            let next_action_state = match get_action_state(state, program) {
201                Some(action_state) => action_state,
202                None => {
203                    return f64::NEG_INFINITY;
204                }
205            };
206
207            // We only update when there is a transition.
208            // NOTE: Why?
209            if current_action_state.register != next_action_state.register {
210                program
211                    .q_table
212                    .update(current_action_state, reward, next_action_state)
213            }
214
215            current_action_state = next_action_state;
216        }
217
218        debug!(
219            program_id = %program.program.id,
220            score = score,
221            "Q-Learning fitness evaluation complete"
222        );
223
224        trace!(
225            program_id = %program.program.id,
226            q_table = serde_json::to_string(&program.q_table).ok(),
227            initial_state = serde_json::to_string(&states.get_initial_state()).ok(),
228            "Full Q-Learning evaluation details"
229        );
230
231        score
232    }
233}
234
235impl Breed<QProgram> for BreedEngine {
236    fn two_point_crossover(mate_1: &QProgram, mate_2: &QProgram) -> (QProgram, QProgram) {
237        let (child_1_program, child_2_program) =
238            BreedEngine::two_point_crossover(&mate_1.program, &mate_2.program);
239
240        let mut child_1 = mate_1.clone();
241        let mut child_2 = mate_2.clone();
242
243        child_1.program = child_1_program;
244        child_2.program = child_2_program;
245
246        ResetEngine::reset(&mut child_1.program.id);
247        ResetEngine::reset(&mut child_2.program.id);
248
249        ResetEngine::reset(&mut child_1.program);
250        ResetEngine::reset(&mut child_2.program);
251
252        ResetEngine::reset(&mut child_1.q_table);
253        ResetEngine::reset(&mut child_2.q_table);
254
255        (child_1, child_2)
256    }
257}
258
259impl Status<QProgram> for StatusEngine {
260    fn valid(item: &QProgram) -> bool {
261        StatusEngine::valid(&item.program)
262    }
263
264    fn set_fitness(program: &mut QProgram, fitness: f64) {
265        program.program.fitness = fitness;
266    }
267
268    fn get_fitness(program: &QProgram) -> f64 {
269        program.program.fitness
270    }
271
272    fn evaluated(item: &QProgram) -> bool {
273        StatusEngine::evaluated(&item.program)
274    }
275}
276
277impl Mutate<QProgramGeneratorParameters, QProgram> for MutateEngine {
278    fn mutate(item: &mut QProgram, using: QProgramGeneratorParameters) {
279        MutateEngine::mutate(&mut item.program, using.program_parameters);
280        ResetEngine::reset(&mut item.program);
281        ResetEngine::reset(&mut item.program.id);
282        ResetEngine::reset(&mut item.q_table);
283    }
284}
285
286impl Generate<QProgramGeneratorParameters, QProgram> for GenerateEngine {
287    fn generate(using: QProgramGeneratorParameters) -> QProgram {
288        let program = GenerateEngine::generate(using.program_parameters);
289        let q_table = GenerateEngine::generate((
290            using.program_parameters.instruction_generator_parameters,
291            using.consts,
292        ));
293
294        QProgram { q_table, program }
295    }
296}
297
298#[derive(Debug, Clone, Args, Deserialize, Serialize, Copy, Builder)]
299pub struct QProgramGeneratorParameters {
300    #[command(flatten)]
301    pub program_parameters: ProgramGeneratorParameters,
302    #[builder(default)]
303    #[command(flatten)]
304    pub consts: QConsts,
305}
306
307#[derive(Debug, Clone, Copy, Args, Serialize, Deserialize, Builder)]
308pub struct QConsts {
309    /// Learning Factor
310    #[arg(long, default_value = "0.1")]
311    #[builder(default = "0.1")]
312    alpha: f64,
313    /// Discount Factor
314    #[arg(long, default_value = "0.9")]
315    #[builder(default = "0.9")]
316    gamma: f64,
317    /// Exploration Factor
318    #[arg(long, default_value = "0.05")]
319    #[builder(default = "0.05")]
320    epsilon: f64,
321    /// Learning Rate Decay
322    #[arg(long, default_value = "0.01")]
323    #[builder(default = "0.01")]
324    alpha_decay: f64,
325    /// Exploration Decay
326    #[arg(long, default_value = "0.001")]
327    #[builder(default = "0.001")]
328    epsilon_decay: f64,
329
330    /// To allow new programs to start from the new state, we have active
331    /// properties to mutuate.
332    #[arg(skip)]
333    #[serde(skip)]
334    #[builder(setter(skip), default)]
335    alpha_active: f64,
336
337    #[serde(skip)]
338    #[arg(skip)]
339    #[builder(setter(skip), default)]
340    epsilon_active: f64,
341}
342
343impl Reset<QConsts> for ResetEngine {
344    fn reset(item: &mut QConsts) {
345        item.alpha_active = item.alpha;
346        item.epsilon_active = item.epsilon;
347    }
348}
349
350impl QConsts {
351    pub fn new(alpha: f64, gamma: f64, epsilon: f64, alpha_decay: f64, epsilon_decay: f64) -> Self {
352        Self {
353            alpha_active: alpha,
354            epsilon_active: epsilon,
355            alpha,
356            gamma,
357            epsilon,
358            alpha_decay,
359            epsilon_decay,
360        }
361    }
362
363    pub fn decay(&mut self) {
364        self.alpha_active *= 1. - self.alpha_decay;
365        self.epsilon_active *= 1. - self.epsilon_decay
366    }
367}
368
369impl Default for QConsts {
370    fn default() -> Self {
371        let alpha = generator().gen_range(0.0..1.);
372        let gamma = generator().gen_range(0.0..1.);
373        let epsilon = generator().gen_range(0.0..1.);
374        let alpha_decay = generator().gen_range(0.0..1.);
375        let epsilon_decay = generator().gen_range(0.0..1.);
376        Self {
377            alpha,
378            gamma,
379            epsilon,
380            alpha_decay,
381            epsilon_decay,
382            alpha_active: alpha,
383            epsilon_active: epsilon_decay,
384        }
385    }
386}