Skip to main content

lgp/extensions/
interactive.rs

1use core::fmt::Debug;
2
3use serde::Serialize;
4use tracing::{instrument, trace};
5
6use crate::core::engines::fitness_engine::Fitness;
7use crate::core::engines::fitness_engine::FitnessEngine;
8
9use crate::core::environment::RlState;
10use crate::core::program::Program;
11use crate::core::registers::ActionRegister;
12use crate::core::registers::ArgmaxInput;
13
14#[derive(Debug, Serialize, Clone, Copy)]
15pub enum Reward {
16    Continue(f64),
17    Terminal(f64),
18}
19
20impl Reward {
21    pub fn get_value(&self) -> f64 {
22        match *self {
23            Reward::Continue(reward) => reward,
24            Reward::Terminal(reward) => reward,
25        }
26    }
27
28    pub fn is_terminal(&self) -> bool {
29        match self {
30            Reward::Continue(_) => false,
31            Reward::Terminal(_) => true,
32        }
33    }
34}
35
36pub struct UseRlFitness;
37
38impl<T> Fitness<Program, T, UseRlFitness> for FitnessEngine
39where
40    T: RlState,
41{
42    #[instrument(skip_all, fields(program_id = %program.id), level = "trace")]
43    fn eval_fitness(program: &mut crate::core::program::Program, states: &mut T) -> f64 {
44        let mut score = 0.;
45        let mut step = 0;
46
47        while let Some(state) = states.get() {
48            // Run program.
49            program.run(state);
50
51            // Eval
52            let reward = match program.registers.argmax(ArgmaxInput::ActionRegisters).any() {
53                ActionRegister::Value(action) => {
54                    let r = state.execute_action(action);
55                    trace!(step = step, action = action, reward = r, "Step executed");
56                    r
57                }
58                ActionRegister::Overflow => {
59                    trace!(step = step, "Register overflow - returning NEG_INFINITY");
60                    return f64::NEG_INFINITY;
61                }
62            };
63
64            score += reward;
65            step += 1;
66        }
67
68        trace!(total_steps = step, final_score = score, "Episode complete");
69        score
70    }
71}