1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
use std::marker::PhantomData;

use num_traits::ToPrimitive;
use rand::distributions::Distribution;
use rand::thread_rng;
use rand_distr::uniform::Uniform;

use crate::agent::agent::update;
use crate::Stepper;

use super::{Agent, ArgBounds};

/// Agent that follows the Epsilon-Greedy Algorithm.
///
/// A fixed (usually small) percentage of the
/// time it picks a random arm; the rest of the time it picks the arm with the highest expected
/// reward.
pub struct EpsilonGreedyAgent<T> {
    /// The current estimates of the Bandit arm values.
    q_star: Vec<f64>,

    /// The Agent's rule for step size updates.
    stepper: Box<dyn Stepper>,

    /// The fraction of times a random arm is chosen.
    epsilon: f64,

    /// A random uniform distribution to determine if a random action should be chosen.
    uniform: Uniform<f64>,

    /// A random uniform distribution to chose a random arm.
    pick_arm: Uniform<usize>,
    phantom: PhantomData<T>,
}

impl<T: ToPrimitive> Agent<T> for EpsilonGreedyAgent<T> {
    /// The action chosen by the Agent. A random action with probability `epsilon` and the greedy
    /// action otherwise.
    fn action(&self) -> usize {
        if self.uniform.sample(&mut thread_rng()) < self.epsilon {
            self.pick_arm.sample(&mut thread_rng())
        } else {
            self.q_star.arg_max()
        }
    }

    /// The number of arms in the Bandit the Agent is playing.
    fn arms(&self) -> usize {
        self.q_star.len()
    }

    /// The Agent's current estimate of the value of a Bandit's arm.
    fn current_estimate(&self, arm: usize) -> f64 {
        self.q_star[arm]
    }

    /// Reset the Agent's history and give it a new initial guess of the Bandit's arm values.
    fn reset(&mut self, q_init: &[f64]) {
        self.q_star = q_init.to_owned();
        self.stepper.reset()
    }

    /// Update the Agent's estimate of a Bandit arm based on a given reward.
    fn step(&mut self, arm: usize, reward: T) {
        self.q_star[arm] += update(&mut self.stepper, &self.q_star, arm, reward)
    }
}

impl<T> EpsilonGreedyAgent<T> {
    /// Initializes a new Epsilon-Greedy agent.
    pub fn new(q_init: Vec<f64>, stepper: Box<dyn Stepper>, epsilon: f64) -> EpsilonGreedyAgent<T> {
        assert!(epsilon > 0.0);
        assert!(epsilon < 1.0);
        let l = q_init.len();
        EpsilonGreedyAgent {
            q_star: q_init,
            stepper,
            epsilon,
            uniform: Uniform::new(0.0, 1.0),
            pick_arm: Uniform::new(0usize, l),
            phantom: PhantomData,
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::HarmonicStepper;

    use super::{Agent, EpsilonGreedyAgent};

    #[test]
    fn test_new() {
        let Q_INIT = vec![0.5, 0.61, 0.7, 0.12, 0.37];
        let stepper = HarmonicStepper::new(1, Q_INIT.len());
        let eps = 0.1;
        let epsilon: EpsilonGreedyAgent<u32> =
            EpsilonGreedyAgent::new(Q_INIT, Box::new(stepper), eps);
        assert_eq!(epsilon.epsilon, eps);
        assert_eq!(epsilon.q_star, vec![0.5, 0.61, 0.7, 0.12, 0.37])
    }

    #[test]
    #[should_panic]
    fn test_new_big_epsilon() {
        let Q_INIT = vec![0.5, 0.61, 0.7, 0.12, 0.37];
        let stepper = HarmonicStepper::new(1, Q_INIT.len());
        let eps = 1.3;
        let epsilon: EpsilonGreedyAgent<u32> =
            EpsilonGreedyAgent::new(Q_INIT, Box::new(stepper), eps);
    }

    #[test]
    #[should_panic]
    fn test_new_small_epsilon() {
        let Q_INIT = vec![0.5, 0.61, 0.7, 0.12, 0.37];
        let stepper = HarmonicStepper::new(1, Q_INIT.len());
        let eps = -0.3;
        let epsilon: EpsilonGreedyAgent<u32> =
            EpsilonGreedyAgent::new(Q_INIT, Box::new(stepper), eps);
    }

    #[test]
    fn test_q_star() {
        let Q_INIT = vec![0.5, 0.61, 0.7, 0.12, 0.37];
        let stepper = HarmonicStepper::new(1, Q_INIT.len());
        let epsilon: EpsilonGreedyAgent<u32> =
            EpsilonGreedyAgent::new(Q_INIT, Box::new(stepper), 0.1);
        assert_eq!(epsilon.q_star, vec![0.5, 0.61, 0.7, 0.12, 0.37])
    }

    #[test]
    fn test_reset() {
        let Q_INIT = vec![0.5, 0.61, 0.7, 0.12, 0.37];
        let stepper = HarmonicStepper::new(1, Q_INIT.len());
        let mut epsilon: EpsilonGreedyAgent<u32> =
            EpsilonGreedyAgent::new(Q_INIT, Box::new(stepper), 0.1);
        let new_q = vec![0.01, 0.86, 0.43, 0.65, 0.66];
        epsilon.reset(&new_q);
        assert_eq!(epsilon.q_star, new_q)
    }
}