use rand::{random, Rng};
pub trait Bandit {
fn select_arm(&mut self) -> usize;
fn receive_reward(&mut self, reward: f64);
fn restart(&mut self);
}
#[derive(Debug, Default, Clone, PartialEq)]
struct BanditState {
steps: usize,
n_available_arms: usize,
selected_arm: usize,
arm_pulls: Vec<usize>,
initial_value: f64,
estimated_arm_values: Vec<f64>,
}
impl BanditState {
fn new(n_available_arms: usize) -> BanditState {
BanditState {
steps: 0,
n_available_arms,
selected_arm: 0,
arm_pulls: vec![0; n_available_arms],
initial_value: 0_f64,
estimated_arm_values: vec![0_f64; n_available_arms],
}
}
fn biased(n_available_arms: usize, initial_value: f64) -> BanditState {
BanditState {
steps: 0,
n_available_arms,
selected_arm: 0,
arm_pulls: vec![0; n_available_arms],
initial_value,
estimated_arm_values: vec![initial_value; n_available_arms],
}
}
}
#[derive(Debug, Clone)]
enum BanditAlgorithm {
EpsilonGreedy(EpsilonGreedy),
Ucb(Ucb),
}
#[derive(Debug, Default, Clone)]
struct EpsilonGreedy {
epsilon: f64,
}
#[derive(Debug, Default, Clone)]
struct Ucb {
exploration_degree: f64,
}
#[derive(Debug, Clone)]
pub struct StochasticBandit {
state: BanditState,
algorithm: BanditAlgorithm,
learning_rate: Option<f64>,
}
impl StochasticBandit {
pub fn greedy(arms: usize) -> StochasticBandit {
StochasticBandit {
state: BanditState::new(arms),
algorithm: BanditAlgorithm::EpsilonGreedy(EpsilonGreedy { epsilon: 0_f64 }),
learning_rate: None,
}
}
pub fn epsilon_greedy(arms: usize, epsilon: f64) -> StochasticBandit {
StochasticBandit {
state: BanditState::new(arms),
algorithm: BanditAlgorithm::EpsilonGreedy(EpsilonGreedy { epsilon }),
learning_rate: None,
}
}
pub fn ucb(arms: usize, exploration_degree: f64) -> StochasticBandit {
StochasticBandit {
state: BanditState::new(arms),
algorithm: BanditAlgorithm::Ucb(Ucb { exploration_degree }),
learning_rate: None,
}
}
pub fn with_constant_learning_rate(self, learning_rate: f64) -> StochasticBandit {
if learning_rate <= 0.0 || learning_rate > 1.0 {
panic!("Invalid alpha value: {learning_rate}");
}
StochasticBandit {
state: self.state,
algorithm: self.algorithm,
learning_rate: Some(learning_rate),
}
}
pub fn with_biased_state(self, value: f64) -> StochasticBandit {
StochasticBandit {
state: BanditState::biased(self.state.n_available_arms, value),
algorithm: self.algorithm,
learning_rate: self.learning_rate,
}
}
}
impl Bandit for StochasticBandit {
fn select_arm(&mut self) -> usize {
if self.state.steps == 0 {
self.state.selected_arm = self
.state
.estimated_arm_values
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.map(|(index, _)| index)
.unwrap();
}
match &self.algorithm {
BanditAlgorithm::EpsilonGreedy(bandit) => {
let exploration_probability: f64 = random();
if exploration_probability > 1.0 - bandit.epsilon {
self.state.selected_arm =
rand::thread_rng().gen_range(0..self.state.n_available_arms);
} else {
self.state.selected_arm = self
.state
.estimated_arm_values
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.map(|(index, _)| index)
.unwrap();
}
}
BanditAlgorithm::Ucb(bandit) => {
self.state.selected_arm = self
.state
.estimated_arm_values
.iter()
.enumerate()
.map(|(i, v)| {
v + bandit.exploration_degree
* f64::sqrt(
f64::ln(self.state.steps as f64) / self.state.arm_pulls[i] as f64,
)
})
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.map(|(index, _)| index)
.unwrap();
}
}
self.state.selected_arm
}
fn receive_reward(&mut self, reward: f64) {
self.state.steps += 1;
self.state.arm_pulls[self.state.selected_arm] += 1;
let alpha = self
.learning_rate
.unwrap_or(1.0 / self.state.arm_pulls[self.state.selected_arm] as f64);
self.state.estimated_arm_values[self.state.selected_arm] +=
alpha * (reward - self.state.estimated_arm_values[self.state.selected_arm])
}
fn restart(&mut self) {
self.state.steps = 0;
self.state.selected_arm = 0;
self.state.arm_pulls = vec![0; self.state.n_available_arms];
self.state.estimated_arm_values =
vec![self.state.initial_value; self.state.n_available_arms];
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn greedy_bandit() {
let mut greedy_bandit = StochasticBandit::greedy(5);
assert_eq!(greedy_bandit.state.n_available_arms, 5);
assert_eq!(greedy_bandit.state.estimated_arm_values, vec![0.0; 5]);
greedy_bandit.receive_reward(1.0);
assert_eq!(greedy_bandit.select_arm(), 0);
assert_eq!(
greedy_bandit.state.estimated_arm_values,
vec![1.0, 0.0, 0.0, 0.0, 0.0]
);
greedy_bandit.receive_reward(-5.0);
assert_eq!(
greedy_bandit.state.estimated_arm_values,
vec![-2.0, 0.0, 0.0, 0.0, 0.0]
);
assert_ne!(greedy_bandit.select_arm(), 0);
}
#[test]
fn epsilon_greedy_bandit() {
let epsilon_greedy_bandit = StochasticBandit::epsilon_greedy(10, 0.05);
assert_eq!(epsilon_greedy_bandit.state.n_available_arms, 10);
assert_eq!(
epsilon_greedy_bandit.state.estimated_arm_values,
vec![0.0; 10]
);
assert_eq!(epsilon_greedy_bandit.state.arm_pulls, vec![0; 10]);
assert_eq!(epsilon_greedy_bandit.state.selected_arm, 0);
}
#[test]
fn constant_learning_rate() {
let mut bandit = StochasticBandit::greedy(5)
.with_constant_learning_rate(0.5)
.with_biased_state(1.5)
.with_constant_learning_rate(1.0);
bandit.restart();
assert_eq!(bandit.learning_rate, Some(1.0));
assert_eq!(bandit.state.estimated_arm_values, vec![1.5; 5]);
}
#[test]
#[should_panic(expected = "Invalid alpha value: 0")]
fn zero_learning_rate() {
StochasticBandit::greedy(5).with_constant_learning_rate(0.0);
}
}