use mctrust::*;
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct NumberLine {
value: i32,
target: i32,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
enum NumberAction {
Increment,
Decrement,
}
impl Environment for NumberLine {
type Action = NumberAction;
fn legal_actions(&self) -> Vec<NumberAction> {
vec![NumberAction::Increment, NumberAction::Decrement]
}
fn apply(&mut self, action: &NumberAction) {
match action {
NumberAction::Increment => self.value += 1,
NumberAction::Decrement => self.value -= 1,
}
}
fn evaluate(&self) -> Outcome {
if self.value == self.target {
Outcome::Success(Reward::WIN)
} else if (self.value - self.target).abs() > 15 {
Outcome::Failure
} else {
Outcome::Ongoing
}
}
fn heuristic(&self) -> Heuristic {
let dist = (self.value - self.target).abs() as f64;
Heuristic::from_reward(Reward::new(1.0 - (dist / 15.0).min(1.0)))
}
fn max_depth(&self) -> Option<usize> {
Some(15)
}
}
fn numberline_search_config(iterations: usize) -> SearchConfig {
SearchConfig::builder()
.iterations(iterations)
.max_depth(15)
.heuristic_weight(1.0)
.rave(RaveConfig {
enabled: false,
bias: 300.0,
})
.build()
}
#[test]
fn integration_numberline_positive_target() {
let game = NumberLine {
value: 0,
target: 5,
};
let config = numberline_search_config(2_000);
let mut search = TreeSearch::with_seed(game, config, 42);
let best = search.run();
assert_eq!(best, Some(NumberAction::Increment));
}
#[test]
fn integration_numberline_negative_target() {
let game = NumberLine {
value: 0,
target: -5,
};
let config = numberline_search_config(2_000);
let mut search = TreeSearch::with_seed(game, config, 42);
let best = search.run();
assert_eq!(best, Some(NumberAction::Decrement));
}
#[test]
fn integration_numberline_tree_reuse() {
let game = NumberLine {
value: 0,
target: 5,
};
let config = numberline_search_config(1_000);
let mut search = TreeSearch::with_seed(game, config, 42);
search.run();
let pv = search.principal_variation();
assert!(!pv.is_empty());
let first = pv[0].clone();
let tree_size_before = search.tree_size();
assert!(search.advance_to_action(&first));
assert!(search.tree_size() <= tree_size_before);
let states = search.principal_variation_states();
assert_eq!(
states[0].value,
if first == NumberAction::Increment {
1
} else {
-1
}
);
}
#[test]
fn integration_numberline_principal_variation_states() {
let game = NumberLine {
value: 0,
target: 5,
};
let config = numberline_search_config(1_000);
let mut search = TreeSearch::with_seed(game, config, 42);
search.run();
let states = search.principal_variation_states();
assert!(states.len() >= 2);
assert_eq!(states[0].value, 0);
}
#[test]
fn integration_numberline_with_evaluator() {
struct NumberLineEval;
impl Evaluator<NumberLine> for NumberLineEval {
fn evaluate(&self, env: &NumberLine) -> Reward {
let dist = (env.value - env.target).abs() as f64;
Reward::new(1.0 - (dist / 15.0).min(1.0))
}
}
let game = NumberLine {
value: 0,
target: 5,
};
let config = SearchConfig::builder().iterations(500).build();
let mut search = TreeSearch::with_seed(game, config, 42);
search.with_evaluator(Arc::new(NumberLineEval));
let best = search.run();
assert_eq!(best, Some(NumberAction::Increment));
}
#[test]
fn integration_bandit_hyperparameter_search() {
let mut search = BanditSearch::new_seeded(BanditConfig::builder().max_pulls(100).build(), 42);
for i in 0..30u64 {
search.add_arm(i, (i / 10) as u32);
}
while let Some(arm) = search.next_arm() {
let reward = if arm < 10 {
0.7
} else if arm < 20 {
0.8
} else {
0.6
};
search.observe(arm, reward);
}
let stats = search.group_stats();
assert_eq!(stats.len(), 3);
let best_family = stats
.iter()
.max_by(|a, b| a.average_reward.partial_cmp(&b.average_reward).unwrap())
.unwrap();
assert_eq!(best_family.group_id, 1);
}
#[test]
fn integration_bandit_with_signals() {
let mut weights = std::collections::HashMap::new();
weights.insert("accuracy".to_string(), 1.0);
weights.insert("latency_ms".to_string(), -0.01);
let mut search = BanditSearch::new_seeded(
BanditConfig::builder()
.scalarizer(Scalarizer {
signal_weights: weights,
default_weight: 0.0,
})
.build(),
42,
);
for i in 0..5u64 {
search.add_arm(i, 0);
}
for i in 0..5u64 {
let arm = search.next_arm().unwrap();
let accuracy = 0.8 + (i as f64) * 0.02;
let latency = 100.0 - (i as f64) * 5.0;
search.observe_with_signals(arm, &[("accuracy", accuracy), ("latency_ms", latency)]);
}
let stats = search.group_stats();
assert_eq!(stats[0].visits, 5);
}
#[test]
fn integration_bandit_reweight_online() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
search.add_arm(0, 0);
search.add_arm(1, 0);
let arm1 = search.next_arm().unwrap();
search.observe_with_signals(arm1, &[("feature_a", 1.0)]);
search.reweight_signals(&[("feature_a", 2.0), ("feature_b", -1.0)]);
let arm2 = search.next_arm().unwrap();
search.observe_with_signals(arm2, &[("feature_a", 1.0), ("feature_b", 1.0)]);
let stats = search.group_stats();
assert_eq!(stats[0].visits, 2);
}
#[test]
fn integration_treesearch_uct_policy() {
let game = NumberLine {
value: 0,
target: 3,
};
let config = SearchConfig::builder()
.iterations(1_000)
.max_depth(15)
.heuristic_weight(1.0)
.rave(RaveConfig {
enabled: false,
bias: 300.0,
})
.tree_policy(TreePolicy::Uct)
.build();
let mut search = TreeSearch::with_seed(game, config, 42);
let best = search.run();
assert_eq!(best, Some(NumberAction::Increment));
}
#[test]
fn integration_treesearch_puct_policy() {
#[derive(Clone)]
struct PriorNumberLine {
value: i32,
target: i32,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum PriorAction {
Increment,
Decrement,
}
impl Environment for PriorNumberLine {
type Action = PriorAction;
fn legal_actions(&self) -> Vec<PriorAction> {
vec![PriorAction::Increment, PriorAction::Decrement]
}
fn apply(&mut self, action: &PriorAction) {
match action {
PriorAction::Increment => self.value += 1,
PriorAction::Decrement => self.value -= 1,
}
}
fn evaluate(&self) -> Outcome {
if self.value == self.target {
Outcome::Success(Reward::WIN)
} else if (self.value - self.target).abs() > 10 {
Outcome::Failure
} else {
Outcome::Ongoing
}
}
fn action_priors(&self, actions: &[PriorAction]) -> Option<Vec<f64>> {
Some(
actions
.iter()
.map(|a| match a {
PriorAction::Increment => 0.9,
PriorAction::Decrement => 0.1,
})
.collect(),
)
}
}
let game = PriorNumberLine {
value: 0,
target: 3,
};
let config = SearchConfig::builder()
.iterations(1_000)
.tree_policy(TreePolicy::Puct { prior_weight: 2.0 })
.build();
let mut search = TreeSearch::with_seed(game, config, 42);
let best = search.run();
assert_eq!(best, Some(PriorAction::Increment));
}
#[test]
fn integration_treesearch_thompson_policy() {
let game = NumberLine {
value: 0,
target: 3,
};
let config = SearchConfig::builder()
.iterations(1_000)
.tree_policy(TreePolicy::ThompsonSampling { temperature: 0.5 })
.build();
let mut search = TreeSearch::with_seed(game, config, 42);
let best = search.run();
assert_eq!(best, Some(NumberAction::Increment));
}
#[test]
fn integration_treesearch_gumbel_policy() {
let game = NumberLine {
value: 0,
target: 3,
};
let config = SearchConfig::builder()
.iterations(1_000)
.tree_policy(TreePolicy::Gumbel {
sampled_actions: 16,
max_completions_coeff: 50.0,
})
.build();
let mut search = TreeSearch::with_seed(game, config, 42);
let best = search.run();
assert_eq!(best, Some(NumberAction::Increment));
}
#[test]
fn integration_dot_export_basic() {
let game = NumberLine {
value: 0,
target: 3,
};
let config = SearchConfig::builder().iterations(50).build();
let mut search = TreeSearch::with_seed(game, config, 42);
search.run();
let dot = search.export_dot(3);
assert!(dot.starts_with("digraph mctrust {"));
assert!(dot.contains("n0"));
assert!(dot.ends_with("}\n"));
}
#[test]
fn integration_dot_export_depth_zero() {
let game = NumberLine {
value: 0,
target: 3,
};
let config = SearchConfig::builder().iterations(50).build();
let mut search = TreeSearch::with_seed(game, config, 42);
search.run();
let dot = search.export_dot(0);
assert!(dot.starts_with("digraph mctrust {"));
assert!(dot.ends_with("}\n"));
}
#[test]
fn integration_checkpoint_restore_treesearch() {
let game = NumberLine {
value: 0,
target: 3,
};
let config = SearchConfig::builder().iterations(100).build();
let mut search = TreeSearch::with_seed(game, config, 42);
search.run();
let size_before = search.tree_size();
let sims_before = search.total_simulations();
let cp = search.checkpoint();
let restored: TreeSearch<NumberLine> = TreeSearch::restore(cp);
assert_eq!(restored.tree_size(), size_before);
assert_eq!(restored.total_simulations(), sims_before);
}
#[test]
fn integration_checkpoint_restore_bandit() {
let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
for i in 0..20u64 {
search.add_arm(i, (i / 5) as u32);
}
for _ in 0..10 {
if let Some(arm) = search.next_arm() {
search.observe(arm, 0.5);
}
}
let cp = search.checkpoint();
let restored = BanditSearch::restore(cp);
assert_eq!(restored.total_pulls(), 10);
let stats = restored.group_stats();
assert_eq!(stats.len(), 4);
}
#[test]
fn integration_progressive_widening_limits_expansion() {
#[derive(Clone)]
struct WideEnv(i32);
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct WideAction(i32);
impl Environment for WideEnv {
type Action = WideAction;
fn legal_actions(&self) -> Vec<WideAction> {
(0..100).map(WideAction).collect()
}
fn apply(&mut self, action: &WideAction) {
self.0 += action.0;
}
fn evaluate(&self) -> Outcome {
if self.0.abs() > 50 {
Outcome::Terminal(Reward::WIN)
} else {
Outcome::Ongoing
}
}
}
let config = SearchConfig::builder()
.iterations(50)
.progressive_widening(ProgressiveWideningConfig {
minimum_children: 1,
coefficient: 1.0,
exponent: 0.5,
})
.build();
let mut search = TreeSearch::with_seed(WideEnv(0), config, 42);
search.run();
let stats = search.root_stats();
assert!(stats.len() < 100);
}
#[test]
fn integration_time_budget_stops_early() {
let game = NumberLine {
value: 0,
target: 3,
};
let mut config = SearchConfig::builder().iterations(10_000_000).build();
config.time_budget = Some(Duration::from_millis(10));
let mut search = TreeSearch::with_seed(game, config, 42);
search.run();
assert!(search.total_simulations() > 0);
assert!(search.total_simulations() < 10_000_000);
}
#[test]
fn integration_run_until_predicate() {
let game = NumberLine {
value: 0,
target: 3,
};
let config = SearchConfig::builder().iterations(100_000).build();
let mut search = TreeSearch::with_seed(game, config, 42);
search.run_until(|s| s.total_simulations() >= 250);
assert!(search.total_simulations() >= 250);
}
#[test]
fn integration_run_until_reward_threshold() {
let game = NumberLine {
value: 0,
target: 3,
};
let config = SearchConfig::builder().iterations(100_000).build();
let mut search = TreeSearch::with_seed(game, config, 42);
search
.run_until(|s| s.best_root_reward().unwrap_or(0.0) > 0.5 || s.total_simulations() >= 5000);
assert!(search.total_simulations() >= 1);
}