use rand_distr::Distribution;
use rand_distr::Normal;
pub trait Arm {
fn value(&self) -> Option<f64>;
fn pull(&self) -> f64;
}
#[derive(Clone, Debug)]
pub struct RandomArm<D: Distribution<f64>> {
value: Option<f64>,
reward_distribution: D,
}
impl<D: Distribution<f64>> RandomArm<D> {
pub fn from_distribution(value: Option<f64>, reward_distribution: D) -> Self {
RandomArm {
value,
reward_distribution,
}
}
}
impl RandomArm<Normal<f64>> {
pub fn normal(value: f64) -> Self {
RandomArm {
value: Some(value),
reward_distribution: Normal::new(value, 1.0).unwrap(),
}
}
}
impl<D: Distribution<f64>> Arm for RandomArm<D> {
fn value(&self) -> Option<f64> {
self.value
}
fn pull(&self) -> f64 {
self.reward_distribution.sample(&mut rand::thread_rng())
}
}
#[derive(Clone, Debug)]
pub struct MultiArm<A: Arm> {
arms: Vec<A>,
}
impl<A: Arm> MultiArm<A> {
pub fn new(arms: Vec<A>) -> MultiArm<A> {
MultiArm { arms }
}
pub fn pull(&self, k: usize) -> f64 {
self.arms[k].pull()
}
pub fn optimal_arm(&self) -> Option<usize> {
if self.arms.iter().any(|arm| arm.value().is_none()) {
None
} else {
self.arms
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.value().unwrap().total_cmp(&b.value().unwrap()))
.map(|(index, _)| index)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::distributions::Uniform;
#[test]
fn standard_normal_arm() {
let arm = RandomArm::normal(0f64);
assert_eq!(arm.value, Some(0f64));
assert_eq!(arm.value(), arm.value);
}
#[test]
fn uniform_arm() {
let arm = RandomArm::from_distribution(None, Uniform::new(0.0, 1.0));
assert_eq!(arm.value, None);
assert_eq!(arm.value(), arm.value);
let reward = arm.pull();
assert!((0f64..1f64).contains(&reward));
}
#[test]
fn optimal_arm() {
let arms = vec![
RandomArm::from_distribution(None, Uniform::new(0.0, 1.0)),
RandomArm::from_distribution(Some(5f64), Uniform::new(-10.0, 10.0)),
RandomArm::from_distribution(Some(0.5), Uniform::new(-1.0, 1.0)),
];
let multi_arm = MultiArm::new(arms);
assert_eq!(multi_arm.optimal_arm(), None);
let arms = vec![
RandomArm::from_distribution(Some(1f64), Uniform::new(0.0, 1.0)),
RandomArm::from_distribution(Some(5f64), Uniform::new(-10.0, 10.0)),
RandomArm::from_distribution(Some(0.5), Uniform::new(-1.0, 1.0)),
];
let multi_arm = MultiArm::new(arms);
assert_eq!(multi_arm.optimal_arm(), Some(1));
}
}