1use crate::evaluate::evaluate;
2use crate::generate_instructions::generate_instructions_from_move_pair;
3use crate::instruction::StateInstructions;
4use crate::state::{MoveChoice, State};
5use rand::distributions::WeightedIndex;
6use rand::prelude::*;
7use rand::thread_rng;
8use std::collections::HashMap;
9use std::time::Duration;
10
11fn sigmoid(x: f32) -> f32 {
12 1.0 / (1.0 + (-0.0125 * x).exp())
14}
15
16#[derive(Debug)]
17pub struct Node {
18 pub root: bool,
19 pub parent: *mut Node,
20 pub children: HashMap<(usize, usize), Vec<Node>>,
21 pub times_visited: i64,
22
23 pub instructions: StateInstructions,
25 pub s1_choice: usize,
26 pub s2_choice: usize,
27
28 pub s1_options: Vec<MoveNode>,
31 pub s2_options: Vec<MoveNode>,
32}
33
34impl Node {
35 fn new(s1_options: Vec<MoveChoice>, s2_options: Vec<MoveChoice>) -> Node {
36 let s1_options_vec = s1_options
37 .iter()
38 .map(|x| MoveNode {
39 move_choice: x.clone(),
40 total_score: 0.0,
41 visits: 0,
42 })
43 .collect();
44 let s2_options_vec = s2_options
45 .iter()
46 .map(|x| MoveNode {
47 move_choice: x.clone(),
48 total_score: 0.0,
49 visits: 0,
50 })
51 .collect();
52
53 Node {
54 root: false,
55 parent: std::ptr::null_mut(),
56 instructions: StateInstructions::default(),
57 times_visited: 0,
58 children: HashMap::new(),
59 s1_choice: 0,
60 s2_choice: 0,
61 s1_options: s1_options_vec,
62 s2_options: s2_options_vec,
63 }
64 }
65
66 pub fn maximize_ucb_for_side(&self, side_map: &[MoveNode]) -> usize {
67 let mut choice = 0;
68 let mut best_ucb1 = f32::MIN;
69 for (index, node) in side_map.iter().enumerate() {
70 let this_ucb1 = node.ucb1(self.times_visited);
71 if this_ucb1 > best_ucb1 {
72 best_ucb1 = this_ucb1;
73 choice = index;
74 }
75 }
76 choice
77 }
78
79 pub unsafe fn selection(&mut self, state: &mut State) -> (*mut Node, usize, usize) {
80 let return_node = self as *mut Node;
81
82 let s1_mc_index = self.maximize_ucb_for_side(&self.s1_options);
83 let s2_mc_index = self.maximize_ucb_for_side(&self.s2_options);
84 let child_vector = self.children.get_mut(&(s1_mc_index, s2_mc_index));
85 match child_vector {
86 Some(child_vector) => {
87 let child_vec_ptr = child_vector as *mut Vec<Node>;
88 let chosen_child = self.sample_node(child_vec_ptr);
89 state.apply_instructions(&(*chosen_child).instructions.instruction_list);
90 (*chosen_child).selection(state)
91 }
92 None => (return_node, s1_mc_index, s2_mc_index),
93 }
94 }
95
96 unsafe fn sample_node(&self, move_vector: *mut Vec<Node>) -> *mut Node {
97 let mut rng = thread_rng();
98 let weights: Vec<f64> = (*move_vector)
99 .iter()
100 .map(|x| x.instructions.percentage as f64)
101 .collect();
102 let dist = WeightedIndex::new(weights).unwrap();
103 let chosen_node = &mut (*move_vector)[dist.sample(&mut rng)];
104 let chosen_node_ptr = chosen_node as *mut Node;
105 chosen_node_ptr
106 }
107
108 pub unsafe fn expand(
109 &mut self,
110 state: &mut State,
111 s1_move_index: usize,
112 s2_move_index: usize,
113 ) -> *mut Node {
114 let s1_move = &self.s1_options[s1_move_index].move_choice;
115 let s2_move = &self.s2_options[s2_move_index].move_choice;
116 if (state.battle_is_over() != 0.0 && !self.root)
118 || (s1_move == &MoveChoice::None && s2_move == &MoveChoice::None)
119 {
120 return self as *mut Node;
121 }
122 let should_branch_on_damage = self.root || (*self.parent).root;
123 let mut new_instructions =
124 generate_instructions_from_move_pair(state, s1_move, s2_move, should_branch_on_damage);
125 let mut this_pair_vec = Vec::with_capacity(2);
126 for state_instructions in new_instructions.drain(..) {
127 state.apply_instructions(&state_instructions.instruction_list);
128 let (s1_options, s2_options) = state.get_all_options();
129 state.reverse_instructions(&state_instructions.instruction_list);
130
131 let mut new_node = Node::new(s1_options, s2_options);
132 new_node.parent = self;
133 new_node.instructions = state_instructions;
134 new_node.s1_choice = s1_move_index;
135 new_node.s2_choice = s2_move_index;
136
137 this_pair_vec.push(new_node);
138 }
139
140 let new_node_ptr = self.sample_node(&mut this_pair_vec);
143 state.apply_instructions(&(*new_node_ptr).instructions.instruction_list);
144 self.children
145 .insert((s1_move_index, s2_move_index), this_pair_vec);
146 new_node_ptr
147 }
148
149 pub unsafe fn backpropagate(&mut self, score: f32, state: &mut State) {
150 self.times_visited += 1;
151 if self.root {
152 return;
153 }
154
155 let parent_s1_movenode = &mut (*self.parent).s1_options[self.s1_choice];
156 parent_s1_movenode.total_score += score;
157 parent_s1_movenode.visits += 1;
158
159 let parent_s2_movenode = &mut (*self.parent).s2_options[self.s2_choice];
160 parent_s2_movenode.total_score += 1.0 - score;
161 parent_s2_movenode.visits += 1;
162
163 state.reverse_instructions(&self.instructions.instruction_list);
164 (*self.parent).backpropagate(score, state);
165 }
166
167 pub fn rollout(&mut self, state: &mut State, root_eval: &f32) -> f32 {
168 let battle_is_over = state.battle_is_over();
169 if battle_is_over == 0.0 {
170 let eval = evaluate(state);
171 sigmoid(eval - root_eval)
172 } else {
173 if battle_is_over == -1.0 {
174 0.0
175 } else {
176 battle_is_over
177 }
178 }
179 }
180}
181
182#[derive(Debug)]
183pub struct MoveNode {
184 pub move_choice: MoveChoice,
185 pub total_score: f32,
186 pub visits: i64,
187}
188
189impl MoveNode {
190 pub fn ucb1(&self, parent_visits: i64) -> f32 {
191 if self.visits == 0 {
192 return f32::INFINITY;
193 }
194 let score = (self.total_score / self.visits as f32)
195 + (2.0 * (parent_visits as f32).ln() / self.visits as f32).sqrt();
196 score
197 }
198 pub fn average_score(&self) -> f32 {
199 let score = self.total_score / self.visits as f32;
200 score
201 }
202}
203
204#[derive(Clone)]
205pub struct MctsSideResult {
206 pub move_choice: MoveChoice,
207 pub total_score: f32,
208 pub visits: i64,
209}
210
211impl MctsSideResult {
212 pub fn average_score(&self) -> f32 {
213 if self.visits == 0 {
214 return 0.0;
215 }
216 let score = self.total_score / self.visits as f32;
217 score
218 }
219}
220
221pub struct MctsResult {
222 pub s1: Vec<MctsSideResult>,
223 pub s2: Vec<MctsSideResult>,
224 pub iteration_count: i64,
225}
226
227fn do_mcts(root_node: &mut Node, state: &mut State, root_eval: &f32) {
228 let (mut new_node, s1_move, s2_move) = unsafe { root_node.selection(state) };
229 new_node = unsafe { (*new_node).expand(state, s1_move, s2_move) };
230 let rollout_result = unsafe { (*new_node).rollout(state, root_eval) };
231 unsafe { (*new_node).backpropagate(rollout_result, state) }
232}
233
234pub fn perform_mcts(
235 state: &mut State,
236 side_one_options: Vec<MoveChoice>,
237 side_two_options: Vec<MoveChoice>,
238 max_time: Duration,
239) -> MctsResult {
240 let mut root_node = Node::new(side_one_options, side_two_options);
241 root_node.root = true;
242
243 let root_eval = evaluate(state);
244 let start_time = std::time::Instant::now();
245 while start_time.elapsed() < max_time {
246 for _ in 0..1000 {
247 do_mcts(&mut root_node, state, &root_eval);
248 }
249
250 if root_node.times_visited == 10_000_000 {
263 break;
264 }
265 }
266
267 let result = MctsResult {
268 s1: root_node
269 .s1_options
270 .iter()
271 .map(|v| MctsSideResult {
272 move_choice: v.move_choice.clone(),
273 total_score: v.total_score,
274 visits: v.visits,
275 })
276 .collect(),
277 s2: root_node
278 .s2_options
279 .iter()
280 .map(|v| MctsSideResult {
281 move_choice: v.move_choice.clone(),
282 total_score: v.total_score,
283 visits: v.visits,
284 })
285 .collect(),
286 iteration_count: root_node.times_visited,
287 };
288
289 result
290}