mcts_lib/mcts_node.rs
1use crate::board::{Board, Bound, GameOutcome, Player};
2
3/// Represents a single node in the Monte Carlo search tree.
4///
5/// Each node stores the state of the game, statistics about the outcomes of simulations,
6/// and information about the move that led to this state.
7#[derive(Debug, Clone)]
8pub struct MctsNode<T: Board> {
9 /// A unique identifier for the node.
10 pub id: i32,
11 /// The depth of the node in the tree.
12 pub height: i32,
13 /// The game state that this node represents.
14 pub board: Box<T>,
15 /// The move that led to this node's state from its parent. `None` for the root node.
16 pub prev_move: Option<T::Move>,
17 /// The player whose turn it is in this node's game state.
18 pub current_player: Player,
19 /// The outcome of the game at this node, if it is terminal.
20 pub outcome: GameOutcome,
21 /// The number of times this node has been visited during the search.
22 pub visits: i32,
23 /// The number of times simulations from this node have resulted in a win for the current player.
24 pub wins: i32,
25 /// The number of times simulations from this node have resulted in a draw.
26 pub draws: i32,
27 /// The bound of the node, used for alpha-beta pruning.
28 pub bound: Bound,
29 /// A flag indicating whether the outcome of this node is definitively known.
30 pub is_fully_calculated: bool,
31}
32
33impl<T: Board> MctsNode<T> {
34 /// Creates a new `MctsNode` with the given ID and board state.
35 pub fn new(id: i32, boxed_board: Box<T>) -> Self {
36 let player = boxed_board.get_current_player();
37 let outcome = boxed_board.get_outcome();
38 MctsNode {
39 id,
40 height: 0,
41 board: boxed_board,
42 prev_move: None,
43 current_player: player,
44 outcome,
45 visits: 0,
46 wins: 0,
47 draws: 0,
48 bound: Bound::None,
49 is_fully_calculated: false,
50 }
51 }
52
53 /// Calculates the win rate of this node.
54 pub fn wins_rate(&self) -> f64 {
55 if self.visits == 0 {
56 0.0
57 } else {
58 (self.wins as f64) / (self.visits as f64)
59 }
60 }
61
62 /// Calculates the draw rate of this node.
63 pub fn draws_rate(&self) -> f64 {
64 if self.visits == 0 {
65 0.0
66 } else {
67 (self.draws as f64) / (self.visits as f64)
68 }
69 }
70}
71
72impl<T: Board> PartialEq<Self> for MctsNode<T> {
73 fn eq(&self, other: &Self) -> bool {
74 self.id == other.id
75 }
76}
77
78impl<T: Board> Eq for MctsNode<T> {}
79
80impl<T: Board> std::hash::Hash for MctsNode<T> {
81 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
82 self.id.hash(state);
83 }
84}