1use crate::evaluate::evaluate;
2use crate::generate_instructions::generate_instructions_from_move_pair;
3use crate::state::{MoveChoice, State};
4use std::sync::mpsc::{channel, Receiver, Sender};
5use std::sync::{Arc, Mutex};
6use std::thread;
7use std::time::Duration;
8
9enum IterativeDeependingThreadMessage {
10 Stop((Vec<MoveChoice>, Vec<MoveChoice>, Vec<f32>, i8)),
11}
12
13pub fn expectiminimax_search(
14 state: &mut State,
15 mut depth: i8,
16 side_one_options: Vec<MoveChoice>,
17 side_two_options: Vec<MoveChoice>,
18 ab_prune: bool,
19 mtx: &Arc<Mutex<bool>>,
20) -> Vec<f32> {
21 depth -= 1;
22 let num_s1_moves = side_one_options.len();
23 let num_s2_moves = side_two_options.len();
24 let mut score_lookup: Vec<f32> = Vec::with_capacity(num_s1_moves * num_s2_moves);
25
26 if *mtx.lock().unwrap() == false {
27 for _ in 0..(num_s1_moves * num_s2_moves) {
28 score_lookup.push(0.0);
29 }
30 return score_lookup;
31 }
32
33 let battle_is_over = state.battle_is_over();
34 if battle_is_over != 0.0 {
35 for _ in 0..(num_s1_moves * num_s2_moves) {
36 score_lookup.push(((100.0 * depth as f32) * battle_is_over) + evaluate(state));
37 }
38 return score_lookup;
39 }
40
41 let mut skip;
42 let mut alpha = f32::MIN;
43 for side_one_move in side_one_options.iter().as_ref() {
44 let mut beta = f32::MAX;
45 skip = false;
46
47 for side_two_move in side_two_options.iter().as_ref() {
48 if skip {
49 score_lookup.push(f32::NAN);
50 continue;
51 }
52
53 let mut score = 0.0;
54 let instructions =
55 generate_instructions_from_move_pair(state, &side_one_move, &side_two_move, false);
56 if depth == 0 {
57 for instruction in instructions.iter() {
58 state.apply_instructions(&instruction.instruction_list);
59 score += instruction.percentage * evaluate(state) / 100.0;
60 state.reverse_instructions(&instruction.instruction_list);
61 }
62 } else {
63 for instruction in instructions.iter() {
64 state.apply_instructions(&instruction.instruction_list);
65 let (next_turn_side_one_options, next_turn_side_two_options) =
66 state.get_all_options();
67
68 let next_turn_side_one_options_len = next_turn_side_one_options.len();
69 let next_turn_side_two_options_len = next_turn_side_two_options.len();
70 let (_, safest) = pick_safest(
71 &expectiminimax_search(
72 state,
73 depth,
74 next_turn_side_one_options,
75 next_turn_side_two_options,
76 true, &mtx,
78 ),
79 next_turn_side_one_options_len,
80 next_turn_side_two_options_len,
81 );
82 score += instruction.percentage * safest / 100.0;
83
84 state.reverse_instructions(&instruction.instruction_list);
85 }
86 }
87 score_lookup.push(score);
88
89 if ab_prune {
90 if score < beta {
91 beta = score;
92 }
93 if score <= alpha {
94 skip = true;
95 }
96 }
97 }
98 if beta > alpha {
99 alpha = beta;
100 }
101 }
102 score_lookup
103}
104
105pub fn pick_safest(
106 score_lookup: &Vec<f32>,
107 num_s1_moves: usize,
108 num_s2_moves: usize,
109) -> (usize, f32) {
110 let mut best_worst_case = f32::MIN;
111 let mut best_worst_case_s1_index = 0;
112 let mut vec_index = 0;
113
114 for s1_index in 0..num_s1_moves {
115 let mut worst_case_this_row = f32::MAX;
116 for _ in 0..num_s2_moves {
117 let score = score_lookup[vec_index];
118 vec_index += 1;
119 if score < worst_case_this_row {
120 worst_case_this_row = score;
121 }
122 }
123 if worst_case_this_row > best_worst_case {
124 best_worst_case_s1_index = s1_index;
125 best_worst_case = worst_case_this_row;
126 }
127 }
128
129 (best_worst_case_s1_index, best_worst_case)
130}
131
132fn re_order_moves_for_iterative_deepening(
133 last_search_result: &Vec<f32>,
134 side_one_options: Vec<MoveChoice>,
135 side_two_options: Vec<MoveChoice>,
136) -> (Vec<MoveChoice>, Vec<MoveChoice>) {
137 let num_s1_moves = side_one_options.len();
138 let num_s2_moves = side_two_options.len();
139 let mut worst_case_s1_scores: Vec<(MoveChoice, f32)> = vec![];
140 let mut vec_index = 0;
141
142 for s1_index in 0..num_s1_moves {
143 let mut worst_case_this_row = f32::MAX;
144 for _ in 0..num_s2_moves {
145 let score = last_search_result[vec_index];
146 vec_index += 1;
147 if score < worst_case_this_row {
148 worst_case_this_row = score;
149 }
150 }
151 worst_case_s1_scores.push((side_one_options[s1_index].clone(), worst_case_this_row));
152 }
153
154 worst_case_s1_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
155 let new_s1_vec = worst_case_s1_scores.iter().map(|x| x.0.clone()).collect();
156
157 (new_s1_vec, side_two_options)
158}
159
160pub fn iterative_deepen_expectiminimax(
161 state: &mut State,
162 side_one_options: Vec<MoveChoice>,
163 side_two_options: Vec<MoveChoice>,
164 max_time: Duration,
165) -> (Vec<MoveChoice>, Vec<MoveChoice>, Vec<f32>, i8) {
166 let mut state_clone = state.clone();
167
168 let mut result = expectiminimax_search(
169 state,
170 1,
171 side_one_options.clone(),
172 side_two_options.clone(),
173 true,
174 &Arc::new(Mutex::new(true)),
175 );
176 let (mut re_ordered_s1_options, mut re_ordered_s2_options) =
177 re_order_moves_for_iterative_deepening(&result, side_one_options, side_two_options);
178 let mut i = 1;
179 let running = Arc::new(Mutex::new(true));
180 let running_clone = Arc::clone(&running);
181
182 let (sender, receiver): (
183 Sender<IterativeDeependingThreadMessage>,
184 Receiver<IterativeDeependingThreadMessage>,
185 ) = channel();
186
187 let handle = thread::spawn(move || {
188 let mut previous_turn_s1_options = re_ordered_s1_options.clone();
189 let mut previous_turn_s2_options = re_ordered_s2_options.clone();
190 loop {
191 let previous_result = result;
192 i += 1;
193 result = expectiminimax_search(
194 &mut state_clone,
195 i,
196 re_ordered_s1_options.clone(),
197 re_ordered_s2_options.clone(),
198 true,
199 &running_clone,
200 );
201
202 if *running_clone.lock().unwrap() == false {
205 sender
206 .send(IterativeDeependingThreadMessage::Stop((
207 previous_turn_s1_options,
208 previous_turn_s2_options,
209 previous_result,
210 i - 1,
211 )))
212 .unwrap();
213 break;
214 }
215 previous_turn_s1_options = re_ordered_s1_options.clone();
216 previous_turn_s2_options = re_ordered_s2_options.clone();
217 (re_ordered_s1_options, re_ordered_s2_options) = re_order_moves_for_iterative_deepening(
218 &result,
219 re_ordered_s1_options,
220 re_ordered_s2_options,
221 );
222 }
223 });
224
225 thread::sleep(max_time);
226 *running.lock().unwrap() = false;
227 match receiver.recv() {
228 Ok(IterativeDeependingThreadMessage::Stop(result)) => {
229 handle.join().unwrap();
230 result
231 }
232 _ => panic!("Failed to receive stop message"),
233 }
234}