mcts_lib/mcts_node.rs
1use crate::board::{Board, Bound, GameOutcome, Player};
2
3/// Represents a single node in the Monte Carlo search tree.
4///
5/// Each node stores the state of the game, statistics about the outcomes of simulations,
6/// and information about the move that led to this state.
7#[derive(Debug, Clone)]
8pub struct MctsNode<T: Board> {
9 /// A unique identifier for the node.
10 pub id: i32,
11 /// The depth of the node in the tree.
12 pub height: usize,
13 /// The game state that this node represents.
14 pub board: Box<T>,
15 /// The move that led to this node's state from its parent. `None` for the root node.
16 pub prev_move: Option<T::Move>,
17 /// The player whose turn it is in this node's game state.
18 pub current_player: Player,
19 /// The outcome of the game at this node, if it is terminal.
20 pub outcome: GameOutcome,
21 /// The number of times this node has been visited during the search.
22 pub visits: i32,
23 /// The number of times simulations from this node have resulted in a win for the current player.
24 pub wins: i32,
25 /// The number of times simulations from this node have resulted in a draw.
26 pub draws: i32,
27 /// The bound of the node, used for alpha-beta pruning.
28 pub bound: Bound,
29 /// A flag indicating whether the outcome of this node is definitively known.
30 pub is_fully_calculated: bool,
31 /// A hash value representing the board state, used for quick comparisons and lookups.
32 pub board_hash: u128,
33}
34
35impl<T: Board> MctsNode<T> {
36 /// Creates a new `MctsNode` with the given ID and board state.
37 pub fn new(id: i32, boxed_board: Box<T>) -> Self {
38 let player = boxed_board.get_current_player();
39 let outcome = boxed_board.get_outcome();
40 let board_hash = boxed_board.get_hash();
41 MctsNode {
42 id,
43 height: 0,
44 board: boxed_board,
45 prev_move: None,
46 current_player: player,
47 outcome,
48 visits: 0,
49 wins: 0,
50 draws: 0,
51 bound: Bound::None,
52 is_fully_calculated: false,
53 board_hash,
54 }
55 }
56
57 /// Calculates the win rate of this node.
58 pub fn wins_rate(&self) -> f64 {
59 if self.visits == 0 {
60 0.0
61 } else {
62 (self.wins as f64) / (self.visits as f64)
63 }
64 }
65
66 /// Calculates the draw rate of this node.
67 pub fn draws_rate(&self) -> f64 {
68 if self.visits == 0 {
69 0.0
70 } else {
71 (self.draws as f64) / (self.visits as f64)
72 }
73 }
74}
75
76impl<T: Board> PartialEq<Self> for MctsNode<T> {
77 fn eq(&self, other: &Self) -> bool {
78 self.id == other.id
79 }
80}
81
82impl<T: Board> Eq for MctsNode<T> {}
83
84impl<T: Board> std::hash::Hash for MctsNode<T> {
85 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
86 self.id.hash(state);
87 }
88}