connect4_lib/ai/
mod.rs

1use super::game::{BoardState, ChipDescrip, Game};
2use rand::prelude::*;
3use serde::{Deserialize, Serialize};
4
5
6#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
7pub struct AIConfig {
8    carlo_iter: isize,
9    minmax_depth: isize,
10}
11
12pub const EASY_AI: AIConfig = AIConfig {
13    carlo_iter: 5,
14    minmax_depth: 2,
15};
16
17pub const MID_AI: AIConfig = AIConfig {
18    carlo_iter: 1000,
19    minmax_depth: 4,
20};
21
22pub const HARD_AI: AIConfig = AIConfig {
23    carlo_iter: 4000,
24    minmax_depth: 6,
25};
26
27pub fn get_best_move(game: &mut Game, ai_conf: AIConfig) -> (isize, ChipDescrip) {
28    let (_, mov, chip) = evaluate_board(game, ai_conf);
29    (mov, chip)
30}
31
32const MINMAX_SHIFT: isize = 14;
33// returns board evaluation and next best move
34pub fn evaluate_board(game: &mut Game, ai_conf: AIConfig) -> (isize, isize, ChipDescrip) {
35    let is_max = game.get_turn() % 2 == 0;
36
37    fn test_move(mov: isize, chip: ChipDescrip, game: &mut Game, ai_conf: AIConfig) -> isize {
38        game.play(mov, chip);
39        let mut score = minmax_search(game, ai_conf.minmax_depth) << MINMAX_SHIFT;
40        if score == 0 {
41            score = monte_carlo_search(game, ai_conf);
42        }
43        game.undo_move();
44        score
45    }
46
47    let mut potentials: Vec<(isize, isize, ChipDescrip)> = game
48        .get_board()
49        .get_valid_moves()
50        .iter()
51        .flat_map(|&mov| {
52            game.current_player()
53                .chip_options
54                .iter()
55                .map(move |&c| (mov, c))
56        })
57        .map(|(mov, c)| (test_move(mov, c, &mut game.clone(), ai_conf), mov, c))
58        .collect();
59
60    potentials.sort_by(|a, b| {
61        if is_max {
62            (b.0).partial_cmp(&a.0).unwrap()
63        } else {
64            (a.0).partial_cmp(&b.0).unwrap()
65        }
66    });
67
68    // println!("{:?}", potentials);
69    let (score, b_mov, c) = potentials[0];
70    (score >> MINMAX_SHIFT, b_mov, c)
71}
72
73fn monte_carlo_search(game: &mut Game, ai_conf: AIConfig) -> isize {
74    let mut score = 0;
75    (0..ai_conf.carlo_iter).for_each(|_| {
76        let mut moves = 0;
77        let mut res = BoardState::Ongoing;
78        let mut finished = false;
79        while !finished {
80            match res {
81                BoardState::Ongoing => {
82                    let m = game.get_board().get_valid_moves();
83                    let ove = random::<usize>() % m.len();
84                    let mov = m[ove];
85                    let chip = random::<usize>() % game.current_player().chip_options.len();
86                    let chip = game.current_player().chip_options[chip];
87                    res = game.play(mov, chip);
88                    moves += 1;
89                }
90                BoardState::Invalid => {
91                    moves -= 1;
92                    res = BoardState::Ongoing;
93                }
94                BoardState::Draw => {
95                    finished = true;
96                }
97                BoardState::Win(x) => {
98                    if x == 1 {
99                        score += 1
100                    } else {
101                        score -= 1
102                    }
103                    finished = true;
104                }
105            }
106        }
107        for _ in 0..moves {
108            game.undo_move()
109        }
110    });
111
112    score
113}
114
115static mut COUNT: isize = 0;
116// specifically a 2 player AI
117// returns < 0 if player 2 wins
118// returns > 0 if player 1 wins
119fn minmax_search(game: &mut Game, depth: isize) -> isize {
120    unsafe {
121        COUNT += 1;
122    }
123    if depth == 0 {
124        return 0;
125    }
126
127    let is_max = game.get_turn() % 2 == 0;
128    if game.get_player(1).just_won(&game) {
129        return -(depth as isize);
130    }
131    if game.get_player(0).just_won(&game) {
132        return depth as isize;
133    }
134
135    let minmax: fn(isize, isize) -> isize = if is_max { std::cmp::max } else { std::cmp::min };
136
137    let mut score = if is_max {
138        std::isize::MIN
139    } else {
140        std::isize::MAX
141    };
142
143    let moves = game.get_board().get_valid_moves();
144    let player = game.current_player().clone();
145    for mov in moves {
146        for chip in &player.chip_options {
147            game.play_no_check(mov, *chip);
148            score = minmax(score, minmax_search(game, depth - 1));
149            game.undo_move();
150        }
151    }
152
153    score
154}
155
156#[cfg(test)]
157mod tests {
158    // Note this useful idiom: importing names from outer (for mod tests) scope.
159    use super::*;
160    //use crate::io::{GameIO, TermIO};
161
162    use std::time::Instant;
163    macro_rules! time {
164        ($x:expr) => {{
165            let now = Instant::now();
166            $x;
167            now.elapsed().as_micros()
168        }};
169    }
170
171    fn make_game(moves: Vec<isize>) -> Game {
172        let mut game = crate::games::connect4_ai();
173        for mov in moves {
174            let chip = game.current_player().chip_options[0];
175            game.play(mov, chip);
176        }
177        game
178    }
179
180    #[test]
181    fn test_win_1() {
182        let mut game = make_game(vec![1, 2, 1, 2, 1, 2]);
183        crate::io::draw_term_board(game.get_board());
184        let ai = MID_AI;
185        let (eval, mov, _) = evaluate_board(&mut game, ai);
186        println!("Best move = {} which is {}", mov, eval);
187        assert_eq!(eval, ai.minmax_depth as isize);
188        assert_eq!(mov, 1);
189    }
190
191    #[test]
192    fn test_win_1_p2() {
193        let mut game = make_game(vec![1, 2, 1, 2, 1, 2, 0]);
194        let ai = MID_AI;
195        let (eval, mov, _) = evaluate_board(&mut game, ai);
196        assert_eq!(eval, -(ai.minmax_depth as isize));
197        assert_eq!(mov, 2);
198    }
199
200    #[test]
201    #[ignore]
202    fn test_timing() {
203        let mut game = make_game(vec![]);
204
205        unsafe {
206            COUNT = 0;
207        }
208        let mut ai = HARD_AI;
209        ai.carlo_iter += 1;
210        let time = time!(get_best_move(&mut game, ai));
211
212        println!("This test is supposed to fail. It is for keeping track of performance");
213        unsafe {
214            println!(
215                "Took {}µs for depth of {}. Searched {} iterations",
216                time, HARD_AI.minmax_depth, COUNT
217            );
218        }
219        assert!(false);
220    }
221}