use std::cmp::{max, min, Ordering};
use crate::{
precalculated::precalculated_score, transposition_table::TranspositionTable, Column, ConnectFour
};
pub struct Solver {
transposition_table: TranspositionTable,
}
impl Default for Solver {
fn default() -> Self {
Self::new()
}
}
impl Solver {
pub fn new() -> Solver {
let transposition_table = TranspositionTable::new(16777213);
Solver {
transposition_table,
}
}
pub fn score(&mut self, game: &ConnectFour) -> i8 {
precalculated_score(game)
.unwrap_or_else(|| self.score_without_precalculated(game))
}
fn score_without_precalculated(&mut self, game: &ConnectFour) -> i8 {
if game.is_victory() {
return score_from_num_stones(game.stones() as i8);
}
if game.can_win_in_next_move() {
return -score_from_num_stones(game.stones() as i8 + 1);
}
let mut min = -(42 - game.stones() as i8) / 2;
let mut max = (42 + 1 - game.stones() as i8) / 2;
while min < max {
let median = min + (max - min) / 2;
let alpha = if median <= 0 && min / 2 < median {
min / 2
} else if median >= 0 && max / 2 > median {
max / 2
} else {
median
};
let result = alpha_beta(game, alpha, alpha + 1, &mut self.transposition_table);
if result <= alpha {
max = result;
} else {
min = result;
}
}
debug_assert_eq!(min, max);
min
}
pub fn best_moves(&mut self, game: &ConnectFour, best_moves: &mut Vec<Column>) {
if game.is_over() {
return;
}
let mut min = i8::MAX;
for column in game.legal_moves() {
let mut board = *game;
board.play(column);
let score = self.score(&board);
match score.cmp(&min) {
Ordering::Less => {
min = score;
best_moves.clear();
best_moves.push(column);
},
Ordering::Equal => {
best_moves.push(column);
},
Ordering::Greater => (),
};
}
}
}
pub fn score(game: &ConnectFour) -> i8 {
Solver::new().score(game)
}
fn alpha_beta(
game: &ConnectFour,
mut alpha: i8,
mut beta: i8,
cached_beta: &mut TranspositionTable,
) -> i8 {
debug_assert!(alpha < beta);
debug_assert!(!game.can_win_in_next_move());
let possibilities = game.non_loosing_moves_impl();
if possibilities.is_empty() {
return score_from_num_stones(game.stones() as i8 + 2);
}
if game.stones() >= 42 - 2 {
return 0;
}
alpha = max(alpha, score_from_num_stones(game.stones() as i8 + 4));
if alpha >= beta {
return alpha;
}
let upper_bound_beta = cached_beta
.get(game.encode())
.unwrap_or_else(|| -score_from_num_stones(game.stones() as i8 + 3));
beta = min(beta, upper_bound_beta);
if alpha >= beta {
return beta;
}
let mut move_explorer = MoveExplorer::new();
for col in 0..7 {
if possibilities.contains(col) {
move_explorer.add(col, game);
}
}
move_explorer.sort();
for position in move_explorer.next_positions() {
let score = -alpha_beta(&position, -beta, -alpha, cached_beta);
if score >= beta {
return score;
}
alpha = max(alpha, score);
}
cached_beta.put(game.encode(), alpha);
alpha
}
fn score_from_num_stones(num_stones: i8) -> i8 {
let remaining_stones = (42 - num_stones) / 2;
-(remaining_stones + 1)
}
struct MoveExplorer {
col_indices: [(u8, u32, ConnectFour); 7],
len: usize,
}
impl MoveExplorer {
pub fn new() -> Self {
Self {
col_indices: [(0, 0, ConnectFour::new()); 7],
len: 0,
}
}
pub fn add(&mut self, col_index: u8, from: &ConnectFour) {
let mut next_position = *from;
let is_legal = next_position.play(Column::from_index(col_index));
debug_assert!(is_legal);
let score = next_position.heuristic();
self.col_indices[self.len] = (col_index, score, next_position);
self.len += 1;
}
pub fn sort(&mut self) {
const COLUMN_PRIORITY: [u8; 7] = [6, 4, 2, 0, 1, 3, 5];
self.col_indices[..self.len].sort_unstable_by(|a, b| {
b.1.cmp(&a.1)
.then_with(|| COLUMN_PRIORITY[a.0 as usize].cmp(&COLUMN_PRIORITY[b.0 as usize]))
});
}
pub fn next_positions(&self) -> impl Iterator<Item = ConnectFour> + '_ {
self.col_indices[..self.len].iter().map(|(_, _, pos)| *pos)
}
}