pub mod config;
mod node;
#[cfg(test)]
mod tests;
#[allow(unused_imports)] pub use config::{BanditConfig, BanditConfigBuilder, Scalarizer};
use node::BanditNode;
pub use node::GroupStats;
use std::collections::HashMap;
use rand::seq::SliceRandom;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
pub struct BanditSearch {
config: BanditConfig,
nodes: Vec<BanditNode>,
group_to_node: HashMap<u32, u32>,
arm_to_node: HashMap<u64, u32>,
pulls_executed: u64,
rng: ChaCha8Rng,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct BanditSearchCheckpoint {
pub(crate) config: BanditConfig,
nodes: Vec<BanditNode>,
group_to_node: HashMap<u32, u32>,
arm_to_node: HashMap<u64, u32>,
pulls_executed: u64,
}
impl BanditSearch {
pub fn new(config: BanditConfig) -> Self {
Self::with_seed(config, entropy_rng())
}
pub fn new_seeded(config: BanditConfig, seed: u64) -> Self {
Self::with_seed(config, ChaCha8Rng::seed_from_u64(seed))
}
fn with_seed(config: BanditConfig, rng: ChaCha8Rng) -> Self {
let root = BanditNode {
visits: 0,
reward: 0.0,
rave_visits: 0,
rave_reward: 0.0,
group_id: u32::MAX,
bias: 0.0,
children: Vec::new(),
arms: Vec::new(),
next_untried: 0,
};
Self {
config,
nodes: vec![root],
group_to_node: HashMap::new(),
arm_to_node: HashMap::new(),
pulls_executed: 0,
rng,
}
}
pub fn checkpoint(&self) -> BanditSearchCheckpoint {
BanditSearchCheckpoint {
config: self.config.clone(),
nodes: self.nodes.clone(),
group_to_node: self.group_to_node.clone(),
arm_to_node: self.arm_to_node.clone(),
pulls_executed: self.pulls_executed,
}
}
pub fn restore(checkpoint: BanditSearchCheckpoint) -> Self {
Self {
config: checkpoint.config,
nodes: checkpoint.nodes,
group_to_node: checkpoint.group_to_node,
arm_to_node: checkpoint.arm_to_node,
pulls_executed: checkpoint.pulls_executed,
rng: entropy_rng(),
}
}
pub fn restore_with_seed(checkpoint: BanditSearchCheckpoint, seed: u64) -> Self {
Self {
config: checkpoint.config,
nodes: checkpoint.nodes,
group_to_node: checkpoint.group_to_node,
arm_to_node: checkpoint.arm_to_node,
pulls_executed: checkpoint.pulls_executed,
rng: ChaCha8Rng::seed_from_u64(seed),
}
}
pub fn add_arm(&mut self, arm_id: u64, group_id: u32) {
if self.arm_to_node.contains_key(&arm_id) {
return;
}
if self.nodes.len() == u32::MAX as usize {
return;
}
let node_idx = if let Some(&idx) = self.group_to_node.get(&group_id) {
idx
} else {
let Ok(idx) = u32::try_from(self.nodes.len()) else {
return;
};
self.nodes.push(BanditNode {
visits: 0,
reward: 0.0,
rave_visits: 0,
rave_reward: 0.0,
group_id,
bias: 0.0,
children: Vec::new(),
arms: Vec::new(),
next_untried: 0,
});
self.nodes[0].children.push(idx);
self.group_to_node.insert(group_id, idx);
idx
};
self.nodes[node_idx as usize].arms.push(arm_id);
self.arm_to_node.insert(arm_id, node_idx);
}
pub fn next_arm(&mut self) -> Option<u64> {
if self.config.max_pulls > 0 && self.pulls_executed >= self.config.max_pulls {
return None;
}
let available_groups: Vec<u32> = self
.nodes
.first()
.map(|root| {
root.children
.iter()
.copied()
.filter(|idx| self.nodes[*idx as usize].has_untried())
.collect()
})
.unwrap_or_default();
let mut available_groups = available_groups;
available_groups.shuffle(&mut self.rng);
let group_idx = available_groups.iter().copied().max_by(|&a, &b| {
let sa = self.nodes[a as usize].score(self.nodes[0].visits, &self.config);
let sb = self.nodes[b as usize].score(self.nodes[0].visits, &self.config);
let ord = sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal);
if ord == std::cmp::Ordering::Equal {
self.nodes[a as usize]
.bias
.partial_cmp(&self.nodes[b as usize].bias)
.unwrap_or(std::cmp::Ordering::Equal)
} else {
ord
}
})?;
let node = &mut self.nodes[group_idx as usize];
if node.next_untried < node.arms.len() {
let arm_id = node.arms[node.next_untried];
node.next_untried += 1;
self.pulls_executed += 1;
Some(arm_id)
} else {
None
}
}
pub fn next_arms(&mut self, n: usize) -> Vec<u64> {
let available_arms = if let Some(root) = self.nodes.first() {
root.children
.iter()
.map(|&idx| {
let node = &self.nodes[idx as usize];
node.arms.len() - node.next_untried
})
.sum::<usize>()
} else {
0
};
let remaining_budget = if self.config.max_pulls > 0 {
usize::try_from(self.config.max_pulls.saturating_sub(self.pulls_executed))
.unwrap_or(usize::MAX)
} else {
usize::MAX
};
let cap = n.min(available_arms).min(remaining_budget);
let mut arms = Vec::with_capacity(cap);
for _ in 0..n {
if let Some(arm_id) = self.next_arm() {
arms.push(arm_id);
} else {
break;
}
}
arms
}
pub fn observe(&mut self, arm_id: u64, reward: f64) {
if !reward.is_finite() {
return;
}
let Some(node_idx) = self.arm_to_node.get(&arm_id) else {
return;
};
let node_idx = *node_idx;
let group_node = &mut self.nodes[node_idx as usize];
group_node.visits += 1;
group_node.reward += reward;
{
let root = &mut self.nodes[0];
root.visits += 1;
root.reward += reward;
}
if self.config.rave_bias > 0.0 {
let sibling_groups: Vec<u32> = self.nodes[0].children.clone();
for sibling_idx in sibling_groups {
if sibling_idx == node_idx {
continue;
}
let sibling = &mut self.nodes[sibling_idx as usize];
sibling.rave_visits += 1;
sibling.rave_reward += reward;
}
}
}
pub fn observe_with_signals(&mut self, arm_id: u64, signals: &[(&str, f64)]) {
let reward = self.config.scalarizer.scalarize(signals);
self.observe(arm_id, reward);
}
pub fn reweight_signals(&mut self, updates: &[(&str, f64)]) {
for (name, weight) in updates {
if weight.is_finite() {
self.config
.scalarizer
.signal_weights
.insert(name.to_string(), *weight);
}
}
}
pub fn set_group_bias(&mut self, group_id: u32, bias: f64) {
if let Some(&node_idx) = self.group_to_node.get(&group_id) {
self.nodes[node_idx as usize].bias = bias;
}
}
pub fn group_stats(&self) -> Vec<GroupStats> {
self.nodes[0]
.children
.iter()
.map(|&idx| {
let node = &self.nodes[idx as usize];
let avg = if node.visits > 0 {
node.reward / f64::from(node.visits)
} else {
0.0
};
GroupStats {
group_id: node.group_id,
visits: node.visits,
average_reward: avg,
total_arms: node.arms.len(),
explored_arms: node.next_untried,
rave_visits: node.rave_visits,
}
})
.collect()
}
pub fn total_pulls(&self) -> u64 {
self.pulls_executed
}
pub fn observe_multi(&mut self, arm_id: u64, signals: &[(&str, f64)]) {
let scalar = signals
.iter()
.map(|(_, v)| *v)
.fold(f64::NEG_INFINITY, f64::max);
self.observe(arm_id, scalar);
}
pub fn group_stat(&self, group_id: u32) -> Option<GroupStats> {
let idx = *self.group_to_node.get(&group_id)?;
let node = &self.nodes[idx as usize];
let avg = if node.visits > 0 {
node.reward / f64::from(node.visits)
} else {
0.0
};
Some(GroupStats {
group_id: node.group_id,
visits: node.visits,
average_reward: avg,
total_arms: node.arms.len(),
explored_arms: node.next_untried,
rave_visits: node.rave_visits,
})
}
}
fn entropy_rng() -> ChaCha8Rng {
match ChaCha8Rng::try_from_rng(&mut rand::rngs::SysRng) {
Ok(rng) => rng,
Err(error) => panic!("failed to seed ChaCha8Rng from system RNG: {error}"),
}
}