reinforcex 0.0.1

Deep Reinforcement Learning Framework
use super::base_explorer::BaseExplorer;
use rand::Rng;

pub struct EpsilonGreedy {
    start_epsilon: f64,
    end_epsilon: f64,
    decay_steps: usize,
}

impl EpsilonGreedy {
    pub fn new(start_epsilon: f64, end_epsilon: f64, decay_steps: usize) -> Self {
        assert!((0.0..=1.0).contains(&start_epsilon));
        assert!((0.0..=1.0).contains(&end_epsilon));
        assert!(decay_steps >= 0);
        EpsilonGreedy {
            start_epsilon,
            end_epsilon,
            decay_steps,
        }
    }
}

impl BaseExplorer for EpsilonGreedy {
    fn select_action(
        &self,
        t: usize,
        random_action_func: &dyn Fn() -> usize,
        greedy_action_func: &dyn Fn() -> usize,
    ) -> usize {
        let epsilon;
        if t > self.decay_steps {
            epsilon = self.end_epsilon
        } else {
            epsilon = self.start_epsilon
                + (self.end_epsilon - self.start_epsilon) * (t as f64 / self.decay_steps as f64)
        }

        let action = if rand::thread_rng().gen::<f64>() < epsilon {
            (random_action_func)()
        } else {
            greedy_action_func()
        };

        action
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use rand::Rng;

    #[test]
    fn test_new() {
        let explorer = EpsilonGreedy::new(0.9, 0.1, 100);
        assert_eq!(explorer.start_epsilon, 0.9);
        assert_eq!(explorer.end_epsilon, 0.1);
        assert_eq!(explorer.decay_steps, 100);
    }

    #[test]
    #[should_panic]
    fn test_new_invalid_epsilon() {
        EpsilonGreedy::new(1.2, 0.1, 100);
    }

    #[test]
    fn test_select_action_exploration() {
        let explorer = EpsilonGreedy::new(1.0, 1.0, 100);
        let random_action = || 456;
        let greedy_action = || 123;

        let action = explorer.select_action(0, &random_action, &greedy_action);
        assert_eq!(action, 456);
    }

    #[test]
    fn test_select_action_exploitation() {
        let explorer = EpsilonGreedy::new(0.0, 0.0, 100);
        let random_action = || 456;
        let greedy_action = || 123;

        let action = explorer.select_action(50, &random_action, &greedy_action);
        assert_eq!(action, 123);
    }

    #[test]
    fn test_select_action_decay() {
        let explorer = EpsilonGreedy::new(1.0, 0.3, 100);
        let random_action = || 456;
        let greedy_action = || 123;

        let mut random_count = 0;
        let mut greedy_count = 0;

        for t in 0..100 {
            let action = explorer.select_action(t, &random_action, &greedy_action);
            if action == 456 {
                random_count += 1;
            } else {
                greedy_count += 1;
            }
        }

        assert!(random_count > 0);
        assert!(greedy_count > 0);
    }
}