1use crate::board::{Board, Bound, GameOutcome, Player};
2use crate::mcts_node::MctsNode;
3use crate::random::{RandomGenerator, StandardRandomGenerator};
4use ego_tree::{NodeId, NodeRef, Tree};
5use std::collections::HashSet;
6use std::ops::{Deref, DerefMut};
7
8pub struct MonteCarloTreeSearch<T: Board, K: RandomGenerator> {
12 tree: Tree<MctsNode<T>>,
13 root_id: NodeId,
14 random: K,
15 use_alpha_beta_pruning: bool,
16 next_action: MctsAction,
17}
18
19pub struct MonteCarloTreeSearchBuilder<T: Board, K: RandomGenerator> {
23 board: T,
24 random_generator: K,
25 use_alpha_beta_pruning: bool,
26}
27
28impl<T: Board, K: RandomGenerator> MonteCarloTreeSearchBuilder<T, K> {
29 pub fn new(board: T) -> Self {
31 Self {
32 board,
33 random_generator: K::default(),
34 use_alpha_beta_pruning: true,
35 }
36 }
37
38 pub fn with_random_generator(mut self, rg: K) -> Self {
40 self.random_generator = rg;
41 self
42 }
43
44 pub fn with_alpha_beta_pruning(mut self, use_abp: bool) -> Self {
46 self.use_alpha_beta_pruning = use_abp;
47 self
48 }
49
50 pub fn build(self) -> MonteCarloTreeSearch<T, K> {
52 MonteCarloTreeSearch::new(
53 self.board,
54 self.random_generator,
55 self.use_alpha_beta_pruning,
56 )
57 }
58}
59
60impl<T: Board, K: RandomGenerator> MonteCarloTreeSearch<T, K> {
61 pub fn builder(board: T) -> MonteCarloTreeSearchBuilder<T, K> {
63 MonteCarloTreeSearchBuilder::new(board)
64 }
65
66 pub fn new(board: T, rg: K, use_alpha_beta_pruning: bool) -> Self {
70 let root_mcts_node = MctsNode::new(0, Box::new(board));
71 let tree: Tree<MctsNode<T>> = Tree::new(root_mcts_node);
72 let root_id = tree.root().id();
73
74 Self {
75 tree,
76 root_id: root_id.clone(),
77 random: rg,
78 use_alpha_beta_pruning,
79 next_action: MctsAction::Selection {
80 R: root_id.clone(),
81 RP: vec![],
82 },
83 }
84 }
85
86 pub fn get_tree(&self) -> &Tree<MctsNode<T>> {
88 &self.tree
89 }
90
91 pub fn get_next_mcts_action(&self) -> &MctsAction {
93 &self.next_action
94 }
95
96 pub fn execute_action(&mut self) {
98 match self.next_action.clone() {
99 MctsAction::Selection { R, RP: _cr } => {
100 let maybe_selected_node = self.select_next_node(R);
101 self.next_action = match maybe_selected_node {
102 None => MctsAction::EverythingIsCalculated,
103 Some(selected_node) => MctsAction::Expansion { L: selected_node },
104 };
105 }
106 MctsAction::Expansion { L } => {
107 let (children, selected_child) = self.expand_node(L);
108 self.next_action = MctsAction::Simulation {
109 C: selected_child,
110 AC: children,
111 };
112 }
113 MctsAction::Simulation { C, AC: _ac } => {
114 let outcome = self.simulate(C);
115 self.next_action = MctsAction::Backpropagation { C, result: outcome };
116 }
117 MctsAction::Backpropagation { C, result } => {
118 let affected_nodes = self.backpropagate(C, result);
119 self.next_action = MctsAction::Selection {
120 R: self.root_id.clone(),
121 RP: affected_nodes,
122 }
123 }
124 MctsAction::EverythingIsCalculated => {}
125 }
126 }
127
128 pub fn do_iteration(&mut self) -> Vec<NodeId> {
131 self.execute_action();
132 let mut is_selection = matches!(self.next_action, MctsAction::Selection { R: _, RP: _ });
133 let mut is_fully_calculated =
134 matches!(self.next_action, MctsAction::EverythingIsCalculated);
135 while !is_selection && !is_fully_calculated {
136 self.execute_action();
137 is_selection = matches!(self.next_action, MctsAction::Selection { R: _, RP: _ });
138 is_fully_calculated = matches!(self.next_action, MctsAction::EverythingIsCalculated);
139 }
140
141 match self.next_action.clone() {
142 MctsAction::Selection { R: _, RP: rp } => rp,
143 _ => vec![],
144 }
145 }
146
147 pub fn iterate_n_times(&mut self, n: u32) {
149 let mut iteration = 0;
150 while iteration < n {
151 self.do_iteration();
152 iteration += 1;
153 }
154 }
155
156 pub fn get_root(&self) -> MctsTreeNode<T> {
158 let root = self.tree.root();
159 root.into()
160 }
161
162 fn select_next_node(&self, root_id: NodeId) -> Option<NodeId> {
164 let mut promising_node_id = root_id.clone();
165 let mut has_changed = false;
166 loop {
167 let mut best_child_id: Option<NodeId> = None;
168 let mut max_ucb = f64::MIN;
169 let node = self.tree.get(promising_node_id).unwrap();
170 for child in node.children() {
171 if child.value().is_fully_calculated {
172 continue;
173 }
174
175 let current_ucb = MonteCarloTreeSearch::<T, K>::ucb_value(
176 node.value().visits,
177 child.value().wins,
178 child.value().visits,
179 );
180 if current_ucb > max_ucb {
181 max_ucb = current_ucb;
182 best_child_id = Some(child.id());
183 }
184 }
185 if best_child_id.is_none() {
186 break;
187 }
188 promising_node_id = best_child_id.unwrap();
189 has_changed = true;
190 }
191
192 if has_changed {
193 Some(promising_node_id.clone())
194 } else {
195 let root = self.tree.root();
196 if root.children().count() == 0 {
197 Some(root_id.clone())
198 } else {
199 None
200 }
201 }
202 }
203
204 fn expand_node(&mut self, node_id: NodeId) -> (Vec<NodeId>, NodeId) {
206 let node = self.tree.get(node_id).unwrap();
207 if !node.children().count() == 0 {
208 panic!("BUG: expanding already expanded node");
209 }
210 if node.value().outcome != GameOutcome::InProgress {
211 return (vec![], node_id.clone());
212 }
213
214 let children_height = node.value().height + 1;
215 let all_possible_moves = node.value().board.get_available_moves();
216 let mut new_mcts_nodes = Vec::with_capacity(all_possible_moves.len());
217
218 for possible_move in all_possible_moves {
219 let mut board_clone = node.value().board.clone();
220 board_clone.perform_move(&possible_move);
221 let new_node_id = self.random.next();
222 let mut mcts_node = MctsNode::new(new_node_id, board_clone);
223 mcts_node.prev_move = Some(possible_move);
224 mcts_node.height = children_height;
225 new_mcts_nodes.push(mcts_node);
226 }
227
228 let mut new_node_ids = Vec::with_capacity(new_mcts_nodes.len());
229 for mcts_node in new_mcts_nodes {
230 let mut node = self.tree.get_mut(node_id).unwrap();
231 node.append(mcts_node);
232 new_node_ids.push(node_id.clone());
233 }
234
235 let children: Vec<_> = self.tree.get(node_id).unwrap().children().collect();
236 let selected_child_index = self.random.next_range(0, children.len() as i32) as usize;
237 let selected_child = children[selected_child_index].id();
238 (new_node_ids, selected_child)
239 }
240
241 fn simulate(&mut self, node_id: NodeId) -> GameOutcome {
243 let node = self.tree.get(node_id).unwrap();
244 let mut board = node.value().board.clone();
245 let mut outcome = board.get_outcome();
246 let mut visited_states = HashSet::new();
247 visited_states.insert(board.get_hash());
248
249 while outcome == GameOutcome::InProgress {
250 let mut all_possible_moves = board.get_available_moves();
251
252 while !all_possible_moves.is_empty() {
253 let random_move_index =
254 self.random.next_range(0, all_possible_moves.len() as i32) as usize;
255 let random_move = all_possible_moves.get(random_move_index).unwrap();
256 let mut new_board = board.clone();
257 new_board.perform_move(random_move);
258 let new_board_hash = new_board.get_hash();
259 if visited_states.contains(&new_board_hash) {
260 all_possible_moves.remove(random_move_index);
261 continue;
262 } else {
263 visited_states.insert(new_board_hash);
264 board = new_board;
265 break;
266 }
267 }
268
269 if all_possible_moves.is_empty() {
270 return GameOutcome::Draw;
271 }
272
273 outcome = board.get_outcome();
274 }
275 outcome
276 }
277
278 fn backpropagate(&mut self, node_id: NodeId, outcome: GameOutcome) -> Vec<NodeId> {
280 let mut branch = vec![node_id.clone()];
281
282 loop {
283 let temp_node = self.tree.get(*branch.last().unwrap()).unwrap();
284 match temp_node.parent() {
285 None => break,
286 Some(parent) => branch.push(parent.id()),
287 }
288 }
289
290 let is_win = outcome == GameOutcome::Win;
291 let is_draw = outcome == GameOutcome::Draw;
292
293 for node_id in &branch {
294 let bound = self.get_bound(*node_id);
295 let is_fully_calculated = self.is_fully_calculated(*node_id, bound);
296 let mut temp_node = self.tree.get_mut(*node_id).unwrap();
297 let mcts_node = temp_node.value();
298 mcts_node.visits += 1;
299 if is_win {
300 mcts_node.wins += 1;
301 }
302
303 if is_draw {
304 mcts_node.draws += 1;
305 }
306
307 if is_fully_calculated {
308 mcts_node.is_fully_calculated = true;
309 }
310
311 if bound != Bound::None {
312 mcts_node.bound = bound;
313 }
314 }
315
316 branch
317 }
318
319 fn get_bound(&self, node_id: NodeId) -> Bound {
321 if !self.use_alpha_beta_pruning {
322 return Bound::None;
323 }
324
325 let node = self.tree.get(node_id).unwrap();
326 let mcts_node = node.value();
327 if mcts_node.bound != Bound::None {
328 return mcts_node.bound;
329 }
330
331 if mcts_node.outcome == GameOutcome::Win {
332 return Bound::DefoWin;
333 }
334
335 if mcts_node.outcome == GameOutcome::Lose {
336 return Bound::DefoLose;
337 }
338
339 if node.children().count() == 0 {
340 return Bound::None;
341 }
342
343 match mcts_node.current_player {
344 Player::Me => {
345 if node.children().all(|x| x.value().bound == Bound::DefoLose) {
346 return Bound::DefoLose;
347 }
348
349 if node.children().any(|x| x.value().bound == Bound::DefoWin) {
350 return Bound::DefoWin;
351 }
352 }
353 Player::Other => {
354 if node.children().all(|x| x.value().bound == Bound::DefoWin) {
355 return Bound::DefoWin;
356 }
357
358 if node.children().any(|x| x.value().bound == Bound::DefoLose) {
359 return Bound::DefoLose;
360 }
361 }
362 }
363
364 Bound::None
365 }
366
367 fn is_fully_calculated(&self, node_id: NodeId, bound: Bound) -> bool {
369 if bound != Bound::None {
370 return true;
371 }
372
373 let node = self.tree.get(node_id).unwrap();
374 if node.value().outcome != GameOutcome::InProgress {
375 return true;
376 }
377
378 if node.children().count() == 0 {
379 return false;
380 }
381
382 let all_children_calculated = node.children().all(|x| x.value().is_fully_calculated);
383
384 all_children_calculated
385 }
386
387 fn ucb_value(total_visits: i32, node_wins: i32, node_visit: i32) -> f64 {
389 const EXPLORATION_PARAMETER: f64 = std::f64::consts::SQRT_2;
390
391 if node_visit == 0 {
392 i32::MAX.into()
393 } else {
394 ((node_wins as f64) / (node_visit as f64))
395 + EXPLORATION_PARAMETER
396 * f64::sqrt(f64::ln(total_visits as f64) / (node_visit as f64))
397 }
398 }
399}
400
401impl<T: Board> MonteCarloTreeSearch<T, StandardRandomGenerator> {
402 pub fn from_board(board: T) -> Self {
403 MonteCarloTreeSearchBuilder::new(board).build()
404 }
405}
406
407#[allow(non_snake_case)]
411#[derive(Debug, PartialEq, Clone)]
412pub enum MctsAction {
413 Selection {
415 R: NodeId,
417 RP: Vec<NodeId>,
419 },
420 Expansion {
422 L: NodeId,
424 },
425 Simulation {
427 C: NodeId,
429 AC: Vec<NodeId>,
431 },
432 Backpropagation {
434 C: NodeId,
436 result: GameOutcome,
438 },
439 EverythingIsCalculated,
441}
442
443impl MctsAction {
444 pub fn get_name(&self) -> String {
446 match self {
447 MctsAction::Selection { R: _, RP: _ } => "Selection".to_string(),
448 MctsAction::Expansion { L: _ } => "Expansion".to_string(),
449 MctsAction::Simulation { C: _, AC: _ } => "Simulation".to_string(),
450 MctsAction::Backpropagation { C: _, result: _ } => "Backpropagation".to_string(),
451 MctsAction::EverythingIsCalculated => "EverythingIsCalculated".to_string(),
452 }
453 }
454}
455
456pub struct MctsTreeNode<'a, T: Board>(pub NodeRef<'a, MctsNode<T>>);
457
458impl<'a, T: Board> Deref for MctsTreeNode<'a, T> {
459 type Target = NodeRef<'a, MctsNode<T>>;
460
461 fn deref(&self) -> &Self::Target {
462 &self.0
463 }
464}
465
466impl<'a, T: Board> DerefMut for MctsTreeNode<'a, T> {
467 fn deref_mut(&mut self) -> &mut Self::Target {
468 &mut self.0
469 }
470}
471
472impl<'a, T: Board> Into<NodeRef<'a, MctsNode<T>>> for MctsTreeNode<'a, T> {
473 fn into(self) -> NodeRef<'a, MctsNode<T>> {
474 self.0
475 }
476}
477
478impl<'a, T: Board> From<NodeRef<'a, MctsNode<T>>> for MctsTreeNode<'a, T> {
479 fn from(node: NodeRef<'a, MctsNode<T>>) -> Self {
480 Self(node)
481 }
482}
483
484impl<'a, T: Board> MctsTreeNode<'a, T> {
485 pub fn get_best_child(&self) -> Option<MctsTreeNode<'a, T>> {
487 let mut best_child = None;
488 let mut best_child_value = f64::MIN;
489
490 for child in self
492 .children()
493 .filter(|x| x.value().bound == Bound::DefoWin)
494 {
495 let child_value = child.value().wins_rate();
496 if child_value > best_child_value {
497 best_child = Some(child);
498 best_child_value = child_value;
499 }
500 }
501
502 if best_child.is_some() {
503 return best_child.map(|x| x.into());
504 }
505
506 for child in self.children() {
508 let child_value = child.value().wins_rate();
509 if child_value > best_child_value {
510 best_child = Some(child);
511 best_child_value = child_value;
512 }
513 }
514
515 best_child.map(|x| x.into())
516 }
517}