pub mod config;
mod node;
#[cfg(test)]
mod tests;
#[allow(unused_imports)] pub use config::{BanditConfig, BanditConfigBuilder};
use node::BanditNode;
pub use node::GroupStats;
use std::collections::HashMap;
use rand::seq::SliceRandom;
use rand::SeedableRng;
pub struct BanditSearch {
config: BanditConfig,
nodes: Vec<BanditNode>,
group_to_node: HashMap<u32, u32>,
arm_to_node: HashMap<u64, u32>,
pulls_executed: u64,
rng: rand::rngs::StdRng,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct BanditSearchCheckpoint {
pub(crate) config: BanditConfig,
nodes: Vec<BanditNode>,
group_to_node: HashMap<u32, u32>,
arm_to_node: HashMap<u64, u32>,
pulls_executed: u64,
}
impl BanditSearch {
#[must_use]
pub fn new(config: BanditConfig) -> Self {
Self::with_seed(config, rand::rngs::StdRng::from_entropy())
}
#[must_use]
pub fn new_seeded(config: BanditConfig, seed: u64) -> Self {
Self::with_seed(config, rand::rngs::StdRng::seed_from_u64(seed))
}
fn with_seed(config: BanditConfig, rng: rand::rngs::StdRng) -> Self {
let root = BanditNode {
visits: 0,
reward: 0.0,
rave_visits: 0,
rave_reward: 0.0,
group_id: u32::MAX,
bias: 0.0,
children: Vec::new(),
arms: Vec::new(),
next_untried: 0,
};
Self {
config,
nodes: vec![root],
group_to_node: HashMap::new(),
arm_to_node: HashMap::new(),
pulls_executed: 0,
rng,
}
}
#[must_use]
pub fn checkpoint(&self) -> BanditSearchCheckpoint {
BanditSearchCheckpoint {
config: self.config.clone(),
nodes: self.nodes.clone(),
group_to_node: self.group_to_node.clone(),
arm_to_node: self.arm_to_node.clone(),
pulls_executed: self.pulls_executed,
}
}
#[must_use]
pub fn restore(checkpoint: BanditSearchCheckpoint) -> Self {
Self {
config: checkpoint.config,
nodes: checkpoint.nodes,
group_to_node: checkpoint.group_to_node,
arm_to_node: checkpoint.arm_to_node,
pulls_executed: checkpoint.pulls_executed,
rng: rand::rngs::StdRng::from_entropy(),
}
}
#[must_use]
pub fn restore_with_seed(checkpoint: BanditSearchCheckpoint, seed: u64) -> Self {
Self {
config: checkpoint.config,
nodes: checkpoint.nodes,
group_to_node: checkpoint.group_to_node,
arm_to_node: checkpoint.arm_to_node,
pulls_executed: checkpoint.pulls_executed,
rng: rand::rngs::StdRng::seed_from_u64(seed),
}
}
pub fn add_arm(&mut self, arm_id: u64, group_id: u32) {
if self.arm_to_node.contains_key(&arm_id) {
return;
}
if self.nodes.len() == u32::MAX as usize {
return;
}
let node_idx = if let Some(&idx) = self.group_to_node.get(&group_id) {
idx
} else {
let Ok(idx) = u32::try_from(self.nodes.len()) else {
return;
};
self.nodes.push(BanditNode {
visits: 0,
reward: 0.0,
rave_visits: 0,
rave_reward: 0.0,
group_id,
bias: 0.0,
children: Vec::new(),
arms: Vec::new(),
next_untried: 0,
});
self.nodes[0].children.push(idx);
self.group_to_node.insert(group_id, idx);
idx
};
self.nodes[node_idx as usize].arms.push(arm_id);
self.arm_to_node.insert(arm_id, node_idx);
}
pub fn next_arm(&mut self) -> Option<u64> {
if self.config.max_pulls > 0 && self.pulls_executed >= self.config.max_pulls {
return None;
}
let available_groups: Vec<u32> = self
.nodes
.first()
.map(|root| {
root.children
.iter()
.copied()
.filter(|idx| self.nodes[*idx as usize].has_untried())
.collect()
})
.unwrap_or_default();
let mut available_groups = available_groups;
available_groups.shuffle(&mut self.rng);
let group_idx = available_groups.iter().copied().max_by(|&a, &b| {
let sa = self.nodes[a as usize].score(self.nodes[0].visits, &self.config);
let sb = self.nodes[b as usize].score(self.nodes[0].visits, &self.config);
let ord = sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal);
if ord == std::cmp::Ordering::Equal {
self.nodes[a as usize]
.bias
.partial_cmp(&self.nodes[b as usize].bias)
.unwrap_or(std::cmp::Ordering::Equal)
} else {
ord
}
})?;
let node = &mut self.nodes[group_idx as usize];
if node.next_untried < node.arms.len() {
let arm_id = node.arms[node.next_untried];
node.next_untried += 1;
self.pulls_executed += 1;
Some(arm_id)
} else {
None
}
}
pub fn observe(&mut self, arm_id: u64, reward: f64) {
let Some(node_idx) = self.arm_to_node.get(&arm_id) else {
return;
};
let node_idx = *node_idx;
let group_node = &mut self.nodes[node_idx as usize];
group_node.visits += 1;
group_node.reward += reward;
{
let root = &mut self.nodes[0];
root.visits += 1;
root.reward += reward;
}
if self.config.rave_bias > 0.0 {
let sibling_groups: Vec<u32> = self.nodes[0].children.clone();
for sibling_idx in sibling_groups {
if sibling_idx == node_idx {
continue;
}
let sibling = &mut self.nodes[sibling_idx as usize];
sibling.rave_visits += 1;
sibling.rave_reward += reward;
}
}
}
pub fn set_group_bias(&mut self, group_id: u32, bias: f64) {
if let Some(&node_idx) = self.group_to_node.get(&group_id) {
self.nodes[node_idx as usize].bias = bias;
}
}
#[must_use]
pub fn group_stats(&self) -> Vec<GroupStats> {
self.nodes[0]
.children
.iter()
.map(|&idx| {
let node = &self.nodes[idx as usize];
let avg = if node.visits > 0 {
node.reward / f64::from(node.visits)
} else {
0.0
};
GroupStats {
group_id: node.group_id,
visits: node.visits,
average_reward: avg,
total_arms: node.arms.len(),
explored_arms: node.next_untried,
rave_visits: node.rave_visits,
}
})
.collect()
}
#[must_use]
pub fn total_pulls(&self) -> u64 {
self.pulls_executed
}
}