use treant::{GameState, ProvenValue};
use rand::Rng;
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256PlusPlus;
pub trait GumbelEvaluator<G: GameState>: Send {
fn evaluate(&self, state: &G, moves: &[G::Move]) -> (Vec<f64>, f64);
}
#[derive(Clone, Copy, Debug)]
pub struct GumbelConfig {
pub m_actions: usize,
pub c_puct: f64,
pub max_depth: usize,
pub value_scale: f64,
pub seed: u64,
}
impl Default for GumbelConfig {
fn default() -> Self {
Self {
m_actions: 16,
c_puct: 1.25,
max_depth: 200,
value_scale: 50.0,
seed: 42,
}
}
}
pub struct MoveStats<M: Clone> {
pub mov: M,
pub visits: u32,
pub completed_q: f64,
pub improved_policy: f64,
}
impl<M: Clone + std::fmt::Debug> std::fmt::Debug for MoveStats<M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MoveStats")
.field("mov", &self.mov)
.field("visits", &self.visits)
.field("completed_q", &self.completed_q)
.field("improved_policy", &self.improved_policy)
.finish()
}
}
#[must_use]
pub struct SearchResult<M: Clone> {
pub best_move: M,
pub root_value: f64,
pub move_stats: Vec<MoveStats<M>>,
pub simulations_used: u32,
}
impl<M: Clone + std::fmt::Debug> std::fmt::Debug for SearchResult<M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SearchResult")
.field("best_move", &self.best_move)
.field("root_value", &self.root_value)
.field("simulations_used", &self.simulations_used)
.field("move_stats", &self.move_stats)
.finish()
}
}
pub struct GumbelSearch<G: GameState, E: GumbelEvaluator<G>> {
config: GumbelConfig,
evaluator: E,
rng: Xoshiro256PlusPlus,
_phantom: std::marker::PhantomData<G>,
}
impl<G: GameState, E: GumbelEvaluator<G>> std::fmt::Debug for GumbelSearch<G, E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GumbelSearch")
.field("config", &self.config)
.finish_non_exhaustive()
}
}
struct Node<M: Clone> {
edges: Vec<Edge<M>>,
visits: u32,
}
struct Edge<M: Clone> {
mov: M,
prior: f64,
visits: u32,
value_sum: f64,
child: Option<Box<Node<M>>>,
}
impl<G, E> GumbelSearch<G, E>
where
G: GameState,
E: GumbelEvaluator<G>,
{
#[must_use]
pub fn new(evaluator: E, config: GumbelConfig) -> Self {
let rng = Xoshiro256PlusPlus::seed_from_u64(config.seed);
Self {
config,
evaluator,
rng,
_phantom: std::marker::PhantomData,
}
}
#[must_use]
pub fn evaluator(&self) -> &E {
&self.evaluator
}
#[must_use]
pub fn config(&self) -> &GumbelConfig {
&self.config
}
pub fn set_seed(&mut self, seed: u64) {
self.rng = Xoshiro256PlusPlus::seed_from_u64(seed);
}
pub fn search(&mut self, state: &G, n_simulations: u32) -> SearchResult<G::Move> {
let moves: Vec<G::Move> = state.available_moves().into_iter().collect();
assert!(!moves.is_empty(), "cannot search from terminal state");
if moves.len() == 1 {
let (_, root_value) = self.evaluator.evaluate(state, &moves);
return SearchResult {
best_move: moves[0].clone(),
root_value,
move_stats: vec![MoveStats {
mov: moves[0].clone(),
visits: 0,
completed_q: root_value,
improved_policy: 1.0,
}],
simulations_used: 0,
};
}
let (logits, root_value) = self.evaluator.evaluate(state, &moves);
let priors = softmax(&logits);
let gumbels: Vec<f64> = (0..moves.len())
.map(|_| sample_gumbel(&mut self.rng))
.collect();
let mut root = Node {
edges: moves
.iter()
.enumerate()
.map(|(i, m)| Edge {
mov: m.clone(),
prior: priors[i],
visits: 0,
value_sum: 0.0,
child: None,
})
.collect(),
visits: 0,
};
let m = self.config.m_actions.min(moves.len());
let mut alive: Vec<usize> = (0..moves.len()).collect();
alive.sort_by(|&a, &b| {
let sa = gumbels[a] + logits[a];
let sb = gumbels[b] + logits[b];
sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal)
});
alive.truncate(m);
let n_seq = if m <= 1 {
1
} else {
(m as f64).log2().ceil() as u32
};
let mut budget = n_simulations;
let mut total_sims = 0u32;
for phase in 0..n_seq {
if alive.len() <= 1 || total_sims >= n_simulations {
break;
}
let phases_left = n_seq - phase;
let n_a = budget / (alive.len() as u32 * phases_left);
if n_a == 0 {
break; }
for &action_idx in &alive {
for _ in 0..n_a {
if total_sims >= n_simulations {
break;
}
let mut s = state.clone();
self.simulate(&mut root, &mut s, action_idx);
total_sims += 1;
}
}
budget = budget.saturating_sub(alive.len() as u32 * n_a);
let mut scored: Vec<(usize, f64)> = alive
.iter()
.map(|&idx| {
let q = completed_q(&root.edges[idx], root_value);
let score = gumbels[idx] + logits[idx] + self.config.value_scale * q;
(idx, score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let keep = alive.len().div_ceil(2);
alive = scored[..keep].iter().map(|&(idx, _)| idx).collect();
}
if total_sims < n_simulations && !alive.is_empty() {
let mut remaining = n_simulations - total_sims;
for (i, &action_idx) in alive.iter().enumerate() {
let actions_left = alive.len() as u32 - i as u32;
let share = remaining / actions_left;
for _ in 0..share {
let mut s = state.clone();
self.simulate(&mut root, &mut s, action_idx);
total_sims += 1;
}
remaining -= share;
}
}
let best_idx = if alive.len() > 1 {
*alive
.iter()
.max_by(|&&a, &&b| {
let sa = gumbels[a]
+ logits[a]
+ self.config.value_scale * completed_q(&root.edges[a], root_value);
let sb = gumbels[b]
+ logits[b]
+ self.config.value_scale * completed_q(&root.edges[b], root_value);
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap()
} else {
alive[0]
};
let improved_scores: Vec<f64> = root
.edges
.iter()
.enumerate()
.map(|(i, e)| logits[i] + self.config.value_scale * completed_q(e, root_value))
.collect();
let improved_probs = softmax(&improved_scores);
let move_stats: Vec<MoveStats<G::Move>> = root
.edges
.iter()
.zip(improved_probs.iter())
.map(|(e, &p)| MoveStats {
mov: e.mov.clone(),
visits: e.visits,
completed_q: completed_q(e, root_value),
improved_policy: p,
})
.collect();
SearchResult {
best_move: root.edges[best_idx].mov.clone(),
root_value,
move_stats,
simulations_used: total_sims,
}
}
fn simulate(&self, root: &mut Node<G::Move>, state: &mut G, forced_action: usize) {
let mov = root.edges[forced_action].mov.clone();
state.make_move(&mov);
let child_value = if root.edges[forced_action].child.is_some() {
self.descend(root.edges[forced_action].child.as_mut().unwrap(), state, 1)
} else {
let (child_node, leaf_value) = self.expand(state);
root.edges[forced_action].child = Some(Box::new(child_node));
leaf_value
};
root.edges[forced_action].value_sum += -child_value;
root.edges[forced_action].visits += 1;
root.visits += 1;
}
fn descend(&self, node: &mut Node<G::Move>, state: &mut G, depth: usize) -> f64 {
if node.edges.is_empty() {
return terminal_value(state);
}
if depth >= self.config.max_depth {
let moves: Vec<G::Move> = state.available_moves().into_iter().collect();
if moves.is_empty() {
return terminal_value(state);
}
let (_, value) = self.evaluator.evaluate(state, &moves);
return value;
}
let action_idx = puct_select(node, self.config.c_puct);
let mov = node.edges[action_idx].mov.clone();
state.make_move(&mov);
let child_value = if node.edges[action_idx].child.is_some() {
self.descend(
node.edges[action_idx].child.as_mut().unwrap(),
state,
depth + 1,
)
} else {
let (child_node, leaf_value) = self.expand(state);
node.edges[action_idx].child = Some(Box::new(child_node));
leaf_value
};
let my_value = -child_value;
node.edges[action_idx].value_sum += my_value;
node.edges[action_idx].visits += 1;
node.visits += 1;
my_value
}
fn expand(&self, state: &G) -> (Node<G::Move>, f64) {
if let Some(pv) = state.terminal_value() {
return (
Node {
edges: vec![],
visits: 0,
},
proven_to_value(pv),
);
}
let moves: Vec<G::Move> = state.available_moves().into_iter().collect();
if moves.is_empty() {
return (
Node {
edges: vec![],
visits: 0,
},
0.0,
);
}
let (logits, value) = self.evaluator.evaluate(state, &moves);
let priors = softmax(&logits);
let node = Node {
edges: moves
.into_iter()
.enumerate()
.map(|(i, m)| Edge {
mov: m,
prior: priors[i],
visits: 0,
value_sum: 0.0,
child: None,
})
.collect(),
visits: 0,
};
(node, value)
}
}
fn sample_gumbel(rng: &mut impl Rng) -> f64 {
let u: f64 = rng.gen();
let u = u.clamp(1e-20, 1.0 - 1e-20);
-(-u.ln()).ln()
}
fn softmax(logits: &[f64]) -> Vec<f64> {
if logits.is_empty() {
return vec![];
}
let max = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if !max.is_finite() {
let n = logits.len() as f64;
return vec![1.0 / n; logits.len()];
}
let exps: Vec<f64> = logits.iter().map(|&x| (x - max).exp()).collect();
let sum: f64 = exps.iter().sum();
if sum == 0.0 {
let n = logits.len() as f64;
return vec![1.0 / n; logits.len()];
}
exps.iter().map(|&e| e / sum).collect()
}
fn puct_select<M: Clone>(node: &Node<M>, c: f64) -> usize {
let sqrt_n = (node.visits as f64).sqrt();
node.edges
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
let sa = puct_score(a, c, sqrt_n);
let sb = puct_score(b, c, sqrt_n);
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0)
}
fn puct_score<M: Clone>(edge: &Edge<M>, c: f64, sqrt_parent_visits: f64) -> f64 {
let q = if edge.visits > 0 {
edge.value_sum / edge.visits as f64
} else {
0.0
};
let u = c * edge.prior * sqrt_parent_visits / (1.0 + edge.visits as f64);
q + u
}
fn completed_q<M: Clone>(edge: &Edge<M>, default_value: f64) -> f64 {
if edge.visits > 0 {
edge.value_sum / edge.visits as f64
} else {
default_value
}
}
fn proven_to_value(pv: ProvenValue) -> f64 {
match pv {
ProvenValue::Win => 1.0,
ProvenValue::Loss => -1.0,
ProvenValue::Draw | ProvenValue::Unknown => 0.0,
}
}
fn terminal_value<G: GameState>(state: &G) -> f64 {
state.terminal_value().map(proven_to_value).unwrap_or(0.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sample_gumbel_mean() {
let mut rng = Xoshiro256PlusPlus::seed_from_u64(123);
let n = 50_000;
let sum: f64 = (0..n).map(|_| sample_gumbel(&mut rng)).sum();
let mean = sum / n as f64;
assert!(
(mean - 0.5772).abs() < 0.02,
"Gumbel mean {mean} too far from 0.5772"
);
}
#[test]
fn test_softmax_sums_to_one() {
let logits = vec![1.0, 2.0, 3.0, 4.0];
let probs = softmax(&logits);
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
}
#[test]
fn test_softmax_ordering() {
let logits = vec![1.0, 3.0, 2.0];
let probs = softmax(&logits);
assert!(probs[1] > probs[2]);
assert!(probs[2] > probs[0]);
}
#[test]
fn test_softmax_uniform() {
let logits = vec![0.0, 0.0, 0.0];
let probs = softmax(&logits);
for &p in &probs {
assert!((p - 1.0 / 3.0).abs() < 1e-10);
}
}
#[test]
fn test_softmax_empty() {
assert!(softmax(&[]).is_empty());
}
#[test]
fn test_softmax_single() {
let probs = softmax(&[42.0]);
assert_eq!(probs.len(), 1);
assert!((probs[0] - 1.0).abs() < 1e-10);
}
#[test]
fn test_softmax_extreme_large_logits() {
let logits = vec![1000.0, 1001.0, 999.0];
let probs = softmax(&logits);
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-10, "sum = {sum}");
assert!(probs[1] > probs[0]);
}
#[test]
fn test_softmax_extreme_negative_logits() {
let logits = vec![-1000.0, -1001.0, -999.0];
let probs = softmax(&logits);
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-10, "sum = {sum}");
assert!(probs[2] > probs[0]);
}
#[test]
fn test_softmax_all_neg_infinity_returns_uniform() {
let logits = vec![f64::NEG_INFINITY, f64::NEG_INFINITY, f64::NEG_INFINITY];
let probs = softmax(&logits);
for &p in &probs {
assert!((p - 1.0 / 3.0).abs() < 1e-10, "should be uniform, got {p}");
}
}
#[test]
fn test_softmax_nan_returns_uniform() {
let logits = vec![f64::NAN, f64::NAN];
let probs = softmax(&logits);
for &p in &probs {
assert!((p - 0.5).abs() < 1e-10, "NaN logits should produce uniform");
}
}
#[test]
fn test_puct_prefers_high_prior_initially() {
let node = Node {
edges: vec![
Edge {
mov: 0u32,
prior: 0.1,
visits: 0,
value_sum: 0.0,
child: None,
},
Edge {
mov: 1,
prior: 0.9,
visits: 0,
value_sum: 0.0,
child: None,
},
],
visits: 1,
};
let selected = puct_select(&node, 1.25);
assert_eq!(selected, 1);
}
#[test]
fn test_puct_prefers_high_value_after_visits() {
let node = Node {
edges: vec![
Edge {
mov: 0u32,
prior: 0.5,
visits: 10,
value_sum: 8.0,
child: None,
},
Edge {
mov: 1,
prior: 0.5,
visits: 10,
value_sum: 2.0,
child: None,
},
],
visits: 20,
};
let selected = puct_select(&node, 1.25);
assert_eq!(selected, 0);
}
#[test]
fn test_puct_zero_priors_degenerates_to_exploitation() {
let node = Node {
edges: vec![
Edge {
mov: 0u32,
prior: 0.0,
visits: 5,
value_sum: 3.0,
child: None,
},
Edge {
mov: 1,
prior: 0.0,
visits: 5,
value_sum: 1.0,
child: None,
},
],
visits: 10,
};
let selected = puct_select(&node, 1.25);
assert_eq!(selected, 0);
}
#[test]
fn test_completed_q_visited() {
let edge = Edge {
mov: 0u32,
prior: 0.5,
visits: 4,
value_sum: 2.0,
child: None,
};
assert!((completed_q(&edge, 0.0) - 0.5).abs() < 1e-10);
}
#[test]
fn test_completed_q_unvisited() {
let edge = Edge {
mov: 0u32,
prior: 0.5,
visits: 0,
value_sum: 0.0,
child: None,
};
assert!((completed_q(&edge, 0.7) - 0.7).abs() < 1e-10);
}
#[test]
fn test_completed_q_negative() {
let edge = Edge {
mov: 0u32,
prior: 0.5,
visits: 4,
value_sum: -2.0,
child: None,
};
assert!((completed_q(&edge, 0.0) - (-0.5)).abs() < 1e-10);
}
#[test]
fn test_proven_to_value() {
assert!((proven_to_value(ProvenValue::Win) - 1.0).abs() < 1e-10);
assert!((proven_to_value(ProvenValue::Loss) - (-1.0)).abs() < 1e-10);
assert!((proven_to_value(ProvenValue::Draw) - 0.0).abs() < 1e-10);
assert!((proven_to_value(ProvenValue::Unknown) - 0.0).abs() < 1e-10);
}
}