#![allow(clippy::cast_possible_truncation)]
use std::collections::HashMap;
use super::*;
#[test]
fn bandit_basic_flow() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
for i in 0..20u64 {
search.add_arm(i, (i / 5) as u32);
}
let mut pulled = Vec::new();
for _ in 0..20 {
if let Some(arm) = search.next_arm() {
pulled.push(arm);
let reward = if arm < 5 { 1.0 } else { 0.0 };
search.observe(arm, reward);
} else {
break;
}
}
assert_eq!(pulled.len(), 20);
}
#[test]
fn bandit_next_arms_collects_multiple() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 7);
for i in 0..5u64 {
search.add_arm(i, 0);
}
let mut pulled = search.next_arms(3);
assert_eq!(pulled.len(), 3);
while let Some(arm) = pulled.pop() {
search.observe(arm, 0.5);
}
let pulled = search.next_arms(10);
assert_eq!(pulled.len(), 2);
for arm in pulled {
search.observe(arm, 0.5);
}
assert_eq!(search.total_pulls(), 5);
}
#[test]
fn bandit_observe_with_signals_scalarizes() {
let mut signal_weights = HashMap::new();
signal_weights.insert("quality".to_string(), 1.0);
signal_weights.insert("latency".to_string(), 2.0);
let config = BanditConfig::builder()
.scalarizer(Scalarizer {
signal_weights,
default_weight: 0.0,
})
.build();
let mut search = BanditSearch::new_seeded(config, 11);
search.add_arm(0, 0);
let arm = search.next_arm().unwrap();
search.observe_with_signals(arm, &[("quality", 0.5), ("latency", 1.0)]);
let stats = search.group_stats();
assert_eq!(stats.len(), 1);
assert!((stats[0].average_reward - 2.5).abs() < f64::EPSILON);
}
#[test]
fn bandit_reweight_signals_updates_online() {
let mut signal_weights = HashMap::new();
signal_weights.insert("quality".to_string(), 1.0);
signal_weights.insert("latency".to_string(), 0.0);
let config = BanditConfig::builder()
.scalarizer(Scalarizer {
signal_weights,
default_weight: 0.0,
})
.build();
let mut search = BanditSearch::new_seeded(config, 11);
search.add_arm(0, 0);
search.add_arm(1, 0);
let first = search.next_arm().unwrap();
search.observe_with_signals(first, &[("quality", 1.0)]);
let second = search.next_arm().unwrap();
search.reweight_signals(&[("quality", 0.0), ("latency", 2.0)]);
search.observe_with_signals(second, &[("latency", 1.0)]);
let stats = search.group_stats();
assert_eq!(search.total_pulls(), 2);
assert_eq!(stats[0].visits, 2);
assert!((stats[0].average_reward - 1.5).abs() < f64::EPSILON);
}
#[test]
fn bandit_respects_budget() {
let config = BanditConfig::builder().max_pulls(5).build();
let mut search = BanditSearch::new_seeded(config, 42);
for i in 0..50u64 {
search.add_arm(i, 0);
}
let mut count = 0;
while let Some(arm) = search.next_arm() {
count += 1;
search.observe(arm, 1.0);
assert!(count <= 100, "budget not respected");
}
assert_eq!(count, 5);
}
#[test]
fn bandit_empty_returns_none() {
let mut search = BanditSearch::new(BanditConfig::default());
assert!(search.next_arm().is_none());
}
#[test]
fn bandit_rave_propagates() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 99);
for i in 0..10u64 {
search.add_arm(i, (i / 5) as u32);
}
let arm1 = search.next_arm().unwrap();
search.observe(arm1, 1.0);
let arm2 = search.next_arm().unwrap();
search.observe(arm2, 1.0);
let stats = search.group_stats();
let total_rave: u32 = stats.iter().map(|s| s.rave_visits).sum();
assert_eq!(total_rave, 2);
}
#[test]
fn bandit_group_bias() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 7);
for i in 0..10u64 {
search.add_arm(i, (i / 5) as u32);
}
search.set_group_bias(1, 100.0);
let arm = search.next_arm().unwrap();
assert!(arm >= 5, "biased group should be selected first");
}
#[test]
fn bandit_total_pulls() {
let mut search = BanditSearch::new(BanditConfig::default());
for i in 0..5u64 {
search.add_arm(i, 0);
}
for _ in 0..3 {
if let Some(arm) = search.next_arm() {
search.observe(arm, 0.5);
}
}
assert_eq!(search.total_pulls(), 3);
}
#[test]
fn bandit_group_stats_correct() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 1);
for i in 0..6u64 {
search.add_arm(i, (i / 3) as u32);
}
for _ in 0..3 {
if let Some(arm) = search.next_arm() {
search.observe(arm, 0.5);
}
}
let stats = search.group_stats();
assert_eq!(stats.len(), 2);
let total_visits: u32 = stats.iter().map(|s| s.visits).sum();
assert_eq!(total_visits, 3);
}
#[test]
fn bandit_seeded_determinism() {
let config = BanditConfig::default();
let mut s1 = BanditSearch::new_seeded(config.clone(), 42);
let mut s2 = BanditSearch::new_seeded(config, 42);
for i in 0..10u64 {
s1.add_arm(i, (i / 5) as u32);
s2.add_arm(i, (i / 5) as u32);
}
for _ in 0..10 {
let a1 = s1.next_arm();
let a2 = s2.next_arm();
assert_eq!(a1, a2, "same seed should produce same sequence");
if let (Some(arm1), Some(arm2)) = (a1, a2) {
s1.observe(arm1, 0.5);
s2.observe(arm2, 0.5);
}
}
}
#[test]
fn bandit_many_groups() {
let mut search = BanditSearch::new(BanditConfig::default());
for group in 0..100u32 {
for arm in 0..10u32 {
search.add_arm(u64::from(group) * 10 + u64::from(arm), group);
}
}
let mut count = 0;
for _ in 0..200 {
if let Some(arm) = search.next_arm() {
search.observe(arm, 0.1);
count += 1;
} else {
break;
}
}
assert!(count >= 100);
}
#[test]
fn bandit_checkpoint_roundtrip() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 11);
for i in 0..7u64 {
search.add_arm(i, 0);
}
for _ in 0..3 {
if let Some(arm) = search.next_arm() {
search.observe(arm, 0.75);
}
}
let checkpoint = search.checkpoint();
let restored = BanditSearch::restore(checkpoint);
assert_eq!(restored.total_pulls(), 3);
}
#[test]
fn bandit_prevents_duplicate_arms() {
let mut search = BanditSearch::new(BanditConfig::default());
search.add_arm(1, 0);
search.add_arm(1, 0);
let stats = search.group_stats();
assert_eq!(stats[0].total_arms, 1);
}
#[test]
fn bandit_zero_max_pulls_disables_limit() {
let config = BanditConfig::builder().max_pulls(0).build();
let mut search = BanditSearch::new_seeded(config, 7);
for i in 0..15u64 {
search.add_arm(i, 0);
}
for _ in 0..15 {
if search.next_arm().is_none() {
break;
}
}
assert_eq!(search.total_pulls(), 15);
}
#[test]
fn test_next_arms_returns_correct_count() {
let mut search = BanditSearch::new(BanditConfig::default());
for i in 0..10u64 {
search.add_arm(i, 0);
}
let arms = search.next_arms(5);
assert_eq!(arms.len(), 5);
}
#[test]
fn test_observe_multi_uses_max() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
search.add_arm(1, 0);
let _ = search.next_arm();
search.observe_multi(0, &[("sink", 1.0), ("crash", 0.5), ("coverage", 0.3)]);
let stats = search.group_stat(0);
assert!(stats.is_some());
}