mcts_lib/
mcts_node.rs

1use crate::board::{Board, Bound, GameOutcome, Player};
2
3/// Represents a single node in the Monte Carlo search tree.
4///
5/// Each node stores the state of the game, statistics about the outcomes of simulations,
6/// and information about the move that led to this state.
7#[derive(Debug, Clone)]
8pub struct MctsNode<T: Board> {
9    /// A unique identifier for the node.
10    pub id: i32,
11    /// The depth of the node in the tree.
12    pub height: usize,
13    /// The game state that this node represents.
14    pub board: Box<T>,
15    /// The move that led to this node's state from its parent. `None` for the root node.
16    pub prev_move: Option<T::Move>,
17    /// The player whose turn it is in this node's game state.
18    pub current_player: Player,
19    /// The outcome of the game at this node, if it is terminal.
20    pub outcome: GameOutcome,
21    /// The number of times this node has been visited during the search.
22    pub visits: i32,
23    /// The number of times simulations from this node have resulted in a win for the current player.
24    pub wins: i32,
25    /// The number of times simulations from this node have resulted in a draw.
26    pub draws: i32,
27    /// The bound of the node, used for alpha-beta pruning.
28    pub bound: Bound,
29    /// A flag indicating whether the outcome of this node is definitively known.
30    pub is_fully_calculated: bool,
31    /// A hash value representing the board state, used for quick comparisons and lookups.
32    pub board_hash: u128,
33}
34
35impl<T: Board> MctsNode<T> {
36    /// Creates a new `MctsNode` with the given ID and board state.
37    pub fn new(id: i32, boxed_board: Box<T>) -> Self {
38        let player = boxed_board.get_current_player();
39        let outcome = boxed_board.get_outcome();
40        let board_hash = boxed_board.get_hash();
41        MctsNode {
42            id,
43            height: 0,
44            board: boxed_board,
45            prev_move: None,
46            current_player: player,
47            outcome,
48            visits: 0,
49            wins: 0,
50            draws: 0,
51            bound: Bound::None,
52            is_fully_calculated: false,
53            board_hash,
54        }
55    }
56
57    /// Calculates the win rate of this node.
58    pub fn wins_rate(&self) -> f64 {
59        if self.visits == 0 {
60            0.0
61        } else {
62            (self.wins as f64) / (self.visits as f64)
63        }
64    }
65
66    /// Calculates the draw rate of this node.
67    pub fn draws_rate(&self) -> f64 {
68        if self.visits == 0 {
69            0.0
70        } else {
71            (self.draws as f64) / (self.visits as f64)
72        }
73    }
74}
75
76impl<T: Board> PartialEq<Self> for MctsNode<T> {
77    fn eq(&self, other: &Self) -> bool {
78        self.id == other.id
79    }
80}
81
82impl<T: Board> Eq for MctsNode<T> {}
83
84impl<T: Board> std::hash::Hash for MctsNode<T> {
85    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
86        self.id.hash(state);
87    }
88}