use mctrust::*;
use std::sync::Arc;
use std::time::Duration;
#[test]
fn regression_rave_disabled_zero_bias() {
let mut search = BanditSearch::new_seeded(BanditConfig::builder().rave_bias(0.0).build(), 99);
for i in 0..10u64 {
search.add_arm(i, (i / 5) as u32);
}
let arm1 = search.next_arm().unwrap();
search.observe(arm1, 1.0);
let arm2 = search.next_arm().unwrap();
search.observe(arm2, 1.0);
let stats = search.group_stats();
let total_rave: u32 = stats.iter().map(|s| s.rave_visits).sum();
assert_eq!(
total_rave, 0,
"RAVE should be completely disabled at bias=0"
);
}
#[test]
fn regression_nan_reward_poisoning() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
search.add_arm(1, 0);
let arm1 = search.next_arm().unwrap();
search.observe(arm1, f64::NAN);
let arm2 = search.next_arm().unwrap();
search.observe(arm2, 1.0);
let stats = search.group_stats();
assert_eq!(stats[0].visits, 1, "Only valid reward should count");
assert!(
!stats[0].average_reward.is_nan(),
"Average should not be NaN"
);
}
#[derive(Clone)]
struct RegressionEnv(i32);
#[derive(Clone, Debug, PartialEq)]
struct RegressionAction;
impl Environment for RegressionEnv {
type Action = RegressionAction;
fn legal_actions(&self) -> Vec<RegressionAction> {
vec![RegressionAction]
}
fn apply(&mut self, _: &RegressionAction) {
self.0 += 1;
}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
}
#[test]
fn regression_max_nodes_zero_no_panic() {
let mut search = TreeSearch::new(
RegressionEnv(0),
SearchConfig::builder().iterations(10).build(),
);
search.with_max_nodes(0);
search.run();
assert_eq!(search.tree_size(), 1);
}
#[test]
fn regression_zero_time_budget_no_iterations() {
let mut search = TreeSearch::new(
RegressionEnv(0),
SearchConfig::builder()
.iterations(1_000_000)
.time_budget(Duration::from_millis(0))
.build(),
);
search.run();
assert_eq!(search.total_simulations(), 0);
}
#[test]
fn regression_advance_unexpanded_returns_false() {
let mut search = TreeSearch::new(
RegressionEnv(0),
SearchConfig::builder().iterations(0).build(),
);
assert!(!search.advance_to_action(&RegressionAction));
}
#[test]
fn regression_empty_dot_export() {
let search = TreeSearch::new(RegressionEnv(0), SearchConfig::default());
let dot = search.export_dot(0);
assert!(dot.starts_with("digraph mctrust {"));
assert!(dot.ends_with("}\n"));
}
#[test]
fn regression_next_arms_zero() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
let arms = search.next_arms(0);
assert!(arms.is_empty());
}
#[test]
fn regression_observe_nonexistent_arm() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
search.observe(999, 1.0);
let stats = search.group_stats();
assert_eq!(stats[0].visits, 0);
}
struct NaNEval;
impl Evaluator<RegressionEnv> for NaNEval {
fn evaluate(&self, _: &RegressionEnv) -> Reward {
Reward::new(f64::NAN)
}
}
#[test]
fn regression_evaluator_nan_no_panic() {
let mut search = TreeSearch::new(
RegressionEnv(0),
SearchConfig::builder().iterations(10).build(),
);
search.with_evaluator(Arc::new(NaNEval));
search.run();
assert_eq!(search.total_simulations(), 10);
}
#[derive(Clone)]
struct BadPriorEnv;
#[derive(Clone, Debug, PartialEq)]
struct BadPriorAction;
impl Environment for BadPriorEnv {
type Action = BadPriorAction;
fn legal_actions(&self) -> Vec<BadPriorAction> {
vec![BadPriorAction]
}
fn apply(&mut self, _: &BadPriorAction) {}
fn evaluate(&self) -> Outcome {
Outcome::Neutral
}
fn action_priors(&self, _actions: &[BadPriorAction]) -> Option<Vec<f64>> {
Some(vec![f64::NAN, f64::NEG_INFINITY])
}
}
#[test]
fn regression_puct_malformed_priors() {
let mut search = TreeSearch::with_seed(
BadPriorEnv,
SearchConfig::builder()
.iterations(10)
.tree_policy(TreePolicy::Puct { prior_weight: 1.0 })
.build(),
42,
);
search.run();
assert_eq!(search.total_simulations(), 10);
}
#[test]
fn regression_gumbel_zero_sampled_actions() {
let mut config = SearchConfig::default();
config.tree_policy = TreePolicy::Gumbel {
sampled_actions: 0,
max_completions_coeff: 50.0,
};
let warnings = config.sanitize();
assert!(!warnings.is_empty());
match config.tree_policy {
TreePolicy::Gumbel {
sampled_actions, ..
} => assert_eq!(sampled_actions, 16),
_ => panic!("Expected Gumbel"),
}
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct CheckpointEnv(i32);
#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
struct CheckpointAction;
impl Environment for CheckpointEnv {
type Action = CheckpointAction;
fn legal_actions(&self) -> Vec<CheckpointAction> {
vec![CheckpointAction]
}
fn apply(&mut self, _: &CheckpointAction) {
self.0 += 1;
}
fn evaluate(&self) -> Outcome {
Outcome::Neutral
}
}
#[test]
fn regression_checkpoint_before_search() {
let search = TreeSearch::new(CheckpointEnv(0), SearchConfig::default());
let cp = search.checkpoint();
let restored: TreeSearch<CheckpointEnv> = TreeSearch::restore(cp);
assert_eq!(restored.tree_size(), 1);
assert_eq!(restored.total_simulations(), 0);
}
#[test]
fn regression_scalarizer_nan_weight_sanitized() {
let mut config = BanditConfig::default();
config
.scalarizer
.signal_weights
.insert("bad".to_string(), f64::NAN);
let warnings = config.sanitize();
assert!(!warnings.is_empty());
assert!(!config.scalarizer.signal_weights.contains_key("bad"));
}
#[test]
fn regression_group_stats_no_nan_average() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
let stats = search.group_stats();
assert!(!stats[0].average_reward.is_nan());
}
#[test]
fn regression_run_until_immediate_stop() {
let mut search = TreeSearch::new(
RegressionEnv(0),
SearchConfig::builder().iterations(1_000_000).build(),
);
let best = search.run_until(|_| true);
assert_eq!(search.total_simulations(), 0);
assert!(best.is_none());
}
#[cfg(feature = "dag")]
#[test]
fn regression_pv_terminates_with_dag_cycles() {
#[derive(Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
struct CycleEnv(i32);
#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
struct CycleAction;
impl Environment for CycleEnv {
type Action = CycleAction;
fn legal_actions(&self) -> Vec<CycleAction> {
vec![CycleAction]
}
fn apply(&mut self, _: &CycleAction) {
self.0 = (self.0 + 1) % 2;
}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
fn state_hash(&self) -> Option<u64> {
Some(self.0 as u64)
}
}
let mut search = TreeSearch::with_seed(
CycleEnv(0),
SearchConfig::builder().iterations(50).build(),
42,
);
search.enable_dag();
search.run();
let pv = search.principal_variation();
assert!(pv.len() <= search.tree_size());
}
#[test]
fn regression_observe_with_empty_signals() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
let arm = search.next_arm().unwrap();
search.observe_with_signals(arm, &[]);
let stats = search.group_stats();
assert_eq!(stats[0].visits, 1);
assert_eq!(stats[0].average_reward, 0.0);
}
#[test]
fn regression_group_bias_nan_no_panic() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
search.set_group_bias(0, f64::NAN);
let arm = search.next_arm();
assert!(arm.is_some());
}