poke_engine/
search.rs

1use crate::engine::evaluate::evaluate;
2use crate::engine::generate_instructions::generate_instructions_from_move_pair;
3use crate::engine::state::MoveChoice;
4use crate::state::State;
5use std::sync::mpsc::{channel, Receiver, Sender};
6use std::sync::{Arc, Mutex};
7use std::thread;
8use std::time::Duration;
9
10enum IterativeDeependingThreadMessage {
11    Stop((Vec<MoveChoice>, Vec<MoveChoice>, Vec<f32>, i8)),
12}
13
14pub fn expectiminimax_search(
15    state: &mut State,
16    mut depth: i8,
17    side_one_options: Vec<MoveChoice>,
18    side_two_options: Vec<MoveChoice>,
19    ab_prune: bool,
20    mtx: &Arc<Mutex<bool>>,
21) -> Vec<f32> {
22    depth -= 1;
23    let num_s1_moves = side_one_options.len();
24    let num_s2_moves = side_two_options.len();
25    let mut score_lookup: Vec<f32> = Vec::with_capacity(num_s1_moves * num_s2_moves);
26
27    if *mtx.lock().unwrap() == false {
28        for _ in 0..(num_s1_moves * num_s2_moves) {
29            score_lookup.push(0.0);
30        }
31        return score_lookup;
32    }
33
34    let battle_is_over = state.battle_is_over();
35    if battle_is_over != 0.0 {
36        for _ in 0..(num_s1_moves * num_s2_moves) {
37            score_lookup.push(((100.0 * depth as f32) * battle_is_over) + evaluate(state));
38        }
39        return score_lookup;
40    }
41
42    let mut skip;
43    let mut alpha = f32::MIN;
44    for side_one_move in side_one_options.iter().as_ref() {
45        let mut beta = f32::MAX;
46        skip = false;
47
48        for side_two_move in side_two_options.iter().as_ref() {
49            if skip {
50                score_lookup.push(f32::NAN);
51                continue;
52            }
53
54            let mut score = 0.0;
55            let instructions =
56                generate_instructions_from_move_pair(state, &side_one_move, &side_two_move, false);
57            if depth == 0 {
58                for instruction in instructions.iter() {
59                    state.apply_instructions(&instruction.instruction_list);
60                    score += instruction.percentage * evaluate(state) / 100.0;
61                    state.reverse_instructions(&instruction.instruction_list);
62                }
63            } else {
64                for instruction in instructions.iter() {
65                    state.apply_instructions(&instruction.instruction_list);
66                    let (next_turn_side_one_options, next_turn_side_two_options) =
67                        state.get_all_options();
68
69                    let next_turn_side_one_options_len = next_turn_side_one_options.len();
70                    let next_turn_side_two_options_len = next_turn_side_two_options.len();
71                    let (_, safest) = pick_safest(
72                        &expectiminimax_search(
73                            state,
74                            depth,
75                            next_turn_side_one_options,
76                            next_turn_side_two_options,
77                            true, // until there is something better than `pick_safest` for evaluating a sub-game, there is no point in this being anything other than `true`
78                            &mtx,
79                        ),
80                        next_turn_side_one_options_len,
81                        next_turn_side_two_options_len,
82                    );
83                    score += instruction.percentage * safest / 100.0;
84
85                    state.reverse_instructions(&instruction.instruction_list);
86                }
87            }
88            score_lookup.push(score);
89
90            if ab_prune {
91                if score < beta {
92                    beta = score;
93                }
94                if score <= alpha {
95                    skip = true;
96                }
97            }
98        }
99        if beta > alpha {
100            alpha = beta;
101        }
102    }
103    score_lookup
104}
105
106pub fn pick_safest(
107    score_lookup: &Vec<f32>,
108    num_s1_moves: usize,
109    num_s2_moves: usize,
110) -> (usize, f32) {
111    let mut best_worst_case = f32::MIN;
112    let mut best_worst_case_s1_index = 0;
113    let mut vec_index = 0;
114
115    for s1_index in 0..num_s1_moves {
116        let mut worst_case_this_row = f32::MAX;
117        for _ in 0..num_s2_moves {
118            let score = score_lookup[vec_index];
119            vec_index += 1;
120            if score < worst_case_this_row {
121                worst_case_this_row = score;
122            }
123        }
124        if worst_case_this_row > best_worst_case {
125            best_worst_case_s1_index = s1_index;
126            best_worst_case = worst_case_this_row;
127        }
128    }
129
130    (best_worst_case_s1_index, best_worst_case)
131}
132
133fn re_order_moves_for_iterative_deepening(
134    last_search_result: &Vec<f32>,
135    side_one_options: Vec<MoveChoice>,
136    side_two_options: Vec<MoveChoice>,
137) -> (Vec<MoveChoice>, Vec<MoveChoice>) {
138    let num_s1_moves = side_one_options.len();
139    let num_s2_moves = side_two_options.len();
140    let mut worst_case_s1_scores: Vec<(MoveChoice, f32)> = vec![];
141    let mut vec_index = 0;
142
143    for s1_index in 0..num_s1_moves {
144        let mut worst_case_this_row = f32::MAX;
145        for _ in 0..num_s2_moves {
146            let score = last_search_result[vec_index];
147            vec_index += 1;
148            if score < worst_case_this_row {
149                worst_case_this_row = score;
150            }
151        }
152        worst_case_s1_scores.push((side_one_options[s1_index].clone(), worst_case_this_row));
153    }
154
155    worst_case_s1_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
156    let new_s1_vec = worst_case_s1_scores.iter().map(|x| x.0.clone()).collect();
157
158    (new_s1_vec, side_two_options)
159}
160
161pub fn iterative_deepen_expectiminimax(
162    state: &mut State,
163    side_one_options: Vec<MoveChoice>,
164    side_two_options: Vec<MoveChoice>,
165    max_time: Duration,
166) -> (Vec<MoveChoice>, Vec<MoveChoice>, Vec<f32>, i8) {
167    let mut state_clone = state.clone();
168
169    let mut result = expectiminimax_search(
170        state,
171        1,
172        side_one_options.clone(),
173        side_two_options.clone(),
174        true,
175        &Arc::new(Mutex::new(true)),
176    );
177    let (mut re_ordered_s1_options, mut re_ordered_s2_options) =
178        re_order_moves_for_iterative_deepening(&result, side_one_options, side_two_options);
179    let mut i = 1;
180    let running = Arc::new(Mutex::new(true));
181    let running_clone = Arc::clone(&running);
182
183    let (sender, receiver): (
184        Sender<IterativeDeependingThreadMessage>,
185        Receiver<IterativeDeependingThreadMessage>,
186    ) = channel();
187
188    let handle = thread::spawn(move || {
189        let mut previous_turn_s1_options = re_ordered_s1_options.clone();
190        let mut previous_turn_s2_options = re_ordered_s2_options.clone();
191        loop {
192            let previous_result = result;
193            i += 1;
194            result = expectiminimax_search(
195                &mut state_clone,
196                i,
197                re_ordered_s1_options.clone(),
198                re_ordered_s2_options.clone(),
199                true,
200                &running_clone,
201            );
202
203            // when we are told to stop, return the *previous* result.
204            // the current result will be invalid
205            if *running_clone.lock().unwrap() == false {
206                sender
207                    .send(IterativeDeependingThreadMessage::Stop((
208                        previous_turn_s1_options,
209                        previous_turn_s2_options,
210                        previous_result,
211                        i - 1,
212                    )))
213                    .unwrap();
214                break;
215            }
216            previous_turn_s1_options = re_ordered_s1_options.clone();
217            previous_turn_s2_options = re_ordered_s2_options.clone();
218            (re_ordered_s1_options, re_ordered_s2_options) = re_order_moves_for_iterative_deepening(
219                &result,
220                re_ordered_s1_options,
221                re_ordered_s2_options,
222            );
223        }
224    });
225
226    thread::sleep(max_time);
227    *running.lock().unwrap() = false;
228    match receiver.recv() {
229        Ok(IterativeDeependingThreadMessage::Stop(result)) => {
230            handle.join().unwrap();
231            result
232        }
233        _ => panic!("Failed to receive stop message"),
234    }
235}