rosomaxa/algorithms/rl/
slot_machine.rs1#[cfg(test)]
2#[path = "../../../tests/unit/algorithms/rl/slot_machine_test.rs"]
3mod slot_machine_test;
4
5use crate::utils::{DistributionSampler, Float};
6use std::fmt::{Display, Formatter};
7
8pub trait SlotAction {
10 type Context;
12 type Feedback: SlotFeedback;
14
15 fn take(&self, context: Self::Context) -> Self::Feedback;
17}
18
19pub trait SlotFeedback {
21 fn reward(&self) -> Float;
23}
24
25#[derive(Clone)]
28pub struct SlotMachine<A, S> {
29 n: usize,
31 alpha: Float,
33 beta: Float,
35 mu: Float,
37 v: Float,
39 sampler: S,
41 action: A,
43}
44
45impl<A, S> SlotMachine<A, S>
46where
47 A: SlotAction + Clone,
48 S: DistributionSampler + Clone,
49{
50 pub fn new(prior_mean: Float, action: A, sampler: S) -> Self {
52 let alpha = 1.;
53 let beta = 10.;
54 let mu = prior_mean;
55 let v = beta / (alpha + 1.);
56
57 Self { n: 0, alpha, beta, mu, v, action, sampler }
58 }
59
60 pub fn sample(&self) -> Float {
62 let precision = self.sampler.gamma(self.alpha, 1. / self.beta);
63 let precision = if precision == 0. || self.n == 0 { 0.001 } else { precision };
64 let variance = 1. / precision;
65
66 self.sampler.normal(self.mu, variance.sqrt())
67 }
68
69 pub fn play(&self, context: A::Context) -> A::Feedback {
71 self.action.take(context)
72 }
73
74 pub fn update(&mut self, feedback: &A::Feedback) {
76 let reward = feedback.reward();
77
78 let n = 1.;
79 let v = self.n as Float;
80
81 self.alpha += n / 2.;
82 self.beta += (n * v / (v + n)) * (reward - self.mu).powi(2) / 2.;
83
84 self.v = self.beta / (self.alpha + 1.);
86 self.n += 1;
87 self.mu += (reward - self.mu) / self.n as Float;
88 }
89
90 pub fn get_params(&self) -> (Float, Float, Float, Float, usize) {
92 (self.alpha, self.beta, self.mu, self.v, self.n)
93 }
94}
95
96impl<T, S> Display for SlotMachine<T, S>
97where
98 T: Clone,
99 S: Clone,
100{
101 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
102 write!(f, "n={},alpha={},beta={},mu={},v={}", self.n, self.alpha, self.beta, self.mu, self.v)
103 }
104}