poke_engine/
mcts.rs

1use crate::engine::evaluate::evaluate;
2use crate::engine::generate_instructions::generate_instructions_from_move_pair;
3use crate::engine::state::MoveChoice;
4use crate::instruction::StateInstructions;
5use crate::state::State;
6use rand::distributions::WeightedIndex;
7use rand::prelude::*;
8use rand::thread_rng;
9use std::collections::HashMap;
10use std::time::Duration;
11
12fn sigmoid(x: f32) -> f32 {
13    // Tuned so that ~200 points is very close to 1.0
14    1.0 / (1.0 + (-0.0125 * x).exp())
15}
16
17#[derive(Debug)]
18pub struct Node {
19    pub root: bool,
20    pub parent: *mut Node,
21    pub children: HashMap<(usize, usize), Vec<Node>>,
22    pub times_visited: u32,
23
24    // represents the instructions & s1/s2 moves that led to this node from the parent
25    pub instructions: StateInstructions,
26    pub s1_choice: u8,
27    pub s2_choice: u8,
28
29    // represents the total score and number of visits for this node
30    // de-coupled for s1 and s2
31    pub s1_options: Option<Vec<MoveNode>>,
32    pub s2_options: Option<Vec<MoveNode>>,
33}
34
35impl Node {
36    fn new() -> Node {
37        Node {
38            root: false,
39            parent: std::ptr::null_mut(),
40            instructions: StateInstructions::default(),
41            times_visited: 0,
42            children: HashMap::new(),
43            s1_choice: 0,
44            s2_choice: 0,
45            s1_options: None,
46            s2_options: None,
47        }
48    }
49    unsafe fn populate(&mut self, s1_options: Vec<MoveChoice>, s2_options: Vec<MoveChoice>) {
50        let s1_options_vec: Vec<MoveNode> = s1_options
51            .iter()
52            .map(|x| MoveNode {
53                move_choice: x.clone(),
54                total_score: 0.0,
55                visits: 0,
56            })
57            .collect();
58        let s2_options_vec: Vec<MoveNode> = s2_options
59            .iter()
60            .map(|x| MoveNode {
61                move_choice: x.clone(),
62                total_score: 0.0,
63                visits: 0,
64            })
65            .collect();
66
67        self.s1_options = Some(s1_options_vec);
68        self.s2_options = Some(s2_options_vec);
69    }
70
71    pub fn maximize_ucb_for_side(&self, side_map: &[MoveNode]) -> usize {
72        let mut choice = 0;
73        let mut best_ucb1 = f32::MIN;
74        for (index, node) in side_map.iter().enumerate() {
75            let this_ucb1 = node.ucb1(self.times_visited);
76            if this_ucb1 > best_ucb1 {
77                best_ucb1 = this_ucb1;
78                choice = index;
79            }
80        }
81        choice
82    }
83
84    pub unsafe fn selection(&mut self, state: &mut State) -> (*mut Node, usize, usize) {
85        let return_node = self as *mut Node;
86        if self.s1_options.is_none() {
87            let (s1_options, s2_options) = state.get_all_options();
88            self.populate(s1_options, s2_options);
89        }
90
91        let s1_mc_index = self.maximize_ucb_for_side(&self.s1_options.as_ref().unwrap());
92        let s2_mc_index = self.maximize_ucb_for_side(&self.s2_options.as_ref().unwrap());
93        let child_vector = self.children.get_mut(&(s1_mc_index, s2_mc_index));
94        match child_vector {
95            Some(child_vector) => {
96                let child_vec_ptr = child_vector as *mut Vec<Node>;
97                let chosen_child = self.sample_node(child_vec_ptr);
98                state.apply_instructions(&(*chosen_child).instructions.instruction_list);
99                (*chosen_child).selection(state)
100            }
101            None => (return_node, s1_mc_index, s2_mc_index),
102        }
103    }
104
105    unsafe fn sample_node(&self, move_vector: *mut Vec<Node>) -> *mut Node {
106        let mut rng = thread_rng();
107        let weights: Vec<f64> = (*move_vector)
108            .iter()
109            .map(|x| x.instructions.percentage as f64)
110            .collect();
111        let dist = WeightedIndex::new(weights).unwrap();
112        let chosen_node = &mut (&mut *move_vector)[dist.sample(&mut rng)];
113        let chosen_node_ptr = chosen_node as *mut Node;
114        chosen_node_ptr
115    }
116
117    pub unsafe fn expand(
118        &mut self,
119        state: &mut State,
120        s1_move_index: usize,
121        s2_move_index: usize,
122    ) -> *mut Node {
123        let s1_move = &self.s1_options.as_ref().unwrap()[s1_move_index].move_choice;
124        let s2_move = &self.s2_options.as_ref().unwrap()[s2_move_index].move_choice;
125        // if the battle is over or both moves are none there is no need to expand
126        if (state.battle_is_over() != 0.0 && !self.root)
127            || (s1_move == &MoveChoice::None && s2_move == &MoveChoice::None)
128        {
129            return self as *mut Node;
130        }
131        let should_branch_on_damage = self.root || (*self.parent).root;
132        let mut new_instructions =
133            generate_instructions_from_move_pair(state, s1_move, s2_move, should_branch_on_damage);
134        let mut this_pair_vec = Vec::with_capacity(new_instructions.len());
135        for state_instructions in new_instructions.drain(..) {
136            let mut new_node = Node::new();
137            new_node.parent = self;
138            new_node.instructions = state_instructions;
139            new_node.s1_choice = s1_move_index as u8;
140            new_node.s2_choice = s2_move_index as u8;
141
142            this_pair_vec.push(new_node);
143        }
144
145        // sample a node from the new instruction list.
146        // this is the node that the rollout will be done on
147        let new_node_ptr = self.sample_node(&mut this_pair_vec);
148        state.apply_instructions(&(*new_node_ptr).instructions.instruction_list);
149        self.children
150            .insert((s1_move_index, s2_move_index), this_pair_vec);
151        new_node_ptr
152    }
153
154    pub unsafe fn backpropagate(&mut self, score: f32, state: &mut State) {
155        self.times_visited += 1;
156        if self.root {
157            return;
158        }
159
160        let parent_s1_movenode =
161            &mut (*self.parent).s1_options.as_mut().unwrap()[self.s1_choice as usize];
162        parent_s1_movenode.total_score += score;
163        parent_s1_movenode.visits += 1;
164
165        let parent_s2_movenode =
166            &mut (*self.parent).s2_options.as_mut().unwrap()[self.s2_choice as usize];
167        parent_s2_movenode.total_score += 1.0 - score;
168        parent_s2_movenode.visits += 1;
169
170        state.reverse_instructions(&self.instructions.instruction_list);
171        (*self.parent).backpropagate(score, state);
172    }
173
174    pub fn rollout(&mut self, state: &mut State, root_eval: &f32) -> f32 {
175        let battle_is_over = state.battle_is_over();
176        if battle_is_over == 0.0 {
177            let eval = evaluate(state);
178            sigmoid(eval - root_eval)
179        } else {
180            if battle_is_over == -1.0 {
181                0.0
182            } else {
183                battle_is_over
184            }
185        }
186    }
187}
188
189#[derive(Debug)]
190pub struct MoveNode {
191    pub move_choice: MoveChoice,
192    pub total_score: f32,
193    pub visits: u32,
194}
195
196impl MoveNode {
197    pub fn ucb1(&self, parent_visits: u32) -> f32 {
198        if self.visits == 0 {
199            return f32::INFINITY;
200        }
201        let score = (self.total_score / self.visits as f32)
202            + (2.0 * (parent_visits as f32).ln() / self.visits as f32).sqrt();
203        score
204    }
205    pub fn average_score(&self) -> f32 {
206        let score = self.total_score / self.visits as f32;
207        score
208    }
209}
210
211#[derive(Clone)]
212pub struct MctsSideResult {
213    pub move_choice: MoveChoice,
214    pub total_score: f32,
215    pub visits: u32,
216}
217
218impl MctsSideResult {
219    pub fn average_score(&self) -> f32 {
220        if self.visits == 0 {
221            return 0.0;
222        }
223        let score = self.total_score / self.visits as f32;
224        score
225    }
226}
227
228pub struct MctsResult {
229    pub s1: Vec<MctsSideResult>,
230    pub s2: Vec<MctsSideResult>,
231    pub iteration_count: u32,
232}
233
234fn do_mcts(root_node: &mut Node, state: &mut State, root_eval: &f32) {
235    let (mut new_node, s1_move, s2_move) = unsafe { root_node.selection(state) };
236    new_node = unsafe { (*new_node).expand(state, s1_move, s2_move) };
237    let rollout_result = unsafe { (*new_node).rollout(state, root_eval) };
238    unsafe { (*new_node).backpropagate(rollout_result, state) }
239}
240
241pub fn perform_mcts(
242    state: &mut State,
243    side_one_options: Vec<MoveChoice>,
244    side_two_options: Vec<MoveChoice>,
245    max_time: Duration,
246) -> MctsResult {
247    let mut root_node = Node::new();
248    unsafe {
249        root_node.populate(side_one_options, side_two_options);
250    }
251    root_node.root = true;
252
253    let root_eval = evaluate(state);
254    let start_time = std::time::Instant::now();
255    while start_time.elapsed() < max_time {
256        for _ in 0..1000 {
257            do_mcts(&mut root_node, state, &root_eval);
258        }
259
260        /*
261        Cut off after 10 million iterations
262
263        Under normal circumstances the bot will only run for 2.5-3.5 million iterations
264        however towards the end of a battle the bot may perform tens of millions of iterations
265
266        Beyond about 30 million iterations some floating point nonsense happens where
267        MoveNode.total_score stops updating because f32 does not have enough precision
268
269        I can push the problem farther out by using f64 but if the bot is running for 10 million iterations
270        then it almost certainly sees a forced win
271        */
272        if root_node.times_visited == 10_000_000 {
273            break;
274        }
275    }
276
277    let result = MctsResult {
278        s1: root_node
279            .s1_options
280            .as_ref()
281            .unwrap()
282            .iter()
283            .map(|v| MctsSideResult {
284                move_choice: v.move_choice.clone(),
285                total_score: v.total_score,
286                visits: v.visits,
287            })
288            .collect(),
289        s2: root_node
290            .s2_options
291            .as_ref()
292            .unwrap()
293            .iter()
294            .map(|v| MctsSideResult {
295                move_choice: v.move_choice.clone(),
296                total_score: v.total_score,
297                visits: v.visits,
298            })
299            .collect(),
300        iteration_count: root_node.times_visited,
301    };
302
303    result
304}