use mctrust::*;
use std::collections::HashMap;
use std::time::Duration;
#[test]
fn reward_new_preserves_value() {
let r = Reward::new(0.75);
assert!((r.value() - 0.75).abs() < f64::EPSILON);
}
#[test]
fn reward_constants_are_correct() {
assert_eq!(Reward::WIN.value(), 1.0);
assert_eq!(Reward::LOSS.value(), -1.0);
assert_eq!(Reward::DRAW.value(), 0.0);
}
#[test]
fn reward_default_is_draw() {
assert_eq!(Reward::default(), Reward::DRAW);
}
#[test]
fn reward_addition_works() {
let a = Reward::new(0.3);
let b = Reward::new(0.4);
let sum = a + b;
assert!((sum.value() - 0.7).abs() < f64::EPSILON);
}
#[test]
fn reward_add_assign_works() {
let mut r = Reward::new(0.1);
r += Reward::new(0.2);
assert!((r.value() - 0.3).abs() < f64::EPSILON);
}
#[test]
fn reward_from_f64() {
let r: Reward = 0.42.into();
assert!((r.value() - 0.42).abs() < f64::EPSILON);
}
#[test]
fn reward_display_formats() {
let r = Reward::new(0.5);
assert_eq!(format!("{r}"), "0.5");
}
#[test]
fn reward_partial_ord() {
assert!(Reward::WIN > Reward::DRAW);
assert!(Reward::DRAW > Reward::LOSS);
}
#[test]
fn reward_serde_roundtrip() {
let r = Reward::new(0.123);
let json = serde_json::to_string(&r).unwrap();
let restored: Reward = serde_json::from_str(&json).unwrap();
assert_eq!(r, restored);
}
#[test]
fn outcome_terminal_detection() {
assert!(Outcome::Terminal(Reward::WIN).is_terminal());
assert!(Outcome::Success(Reward::new(0.5)).is_terminal());
assert!(Outcome::Failure.is_terminal());
assert!(Outcome::Neutral.is_terminal());
assert!(!Outcome::Ongoing.is_terminal());
}
#[test]
fn outcome_reward_extraction() {
assert_eq!(Outcome::Ongoing.reward(), None);
assert_eq!(
Outcome::Terminal(Reward::new(0.5)).reward(),
Some(Reward::new(0.5))
);
assert_eq!(
Outcome::Success(Reward::new(0.5)).reward(),
Some(Reward::new(0.5))
);
assert_eq!(Outcome::Failure.reward(), Some(Reward::LOSS));
assert_eq!(Outcome::Neutral.reward(), Some(Reward::DRAW));
}
#[test]
fn outcome_display_formats() {
assert_eq!(format!("{}", Outcome::Ongoing), "ongoing");
assert_eq!(format!("{}", Outcome::Failure), "failure");
assert_eq!(format!("{}", Outcome::Neutral), "neutral");
assert!(format!("{}", Outcome::Terminal(Reward::WIN)).contains("1"));
}
#[test]
fn outcome_serde_roundtrip() {
let outcomes = vec![
Outcome::Ongoing,
Outcome::Terminal(Reward::WIN),
Outcome::Success(Reward::new(0.5)),
Outcome::Failure,
Outcome::Neutral,
];
for o in outcomes {
let json = serde_json::to_string(&o).unwrap();
let restored: Outcome = serde_json::from_str(&json).unwrap();
assert_eq!(o, restored);
}
}
#[test]
fn heuristic_default_is_none() {
let h = Heuristic::default();
assert_eq!(h.value, None);
}
#[test]
fn heuristic_from_reward() {
let h = Heuristic::from_reward(Reward::new(0.8));
assert_eq!(h.value, Some(Reward::new(0.8)));
}
#[test]
fn heuristic_serde_roundtrip() {
let h = Heuristic::from_reward(Reward::new(0.33));
let toml_str = toml::to_string(&h).unwrap();
let restored: Heuristic = toml::from_str(&toml_str).unwrap();
assert_eq!(h, restored);
}
#[test]
fn nodestats_serde_roundtrip() {
let stats = NodeStats {
visits: 42,
average_reward: 0.75,
children_count: 5,
unexpanded_count: 3,
};
let toml_str = toml::to_string(&stats).unwrap();
let restored: NodeStats = toml::from_str(&toml_str).unwrap();
assert_eq!(stats.visits, restored.visits);
assert!((stats.average_reward - restored.average_reward).abs() < f64::EPSILON);
assert_eq!(stats.children_count, restored.children_count);
assert_eq!(stats.unexpanded_count, restored.unexpanded_count);
}
#[test]
fn searchconfig_default_values() {
let c = SearchConfig::default();
assert_eq!(c.iterations, 10_000);
assert!((c.exploration_constant - std::f64::consts::SQRT_2).abs() < f64::EPSILON);
assert_eq!(c.max_depth, 50);
assert_eq!(c.tree_policy, TreePolicy::Uct);
assert!((c.heuristic_weight - 0.35).abs() < f64::EPSILON);
assert!(c.rave.enabled);
assert!((c.rave.bias - 300.0).abs() < f64::EPSILON);
assert!(c.progressive_widening.is_none());
assert!(c.time_budget.is_none());
}
#[test]
fn searchconfig_builder_all_fields() {
let c = SearchConfig::builder()
.iterations(500)
.exploration_constant(3.0)
.max_depth(10)
.tree_policy(TreePolicy::Puct { prior_weight: 2.0 })
.heuristic_weight(0.5)
.rave(RaveConfig {
enabled: false,
bias: 100.0,
})
.progressive_widening(ProgressiveWideningConfig {
minimum_children: 2,
coefficient: 2.0,
exponent: 0.75,
})
.time_budget(Duration::from_millis(100))
.build();
assert_eq!(c.iterations, 500);
assert!((c.exploration_constant - 3.0).abs() < f64::EPSILON);
assert_eq!(c.max_depth, 10);
assert_eq!(c.tree_policy, TreePolicy::Puct { prior_weight: 2.0 });
assert!((c.heuristic_weight - 0.5).abs() < f64::EPSILON);
assert!(!c.rave.enabled);
assert!((c.rave.bias - 100.0).abs() < f64::EPSILON);
assert!(c.progressive_widening.is_some());
assert_eq!(c.time_budget, Some(Duration::from_millis(100)));
}
#[test]
fn searchconfig_sanitize_zero_iterations() {
let mut c = SearchConfig::default();
c.iterations = 0;
let warnings = c.sanitize();
assert!(!warnings.is_empty());
assert!(warnings.iter().any(|w| w.contains("iterations")));
assert_eq!(c.iterations, 10_000);
}
#[test]
fn searchconfig_sanitize_nan_exploration() {
let mut c = SearchConfig::default();
c.exploration_constant = f64::NAN;
let warnings = c.sanitize();
assert!(!warnings.is_empty());
assert!(warnings.iter().any(|w| w.contains("exploration_constant")));
assert!((c.exploration_constant - std::f64::consts::SQRT_2).abs() < f64::EPSILON);
}
#[test]
fn searchconfig_sanitize_negative_exploration() {
let mut c = SearchConfig::default();
c.exploration_constant = -1.0;
let warnings = c.sanitize();
assert!(!warnings.is_empty());
assert!((c.exploration_constant - std::f64::consts::SQRT_2).abs() < f64::EPSILON);
}
#[test]
fn searchconfig_sanitize_heuristic_weight_out_of_range() {
let mut c = SearchConfig::default();
c.heuristic_weight = 2.0;
let warnings = c.sanitize();
assert!(!warnings.is_empty());
assert!((c.heuristic_weight - 1.0).abs() < f64::EPSILON);
}
#[test]
fn searchconfig_sanitize_heuristic_weight_nan() {
let mut c = SearchConfig::default();
c.heuristic_weight = f64::NAN;
let warnings = c.sanitize();
assert!(!warnings.is_empty());
assert!((c.heuristic_weight - 0.35).abs() < f64::EPSILON);
}
#[test]
fn searchconfig_sanitize_rave_bias_nan() {
let mut c = SearchConfig::default();
c.rave.bias = f64::NAN;
let warnings = c.sanitize();
assert!(!warnings.is_empty());
assert!((c.rave.bias - 300.0).abs() < f64::EPSILON);
}
#[test]
fn searchconfig_sanitize_progressive_widening_zero_minimum() {
let mut c = SearchConfig::default();
c.progressive_widening = Some(ProgressiveWideningConfig {
minimum_children: 0,
coefficient: 1.5,
exponent: 0.5,
});
let warnings = c.sanitize();
assert!(!warnings.is_empty());
assert_eq!(c.progressive_widening.as_ref().unwrap().minimum_children, 1);
}
#[test]
fn searchconfig_sanitize_puct_prior_weight_nan() {
let mut c = SearchConfig::default();
c.tree_policy = TreePolicy::Puct {
prior_weight: f64::NAN,
};
let warnings = c.sanitize();
assert!(!warnings.is_empty());
match c.tree_policy {
TreePolicy::Puct { prior_weight } => assert_eq!(prior_weight, 1.0),
_ => panic!("Expected Puct"),
}
}
#[test]
fn searchconfig_sanitize_thompson_temperature_nan() {
let mut c = SearchConfig::default();
c.tree_policy = TreePolicy::ThompsonSampling {
temperature: f64::NAN,
};
let warnings = c.sanitize();
assert!(!warnings.is_empty());
match c.tree_policy {
TreePolicy::ThompsonSampling { temperature } => assert_eq!(temperature, 1.0),
_ => panic!("Expected ThompsonSampling"),
}
}
#[test]
fn searchconfig_sanitize_gumbel_zero_sampled_actions() {
let mut c = SearchConfig::default();
c.tree_policy = TreePolicy::Gumbel {
sampled_actions: 0,
max_completions_coeff: 50.0,
};
let warnings = c.sanitize();
assert!(!warnings.is_empty());
match c.tree_policy {
TreePolicy::Gumbel {
sampled_actions, ..
} => assert_eq!(sampled_actions, 16),
_ => panic!("Expected Gumbel"),
}
}
#[test]
fn searchconfig_sanitize_gumbel_nan_coeff() {
let mut c = SearchConfig::default();
c.tree_policy = TreePolicy::Gumbel {
sampled_actions: 8,
max_completions_coeff: f64::NAN,
};
let warnings = c.sanitize();
assert!(!warnings.is_empty());
match c.tree_policy {
TreePolicy::Gumbel {
max_completions_coeff,
..
} => assert_eq!(max_completions_coeff, 50.0),
_ => panic!("Expected Gumbel"),
}
}
#[test]
fn searchconfig_toml_roundtrip() {
let c = SearchConfig::builder()
.iterations(256)
.tree_policy(TreePolicy::ThompsonSampling { temperature: 0.5 })
.build();
let toml_str = toml::to_string(&c).unwrap();
let parsed: SearchConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(c.iterations, parsed.iterations);
assert_eq!(c.tree_policy, parsed.tree_policy);
}
#[test]
fn searchconfig_from_toml_str() {
let config = SearchConfig::from_toml_str(
r#"
iterations = 128
max_depth = 20
exploration_constant = 1.5
[tree_policy]
kind = "uct"
"#,
)
.unwrap();
assert_eq!(config.iterations, 128);
assert_eq!(config.max_depth, 20);
assert!((config.exploration_constant - 1.5).abs() < f64::EPSILON);
}
#[test]
fn searchconfig_from_toml_file_not_found() {
let result = SearchConfig::from_toml_file("/definitely/does/not/exist.toml");
assert!(result.is_err());
}
#[test]
fn treepolicy_default_is_uct() {
let p: TreePolicy = TreePolicy::default();
assert!(matches!(p, TreePolicy::Uct));
}
#[test]
fn treepolicy_serde_roundtrip() {
let policies = vec![
TreePolicy::Uct,
TreePolicy::Puct { prior_weight: 1.5 },
TreePolicy::ThompsonSampling { temperature: 0.25 },
TreePolicy::Gumbel {
sampled_actions: 16,
max_completions_coeff: 50.0,
},
];
for p in policies {
let toml_str = toml::to_string(&p).unwrap();
let restored: TreePolicy = toml::from_str(&toml_str).unwrap();
assert_eq!(p, restored);
}
}
#[test]
fn raveconfig_default() {
let r = RaveConfig::default();
assert!(r.enabled);
assert!((r.bias - 300.0).abs() < f64::EPSILON);
}
#[test]
fn raveconfig_serde_roundtrip() {
let r = RaveConfig {
enabled: false,
bias: 100.0,
};
let toml_str = toml::to_string(&r).unwrap();
let restored: RaveConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(r, restored);
}
#[test]
fn progressivewidening_default() {
let p = ProgressiveWideningConfig::default();
assert_eq!(p.minimum_children, 1);
assert!((p.coefficient - 1.5).abs() < f64::EPSILON);
assert!((p.exponent - 0.5).abs() < f64::EPSILON);
}
#[test]
fn progressivewidening_serde_roundtrip() {
let p = ProgressiveWideningConfig {
minimum_children: 3,
coefficient: 2.5,
exponent: 0.33,
};
let toml_str = toml::to_string(&p).unwrap();
let restored: ProgressiveWideningConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(p, restored);
}
#[test]
fn scalarizer_default() {
let s = Scalarizer::default();
assert!(s.signal_weights.is_empty());
assert_eq!(s.default_weight, 0.0);
}
#[test]
fn scalarizer_scalarize_with_weights() {
let mut weights = HashMap::new();
weights.insert("coverage".to_string(), 1.0);
weights.insert("crash".to_string(), 10.0);
let s = Scalarizer {
signal_weights: weights,
default_weight: 0.0,
};
let reward = s.scalarize(&[("coverage", 0.5), ("crash", 1.0)]);
assert!((reward - 10.5).abs() < f64::EPSILON);
}
#[test]
fn scalarizer_scalarize_uses_default_weight() {
let s = Scalarizer {
signal_weights: HashMap::new(),
default_weight: 2.0,
};
let reward = s.scalarize(&[("unknown", 3.0)]);
assert!((reward - 6.0).abs() < f64::EPSILON);
}
#[test]
fn scalarizer_scalarize_skips_nan_signals() {
let s = Scalarizer {
signal_weights: HashMap::new(),
default_weight: 1.0,
};
let reward = s.scalarize(&[("a", 1.0), ("b", f64::NAN), ("c", 2.0)]);
assert!((reward - 3.0).abs() < f64::EPSILON);
}
#[test]
fn scalarizer_scalarize_skips_nan_weights() {
let mut weights = HashMap::new();
weights.insert("a".to_string(), 1.0);
weights.insert("b".to_string(), f64::NAN);
let s = Scalarizer {
signal_weights: weights,
default_weight: 1.0,
};
let reward = s.scalarize(&[("a", 2.0), ("b", 3.0)]);
assert!((reward - 2.0).abs() < f64::EPSILON);
}
#[test]
fn scalarizer_serde_roundtrip() {
let mut weights = HashMap::new();
weights.insert("x".to_string(), 1.5);
let s = Scalarizer {
signal_weights: weights,
default_weight: 0.5,
};
let toml_str = toml::to_string(&s).unwrap();
let restored: Scalarizer = toml::from_str(&toml_str).unwrap();
assert_eq!(s, restored);
}
#[test]
fn banditconfig_default() {
let c = BanditConfig::default();
assert!((c.exploration_constant - std::f64::consts::SQRT_2).abs() < f64::EPSILON);
assert!((c.rave_bias - 500.0).abs() < f64::EPSILON);
assert_eq!(c.max_pulls, 0);
assert_eq!(c.scalarizer, Scalarizer::default());
}
#[test]
fn banditconfig_builder() {
let c = BanditConfig::builder()
.exploration_constant(2.0)
.rave_bias(100.0)
.max_pulls(1000)
.build();
assert!((c.exploration_constant - 2.0).abs() < f64::EPSILON);
assert!((c.rave_bias - 100.0).abs() < f64::EPSILON);
assert_eq!(c.max_pulls, 1000);
}
#[test]
fn banditconfig_sanitize_nan_exploration() {
let mut c = BanditConfig::default();
c.exploration_constant = f64::NAN;
let warnings = c.sanitize();
assert!(!warnings.is_empty());
assert!((c.exploration_constant - std::f64::consts::SQRT_2).abs() < f64::EPSILON);
}
#[test]
fn banditconfig_sanitize_negative_rave_bias() {
let mut c = BanditConfig::default();
c.rave_bias = -1.0;
let warnings = c.sanitize();
assert!(!warnings.is_empty());
assert!((c.rave_bias - 500.0).abs() < f64::EPSILON);
}
#[test]
fn banditconfig_sanitize_invalid_signal_weights() {
let mut c = BanditConfig::default();
c.scalarizer
.signal_weights
.insert("bad".to_string(), f64::NAN);
c.scalarizer.signal_weights.insert("good".to_string(), 1.0);
let warnings = c.sanitize();
assert!(!warnings.is_empty());
assert!(!c.scalarizer.signal_weights.contains_key("bad"));
assert!(c.scalarizer.signal_weights.contains_key("good"));
}
#[test]
fn banditconfig_serde_roundtrip() {
let c = BanditConfig::builder()
.exploration_constant(1.5)
.rave_bias(200.0)
.max_pulls(500)
.build();
let toml_str = toml::to_string(&c).unwrap();
let restored: BanditConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(c, restored);
}
#[test]
fn groupstats_serde_roundtrip() {
let stats = GroupStats {
group_id: 42,
visits: 100,
average_reward: 0.75,
total_arms: 10,
explored_arms: 5,
rave_visits: 50,
};
let toml_str = toml::to_string(&stats).unwrap();
let restored: GroupStats = toml::from_str(&toml_str).unwrap();
assert_eq!(stats.group_id, restored.group_id);
assert_eq!(stats.visits, restored.visits);
assert!((stats.average_reward - restored.average_reward).abs() < f64::EPSILON);
}
#[derive(Clone)]
struct MinimalEnv;
#[derive(Clone, Debug, PartialEq)]
struct MinimalAction;
impl Environment for MinimalEnv {
type Action = MinimalAction;
fn legal_actions(&self) -> Vec<MinimalAction> {
vec![MinimalAction]
}
fn apply(&mut self, _: &MinimalAction) {}
fn evaluate(&self) -> Outcome {
Outcome::Neutral
}
}
#[test]
fn environment_defaults() {
let env = MinimalEnv;
assert_eq!(env.current_player(), 0);
assert_eq!(env.num_players(), 1);
assert_eq!(env.heuristic(), Heuristic::default());
assert_eq!(env.max_depth(), None);
assert_eq!(env.action_priors(&[MinimalAction]), None);
assert_eq!(env.state_hash(), None);
}
struct ConstantEval(f64);
impl<E: Environment> Evaluator<E> for ConstantEval {
fn evaluate(&self, _env: &E) -> Reward {
Reward::new(self.0)
}
}
#[test]
fn evaluator_returns_expected_value() {
let eval = ConstantEval(0.75);
let reward = eval.evaluate(&MinimalEnv);
assert!((reward.value() - 0.75).abs() < f64::EPSILON);
}