#![allow(
clippy::cast_possible_truncation,
clippy::cast_lossless,
clippy::float_cmp
)]
use crate::node::Node;
use crate::*;
use std::sync::Arc;
use std::thread;
#[test]
fn bandit_with_zero_arms() {
let mut search = BanditSearch::new(BanditConfig::default());
let arm = search.next_arm();
assert!(arm.is_none(), "Bandit with no arms should return None");
}
#[test]
fn bandit_group_stats_with_zero_arms() {
let search = BanditSearch::new(BanditConfig::default());
let stats = search.group_stats();
assert!(stats.is_empty(), "Stats should be empty with no arms");
}
#[test]
fn bandit_total_pulls_with_zero_arms() {
let search = BanditSearch::new(BanditConfig::default());
assert_eq!(search.total_pulls(), 0, "Total pulls should be 0");
}
#[test]
fn bandit_with_single_arm() {
let mut search = BanditSearch::new(BanditConfig::default());
search.add_arm(0, 0);
let arm1 = search.next_arm();
assert_eq!(arm1, Some(0), "Should return the single arm");
let arm2 = search.next_arm();
assert!(
arm2.is_none(),
"Should return None after single arm is exhausted"
);
}
#[test]
fn bandit_single_arm_with_observe() {
let mut search = BanditSearch::new(BanditConfig::default());
search.add_arm(42, 0);
let arm = search.next_arm().unwrap();
search.observe(arm, 1.0);
let stats = search.group_stats();
assert_eq!(stats.len(), 1);
assert_eq!(stats[0].visits, 1);
assert_eq!(stats[0].average_reward, 1.0);
}
#[test]
fn bandit_exploration_constant_zero() {
let config = BanditConfig::builder().exploration_constant(0.0).build();
let mut search = BanditSearch::new_seeded(config, 42);
for i in 0..10u64 {
search.add_arm(i, (i / 3) as u32);
}
let arm = search.next_arm();
assert!(arm.is_some());
search.observe(arm.unwrap(), 0.5);
}
#[test]
fn bandit_exploration_constant_infinity() {
let config = BanditConfig::builder()
.exploration_constant(f64::INFINITY)
.build();
let mut search = BanditSearch::new_seeded(config, 42);
for i in 0..10u64 {
search.add_arm(i, (i / 3) as u32);
}
let arm = search.next_arm();
assert!(arm.is_some());
}
#[test]
fn bandit_exploration_constant_nan() {
let config = BanditConfig::builder()
.exploration_constant(f64::NAN)
.build();
assert!((config.exploration_constant - std::f64::consts::SQRT_2).abs() < f64::EPSILON);
let mut search = BanditSearch::new_seeded(config, 42);
for i in 0..10u64 {
search.add_arm(i, (i / 3) as u32);
}
let arm = search.next_arm();
assert!(arm.is_some());
}
#[test]
fn bandit_exploration_constant_negative() {
let config = BanditConfig::builder().exploration_constant(-1.0).build();
assert!((config.exploration_constant - std::f64::consts::SQRT_2).abs() < f64::EPSILON);
let mut search = BanditSearch::new_seeded(config, 42);
for i in 0..10u64 {
search.add_arm(i, (i / 3) as u32);
}
let arm = search.next_arm();
assert!(arm.is_some());
}
#[test]
fn bandit_rave_bias_zero() {
let config = BanditConfig::builder().rave_bias(0.0).build();
let mut search = BanditSearch::new_seeded(config, 42);
for i in 0..10u64 {
search.add_arm(i, (i / 5) as u32);
}
let arm1 = search.next_arm().unwrap();
search.observe(arm1, 1.0);
let arm2 = search.next_arm().unwrap();
search.observe(arm2, 1.0);
let stats = search.group_stats();
let total_rave: u32 = stats.iter().map(|s| s.rave_visits).sum();
assert_eq!(
total_rave, 0,
"RAVE visits should not increment when rave_bias=0"
);
}
#[test]
fn bandit_rave_bias_infinity() {
let config = BanditConfig::builder().rave_bias(f64::INFINITY).build();
assert!((config.rave_bias - 500.0).abs() < f64::EPSILON);
let mut search = BanditSearch::new_seeded(config, 42);
for i in 0..10u64 {
search.add_arm(i, (i / 5) as u32);
}
let arm = search.next_arm();
assert!(arm.is_some());
}
#[test]
fn bandit_observe_nan_reward() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
let arm = search.next_arm().unwrap();
search.observe(arm, f64::NAN);
let stats = search.group_stats();
assert_eq!(stats[0].visits, 0, "NaN reward should be rejected");
}
#[test]
fn bandit_observe_infinity_reward() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
let arm = search.next_arm().unwrap();
search.observe(arm, f64::INFINITY);
let stats = search.group_stats();
assert_eq!(stats[0].visits, 0, "Infinite reward should be rejected");
}
#[test]
fn bandit_observe_negative_infinity_reward() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
let arm = search.next_arm().unwrap();
search.observe(arm, f64::NEG_INFINITY);
let stats = search.group_stats();
assert_eq!(
stats[0].visits, 0,
"Negative infinite reward should be rejected"
);
}
#[test]
fn bandit_observe_negative_reward() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
let arm = search.next_arm().unwrap();
search.observe(arm, -5.0);
let stats = search.group_stats();
assert_eq!(stats[0].visits, 1);
assert_eq!(
stats[0].average_reward, -5.0,
"Negative reward should be accepted"
);
}
#[test]
fn bandit_mixed_rewards_positive_and_negative() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
for i in 0..5u64 {
search.add_arm(i, 0);
}
for i in 0..5u64 {
let arm = search.next_arm().unwrap();
let reward = if i % 2 == 0 { 1.0 } else { -1.0 };
search.observe(arm, reward);
}
let stats = search.group_stats();
assert_eq!(stats[0].total_arms, 5);
}
#[test]
fn bandit_observe_very_large_reward() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
let arm = search.next_arm().unwrap();
search.observe(arm, 1e308);
let stats = search.group_stats();
assert!(
stats[0].average_reward > 1e307,
"Very large reward should be preserved"
);
}
#[test]
fn bandit_observe_very_small_reward() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
let arm = search.next_arm().unwrap();
search.observe(arm, 1e-308);
let stats = search.group_stats();
assert!(
stats[0].average_reward > 0.0 && stats[0].average_reward < 1e-307,
"Very small reward should be preserved"
);
}
#[test]
fn bandit_concurrent_observe_calls() {
let search = Arc::new(std::sync::Mutex::new(BanditSearch::new_seeded(
BanditConfig::default(),
42,
)));
{
let mut s = search.lock().unwrap();
for i in 0..200u64 {
s.add_arm(i, (i / 50) as u32);
}
}
let mut handles = vec![];
for thread_id in 0..10 {
let search = Arc::clone(&search);
handles.push(thread::spawn(move || {
for _ in 0..10 {
let mut s = search.lock().unwrap();
if let Some(arm) = s.next_arm() {
s.observe(arm, thread_id as f64 * 0.1);
}
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let s = search.lock().unwrap();
assert_eq!(s.total_pulls(), 100, "Should have 100 total pulls");
}
#[test]
fn bandit_many_visits_u32_boundary() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
for _ in 0..1000 {
search.observe(0, 0.5);
}
let stats = search.group_stats();
assert_eq!(stats[0].visits, 1000);
}
#[test]
fn bandit_group_bias_infinity() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
for i in 0..10u64 {
search.add_arm(i, (i / 5) as u32);
}
search.set_group_bias(1, f64::INFINITY);
let arm = search.next_arm().unwrap();
assert!(
arm >= 5,
"Infinite bias should force selection from group 1"
);
}
#[test]
fn bandit_group_bias_nan() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
for i in 0..10u64 {
search.add_arm(i, (i / 5) as u32);
}
search.set_group_bias(0, f64::NAN);
let arm = search.next_arm();
assert!(arm.is_some());
}
#[test]
fn bandit_set_bias_for_nonexistent_group() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
search.set_group_bias(999, 100.0);
let arm = search.next_arm();
assert!(arm.is_some());
}
#[test]
fn bandit_max_pulls_of_one() {
let config = BanditConfig::builder().max_pulls(1).build();
let mut search = BanditSearch::new_seeded(config, 42);
for i in 0..10u64 {
search.add_arm(i, 0);
}
let arm1 = search.next_arm();
assert!(arm1.is_some());
let arm2 = search.next_arm();
assert!(arm2.is_none(), "Should return None after max_pulls reached");
}
#[test]
fn bandit_max_pulls_u64_max() {
let config = BanditConfig::builder().max_pulls(u64::MAX).build();
let mut search = BanditSearch::new_seeded(config, 42);
search.add_arm(0, 0);
for _ in 0..100 {
let arm = search.next_arm();
if arm.is_none() {
break; }
}
assert!(search.total_pulls() <= 1);
}
#[test]
fn bandit_checkpoint_empty() {
let search = BanditSearch::new(BanditConfig::default());
let checkpoint = search.checkpoint();
let mut restored = BanditSearch::restore(checkpoint);
assert_eq!(restored.total_pulls(), 0);
assert!(restored.next_arm().is_none());
}
#[test]
fn bandit_checkpoint_with_observations() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
for i in 0..10u64 {
search.add_arm(i, 0);
}
for _ in 0..5 {
let arm = search.next_arm().unwrap();
search.observe(arm, 1.0);
}
let checkpoint = search.checkpoint();
let restored = BanditSearch::restore(checkpoint);
assert_eq!(restored.total_pulls(), 5);
}
#[test]
fn bandit_add_duplicate_arm() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
search.add_arm(0, 0);
let stats = search.group_stats();
assert_eq!(stats[0].total_arms, 1, "Duplicate arm should not be added");
}
#[test]
fn bandit_add_arm_with_u32_max_group() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, u32::MAX);
let stats = search.group_stats();
assert_eq!(stats[0].group_id, u32::MAX);
}
#[test]
fn bandit_add_arm_with_u64_max_arm_id() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(u64::MAX, 0);
let arm = search.next_arm();
assert_eq!(arm, Some(u64::MAX));
}
#[test]
fn bandit_add_many_groups() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
for i in 0..1000u64 {
search.add_arm(i, i as u32);
}
let stats = search.group_stats();
assert_eq!(stats.len(), 1000);
}
#[test]
fn bandit_add_many_arms_single_group() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
for i in 0..1000u64 {
search.add_arm(i, 0);
}
let stats = search.group_stats();
assert_eq!(stats.len(), 1);
assert_eq!(stats[0].total_arms, 1000);
}
#[test]
fn bandit_observe_nonexistent_arm() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
search.observe(999, 1.0);
let stats = search.group_stats();
assert_eq!(
stats[0].visits, 0,
"Nonexistent arm observation should not affect stats"
);
}
#[test]
fn bandit_observe_before_next_arm() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
search.observe(0, 1.0);
let stats = search.group_stats();
assert_eq!(stats[0].visits, 1);
}
#[test]
fn bandit_next_arms_zero() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
let arms = search.next_arms(0);
assert!(
arms.is_empty(),
"Should return empty vec when requesting 0 arms"
);
}
#[test]
fn bandit_next_arms_max() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
for i in 0..10u64 {
search.add_arm(i, 0);
}
let arms = search.next_arms(usize::MAX);
assert_eq!(arms.len(), 10, "Should return up to available arms");
}
#[derive(Clone)]
struct EmptyEnv;
impl Environment for EmptyEnv {
type Action = ();
fn legal_actions(&self) -> Vec<Self::Action> {
vec![]
}
fn apply(&mut self, _: &Self::Action) {}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
}
#[test]
fn test_tree_search_empty_actions() {
let mut search = TreeSearch::new(EmptyEnv, SearchConfig::default());
let best = search.run();
assert!(
best.is_none(),
"Search should return None when no legal actions exist"
);
}
#[derive(Clone)]
struct EndlessEnv(i32);
impl Environment for EndlessEnv {
type Action = i32;
fn legal_actions(&self) -> Vec<Self::Action> {
vec![1]
}
fn apply(&mut self, action: &Self::Action) {
self.0 += action;
}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
}
#[test]
fn test_tree_search_extreme_iterations() {
let config = SearchConfig::builder()
.iterations(usize::MAX)
.time_budget(std::time::Duration::from_millis(5))
.build();
let mut search = TreeSearch::new(EndlessEnv(0), config);
let best = search.run();
assert!(
search.total_simulations() > 0,
"Should have run at least one simulation"
);
assert!(
search.total_simulations() < u32::MAX,
"Should not reach extreme values in 5ms"
);
assert_eq!(best, Some(1));
}
#[test]
fn test_tree_search_max_nodes_zero() {
let config = SearchConfig::builder().iterations(100).build();
let mut search = TreeSearch::new(EndlessEnv(0), config);
search.with_max_nodes(0);
search.run();
assert_eq!(
search.tree_size(),
1,
"Tree should not expand beyond root if limit is 0 (or lower than root size)"
);
}
#[test]
fn test_tree_search_huge_depth() {
let config = SearchConfig::builder()
.iterations(1)
.max_depth(10_000_000) .time_budget(std::time::Duration::from_millis(5))
.build();
let mut search = TreeSearch::new(EndlessEnv(0), config);
search.run();
assert!(true); }
#[test]
fn test_reward_arithmetic_nan() {
let r1 = Reward::new(0.5);
let r2 = Reward::new(f64::NAN);
let sum = r1 + r2;
assert!(
sum.value().is_nan(),
"Reward should correctly propagate NaN"
);
let mut r3 = Reward::new(1.0);
r3 += Reward::new(f64::NAN);
assert!(
r3.value().is_nan(),
"Reward add_assign should propagate NaN"
);
}
#[test]
fn test_node_uct_score_zero_visits() {
let node = Node::<i32>::root(vec![1, 2]);
let score = node.uct_score(100, f64::INFINITY);
assert!(
score.is_infinite(),
"UCT score should handle infinite exploration constant at 0 visits by returning INFINITY"
);
}
#[test]
fn test_node_uct_score_huge_visits() {
let mut node = Node::<i32>::root(vec![]);
node.visits = u32::MAX;
node.cumulative_reward = 1.0;
let score = node.uct_score(u32::MAX, 1.41);
assert!(score.is_finite());
}
#[test]
fn test_node_uct_score_negative_exploration() {
let mut node = Node::<i32>::root(vec![]);
node.visits = 10;
node.cumulative_reward = 5.0;
let score = node.uct_score(100, -1.0);
assert!(score.is_finite());
assert!(score < 0.5); }
#[test]
fn test_reward_value_infinity() {
let r = Reward::new(f64::INFINITY);
assert!(r.value().is_infinite());
let mut r2 = Reward::new(10.0);
r2 += r;
assert!(r2.value().is_infinite());
}
#[test]
fn test_reward_from_nan() {
let r = Reward::from(f64::NAN);
assert!(r.value().is_nan());
}
#[test]
fn test_tree_search_advance_action_unexpanded() {
let config = SearchConfig::builder().iterations(0).build();
let mut search = TreeSearch::new(EndlessEnv(0), config);
assert!(!search.advance_to_action(&1));
}
#[test]
fn test_tree_search_with_nan_heuristic() {
#[derive(Clone)]
struct NanHeuristicEnv;
impl Environment for NanHeuristicEnv {
type Action = ();
fn legal_actions(&self) -> Vec<()> {
vec![()]
}
fn apply(&mut self, _: &()) {}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
fn heuristic(&self) -> Heuristic {
Heuristic::from_reward(Reward::new(f64::NAN))
}
}
let mut search = TreeSearch::new(
NanHeuristicEnv,
SearchConfig::builder().iterations(100).build(),
);
search.run();
let pv = search.principal_variation();
assert!(pv.is_empty() || pv.len() > 0);
}
#[test]
fn test_tree_search_time_budget_zero() {
let config = SearchConfig::builder()
.iterations(100)
.time_budget(std::time::Duration::from_millis(0))
.build();
let mut search = TreeSearch::new(EndlessEnv(0), config);
search.run();
assert_eq!(search.total_simulations(), 0);
}
#[test]
fn test_tree_search_evaluator_panic_handling() {
struct EmptyEvaluator;
impl Evaluator<EndlessEnv> for EmptyEvaluator {
fn evaluate(&self, _env: &EndlessEnv) -> Reward {
Reward::new(f64::NAN)
}
}
let config = SearchConfig::builder().iterations(10).build();
let mut search = TreeSearch::new(EndlessEnv(0), config);
search.with_evaluator(Arc::new(EmptyEvaluator));
search.run();
assert!(search.total_simulations() > 0);
}