poke_engine/
mcts.rs

1use crate::engine::evaluate::evaluate;
2use crate::engine::generate_instructions::generate_instructions_from_move_pair;
3use crate::engine::state::MoveChoice;
4use crate::instruction::StateInstructions;
5use crate::state::State;
6use rand::distributions::WeightedIndex;
7use rand::prelude::*;
8use rand::thread_rng;
9use std::collections::HashMap;
10use std::time::Duration;
11
12fn sigmoid(x: f32) -> f32 {
13    // Tuned so that ~200 points is very close to 1.0
14    1.0 / (1.0 + (-0.0125 * x).exp())
15}
16
17#[derive(Debug)]
18pub struct Node {
19    pub root: bool,
20    pub parent: *mut Node,
21    pub children: HashMap<(usize, usize), Vec<Node>>,
22    pub times_visited: i64,
23
24    // represents the instructions & s1/s2 moves that led to this node from the parent
25    pub instructions: StateInstructions,
26    pub s1_choice: usize,
27    pub s2_choice: usize,
28
29    // represents the total score and number of visits for this node
30    // de-coupled for s1 and s2
31    pub s1_options: Vec<MoveNode>,
32    pub s2_options: Vec<MoveNode>,
33}
34
35impl Node {
36    fn new(s1_options: Vec<MoveChoice>, s2_options: Vec<MoveChoice>) -> Node {
37        let s1_options_vec = s1_options
38            .iter()
39            .map(|x| MoveNode {
40                move_choice: x.clone(),
41                total_score: 0.0,
42                visits: 0,
43            })
44            .collect();
45        let s2_options_vec = s2_options
46            .iter()
47            .map(|x| MoveNode {
48                move_choice: x.clone(),
49                total_score: 0.0,
50                visits: 0,
51            })
52            .collect();
53
54        Node {
55            root: false,
56            parent: std::ptr::null_mut(),
57            instructions: StateInstructions::default(),
58            times_visited: 0,
59            children: HashMap::new(),
60            s1_choice: 0,
61            s2_choice: 0,
62            s1_options: s1_options_vec,
63            s2_options: s2_options_vec,
64        }
65    }
66
67    pub fn maximize_ucb_for_side(&self, side_map: &[MoveNode]) -> usize {
68        let mut choice = 0;
69        let mut best_ucb1 = f32::MIN;
70        for (index, node) in side_map.iter().enumerate() {
71            let this_ucb1 = node.ucb1(self.times_visited);
72            if this_ucb1 > best_ucb1 {
73                best_ucb1 = this_ucb1;
74                choice = index;
75            }
76        }
77        choice
78    }
79
80    pub unsafe fn selection(&mut self, state: &mut State) -> (*mut Node, usize, usize) {
81        let return_node = self as *mut Node;
82
83        let s1_mc_index = self.maximize_ucb_for_side(&self.s1_options);
84        let s2_mc_index = self.maximize_ucb_for_side(&self.s2_options);
85        let child_vector = self.children.get_mut(&(s1_mc_index, s2_mc_index));
86        match child_vector {
87            Some(child_vector) => {
88                let child_vec_ptr = child_vector as *mut Vec<Node>;
89                let chosen_child = self.sample_node(child_vec_ptr);
90                state.apply_instructions(&(*chosen_child).instructions.instruction_list);
91                (*chosen_child).selection(state)
92            }
93            None => (return_node, s1_mc_index, s2_mc_index),
94        }
95    }
96
97    unsafe fn sample_node(&self, move_vector: *mut Vec<Node>) -> *mut Node {
98        let mut rng = thread_rng();
99        let weights: Vec<f64> = (*move_vector)
100            .iter()
101            .map(|x| x.instructions.percentage as f64)
102            .collect();
103        let dist = WeightedIndex::new(weights).unwrap();
104        let chosen_node = &mut (*move_vector)[dist.sample(&mut rng)];
105        let chosen_node_ptr = chosen_node as *mut Node;
106        chosen_node_ptr
107    }
108
109    pub unsafe fn expand(
110        &mut self,
111        state: &mut State,
112        s1_move_index: usize,
113        s2_move_index: usize,
114    ) -> *mut Node {
115        let s1_move = &self.s1_options[s1_move_index].move_choice;
116        let s2_move = &self.s2_options[s2_move_index].move_choice;
117        // if the battle is over or both moves are none there is no need to expand
118        if (state.battle_is_over() != 0.0 && !self.root)
119            || (s1_move == &MoveChoice::None && s2_move == &MoveChoice::None)
120        {
121            return self as *mut Node;
122        }
123        let should_branch_on_damage = self.root || (*self.parent).root;
124        let mut new_instructions =
125            generate_instructions_from_move_pair(state, s1_move, s2_move, should_branch_on_damage);
126        let mut this_pair_vec = Vec::with_capacity(new_instructions.len());
127        for state_instructions in new_instructions.drain(..) {
128            state.apply_instructions(&state_instructions.instruction_list);
129            let (s1_options, s2_options) = state.get_all_options();
130            state.reverse_instructions(&state_instructions.instruction_list);
131
132            let mut new_node = Node::new(s1_options, s2_options);
133            new_node.parent = self;
134            new_node.instructions = state_instructions;
135            new_node.s1_choice = s1_move_index;
136            new_node.s2_choice = s2_move_index;
137
138            this_pair_vec.push(new_node);
139        }
140
141        // sample a node from the new instruction list.
142        // this is the node that the rollout will be done on
143        let new_node_ptr = self.sample_node(&mut this_pair_vec);
144        state.apply_instructions(&(*new_node_ptr).instructions.instruction_list);
145        self.children
146            .insert((s1_move_index, s2_move_index), this_pair_vec);
147        new_node_ptr
148    }
149
150    pub unsafe fn backpropagate(&mut self, score: f32, state: &mut State) {
151        self.times_visited += 1;
152        if self.root {
153            return;
154        }
155
156        let parent_s1_movenode = &mut (*self.parent).s1_options[self.s1_choice];
157        parent_s1_movenode.total_score += score;
158        parent_s1_movenode.visits += 1;
159
160        let parent_s2_movenode = &mut (*self.parent).s2_options[self.s2_choice];
161        parent_s2_movenode.total_score += 1.0 - score;
162        parent_s2_movenode.visits += 1;
163
164        state.reverse_instructions(&self.instructions.instruction_list);
165        (*self.parent).backpropagate(score, state);
166    }
167
168    pub fn rollout(&mut self, state: &mut State, root_eval: &f32) -> f32 {
169        let battle_is_over = state.battle_is_over();
170        if battle_is_over == 0.0 {
171            let eval = evaluate(state);
172            sigmoid(eval - root_eval)
173        } else {
174            if battle_is_over == -1.0 {
175                0.0
176            } else {
177                battle_is_over
178            }
179        }
180    }
181}
182
183#[derive(Debug)]
184pub struct MoveNode {
185    pub move_choice: MoveChoice,
186    pub total_score: f32,
187    pub visits: i64,
188}
189
190impl MoveNode {
191    pub fn ucb1(&self, parent_visits: i64) -> f32 {
192        if self.visits == 0 {
193            return f32::INFINITY;
194        }
195        let score = (self.total_score / self.visits as f32)
196            + (2.0 * (parent_visits as f32).ln() / self.visits as f32).sqrt();
197        score
198    }
199    pub fn average_score(&self) -> f32 {
200        let score = self.total_score / self.visits as f32;
201        score
202    }
203}
204
205#[derive(Clone)]
206pub struct MctsSideResult {
207    pub move_choice: MoveChoice,
208    pub total_score: f32,
209    pub visits: i64,
210}
211
212impl MctsSideResult {
213    pub fn average_score(&self) -> f32 {
214        if self.visits == 0 {
215            return 0.0;
216        }
217        let score = self.total_score / self.visits as f32;
218        score
219    }
220}
221
222pub struct MctsResult {
223    pub s1: Vec<MctsSideResult>,
224    pub s2: Vec<MctsSideResult>,
225    pub iteration_count: i64,
226}
227
228fn do_mcts(root_node: &mut Node, state: &mut State, root_eval: &f32) {
229    let (mut new_node, s1_move, s2_move) = unsafe { root_node.selection(state) };
230    new_node = unsafe { (*new_node).expand(state, s1_move, s2_move) };
231    let rollout_result = unsafe { (*new_node).rollout(state, root_eval) };
232    unsafe { (*new_node).backpropagate(rollout_result, state) }
233}
234
235pub fn perform_mcts(
236    state: &mut State,
237    side_one_options: Vec<MoveChoice>,
238    side_two_options: Vec<MoveChoice>,
239    max_time: Duration,
240) -> MctsResult {
241    let mut root_node = Node::new(side_one_options, side_two_options);
242    root_node.root = true;
243
244    let root_eval = evaluate(state);
245    let start_time = std::time::Instant::now();
246    while start_time.elapsed() < max_time {
247        for _ in 0..1000 {
248            do_mcts(&mut root_node, state, &root_eval);
249        }
250
251        /*
252        Cut off after 10 million iterations
253
254        Under normal circumstances the bot will only run for 2.5-3.5 million iterations
255        however towards the end of a battle the bot may perform tens of millions of iterations
256
257        Beyond about 30 million iterations some floating point nonsense happens where
258        MoveNode.total_score stops updating because f32 does not have enough precision
259
260        I can push the problem farther out by using f64 but if the bot is running for 10 million iterations
261        then it almost certainly sees a forced win
262        */
263        if root_node.times_visited == 10_000_000 {
264            break;
265        }
266    }
267
268    let result = MctsResult {
269        s1: root_node
270            .s1_options
271            .iter()
272            .map(|v| MctsSideResult {
273                move_choice: v.move_choice.clone(),
274                total_score: v.total_score,
275                visits: v.visits,
276            })
277            .collect(),
278        s2: root_node
279            .s2_options
280            .iter()
281            .map(|v| MctsSideResult {
282                move_choice: v.move_choice.clone(),
283                total_score: v.total_score,
284                visits: v.visits,
285            })
286            .collect(),
287        iteration_count: root_node.times_visited,
288    };
289
290    result
291}