use treant::tree_policy::*;
use treant::*;
#[derive(Clone, Debug)]
struct ScoreGame {
depth: u8,
branch: Option<Branch>,
score: Option<i32>,
current: Player,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum Player {
P1,
P2,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum Branch {
A,
B,
C,
}
#[derive(Clone, Debug, PartialEq)]
enum ScoreMove {
Pick(Branch),
Respond(u8),
}
impl std::fmt::Display for ScoreMove {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
ScoreMove::Pick(b) => write!(f, "{b:?}"),
ScoreMove::Respond(i) => write!(f, "R{i}"),
}
}
}
fn terminal_scores(branch: Branch, response: u8) -> i32 {
match (branch, response) {
(Branch::A, 0) => 2,
(Branch::A, 1) => 5,
(Branch::B, 0) => 1,
(Branch::B, 1) => 3,
(Branch::C, 0) => 8,
(Branch::C, 1) => 6,
_ => unreachable!(),
}
}
impl GameState for ScoreGame {
type Move = ScoreMove;
type Player = Player;
type MoveList = Vec<ScoreMove>;
fn current_player(&self) -> Player {
self.current
}
fn available_moves(&self) -> Vec<ScoreMove> {
match self.depth {
0 => vec![
ScoreMove::Pick(Branch::A),
ScoreMove::Pick(Branch::B),
ScoreMove::Pick(Branch::C),
],
1 => {
vec![ScoreMove::Respond(0), ScoreMove::Respond(1)]
}
_ => vec![],
}
}
fn make_move(&mut self, mov: &ScoreMove) {
match mov {
ScoreMove::Pick(b) => {
self.branch = Some(*b);
self.depth = 1;
self.current = Player::P2;
}
ScoreMove::Respond(i) => {
let branch = self.branch.unwrap();
self.score = Some(terminal_scores(branch, *i));
self.depth = 2;
self.current = Player::P1;
}
}
}
fn terminal_score(&self) -> Option<i32> {
if self.depth == 2 {
self.score
} else {
None
}
}
}
struct ScoreEval;
impl Evaluator<ScoreMCTS> for ScoreEval {
type StateEvaluation = i32;
fn evaluate_new_state(
&self,
state: &ScoreGame,
moves: &Vec<ScoreMove>,
_: Option<SearchHandle<ScoreMCTS>>,
) -> (Vec<()>, i32) {
let eval = state.score.unwrap_or(0);
(vec![(); moves.len()], eval)
}
fn interpret_evaluation_for_player(&self, evaln: &i32, player: &Player) -> i64 {
match player {
Player::P1 => *evaln as i64,
Player::P2 => -(*evaln as i64),
}
}
fn evaluate_existing_state(
&self,
_: &ScoreGame,
evaln: &i32,
_: SearchHandle<ScoreMCTS>,
) -> i32 {
*evaln
}
}
#[derive(Default)]
struct ScoreMCTS;
impl MCTS for ScoreMCTS {
type State = ScoreGame;
type Eval = ScoreEval;
type NodeData = ();
type ExtraThreadData = ();
type TreePolicy = UCTPolicy;
type TranspositionTable = ();
fn visits_before_expansion(&self) -> u64 {
0
}
fn score_bounded_enabled(&self) -> bool {
true
}
fn solver_enabled(&self) -> bool {
true
}
}
fn main() {
println!("=== Score-Bounded MCTS ===\n");
println!("A depth-2 two-player game with known terminal scores.");
println!("P1 picks branch A/B/C, then P2 responds.\n");
println!("Tree (scores from P1's perspective):");
println!(" Root (P1)");
println!(" / | \\");
println!(" A B C");
println!(" (P2) (P2) (P2)");
println!(" / \\ / \\ / \\");
println!(" 2 5 1 3 8 6");
println!();
println!("Minimax: A=min(2,5)=2, B=min(1,3)=1, C=min(8,6)=6");
println!("Root = max(2,1,6) = 6 via branch C\n");
let mut mcts = MCTSManager::new(
ScoreGame {
depth: 0,
branch: None,
score: None,
current: Player::P1,
},
ScoreMCTS,
ScoreEval,
UCTPolicy::new(1.0),
(),
);
mcts.playout_n(100);
let bounds = mcts.root_score_bounds();
println!("Root score bounds: [{}, {}]", bounds.lower, bounds.upper);
println!("Bounds converged: {}", bounds.is_proven());
let proven = mcts.root_proven_value();
println!("Root proven value: {proven:?}");
let best = mcts
.best_move()
.map(|m| format!("{m}"))
.unwrap_or_else(|| "-".into());
println!("Best move: {best}");
let nodes = mcts.tree().num_nodes();
println!("Nodes: {nodes}\n");
println!("Child stats (bounds from P1's perspective):");
let stats = mcts.root_child_stats();
for s in &stats {
let p1_lower = negate(s.score_bounds.upper);
let p1_upper = negate(s.score_bounds.lower);
println!(
" {:<3} — visits: {:4}, avg_reward: {:6.1}, \
value: [{:>4}, {:>4}], proven: {:?}",
s.mov,
s.visits,
s.avg_reward,
format_bound(p1_lower),
format_bound(p1_upper),
s.proven_value,
);
}
println!();
assert!(
bounds.is_proven() && bounds.lower == 6,
"Expected proven minimax value 6, got [{}, {}]",
bounds.lower,
bounds.upper
);
println!("Verified: minimax value = {} (exact)", bounds.lower);
}
fn negate(v: i32) -> i32 {
match v {
i32::MIN => i32::MAX,
i32::MAX => i32::MIN,
_ => -v,
}
}
fn format_bound(v: i32) -> String {
match v {
i32::MIN => "-inf".to_string(),
i32::MAX => "+inf".to_string(),
_ => v.to_string(),
}
}