use super::{CloneBuild, EnvDistribution, EnvStructure, Environment, Successor};
use crate::feedback::Reward;
use crate::logging::StatsLogger;
use crate::spaces::{IndexSpace, IntervalSpace, SingletonSpace};
use crate::utils::distributions::{Bernoulli, Bounded, Deterministic, FromMean};
use crate::Prng;
use rand::distributions::{Distribution, Uniform};
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Bandit<D> {
distributions: Vec<D>,
}
impl<D: Clone> CloneBuild for Bandit<D> {}
impl<D> Bandit<D> {
#[must_use]
pub fn new(distributions: Vec<D>) -> Self {
Self { distributions }
}
}
impl<D: Bounded<f64>> EnvStructure for Bandit<D> {
type ObservationSpace = SingletonSpace;
type ActionSpace = IndexSpace;
type FeedbackSpace = IntervalSpace<Reward>;
fn observation_space(&self) -> Self::ObservationSpace {
SingletonSpace::new()
}
fn action_space(&self) -> Self::ActionSpace {
IndexSpace::new(self.distributions.len())
}
fn feedback_space(&self) -> Self::FeedbackSpace {
let (min, max) = self
.distributions
.iter()
.map(Bounded::bounds)
.reduce(|(a_min, a_max), (b_min, b_max)| (a_min.min(b_min), a_max.max(b_max)))
.unwrap_or((0.0, 0.0));
IntervalSpace::new(Reward(min), Reward(max))
}
fn discount_factor(&self) -> f64 {
1.0
}
}
impl<D: Distribution<f64> + Bounded<f64>> Environment for Bandit<D> {
type State = ();
type Observation = ();
type Action = usize;
type Feedback = Reward;
fn initial_state(&self, _: &mut Prng) -> Self::State {}
fn observe(&self, _: &Self::State, _: &mut Prng) -> Self::State {}
fn step(
&self,
_state: Self::State,
action: &Self::Action,
rng: &mut Prng,
_logger: &mut dyn StatsLogger,
) -> (Successor<Self::State>, Self::Feedback) {
let reward = self.distributions[*action].sample(rng);
(Successor::Terminate, reward.into())
}
}
impl<D: FromMean<f64>> Bandit<D> {
pub fn from_means<I: IntoIterator<Item = T>, T: Borrow<f64>>(
means: I,
) -> Result<Self, D::Error> {
means
.into_iter()
.map(|m| D::from_mean(*m.borrow()))
.collect::<Result<_, _>>()
.map(Self::new)
}
}
pub type BernoulliBandit = Bandit<Bernoulli>;
impl BernoulliBandit {
pub fn uniform<R: Rng>(num_arms: usize, rng: &mut R) -> Self {
let distributions = rng
.sample_iter(Uniform::new_inclusive(0.0, 1.0))
.take(num_arms)
.map(|p| Bernoulli::new(p).unwrap())
.collect();
Self { distributions }
}
}
pub type DeterministicBandit = Bandit<Deterministic<f64>>;
impl DeterministicBandit {
pub fn from_values<I: IntoIterator<Item = T>, T: Borrow<f64>>(values: I) -> Self {
Self::from_means(values).unwrap()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct UniformBernoulliBandits {
pub num_arms: usize,
}
impl UniformBernoulliBandits {
#[must_use]
pub const fn new(num_arms: usize) -> Self {
Self { num_arms }
}
}
impl Default for UniformBernoulliBandits {
fn default() -> Self {
Self { num_arms: 2 }
}
}
impl CloneBuild for UniformBernoulliBandits {}
impl EnvStructure for UniformBernoulliBandits {
type ObservationSpace = SingletonSpace;
type ActionSpace = IndexSpace;
type FeedbackSpace = IntervalSpace<Reward>;
fn observation_space(&self) -> Self::ObservationSpace {
SingletonSpace::new()
}
fn action_space(&self) -> Self::ActionSpace {
IndexSpace::new(self.num_arms)
}
fn feedback_space(&self) -> Self::FeedbackSpace {
IntervalSpace::new(Reward(0.0), Reward(1.0))
}
fn discount_factor(&self) -> f64 {
1.0
}
}
impl EnvDistribution for UniformBernoulliBandits {
type State = <Self::Environment as Environment>::State;
type Observation = <Self::Environment as Environment>::Observation;
type Action = <Self::Environment as Environment>::Action;
type Feedback = <Self::Environment as Environment>::Feedback;
type Environment = BernoulliBandit;
#[inline]
fn sample_environment(&self, rng: &mut Prng) -> Self::Environment {
BernoulliBandit::uniform(self.num_arms, rng)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct OneHotBandits {
pub num_arms: usize,
}
impl OneHotBandits {
#[must_use]
pub const fn new(num_arms: usize) -> Self {
Self { num_arms }
}
}
impl Default for OneHotBandits {
fn default() -> Self {
Self { num_arms: 2 }
}
}
impl CloneBuild for OneHotBandits {}
impl EnvStructure for OneHotBandits {
type ObservationSpace = SingletonSpace;
type ActionSpace = IndexSpace;
type FeedbackSpace = IntervalSpace<Reward>;
fn observation_space(&self) -> Self::ObservationSpace {
SingletonSpace::new()
}
fn action_space(&self) -> Self::ActionSpace {
IndexSpace::new(self.num_arms)
}
fn feedback_space(&self) -> Self::FeedbackSpace {
IntervalSpace::new(Reward(0.0), Reward(1.0))
}
fn discount_factor(&self) -> f64 {
1.0
}
}
impl EnvDistribution for OneHotBandits {
type State = <Self::Environment as Environment>::State;
type Observation = <Self::Environment as Environment>::Observation;
type Action = <Self::Environment as Environment>::Action;
type Feedback = <Self::Environment as Environment>::Feedback;
type Environment = DeterministicBandit;
#[inline]
fn sample_environment(&self, rng: &mut Prng) -> Self::Environment {
let mut means = vec![0.0; self.num_arms];
let index = rng.gen_range(0..self.num_arms);
means[index] = 1.0;
DeterministicBandit::from_means(means).unwrap()
}
}
#[cfg(test)]
mod bernoulli_bandit {
use super::super::testing;
use super::*;
#[test]
fn run() {
let env = BernoulliBandit::from_means(vec![0.2, 0.8]).unwrap();
testing::check_structured_env(&env, 1000, 0);
}
#[test]
fn rewards() {
let mean = 0.2;
let num_samples = 10000;
let env = BernoulliBandit::from_means(vec![mean]).unwrap();
let mut rng = Prng::seed_from_u64(1);
let mut reward_1_count = 0;
for _ in 0..num_samples {
let (_, feedback) = env.step((), &0, &mut rng, &mut ());
#[allow(clippy::float_cmp)] if feedback < Reward(0.5) {
assert_eq!(feedback, Reward(0.0));
} else {
assert_eq!(feedback, Reward(1.0));
reward_1_count += 1
}
}
let bin_mean = f64::from(num_samples) * mean;
let bin_stddev = (f64::from(num_samples) * mean * (1.0 - mean)).sqrt();
assert!(
((bin_mean - 3.5 * bin_stddev)..=(bin_mean + 3.5 * bin_stddev))
.contains(&reward_1_count.into())
);
}
}
#[cfg(test)]
mod deterministic_bandit {
use super::super::testing;
use super::*;
#[test]
fn run() {
let env = DeterministicBandit::from_values(vec![0.2, 0.8]);
testing::check_structured_env(&env, 1000, 0);
}
#[test]
#[allow(clippy::float_cmp)] fn rewards() {
let mut rng = Prng::seed_from_u64(0);
let env = DeterministicBandit::from_values(vec![0.2, 0.8]);
let (_, reward_0) = env.step((), &0, &mut rng, &mut ());
assert_eq!(reward_0, Reward(0.2));
let (_, reward_1) = env.step((), &1, &mut rng, &mut ());
assert_eq!(reward_1, Reward(0.8));
}
}
#[cfg(test)]
mod uniform_determistic_bandits {
use super::super::testing;
use super::*;
use rand::SeedableRng;
#[test]
fn run_sample() {
let env_dist = UniformBernoulliBandits::new(3);
let mut rng = Prng::seed_from_u64(284);
let env = env_dist.sample_environment(&mut rng);
testing::check_structured_env(&env, 1000, 286);
}
#[test]
fn subset_env_structure() {
let env_dist = UniformBernoulliBandits::new(3);
testing::check_env_distribution_structure(&env_dist, 2);
}
}
#[cfg(test)]
mod needle_haystack_bandits {
use super::super::testing;
use super::*;
use rand::SeedableRng;
#[test]
fn run_sample() {
let env_dist = OneHotBandits::new(3);
let mut rng = Prng::seed_from_u64(284);
let env = env_dist.sample_environment(&mut rng);
testing::check_structured_env(&env, 1000, 286);
}
#[test]
fn subset_env_structure() {
let env_dist = OneHotBandits::new(3);
testing::check_env_distribution_structure(&env_dist, 2);
}
}