use super::super::types::KernelType;
#[derive(Debug, Clone, Default)]
pub struct KernelArm {
pub pulls: u32,
pub total_reward: f32,
pub total_reward_sq: f32,
}
impl KernelArm {
#[allow(clippy::cast_precision_loss)] pub fn mean(&self) -> f32 {
if self.pulls == 0 {
0.0
} else {
self.total_reward / self.pulls as f32
}
}
#[allow(clippy::cast_precision_loss)] pub fn ucb(&self, total_pulls: u32, c: f32) -> f32 {
if self.pulls == 0 {
f32::INFINITY } else {
self.mean() + c * (2.0 * (total_pulls.max(1) as f32).ln() / self.pulls as f32).sqrt()
}
}
}
#[derive(Debug, Clone, Default)]
pub struct KernelBandit {
pub(crate) arms: Vec<KernelArm>,
pub(crate) total_pulls: u32,
pub(crate) exploration_c: f32,
pub(crate) use_thompson: bool,
}
impl KernelBandit {
pub const NUM_KERNELS: usize = 12;
pub fn new() -> Self {
Self {
arms: vec![KernelArm::default(); Self::NUM_KERNELS],
total_pulls: 0,
exploration_c: 2.0, use_thompson: false,
}
}
pub fn with_thompson_sampling() -> Self {
Self {
arms: vec![KernelArm::default(); Self::NUM_KERNELS],
total_pulls: 0,
exploration_c: 2.0,
use_thompson: true,
}
}
pub fn select(&self) -> KernelType {
let idx = if self.use_thompson { self.select_thompson() } else { self.select_ucb() };
KernelType::from_index(idx)
}
#[allow(clippy::cast_precision_loss)] fn select_ucb(&self) -> usize {
self.arms
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.ucb(self.total_pulls, self.exploration_c)
.partial_cmp(&b.ucb(self.total_pulls, self.exploration_c))
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0)
}
#[allow(clippy::cast_precision_loss)]
fn select_thompson(&self) -> usize {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
self.total_pulls.hash(&mut hasher);
let seed = hasher.finish();
self.arms
.iter()
.enumerate()
.max_by(|(i, a), (j, b)| {
let sample_a =
a.mean() + 0.1 * ((seed.wrapping_add(*i as u64) % 1000) as f32 / 1000.0 - 0.5);
let sample_b =
b.mean() + 0.1 * ((seed.wrapping_add(*j as u64) % 1000) as f32 / 1000.0 - 0.5);
sample_a.partial_cmp(&sample_b).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0)
}
pub fn update(&mut self, kernel: KernelType, reward: f32) {
contract_pre_update!();
let idx = kernel.to_index();
if idx < self.arms.len() {
self.arms[idx].pulls += 1;
self.arms[idx].total_reward += reward;
self.arms[idx].total_reward_sq += reward * reward;
self.total_pulls += 1;
}
}
pub fn best_kernel(&self) -> KernelType {
let idx = self
.arms
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.mean().partial_cmp(&b.mean()).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0);
KernelType::from_index(idx)
}
#[allow(clippy::cast_precision_loss)] pub fn exploration_rate(&self) -> f32 {
if self.total_pulls == 0 {
return 1.0;
}
let best_pulls = self.arms.iter().map(|a| a.pulls).max().unwrap_or(0);
1.0 - (best_pulls as f32 / self.total_pulls as f32)
}
#[allow(clippy::cast_precision_loss)] pub fn estimated_regret(&self) -> f32 {
let best_mean = self.arms.iter().map(|a| a.mean()).fold(0.0f32, f32::max);
self.arms.iter().map(|a| (best_mean - a.mean()) * a.pulls as f32).sum()
}
}