mod discounted_thompson;
mod epsilon_greedy;
mod lin_ucb;
mod thompson;
mod ucb;
pub use discounted_thompson::DiscountedThompsonSampling;
pub use epsilon_greedy::EpsilonGreedy;
pub use lin_ucb::LinUCB;
pub use thompson::ThompsonSampling;
pub use ucb::{UCBTuned, UCB1};
pub trait Bandit: Send + Sync {
fn select_arm(&mut self) -> usize;
fn update(&mut self, arm: usize, reward: f64);
fn n_arms(&self) -> usize;
fn n_pulls(&self) -> u64;
fn reset(&mut self);
fn arm_values(&self) -> &[f64];
fn arm_counts(&self) -> &[u64];
}
pub trait ContextualBandit: Send + Sync {
fn select_arm(&mut self, context: &[f64]) -> usize;
fn update(&mut self, arm: usize, context: &[f64], reward: f64);
fn n_arms(&self) -> usize;
fn n_pulls(&self) -> u64;
fn reset(&mut self);
}
pub(crate) use irithyll_core::rng::{standard_normal, xorshift64, xorshift64_f64};
pub(crate) fn gamma_sample(shape: f64, state: &mut u64) -> f64 {
debug_assert!(shape > 0.0, "gamma shape must be positive");
if shape < 1.0 {
let u = xorshift64_f64(state);
return gamma_sample(shape + 1.0, state) * u.powf(1.0 / shape);
}
let d = shape - 1.0 / 3.0;
let c = 1.0 / (9.0 * d).sqrt();
loop {
let x = standard_normal(state);
let v_base = 1.0 + c * x;
if v_base <= 0.0 {
continue;
}
let v = v_base * v_base * v_base;
let u = xorshift64_f64(state);
if u < 1.0 - 0.0331 * (x * x) * (x * x) {
return d * v;
}
if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
return d * v;
}
}
}
pub(crate) fn beta_sample(alpha: f64, beta: f64, state: &mut u64) -> f64 {
let x = gamma_sample(alpha, state);
let y = gamma_sample(beta, state);
x / (x + y)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn xorshift64_deterministic() {
let mut s1: u64 = 42;
let mut s2: u64 = 42;
let a: Vec<u64> = (0..50).map(|_| xorshift64(&mut s1)).collect();
let b: Vec<u64> = (0..50).map(|_| xorshift64(&mut s2)).collect();
assert_eq!(a, b, "same seed should produce identical sequence");
}
#[test]
fn xorshift64_f64_in_unit_interval() {
let mut s: u64 = 123;
for _ in 0..1000 {
let v = xorshift64_f64(&mut s);
assert!(
(0.0..1.0).contains(&v),
"xorshift64_f64 should be in [0, 1), got {}",
v
);
}
}
#[test]
fn standard_normal_finite() {
let mut s: u64 = 99;
for _ in 0..1000 {
let v = standard_normal(&mut s);
assert!(v.is_finite(), "standard_normal should be finite, got {}", v);
}
}
#[test]
fn gamma_sample_positive() {
let mut s: u64 = 77;
for shape in [0.5, 1.0, 2.0, 5.0, 10.0] {
for _ in 0..200 {
let v = gamma_sample(shape, &mut s);
assert!(
v > 0.0 && v.is_finite(),
"gamma({}) should be positive finite, got {}",
shape,
v
);
}
}
}
#[test]
fn beta_sample_in_unit_interval() {
let mut s: u64 = 55;
for (a, b) in [(1.0, 1.0), (2.0, 5.0), (0.5, 0.5), (10.0, 10.0)] {
for _ in 0..200 {
let v = beta_sample(a, b, &mut s);
assert!(
(0.0..=1.0).contains(&v) && v.is_finite(),
"beta({}, {}) should be in [0, 1], got {}",
a,
b,
v
);
}
}
}
#[test]
fn beta_sample_mean_approximately_correct() {
let mut s: u64 = 200;
let alpha = 3.0;
let beta = 7.0;
let n = 5000;
let sum: f64 = (0..n).map(|_| beta_sample(alpha, beta, &mut s)).sum();
let mean = sum / n as f64;
let expected = alpha / (alpha + beta); assert!(
(mean - expected).abs() < 0.05,
"beta({}, {}) mean should be ~{}, got {}",
alpha,
beta,
expected,
mean
);
}
}