1use crate::engine::evaluate::evaluate;
2use crate::engine::generate_instructions::generate_instructions_from_move_pair;
3use crate::engine::state::MoveChoice;
4use crate::state::State;
5use std::sync::mpsc::{channel, Receiver, Sender};
6use std::sync::{Arc, Mutex};
7use std::thread;
8use std::time::Duration;
9
10enum IterativeDeependingThreadMessage {
11 Stop((Vec<MoveChoice>, Vec<MoveChoice>, Vec<f32>, i8)),
12}
13
14pub fn expectiminimax_search(
15 state: &mut State,
16 mut depth: i8,
17 side_one_options: Vec<MoveChoice>,
18 side_two_options: Vec<MoveChoice>,
19 ab_prune: bool,
20 mtx: &Arc<Mutex<bool>>,
21) -> Vec<f32> {
22 depth -= 1;
23 let num_s1_moves = side_one_options.len();
24 let num_s2_moves = side_two_options.len();
25 let mut score_lookup: Vec<f32> = Vec::with_capacity(num_s1_moves * num_s2_moves);
26
27 if *mtx.lock().unwrap() == false {
28 for _ in 0..(num_s1_moves * num_s2_moves) {
29 score_lookup.push(0.0);
30 }
31 return score_lookup;
32 }
33
34 let battle_is_over = state.battle_is_over();
35 if battle_is_over != 0.0 {
36 for _ in 0..(num_s1_moves * num_s2_moves) {
37 score_lookup.push(((100.0 * depth as f32) * battle_is_over) + evaluate(state));
38 }
39 return score_lookup;
40 }
41
42 let mut skip;
43 let mut alpha = f32::MIN;
44 for side_one_move in side_one_options.iter().as_ref() {
45 let mut beta = f32::MAX;
46 skip = false;
47
48 for side_two_move in side_two_options.iter().as_ref() {
49 if skip {
50 score_lookup.push(f32::NAN);
51 continue;
52 }
53
54 let mut score = 0.0;
55 let instructions =
56 generate_instructions_from_move_pair(state, &side_one_move, &side_two_move, false);
57 if depth == 0 {
58 for instruction in instructions.iter() {
59 state.apply_instructions(&instruction.instruction_list);
60 score += instruction.percentage * evaluate(state) / 100.0;
61 state.reverse_instructions(&instruction.instruction_list);
62 }
63 } else {
64 for instruction in instructions.iter() {
65 state.apply_instructions(&instruction.instruction_list);
66 let (next_turn_side_one_options, next_turn_side_two_options) =
67 state.get_all_options();
68
69 let next_turn_side_one_options_len = next_turn_side_one_options.len();
70 let next_turn_side_two_options_len = next_turn_side_two_options.len();
71 let (_, safest) = pick_safest(
72 &expectiminimax_search(
73 state,
74 depth,
75 next_turn_side_one_options,
76 next_turn_side_two_options,
77 true, &mtx,
79 ),
80 next_turn_side_one_options_len,
81 next_turn_side_two_options_len,
82 );
83 score += instruction.percentage * safest / 100.0;
84
85 state.reverse_instructions(&instruction.instruction_list);
86 }
87 }
88 score_lookup.push(score);
89
90 if ab_prune {
91 if score < beta {
92 beta = score;
93 }
94 if score <= alpha {
95 skip = true;
96 }
97 }
98 }
99 if beta > alpha {
100 alpha = beta;
101 }
102 }
103 score_lookup
104}
105
106pub fn pick_safest(
107 score_lookup: &Vec<f32>,
108 num_s1_moves: usize,
109 num_s2_moves: usize,
110) -> (usize, f32) {
111 let mut best_worst_case = f32::MIN;
112 let mut best_worst_case_s1_index = 0;
113 let mut vec_index = 0;
114
115 for s1_index in 0..num_s1_moves {
116 let mut worst_case_this_row = f32::MAX;
117 for _ in 0..num_s2_moves {
118 let score = score_lookup[vec_index];
119 vec_index += 1;
120 if score < worst_case_this_row {
121 worst_case_this_row = score;
122 }
123 }
124 if worst_case_this_row > best_worst_case {
125 best_worst_case_s1_index = s1_index;
126 best_worst_case = worst_case_this_row;
127 }
128 }
129
130 (best_worst_case_s1_index, best_worst_case)
131}
132
133fn re_order_moves_for_iterative_deepening(
134 last_search_result: &Vec<f32>,
135 side_one_options: Vec<MoveChoice>,
136 side_two_options: Vec<MoveChoice>,
137) -> (Vec<MoveChoice>, Vec<MoveChoice>) {
138 let num_s1_moves = side_one_options.len();
139 let num_s2_moves = side_two_options.len();
140 let mut worst_case_s1_scores: Vec<(MoveChoice, f32)> = vec![];
141 let mut vec_index = 0;
142
143 for s1_index in 0..num_s1_moves {
144 let mut worst_case_this_row = f32::MAX;
145 for _ in 0..num_s2_moves {
146 let score = last_search_result[vec_index];
147 vec_index += 1;
148 if score < worst_case_this_row {
149 worst_case_this_row = score;
150 }
151 }
152 worst_case_s1_scores.push((side_one_options[s1_index].clone(), worst_case_this_row));
153 }
154
155 worst_case_s1_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
156 let new_s1_vec = worst_case_s1_scores.iter().map(|x| x.0.clone()).collect();
157
158 (new_s1_vec, side_two_options)
159}
160
161pub fn iterative_deepen_expectiminimax(
162 state: &mut State,
163 side_one_options: Vec<MoveChoice>,
164 side_two_options: Vec<MoveChoice>,
165 max_time: Duration,
166) -> (Vec<MoveChoice>, Vec<MoveChoice>, Vec<f32>, i8) {
167 let mut state_clone = state.clone();
168
169 let mut result = expectiminimax_search(
170 state,
171 1,
172 side_one_options.clone(),
173 side_two_options.clone(),
174 true,
175 &Arc::new(Mutex::new(true)),
176 );
177 let (mut re_ordered_s1_options, mut re_ordered_s2_options) =
178 re_order_moves_for_iterative_deepening(&result, side_one_options, side_two_options);
179 let mut i = 1;
180 let running = Arc::new(Mutex::new(true));
181 let running_clone = Arc::clone(&running);
182
183 let (sender, receiver): (
184 Sender<IterativeDeependingThreadMessage>,
185 Receiver<IterativeDeependingThreadMessage>,
186 ) = channel();
187
188 let handle = thread::spawn(move || {
189 let mut previous_turn_s1_options = re_ordered_s1_options.clone();
190 let mut previous_turn_s2_options = re_ordered_s2_options.clone();
191 loop {
192 let previous_result = result;
193 i += 1;
194 result = expectiminimax_search(
195 &mut state_clone,
196 i,
197 re_ordered_s1_options.clone(),
198 re_ordered_s2_options.clone(),
199 true,
200 &running_clone,
201 );
202
203 if *running_clone.lock().unwrap() == false {
206 sender
207 .send(IterativeDeependingThreadMessage::Stop((
208 previous_turn_s1_options,
209 previous_turn_s2_options,
210 previous_result,
211 i - 1,
212 )))
213 .unwrap();
214 break;
215 }
216 previous_turn_s1_options = re_ordered_s1_options.clone();
217 previous_turn_s2_options = re_ordered_s2_options.clone();
218 (re_ordered_s1_options, re_ordered_s2_options) = re_order_moves_for_iterative_deepening(
219 &result,
220 re_ordered_s1_options,
221 re_ordered_s2_options,
222 );
223 }
224 });
225
226 thread::sleep(max_time);
227 *running.lock().unwrap() = false;
228 match receiver.recv() {
229 Ok(IterativeDeependingThreadMessage::Stop(result)) => {
230 handle.join().unwrap();
231 result
232 }
233 _ => panic!("Failed to receive stop message"),
234 }
235}