connect_four_solver/
solver.rs

1use std::cmp::{max, min, Ordering};
2
3use crate::{
4    precalculated::precalculated_score, transposition_table::TranspositionTable, Column, ConnectFour
5};
6
7/// Reusing the same solver instead of repeatedly running score in order to calculate similar
8/// positions, may have performance benefits, because we can reuse the transposition table.
9pub struct Solver {
10    transposition_table: TranspositionTable,
11}
12
13impl Default for Solver {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl Solver {
20    pub fn new() -> Solver {
21        // 64Bit per entry. Let's hardcode it to use a prime close to 16777213 which multiplied by 8
22        // Byte should be close to 128MiB.
23        let transposition_table = TranspositionTable::new(16777213);
24        Solver {
25            transposition_table,
26        }
27    }
28
29    /// Calculates the score of a connect four game. The score is set up so always picking the move with
30    /// the lowest score results in perfect play. Perfect meaning winning as fast as possible, drawing
31    /// or loosing as late as possible.
32    ///
33    /// A positive score means the player who can put in the next stone can win. Positions which can be
34    /// won faster are scored higher. The score is 1 if the current player can win with his last stone.
35    /// Two if he can win with his second to last stone and so on. A score of zero means the game will
36    /// end in a draw if both players play perfectly. A negative score means the opponent (the player
37    /// which is not putting in the next stone) is winnig. It is `-1` if the opponent is winning with
38    /// his last stone. `-2` if he is winning second to last stone and so on.
39    pub fn score(&mut self, game: &ConnectFour) -> i8 {
40        precalculated_score(game)
41            .unwrap_or_else(|| self.score_without_precalculated(game))
42    }
43
44    fn score_without_precalculated(&mut self, game: &ConnectFour) -> i8 {
45        if game.is_victory() {
46            return score_from_num_stones(game.stones() as i8);
47        }
48
49        // Check if we can win in the next move because `alpha_beta` assumes that the next move can not
50        // win the game.
51        if game.can_win_in_next_move() {
52            return -score_from_num_stones(game.stones() as i8 + 1);
53        }
54
55        // 64Bit per entry. Let's hardcode it to use a prime close to 16777213 which multiplied by 8
56        // Byte should be close to 128MiB.
57        let mut min = -(42 - game.stones() as i8) / 2;
58        let mut max = (42 + 1 - game.stones() as i8) / 2;
59
60        // Iterative deepening
61        while min < max {
62            let median = min + (max - min) / 2;
63            let alpha = if median <= 0 && min / 2 < median {
64                // Explore loosing path deeper
65                min / 2
66            } else if median >= 0 && max / 2 > median {
67                // Explore winning path deeper
68                max / 2
69            } else {
70                median
71            };
72            let result = alpha_beta(game, alpha, alpha + 1, &mut self.transposition_table);
73            if result <= alpha {
74                max = result;
75            } else {
76                min = result;
77            }
78        }
79        debug_assert_eq!(min, max);
80        min
81    }
82
83    /// Fills `best_moves` with all the legal moves, which have the best strong score.
84    pub fn best_moves(&mut self, game: &ConnectFour, best_moves: &mut Vec<Column>) {
85        if game.is_over() {
86            return;
87        }
88        let mut min = i8::MAX;
89        for column in game.legal_moves() {
90            let mut board = *game;
91            board.play(column);
92            let score = self.score(&board);
93            match score.cmp(&min) {
94                Ordering::Less => {
95                    min = score; 
96                    best_moves.clear();
97                    best_moves.push(column);
98                },
99                Ordering::Equal => {
100                    best_moves.push(column);
101                },
102                Ordering::Greater => (),
103            };
104        }
105    }
106}
107
108/// Calculates the score of a connect four game. The score is set up so always picking the move with
109/// the lowest score results in perfect play. Perfect meaning winning as fast as possible, drawing
110/// or loosing as late as possible.
111///
112/// A positive score means the player who can put in the next stone can win. Positions which can be
113/// won faster are scored higher. The score is 1 if the current player can win with his last stone.
114/// Two if he can win with his second to last stone and so on. A score of zero means the game will
115/// end in a draw if both players play perfectly. A negative score means the opponent (the player
116/// which is not putting in the next stone) is winnig. It is `-1` if the opponent is winning with
117/// his last stone. `-2` if he is winning second to last stone and so on.
118pub fn score(game: &ConnectFour) -> i8 {
119    Solver::new().score(game)
120}
121
122/// Score of the position with alepha beta pruning.
123///
124/// Assumes that position can not be won in a single move. Assumes that position is not won position
125/// already.
126///
127/// * If actual score is smaller than alpha then: actual score <= return value <= alpha
128/// * If actual score is bigger than beta then: actual score >= return value >= beta
129/// * If score is within alpha beta window precise score is returned
130///
131/// If alpha is higher (or equal) than the score of this position, we can prune this position,
132/// because the current player would not play this route, since he is guaranteed to achieve a better
133/// outcome with some other play.
134///
135/// Similarly if this positions score is higher than beta we can prune it, since the opponent would
136/// choose a different line of play, which leavs him in a better position.
137///
138/// Alpha is a lower bound on what the current player can expect. Beta is as upper bound on what he
139/// can expect.
140fn alpha_beta(
141    game: &ConnectFour,
142    mut alpha: i8,
143    mut beta: i8,
144    cached_beta: &mut TranspositionTable,
145) -> i8 {
146    debug_assert!(alpha < beta);
147    debug_assert!(!game.can_win_in_next_move());
148
149    let possibilities = game.non_loosing_moves_impl();
150    if possibilities.is_empty() {
151        // If there are no possibilities for the current player not to loose, the opponent wins.
152        return score_from_num_stones(game.stones() as i8 + 2);
153    }
154
155    // Check for draw
156    if game.stones() >= 42 - 2 {
157        return 0;
158    }
159
160    // Opponent can not win within one move, this gives us a lower bound for the score
161    alpha = max(alpha, score_from_num_stones(game.stones() as i8 + 4));
162    if alpha >= beta {
163        return alpha;
164    }
165
166    // We may also find an upper bound in the cache. If not we use the fact that we know we can not
167    // win with our next stone, which puts the fastest possible win at least three stones away.
168    let upper_bound_beta = cached_beta
169        .get(game.encode())
170        .unwrap_or_else(|| -score_from_num_stones(game.stones() as i8 + 3));
171    beta = min(beta, upper_bound_beta);
172    if alpha >= beta {
173        return beta;
174    }
175
176    let mut move_explorer = MoveExplorer::new();
177    for col in 0..7 {
178        if possibilities.contains(col) {
179            move_explorer.add(col, game);
180        }
181    }
182    move_explorer.sort();
183
184    // We play the position which is the worst for our opponent
185    for position in move_explorer.next_positions() {
186        // Score from the perspective of the current player is the negative of the opponents.
187        let score = -alpha_beta(&position, -beta, -alpha, cached_beta);
188        // prune the exploration if we find a possible move better than what we were looking for.
189        if score >= beta {
190            return score;
191        }
192        // We only need to search for positions, which are better than the best so far.
193        alpha = max(alpha, score);
194    }
195
196    // save the upper bound of the position
197    cached_beta.put(game.encode(), alpha);
198    alpha
199}
200
201/// Score from the perspective of the current player (who can no longer move, because the game is
202/// over), assuming the last stone won after `num_stones`.
203fn score_from_num_stones(num_stones: i8) -> i8 {
204    // Remaining stones of the winning player.
205    let remaining_stones = (42 - num_stones) / 2;
206    // Score is from the perspective of the moving player. So if the current position is a win, it
207    // is negative.
208    -(remaining_stones + 1)
209}
210
211/// Stack allocated container for possible moves. Iterates over moves in a fashion which allows to
212/// prune the search tree sooner.
213struct MoveExplorer {
214    /// Up to seven indices are possible. Store index, score and position.
215    col_indices: [(u8, u32, ConnectFour); 7],
216    /// Up to this index the moves are valid.
217    len: usize,
218}
219
220impl MoveExplorer {
221    pub fn new() -> Self {
222        Self {
223            col_indices: [(0, 0, ConnectFour::new()); 7],
224            len: 0,
225        }
226    }
227
228    pub fn add(&mut self, col_index: u8, from: &ConnectFour) {
229        let mut next_position = *from;
230        let is_legal = next_position.play(Column::from_index(col_index));
231        debug_assert!(is_legal);
232        let score = next_position.heuristic();
233        self.col_indices[self.len] = (col_index, score, next_position);
234        self.len += 1;
235    }
236
237    pub fn sort(&mut self) {
238        /// Indices which should get explored first get smaller values. Explore center moves first.
239        /// These are better on average. This allows for faster pruning.
240        const COLUMN_PRIORITY: [u8; 7] = [6, 4, 2, 0, 1, 3, 5];
241        self.col_indices[..self.len].sort_unstable_by(|a, b| {
242            // sort by score first, then by column priority. We prefer higher scores, therfore a, b
243            // are switched in order.
244            b.1.cmp(&a.1)
245                .then_with(|| COLUMN_PRIORITY[a.0 as usize].cmp(&COLUMN_PRIORITY[b.0 as usize]))
246        });
247    }
248
249    pub fn next_positions(&self) -> impl Iterator<Item = ConnectFour> + '_ {
250        self.col_indices[..self.len].iter().map(|(_, _, pos)| *pos)
251    }
252}