poke_engine/
mcts.rs

1use crate::evaluate::evaluate;
2use crate::generate_instructions::generate_instructions_from_move_pair;
3use crate::instruction::StateInstructions;
4use crate::state::{MoveChoice, State};
5use rand::distributions::WeightedIndex;
6use rand::prelude::*;
7use rand::thread_rng;
8use std::collections::HashMap;
9use std::time::Duration;
10
11fn sigmoid(x: f32) -> f32 {
12    // Tuned so that ~200 points is very close to 1.0
13    1.0 / (1.0 + (-0.0125 * x).exp())
14}
15
16#[derive(Debug)]
17pub struct Node {
18    pub root: bool,
19    pub parent: *mut Node,
20    pub children: HashMap<(usize, usize), Vec<Node>>,
21    pub times_visited: i64,
22
23    // represents the instructions & s1/s2 moves that led to this node from the parent
24    pub instructions: StateInstructions,
25    pub s1_choice: usize,
26    pub s2_choice: usize,
27
28    // represents the total score and number of visits for this node
29    // de-coupled for s1 and s2
30    pub s1_options: Vec<MoveNode>,
31    pub s2_options: Vec<MoveNode>,
32}
33
34impl Node {
35    fn new(s1_options: Vec<MoveChoice>, s2_options: Vec<MoveChoice>) -> Node {
36        let s1_options_vec = s1_options
37            .iter()
38            .map(|x| MoveNode {
39                move_choice: x.clone(),
40                total_score: 0.0,
41                visits: 0,
42            })
43            .collect();
44        let s2_options_vec = s2_options
45            .iter()
46            .map(|x| MoveNode {
47                move_choice: x.clone(),
48                total_score: 0.0,
49                visits: 0,
50            })
51            .collect();
52
53        Node {
54            root: false,
55            parent: std::ptr::null_mut(),
56            instructions: StateInstructions::default(),
57            times_visited: 0,
58            children: HashMap::new(),
59            s1_choice: 0,
60            s2_choice: 0,
61            s1_options: s1_options_vec,
62            s2_options: s2_options_vec,
63        }
64    }
65
66    pub fn maximize_ucb_for_side(&self, side_map: &[MoveNode]) -> usize {
67        let mut choice = 0;
68        let mut best_ucb1 = f32::MIN;
69        for (index, node) in side_map.iter().enumerate() {
70            let this_ucb1 = node.ucb1(self.times_visited);
71            if this_ucb1 > best_ucb1 {
72                best_ucb1 = this_ucb1;
73                choice = index;
74            }
75        }
76        choice
77    }
78
79    pub unsafe fn selection(&mut self, state: &mut State) -> (*mut Node, usize, usize) {
80        let return_node = self as *mut Node;
81
82        let s1_mc_index = self.maximize_ucb_for_side(&self.s1_options);
83        let s2_mc_index = self.maximize_ucb_for_side(&self.s2_options);
84        let child_vector = self.children.get_mut(&(s1_mc_index, s2_mc_index));
85        match child_vector {
86            Some(child_vector) => {
87                let child_vec_ptr = child_vector as *mut Vec<Node>;
88                let chosen_child = self.sample_node(child_vec_ptr);
89                state.apply_instructions(&(*chosen_child).instructions.instruction_list);
90                (*chosen_child).selection(state)
91            }
92            None => (return_node, s1_mc_index, s2_mc_index),
93        }
94    }
95
96    unsafe fn sample_node(&self, move_vector: *mut Vec<Node>) -> *mut Node {
97        let mut rng = thread_rng();
98        let weights: Vec<f64> = (*move_vector)
99            .iter()
100            .map(|x| x.instructions.percentage as f64)
101            .collect();
102        let dist = WeightedIndex::new(weights).unwrap();
103        let chosen_node = &mut (*move_vector)[dist.sample(&mut rng)];
104        let chosen_node_ptr = chosen_node as *mut Node;
105        chosen_node_ptr
106    }
107
108    pub unsafe fn expand(
109        &mut self,
110        state: &mut State,
111        s1_move_index: usize,
112        s2_move_index: usize,
113    ) -> *mut Node {
114        let s1_move = &self.s1_options[s1_move_index].move_choice;
115        let s2_move = &self.s2_options[s2_move_index].move_choice;
116        // if the battle is over or both moves are none there is no need to expand
117        if (state.battle_is_over() != 0.0 && !self.root)
118            || (s1_move == &MoveChoice::None && s2_move == &MoveChoice::None)
119        {
120            return self as *mut Node;
121        }
122        let should_branch_on_damage = self.root || (*self.parent).root;
123        let mut new_instructions =
124            generate_instructions_from_move_pair(state, s1_move, s2_move, should_branch_on_damage);
125        let mut this_pair_vec = Vec::with_capacity(2);
126        for state_instructions in new_instructions.drain(..) {
127            state.apply_instructions(&state_instructions.instruction_list);
128            let (s1_options, s2_options) = state.get_all_options();
129            state.reverse_instructions(&state_instructions.instruction_list);
130
131            let mut new_node = Node::new(s1_options, s2_options);
132            new_node.parent = self;
133            new_node.instructions = state_instructions;
134            new_node.s1_choice = s1_move_index;
135            new_node.s2_choice = s2_move_index;
136
137            this_pair_vec.push(new_node);
138        }
139
140        // sample a node from the new instruction list.
141        // this is the node that the rollout will be done on
142        let new_node_ptr = self.sample_node(&mut this_pair_vec);
143        state.apply_instructions(&(*new_node_ptr).instructions.instruction_list);
144        self.children
145            .insert((s1_move_index, s2_move_index), this_pair_vec);
146        new_node_ptr
147    }
148
149    pub unsafe fn backpropagate(&mut self, score: f32, state: &mut State) {
150        self.times_visited += 1;
151        if self.root {
152            return;
153        }
154
155        let parent_s1_movenode = &mut (*self.parent).s1_options[self.s1_choice];
156        parent_s1_movenode.total_score += score;
157        parent_s1_movenode.visits += 1;
158
159        let parent_s2_movenode = &mut (*self.parent).s2_options[self.s2_choice];
160        parent_s2_movenode.total_score += 1.0 - score;
161        parent_s2_movenode.visits += 1;
162
163        state.reverse_instructions(&self.instructions.instruction_list);
164        (*self.parent).backpropagate(score, state);
165    }
166
167    pub fn rollout(&mut self, state: &mut State, root_eval: &f32) -> f32 {
168        let battle_is_over = state.battle_is_over();
169        if battle_is_over == 0.0 {
170            let eval = evaluate(state);
171            sigmoid(eval - root_eval)
172        } else {
173            if battle_is_over == -1.0 {
174                0.0
175            } else {
176                battle_is_over
177            }
178        }
179    }
180}
181
182#[derive(Debug)]
183pub struct MoveNode {
184    pub move_choice: MoveChoice,
185    pub total_score: f32,
186    pub visits: i64,
187}
188
189impl MoveNode {
190    pub fn ucb1(&self, parent_visits: i64) -> f32 {
191        if self.visits == 0 {
192            return f32::INFINITY;
193        }
194        let score = (self.total_score / self.visits as f32)
195            + (2.0 * (parent_visits as f32).ln() / self.visits as f32).sqrt();
196        score
197    }
198    pub fn average_score(&self) -> f32 {
199        let score = self.total_score / self.visits as f32;
200        score
201    }
202}
203
204#[derive(Clone)]
205pub struct MctsSideResult {
206    pub move_choice: MoveChoice,
207    pub total_score: f32,
208    pub visits: i64,
209}
210
211impl MctsSideResult {
212    pub fn average_score(&self) -> f32 {
213        if self.visits == 0 {
214            return 0.0;
215        }
216        let score = self.total_score / self.visits as f32;
217        score
218    }
219}
220
221pub struct MctsResult {
222    pub s1: Vec<MctsSideResult>,
223    pub s2: Vec<MctsSideResult>,
224    pub iteration_count: i64,
225}
226
227fn do_mcts(root_node: &mut Node, state: &mut State, root_eval: &f32) {
228    let (mut new_node, s1_move, s2_move) = unsafe { root_node.selection(state) };
229    new_node = unsafe { (*new_node).expand(state, s1_move, s2_move) };
230    let rollout_result = unsafe { (*new_node).rollout(state, root_eval) };
231    unsafe { (*new_node).backpropagate(rollout_result, state) }
232}
233
234pub fn perform_mcts(
235    state: &mut State,
236    side_one_options: Vec<MoveChoice>,
237    side_two_options: Vec<MoveChoice>,
238    max_time: Duration,
239) -> MctsResult {
240    let mut root_node = Node::new(side_one_options, side_two_options);
241    root_node.root = true;
242
243    let root_eval = evaluate(state);
244    let start_time = std::time::Instant::now();
245    while start_time.elapsed() < max_time {
246        for _ in 0..1000 {
247            do_mcts(&mut root_node, state, &root_eval);
248        }
249
250        /*
251        Cut off after 10 million iterations
252
253        Under normal circumstances the bot will only run for 2.5-3.5 million iterations
254        however towards the end of a battle the bot may perform tens of millions of iterations
255
256        Beyond about 30 million iterations some floating point nonsense happens where
257        MoveNode.total_score stops updating because f32 does not have enough precision
258
259        I can push the problem farther out by using f64 but if the bot is running for 10 million iterations
260        then it almost certainly sees a forced win
261        */
262        if root_node.times_visited == 10_000_000 {
263            break;
264        }
265    }
266
267    let result = MctsResult {
268        s1: root_node
269            .s1_options
270            .iter()
271            .map(|v| MctsSideResult {
272                move_choice: v.move_choice.clone(),
273                total_score: v.total_score,
274                visits: v.visits,
275            })
276            .collect(),
277        s2: root_node
278            .s2_options
279            .iter()
280            .map(|v| MctsSideResult {
281                move_choice: v.move_choice.clone(),
282                total_score: v.total_score,
283                visits: v.visits,
284            })
285            .collect(),
286        iteration_count: root_node.times_visited,
287    };
288
289    result
290}