1use crate::engine::evaluate::evaluate;
2use crate::engine::generate_instructions::generate_instructions_from_move_pair;
3use crate::engine::state::MoveChoice;
4use crate::instruction::StateInstructions;
5use crate::state::State;
6use rand::distributions::WeightedIndex;
7use rand::prelude::*;
8use rand::thread_rng;
9use std::collections::HashMap;
10use std::time::Duration;
11
12fn sigmoid(x: f32) -> f32 {
13 1.0 / (1.0 + (-0.0125 * x).exp())
15}
16
17#[derive(Debug)]
18pub struct Node {
19 pub root: bool,
20 pub parent: *mut Node,
21 pub children: HashMap<(usize, usize), Vec<Node>>,
22 pub times_visited: i64,
23
24 pub instructions: StateInstructions,
26 pub s1_choice: usize,
27 pub s2_choice: usize,
28
29 pub s1_options: Vec<MoveNode>,
32 pub s2_options: Vec<MoveNode>,
33}
34
35impl Node {
36 fn new(s1_options: Vec<MoveChoice>, s2_options: Vec<MoveChoice>) -> Node {
37 let s1_options_vec = s1_options
38 .iter()
39 .map(|x| MoveNode {
40 move_choice: x.clone(),
41 total_score: 0.0,
42 visits: 0,
43 })
44 .collect();
45 let s2_options_vec = s2_options
46 .iter()
47 .map(|x| MoveNode {
48 move_choice: x.clone(),
49 total_score: 0.0,
50 visits: 0,
51 })
52 .collect();
53
54 Node {
55 root: false,
56 parent: std::ptr::null_mut(),
57 instructions: StateInstructions::default(),
58 times_visited: 0,
59 children: HashMap::new(),
60 s1_choice: 0,
61 s2_choice: 0,
62 s1_options: s1_options_vec,
63 s2_options: s2_options_vec,
64 }
65 }
66
67 pub fn maximize_ucb_for_side(&self, side_map: &[MoveNode]) -> usize {
68 let mut choice = 0;
69 let mut best_ucb1 = f32::MIN;
70 for (index, node) in side_map.iter().enumerate() {
71 let this_ucb1 = node.ucb1(self.times_visited);
72 if this_ucb1 > best_ucb1 {
73 best_ucb1 = this_ucb1;
74 choice = index;
75 }
76 }
77 choice
78 }
79
80 pub unsafe fn selection(&mut self, state: &mut State) -> (*mut Node, usize, usize) {
81 let return_node = self as *mut Node;
82
83 let s1_mc_index = self.maximize_ucb_for_side(&self.s1_options);
84 let s2_mc_index = self.maximize_ucb_for_side(&self.s2_options);
85 let child_vector = self.children.get_mut(&(s1_mc_index, s2_mc_index));
86 match child_vector {
87 Some(child_vector) => {
88 let child_vec_ptr = child_vector as *mut Vec<Node>;
89 let chosen_child = self.sample_node(child_vec_ptr);
90 state.apply_instructions(&(*chosen_child).instructions.instruction_list);
91 (*chosen_child).selection(state)
92 }
93 None => (return_node, s1_mc_index, s2_mc_index),
94 }
95 }
96
97 unsafe fn sample_node(&self, move_vector: *mut Vec<Node>) -> *mut Node {
98 let mut rng = thread_rng();
99 let weights: Vec<f64> = (*move_vector)
100 .iter()
101 .map(|x| x.instructions.percentage as f64)
102 .collect();
103 let dist = WeightedIndex::new(weights).unwrap();
104 let chosen_node = &mut (*move_vector)[dist.sample(&mut rng)];
105 let chosen_node_ptr = chosen_node as *mut Node;
106 chosen_node_ptr
107 }
108
109 pub unsafe fn expand(
110 &mut self,
111 state: &mut State,
112 s1_move_index: usize,
113 s2_move_index: usize,
114 ) -> *mut Node {
115 let s1_move = &self.s1_options[s1_move_index].move_choice;
116 let s2_move = &self.s2_options[s2_move_index].move_choice;
117 if (state.battle_is_over() != 0.0 && !self.root)
119 || (s1_move == &MoveChoice::None && s2_move == &MoveChoice::None)
120 {
121 return self as *mut Node;
122 }
123 let should_branch_on_damage = self.root || (*self.parent).root;
124 let mut new_instructions =
125 generate_instructions_from_move_pair(state, s1_move, s2_move, should_branch_on_damage);
126 let mut this_pair_vec = Vec::with_capacity(new_instructions.len());
127 for state_instructions in new_instructions.drain(..) {
128 state.apply_instructions(&state_instructions.instruction_list);
129 let (s1_options, s2_options) = state.get_all_options();
130 state.reverse_instructions(&state_instructions.instruction_list);
131
132 let mut new_node = Node::new(s1_options, s2_options);
133 new_node.parent = self;
134 new_node.instructions = state_instructions;
135 new_node.s1_choice = s1_move_index;
136 new_node.s2_choice = s2_move_index;
137
138 this_pair_vec.push(new_node);
139 }
140
141 let new_node_ptr = self.sample_node(&mut this_pair_vec);
144 state.apply_instructions(&(*new_node_ptr).instructions.instruction_list);
145 self.children
146 .insert((s1_move_index, s2_move_index), this_pair_vec);
147 new_node_ptr
148 }
149
150 pub unsafe fn backpropagate(&mut self, score: f32, state: &mut State) {
151 self.times_visited += 1;
152 if self.root {
153 return;
154 }
155
156 let parent_s1_movenode = &mut (*self.parent).s1_options[self.s1_choice];
157 parent_s1_movenode.total_score += score;
158 parent_s1_movenode.visits += 1;
159
160 let parent_s2_movenode = &mut (*self.parent).s2_options[self.s2_choice];
161 parent_s2_movenode.total_score += 1.0 - score;
162 parent_s2_movenode.visits += 1;
163
164 state.reverse_instructions(&self.instructions.instruction_list);
165 (*self.parent).backpropagate(score, state);
166 }
167
168 pub fn rollout(&mut self, state: &mut State, root_eval: &f32) -> f32 {
169 let battle_is_over = state.battle_is_over();
170 if battle_is_over == 0.0 {
171 let eval = evaluate(state);
172 sigmoid(eval - root_eval)
173 } else {
174 if battle_is_over == -1.0 {
175 0.0
176 } else {
177 battle_is_over
178 }
179 }
180 }
181}
182
183#[derive(Debug)]
184pub struct MoveNode {
185 pub move_choice: MoveChoice,
186 pub total_score: f32,
187 pub visits: i64,
188}
189
190impl MoveNode {
191 pub fn ucb1(&self, parent_visits: i64) -> f32 {
192 if self.visits == 0 {
193 return f32::INFINITY;
194 }
195 let score = (self.total_score / self.visits as f32)
196 + (2.0 * (parent_visits as f32).ln() / self.visits as f32).sqrt();
197 score
198 }
199 pub fn average_score(&self) -> f32 {
200 let score = self.total_score / self.visits as f32;
201 score
202 }
203}
204
205#[derive(Clone)]
206pub struct MctsSideResult {
207 pub move_choice: MoveChoice,
208 pub total_score: f32,
209 pub visits: i64,
210}
211
212impl MctsSideResult {
213 pub fn average_score(&self) -> f32 {
214 if self.visits == 0 {
215 return 0.0;
216 }
217 let score = self.total_score / self.visits as f32;
218 score
219 }
220}
221
222pub struct MctsResult {
223 pub s1: Vec<MctsSideResult>,
224 pub s2: Vec<MctsSideResult>,
225 pub iteration_count: i64,
226}
227
228fn do_mcts(root_node: &mut Node, state: &mut State, root_eval: &f32) {
229 let (mut new_node, s1_move, s2_move) = unsafe { root_node.selection(state) };
230 new_node = unsafe { (*new_node).expand(state, s1_move, s2_move) };
231 let rollout_result = unsafe { (*new_node).rollout(state, root_eval) };
232 unsafe { (*new_node).backpropagate(rollout_result, state) }
233}
234
235pub fn perform_mcts(
236 state: &mut State,
237 side_one_options: Vec<MoveChoice>,
238 side_two_options: Vec<MoveChoice>,
239 max_time: Duration,
240) -> MctsResult {
241 let mut root_node = Node::new(side_one_options, side_two_options);
242 root_node.root = true;
243
244 let root_eval = evaluate(state);
245 let start_time = std::time::Instant::now();
246 while start_time.elapsed() < max_time {
247 for _ in 0..1000 {
248 do_mcts(&mut root_node, state, &root_eval);
249 }
250
251 if root_node.times_visited == 10_000_000 {
264 break;
265 }
266 }
267
268 let result = MctsResult {
269 s1: root_node
270 .s1_options
271 .iter()
272 .map(|v| MctsSideResult {
273 move_choice: v.move_choice.clone(),
274 total_score: v.total_score,
275 visits: v.visits,
276 })
277 .collect(),
278 s2: root_node
279 .s2_options
280 .iter()
281 .map(|v| MctsSideResult {
282 move_choice: v.move_choice.clone(),
283 total_score: v.total_score,
284 visits: v.visits,
285 })
286 .collect(),
287 iteration_count: root_node.times_visited,
288 };
289
290 result
291}