1use crate::engine::evaluate::evaluate;
2use crate::engine::generate_instructions::generate_instructions_from_move_pair;
3use crate::engine::state::MoveChoice;
4use crate::instruction::StateInstructions;
5use crate::state::State;
6use rand::distributions::WeightedIndex;
7use rand::prelude::*;
8use rand::thread_rng;
9use std::collections::HashMap;
10use std::time::Duration;
11
12fn sigmoid(x: f32) -> f32 {
13 1.0 / (1.0 + (-0.0125 * x).exp())
15}
16
17#[derive(Debug)]
18pub struct Node {
19 pub root: bool,
20 pub parent: *mut Node,
21 pub children: HashMap<(usize, usize), Vec<Node>>,
22 pub times_visited: u32,
23
24 pub instructions: StateInstructions,
26 pub s1_choice: u8,
27 pub s2_choice: u8,
28
29 pub s1_options: Option<Vec<MoveNode>>,
32 pub s2_options: Option<Vec<MoveNode>>,
33}
34
35impl Node {
36 fn new() -> Node {
37 Node {
38 root: false,
39 parent: std::ptr::null_mut(),
40 instructions: StateInstructions::default(),
41 times_visited: 0,
42 children: HashMap::new(),
43 s1_choice: 0,
44 s2_choice: 0,
45 s1_options: None,
46 s2_options: None,
47 }
48 }
49 unsafe fn populate(&mut self, s1_options: Vec<MoveChoice>, s2_options: Vec<MoveChoice>) {
50 let s1_options_vec: Vec<MoveNode> = s1_options
51 .iter()
52 .map(|x| MoveNode {
53 move_choice: x.clone(),
54 total_score: 0.0,
55 visits: 0,
56 })
57 .collect();
58 let s2_options_vec: Vec<MoveNode> = s2_options
59 .iter()
60 .map(|x| MoveNode {
61 move_choice: x.clone(),
62 total_score: 0.0,
63 visits: 0,
64 })
65 .collect();
66
67 self.s1_options = Some(s1_options_vec);
68 self.s2_options = Some(s2_options_vec);
69 }
70
71 pub fn maximize_ucb_for_side(&self, side_map: &[MoveNode]) -> usize {
72 let mut choice = 0;
73 let mut best_ucb1 = f32::MIN;
74 for (index, node) in side_map.iter().enumerate() {
75 let this_ucb1 = node.ucb1(self.times_visited);
76 if this_ucb1 > best_ucb1 {
77 best_ucb1 = this_ucb1;
78 choice = index;
79 }
80 }
81 choice
82 }
83
84 pub unsafe fn selection(&mut self, state: &mut State) -> (*mut Node, usize, usize) {
85 let return_node = self as *mut Node;
86 if self.s1_options.is_none() {
87 let (s1_options, s2_options) = state.get_all_options();
88 self.populate(s1_options, s2_options);
89 }
90
91 let s1_mc_index = self.maximize_ucb_for_side(&self.s1_options.as_ref().unwrap());
92 let s2_mc_index = self.maximize_ucb_for_side(&self.s2_options.as_ref().unwrap());
93 let child_vector = self.children.get_mut(&(s1_mc_index, s2_mc_index));
94 match child_vector {
95 Some(child_vector) => {
96 let child_vec_ptr = child_vector as *mut Vec<Node>;
97 let chosen_child = self.sample_node(child_vec_ptr);
98 state.apply_instructions(&(*chosen_child).instructions.instruction_list);
99 (*chosen_child).selection(state)
100 }
101 None => (return_node, s1_mc_index, s2_mc_index),
102 }
103 }
104
105 unsafe fn sample_node(&self, move_vector: *mut Vec<Node>) -> *mut Node {
106 let mut rng = thread_rng();
107 let weights: Vec<f64> = (*move_vector)
108 .iter()
109 .map(|x| x.instructions.percentage as f64)
110 .collect();
111 let dist = WeightedIndex::new(weights).unwrap();
112 let chosen_node = &mut (&mut *move_vector)[dist.sample(&mut rng)];
113 let chosen_node_ptr = chosen_node as *mut Node;
114 chosen_node_ptr
115 }
116
117 pub unsafe fn expand(
118 &mut self,
119 state: &mut State,
120 s1_move_index: usize,
121 s2_move_index: usize,
122 ) -> *mut Node {
123 let s1_move = &self.s1_options.as_ref().unwrap()[s1_move_index].move_choice;
124 let s2_move = &self.s2_options.as_ref().unwrap()[s2_move_index].move_choice;
125 if (state.battle_is_over() != 0.0 && !self.root)
127 || (s1_move == &MoveChoice::None && s2_move == &MoveChoice::None)
128 {
129 return self as *mut Node;
130 }
131 let should_branch_on_damage = self.root || (*self.parent).root;
132 let mut new_instructions =
133 generate_instructions_from_move_pair(state, s1_move, s2_move, should_branch_on_damage);
134 let mut this_pair_vec = Vec::with_capacity(new_instructions.len());
135 for state_instructions in new_instructions.drain(..) {
136 let mut new_node = Node::new();
137 new_node.parent = self;
138 new_node.instructions = state_instructions;
139 new_node.s1_choice = s1_move_index as u8;
140 new_node.s2_choice = s2_move_index as u8;
141
142 this_pair_vec.push(new_node);
143 }
144
145 let new_node_ptr = self.sample_node(&mut this_pair_vec);
148 state.apply_instructions(&(*new_node_ptr).instructions.instruction_list);
149 self.children
150 .insert((s1_move_index, s2_move_index), this_pair_vec);
151 new_node_ptr
152 }
153
154 pub unsafe fn backpropagate(&mut self, score: f32, state: &mut State) {
155 self.times_visited += 1;
156 if self.root {
157 return;
158 }
159
160 let parent_s1_movenode =
161 &mut (*self.parent).s1_options.as_mut().unwrap()[self.s1_choice as usize];
162 parent_s1_movenode.total_score += score;
163 parent_s1_movenode.visits += 1;
164
165 let parent_s2_movenode =
166 &mut (*self.parent).s2_options.as_mut().unwrap()[self.s2_choice as usize];
167 parent_s2_movenode.total_score += 1.0 - score;
168 parent_s2_movenode.visits += 1;
169
170 state.reverse_instructions(&self.instructions.instruction_list);
171 (*self.parent).backpropagate(score, state);
172 }
173
174 pub fn rollout(&mut self, state: &mut State, root_eval: &f32) -> f32 {
175 let battle_is_over = state.battle_is_over();
176 if battle_is_over == 0.0 {
177 let eval = evaluate(state);
178 sigmoid(eval - root_eval)
179 } else {
180 if battle_is_over == -1.0 {
181 0.0
182 } else {
183 battle_is_over
184 }
185 }
186 }
187}
188
189#[derive(Debug)]
190pub struct MoveNode {
191 pub move_choice: MoveChoice,
192 pub total_score: f32,
193 pub visits: u32,
194}
195
196impl MoveNode {
197 pub fn ucb1(&self, parent_visits: u32) -> f32 {
198 if self.visits == 0 {
199 return f32::INFINITY;
200 }
201 let score = (self.total_score / self.visits as f32)
202 + (2.0 * (parent_visits as f32).ln() / self.visits as f32).sqrt();
203 score
204 }
205 pub fn average_score(&self) -> f32 {
206 let score = self.total_score / self.visits as f32;
207 score
208 }
209}
210
211#[derive(Clone)]
212pub struct MctsSideResult {
213 pub move_choice: MoveChoice,
214 pub total_score: f32,
215 pub visits: u32,
216}
217
218impl MctsSideResult {
219 pub fn average_score(&self) -> f32 {
220 if self.visits == 0 {
221 return 0.0;
222 }
223 let score = self.total_score / self.visits as f32;
224 score
225 }
226}
227
228pub struct MctsResult {
229 pub s1: Vec<MctsSideResult>,
230 pub s2: Vec<MctsSideResult>,
231 pub iteration_count: u32,
232}
233
234fn do_mcts(root_node: &mut Node, state: &mut State, root_eval: &f32) {
235 let (mut new_node, s1_move, s2_move) = unsafe { root_node.selection(state) };
236 new_node = unsafe { (*new_node).expand(state, s1_move, s2_move) };
237 let rollout_result = unsafe { (*new_node).rollout(state, root_eval) };
238 unsafe { (*new_node).backpropagate(rollout_result, state) }
239}
240
241pub fn perform_mcts(
242 state: &mut State,
243 side_one_options: Vec<MoveChoice>,
244 side_two_options: Vec<MoveChoice>,
245 max_time: Duration,
246) -> MctsResult {
247 let mut root_node = Node::new();
248 unsafe {
249 root_node.populate(side_one_options, side_two_options);
250 }
251 root_node.root = true;
252
253 let root_eval = evaluate(state);
254 let start_time = std::time::Instant::now();
255 while start_time.elapsed() < max_time {
256 for _ in 0..1000 {
257 do_mcts(&mut root_node, state, &root_eval);
258 }
259
260 if root_node.times_visited == 10_000_000 {
273 break;
274 }
275 }
276
277 let result = MctsResult {
278 s1: root_node
279 .s1_options
280 .as_ref()
281 .unwrap()
282 .iter()
283 .map(|v| MctsSideResult {
284 move_choice: v.move_choice.clone(),
285 total_score: v.total_score,
286 visits: v.visits,
287 })
288 .collect(),
289 s2: root_node
290 .s2_options
291 .as_ref()
292 .unwrap()
293 .iter()
294 .map(|v| MctsSideResult {
295 move_choice: v.move_choice.clone(),
296 total_score: v.total_score,
297 visits: v.visits,
298 })
299 .collect(),
300 iteration_count: root_node.times_visited,
301 };
302
303 result
304}