use crate::board::{Board, GameResult, Move, Stone, BOARD_SIZE, NUM_CELLS};
use crate::eval::evaluate;
use crate::heuristic::{scan_line, DIR};
use crate::vct::{classify_move, search_vct, ThreatKind, VctConfig};
use noru::network::NnueWeights;
use std::time::{Duration, Instant};
const INF: i32 = 1_000_000;
const WIN_SCORE: i32 = 999_000;
const ROOT_VCT_BUDGET_MS: u64 = 150;
const ROOT_VCT_DEPTH: u32 = 14;
const ROOT_VCT_BUDGET_FRACTION: u32 = 8;
const ROOT_VCT_BUDGET_CAP_MS: u64 = 2_000;
const ROOT_VCT_BUDGET_FLOOR_MS: u64 = 100;
pub struct SearchResult {
pub best_move: Option<Move>,
pub score: i32,
pub depth: u32,
pub nodes: u64,
}
pub struct Searcher {
pub nodes: u64,
killers: [[Option<Move>; 2]; 64],
history: [[i32; NUM_CELLS]; 2],
deadline: Option<Instant>,
aborted: bool,
}
impl Searcher {
pub fn new() -> Self {
Self {
nodes: 0,
killers: [[None; 2]; 64],
history: [[0; NUM_CELLS]; 2],
deadline: None,
aborted: false,
}
}
pub fn search(
&mut self,
board: &mut Board,
weights: &NnueWeights,
max_depth: u32,
time_limit: Option<Duration>,
) -> SearchResult {
self.nodes = 0;
self.aborted = false;
self.killers = [[None; 2]; 64];
self.history = [[0; NUM_CELLS]; 2];
self.deadline = time_limit.map(|d| Instant::now() + d);
let vct_budget = match time_limit {
Some(d) => (d / ROOT_VCT_BUDGET_FRACTION)
.max(Duration::from_millis(ROOT_VCT_BUDGET_FLOOR_MS))
.min(Duration::from_millis(ROOT_VCT_BUDGET_CAP_MS)),
None => Duration::from_millis(ROOT_VCT_BUDGET_MS),
};
let vct_cfg = VctConfig {
max_depth: ROOT_VCT_DEPTH,
time_budget: Some(vct_budget),
};
if let Some(seq) = search_vct(board, &vct_cfg) {
if let Some(&first) = seq.first() {
return SearchResult {
best_move: Some(first),
score: WIN_SCORE,
depth: seq.len() as u32,
nodes: self.nodes,
};
}
}
let mut best_result = SearchResult {
best_move: None,
score: 0,
depth: 0,
nodes: 0,
};
for depth in 1..=max_depth {
let mut best_move = None;
let mut alpha = -INF;
let moves = self.order_moves(board, 0);
for mv in &moves {
board.make_move(*mv);
let score =
-self.alpha_beta(board, weights, depth - 1, 1, -INF, -alpha);
board.undo_move();
if self.aborted {
break;
}
if score > alpha {
alpha = score;
best_move = Some(*mv);
}
}
if self.aborted {
break;
}
best_result = SearchResult {
best_move,
score: alpha,
depth,
nodes: self.nodes,
};
if alpha.abs() > WIN_SCORE - 100 {
break;
}
}
best_result
}
fn alpha_beta(
&mut self,
board: &mut Board,
weights: &NnueWeights,
depth: u32,
ply: usize,
mut alpha: i32,
beta: i32,
) -> i32 {
self.nodes += 1;
if self.nodes & 127 == 0 {
if let Some(deadline) = self.deadline {
if Instant::now() >= deadline {
self.aborted = true;
return 0;
}
}
}
match board.game_result() {
GameResult::BlackWin | GameResult::WhiteWin => {
return -(WIN_SCORE - ply as i32);
}
GameResult::Draw => return 0,
GameResult::Ongoing => {}
}
if depth == 0 {
return evaluate(board, weights);
}
let moves = self.order_moves(board, ply);
if moves.is_empty() {
return 0;
}
let mut best_score = -INF;
let side = board.side_to_move as usize;
for mv in moves {
board.make_move(mv);
let score =
-self.alpha_beta(board, weights, depth - 1, ply + 1, -beta, -alpha);
board.undo_move();
if self.aborted {
return 0;
}
if score > best_score {
best_score = score;
}
if score > alpha {
alpha = score;
self.history[side][mv] += (depth * depth) as i32;
}
if alpha >= beta {
if ply < 64 {
self.killers[ply][1] = self.killers[ply][0];
self.killers[ply][0] = Some(mv);
}
break;
}
}
best_score
}
fn order_moves(&self, board: &Board, ply: usize) -> Vec<Move> {
let mut moves = board.candidate_moves();
let side = board.side_to_move as usize;
let (my, opp) = match board.side_to_move {
Stone::Black => (&board.black, &board.white),
Stone::White => (&board.white, &board.black),
};
moves.sort_unstable_by(|&a, &b| {
let score_a = self.move_score(a, ply, side, my, opp);
let score_b = self.move_score(b, ply, side, my, opp);
score_b.cmp(&score_a)
});
moves
}
fn move_score(
&self,
mv: Move,
ply: usize,
side: usize,
my: &crate::board::BitBoard,
opp: &crate::board::BitBoard,
) -> i32 {
let mut score = self.history[side][mv];
let row = (mv / BOARD_SIZE) as i32;
let col = (mv % BOARD_SIZE) as i32;
score += threat_priority(classify_move(my, opp, mv), false);
score += threat_priority(classify_move(opp, my, mv), true);
for &(dr, dc) in &DIR {
let my_info = scan_line(my, opp, row, col, dr, dc);
if my_info.count >= 4 {
score += 500_000; } else if my_info.count >= 3 {
let open = my_info.open_front as u32 + my_info.open_back as u32;
if open >= 2 {
score += 50_000; } else if open >= 1 {
score += 5_000; }
} else if my_info.count >= 2 {
let open = my_info.open_front as u32 + my_info.open_back as u32;
if open >= 2 {
score += 500; }
}
let opp_info = scan_line(opp, my, row, col, dr, dc);
if opp_info.count >= 4 {
score += 400_000; } else if opp_info.count >= 3 {
let open = opp_info.open_front as u32 + opp_info.open_back as u32;
if open >= 2 {
score += 40_000; } else if open >= 1 {
score += 4_000; }
}
}
if ply < 64 {
if self.killers[ply][0] == Some(mv) {
score += 10_000;
} else if self.killers[ply][1] == Some(mv) {
score += 5_000;
}
}
let center_dist = ((row - 7).abs() + (col - 7).abs()) as i32;
score += (14 - center_dist) * 2;
score
}
}
fn threat_priority(kind: ThreatKind, defending: bool) -> i32 {
let base = match kind {
ThreatKind::Five => 1_000_000,
ThreatKind::OpenFour => 500_000,
ThreatKind::DoubleFour | ThreatKind::FourThree => 300_000,
ThreatKind::DoubleThree => 200_000,
ThreatKind::ClosedFour => 100_000,
ThreatKind::OpenThree => 30_000,
ThreatKind::None => 0,
};
if defending {
base * 9 / 10
} else {
base
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::board::{to_idx, Board};
use crate::features::GOMOKU_NNUE_CONFIG;
#[test]
fn test_search_finds_winning_move() {
let mut board = Board::new();
let weights = NnueWeights::zeros(GOMOKU_NNUE_CONFIG);
board.make_move(to_idx(7, 3));
board.make_move(to_idx(8, 3));
board.make_move(to_idx(7, 4));
board.make_move(to_idx(8, 4));
board.make_move(to_idx(7, 5));
board.make_move(to_idx(8, 5));
board.make_move(to_idx(7, 6));
board.make_move(to_idx(8, 6));
let mut searcher = Searcher::new();
let result = searcher.search(&mut board, &weights, 2, None);
let winning_moves = [to_idx(7, 7), to_idx(7, 2)];
assert!(result.best_move.is_some());
assert!(
winning_moves.contains(&result.best_move.unwrap()),
"should find the winning move, got {:?}",
result.best_move
);
}
#[test]
fn test_search_depth_1() {
let mut board = Board::new();
let weights = NnueWeights::zeros(GOMOKU_NNUE_CONFIG);
let mut searcher = Searcher::new();
let result = searcher.search(&mut board, &weights, 1, None);
assert!(result.best_move.is_some());
}
}