entrenar/search/mcts/search/
algorithm.rs1#![allow(clippy::field_reassign_with_default)]
7
8use super::result::MctsResult;
9use super::stats::MctsStats;
10use crate::search::mcts::config::MctsConfig;
11use crate::search::mcts::node::NodeId;
12use crate::search::mcts::traits::{Action, ActionSpace, PolicyNetwork, State, StateSpace};
13use crate::search::mcts::tree::SearchTree;
14use crate::search::mcts::Reward;
15
16pub struct MctsSearch<S: State, A: Action> {
18 tree: SearchTree<S, A>,
20 config: MctsConfig,
22 rng: rand::rngs::StdRng,
24}
25
26impl<S: State + Send + Sync, A: Action + Send + Sync> MctsSearch<S, A> {
27 pub fn new<AS: ActionSpace<S, A>>(
29 initial_state: S,
30 action_space: &AS,
31 config: MctsConfig,
32 ) -> Self {
33 use rand::SeedableRng;
34 let actions = action_space.legal_actions(&initial_state);
35 let tree = SearchTree::new(initial_state, actions);
36 Self { tree, config, rng: rand::rngs::StdRng::from_os_rng() }
37 }
38
39 pub fn with_seed<AS: ActionSpace<S, A>>(
41 initial_state: S,
42 action_space: &AS,
43 config: MctsConfig,
44 seed: u64,
45 ) -> Self {
46 use rand::SeedableRng;
47 let actions = action_space.legal_actions(&initial_state);
48 let tree = SearchTree::new(initial_state, actions);
49 Self { tree, config, rng: rand::rngs::StdRng::seed_from_u64(seed) }
50 }
51
52 pub fn search<SS, AS>(
54 &mut self,
55 state_space: &SS,
56 action_space: &AS,
57 policy: Option<&dyn PolicyNetwork<S, A>>,
58 ) -> MctsResult<S, A>
59 where
60 SS: StateSpace<S, A>,
61 AS: ActionSpace<S, A>,
62 {
63 let mut total_sim_length = 0usize;
64 let mut max_depth = 0usize;
65
66 for _ in 0..self.config.max_iterations {
67 let (leaf_id, depth) = self.select();
69 max_depth = max_depth.max(depth);
70
71 let leaf_state = self.tree.get(leaf_id).map(|n| n.state.clone());
73 let Some(leaf_state) = leaf_state else {
74 continue;
75 };
76
77 if leaf_state.is_terminal() {
79 let reward = state_space.evaluate(&leaf_state);
80 self.backpropagate(leaf_id, reward);
81 continue;
82 }
83
84 let child_id = self.expand(leaf_id, state_space, action_space, policy);
86 let Some(child_id) = child_id else {
87 continue;
88 };
89
90 let child_state = self.tree.get(child_id).map(|n| n.state.clone());
92 let Some(child_state) = child_state else {
93 continue;
94 };
95
96 let (reward, sim_length) = self.simulate(&child_state, state_space, action_space);
97 total_sim_length += sim_length;
98
99 self.backpropagate(child_id, reward);
101 }
102
103 let root = self.tree.root();
105 let root_visits = root.stats.visits;
106
107 let action_visits: Vec<(A, usize)> = self
109 .tree
110 .children(self.tree.root_id)
111 .iter()
112 .filter_map(|child| child.action.clone().map(|a| (a, child.stats.visits)))
113 .collect();
114
115 let best_child =
117 self.tree.children(self.tree.root_id).into_iter().max_by_key(|n| n.stats.visits);
118
119 let (best_action, expected_reward, resulting_state) = if let Some(child) = best_child {
120 (child.action.clone(), child.stats.mean_reward, Some(child.state.clone()))
121 } else {
122 (None, 0.0, None)
123 };
124
125 let avg_simulation_length = if self.config.max_iterations > 0 {
126 total_sim_length as f64 / self.config.max_iterations as f64
127 } else {
128 0.0
129 };
130
131 MctsResult {
132 best_action,
133 expected_reward,
134 action_visits,
135 stats: MctsStats {
136 iterations: self.config.max_iterations,
137 tree_size: self.tree.size(),
138 max_depth,
139 avg_simulation_length,
140 root_visits,
141 },
142 resulting_state,
143 }
144 }
145
146 fn select(&self) -> (NodeId, usize) {
148 let mut current_id = self.tree.root_id;
149 let mut depth = 0;
150
151 loop {
152 let node = match self.tree.get(current_id) {
153 Some(n) => n,
154 None => return (current_id, depth),
155 };
156
157 if !node.untried_actions.is_empty() || node.state.is_terminal() {
159 return (current_id, depth);
160 }
161
162 if node.children.is_empty() {
164 return (current_id, depth);
165 }
166
167 let parent_visits = node.stats.visits;
169 let best_child =
170 node.children.iter().filter_map(|&cid| self.tree.get(cid)).max_by(|a, b| {
171 let score_a = if self.config.use_policy_priors {
172 a.stats.puct(parent_visits, self.config.exploration_constant)
173 } else {
174 a.stats.ucb1(parent_visits, self.config.exploration_constant)
175 };
176 let score_b = if self.config.use_policy_priors {
177 b.stats.puct(parent_visits, self.config.exploration_constant)
178 } else {
179 b.stats.ucb1(parent_visits, self.config.exploration_constant)
180 };
181 score_a.partial_cmp(&score_b).unwrap_or(std::cmp::Ordering::Equal)
182 });
183
184 match best_child {
185 Some(child) => {
186 current_id = child.id;
187 depth += 1;
188 }
189 None => return (current_id, depth),
190 }
191 }
192 }
193
194 fn expand<SS, AS>(
196 &mut self,
197 node_id: NodeId,
198 state_space: &SS,
199 action_space: &AS,
200 policy: Option<&dyn PolicyNetwork<S, A>>,
201 ) -> Option<NodeId>
202 where
203 SS: StateSpace<S, A>,
204 AS: ActionSpace<S, A>,
205 {
206 let (action, parent_state) = {
208 let node = self.tree.get_mut(node_id)?;
209 let action = node.untried_actions.pop()?;
210 let parent_state = node.state.clone();
211 node.expanded = node.untried_actions.is_empty();
212 (action, parent_state)
213 };
214
215 let new_state = state_space.apply(&parent_state, &action);
217 let new_actions = action_space.legal_actions(&new_state);
218
219 let prior = policy
221 .and_then(|p| {
222 p.predict(&parent_state).iter().find(|(a, _)| a == &action).map(|(_, p)| *p)
223 })
224 .unwrap_or(1.0 / (new_actions.len().max(1) as f64));
225
226 let child_id = self.tree.add_child(node_id, new_state, action, new_actions, prior);
228 Some(child_id)
229 }
230
231 fn simulate<SS, AS>(
233 &mut self,
234 initial_state: &S,
235 state_space: &SS,
236 action_space: &AS,
237 ) -> (Reward, usize)
238 where
239 SS: StateSpace<S, A>,
240 AS: ActionSpace<S, A>,
241 {
242 use rand::prelude::IndexedRandom;
243
244 let mut state = initial_state.clone();
245 let mut depth = 0;
246
247 while !state.is_terminal() && depth < self.config.max_simulation_depth {
248 let actions = action_space.legal_actions(&state);
249 if actions.is_empty() {
250 break;
251 }
252
253 if let Some(action) = actions.choose(&mut self.rng) {
255 state = state_space.apply(&state, action);
256 }
257 depth += 1;
258 }
259
260 (state_space.evaluate(&state), depth)
261 }
262
263 fn backpropagate(&mut self, leaf_id: NodeId, reward: Reward) {
265 let mut current_id = Some(leaf_id);
266
267 while let Some(id) = current_id {
268 if let Some(node) = self.tree.get_mut(id) {
269 node.stats.update(reward);
270 current_id = node.parent;
271 } else {
272 break;
273 }
274 }
275 }
276
277 #[must_use]
279 pub fn tree_size(&self) -> usize {
280 self.tree.size()
281 }
282
283 #[must_use]
285 pub fn tree(&self) -> &SearchTree<S, A> {
286 &self.tree
287 }
288}