use super::*;
use crate::{ProgressiveWideningConfig, Reward, TreePolicy};
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct NumberGame {
value: i32,
target: i32,
move_count: u32,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
enum Move {
Inc,
Dec,
}
impl Environment for NumberGame {
type Action = Move;
fn legal_actions(&self) -> Vec<Move> {
vec![Move::Inc, Move::Dec]
}
fn apply(&mut self, action: &Move) {
match action {
Move::Inc => {
self.value += 1;
self.move_count += 1;
}
Move::Dec => {
self.value -= 1;
self.move_count += 1;
}
}
}
fn evaluate(&self) -> Outcome {
if self.value == self.target {
Outcome::Success(Reward::WIN)
} else if (self.value - self.target).abs() > 20 {
Outcome::Failure
} else {
Outcome::Ongoing
}
}
fn heuristic(&self) -> crate::environment::Heuristic {
let distance = f64::from((self.value - self.target).abs());
let h = 1.0 - (distance / 20.0).min(1.0);
crate::environment::Heuristic::from_reward(Reward::new(h))
}
fn max_depth(&self) -> Option<usize> {
Some(20)
}
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct PriorGame {
state: i32,
}
impl Environment for PriorGame {
type Action = i32;
fn legal_actions(&self) -> Vec<i32> {
if self.state == 0 {
vec![1, -1]
} else {
vec![0]
}
}
fn apply(&mut self, action: &i32) {
self.state += action;
}
fn evaluate(&self) -> Outcome {
if self.state == 2 {
Outcome::Success(Reward::WIN)
} else if self.state == -2 {
Outcome::Failure
} else {
Outcome::Ongoing
}
}
fn action_priors(&self, actions: &[Self::Action]) -> Option<Vec<f64>> {
Some(
actions
.iter()
.map(|a| if *a == 1 { 1.0 } else { 0.1 })
.collect(),
)
}
}
fn number_line_search_config(iterations: usize, max_depth: usize) -> SearchConfig {
SearchConfig::builder()
.iterations(iterations)
.max_depth(max_depth)
.heuristic_weight(1.0)
.build()
}
#[test]
fn finds_correct_direction() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = number_line_search_config(2_000, 10);
let mut search = TreeSearch::with_seed(game, config, 42);
let best = search.run();
assert_eq!(best, Some(Move::Inc));
}
#[test]
fn finds_negative_direction() {
let game = NumberGame {
value: 0,
target: -3,
move_count: 0,
};
let config = number_line_search_config(2_000, 10);
let mut search = TreeSearch::with_seed(game, config, 42);
let best = search.run();
assert_eq!(best, Some(Move::Dec));
}
#[test]
fn no_actions_returns_none() {
#[derive(Clone)]
struct DeadEnd;
#[derive(Clone, Debug, PartialEq)]
struct Noop;
impl Environment for DeadEnd {
type Action = Noop;
fn legal_actions(&self) -> Vec<Noop> {
vec![]
}
fn apply(&mut self, _: &Noop) {}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
}
let mut search = TreeSearch::new(DeadEnd, SearchConfig::default());
assert!(search.run().is_none());
}
#[test]
fn root_stats_populated() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder().iterations(500).build();
let mut search = TreeSearch::with_seed(game, config, 99);
search.run();
let stats = search.root_stats();
assert_eq!(stats.len(), 2); let total_visits: u32 = stats.iter().map(|(_, s)| s.visits).sum();
assert!(total_visits > 0);
assert!(total_visits <= 500);
}
#[test]
fn tree_grows_with_iterations() {
let game = NumberGame {
value: 0,
target: 5,
move_count: 0,
};
let config = SearchConfig::builder().iterations(100).max_depth(8).build();
let mut search = TreeSearch::with_seed(game, config, 7);
search.run();
assert!(search.tree_size() > 3);
}
#[test]
fn deterministic_with_same_seed() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = number_line_search_config(1_000, 10);
let mut s1 = TreeSearch::with_seed(game.clone(), config.clone(), 42);
let mut s2 = TreeSearch::with_seed(game, config, 42);
assert_eq!(s1.run(), s2.run());
}
#[test]
fn total_simulations_matches_iterations() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder().iterations(200).build();
let mut search = TreeSearch::with_seed(game, config, 1);
search.run();
assert_eq!(search.total_simulations(), 200);
}
#[test]
fn puct_policy_uses_priors() {
let game = PriorGame { state: 0 };
let config = SearchConfig::builder()
.iterations(100)
.tree_policy(TreePolicy::Puct { prior_weight: 2.0 })
.build();
let mut search = TreeSearch::with_seed(game, config, 12);
search.run();
let best = search.best_root_action();
assert_eq!(best, Some(1));
}
#[test]
fn thompson_policy_runs() {
let game = PriorGame { state: 0 };
let config = SearchConfig::builder()
.iterations(400)
.tree_policy(TreePolicy::ThompsonSampling { temperature: 0.5 })
.build();
let mut search = TreeSearch::with_seed(game, config, 33);
let best = search.run();
assert_eq!(best, Some(1));
assert!(search.total_simulations() > 0);
}
#[test]
fn uses_progressive_widening_limit() {
let game = PriorGame { state: 0 };
let config = SearchConfig::builder()
.iterations(20)
.progressive_widening(ProgressiveWideningConfig {
minimum_children: 1,
coefficient: 0.0,
exponent: 1.0,
})
.build();
let mut search = TreeSearch::with_seed(game, config, 5);
search.run();
assert_eq!(search.nodes[0].children.len(), 1);
}
#[test]
fn checkpoint_restores_progress() {
let game = NumberGame {
value: 0,
target: 4,
move_count: 0,
};
let config = SearchConfig::builder().iterations(20).build();
let mut search = TreeSearch::with_seed(game, config, 11);
search.run();
let checkpoint = search.checkpoint();
let resumed = TreeSearch::restore(checkpoint);
assert_eq!(search.tree_size(), resumed.tree_size());
}
#[test]
fn uses_uct_rave_toggle() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder()
.iterations(50)
.rave(crate::config::RaveConfig {
enabled: false,
bias: 1.0,
})
.build();
let search = TreeSearch::with_seed(game, config, 1);
assert!(!search.uses_rave());
}
#[test]
fn checkpoint_roundtrip() {
let game = PriorGame { state: 0 };
let config = SearchConfig::builder().iterations(80).build();
let mut search = TreeSearch::with_seed(game, config, 1);
search.run();
let checkpoint = search.checkpoint();
let resumed: TreeSearch<PriorGame> = TreeSearch::restore(checkpoint);
assert!(resumed.total_simulations() > 0);
}
#[test]
fn gumbel_policy_finds_correct_direction() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder()
.iterations(2_000)
.max_depth(10)
.heuristic_weight(1.0)
.tree_policy(TreePolicy::Gumbel {
sampled_actions: 16,
max_completions_coeff: 50.0,
})
.build();
let mut search = TreeSearch::with_seed(game, config, 42);
let best = search.run();
assert_eq!(best, Some(Move::Inc));
}
#[test]
fn gumbel_policy_total_simulations_correct() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder()
.iterations(500)
.tree_policy(TreePolicy::Gumbel {
sampled_actions: 8,
max_completions_coeff: 50.0,
})
.build();
let mut search = TreeSearch::with_seed(game, config, 7);
search.run();
assert_eq!(search.total_simulations(), 500);
}
#[cfg(feature = "dag")]
#[test]
fn dag_enable_disable_lifecycle() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder().iterations(100).build();
let mut search = TreeSearch::with_seed(game, config, 1);
assert_eq!(search.dag_hit_count(), 0);
search.enable_dag();
search.run();
assert!(search.dag_hit_count() == 0);
search.disable_dag();
assert_eq!(search.dag_hit_count(), 0);
}
#[derive(Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
struct HashableGame {
value: i32,
target: i32,
}
impl Environment for HashableGame {
type Action = Move;
fn legal_actions(&self) -> Vec<Move> {
vec![Move::Inc, Move::Dec]
}
fn apply(&mut self, action: &Move) {
match action {
Move::Inc => self.value += 1,
Move::Dec => self.value -= 1,
}
}
fn evaluate(&self) -> Outcome {
if self.value == self.target {
Outcome::Success(Reward::WIN)
} else if (self.value - self.target).abs() > 15 {
Outcome::Failure
} else {
Outcome::Ongoing
}
}
fn state_hash(&self) -> Option<u64> {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
self.hash(&mut hasher);
Some(hasher.finish())
}
}
#[cfg(feature = "dag")]
#[test]
fn dag_transposition_reduces_tree_size() {
let game = HashableGame {
value: 0,
target: 3,
};
let config = SearchConfig::builder()
.iterations(500)
.max_depth(10)
.build();
let mut search_no_dag = TreeSearch::with_seed(game.clone(), config.clone(), 42);
search_no_dag.run();
let size_no_dag = search_no_dag.tree_size();
let mut search_dag = TreeSearch::with_seed(game, config, 42);
search_dag.enable_dag();
search_dag.run();
let size_dag = search_dag.tree_size();
assert!(size_dag <= size_no_dag);
assert!(search_dag.dag_hit_count() > 0);
}
#[cfg(feature = "parallel")]
#[test]
fn parallel_search_finds_correct_direction() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = number_line_search_config(2_000, 10);
let mut search = TreeSearch::with_seed(game, config, 42);
let best = search.run_parallel(4);
assert_eq!(best, Some(Move::Inc));
}
#[cfg(feature = "parallel")]
#[test]
fn parallel_search_negative_direction() {
let game = NumberGame {
value: 0,
target: -3,
move_count: 0,
};
let config = number_line_search_config(2_000, 10);
let mut search = TreeSearch::with_seed(game, config, 42);
let best = search.run_parallel(4);
assert_eq!(best, Some(Move::Dec));
}
#[cfg(feature = "parallel")]
#[test]
fn parallel_with_single_thread_matches_serial() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = number_line_search_config(1_000, 10);
let mut serial = TreeSearch::with_seed(game.clone(), config.clone(), 42);
let serial_result = serial.run();
let mut parallel = TreeSearch::with_seed(game, config, 42);
let parallel_result = parallel.run_parallel(1);
assert_eq!(serial_result, parallel_result);
}
#[test]
fn principal_variation_nonempty_after_search() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = number_line_search_config(500, 10);
let mut search = TreeSearch::with_seed(game, config, 42);
search.run();
let pv = search.principal_variation();
assert!(!pv.is_empty());
assert_eq!(pv[0], Move::Inc);
}
#[test]
fn principal_variation_empty_before_search() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::default();
let search = TreeSearch::with_seed(game, config, 42);
let pv = search.principal_variation();
assert!(pv.is_empty());
}
#[test]
fn best_root_reward_some_after_search() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder().iterations(200).build();
let mut search = TreeSearch::with_seed(game, config, 42);
search.run();
let reward = search.best_root_reward();
assert!(reward.is_some());
}
#[test]
fn best_root_reward_none_before_search() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::default();
let search = TreeSearch::with_seed(game, config, 42);
let reward = search.best_root_reward();
assert!(reward.is_none());
}
#[test]
fn run_step_increments_simulations() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder().iterations(100).build();
let mut search = TreeSearch::with_seed(game, config, 42);
assert_eq!(search.total_simulations(), 0);
search.run_step();
assert_eq!(search.total_simulations(), 1);
search.run_step();
assert_eq!(search.total_simulations(), 2);
}
#[test]
fn run_step_accumulates_to_correct_result() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = number_line_search_config(2_000, 10);
let mut search = TreeSearch::with_seed(game, config, 42);
for _ in 0..2_000 {
search.run_step();
}
assert_eq!(search.best_root_action(), Some(Move::Inc));
}
#[test]
fn time_budget_stops_early() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder()
.iterations(10_000_000)
.time_budget(std::time::Duration::from_millis(5))
.build();
let mut search = TreeSearch::with_seed(game, config, 42);
search.run();
assert!(search.total_simulations() < 10_000_000);
assert!(search.total_simulations() > 0);
}
struct ConstantEvaluator(f64);
impl crate::Evaluator<NumberGame> for ConstantEvaluator {
fn evaluate(&self, _env: &NumberGame) -> crate::Reward {
crate::Reward::new(self.0)
}
}
#[test]
fn evaluator_replaces_rollout() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder().iterations(200).build();
let mut search = TreeSearch::with_seed(game, config, 42);
search.with_evaluator(std::sync::Arc::new(ConstantEvaluator(0.9)));
search.run();
let reward = search.best_root_reward().unwrap();
assert!(
reward > 0.5,
"Evaluator should influence reward: got {reward}"
);
}
#[test]
fn max_nodes_limits_tree_growth() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder().iterations(5_000).build();
let mut search = TreeSearch::with_seed(game, config, 42);
search.with_max_nodes(50);
search.run();
assert!(
search.tree_size() <= 50,
"Tree should not exceed max_nodes limit: got {}",
search.tree_size()
);
}
#[test]
fn principal_variation_states_starts_with_root() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder().iterations(500).build();
let mut search = TreeSearch::with_seed(game, config, 42);
search.run();
let states = search.principal_variation_states();
assert!(states.len() >= 2, "PV states should have >= 2 elements");
assert_eq!(states[0].value, 0, "First state should be the root");
}
#[test]
fn run_until_stops_on_predicate() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder().iterations(100_000).build();
let mut search = TreeSearch::with_seed(game, config, 42);
search.run_until(|s| s.total_simulations() >= 100);
assert!(search.total_simulations() >= 100);
assert!(
search.total_simulations() < 1_000,
"Should have stopped early"
);
}
#[test]
fn export_dot_produces_valid_dot() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder().iterations(50).build();
let mut search = TreeSearch::with_seed(game, config, 42);
search.run();
let dot = search.export_dot(3);
assert!(dot.starts_with("digraph mctrust {"));
assert!(dot.contains("n0"));
assert!(dot.ends_with("}\n"));
}
#[test]
fn advance_to_action_preserves_subtree() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder().iterations(500).build();
let mut search = TreeSearch::with_seed(game, config, 42);
search.run();
let original_tree_size = search.tree_size();
assert!(original_tree_size > 1);
let best = search.principal_variation()[0].clone();
let best_child_id = search.best_root_child_id().unwrap();
let best_visits_before = search.nodes[best_child_id as usize].visits;
assert!(search.advance_to_action(&best));
assert!(search.tree_size() <= original_tree_size);
assert!(search.tree_size() > 0);
assert_eq!(search.nodes[0].visits, best_visits_before);
assert!(search.nodes[0].parent.is_none());
assert!(search.nodes[0].action.is_none());
}
#[test]
fn advance_to_unknown_action_returns_false() {
let game = NumberGame {
value: 0,
target: 3,
move_count: 0,
};
let config = SearchConfig::builder().iterations(50).build();
let mut search = TreeSearch::with_seed(game, config, 42);
search.run();
let tree_size_before = search.tree_size();
let fresh = TreeSearch::with_seed(
NumberGame {
value: 0,
target: 3,
move_count: 0,
},
SearchConfig::builder().iterations(0).build(),
1,
);
assert!(!fresh.clone().advance_to_action(&Move::Inc));
let _ = tree_size_before;
}
#[test]
fn advance_updates_root_env() {
let game = NumberGame {
value: 0,
target: 5,
move_count: 0,
};
let config = SearchConfig::builder().iterations(500).build();
let mut search = TreeSearch::with_seed(game, config, 42);
search.run();
let best = search.principal_variation()[0].clone();
search.advance_to_action(&best);
match best {
Move::Inc => assert_eq!(search.root_env.value, 1),
Move::Dec => assert_eq!(search.root_env.value, -1),
}
}