Skip to main content

treant_gumbel/
lib.rs

1//! Gumbel MuZero search: policy improvement by planning with Gumbel.
2//!
3//! Implements the algorithm from
4//! [Danihelka et al., "Policy improvement by planning with Gumbel" (ICLR 2022)](https://openreview.net/forum?id=bERaNdoegnO).
5//!
6//! Key features:
7//! - **Gumbel-Top-k** sampling at the root for action selection
8//! - **Sequential Halving** for optimal simulation budget allocation
9//! - **PUCT** selection for tree traversal below the root
10//! - **Improved policy** output — a better training target than visit counts
11//!
12//! # Design
13//!
14//! Gumbel search is fundamentally different from standard MCTS at the root level:
15//! instead of UCT/PUCT selection, it samples Gumbel noise, selects top-m actions,
16//! then uses Sequential Halving to allocate simulations. Below the root, standard
17//! PUCT guides tree traversal. This produces monotonically improving policies —
18//! more simulations always help.
19//!
20//! The crate reuses [`treant::GameState`] so any game implemented for the core MCTS
21//! crate works with Gumbel search.
22//!
23//! # Example
24//!
25//! ```
26//! use treant::GameState;
27//! use treant_gumbel::{GumbelSearch, GumbelConfig, GumbelEvaluator};
28//!
29//! # #[derive(Clone, Debug)] struct MyGame;
30//! # #[derive(Clone, Debug, PartialEq)] struct MyMove;
31//! # impl GameState for MyGame {
32//! #     type Move = MyMove; type Player = (); type MoveList = Vec<MyMove>;
33//! #     fn current_player(&self) {}
34//! #     fn available_moves(&self) -> Vec<MyMove> { vec![MyMove] }
35//! #     fn make_move(&mut self, _: &MyMove) {}
36//! #     fn terminal_value(&self) -> Option<treant::ProvenValue> { Some(treant::ProvenValue::Draw) }
37//! # }
38//! # struct Eval;
39//! # impl GumbelEvaluator<MyGame> for Eval {
40//! #     fn evaluate(&self, _: &MyGame, m: &[MyMove]) -> (Vec<f64>, f64) { (vec![0.0; m.len()], 0.0) }
41//! # }
42//! let mut search = GumbelSearch::new(Eval, GumbelConfig::default());
43//! let result = search.search(&MyGame, 100);
44//! println!("Best move: {:?}", result.best_move);
45//! ```
46
47use treant::{GameState, ProvenValue};
48use rand::Rng;
49use rand::SeedableRng;
50use rand_xoshiro::Xoshiro256PlusPlus;
51
52// ============================================================
53// Public API
54// ============================================================
55
56/// Evaluator providing policy logits and value estimates.
57///
58/// Implement this for your game to provide Gumbel search with action
59/// priors and state values. For neural networks, wrap the policy/value
60/// heads. For heuristic evaluators, return uniform logits with a
61/// heuristic value.
62pub trait GumbelEvaluator<G: GameState>: Send {
63    /// Evaluate a game state.
64    ///
65    /// Returns `(logits, value)` where:
66    /// - `logits`: one `f64` per move (unnormalized log-probabilities)
67    /// - `value`: state value for the current player, in `[-1.0, 1.0]`
68    fn evaluate(&self, state: &G, moves: &[G::Move]) -> (Vec<f64>, f64);
69}
70
71/// Configuration for Gumbel search.
72#[derive(Clone, Copy, Debug)]
73pub struct GumbelConfig {
74    /// Number of actions to consider after Gumbel-Top-k sampling.
75    /// Larger values explore more broadly but use more simulation budget.
76    /// Default: 16.
77    pub m_actions: usize,
78
79    /// PUCT exploration constant for below-root tree traversal.
80    /// Default: 1.25.
81    pub c_puct: f64,
82
83    /// Maximum search depth per simulation. Default: 200.
84    pub max_depth: usize,
85
86    /// Scale factor mapping Q-values to the logit scale (`c_visit` in the paper).
87    /// Higher values make the search more exploitation-focused; the improved policy
88    /// becomes sharper. Tune relative to your logit scale. Default: 50.0.
89    pub value_scale: f64,
90
91    /// RNG seed for Gumbel noise sampling. Default: 42.
92    pub seed: u64,
93}
94
95impl Default for GumbelConfig {
96    fn default() -> Self {
97        Self {
98            m_actions: 16,
99            c_puct: 1.25,
100            max_depth: 200,
101            value_scale: 50.0,
102            seed: 42,
103        }
104    }
105}
106
107/// Per-move statistics from Gumbel search.
108pub struct MoveStats<M: Clone> {
109    /// The move.
110    pub mov: M,
111    /// Number of simulations allocated to this move.
112    pub visits: u32,
113    /// Completed Q-value (empirical mean if visited, root value estimate otherwise).
114    pub completed_q: f64,
115    /// Improved policy probability from Gumbel search (sums to 1.0 across all moves).
116    pub improved_policy: f64,
117}
118
119impl<M: Clone + std::fmt::Debug> std::fmt::Debug for MoveStats<M> {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        f.debug_struct("MoveStats")
122            .field("mov", &self.mov)
123            .field("visits", &self.visits)
124            .field("completed_q", &self.completed_q)
125            .field("improved_policy", &self.improved_policy)
126            .finish()
127    }
128}
129
130/// Result of a Gumbel search.
131#[must_use]
132pub struct SearchResult<M: Clone> {
133    /// The best move found by search.
134    pub best_move: M,
135
136    /// Value estimate for the root state's current player.
137    pub root_value: f64,
138
139    /// Per-move statistics: visits, completed Q, and improved policy.
140    pub move_stats: Vec<MoveStats<M>>,
141
142    /// Total simulations used.
143    pub simulations_used: u32,
144}
145
146impl<M: Clone + std::fmt::Debug> std::fmt::Debug for SearchResult<M> {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        f.debug_struct("SearchResult")
149            .field("best_move", &self.best_move)
150            .field("root_value", &self.root_value)
151            .field("simulations_used", &self.simulations_used)
152            .field("move_stats", &self.move_stats)
153            .finish()
154    }
155}
156
157/// Gumbel MCTS search engine.
158///
159/// Implements Gumbel-Top-k root selection with Sequential Halving,
160/// providing monotonic policy improvement and better simulation
161/// efficiency compared to standard MCTS.
162///
163/// Two-player zero-sum games (negamax). Single-threaded.
164pub struct GumbelSearch<G: GameState, E: GumbelEvaluator<G>> {
165    config: GumbelConfig,
166    evaluator: E,
167    rng: Xoshiro256PlusPlus,
168    _phantom: std::marker::PhantomData<G>,
169}
170
171impl<G: GameState, E: GumbelEvaluator<G>> std::fmt::Debug for GumbelSearch<G, E> {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        f.debug_struct("GumbelSearch")
174            .field("config", &self.config)
175            .finish_non_exhaustive()
176    }
177}
178
179// ============================================================
180// Internal tree structures
181// ============================================================
182
183struct Node<M: Clone> {
184    edges: Vec<Edge<M>>,
185    visits: u32,
186}
187
188struct Edge<M: Clone> {
189    mov: M,
190    prior: f64,
191    visits: u32,
192    value_sum: f64,
193    child: Option<Box<Node<M>>>,
194}
195
196// ============================================================
197// Implementation
198// ============================================================
199
200impl<G, E> GumbelSearch<G, E>
201where
202    G: GameState,
203    E: GumbelEvaluator<G>,
204{
205    /// Create a new Gumbel search engine.
206    #[must_use]
207    pub fn new(evaluator: E, config: GumbelConfig) -> Self {
208        let rng = Xoshiro256PlusPlus::seed_from_u64(config.seed);
209        Self {
210            config,
211            evaluator,
212            rng,
213            _phantom: std::marker::PhantomData,
214        }
215    }
216
217    /// Access the evaluator.
218    #[must_use]
219    pub fn evaluator(&self) -> &E {
220        &self.evaluator
221    }
222
223    /// Access the configuration.
224    #[must_use]
225    pub fn config(&self) -> &GumbelConfig {
226        &self.config
227    }
228
229    /// Reset the RNG seed for reproducible searches.
230    pub fn set_seed(&mut self, seed: u64) {
231        self.rng = Xoshiro256PlusPlus::seed_from_u64(seed);
232    }
233
234    /// Run Gumbel search from the given state.
235    ///
236    /// # Panics
237    ///
238    /// Panics if the state is terminal (no available moves).
239    pub fn search(&mut self, state: &G, n_simulations: u32) -> SearchResult<G::Move> {
240        let moves: Vec<G::Move> = state.available_moves().into_iter().collect();
241        assert!(!moves.is_empty(), "cannot search from terminal state");
242
243        // Single move: evaluate for root_value but skip search
244        if moves.len() == 1 {
245            let (_, root_value) = self.evaluator.evaluate(state, &moves);
246            return SearchResult {
247                best_move: moves[0].clone(),
248                root_value,
249                move_stats: vec![MoveStats {
250                    mov: moves[0].clone(),
251                    visits: 0,
252                    completed_q: root_value,
253                    improved_policy: 1.0,
254                }],
255                simulations_used: 0,
256            };
257        }
258
259        // Evaluate root
260        let (logits, root_value) = self.evaluator.evaluate(state, &moves);
261        let priors = softmax(&logits);
262
263        // Sample Gumbel(0,1) noise for each action
264        let gumbels: Vec<f64> = (0..moves.len())
265            .map(|_| sample_gumbel(&mut self.rng))
266            .collect();
267
268        // Create root node
269        let mut root = Node {
270            edges: moves
271                .iter()
272                .enumerate()
273                .map(|(i, m)| Edge {
274                    mov: m.clone(),
275                    prior: priors[i],
276                    visits: 0,
277                    value_sum: 0.0,
278                    child: None,
279                })
280                .collect(),
281            visits: 0,
282        };
283
284        // Top-m selection by g(a) + logit(a)
285        let m = self.config.m_actions.min(moves.len());
286        let mut alive: Vec<usize> = (0..moves.len()).collect();
287        alive.sort_by(|&a, &b| {
288            let sa = gumbels[a] + logits[a];
289            let sb = gumbels[b] + logits[b];
290            sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal)
291        });
292        alive.truncate(m);
293
294        // Sequential Halving
295        let n_seq = if m <= 1 {
296            1
297        } else {
298            (m as f64).log2().ceil() as u32
299        };
300        let mut budget = n_simulations;
301        let mut total_sims = 0u32;
302
303        for phase in 0..n_seq {
304            if alive.len() <= 1 || total_sims >= n_simulations {
305                break;
306            }
307
308            // Remaining-budget allocation: each phase gets a fair share of what's left
309            let phases_left = n_seq - phase;
310            let n_a = budget / (alive.len() as u32 * phases_left);
311            if n_a == 0 {
312                break; // budget exhausted
313            }
314
315            for &action_idx in &alive {
316                for _ in 0..n_a {
317                    if total_sims >= n_simulations {
318                        break;
319                    }
320                    let mut s = state.clone();
321                    self.simulate(&mut root, &mut s, action_idx);
322                    total_sims += 1;
323                }
324            }
325            budget = budget.saturating_sub(alive.len() as u32 * n_a);
326
327            // Score each surviving action and halve
328            let mut scored: Vec<(usize, f64)> = alive
329                .iter()
330                .map(|&idx| {
331                    let q = completed_q(&root.edges[idx], root_value);
332                    let score = gumbels[idx] + logits[idx] + self.config.value_scale * q;
333                    (idx, score)
334                })
335                .collect();
336            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
337
338            let keep = alive.len().div_ceil(2);
339            alive = scored[..keep].iter().map(|&(idx, _)| idx).collect();
340        }
341
342        // Spend remaining budget on the survivor(s), distributing remainder fairly
343        if total_sims < n_simulations && !alive.is_empty() {
344            let mut remaining = n_simulations - total_sims;
345            for (i, &action_idx) in alive.iter().enumerate() {
346                let actions_left = alive.len() as u32 - i as u32;
347                let share = remaining / actions_left;
348                for _ in 0..share {
349                    let mut s = state.clone();
350                    self.simulate(&mut root, &mut s, action_idx);
351                    total_sims += 1;
352                }
353                remaining -= share;
354            }
355        }
356
357        // Re-rank survivors after final simulations
358        let best_idx = if alive.len() > 1 {
359            *alive
360                .iter()
361                .max_by(|&&a, &&b| {
362                    let sa = gumbels[a]
363                        + logits[a]
364                        + self.config.value_scale * completed_q(&root.edges[a], root_value);
365                    let sb = gumbels[b]
366                        + logits[b]
367                        + self.config.value_scale * completed_q(&root.edges[b], root_value);
368                    sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
369                })
370                .unwrap()
371        } else {
372            alive[0]
373        };
374
375        // Improved policy: softmax(logit + value_scale * q_completed)
376        let improved_scores: Vec<f64> = root
377            .edges
378            .iter()
379            .enumerate()
380            .map(|(i, e)| logits[i] + self.config.value_scale * completed_q(e, root_value))
381            .collect();
382        let improved_probs = softmax(&improved_scores);
383
384        let move_stats: Vec<MoveStats<G::Move>> = root
385            .edges
386            .iter()
387            .zip(improved_probs.iter())
388            .map(|(e, &p)| MoveStats {
389                mov: e.mov.clone(),
390                visits: e.visits,
391                completed_q: completed_q(e, root_value),
392                improved_policy: p,
393            })
394            .collect();
395
396        SearchResult {
397            best_move: root.edges[best_idx].mov.clone(),
398            root_value,
399            move_stats,
400            simulations_used: total_sims,
401        }
402    }
403
404    /// Run a single simulation, forcing `forced_action` at the root.
405    fn simulate(&self, root: &mut Node<G::Move>, state: &mut G, forced_action: usize) {
406        let mov = root.edges[forced_action].mov.clone();
407        state.make_move(&mov);
408
409        let child_value = if root.edges[forced_action].child.is_some() {
410            self.descend(root.edges[forced_action].child.as_mut().unwrap(), state, 1)
411        } else {
412            let (child_node, leaf_value) = self.expand(state);
413            root.edges[forced_action].child = Some(Box::new(child_node));
414            leaf_value
415        };
416
417        // Negamax: root player's value = -child's value
418        root.edges[forced_action].value_sum += -child_value;
419        root.edges[forced_action].visits += 1;
420        root.visits += 1;
421    }
422
423    /// Descend the tree with PUCT selection. Returns value for the current player.
424    fn descend(&self, node: &mut Node<G::Move>, state: &mut G, depth: usize) -> f64 {
425        // Terminal node
426        if node.edges.is_empty() {
427            return terminal_value(state);
428        }
429
430        // Depth limit: evaluate in place
431        if depth >= self.config.max_depth {
432            let moves: Vec<G::Move> = state.available_moves().into_iter().collect();
433            if moves.is_empty() {
434                return terminal_value(state);
435            }
436            let (_, value) = self.evaluator.evaluate(state, &moves);
437            return value;
438        }
439
440        // PUCT selection
441        let action_idx = puct_select(node, self.config.c_puct);
442
443        let mov = node.edges[action_idx].mov.clone();
444        state.make_move(&mov);
445
446        let child_value = if node.edges[action_idx].child.is_some() {
447            self.descend(
448                node.edges[action_idx].child.as_mut().unwrap(),
449                state,
450                depth + 1,
451            )
452        } else {
453            let (child_node, leaf_value) = self.expand(state);
454            node.edges[action_idx].child = Some(Box::new(child_node));
455            leaf_value
456        };
457
458        // Negamax
459        let my_value = -child_value;
460        node.edges[action_idx].value_sum += my_value;
461        node.edges[action_idx].visits += 1;
462        node.visits += 1;
463
464        my_value
465    }
466
467    /// Expand a leaf: evaluate the state and create a new node.
468    fn expand(&self, state: &G) -> (Node<G::Move>, f64) {
469        if let Some(pv) = state.terminal_value() {
470            return (
471                Node {
472                    edges: vec![],
473                    visits: 0,
474                },
475                proven_to_value(pv),
476            );
477        }
478
479        let moves: Vec<G::Move> = state.available_moves().into_iter().collect();
480        if moves.is_empty() {
481            return (
482                Node {
483                    edges: vec![],
484                    visits: 0,
485                },
486                0.0,
487            );
488        }
489
490        let (logits, value) = self.evaluator.evaluate(state, &moves);
491        let priors = softmax(&logits);
492
493        let node = Node {
494            edges: moves
495                .into_iter()
496                .enumerate()
497                .map(|(i, m)| Edge {
498                    mov: m,
499                    prior: priors[i],
500                    visits: 0,
501                    value_sum: 0.0,
502                    child: None,
503                })
504                .collect(),
505            visits: 0,
506        };
507
508        (node, value)
509    }
510}
511
512// ============================================================
513// Utility functions
514// ============================================================
515
516/// Sample from Gumbel(0, 1): -ln(-ln(U)), U ~ Uniform(0,1).
517fn sample_gumbel(rng: &mut impl Rng) -> f64 {
518    let u: f64 = rng.gen();
519    let u = u.clamp(1e-20, 1.0 - 1e-20);
520    -(-u.ln()).ln()
521}
522
523/// Numerically stable softmax. Returns uniform distribution for degenerate inputs.
524fn softmax(logits: &[f64]) -> Vec<f64> {
525    if logits.is_empty() {
526        return vec![];
527    }
528    let max = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
529    if !max.is_finite() {
530        // All -inf or NaN: fall back to uniform
531        let n = logits.len() as f64;
532        return vec![1.0 / n; logits.len()];
533    }
534    let exps: Vec<f64> = logits.iter().map(|&x| (x - max).exp()).collect();
535    let sum: f64 = exps.iter().sum();
536    if sum == 0.0 {
537        let n = logits.len() as f64;
538        return vec![1.0 / n; logits.len()];
539    }
540    exps.iter().map(|&e| e / sum).collect()
541}
542
543/// PUCT action selection: argmax Q(a) + c * P(a) * sqrt(N) / (1 + n(a)).
544fn puct_select<M: Clone>(node: &Node<M>, c: f64) -> usize {
545    let sqrt_n = (node.visits as f64).sqrt();
546
547    node.edges
548        .iter()
549        .enumerate()
550        .max_by(|(_, a), (_, b)| {
551            let sa = puct_score(a, c, sqrt_n);
552            let sb = puct_score(b, c, sqrt_n);
553            sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
554        })
555        .map(|(i, _)| i)
556        .unwrap_or(0)
557}
558
559fn puct_score<M: Clone>(edge: &Edge<M>, c: f64, sqrt_parent_visits: f64) -> f64 {
560    let q = if edge.visits > 0 {
561        edge.value_sum / edge.visits as f64
562    } else {
563        0.0
564    };
565    let u = c * edge.prior * sqrt_parent_visits / (1.0 + edge.visits as f64);
566    q + u
567}
568
569/// Completed Q-value: empirical mean if visited, otherwise falls back to the
570/// parent node's value estimate (used as a surrogate for unvisited actions).
571fn completed_q<M: Clone>(edge: &Edge<M>, default_value: f64) -> f64 {
572    if edge.visits > 0 {
573        edge.value_sum / edge.visits as f64
574    } else {
575        default_value
576    }
577}
578
579/// Map a ProvenValue to a numeric value for the current player.
580fn proven_to_value(pv: ProvenValue) -> f64 {
581    match pv {
582        ProvenValue::Win => 1.0,
583        ProvenValue::Loss => -1.0,
584        ProvenValue::Draw | ProvenValue::Unknown => 0.0,
585    }
586}
587
588/// Terminal value for the current player.
589fn terminal_value<G: GameState>(state: &G) -> f64 {
590    state.terminal_value().map(proven_to_value).unwrap_or(0.0)
591}
592
593// ============================================================
594// Unit tests
595// ============================================================
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    #[test]
602    fn test_sample_gumbel_mean() {
603        // Gumbel(0,1) has mean = Euler-Mascheroni constant ~0.5772
604        let mut rng = Xoshiro256PlusPlus::seed_from_u64(123);
605        let n = 50_000;
606        let sum: f64 = (0..n).map(|_| sample_gumbel(&mut rng)).sum();
607        let mean = sum / n as f64;
608        assert!(
609            (mean - 0.5772).abs() < 0.02,
610            "Gumbel mean {mean} too far from 0.5772"
611        );
612    }
613
614    #[test]
615    fn test_softmax_sums_to_one() {
616        let logits = vec![1.0, 2.0, 3.0, 4.0];
617        let probs = softmax(&logits);
618        let sum: f64 = probs.iter().sum();
619        assert!((sum - 1.0).abs() < 1e-10);
620    }
621
622    #[test]
623    fn test_softmax_ordering() {
624        let logits = vec![1.0, 3.0, 2.0];
625        let probs = softmax(&logits);
626        assert!(probs[1] > probs[2]);
627        assert!(probs[2] > probs[0]);
628    }
629
630    #[test]
631    fn test_softmax_uniform() {
632        let logits = vec![0.0, 0.0, 0.0];
633        let probs = softmax(&logits);
634        for &p in &probs {
635            assert!((p - 1.0 / 3.0).abs() < 1e-10);
636        }
637    }
638
639    #[test]
640    fn test_softmax_empty() {
641        assert!(softmax(&[]).is_empty());
642    }
643
644    #[test]
645    fn test_softmax_single() {
646        let probs = softmax(&[42.0]);
647        assert_eq!(probs.len(), 1);
648        assert!((probs[0] - 1.0).abs() < 1e-10);
649    }
650
651    #[test]
652    fn test_softmax_extreme_large_logits() {
653        let logits = vec![1000.0, 1001.0, 999.0];
654        let probs = softmax(&logits);
655        let sum: f64 = probs.iter().sum();
656        assert!((sum - 1.0).abs() < 1e-10, "sum = {sum}");
657        assert!(probs[1] > probs[0]);
658    }
659
660    #[test]
661    fn test_softmax_extreme_negative_logits() {
662        let logits = vec![-1000.0, -1001.0, -999.0];
663        let probs = softmax(&logits);
664        let sum: f64 = probs.iter().sum();
665        assert!((sum - 1.0).abs() < 1e-10, "sum = {sum}");
666        assert!(probs[2] > probs[0]);
667    }
668
669    #[test]
670    fn test_softmax_all_neg_infinity_returns_uniform() {
671        let logits = vec![f64::NEG_INFINITY, f64::NEG_INFINITY, f64::NEG_INFINITY];
672        let probs = softmax(&logits);
673        for &p in &probs {
674            assert!((p - 1.0 / 3.0).abs() < 1e-10, "should be uniform, got {p}");
675        }
676    }
677
678    #[test]
679    fn test_softmax_nan_returns_uniform() {
680        let logits = vec![f64::NAN, f64::NAN];
681        let probs = softmax(&logits);
682        for &p in &probs {
683            assert!((p - 0.5).abs() < 1e-10, "NaN logits should produce uniform");
684        }
685    }
686
687    #[test]
688    fn test_puct_prefers_high_prior_initially() {
689        let node = Node {
690            edges: vec![
691                Edge {
692                    mov: 0u32,
693                    prior: 0.1,
694                    visits: 0,
695                    value_sum: 0.0,
696                    child: None,
697                },
698                Edge {
699                    mov: 1,
700                    prior: 0.9,
701                    visits: 0,
702                    value_sum: 0.0,
703                    child: None,
704                },
705            ],
706            visits: 1,
707        };
708        let selected = puct_select(&node, 1.25);
709        assert_eq!(selected, 1);
710    }
711
712    #[test]
713    fn test_puct_prefers_high_value_after_visits() {
714        let node = Node {
715            edges: vec![
716                Edge {
717                    mov: 0u32,
718                    prior: 0.5,
719                    visits: 10,
720                    value_sum: 8.0,
721                    child: None,
722                },
723                Edge {
724                    mov: 1,
725                    prior: 0.5,
726                    visits: 10,
727                    value_sum: 2.0,
728                    child: None,
729                },
730            ],
731            visits: 20,
732        };
733        let selected = puct_select(&node, 1.25);
734        assert_eq!(selected, 0);
735    }
736
737    #[test]
738    fn test_puct_zero_priors_degenerates_to_exploitation() {
739        let node = Node {
740            edges: vec![
741                Edge {
742                    mov: 0u32,
743                    prior: 0.0,
744                    visits: 5,
745                    value_sum: 3.0,
746                    child: None,
747                },
748                Edge {
749                    mov: 1,
750                    prior: 0.0,
751                    visits: 5,
752                    value_sum: 1.0,
753                    child: None,
754                },
755            ],
756            visits: 10,
757        };
758        // Zero priors: exploration term is 0, should pick highest Q (action 0: Q=0.6)
759        let selected = puct_select(&node, 1.25);
760        assert_eq!(selected, 0);
761    }
762
763    #[test]
764    fn test_completed_q_visited() {
765        let edge = Edge {
766            mov: 0u32,
767            prior: 0.5,
768            visits: 4,
769            value_sum: 2.0,
770            child: None,
771        };
772        assert!((completed_q(&edge, 0.0) - 0.5).abs() < 1e-10);
773    }
774
775    #[test]
776    fn test_completed_q_unvisited() {
777        let edge = Edge {
778            mov: 0u32,
779            prior: 0.5,
780            visits: 0,
781            value_sum: 0.0,
782            child: None,
783        };
784        assert!((completed_q(&edge, 0.7) - 0.7).abs() < 1e-10);
785    }
786
787    #[test]
788    fn test_completed_q_negative() {
789        let edge = Edge {
790            mov: 0u32,
791            prior: 0.5,
792            visits: 4,
793            value_sum: -2.0,
794            child: None,
795        };
796        assert!((completed_q(&edge, 0.0) - (-0.5)).abs() < 1e-10);
797    }
798
799    #[test]
800    fn test_proven_to_value() {
801        assert!((proven_to_value(ProvenValue::Win) - 1.0).abs() < 1e-10);
802        assert!((proven_to_value(ProvenValue::Loss) - (-1.0)).abs() < 1e-10);
803        assert!((proven_to_value(ProvenValue::Draw) - 0.0).abs() < 1e-10);
804        assert!((proven_to_value(ProvenValue::Unknown) - 0.0).abs() < 1e-10);
805    }
806}