use std::collections::HashMap;
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct BanditArm {
pub id: String,
pub pulls: u64,
pub rewards: f64,
pub last_updated: u64,
}
impl BanditArm {
#[must_use]
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
pulls: 0,
rewards: 0.0,
last_updated: 0,
}
}
#[must_use]
pub fn mean_reward(&self) -> f64 {
if self.pulls == 0 {
0.0
} else {
self.rewards / self.pulls as f64
}
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct EpsilonGreedy {
pub epsilon: f32,
pub arms: Vec<BanditArm>,
}
impl EpsilonGreedy {
#[must_use]
pub fn new(epsilon: f32, arms: Vec<BanditArm>) -> Self {
Self {
epsilon: epsilon.clamp(0.0, 1.0),
arms,
}
}
#[must_use]
pub fn select(&self, seed: u64) -> usize {
if self.arms.is_empty() {
return 0;
}
let rand_val = lcg_f64(seed);
if rand_val < f64::from(self.epsilon) {
let rand_idx = lcg_u64(seed.wrapping_add(1)) % self.arms.len() as u64;
rand_idx as usize
} else {
self.best_arm()
}
}
pub fn update(&mut self, arm_idx: usize, reward: f64) {
if let Some(arm) = self.arms.get_mut(arm_idx) {
arm.pulls += 1;
arm.rewards += reward;
}
}
#[must_use]
pub fn best_arm(&self) -> usize {
self.arms
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.mean_reward()
.partial_cmp(&b.mean_reward())
.unwrap_or(std::cmp::Ordering::Equal)
})
.map_or(0, |(idx, _)| idx)
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct Ucb1Bandit {
pub arms: Vec<BanditArm>,
pub total_pulls: u64,
}
impl Ucb1Bandit {
#[must_use]
pub fn new(arms: Vec<BanditArm>) -> Self {
Self {
arms,
total_pulls: 0,
}
}
#[must_use]
fn ucb1_score(&self, arm: &BanditArm) -> f64 {
if arm.pulls == 0 {
return f64::INFINITY;
}
let exploration = ((2.0 * (self.total_pulls as f64).ln()) / arm.pulls as f64).sqrt();
arm.mean_reward() + exploration
}
#[must_use]
pub fn select(&self) -> usize {
self.arms
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
self.ucb1_score(a)
.partial_cmp(&self.ucb1_score(b))
.unwrap_or(std::cmp::Ordering::Equal)
})
.map_or(0, |(idx, _)| idx)
}
pub fn update(&mut self, arm_idx: usize, reward: f64) {
if let Some(arm) = self.arms.get_mut(arm_idx) {
arm.pulls += 1;
arm.rewards += reward;
self.total_pulls += 1;
}
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct ThompsonSampling {
pub alpha: Vec<f64>,
pub beta: Vec<f64>,
}
impl ThompsonSampling {
#[must_use]
pub fn new(n: usize) -> Self {
Self {
alpha: vec![1.0; n],
beta: vec![1.0; n],
}
}
#[must_use]
pub fn sample_beta(alpha: f64, beta: f64, seed: u64) -> f64 {
let mut s = seed;
let gamma_a = sample_gamma(alpha, &mut s);
let gamma_b = sample_gamma(beta, &mut s);
if gamma_a + gamma_b == 0.0 {
return 0.5;
}
gamma_a / (gamma_a + gamma_b)
}
#[must_use]
pub fn select(&self, seed: u64) -> usize {
self.alpha
.iter()
.zip(self.beta.iter())
.enumerate()
.map(|(i, (&a, &b))| {
let s = Self::sample_beta(
a,
b,
seed.wrapping_add((i as u64).wrapping_mul(6364136223846793005)),
);
(i, s)
})
.max_by(|(_, s1), (_, s2)| s1.partial_cmp(s2).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i)
}
pub fn update_success(&mut self, arm: usize) {
if arm < self.alpha.len() {
self.alpha[arm] += 1.0;
}
}
pub fn update_failure(&mut self, arm: usize) {
if arm < self.beta.len() {
self.beta[arm] += 1.0;
}
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct ContentBandit {
inner: EpsilonGreedy,
content_to_arm: HashMap<u64, usize>,
arm_to_content: Vec<u64>,
}
impl ContentBandit {
#[must_use]
pub fn new(content_ids: Vec<u64>, epsilon: f32) -> Self {
let arms: Vec<BanditArm> = content_ids
.iter()
.map(|id| BanditArm::new(id.to_string()))
.collect();
let content_to_arm: HashMap<u64, usize> = content_ids
.iter()
.enumerate()
.map(|(i, &id)| (id, i))
.collect();
Self {
inner: EpsilonGreedy::new(epsilon, arms),
content_to_arm,
arm_to_content: content_ids,
}
}
#[must_use]
pub fn select_content(&self, seed: u64) -> Option<u64> {
let arm_idx = self.inner.select(seed);
self.arm_to_content.get(arm_idx).copied()
}
pub fn update(&mut self, content_id: u64, reward: f64) {
if let Some(&arm_idx) = self.content_to_arm.get(&content_id) {
self.inner.update(arm_idx, reward);
}
}
#[must_use]
pub fn best_content(&self) -> Option<u64> {
let best = self.inner.best_arm();
self.arm_to_content.get(best).copied()
}
#[must_use]
pub fn arm_count(&self) -> usize {
self.inner.arms.len()
}
}
#[inline]
fn lcg_next(state: u64) -> u64 {
state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407)
}
#[inline]
fn lcg_f64(seed: u64) -> f64 {
let s = lcg_next(seed);
(s >> 11) as f64 / (1u64 << 53) as f64
}
#[inline]
fn lcg_u64(seed: u64) -> u64 {
lcg_next(seed)
}
fn sample_gamma(shape: f64, state: &mut u64) -> f64 {
if shape < 1.0 {
let boost = sample_gamma(1.0 + shape, state);
let u = {
*state = lcg_next(*state);
(*state >> 11) as f64 / (1u64 << 53) as f64
};
return boost * u.powf(1.0 / shape);
}
let d = shape - 1.0 / 3.0;
let c = 1.0 / (9.0 * d).sqrt();
loop {
*state = lcg_next(*state);
let u1 = (*state >> 11) as f64 / (1u64 << 53) as f64;
*state = lcg_next(*state);
let u2 = (*state >> 11) as f64 / (1u64 << 53) as f64;
let x = (-2.0 * (u1 + 1e-10).ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
let v = 1.0 + c * x;
if v <= 0.0 {
continue;
}
let v3 = v * v * v;
*state = lcg_next(*state);
let u = (*state >> 11) as f64 / (1u64 << 53) as f64;
if u < 1.0 - 0.0331 * (x * x) * (x * x) {
return d * v3;
}
if u.ln() < 0.5 * x * x + d * (1.0 - v3 + v3.ln()) {
return d * v3;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_arms(n: usize) -> Vec<BanditArm> {
(0..n).map(|i| BanditArm::new(format!("arm-{i}"))).collect()
}
#[test]
fn test_bandit_arm_mean_reward_no_pulls() {
let arm = BanditArm::new("test");
assert_eq!(arm.mean_reward(), 0.0);
}
#[test]
fn test_bandit_arm_mean_reward_after_update() {
let mut arm = BanditArm::new("test");
arm.pulls = 2;
arm.rewards = 1.6;
assert!((arm.mean_reward() - 0.8).abs() < 1e-9);
}
#[test]
fn test_epsilon_greedy_best_arm() {
let mut eg = EpsilonGreedy::new(0.0, make_arms(3));
eg.update(1, 0.9);
eg.update(1, 0.9);
assert_eq!(eg.best_arm(), 1);
}
#[test]
fn test_epsilon_greedy_greedy_select() {
let mut eg = EpsilonGreedy::new(0.0, make_arms(3));
eg.update(2, 1.0);
eg.update(2, 1.0);
let selected = eg.select(42);
assert_eq!(selected, 2);
}
#[test]
fn test_epsilon_greedy_update() {
let mut eg = EpsilonGreedy::new(0.1, make_arms(3));
eg.update(0, 0.5);
assert_eq!(eg.arms[0].pulls, 1);
assert!((eg.arms[0].rewards - 0.5).abs() < 1e-9);
}
#[test]
fn test_epsilon_greedy_explore() {
let eg = EpsilonGreedy::new(1.0, make_arms(5));
let selected = eg.select(12345);
assert!(selected < 5);
}
#[test]
fn test_ucb1_select_unpulled_first() {
let mut bandit = Ucb1Bandit::new(make_arms(3));
bandit.update(0, 0.5);
bandit.update(0, 0.5);
let selected = bandit.select();
assert!(selected == 1 || selected == 2);
}
#[test]
fn test_ucb1_update_counts() {
let mut bandit = Ucb1Bandit::new(make_arms(2));
bandit.update(0, 1.0);
bandit.update(1, 0.0);
assert_eq!(bandit.total_pulls, 2);
assert_eq!(bandit.arms[0].pulls, 1);
assert_eq!(bandit.arms[1].pulls, 1);
}
#[test]
fn test_ucb1_selects_higher_reward() {
let mut bandit = Ucb1Bandit::new(make_arms(2));
for _ in 0..50 {
bandit.update(0, 0.9);
bandit.update(1, 0.1);
}
let selected = bandit.select();
assert_eq!(selected, 0);
}
#[test]
fn test_thompson_sampling_select_range() {
let ts = ThompsonSampling::new(5);
let selected = ts.select(9999);
assert!(selected < 5);
}
#[test]
fn test_thompson_sampling_update_success() {
let mut ts = ThompsonSampling::new(3);
ts.update_success(0);
assert!((ts.alpha[0] - 2.0).abs() < 1e-9);
assert!((ts.beta[0] - 1.0).abs() < 1e-9);
}
#[test]
fn test_thompson_sampling_update_failure() {
let mut ts = ThompsonSampling::new(3);
ts.update_failure(2);
assert!((ts.beta[2] - 2.0).abs() < 1e-9);
}
#[test]
fn test_sample_beta_range() {
for seed in 0..20u64 {
let v = ThompsonSampling::sample_beta(2.0, 5.0, seed * 1000);
assert!((0.0..=1.0).contains(&v), "value {v} out of range");
}
}
#[test]
fn test_content_bandit_select() {
let cb = ContentBandit::new(vec![10, 20, 30], 0.0);
let content = cb.select_content(42);
assert!(content.is_some());
}
#[test]
fn test_content_bandit_update_and_best() {
let mut cb = ContentBandit::new(vec![10, 20, 30], 0.0);
cb.update(20, 1.0);
cb.update(20, 1.0);
assert_eq!(cb.best_content(), Some(20));
}
#[test]
fn test_content_bandit_arm_count() {
let cb = ContentBandit::new(vec![1, 2, 3, 4], 0.1);
assert_eq!(cb.arm_count(), 4);
}
}