use super::super::{
buffers::VecBuffer, finite::FiniteSpaceAgent, Actor, ActorMode, Agent, BatchUpdate, BuildAgent,
BuildAgentError, HistoryDataBound,
};
use crate::envs::EnvStructure;
use crate::feedback::Reward;
use crate::logging::StatsLogger;
use crate::simulation::PartialStep;
use crate::spaces::{FiniteSpace, IntervalSpace};
use crate::utils::iter::ArgMaxBy;
use crate::Prng;
use ndarray::{Array, Array2, Axis};
use rand::distributions::Distribution;
use rand_distr::Beta;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct BetaThompsonSamplingAgentConfig {
pub num_samples: usize,
}
impl BetaThompsonSamplingAgentConfig {
#[must_use]
pub const fn new(num_samples: usize) -> Self {
Self { num_samples }
}
}
impl Default for BetaThompsonSamplingAgentConfig {
fn default() -> Self {
Self::new(1)
}
}
impl<OS, AS> BuildAgent<OS, AS, IntervalSpace<Reward>> for BetaThompsonSamplingAgentConfig
where
OS: FiniteSpace + Clone + 'static,
AS: FiniteSpace + Clone + 'static,
{
type Agent = BetaThompsonSamplingAgent<OS, AS>;
fn build_agent(
&self,
env: &dyn EnvStructure<
ObservationSpace = OS,
ActionSpace = AS,
FeedbackSpace = IntervalSpace<Reward>,
>,
_: &mut Prng,
) -> Result<Self::Agent, BuildAgentError> {
let observation_space = env.observation_space();
let action_space = env.action_space();
let IntervalSpace {
low: Reward(r_min),
high: Reward(r_max),
} = env.feedback_space();
Ok(FiniteSpaceAgent {
agent: BaseBetaThompsonSamplingAgent::new(
observation_space.size(),
action_space.size(),
(r_min, r_max),
self.num_samples,
),
observation_space,
action_space,
})
}
}
pub type BetaThompsonSamplingAgent<OS, AS> =
FiniteSpaceAgent<BaseBetaThompsonSamplingAgent, OS, AS>;
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct BaseBetaThompsonSamplingAgent {
pub reward_threshold: Reward,
pub num_samples: usize,
low_high_reward_counts: Arc<Array2<(u64, u64)>>,
}
impl BaseBetaThompsonSamplingAgent {
pub fn new(
num_observations: usize,
num_actions: usize,
reward_range: (f64, f64),
num_samples: usize,
) -> Self {
let (reward_min, reward_max) = reward_range;
let reward_threshold = Reward((reward_min + reward_max) / 2.0);
let low_high_reward_counts =
Arc::new(Array::from_elem((num_observations, num_actions), (1, 1)));
Self {
reward_threshold,
num_samples,
low_high_reward_counts,
}
}
}
impl fmt::Display for BaseBetaThompsonSamplingAgent {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"BaseBetaThompsonSamplingAgent({})",
self.reward_threshold
)
}
}
impl BaseBetaThompsonSamplingAgent {
fn step_update(&mut self, step: PartialStep<usize, usize>) {
let reward_count = Arc::get_mut(&mut self.low_high_reward_counts)
.expect("cannot update agent while actors exist")
.get_mut((step.observation, step.action))
.unwrap();
if step.feedback > self.reward_threshold {
reward_count.1 += 1;
} else {
reward_count.0 += 1;
}
}
}
impl Agent<usize, usize> for BaseBetaThompsonSamplingAgent {
type Actor = BaseBetaThompsonSamplingActor;
fn actor(&self, mode: ActorMode) -> Self::Actor {
BaseBetaThompsonSamplingActor {
mode,
num_samples: self.num_samples,
low_high_reward_counts: Arc::clone(&self.low_high_reward_counts),
}
}
}
impl BatchUpdate<usize, usize> for BaseBetaThompsonSamplingAgent {
type Feedback = Reward;
type HistoryBuffer = VecBuffer<usize, usize>;
fn buffer(&self) -> Self::HistoryBuffer {
VecBuffer::new()
}
fn min_update_size(&self) -> HistoryDataBound {
HistoryDataBound {
min_steps: 1,
slack_steps: 0,
}
}
fn batch_update<'a, I>(&mut self, buffers: I, _logger: &mut dyn StatsLogger)
where
I: IntoIterator<Item = &'a mut Self::HistoryBuffer>,
Self::HistoryBuffer: 'a,
{
for buffer in buffers {
for step in buffer.drain_steps() {
self.step_update(step)
}
}
}
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct BaseBetaThompsonSamplingActor {
mode: ActorMode,
num_samples: usize,
low_high_reward_counts: Arc<Array2<(u64, u64)>>,
}
impl Actor<usize, usize> for BaseBetaThompsonSamplingActor {
type EpisodeState = ();
fn initial_state(&self, _: &mut Prng) -> Self::EpisodeState {}
fn act(&self, _: &mut Self::EpisodeState, observation: &usize, rng: &mut Prng) -> usize {
match self.mode {
ActorMode::Training => self
.low_high_reward_counts
.index_axis(Axis(0), *observation)
.mapv(|(beta, alpha)| -> f64 {
Beta::new(alpha as f64, beta as f64)
.unwrap()
.sample_iter(&mut *rng)
.take(self.num_samples)
.sum()
})
.into_iter()
.argmax_by(|a, b| a.partial_cmp(b).unwrap())
.expect("empty action space"),
ActorMode::Evaluation => self
.low_high_reward_counts
.index_axis(Axis(0), *observation)
.mapv(|(beta, alpha)| alpha as f64 / (alpha + beta) as f64)
.into_iter()
.argmax_by(|a, b| a.partial_cmp(b).unwrap())
.expect("empty action space"),
}
}
}
#[cfg(test)]
mod beta_thompson_sampling {
use super::super::super::testing;
use super::*;
#[test]
fn learns_determinstic_bandit() {
testing::train_deterministic_bandit(&BetaThompsonSamplingAgentConfig::default(), 1000, 0.9);
}
}