mctrust 0.4.0

Universal search & planning toolkit — MCTS, bandit search, pluggable evaluators, tree reuse, DAG transpositions, root parallelism. Define an Environment, search handles the rest.
Documentation
//! Regression tests for bugs found in mctrust.

use mctrust::*;
use std::sync::Arc;
use std::time::Duration;

// =============================================================================
// Regression: RAVE updates when rave_bias=0 should be skipped
// =============================================================================

#[test]
fn regression_rave_disabled_zero_bias() {
    let mut search = BanditSearch::new_seeded(BanditConfig::builder().rave_bias(0.0).build(), 99);
    for i in 0..10u64 {
        search.add_arm(i, (i / 5) as u32);
    }

    let arm1 = search.next_arm().unwrap();
    search.observe(arm1, 1.0);
    let arm2 = search.next_arm().unwrap();
    search.observe(arm2, 1.0);

    let stats = search.group_stats();
    let total_rave: u32 = stats.iter().map(|s| s.rave_visits).sum();
    assert_eq!(
        total_rave, 0,
        "RAVE should be completely disabled at bias=0"
    );
}

// =============================================================================
// Regression: NaN rewards must not poison statistics
// =============================================================================

#[test]
fn regression_nan_reward_poisoning() {
    let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
    search.add_arm(0, 0);
    search.add_arm(1, 0);

    let arm1 = search.next_arm().unwrap();
    search.observe(arm1, f64::NAN);

    let arm2 = search.next_arm().unwrap();
    search.observe(arm2, 1.0);

    let stats = search.group_stats();
    assert_eq!(stats[0].visits, 1, "Only valid reward should count");
    assert!(
        !stats[0].average_reward.is_nan(),
        "Average should not be NaN"
    );
}

// =============================================================================
// Regression: TreeSearch with max_nodes=0 must not panic
// =============================================================================

#[derive(Clone)]
struct RegressionEnv(i32);

#[derive(Clone, Debug, PartialEq)]
struct RegressionAction;

impl Environment for RegressionEnv {
    type Action = RegressionAction;
    fn legal_actions(&self) -> Vec<RegressionAction> {
        vec![RegressionAction]
    }
    fn apply(&mut self, _: &RegressionAction) {
        self.0 += 1;
    }
    fn evaluate(&self) -> Outcome {
        Outcome::Ongoing
    }
}

#[test]
fn regression_max_nodes_zero_no_panic() {
    let mut search = TreeSearch::new(
        RegressionEnv(0),
        SearchConfig::builder().iterations(10).build(),
    );
    search.with_max_nodes(0);
    search.run();
    assert_eq!(search.tree_size(), 1);
}

// =============================================================================
// Regression: time_budget=0 must not run any iterations
// =============================================================================

#[test]
fn regression_zero_time_budget_no_iterations() {
    let mut search = TreeSearch::new(
        RegressionEnv(0),
        SearchConfig::builder()
            .iterations(1_000_000)
            .time_budget(Duration::from_millis(0))
            .build(),
    );
    search.run();
    assert_eq!(search.total_simulations(), 0);
}

// =============================================================================
// Regression: advance_to_action on unexpanded tree must return false
// =============================================================================

#[test]
fn regression_advance_unexpanded_returns_false() {
    let mut search = TreeSearch::new(
        RegressionEnv(0),
        SearchConfig::builder().iterations(0).build(),
    );
    assert!(!search.advance_to_action(&RegressionAction));
}

// =============================================================================
// Regression: empty DOT export must be well-formed
// =============================================================================

#[test]
fn regression_empty_dot_export() {
    let search = TreeSearch::new(RegressionEnv(0), SearchConfig::default());
    let dot = search.export_dot(0);
    assert!(dot.starts_with("digraph mctrust {"));
    assert!(dot.ends_with("}\n"));
}

// =============================================================================
// Regression: BanditSearch next_arms with n=0 must return empty vec
// =============================================================================

#[test]
fn regression_next_arms_zero() {
    let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
    search.add_arm(0, 0);
    let arms = search.next_arms(0);
    assert!(arms.is_empty());
}

// =============================================================================
// Regression: observe_nonexistent_arm must not panic
// =============================================================================

#[test]
fn regression_observe_nonexistent_arm() {
    let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
    search.add_arm(0, 0);
    search.observe(999, 1.0);
    let stats = search.group_stats();
    assert_eq!(stats[0].visits, 0);
}

// =============================================================================
// Regression: TreeSearch with evaluator returning NaN must not panic
// =============================================================================

struct NaNEval;

impl Evaluator<RegressionEnv> for NaNEval {
    fn evaluate(&self, _: &RegressionEnv) -> Reward {
        Reward::new(f64::NAN)
    }
}

#[test]
fn regression_evaluator_nan_no_panic() {
    let mut search = TreeSearch::new(
        RegressionEnv(0),
        SearchConfig::builder().iterations(10).build(),
    );
    search.with_evaluator(Arc::new(NaNEval));
    search.run();
    assert_eq!(search.total_simulations(), 10);
}

// =============================================================================
// Regression: PUCT with malformed priors must not panic
// =============================================================================

#[derive(Clone)]
struct BadPriorEnv;

#[derive(Clone, Debug, PartialEq)]
struct BadPriorAction;

impl Environment for BadPriorEnv {
    type Action = BadPriorAction;
    fn legal_actions(&self) -> Vec<BadPriorAction> {
        vec![BadPriorAction]
    }
    fn apply(&mut self, _: &BadPriorAction) {}
    fn evaluate(&self) -> Outcome {
        Outcome::Neutral
    }
    fn action_priors(&self, _actions: &[BadPriorAction]) -> Option<Vec<f64>> {
        Some(vec![f64::NAN, f64::NEG_INFINITY])
    }
}

#[test]
fn regression_puct_malformed_priors() {
    let mut search = TreeSearch::with_seed(
        BadPriorEnv,
        SearchConfig::builder()
            .iterations(10)
            .tree_policy(TreePolicy::Puct { prior_weight: 1.0 })
            .build(),
        42,
    );
    search.run();
    assert_eq!(search.total_simulations(), 10);
}

// =============================================================================
// Regression: Gumbel policy with 0 sampled_actions gets fixed to 16
// =============================================================================

#[test]
fn regression_gumbel_zero_sampled_actions() {
    let mut config = SearchConfig::default();
    config.tree_policy = TreePolicy::Gumbel {
        sampled_actions: 0,
        max_completions_coeff: 50.0,
    };
    let warnings = config.sanitize();
    assert!(!warnings.is_empty());
    match config.tree_policy {
        TreePolicy::Gumbel {
            sampled_actions, ..
        } => assert_eq!(sampled_actions, 16),
        _ => panic!("Expected Gumbel"),
    }
}

// =============================================================================
// Regression: checkpoint before any simulation must restore correctly
// =============================================================================

#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct CheckpointEnv(i32);

#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
struct CheckpointAction;

impl Environment for CheckpointEnv {
    type Action = CheckpointAction;
    fn legal_actions(&self) -> Vec<CheckpointAction> {
        vec![CheckpointAction]
    }
    fn apply(&mut self, _: &CheckpointAction) {
        self.0 += 1;
    }
    fn evaluate(&self) -> Outcome {
        Outcome::Neutral
    }
}

#[test]
fn regression_checkpoint_before_search() {
    let search = TreeSearch::new(CheckpointEnv(0), SearchConfig::default());
    let cp = search.checkpoint();
    let restored: TreeSearch<CheckpointEnv> = TreeSearch::restore(cp);
    assert_eq!(restored.tree_size(), 1);
    assert_eq!(restored.total_simulations(), 0);
}

// =============================================================================
// Regression: Scalarizer with NaN weight must be sanitized
// =============================================================================

#[test]
fn regression_scalarizer_nan_weight_sanitized() {
    let mut config = BanditConfig::default();
    config
        .scalarizer
        .signal_weights
        .insert("bad".to_string(), f64::NAN);
    let warnings = config.sanitize();
    assert!(!warnings.is_empty());
    assert!(!config.scalarizer.signal_weights.contains_key("bad"));
}

// =============================================================================
// Regression: GroupStats must not have NaN average_reward
// =============================================================================

#[test]
fn regression_group_stats_no_nan_average() {
    let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
    search.add_arm(0, 0);
    let stats = search.group_stats();
    assert!(!stats[0].average_reward.is_nan());
}

// =============================================================================
// Regression: run_until with immediate true predicate must not loop forever
// =============================================================================

#[test]
fn regression_run_until_immediate_stop() {
    let mut search = TreeSearch::new(
        RegressionEnv(0),
        SearchConfig::builder().iterations(1_000_000).build(),
    );
    let best = search.run_until(|_| true);
    assert_eq!(search.total_simulations(), 0);
    assert!(best.is_none());
}

// =============================================================================
// Regression: principal_variation must terminate in cyclic DAG
// =============================================================================

#[cfg(feature = "dag")]
#[test]
fn regression_pv_terminates_with_dag_cycles() {
    #[derive(Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
    struct CycleEnv(i32);

    #[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
    struct CycleAction;

    impl Environment for CycleEnv {
        type Action = CycleAction;
        fn legal_actions(&self) -> Vec<CycleAction> {
            vec![CycleAction]
        }
        fn apply(&mut self, _: &CycleAction) {
            self.0 = (self.0 + 1) % 2;
        }
        fn evaluate(&self) -> Outcome {
            Outcome::Ongoing
        }
        fn state_hash(&self) -> Option<u64> {
            Some(self.0 as u64)
        }
    }

    let mut search = TreeSearch::with_seed(
        CycleEnv(0),
        SearchConfig::builder().iterations(50).build(),
        42,
    );
    search.enable_dag();
    search.run();

    let pv = search.principal_variation();
    // Must terminate, not loop forever
    assert!(pv.len() <= search.tree_size());
}

// =============================================================================
// Regression: bandit observe_with_signals with empty signals
// =============================================================================

#[test]
fn regression_observe_with_empty_signals() {
    let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
    search.add_arm(0, 0);
    let arm = search.next_arm().unwrap();
    search.observe_with_signals(arm, &[]);
    let stats = search.group_stats();
    assert_eq!(stats[0].visits, 1);
    assert_eq!(stats[0].average_reward, 0.0);
}

// =============================================================================
// Regression: bandit set_group_bias with NaN must not panic
// =============================================================================

#[test]
fn regression_group_bias_nan_no_panic() {
    let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
    search.add_arm(0, 0);
    search.set_group_bias(0, f64::NAN);
    let arm = search.next_arm();
    assert!(arm.is_some());
}