use mctrust::*;
use std::sync::Arc;
use std::time::Duration;
#[test]
fn bandit_empty_search_returns_none() {
let mut search = BanditSearch::new(BanditConfig::default());
assert!(search.next_arm().is_none());
assert!(search.group_stats().is_empty());
assert_eq!(search.total_pulls(), 0);
}
#[test]
fn bandit_empty_next_arms() {
let mut search = BanditSearch::new(BanditConfig::default());
assert!(search.next_arms(10).is_empty());
}
#[test]
fn bandit_empty_observe_does_nothing() {
let mut search = BanditSearch::new(BanditConfig::default());
search.observe(0, 1.0); assert!(search.group_stats().is_empty());
}
#[test]
fn treesearch_empty_actions_returns_none() {
#[derive(Clone)]
struct EmptyEnv;
#[derive(Clone, Debug, PartialEq)]
struct NoAction;
impl Environment for EmptyEnv {
type Action = NoAction;
fn legal_actions(&self) -> Vec<NoAction> {
vec![]
}
fn apply(&mut self, _: &NoAction) {}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
}
let mut search = TreeSearch::new(EmptyEnv, SearchConfig::default());
assert!(search.run().is_none());
assert!(search.best_root_action().is_none());
assert!(search.best_root_reward().is_none());
}
#[test]
fn treesearch_single_action_always_returns_it() {
#[derive(Clone)]
struct SingleEnv(bool);
#[derive(Clone, Debug, PartialEq)]
struct OneAction;
impl Environment for SingleEnv {
type Action = OneAction;
fn legal_actions(&self) -> Vec<OneAction> {
if self.0 {
vec![]
} else {
vec![OneAction]
}
}
fn apply(&mut self, _: &OneAction) {
self.0 = true;
}
fn evaluate(&self) -> Outcome {
if self.0 {
Outcome::Neutral
} else {
Outcome::Ongoing
}
}
}
let mut search = TreeSearch::new(
SingleEnv(false),
SearchConfig::builder().iterations(10).build(),
);
assert_eq!(search.run(), Some(OneAction));
}
#[test]
fn bandit_single_arm_exhausts_after_one_pull() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
assert_eq!(search.next_arm(), Some(0));
assert!(search.next_arm().is_none());
assert_eq!(search.total_pulls(), 1);
}
#[test]
fn bandit_single_group_many_arms() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
for i in 0..1000u64 {
search.add_arm(i, 0);
}
let arms = search.next_arms(1000);
assert_eq!(arms.len(), 1000);
assert_eq!(search.total_pulls(), 1000);
assert!(search.next_arm().is_none());
}
#[test]
fn bandit_zero_max_pulls_means_unlimited() {
let mut search = BanditSearch::new_seeded(BanditConfig::builder().max_pulls(0).build(), 42);
for i in 0..5u64 {
search.add_arm(i, 0);
}
let arms = search.next_arms(5);
assert_eq!(arms.len(), 5);
}
#[test]
fn bandit_zero_exploration_constant() {
let mut search = BanditSearch::new_seeded(
BanditConfig::builder().exploration_constant(0.0).build(),
42,
);
for i in 0..10u64 {
search.add_arm(i, 0);
}
let arm = search.next_arm();
assert!(arm.is_some());
}
#[test]
fn treesearch_zero_iterations() {
#[derive(Clone)]
struct QuickEnv;
#[derive(Clone, Debug, PartialEq)]
struct A;
impl Environment for QuickEnv {
type Action = A;
fn legal_actions(&self) -> Vec<A> {
vec![A]
}
fn apply(&mut self, _: &A) {}
fn evaluate(&self) -> Outcome {
Outcome::Neutral
}
}
let mut cfg = SearchConfig::default();
cfg.iterations = 0;
let mut search = TreeSearch::new(QuickEnv, cfg);
search.run();
assert_eq!(search.total_simulations(), 0);
}
#[test]
fn treesearch_zero_max_depth() {
#[derive(Clone)]
struct QuickEnv;
#[derive(Clone, Debug, PartialEq)]
struct A;
impl Environment for QuickEnv {
type Action = A;
fn legal_actions(&self) -> Vec<A> {
vec![A]
}
fn apply(&mut self, _: &A) {}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
fn heuristic(&self) -> Heuristic {
Heuristic::from_reward(Reward::new(0.5))
}
}
let mut search = TreeSearch::new(
QuickEnv,
SearchConfig::builder().iterations(10).max_depth(0).build(),
);
search.run();
assert_eq!(search.total_simulations(), 10);
}
#[test]
fn treesearch_zero_time_budget_stops_immediately() {
#[derive(Clone)]
struct QuickEnv;
#[derive(Clone, Debug, PartialEq)]
struct A;
impl Environment for QuickEnv {
type Action = A;
fn legal_actions(&self) -> Vec<A> {
vec![A]
}
fn apply(&mut self, _: &A) {}
fn evaluate(&self) -> Outcome {
Outcome::Neutral
}
}
let mut search = TreeSearch::new(
QuickEnv,
SearchConfig::builder()
.iterations(1_000_000)
.time_budget(Duration::from_millis(0))
.build(),
);
search.run();
assert_eq!(search.total_simulations(), 0);
}
#[test]
fn bandit_u64_max_arm_id() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(u64::MAX, 0);
assert_eq!(search.next_arm(), Some(u64::MAX));
}
#[test]
fn bandit_u32_max_group_id() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, u32::MAX);
let stats = search.group_stats();
assert_eq!(stats.len(), 1);
assert_eq!(stats[0].group_id, u32::MAX);
}
#[test]
fn bandit_max_pulls_u64_max() {
let mut search =
BanditSearch::new_seeded(BanditConfig::builder().max_pulls(u64::MAX).build(), 42);
for i in 0..5u64 {
search.add_arm(i, 0);
}
let arms = search.next_arms(usize::MAX);
assert_eq!(arms.len(), 5);
}
#[test]
fn treesearch_huge_iterations_with_time_budget() {
#[derive(Clone)]
struct EndlessEnv(i32);
impl Environment for EndlessEnv {
type Action = i32;
fn legal_actions(&self) -> Vec<i32> {
vec![1]
}
fn apply(&mut self, a: &i32) {
self.0 += a;
}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
}
let mut search = TreeSearch::new(
EndlessEnv(0),
SearchConfig::builder()
.iterations(usize::MAX)
.time_budget(Duration::from_millis(5))
.build(),
);
search.run();
assert!(search.total_simulations() > 0);
assert!(search.total_simulations() < u32::MAX);
}
#[test]
fn treesearch_max_nodes_zero_prevents_expansion() {
#[derive(Clone)]
struct EndlessEnv(i32);
impl Environment for EndlessEnv {
type Action = i32;
fn legal_actions(&self) -> Vec<i32> {
vec![1]
}
fn apply(&mut self, a: &i32) {
self.0 += a;
}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
}
let mut search = TreeSearch::new(
EndlessEnv(0),
SearchConfig::builder().iterations(100).build(),
);
search.with_max_nodes(0);
search.run();
assert_eq!(search.tree_size(), 1);
}
#[test]
fn bandit_nan_reward_is_rejected() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
let arm = search.next_arm().unwrap();
search.observe(arm, f64::NAN);
let stats = search.group_stats();
assert_eq!(stats[0].visits, 0);
}
#[test]
fn bandit_infinity_reward_is_rejected() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
let arm = search.next_arm().unwrap();
search.observe(arm, f64::INFINITY);
let stats = search.group_stats();
assert_eq!(stats[0].visits, 0);
}
#[test]
fn bandit_neg_infinity_reward_is_rejected() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
let arm = search.next_arm().unwrap();
search.observe(arm, f64::NEG_INFINITY);
let stats = search.group_stats();
assert_eq!(stats[0].visits, 0);
}
#[test]
fn reward_arithmetic_with_nan() {
let a = Reward::new(0.5);
let b = Reward::new(f64::NAN);
let sum = a + b;
assert!(sum.value().is_nan());
}
#[test]
fn reward_arithmetic_with_infinity() {
let a = Reward::new(0.5);
let b = Reward::new(f64::INFINITY);
let sum = a + b;
assert!(sum.value().is_infinite());
}
#[test]
fn treesearch_nan_heuristic_weighted() {
#[derive(Clone)]
struct NanHeuristicEnv;
#[derive(Clone, Debug, PartialEq)]
struct A;
impl Environment for NanHeuristicEnv {
type Action = A;
fn legal_actions(&self) -> Vec<A> {
vec![A]
}
fn apply(&mut self, _: &A) {}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
fn heuristic(&self) -> Heuristic {
Heuristic::from_reward(Reward::new(f64::NAN))
}
}
let mut search = TreeSearch::new(
NanHeuristicEnv,
SearchConfig::builder().iterations(10).max_depth(0).build(),
);
search.run();
assert_eq!(search.total_simulations(), 10);
}
#[test]
fn bandit_negative_exploration_constant_gets_sanitized() {
let config = BanditConfig::builder().exploration_constant(-1.0).build();
assert!(config.exploration_constant >= 0.0);
}
#[test]
fn bandit_negative_rave_bias_gets_sanitized() {
let config = BanditConfig::builder().rave_bias(-1.0).build();
assert!(config.rave_bias >= 0.0);
}
#[test]
fn searchconfig_negative_exploration_constant_gets_sanitized() {
let config = SearchConfig::builder().exploration_constant(-1.0).build();
assert!(config.exploration_constant >= 0.0);
}
#[test]
fn searchconfig_negative_heuristic_weight_gets_clamped() {
let config = SearchConfig::builder().heuristic_weight(-0.5).build();
assert!(config.heuristic_weight >= 0.0);
assert!(config.heuristic_weight <= 1.0);
}
#[test]
fn treesearch_immediate_terminal() {
#[derive(Clone)]
struct TerminalEnv;
#[derive(Clone, Debug, PartialEq)]
struct A;
impl Environment for TerminalEnv {
type Action = A;
fn legal_actions(&self) -> Vec<A> {
vec![A]
}
fn apply(&mut self, _: &A) {}
fn evaluate(&self) -> Outcome {
Outcome::Terminal(Reward::WIN)
}
}
let mut search = TreeSearch::new(TerminalEnv, SearchConfig::builder().iterations(10).build());
search.run();
assert_eq!(search.total_simulations(), 10);
}
#[test]
fn treesearch_infinite_loop_environment() {
#[derive(Clone)]
struct LoopEnv(i32);
impl Environment for LoopEnv {
type Action = i32;
fn legal_actions(&self) -> Vec<i32> {
vec![1, -1]
}
fn apply(&mut self, a: &i32) {
self.0 += a;
}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
}
let mut search = TreeSearch::new(
LoopEnv(0),
SearchConfig::builder().iterations(100).max_depth(5).build(),
);
search.run();
assert_eq!(search.total_simulations(), 100);
}
#[test]
fn treesearch_always_no_actions_after_first_move() {
#[derive(Clone)]
struct OneMoveEnv(bool);
#[derive(Clone, Debug, PartialEq)]
struct A;
impl Environment for OneMoveEnv {
type Action = A;
fn legal_actions(&self) -> Vec<A> {
if self.0 {
vec![]
} else {
vec![A]
}
}
fn apply(&mut self, _: &A) {
self.0 = true;
}
fn evaluate(&self) -> Outcome {
if self.0 {
Outcome::Terminal(Reward::WIN)
} else {
Outcome::Ongoing
}
}
}
let mut search = TreeSearch::new(
OneMoveEnv(false),
SearchConfig::builder().iterations(50).build(),
);
search.run();
assert_eq!(search.total_simulations(), 50);
}
#[test]
fn treesearch_adversarial_two_player_flip() {
#[derive(Clone)]
struct FlipEnv(i32);
#[derive(Clone, Debug, PartialEq)]
enum FlipAction {
Left,
Right,
}
impl Environment for FlipEnv {
type Action = FlipAction;
fn legal_actions(&self) -> Vec<FlipAction> {
vec![FlipAction::Left, FlipAction::Right]
}
fn apply(&mut self, a: &FlipAction) {
match a {
FlipAction::Left => self.0 -= 1,
FlipAction::Right => self.0 += 1,
}
}
fn evaluate(&self) -> Outcome {
if self.0.abs() >= 5 {
Outcome::Terminal(Reward::WIN)
} else {
Outcome::Ongoing
}
}
fn current_player(&self) -> usize {
self.0.abs() as usize % 2
}
fn num_players(&self) -> usize {
2
}
}
let mut search = TreeSearch::with_seed(
FlipEnv(0),
SearchConfig::builder().iterations(200).build(),
42,
);
let best = search.run();
assert!(best.is_some());
}
#[test]
fn treesearch_puct_with_malformed_priors() {
#[derive(Clone)]
struct BadPriorEnv;
#[derive(Clone, Debug, PartialEq)]
struct A;
impl Environment for BadPriorEnv {
type Action = A;
fn legal_actions(&self) -> Vec<A> {
vec![A]
}
fn apply(&mut self, _: &A) {}
fn evaluate(&self) -> Outcome {
Outcome::Neutral
}
fn action_priors(&self, _actions: &[A]) -> Option<Vec<f64>> {
Some(vec![f64::NAN])
}
}
let mut search = TreeSearch::with_seed(
BadPriorEnv,
SearchConfig::builder()
.iterations(10)
.tree_policy(TreePolicy::Puct { prior_weight: 1.0 })
.build(),
42,
);
search.run();
assert_eq!(search.total_simulations(), 10);
}
#[test]
fn treesearch_puct_with_empty_priors() {
#[derive(Clone)]
struct EmptyPriorEnv;
#[derive(Clone, Debug, PartialEq)]
struct A;
impl Environment for EmptyPriorEnv {
type Action = A;
fn legal_actions(&self) -> Vec<A> {
vec![A]
}
fn apply(&mut self, _: &A) {}
fn evaluate(&self) -> Outcome {
Outcome::Neutral
}
fn action_priors(&self, _actions: &[A]) -> Option<Vec<f64>> {
Some(vec![])
}
}
let mut search = TreeSearch::with_seed(
EmptyPriorEnv,
SearchConfig::builder()
.iterations(10)
.tree_policy(TreePolicy::Puct { prior_weight: 1.0 })
.build(),
42,
);
search.run();
assert_eq!(search.total_simulations(), 10);
}
#[test]
fn treesearch_puct_with_too_many_priors() {
#[derive(Clone)]
struct ExtraPriorEnv;
#[derive(Clone, Debug, PartialEq)]
struct A;
impl Environment for ExtraPriorEnv {
type Action = A;
fn legal_actions(&self) -> Vec<A> {
vec![A]
}
fn apply(&mut self, _: &A) {}
fn evaluate(&self) -> Outcome {
Outcome::Neutral
}
fn action_priors(&self, _actions: &[A]) -> Option<Vec<f64>> {
Some(vec![0.5, 0.5, 0.5, 0.5]) }
}
let mut search = TreeSearch::with_seed(
ExtraPriorEnv,
SearchConfig::builder()
.iterations(10)
.tree_policy(TreePolicy::Puct { prior_weight: 1.0 })
.build(),
42,
);
search.run();
assert_eq!(search.total_simulations(), 10);
}
#[test]
fn treesearch_unicode_actions() {
#[derive(Clone, Debug, PartialEq)]
struct UnicodeEnv;
impl Environment for UnicodeEnv {
type Action = String;
fn legal_actions(&self) -> Vec<String> {
vec!["α".to_string(), "β".to_string(), "γ".to_string()]
}
fn apply(&mut self, _: &String) {}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
}
let mut search = TreeSearch::with_seed(
UnicodeEnv,
SearchConfig::builder().iterations(50).build(),
42,
);
let best = search.run();
assert!(best.is_some());
}
#[test]
fn treesearch_very_long_action_strings() {
#[derive(Clone, Debug, PartialEq)]
struct LongStringEnv;
impl Environment for LongStringEnv {
type Action = String;
fn legal_actions(&self) -> Vec<String> {
vec!["a".repeat(10000), "b".repeat(10000)]
}
fn apply(&mut self, _: &String) {}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
}
let mut search = TreeSearch::with_seed(
LongStringEnv,
SearchConfig::builder().iterations(10).build(),
42,
);
let best = search.run();
assert!(best.is_some());
}
#[test]
fn bandit_nan_bias_does_not_panic() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
for i in 0..10u64 {
search.add_arm(i, (i / 5) as u32);
}
search.set_group_bias(0, f64::NAN);
search.set_group_bias(1, f64::NAN);
let arm = search.next_arm();
assert!(arm.is_some());
}
#[test]
fn bandit_infinity_bias_selects_that_group() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
for i in 0..10u64 {
search.add_arm(i, (i / 5) as u32);
}
search.set_group_bias(1, f64::INFINITY);
let arm = search.next_arm().unwrap();
assert!(arm >= 5);
}
#[test]
fn bandit_bias_on_nonexistent_group_is_noop() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
search.set_group_bias(9999, 1000.0);
assert_eq!(search.next_arm(), Some(0));
}
#[test]
fn bandit_checkpoint_empty_restore() {
let search = BanditSearch::new(BanditConfig::default());
let cp = search.checkpoint();
let restored = BanditSearch::restore(cp);
assert_eq!(restored.total_pulls(), 0);
}
#[test]
fn treesearch_checkpoint_before_any_simulations() {
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct SimpleEnv(i32);
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
struct A;
impl Environment for SimpleEnv {
type Action = A;
fn legal_actions(&self) -> Vec<A> {
vec![A]
}
fn apply(&mut self, _: &A) {
self.0 += 1;
}
fn evaluate(&self) -> Outcome {
Outcome::Neutral
}
}
let search = TreeSearch::new(SimpleEnv(0), SearchConfig::default());
let cp = search.checkpoint();
let restored = TreeSearch::restore(cp);
assert_eq!(restored.tree_size(), 1);
}
#[test]
fn treesearch_evaluator_returns_nan() {
#[derive(Clone)]
struct Env;
#[derive(Clone, Debug, PartialEq)]
struct A;
impl Environment for Env {
type Action = A;
fn legal_actions(&self) -> Vec<A> {
vec![A]
}
fn apply(&mut self, _: &A) {}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
}
struct NanEval;
impl Evaluator<Env> for NanEval {
fn evaluate(&self, _: &Env) -> Reward {
Reward::new(f64::NAN)
}
}
let mut search = TreeSearch::new(Env, SearchConfig::builder().iterations(10).build());
search.with_evaluator(Arc::new(NanEval));
search.run();
assert_eq!(search.total_simulations(), 10);
}
#[test]
fn treesearch_evaluator_returns_infinity() {
#[derive(Clone)]
struct Env;
#[derive(Clone, Debug, PartialEq)]
struct A;
impl Environment for Env {
type Action = A;
fn legal_actions(&self) -> Vec<A> {
vec![A]
}
fn apply(&mut self, _: &A) {}
fn evaluate(&self) -> Outcome {
Outcome::Ongoing
}
}
struct InfEval;
impl Evaluator<Env> for InfEval {
fn evaluate(&self, _: &Env) -> Reward {
Reward::new(f64::INFINITY)
}
}
let mut search = TreeSearch::new(Env, SearchConfig::builder().iterations(10).build());
search.with_evaluator(Arc::new(InfEval));
search.run();
assert_eq!(search.total_simulations(), 10);
}