mcts_lib/
mcts.rs

1use crate::board::{Board, Bound, GameOutcome, Player};
2use crate::mcts_node::MctsNode;
3use crate::random::{RandomGenerator, StandardRandomGenerator};
4use ego_tree::{NodeId, NodeRef, Tree};
5use std::collections::HashSet;
6use std::ops::{Deref, DerefMut};
7
8/// The main struct for running the Monte Carlo Tree Search algorithm.
9///
10/// It holds the search tree, the random number generator, and the configuration for the search.
11pub struct MonteCarloTreeSearch<T: Board, K: RandomGenerator> {
12    tree: Tree<MctsNode<T>>,
13    root_id: NodeId,
14    random: K,
15    use_alpha_beta_pruning: bool,
16    next_action: MctsAction,
17}
18
19/// A builder for creating instances of `MonteCarloTreeSearch`.
20///
21/// This provides a convenient way to configure the MCTS search with different parameters.
22pub struct MonteCarloTreeSearchBuilder<T: Board, K: RandomGenerator> {
23    board: T,
24    random_generator: K,
25    use_alpha_beta_pruning: bool,
26}
27
28impl<T: Board, K: RandomGenerator> MonteCarloTreeSearchBuilder<T, K> {
29    /// Creates a new builder with the given initial board state.
30    pub fn new(board: T) -> Self {
31        Self {
32            board,
33            random_generator: K::default(),
34            use_alpha_beta_pruning: true,
35        }
36    }
37
38    /// Sets the random number generator for the MCTS search.
39    pub fn with_random_generator(mut self, rg: K) -> Self {
40        self.random_generator = rg;
41        self
42    }
43
44    /// Enables or disables alpha-beta pruning.
45    pub fn with_alpha_beta_pruning(mut self, use_abp: bool) -> Self {
46        self.use_alpha_beta_pruning = use_abp;
47        self
48    }
49
50    /// Builds the `MonteCarloTreeSearch` instance with the configured parameters.
51    pub fn build(self) -> MonteCarloTreeSearch<T, K> {
52        MonteCarloTreeSearch::new(
53            self.board,
54            self.random_generator,
55            self.use_alpha_beta_pruning,
56        )
57    }
58}
59
60impl<T: Board, K: RandomGenerator> MonteCarloTreeSearch<T, K> {
61    /// Returns a new builder for `MonteCarloTreeSearch`.
62    pub fn builder(board: T) -> MonteCarloTreeSearchBuilder<T, K> {
63        MonteCarloTreeSearchBuilder::new(board)
64    }
65
66    /// Creates a new `MonteCarloTreeSearch` instance.
67    ///
68    /// It is recommended to use the builder pattern via `MonteCarloTreeSearch::builder()` instead.
69    pub fn new(board: T, rg: K, use_alpha_beta_pruning: bool) -> Self {
70        let root_mcts_node = MctsNode::new(0, Box::new(board));
71        let tree: Tree<MctsNode<T>> = Tree::new(root_mcts_node);
72        let root_id = tree.root().id();
73
74        Self {
75            tree,
76            root_id: root_id.clone(),
77            random: rg,
78            use_alpha_beta_pruning,
79            next_action: MctsAction::Selection {
80                R: root_id.clone(),
81                RP: vec![],
82            },
83        }
84    }
85
86    /// Returns an immutable reference to the underlying search tree.
87    pub fn get_tree(&self) -> &Tree<MctsNode<T>> {
88        &self.tree
89    }
90
91    /// Returns the next MCTS action to be performed. Useful for debugging and visualization.
92    pub fn get_next_mcts_action(&self) -> &MctsAction {
93        &self.next_action
94    }
95
96    /// Executes a single step of the MCTS algorithm (Selection, Expansion, Simulation, or Backpropagation).
97    pub fn execute_action(&mut self) {
98        match self.next_action.clone() {
99            MctsAction::Selection { R, RP: _cr } => {
100                let maybe_selected_node = self.select_next_node(R);
101                self.next_action = match maybe_selected_node {
102                    None => MctsAction::EverythingIsCalculated,
103                    Some(selected_node) => MctsAction::Expansion { L: selected_node },
104                };
105            }
106            MctsAction::Expansion { L } => {
107                let (children, selected_child) = self.expand_node(L);
108                self.next_action = MctsAction::Simulation {
109                    C: selected_child,
110                    AC: children,
111                };
112            }
113            MctsAction::Simulation { C, AC: _ac } => {
114                let outcome = self.simulate(C);
115                self.next_action = MctsAction::Backpropagation { C, result: outcome };
116            }
117            MctsAction::Backpropagation { C, result } => {
118                let affected_nodes = self.backpropagate(C, result);
119                self.next_action = MctsAction::Selection {
120                    R: self.root_id.clone(),
121                    RP: affected_nodes,
122                }
123            }
124            MctsAction::EverythingIsCalculated => {}
125        }
126    }
127
128    /// Performs one full iteration of the MCTS algorithm (Selection, Expansion, Simulation, Backpropagation).
129    /// Returns the path of nodes that were updated during backpropagation.
130    pub fn do_iteration(&mut self) -> Vec<NodeId> {
131        self.execute_action();
132        let mut is_selection = matches!(self.next_action, MctsAction::Selection { R: _, RP: _ });
133        let mut is_fully_calculated =
134            matches!(self.next_action, MctsAction::EverythingIsCalculated);
135        while !is_selection && !is_fully_calculated {
136            self.execute_action();
137            is_selection = matches!(self.next_action, MctsAction::Selection { R: _, RP: _ });
138            is_fully_calculated = matches!(self.next_action, MctsAction::EverythingIsCalculated);
139        }
140
141        match self.next_action.clone() {
142            MctsAction::Selection { R: _, RP: rp } => rp,
143            _ => vec![],
144        }
145    }
146
147    /// Runs the MCTS search for a specified number of iterations.
148    pub fn iterate_n_times(&mut self, n: u32) {
149        let mut iteration = 0;
150        while iteration < n {
151            self.do_iteration();
152            iteration += 1;
153        }
154    }
155
156    /// Returns a reference to the root node of the search tree.
157    pub fn get_root(&self) -> MctsTreeNode<T> {
158        let root = self.tree.root();
159        root.into()
160    }
161
162    /// Selects the most promising node to expand, using the UCB1 formula.
163    fn select_next_node(&self, root_id: NodeId) -> Option<NodeId> {
164        let mut promising_node_id = root_id.clone();
165        let mut has_changed = false;
166        loop {
167            let mut best_child_id: Option<NodeId> = None;
168            let mut max_ucb = f64::MIN;
169            let node = self.tree.get(promising_node_id).unwrap();
170            for child in node.children() {
171                if child.value().is_fully_calculated {
172                    continue;
173                }
174
175                let current_ucb = MonteCarloTreeSearch::<T, K>::ucb_value(
176                    node.value().visits,
177                    child.value().wins,
178                    child.value().visits,
179                );
180                if current_ucb > max_ucb {
181                    max_ucb = current_ucb;
182                    best_child_id = Some(child.id());
183                }
184            }
185            if best_child_id.is_none() {
186                break;
187            }
188            promising_node_id = best_child_id.unwrap();
189            has_changed = true;
190        }
191
192        if has_changed {
193            Some(promising_node_id.clone())
194        } else {
195            let root = self.tree.root();
196            if root.children().count() == 0 {
197                Some(root_id.clone())
198            } else {
199                None
200            }
201        }
202    }
203
204    /// Expands a leaf node by creating its children, representing all possible moves from that state.
205    fn expand_node(&mut self, node_id: NodeId) -> (Vec<NodeId>, NodeId) {
206        let node = self.tree.get(node_id).unwrap();
207        if !node.children().count() == 0 {
208            panic!("BUG: expanding already expanded node");
209        }
210        if node.value().outcome != GameOutcome::InProgress {
211            return (vec![], node_id.clone());
212        }
213
214        
215        let children_height = node.value().height + 1;
216        let all_possible_moves = self.get_available_moves(node_id);
217        let mut new_mcts_nodes = Vec::with_capacity(all_possible_moves.len());
218
219        for possible_move in all_possible_moves {
220            let mut board_clone = node.value().board.clone();
221            board_clone.perform_move(&possible_move);
222            let new_node_id = self.random.next();
223            let mut mcts_node = MctsNode::new(new_node_id, board_clone);
224            mcts_node.prev_move = Some(possible_move);
225            mcts_node.height = children_height;
226            new_mcts_nodes.push(mcts_node);
227        }
228
229        let mut new_node_ids = Vec::with_capacity(new_mcts_nodes.len());
230        for mcts_node in new_mcts_nodes {
231            let mut node = self.tree.get_mut(node_id).unwrap();
232            node.append(mcts_node);
233            new_node_ids.push(node_id.clone());
234        }
235
236        let children: Vec<_> = self.tree.get(node_id).unwrap().children().collect();
237        let selected_child_index = self.random.next_range(0, children.len() as i32) as usize;
238        let selected_child = children[selected_child_index].id();
239        (new_node_ids, selected_child)
240    }
241
242    /// Simulates a random playout from a given node until the game ends.
243    fn simulate(&mut self, node_id: NodeId) -> GameOutcome {
244        let node = self.tree.get(node_id).unwrap();
245        let mut board = node.value().board.clone();
246        let mut outcome = board.get_outcome();
247        let mut hashes = self.get_branch_hashes(node_id);
248
249        while outcome == GameOutcome::InProgress {
250            let mut all_possible_moves = board.get_available_moves();
251
252            while !all_possible_moves.is_empty() {
253                let random_move_index =
254                    self.random.next_range(0, all_possible_moves.len() as i32) as usize;
255                let random_move = all_possible_moves.get(random_move_index).unwrap();
256                let mut new_board = board.clone();
257                new_board.perform_move(random_move);
258                let new_board_hash = new_board.get_hash();
259                if hashes.contains(&new_board_hash) {
260                    all_possible_moves.remove(random_move_index);
261                    continue;
262                } else {
263                    hashes.insert(new_board_hash);
264                    board = new_board;
265                    break;
266                }
267            }
268
269            if all_possible_moves.is_empty() {
270                return GameOutcome::Lose;
271            }
272
273            outcome = board.get_outcome();
274        }
275        outcome
276    }
277
278    /// Propagates the result of a simulation back up the tree, updating node statistics.
279    fn backpropagate(&mut self, node_id: NodeId, outcome: GameOutcome) -> Vec<NodeId> {
280        let mut branch = vec![node_id.clone()];
281
282        loop {
283            let temp_node = self.tree.get(*branch.last().unwrap()).unwrap();
284            match temp_node.parent() {
285                None => break,
286                Some(parent) => branch.push(parent.id()),
287            }
288        }
289
290        let is_win = outcome == GameOutcome::Win;
291        let is_draw = outcome == GameOutcome::Draw;
292
293        for node_id in &branch {
294            let bound = self.get_bound(*node_id);
295            let is_fully_calculated = self.is_fully_calculated(*node_id, bound);
296            let mut temp_node = self.tree.get_mut(*node_id).unwrap();
297            let mcts_node = temp_node.value();
298            mcts_node.visits += 1;
299            if is_win {
300                mcts_node.wins += 1;
301            }
302
303            if is_draw {
304                mcts_node.draws += 1;
305            }
306
307            if is_fully_calculated {
308                mcts_node.is_fully_calculated = true;
309            }
310
311            if bound != Bound::None {
312                mcts_node.bound = bound;
313            }
314        }
315
316        branch
317    }
318
319    /// Determines the bound of a node for alpha-beta pruning.
320    fn get_bound(&self, node_id: NodeId) -> Bound {
321        if !self.use_alpha_beta_pruning {
322            return Bound::None;
323        }
324
325        let node = self.tree.get(node_id).unwrap();
326        let mcts_node = node.value();
327        if mcts_node.bound != Bound::None {
328            return mcts_node.bound;
329        }
330
331        if mcts_node.outcome == GameOutcome::Win {
332            return Bound::DefoWin;
333        }
334
335        if mcts_node.outcome == GameOutcome::Lose {
336            return Bound::DefoLose;
337        }
338
339        if node.children().count() == 0 {
340            return Bound::None;
341        }
342
343        match mcts_node.current_player {
344            Player::Me => {
345                if node.children().all(|x| x.value().bound == Bound::DefoLose) {
346                    return Bound::DefoLose;
347                }
348
349                if node.children().any(|x| x.value().bound == Bound::DefoWin) {
350                    return Bound::DefoWin;
351                }
352            }
353            Player::Other => {
354                if node.children().all(|x| x.value().bound == Bound::DefoWin) {
355                    return Bound::DefoWin;
356                }
357
358                if node.children().any(|x| x.value().bound == Bound::DefoLose) {
359                    return Bound::DefoLose;
360                }
361            }
362        }
363
364        Bound::None
365    }
366
367    /// Checks if a node can be considered fully calculated, meaning its outcome is certain.
368    fn is_fully_calculated(&self, node_id: NodeId, bound: Bound) -> bool {
369        if bound != Bound::None {
370            return true;
371        }
372
373        let node = self.tree.get(node_id).unwrap();
374        if node.value().outcome != GameOutcome::InProgress {
375            return true;
376        }
377
378        if node.children().count() == 0 {
379            return false;
380        }
381
382        let all_children_calculated = node.children().all(|x| x.value().is_fully_calculated);
383
384        all_children_calculated
385    }
386
387    /// Calculates the UCB1 (Upper Confidence Bound 1) value for a node.
388    fn ucb_value(total_visits: i32, node_wins: i32, node_visit: i32) -> f64 {
389        const EXPLORATION_PARAMETER: f64 = std::f64::consts::SQRT_2;
390
391        if node_visit == 0 {
392            i32::MAX.into()
393        } else {
394            ((node_wins as f64) / (node_visit as f64))
395                + EXPLORATION_PARAMETER
396                    * f64::sqrt(f64::ln(total_visits as f64) / (node_visit as f64))
397        }
398    }
399
400    /// Retrieves the hashes of all nodes in the branch from the given node to the root.
401    fn get_branch_hashes(&self, node_id: NodeId) -> HashSet<u128> {
402        let mut current_node = self.tree.get(node_id).unwrap();
403        let mut branch_hashes = HashSet::with_capacity(current_node.value().height + 1);
404        loop {
405            branch_hashes.insert(current_node.value().board_hash);
406            match current_node.parent() {
407                None => break,
408                Some(parent) => current_node = parent,
409            }
410        }
411        branch_hashes
412    }
413
414    /// Determines the available moves from a given node, avoiding cycles in the tree.
415    fn get_available_moves(&self, node_id: NodeId) -> Vec<T::Move> {
416        let node = self.tree.get(node_id).unwrap();
417        let hashes = self.get_branch_hashes(node_id);
418
419        let available_moves = node.value().board.get_available_moves();
420        let mut filtered_moves = Vec::with_capacity(available_moves.len());
421        for available_move in &available_moves {
422            let mut board_clone = node.value().board.clone();
423            board_clone.perform_move(available_move);
424            let hash = board_clone.get_hash();
425            if hashes.contains(&hash) {
426                filtered_moves.push(available_move);
427            }
428        }
429        available_moves
430    }
431}
432
433impl<T: Board> MonteCarloTreeSearch<T, StandardRandomGenerator> {
434    pub fn from_board(board: T) -> Self {
435        MonteCarloTreeSearchBuilder::new(board).build()
436    }
437}
438
439/// Represents the four main stages of the MCTS algorithm.
440///
441/// This enum is used to manage the state of the search process.
442#[allow(non_snake_case)]
443#[derive(Debug, PartialEq, Clone)]
444pub enum MctsAction {
445    /// **Selection**: Start from the root `R` and select successive child nodes until a leaf node `L` is reached.
446    Selection {
447        /// The root of the current selection phase.
448        R: NodeId,
449        /// The path of nodes visited during the last backpropagation phase.
450        RP: Vec<NodeId>,
451    },
452    /// **Expansion**: Create one or more child nodes from the selected leaf node `L`.
453    Expansion {
454        /// The leaf node to be expanded.
455        L: NodeId,
456    },
457    /// **Simulation**: Run a random playout from a newly created child node `C`.
458    Simulation {
459        /// The child node from which the simulation will start.
460        C: NodeId,
461        /// All children created during the expansion phase.
462        AC: Vec<NodeId>,
463    },
464    /// **Backpropagation**: Update the statistics of the nodes on the path from `C` to the root `R`.
465    Backpropagation {
466        /// The child node from which the simulation was run.
467        C: NodeId,
468        /// The result of the simulation.
469        result: GameOutcome,
470    },
471    /// Represents a state where the entire tree has been explored and the outcome is certain.
472    EverythingIsCalculated,
473}
474
475impl MctsAction {
476    /// Returns the name of the current MCTS action as a string.
477    pub fn get_name(&self) -> String {
478        match self {
479            MctsAction::Selection { R: _, RP: _ } => "Selection".to_string(),
480            MctsAction::Expansion { L: _ } => "Expansion".to_string(),
481            MctsAction::Simulation { C: _, AC: _ } => "Simulation".to_string(),
482            MctsAction::Backpropagation { C: _, result: _ } => "Backpropagation".to_string(),
483            MctsAction::EverythingIsCalculated => "EverythingIsCalculated".to_string(),
484        }
485    }
486}
487
488pub struct MctsTreeNode<'a, T: Board>(pub NodeRef<'a, MctsNode<T>>);
489
490impl<'a, T: Board> Deref for MctsTreeNode<'a, T> {
491    type Target = NodeRef<'a, MctsNode<T>>;
492
493    fn deref(&self) -> &Self::Target {
494        &self.0
495    }
496}
497
498impl<'a, T: Board> DerefMut for MctsTreeNode<'a, T> {
499    fn deref_mut(&mut self) -> &mut Self::Target {
500        &mut self.0
501    }
502}
503
504impl<'a, T: Board> Into<NodeRef<'a, MctsNode<T>>> for MctsTreeNode<'a, T> {
505    fn into(self) -> NodeRef<'a, MctsNode<T>> {
506        self.0
507    }
508}
509
510impl<'a, T: Board> From<NodeRef<'a, MctsNode<T>>> for MctsTreeNode<'a, T> {
511    fn from(node: NodeRef<'a, MctsNode<T>>) -> Self {
512        Self(node)
513    }
514}
515
516impl<'a, T: Board> MctsTreeNode<'a, T> {
517    /// Returns the child of the given node that is considered the most promising, based on win rate.
518    pub fn get_best_child(&self) -> Option<MctsTreeNode<'a, T>> {
519        let mut best_child = None;
520        let mut best_child_value = f64::MIN;
521
522        // get the best child amount with DefoWin bound
523        for child in self
524            .children()
525            .filter(|x| x.value().bound == Bound::DefoWin)
526        {
527            let child_value = child.value().wins_rate();
528            if child_value > best_child_value {
529                best_child = Some(child);
530                best_child_value = child_value;
531            }
532        }
533
534        if best_child.is_some() {
535            return best_child.map(|x| x.into());
536        }
537
538        // get the best child overall
539        for child in self.children() {
540            let child_value = child.value().wins_rate();
541            if child_value > best_child_value {
542                best_child = Some(child);
543                best_child_value = child_value;
544            }
545        }
546
547        best_child.map(|x| x.into())
548    }
549}