poke_engine/
search.rs

1use crate::evaluate::evaluate;
2use crate::generate_instructions::generate_instructions_from_move_pair;
3use crate::state::{MoveChoice, State};
4use std::sync::mpsc::{channel, Receiver, Sender};
5use std::sync::{Arc, Mutex};
6use std::thread;
7use std::time::Duration;
8
9enum IterativeDeependingThreadMessage {
10    Stop((Vec<MoveChoice>, Vec<MoveChoice>, Vec<f32>, i8)),
11}
12
13pub fn expectiminimax_search(
14    state: &mut State,
15    mut depth: i8,
16    side_one_options: Vec<MoveChoice>,
17    side_two_options: Vec<MoveChoice>,
18    ab_prune: bool,
19    mtx: &Arc<Mutex<bool>>,
20) -> Vec<f32> {
21    depth -= 1;
22    let num_s1_moves = side_one_options.len();
23    let num_s2_moves = side_two_options.len();
24    let mut score_lookup: Vec<f32> = Vec::with_capacity(num_s1_moves * num_s2_moves);
25
26    if *mtx.lock().unwrap() == false {
27        for _ in 0..(num_s1_moves * num_s2_moves) {
28            score_lookup.push(0.0);
29        }
30        return score_lookup;
31    }
32
33    let battle_is_over = state.battle_is_over();
34    if battle_is_over != 0.0 {
35        for _ in 0..(num_s1_moves * num_s2_moves) {
36            score_lookup.push(((100.0 * depth as f32) * battle_is_over) + evaluate(state));
37        }
38        return score_lookup;
39    }
40
41    let mut skip;
42    let mut alpha = f32::MIN;
43    for side_one_move in side_one_options.iter().as_ref() {
44        let mut beta = f32::MAX;
45        skip = false;
46
47        for side_two_move in side_two_options.iter().as_ref() {
48            if skip {
49                score_lookup.push(f32::NAN);
50                continue;
51            }
52
53            let mut score = 0.0;
54            let instructions =
55                generate_instructions_from_move_pair(state, &side_one_move, &side_two_move, false);
56            if depth == 0 {
57                for instruction in instructions.iter() {
58                    state.apply_instructions(&instruction.instruction_list);
59                    score += instruction.percentage * evaluate(state) / 100.0;
60                    state.reverse_instructions(&instruction.instruction_list);
61                }
62            } else {
63                for instruction in instructions.iter() {
64                    state.apply_instructions(&instruction.instruction_list);
65                    let (next_turn_side_one_options, next_turn_side_two_options) =
66                        state.get_all_options();
67
68                    let next_turn_side_one_options_len = next_turn_side_one_options.len();
69                    let next_turn_side_two_options_len = next_turn_side_two_options.len();
70                    let (_, safest) = pick_safest(
71                        &expectiminimax_search(
72                            state,
73                            depth,
74                            next_turn_side_one_options,
75                            next_turn_side_two_options,
76                            true, // until there is something better than `pick_safest` for evaluating a sub-game, there is no point in this being anything other than `true`
77                            &mtx,
78                        ),
79                        next_turn_side_one_options_len,
80                        next_turn_side_two_options_len,
81                    );
82                    score += instruction.percentage * safest / 100.0;
83
84                    state.reverse_instructions(&instruction.instruction_list);
85                }
86            }
87            score_lookup.push(score);
88
89            if ab_prune {
90                if score < beta {
91                    beta = score;
92                }
93                if score <= alpha {
94                    skip = true;
95                }
96            }
97        }
98        if beta > alpha {
99            alpha = beta;
100        }
101    }
102    score_lookup
103}
104
105pub fn pick_safest(
106    score_lookup: &Vec<f32>,
107    num_s1_moves: usize,
108    num_s2_moves: usize,
109) -> (usize, f32) {
110    let mut best_worst_case = f32::MIN;
111    let mut best_worst_case_s1_index = 0;
112    let mut vec_index = 0;
113
114    for s1_index in 0..num_s1_moves {
115        let mut worst_case_this_row = f32::MAX;
116        for _ in 0..num_s2_moves {
117            let score = score_lookup[vec_index];
118            vec_index += 1;
119            if score < worst_case_this_row {
120                worst_case_this_row = score;
121            }
122        }
123        if worst_case_this_row > best_worst_case {
124            best_worst_case_s1_index = s1_index;
125            best_worst_case = worst_case_this_row;
126        }
127    }
128
129    (best_worst_case_s1_index, best_worst_case)
130}
131
132fn re_order_moves_for_iterative_deepening(
133    last_search_result: &Vec<f32>,
134    side_one_options: Vec<MoveChoice>,
135    side_two_options: Vec<MoveChoice>,
136) -> (Vec<MoveChoice>, Vec<MoveChoice>) {
137    let num_s1_moves = side_one_options.len();
138    let num_s2_moves = side_two_options.len();
139    let mut worst_case_s1_scores: Vec<(MoveChoice, f32)> = vec![];
140    let mut vec_index = 0;
141
142    for s1_index in 0..num_s1_moves {
143        let mut worst_case_this_row = f32::MAX;
144        for _ in 0..num_s2_moves {
145            let score = last_search_result[vec_index];
146            vec_index += 1;
147            if score < worst_case_this_row {
148                worst_case_this_row = score;
149            }
150        }
151        worst_case_s1_scores.push((side_one_options[s1_index].clone(), worst_case_this_row));
152    }
153
154    worst_case_s1_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
155    let new_s1_vec = worst_case_s1_scores.iter().map(|x| x.0.clone()).collect();
156
157    (new_s1_vec, side_two_options)
158}
159
160pub fn iterative_deepen_expectiminimax(
161    state: &mut State,
162    side_one_options: Vec<MoveChoice>,
163    side_two_options: Vec<MoveChoice>,
164    max_time: Duration,
165) -> (Vec<MoveChoice>, Vec<MoveChoice>, Vec<f32>, i8) {
166    let mut state_clone = state.clone();
167
168    let mut result = expectiminimax_search(
169        state,
170        1,
171        side_one_options.clone(),
172        side_two_options.clone(),
173        true,
174        &Arc::new(Mutex::new(true)),
175    );
176    let (mut re_ordered_s1_options, mut re_ordered_s2_options) =
177        re_order_moves_for_iterative_deepening(&result, side_one_options, side_two_options);
178    let mut i = 1;
179    let running = Arc::new(Mutex::new(true));
180    let running_clone = Arc::clone(&running);
181
182    let (sender, receiver): (
183        Sender<IterativeDeependingThreadMessage>,
184        Receiver<IterativeDeependingThreadMessage>,
185    ) = channel();
186
187    let handle = thread::spawn(move || {
188        let mut previous_turn_s1_options = re_ordered_s1_options.clone();
189        let mut previous_turn_s2_options = re_ordered_s2_options.clone();
190        loop {
191            let previous_result = result;
192            i += 1;
193            result = expectiminimax_search(
194                &mut state_clone,
195                i,
196                re_ordered_s1_options.clone(),
197                re_ordered_s2_options.clone(),
198                true,
199                &running_clone,
200            );
201
202            // when we are told to stop, return the *previous* result.
203            // the current result will be invalid
204            if *running_clone.lock().unwrap() == false {
205                sender
206                    .send(IterativeDeependingThreadMessage::Stop((
207                        previous_turn_s1_options,
208                        previous_turn_s2_options,
209                        previous_result,
210                        i - 1,
211                    )))
212                    .unwrap();
213                break;
214            }
215            previous_turn_s1_options = re_ordered_s1_options.clone();
216            previous_turn_s2_options = re_ordered_s2_options.clone();
217            (re_ordered_s1_options, re_ordered_s2_options) = re_order_moves_for_iterative_deepening(
218                &result,
219                re_ordered_s1_options,
220                re_ordered_s2_options,
221            );
222        }
223    });
224
225    thread::sleep(max_time);
226    *running.lock().unwrap() = false;
227    match receiver.recv() {
228        Ok(IterativeDeependingThreadMessage::Stop(result)) => {
229            handle.join().unwrap();
230            result
231        }
232        _ => panic!("Failed to receive stop message"),
233    }
234}