Skip to main content

entrenar/search/mcts/search/
algorithm.rs

1//! MCTS search algorithm implementation.
2//!
3//! This module contains the main search algorithm including
4//! selection, expansion, simulation, and backpropagation phases.
5
6#![allow(clippy::field_reassign_with_default)]
7
8use super::result::MctsResult;
9use super::stats::MctsStats;
10use crate::search::mcts::config::MctsConfig;
11use crate::search::mcts::node::NodeId;
12use crate::search::mcts::traits::{Action, ActionSpace, PolicyNetwork, State, StateSpace};
13use crate::search::mcts::tree::SearchTree;
14use crate::search::mcts::Reward;
15
16/// Main MCTS search algorithm
17pub struct MctsSearch<S: State, A: Action> {
18    /// Search tree
19    tree: SearchTree<S, A>,
20    /// Configuration
21    config: MctsConfig,
22    /// Random number generator
23    rng: rand::rngs::StdRng,
24}
25
26impl<S: State + Send + Sync, A: Action + Send + Sync> MctsSearch<S, A> {
27    /// Create a new MCTS search from initial state
28    pub fn new<AS: ActionSpace<S, A>>(
29        initial_state: S,
30        action_space: &AS,
31        config: MctsConfig,
32    ) -> Self {
33        use rand::SeedableRng;
34        let actions = action_space.legal_actions(&initial_state);
35        let tree = SearchTree::new(initial_state, actions);
36        Self { tree, config, rng: rand::rngs::StdRng::from_os_rng() }
37    }
38
39    /// Create a new MCTS search with a seed for reproducibility
40    pub fn with_seed<AS: ActionSpace<S, A>>(
41        initial_state: S,
42        action_space: &AS,
43        config: MctsConfig,
44        seed: u64,
45    ) -> Self {
46        use rand::SeedableRng;
47        let actions = action_space.legal_actions(&initial_state);
48        let tree = SearchTree::new(initial_state, actions);
49        Self { tree, config, rng: rand::rngs::StdRng::seed_from_u64(seed) }
50    }
51
52    /// Run the MCTS search
53    pub fn search<SS, AS>(
54        &mut self,
55        state_space: &SS,
56        action_space: &AS,
57        policy: Option<&dyn PolicyNetwork<S, A>>,
58    ) -> MctsResult<S, A>
59    where
60        SS: StateSpace<S, A>,
61        AS: ActionSpace<S, A>,
62    {
63        let mut total_sim_length = 0usize;
64        let mut max_depth = 0usize;
65
66        for _ in 0..self.config.max_iterations {
67            // Selection: traverse tree to find a leaf node
68            let (leaf_id, depth) = self.select();
69            max_depth = max_depth.max(depth);
70
71            // Get the leaf state
72            let leaf_state = self.tree.get(leaf_id).map(|n| n.state.clone());
73            let Some(leaf_state) = leaf_state else {
74                continue;
75            };
76
77            // Check if terminal
78            if leaf_state.is_terminal() {
79                let reward = state_space.evaluate(&leaf_state);
80                self.backpropagate(leaf_id, reward);
81                continue;
82            }
83
84            // Expansion: add a child node
85            let child_id = self.expand(leaf_id, state_space, action_space, policy);
86            let Some(child_id) = child_id else {
87                continue;
88            };
89
90            // Simulation: random playout from child
91            let child_state = self.tree.get(child_id).map(|n| n.state.clone());
92            let Some(child_state) = child_state else {
93                continue;
94            };
95
96            let (reward, sim_length) = self.simulate(&child_state, state_space, action_space);
97            total_sim_length += sim_length;
98
99            // Backpropagation: update statistics up the tree
100            self.backpropagate(child_id, reward);
101        }
102
103        // Compute results
104        let root = self.tree.root();
105        let root_visits = root.stats.visits;
106
107        // Get action visits
108        let action_visits: Vec<(A, usize)> = self
109            .tree
110            .children(self.tree.root_id)
111            .iter()
112            .filter_map(|child| child.action.clone().map(|a| (a, child.stats.visits)))
113            .collect();
114
115        // Select best action based on visits (robust child selection)
116        let best_child =
117            self.tree.children(self.tree.root_id).into_iter().max_by_key(|n| n.stats.visits);
118
119        let (best_action, expected_reward, resulting_state) = if let Some(child) = best_child {
120            (child.action.clone(), child.stats.mean_reward, Some(child.state.clone()))
121        } else {
122            (None, 0.0, None)
123        };
124
125        let avg_simulation_length = if self.config.max_iterations > 0 {
126            total_sim_length as f64 / self.config.max_iterations as f64
127        } else {
128            0.0
129        };
130
131        MctsResult {
132            best_action,
133            expected_reward,
134            action_visits,
135            stats: MctsStats {
136                iterations: self.config.max_iterations,
137                tree_size: self.tree.size(),
138                max_depth,
139                avg_simulation_length,
140                root_visits,
141            },
142            resulting_state,
143        }
144    }
145
146    /// Selection phase: traverse tree using UCB1/PUCT
147    fn select(&self) -> (NodeId, usize) {
148        let mut current_id = self.tree.root_id;
149        let mut depth = 0;
150
151        loop {
152            let node = match self.tree.get(current_id) {
153                Some(n) => n,
154                None => return (current_id, depth),
155            };
156
157            // If node has untried actions or is terminal, return it
158            if !node.untried_actions.is_empty() || node.state.is_terminal() {
159                return (current_id, depth);
160            }
161
162            // If no children, return current
163            if node.children.is_empty() {
164                return (current_id, depth);
165            }
166
167            // Select best child using UCB1/PUCT
168            let parent_visits = node.stats.visits;
169            let best_child =
170                node.children.iter().filter_map(|&cid| self.tree.get(cid)).max_by(|a, b| {
171                    let score_a = if self.config.use_policy_priors {
172                        a.stats.puct(parent_visits, self.config.exploration_constant)
173                    } else {
174                        a.stats.ucb1(parent_visits, self.config.exploration_constant)
175                    };
176                    let score_b = if self.config.use_policy_priors {
177                        b.stats.puct(parent_visits, self.config.exploration_constant)
178                    } else {
179                        b.stats.ucb1(parent_visits, self.config.exploration_constant)
180                    };
181                    score_a.partial_cmp(&score_b).unwrap_or(std::cmp::Ordering::Equal)
182                });
183
184            match best_child {
185                Some(child) => {
186                    current_id = child.id;
187                    depth += 1;
188                }
189                None => return (current_id, depth),
190            }
191        }
192    }
193
194    /// Expansion phase: add a child node for an untried action
195    fn expand<SS, AS>(
196        &mut self,
197        node_id: NodeId,
198        state_space: &SS,
199        action_space: &AS,
200        policy: Option<&dyn PolicyNetwork<S, A>>,
201    ) -> Option<NodeId>
202    where
203        SS: StateSpace<S, A>,
204        AS: ActionSpace<S, A>,
205    {
206        // Get an untried action
207        let (action, parent_state) = {
208            let node = self.tree.get_mut(node_id)?;
209            let action = node.untried_actions.pop()?;
210            let parent_state = node.state.clone();
211            node.expanded = node.untried_actions.is_empty();
212            (action, parent_state)
213        };
214
215        // Compute new state
216        let new_state = state_space.apply(&parent_state, &action);
217        let new_actions = action_space.legal_actions(&new_state);
218
219        // Get prior from policy network
220        let prior = policy
221            .and_then(|p| {
222                p.predict(&parent_state).iter().find(|(a, _)| a == &action).map(|(_, p)| *p)
223            })
224            .unwrap_or(1.0 / (new_actions.len().max(1) as f64));
225
226        // Add child
227        let child_id = self.tree.add_child(node_id, new_state, action, new_actions, prior);
228        Some(child_id)
229    }
230
231    /// Simulation phase: random playout from state
232    fn simulate<SS, AS>(
233        &mut self,
234        initial_state: &S,
235        state_space: &SS,
236        action_space: &AS,
237    ) -> (Reward, usize)
238    where
239        SS: StateSpace<S, A>,
240        AS: ActionSpace<S, A>,
241    {
242        use rand::prelude::IndexedRandom;
243
244        let mut state = initial_state.clone();
245        let mut depth = 0;
246
247        while !state.is_terminal() && depth < self.config.max_simulation_depth {
248            let actions = action_space.legal_actions(&state);
249            if actions.is_empty() {
250                break;
251            }
252
253            // Random action selection
254            if let Some(action) = actions.choose(&mut self.rng) {
255                state = state_space.apply(&state, action);
256            }
257            depth += 1;
258        }
259
260        (state_space.evaluate(&state), depth)
261    }
262
263    /// Backpropagation phase: update statistics up the tree
264    fn backpropagate(&mut self, leaf_id: NodeId, reward: Reward) {
265        let mut current_id = Some(leaf_id);
266
267        while let Some(id) = current_id {
268            if let Some(node) = self.tree.get_mut(id) {
269                node.stats.update(reward);
270                current_id = node.parent;
271            } else {
272                break;
273            }
274        }
275    }
276
277    /// Get the current tree size
278    #[must_use]
279    pub fn tree_size(&self) -> usize {
280        self.tree.size()
281    }
282
283    /// Get reference to the search tree
284    #[must_use]
285    pub fn tree(&self) -> &SearchTree<S, A> {
286        &self.tree
287    }
288}