1use crate::board::{Board, Bound, GameOutcome, Player};
2use crate::mcts_node::MctsNode;
3use crate::random::{RandomGenerator, StandardRandomGenerator};
4use ego_tree::{NodeId, NodeRef, Tree};
5use std::collections::HashSet;
6use std::ops::{Deref, DerefMut};
7
8pub struct MonteCarloTreeSearch<T: Board, K: RandomGenerator> {
12 tree: Tree<MctsNode<T>>,
13 root_id: NodeId,
14 random: K,
15 use_alpha_beta_pruning: bool,
16 next_action: MctsAction,
17}
18
19pub struct MonteCarloTreeSearchBuilder<T: Board, K: RandomGenerator> {
23 board: T,
24 random_generator: K,
25 use_alpha_beta_pruning: bool,
26}
27
28impl<T: Board, K: RandomGenerator> MonteCarloTreeSearchBuilder<T, K> {
29 pub fn new(board: T) -> Self {
31 Self {
32 board,
33 random_generator: K::default(),
34 use_alpha_beta_pruning: true,
35 }
36 }
37
38 pub fn with_random_generator(mut self, rg: K) -> Self {
40 self.random_generator = rg;
41 self
42 }
43
44 pub fn with_alpha_beta_pruning(mut self, use_abp: bool) -> Self {
46 self.use_alpha_beta_pruning = use_abp;
47 self
48 }
49
50 pub fn build(self) -> MonteCarloTreeSearch<T, K> {
52 MonteCarloTreeSearch::new(
53 self.board,
54 self.random_generator,
55 self.use_alpha_beta_pruning,
56 )
57 }
58}
59
60impl<T: Board, K: RandomGenerator> MonteCarloTreeSearch<T, K> {
61 pub fn builder(board: T) -> MonteCarloTreeSearchBuilder<T, K> {
63 MonteCarloTreeSearchBuilder::new(board)
64 }
65
66 pub fn new(board: T, rg: K, use_alpha_beta_pruning: bool) -> Self {
70 let root_mcts_node = MctsNode::new(0, Box::new(board));
71 let tree: Tree<MctsNode<T>> = Tree::new(root_mcts_node);
72 let root_id = tree.root().id();
73
74 Self {
75 tree,
76 root_id: root_id.clone(),
77 random: rg,
78 use_alpha_beta_pruning,
79 next_action: MctsAction::Selection {
80 R: root_id.clone(),
81 RP: vec![],
82 },
83 }
84 }
85
86 pub fn get_tree(&self) -> &Tree<MctsNode<T>> {
88 &self.tree
89 }
90
91 pub fn get_next_mcts_action(&self) -> &MctsAction {
93 &self.next_action
94 }
95
96 pub fn execute_action(&mut self) {
98 match self.next_action.clone() {
99 MctsAction::Selection { R, RP: _cr } => {
100 let maybe_selected_node = self.select_next_node(R);
101 self.next_action = match maybe_selected_node {
102 None => MctsAction::EverythingIsCalculated,
103 Some(selected_node) => MctsAction::Expansion { L: selected_node },
104 };
105 }
106 MctsAction::Expansion { L } => {
107 let (children, selected_child) = self.expand_node(L);
108 self.next_action = MctsAction::Simulation {
109 C: selected_child,
110 AC: children,
111 };
112 }
113 MctsAction::Simulation { C, AC: _ac } => {
114 let outcome = self.simulate(C);
115 self.next_action = MctsAction::Backpropagation { C, result: outcome };
116 }
117 MctsAction::Backpropagation { C, result } => {
118 let affected_nodes = self.backpropagate(C, result);
119 self.next_action = MctsAction::Selection {
120 R: self.root_id.clone(),
121 RP: affected_nodes,
122 }
123 }
124 MctsAction::EverythingIsCalculated => {}
125 }
126 }
127
128 pub fn do_iteration(&mut self) -> Vec<NodeId> {
131 self.execute_action();
132 let mut is_selection = matches!(self.next_action, MctsAction::Selection { R: _, RP: _ });
133 let mut is_fully_calculated =
134 matches!(self.next_action, MctsAction::EverythingIsCalculated);
135 while !is_selection && !is_fully_calculated {
136 self.execute_action();
137 is_selection = matches!(self.next_action, MctsAction::Selection { R: _, RP: _ });
138 is_fully_calculated = matches!(self.next_action, MctsAction::EverythingIsCalculated);
139 }
140
141 match self.next_action.clone() {
142 MctsAction::Selection { R: _, RP: rp } => rp,
143 _ => vec![],
144 }
145 }
146
147 pub fn iterate_n_times(&mut self, n: u32) {
149 let mut iteration = 0;
150 while iteration < n {
151 self.do_iteration();
152 iteration += 1;
153 }
154 }
155
156 pub fn get_root(&self) -> MctsTreeNode<T> {
158 let root = self.tree.root();
159 root.into()
160 }
161
162 fn select_next_node(&self, root_id: NodeId) -> Option<NodeId> {
164 let mut promising_node_id = root_id.clone();
165 let mut has_changed = false;
166 loop {
167 let mut best_child_id: Option<NodeId> = None;
168 let mut max_ucb = f64::MIN;
169 let node = self.tree.get(promising_node_id).unwrap();
170 for child in node.children() {
171 if child.value().is_fully_calculated {
172 continue;
173 }
174
175 let current_ucb = MonteCarloTreeSearch::<T, K>::ucb_value(
176 node.value().visits,
177 child.value().wins,
178 child.value().visits,
179 );
180 if current_ucb > max_ucb {
181 max_ucb = current_ucb;
182 best_child_id = Some(child.id());
183 }
184 }
185 if best_child_id.is_none() {
186 break;
187 }
188 promising_node_id = best_child_id.unwrap();
189 has_changed = true;
190 }
191
192 if has_changed {
193 Some(promising_node_id.clone())
194 } else {
195 let root = self.tree.root();
196 if root.children().count() == 0 {
197 Some(root_id.clone())
198 } else {
199 None
200 }
201 }
202 }
203
204 fn expand_node(&mut self, node_id: NodeId) -> (Vec<NodeId>, NodeId) {
206 let node = self.tree.get(node_id).unwrap();
207 if !node.children().count() == 0 {
208 panic!("BUG: expanding already expanded node");
209 }
210 if node.value().outcome != GameOutcome::InProgress {
211 return (vec![], node_id.clone());
212 }
213
214
215 let children_height = node.value().height + 1;
216 let all_possible_moves = self.get_available_moves(node_id);
217 let mut new_mcts_nodes = Vec::with_capacity(all_possible_moves.len());
218
219 for possible_move in all_possible_moves {
220 let mut board_clone = node.value().board.clone();
221 board_clone.perform_move(&possible_move);
222 let new_node_id = self.random.next();
223 let mut mcts_node = MctsNode::new(new_node_id, board_clone);
224 mcts_node.prev_move = Some(possible_move);
225 mcts_node.height = children_height;
226 new_mcts_nodes.push(mcts_node);
227 }
228
229 let mut new_node_ids = Vec::with_capacity(new_mcts_nodes.len());
230 for mcts_node in new_mcts_nodes {
231 let mut node = self.tree.get_mut(node_id).unwrap();
232 node.append(mcts_node);
233 new_node_ids.push(node_id.clone());
234 }
235
236 let children: Vec<_> = self.tree.get(node_id).unwrap().children().collect();
237 let selected_child_index = self.random.next_range(0, children.len() as i32) as usize;
238 let selected_child = children[selected_child_index].id();
239 (new_node_ids, selected_child)
240 }
241
242 fn simulate(&mut self, node_id: NodeId) -> GameOutcome {
244 let node = self.tree.get(node_id).unwrap();
245 let mut board = node.value().board.clone();
246 let mut outcome = board.get_outcome();
247 let mut hashes = self.get_branch_hashes(node_id);
248
249 while outcome == GameOutcome::InProgress {
250 let mut all_possible_moves = board.get_available_moves();
251
252 while !all_possible_moves.is_empty() {
253 let random_move_index =
254 self.random.next_range(0, all_possible_moves.len() as i32) as usize;
255 let random_move = all_possible_moves.get(random_move_index).unwrap();
256 let mut new_board = board.clone();
257 new_board.perform_move(random_move);
258 let new_board_hash = new_board.get_hash();
259 if hashes.contains(&new_board_hash) {
260 all_possible_moves.remove(random_move_index);
261 continue;
262 } else {
263 hashes.insert(new_board_hash);
264 board = new_board;
265 break;
266 }
267 }
268
269 if all_possible_moves.is_empty() {
270 return GameOutcome::Lose;
271 }
272
273 outcome = board.get_outcome();
274 }
275 outcome
276 }
277
278 fn backpropagate(&mut self, node_id: NodeId, outcome: GameOutcome) -> Vec<NodeId> {
280 let mut branch = vec![node_id.clone()];
281
282 loop {
283 let temp_node = self.tree.get(*branch.last().unwrap()).unwrap();
284 match temp_node.parent() {
285 None => break,
286 Some(parent) => branch.push(parent.id()),
287 }
288 }
289
290 let is_win = outcome == GameOutcome::Win;
291 let is_draw = outcome == GameOutcome::Draw;
292
293 for node_id in &branch {
294 let bound = self.get_bound(*node_id);
295 let is_fully_calculated = self.is_fully_calculated(*node_id, bound);
296 let mut temp_node = self.tree.get_mut(*node_id).unwrap();
297 let mcts_node = temp_node.value();
298 mcts_node.visits += 1;
299 if is_win {
300 mcts_node.wins += 1;
301 }
302
303 if is_draw {
304 mcts_node.draws += 1;
305 }
306
307 if is_fully_calculated {
308 mcts_node.is_fully_calculated = true;
309 }
310
311 if bound != Bound::None {
312 mcts_node.bound = bound;
313 }
314 }
315
316 branch
317 }
318
319 fn get_bound(&self, node_id: NodeId) -> Bound {
321 if !self.use_alpha_beta_pruning {
322 return Bound::None;
323 }
324
325 let node = self.tree.get(node_id).unwrap();
326 let mcts_node = node.value();
327 if mcts_node.bound != Bound::None {
328 return mcts_node.bound;
329 }
330
331 if mcts_node.outcome == GameOutcome::Win {
332 return Bound::DefoWin;
333 }
334
335 if mcts_node.outcome == GameOutcome::Lose {
336 return Bound::DefoLose;
337 }
338
339 if node.children().count() == 0 {
340 return Bound::None;
341 }
342
343 match mcts_node.current_player {
344 Player::Me => {
345 if node.children().all(|x| x.value().bound == Bound::DefoLose) {
346 return Bound::DefoLose;
347 }
348
349 if node.children().any(|x| x.value().bound == Bound::DefoWin) {
350 return Bound::DefoWin;
351 }
352 }
353 Player::Other => {
354 if node.children().all(|x| x.value().bound == Bound::DefoWin) {
355 return Bound::DefoWin;
356 }
357
358 if node.children().any(|x| x.value().bound == Bound::DefoLose) {
359 return Bound::DefoLose;
360 }
361 }
362 }
363
364 Bound::None
365 }
366
367 fn is_fully_calculated(&self, node_id: NodeId, bound: Bound) -> bool {
369 if bound != Bound::None {
370 return true;
371 }
372
373 let node = self.tree.get(node_id).unwrap();
374 if node.value().outcome != GameOutcome::InProgress {
375 return true;
376 }
377
378 if node.children().count() == 0 {
379 return false;
380 }
381
382 let all_children_calculated = node.children().all(|x| x.value().is_fully_calculated);
383
384 all_children_calculated
385 }
386
387 fn ucb_value(total_visits: i32, node_wins: i32, node_visit: i32) -> f64 {
389 const EXPLORATION_PARAMETER: f64 = std::f64::consts::SQRT_2;
390
391 if node_visit == 0 {
392 i32::MAX.into()
393 } else {
394 ((node_wins as f64) / (node_visit as f64))
395 + EXPLORATION_PARAMETER
396 * f64::sqrt(f64::ln(total_visits as f64) / (node_visit as f64))
397 }
398 }
399
400 fn get_branch_hashes(&self, node_id: NodeId) -> HashSet<u128> {
402 let mut current_node = self.tree.get(node_id).unwrap();
403 let mut branch_hashes = HashSet::with_capacity(current_node.value().height + 1);
404 loop {
405 branch_hashes.insert(current_node.value().board_hash);
406 match current_node.parent() {
407 None => break,
408 Some(parent) => current_node = parent,
409 }
410 }
411 branch_hashes
412 }
413
414 fn get_available_moves(&self, node_id: NodeId) -> Vec<T::Move> {
416 let node = self.tree.get(node_id).unwrap();
417 let hashes = self.get_branch_hashes(node_id);
418
419 let available_moves = node.value().board.get_available_moves();
420 let mut filtered_moves = Vec::with_capacity(available_moves.len());
421 for available_move in &available_moves {
422 let mut board_clone = node.value().board.clone();
423 board_clone.perform_move(available_move);
424 let hash = board_clone.get_hash();
425 if hashes.contains(&hash) {
426 filtered_moves.push(available_move);
427 }
428 }
429 available_moves
430 }
431}
432
433impl<T: Board> MonteCarloTreeSearch<T, StandardRandomGenerator> {
434 pub fn from_board(board: T) -> Self {
435 MonteCarloTreeSearchBuilder::new(board).build()
436 }
437}
438
439#[allow(non_snake_case)]
443#[derive(Debug, PartialEq, Clone)]
444pub enum MctsAction {
445 Selection {
447 R: NodeId,
449 RP: Vec<NodeId>,
451 },
452 Expansion {
454 L: NodeId,
456 },
457 Simulation {
459 C: NodeId,
461 AC: Vec<NodeId>,
463 },
464 Backpropagation {
466 C: NodeId,
468 result: GameOutcome,
470 },
471 EverythingIsCalculated,
473}
474
475impl MctsAction {
476 pub fn get_name(&self) -> String {
478 match self {
479 MctsAction::Selection { R: _, RP: _ } => "Selection".to_string(),
480 MctsAction::Expansion { L: _ } => "Expansion".to_string(),
481 MctsAction::Simulation { C: _, AC: _ } => "Simulation".to_string(),
482 MctsAction::Backpropagation { C: _, result: _ } => "Backpropagation".to_string(),
483 MctsAction::EverythingIsCalculated => "EverythingIsCalculated".to_string(),
484 }
485 }
486}
487
488pub struct MctsTreeNode<'a, T: Board>(pub NodeRef<'a, MctsNode<T>>);
489
490impl<'a, T: Board> Deref for MctsTreeNode<'a, T> {
491 type Target = NodeRef<'a, MctsNode<T>>;
492
493 fn deref(&self) -> &Self::Target {
494 &self.0
495 }
496}
497
498impl<'a, T: Board> DerefMut for MctsTreeNode<'a, T> {
499 fn deref_mut(&mut self) -> &mut Self::Target {
500 &mut self.0
501 }
502}
503
504impl<'a, T: Board> Into<NodeRef<'a, MctsNode<T>>> for MctsTreeNode<'a, T> {
505 fn into(self) -> NodeRef<'a, MctsNode<T>> {
506 self.0
507 }
508}
509
510impl<'a, T: Board> From<NodeRef<'a, MctsNode<T>>> for MctsTreeNode<'a, T> {
511 fn from(node: NodeRef<'a, MctsNode<T>>) -> Self {
512 Self(node)
513 }
514}
515
516impl<'a, T: Board> MctsTreeNode<'a, T> {
517 pub fn get_best_child(&self) -> Option<MctsTreeNode<'a, T>> {
519 let mut best_child = None;
520 let mut best_child_value = f64::MIN;
521
522 for child in self
524 .children()
525 .filter(|x| x.value().bound == Bound::DefoWin)
526 {
527 let child_value = child.value().wins_rate();
528 if child_value > best_child_value {
529 best_child = Some(child);
530 best_child_value = child_value;
531 }
532 }
533
534 if best_child.is_some() {
535 return best_child.map(|x| x.into());
536 }
537
538 for child in self.children() {
540 let child_value = child.value().wins_rate();
541 if child_value > best_child_value {
542 best_child = Some(child);
543 best_child_value = child_value;
544 }
545 }
546
547 best_child.map(|x| x.into())
548 }
549}