1use treant::{GameState, ProvenValue};
48use rand::Rng;
49use rand::SeedableRng;
50use rand_xoshiro::Xoshiro256PlusPlus;
51
52pub trait GumbelEvaluator<G: GameState>: Send {
63 fn evaluate(&self, state: &G, moves: &[G::Move]) -> (Vec<f64>, f64);
69}
70
71#[derive(Clone, Copy, Debug)]
73pub struct GumbelConfig {
74 pub m_actions: usize,
78
79 pub c_puct: f64,
82
83 pub max_depth: usize,
85
86 pub value_scale: f64,
90
91 pub seed: u64,
93}
94
95impl Default for GumbelConfig {
96 fn default() -> Self {
97 Self {
98 m_actions: 16,
99 c_puct: 1.25,
100 max_depth: 200,
101 value_scale: 50.0,
102 seed: 42,
103 }
104 }
105}
106
107pub struct MoveStats<M: Clone> {
109 pub mov: M,
111 pub visits: u32,
113 pub completed_q: f64,
115 pub improved_policy: f64,
117}
118
119impl<M: Clone + std::fmt::Debug> std::fmt::Debug for MoveStats<M> {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 f.debug_struct("MoveStats")
122 .field("mov", &self.mov)
123 .field("visits", &self.visits)
124 .field("completed_q", &self.completed_q)
125 .field("improved_policy", &self.improved_policy)
126 .finish()
127 }
128}
129
130#[must_use]
132pub struct SearchResult<M: Clone> {
133 pub best_move: M,
135
136 pub root_value: f64,
138
139 pub move_stats: Vec<MoveStats<M>>,
141
142 pub simulations_used: u32,
144}
145
146impl<M: Clone + std::fmt::Debug> std::fmt::Debug for SearchResult<M> {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 f.debug_struct("SearchResult")
149 .field("best_move", &self.best_move)
150 .field("root_value", &self.root_value)
151 .field("simulations_used", &self.simulations_used)
152 .field("move_stats", &self.move_stats)
153 .finish()
154 }
155}
156
157pub struct GumbelSearch<G: GameState, E: GumbelEvaluator<G>> {
165 config: GumbelConfig,
166 evaluator: E,
167 rng: Xoshiro256PlusPlus,
168 _phantom: std::marker::PhantomData<G>,
169}
170
171impl<G: GameState, E: GumbelEvaluator<G>> std::fmt::Debug for GumbelSearch<G, E> {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 f.debug_struct("GumbelSearch")
174 .field("config", &self.config)
175 .finish_non_exhaustive()
176 }
177}
178
179struct Node<M: Clone> {
184 edges: Vec<Edge<M>>,
185 visits: u32,
186}
187
188struct Edge<M: Clone> {
189 mov: M,
190 prior: f64,
191 visits: u32,
192 value_sum: f64,
193 child: Option<Box<Node<M>>>,
194}
195
196impl<G, E> GumbelSearch<G, E>
201where
202 G: GameState,
203 E: GumbelEvaluator<G>,
204{
205 #[must_use]
207 pub fn new(evaluator: E, config: GumbelConfig) -> Self {
208 let rng = Xoshiro256PlusPlus::seed_from_u64(config.seed);
209 Self {
210 config,
211 evaluator,
212 rng,
213 _phantom: std::marker::PhantomData,
214 }
215 }
216
217 #[must_use]
219 pub fn evaluator(&self) -> &E {
220 &self.evaluator
221 }
222
223 #[must_use]
225 pub fn config(&self) -> &GumbelConfig {
226 &self.config
227 }
228
229 pub fn set_seed(&mut self, seed: u64) {
231 self.rng = Xoshiro256PlusPlus::seed_from_u64(seed);
232 }
233
234 pub fn search(&mut self, state: &G, n_simulations: u32) -> SearchResult<G::Move> {
240 let moves: Vec<G::Move> = state.available_moves().into_iter().collect();
241 assert!(!moves.is_empty(), "cannot search from terminal state");
242
243 if moves.len() == 1 {
245 let (_, root_value) = self.evaluator.evaluate(state, &moves);
246 return SearchResult {
247 best_move: moves[0].clone(),
248 root_value,
249 move_stats: vec![MoveStats {
250 mov: moves[0].clone(),
251 visits: 0,
252 completed_q: root_value,
253 improved_policy: 1.0,
254 }],
255 simulations_used: 0,
256 };
257 }
258
259 let (logits, root_value) = self.evaluator.evaluate(state, &moves);
261 let priors = softmax(&logits);
262
263 let gumbels: Vec<f64> = (0..moves.len())
265 .map(|_| sample_gumbel(&mut self.rng))
266 .collect();
267
268 let mut root = Node {
270 edges: moves
271 .iter()
272 .enumerate()
273 .map(|(i, m)| Edge {
274 mov: m.clone(),
275 prior: priors[i],
276 visits: 0,
277 value_sum: 0.0,
278 child: None,
279 })
280 .collect(),
281 visits: 0,
282 };
283
284 let m = self.config.m_actions.min(moves.len());
286 let mut alive: Vec<usize> = (0..moves.len()).collect();
287 alive.sort_by(|&a, &b| {
288 let sa = gumbels[a] + logits[a];
289 let sb = gumbels[b] + logits[b];
290 sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal)
291 });
292 alive.truncate(m);
293
294 let n_seq = if m <= 1 {
296 1
297 } else {
298 (m as f64).log2().ceil() as u32
299 };
300 let mut budget = n_simulations;
301 let mut total_sims = 0u32;
302
303 for phase in 0..n_seq {
304 if alive.len() <= 1 || total_sims >= n_simulations {
305 break;
306 }
307
308 let phases_left = n_seq - phase;
310 let n_a = budget / (alive.len() as u32 * phases_left);
311 if n_a == 0 {
312 break; }
314
315 for &action_idx in &alive {
316 for _ in 0..n_a {
317 if total_sims >= n_simulations {
318 break;
319 }
320 let mut s = state.clone();
321 self.simulate(&mut root, &mut s, action_idx);
322 total_sims += 1;
323 }
324 }
325 budget = budget.saturating_sub(alive.len() as u32 * n_a);
326
327 let mut scored: Vec<(usize, f64)> = alive
329 .iter()
330 .map(|&idx| {
331 let q = completed_q(&root.edges[idx], root_value);
332 let score = gumbels[idx] + logits[idx] + self.config.value_scale * q;
333 (idx, score)
334 })
335 .collect();
336 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
337
338 let keep = alive.len().div_ceil(2);
339 alive = scored[..keep].iter().map(|&(idx, _)| idx).collect();
340 }
341
342 if total_sims < n_simulations && !alive.is_empty() {
344 let mut remaining = n_simulations - total_sims;
345 for (i, &action_idx) in alive.iter().enumerate() {
346 let actions_left = alive.len() as u32 - i as u32;
347 let share = remaining / actions_left;
348 for _ in 0..share {
349 let mut s = state.clone();
350 self.simulate(&mut root, &mut s, action_idx);
351 total_sims += 1;
352 }
353 remaining -= share;
354 }
355 }
356
357 let best_idx = if alive.len() > 1 {
359 *alive
360 .iter()
361 .max_by(|&&a, &&b| {
362 let sa = gumbels[a]
363 + logits[a]
364 + self.config.value_scale * completed_q(&root.edges[a], root_value);
365 let sb = gumbels[b]
366 + logits[b]
367 + self.config.value_scale * completed_q(&root.edges[b], root_value);
368 sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
369 })
370 .unwrap()
371 } else {
372 alive[0]
373 };
374
375 let improved_scores: Vec<f64> = root
377 .edges
378 .iter()
379 .enumerate()
380 .map(|(i, e)| logits[i] + self.config.value_scale * completed_q(e, root_value))
381 .collect();
382 let improved_probs = softmax(&improved_scores);
383
384 let move_stats: Vec<MoveStats<G::Move>> = root
385 .edges
386 .iter()
387 .zip(improved_probs.iter())
388 .map(|(e, &p)| MoveStats {
389 mov: e.mov.clone(),
390 visits: e.visits,
391 completed_q: completed_q(e, root_value),
392 improved_policy: p,
393 })
394 .collect();
395
396 SearchResult {
397 best_move: root.edges[best_idx].mov.clone(),
398 root_value,
399 move_stats,
400 simulations_used: total_sims,
401 }
402 }
403
404 fn simulate(&self, root: &mut Node<G::Move>, state: &mut G, forced_action: usize) {
406 let mov = root.edges[forced_action].mov.clone();
407 state.make_move(&mov);
408
409 let child_value = if root.edges[forced_action].child.is_some() {
410 self.descend(root.edges[forced_action].child.as_mut().unwrap(), state, 1)
411 } else {
412 let (child_node, leaf_value) = self.expand(state);
413 root.edges[forced_action].child = Some(Box::new(child_node));
414 leaf_value
415 };
416
417 root.edges[forced_action].value_sum += -child_value;
419 root.edges[forced_action].visits += 1;
420 root.visits += 1;
421 }
422
423 fn descend(&self, node: &mut Node<G::Move>, state: &mut G, depth: usize) -> f64 {
425 if node.edges.is_empty() {
427 return terminal_value(state);
428 }
429
430 if depth >= self.config.max_depth {
432 let moves: Vec<G::Move> = state.available_moves().into_iter().collect();
433 if moves.is_empty() {
434 return terminal_value(state);
435 }
436 let (_, value) = self.evaluator.evaluate(state, &moves);
437 return value;
438 }
439
440 let action_idx = puct_select(node, self.config.c_puct);
442
443 let mov = node.edges[action_idx].mov.clone();
444 state.make_move(&mov);
445
446 let child_value = if node.edges[action_idx].child.is_some() {
447 self.descend(
448 node.edges[action_idx].child.as_mut().unwrap(),
449 state,
450 depth + 1,
451 )
452 } else {
453 let (child_node, leaf_value) = self.expand(state);
454 node.edges[action_idx].child = Some(Box::new(child_node));
455 leaf_value
456 };
457
458 let my_value = -child_value;
460 node.edges[action_idx].value_sum += my_value;
461 node.edges[action_idx].visits += 1;
462 node.visits += 1;
463
464 my_value
465 }
466
467 fn expand(&self, state: &G) -> (Node<G::Move>, f64) {
469 if let Some(pv) = state.terminal_value() {
470 return (
471 Node {
472 edges: vec![],
473 visits: 0,
474 },
475 proven_to_value(pv),
476 );
477 }
478
479 let moves: Vec<G::Move> = state.available_moves().into_iter().collect();
480 if moves.is_empty() {
481 return (
482 Node {
483 edges: vec![],
484 visits: 0,
485 },
486 0.0,
487 );
488 }
489
490 let (logits, value) = self.evaluator.evaluate(state, &moves);
491 let priors = softmax(&logits);
492
493 let node = Node {
494 edges: moves
495 .into_iter()
496 .enumerate()
497 .map(|(i, m)| Edge {
498 mov: m,
499 prior: priors[i],
500 visits: 0,
501 value_sum: 0.0,
502 child: None,
503 })
504 .collect(),
505 visits: 0,
506 };
507
508 (node, value)
509 }
510}
511
512fn sample_gumbel(rng: &mut impl Rng) -> f64 {
518 let u: f64 = rng.gen();
519 let u = u.clamp(1e-20, 1.0 - 1e-20);
520 -(-u.ln()).ln()
521}
522
523fn softmax(logits: &[f64]) -> Vec<f64> {
525 if logits.is_empty() {
526 return vec![];
527 }
528 let max = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
529 if !max.is_finite() {
530 let n = logits.len() as f64;
532 return vec![1.0 / n; logits.len()];
533 }
534 let exps: Vec<f64> = logits.iter().map(|&x| (x - max).exp()).collect();
535 let sum: f64 = exps.iter().sum();
536 if sum == 0.0 {
537 let n = logits.len() as f64;
538 return vec![1.0 / n; logits.len()];
539 }
540 exps.iter().map(|&e| e / sum).collect()
541}
542
543fn puct_select<M: Clone>(node: &Node<M>, c: f64) -> usize {
545 let sqrt_n = (node.visits as f64).sqrt();
546
547 node.edges
548 .iter()
549 .enumerate()
550 .max_by(|(_, a), (_, b)| {
551 let sa = puct_score(a, c, sqrt_n);
552 let sb = puct_score(b, c, sqrt_n);
553 sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
554 })
555 .map(|(i, _)| i)
556 .unwrap_or(0)
557}
558
559fn puct_score<M: Clone>(edge: &Edge<M>, c: f64, sqrt_parent_visits: f64) -> f64 {
560 let q = if edge.visits > 0 {
561 edge.value_sum / edge.visits as f64
562 } else {
563 0.0
564 };
565 let u = c * edge.prior * sqrt_parent_visits / (1.0 + edge.visits as f64);
566 q + u
567}
568
569fn completed_q<M: Clone>(edge: &Edge<M>, default_value: f64) -> f64 {
572 if edge.visits > 0 {
573 edge.value_sum / edge.visits as f64
574 } else {
575 default_value
576 }
577}
578
579fn proven_to_value(pv: ProvenValue) -> f64 {
581 match pv {
582 ProvenValue::Win => 1.0,
583 ProvenValue::Loss => -1.0,
584 ProvenValue::Draw | ProvenValue::Unknown => 0.0,
585 }
586}
587
588fn terminal_value<G: GameState>(state: &G) -> f64 {
590 state.terminal_value().map(proven_to_value).unwrap_or(0.0)
591}
592
593#[cfg(test)]
598mod tests {
599 use super::*;
600
601 #[test]
602 fn test_sample_gumbel_mean() {
603 let mut rng = Xoshiro256PlusPlus::seed_from_u64(123);
605 let n = 50_000;
606 let sum: f64 = (0..n).map(|_| sample_gumbel(&mut rng)).sum();
607 let mean = sum / n as f64;
608 assert!(
609 (mean - 0.5772).abs() < 0.02,
610 "Gumbel mean {mean} too far from 0.5772"
611 );
612 }
613
614 #[test]
615 fn test_softmax_sums_to_one() {
616 let logits = vec![1.0, 2.0, 3.0, 4.0];
617 let probs = softmax(&logits);
618 let sum: f64 = probs.iter().sum();
619 assert!((sum - 1.0).abs() < 1e-10);
620 }
621
622 #[test]
623 fn test_softmax_ordering() {
624 let logits = vec![1.0, 3.0, 2.0];
625 let probs = softmax(&logits);
626 assert!(probs[1] > probs[2]);
627 assert!(probs[2] > probs[0]);
628 }
629
630 #[test]
631 fn test_softmax_uniform() {
632 let logits = vec![0.0, 0.0, 0.0];
633 let probs = softmax(&logits);
634 for &p in &probs {
635 assert!((p - 1.0 / 3.0).abs() < 1e-10);
636 }
637 }
638
639 #[test]
640 fn test_softmax_empty() {
641 assert!(softmax(&[]).is_empty());
642 }
643
644 #[test]
645 fn test_softmax_single() {
646 let probs = softmax(&[42.0]);
647 assert_eq!(probs.len(), 1);
648 assert!((probs[0] - 1.0).abs() < 1e-10);
649 }
650
651 #[test]
652 fn test_softmax_extreme_large_logits() {
653 let logits = vec![1000.0, 1001.0, 999.0];
654 let probs = softmax(&logits);
655 let sum: f64 = probs.iter().sum();
656 assert!((sum - 1.0).abs() < 1e-10, "sum = {sum}");
657 assert!(probs[1] > probs[0]);
658 }
659
660 #[test]
661 fn test_softmax_extreme_negative_logits() {
662 let logits = vec![-1000.0, -1001.0, -999.0];
663 let probs = softmax(&logits);
664 let sum: f64 = probs.iter().sum();
665 assert!((sum - 1.0).abs() < 1e-10, "sum = {sum}");
666 assert!(probs[2] > probs[0]);
667 }
668
669 #[test]
670 fn test_softmax_all_neg_infinity_returns_uniform() {
671 let logits = vec![f64::NEG_INFINITY, f64::NEG_INFINITY, f64::NEG_INFINITY];
672 let probs = softmax(&logits);
673 for &p in &probs {
674 assert!((p - 1.0 / 3.0).abs() < 1e-10, "should be uniform, got {p}");
675 }
676 }
677
678 #[test]
679 fn test_softmax_nan_returns_uniform() {
680 let logits = vec![f64::NAN, f64::NAN];
681 let probs = softmax(&logits);
682 for &p in &probs {
683 assert!((p - 0.5).abs() < 1e-10, "NaN logits should produce uniform");
684 }
685 }
686
687 #[test]
688 fn test_puct_prefers_high_prior_initially() {
689 let node = Node {
690 edges: vec![
691 Edge {
692 mov: 0u32,
693 prior: 0.1,
694 visits: 0,
695 value_sum: 0.0,
696 child: None,
697 },
698 Edge {
699 mov: 1,
700 prior: 0.9,
701 visits: 0,
702 value_sum: 0.0,
703 child: None,
704 },
705 ],
706 visits: 1,
707 };
708 let selected = puct_select(&node, 1.25);
709 assert_eq!(selected, 1);
710 }
711
712 #[test]
713 fn test_puct_prefers_high_value_after_visits() {
714 let node = Node {
715 edges: vec![
716 Edge {
717 mov: 0u32,
718 prior: 0.5,
719 visits: 10,
720 value_sum: 8.0,
721 child: None,
722 },
723 Edge {
724 mov: 1,
725 prior: 0.5,
726 visits: 10,
727 value_sum: 2.0,
728 child: None,
729 },
730 ],
731 visits: 20,
732 };
733 let selected = puct_select(&node, 1.25);
734 assert_eq!(selected, 0);
735 }
736
737 #[test]
738 fn test_puct_zero_priors_degenerates_to_exploitation() {
739 let node = Node {
740 edges: vec![
741 Edge {
742 mov: 0u32,
743 prior: 0.0,
744 visits: 5,
745 value_sum: 3.0,
746 child: None,
747 },
748 Edge {
749 mov: 1,
750 prior: 0.0,
751 visits: 5,
752 value_sum: 1.0,
753 child: None,
754 },
755 ],
756 visits: 10,
757 };
758 let selected = puct_select(&node, 1.25);
760 assert_eq!(selected, 0);
761 }
762
763 #[test]
764 fn test_completed_q_visited() {
765 let edge = Edge {
766 mov: 0u32,
767 prior: 0.5,
768 visits: 4,
769 value_sum: 2.0,
770 child: None,
771 };
772 assert!((completed_q(&edge, 0.0) - 0.5).abs() < 1e-10);
773 }
774
775 #[test]
776 fn test_completed_q_unvisited() {
777 let edge = Edge {
778 mov: 0u32,
779 prior: 0.5,
780 visits: 0,
781 value_sum: 0.0,
782 child: None,
783 };
784 assert!((completed_q(&edge, 0.7) - 0.7).abs() < 1e-10);
785 }
786
787 #[test]
788 fn test_completed_q_negative() {
789 let edge = Edge {
790 mov: 0u32,
791 prior: 0.5,
792 visits: 4,
793 value_sum: -2.0,
794 child: None,
795 };
796 assert!((completed_q(&edge, 0.0) - (-0.5)).abs() < 1e-10);
797 }
798
799 #[test]
800 fn test_proven_to_value() {
801 assert!((proven_to_value(ProvenValue::Win) - 1.0).abs() < 1e-10);
802 assert!((proven_to_value(ProvenValue::Loss) - (-1.0)).abs() < 1e-10);
803 assert!((proven_to_value(ProvenValue::Draw) - 0.0).abs() < 1e-10);
804 assert!((proven_to_value(ProvenValue::Unknown) - 0.0).abs() < 1e-10);
805 }
806}