use super::config::BanditConfig;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct BanditNode {
pub visits: u32,
pub reward: f64,
pub rave_visits: u32,
pub rave_reward: f64,
pub group_id: u32,
pub bias: f64,
pub children: Vec<u32>,
pub arms: Vec<u64>,
pub next_untried: usize,
}
impl BanditNode {
pub(crate) fn score(&self, parent_visits: u32, config: &BanditConfig) -> f64 {
if self.visits == 0 {
return f64::INFINITY;
}
let exploitation = self.reward / f64::from(self.visits);
let exploration = config.exploration_constant
* (f64::from(parent_visits).ln() / f64::from(self.visits)).sqrt();
let beta = if self.rave_visits > 0 {
let n = f64::from(self.visits);
let nr = f64::from(self.rave_visits);
let b2 = config.rave_bias * config.rave_bias;
nr / (nr + n + 4.0 * b2 * n * nr)
} else {
0.0
};
let rave_value = if self.rave_visits > 0 {
self.rave_reward / f64::from(self.rave_visits)
} else {
0.0
};
let uct = exploitation + exploration;
(1.0 - beta) * uct + beta * rave_value + self.bias
}
pub(crate) fn has_untried(&self) -> bool {
self.next_untried < self.arms.len()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroupStats {
pub group_id: u32,
pub visits: u32,
pub average_reward: f64,
pub total_arms: usize,
pub explored_arms: usize,
pub rave_visits: u32,
}