use crate::board::{Board, GameResult, Move, Stone, BOARD_SIZE, NUM_CELLS};
use crate::eval::IncrementalEval;
use crate::heuristic::{scan_line, DIR};
use crate::transposition::{Bound, TranspositionTable, TtStats};
use crate::vct::{classify_move, search_vct, ThreatKind, VctConfig};
use noru::network::NnueWeights;
use std::time::{Duration, Instant};
const INF: i32 = 1_000_000;
const WIN_SCORE: i32 = 999_000;
const ROOT_VCT_BUDGET_MS: u64 = 150;
const ROOT_VCT_DEPTH: u32 = 14;
const ROOT_VCT_BUDGET_FRACTION: u32 = 8;
const ROOT_VCT_BUDGET_CAP_MS: u64 = 2_000;
const ROOT_VCT_BUDGET_FLOOR_MS: u64 = 100;
const TT_BUCKET_BITS: u32 = 18;
const ASPIRATION_MIN_DEPTH: u32 = 4;
const ASPIRATION_INITIAL_DELTA: i32 = 50;
const QSEARCH_MAX_PLY: u32 = 4;
const LMR_MIN_DEPTH: u32 = 3;
const LMR_MIN_MOVE_IDX: usize = 3;
const IIR_MIN_DEPTH: u32 = 4;
const LMP_MIN_DEPTH: u32 = 1;
const LMP_MAX_DEPTH: u32 = 3;
const LMP_BASE: usize = 8;
const LMP_PER_DEPTH: usize = 4;
pub struct SearchResult {
pub best_move: Option<Move>,
pub score: i32,
pub depth: u32,
pub nodes: u64,
}
pub struct Searcher {
pub nodes: u64,
pub tt_cutoffs: u64,
killers: [[Option<Move>; 2]; 64],
history: [[i32; NUM_CELLS]; 2],
deadline: Option<Instant>,
aborted: bool,
tt: TranspositionTable,
}
impl Searcher {
pub fn new() -> Self {
Self {
nodes: 0,
tt_cutoffs: 0,
killers: [[None; 2]; 64],
history: [[0; NUM_CELLS]; 2],
deadline: None,
aborted: false,
tt: TranspositionTable::new(TT_BUCKET_BITS),
}
}
pub fn tt_stats(&self) -> TtStats {
self.tt.stats()
}
pub fn tt_occupancy(&self) -> (usize, usize, usize) {
self.tt.occupancy()
}
pub fn search(
&mut self,
board: &mut Board,
weights: &NnueWeights,
max_depth: u32,
time_limit: Option<Duration>,
) -> SearchResult {
self.nodes = 0;
self.tt_cutoffs = 0;
self.aborted = false;
self.killers = [[None; 2]; 64];
self.history = [[0; NUM_CELLS]; 2];
self.deadline = time_limit.map(|d| Instant::now() + d);
self.tt.reset_stats();
self.tt.clear();
let vct_budget = match time_limit {
Some(d) => (d / ROOT_VCT_BUDGET_FRACTION)
.max(Duration::from_millis(ROOT_VCT_BUDGET_FLOOR_MS))
.min(Duration::from_millis(ROOT_VCT_BUDGET_CAP_MS)),
None => Duration::from_millis(ROOT_VCT_BUDGET_MS),
};
let vct_cfg = VctConfig {
max_depth: ROOT_VCT_DEPTH,
time_budget: Some(vct_budget),
};
if let Some(seq) = search_vct(board, &vct_cfg) {
if let Some(&first) = seq.first() {
return SearchResult {
best_move: Some(first),
score: WIN_SCORE,
depth: seq.len() as u32,
nodes: self.nodes,
};
}
}
let mut best_result = SearchResult {
best_move: None,
score: 0,
depth: 0,
nodes: 0,
};
let mut inc = IncrementalEval::new(weights);
inc.refresh(board, weights);
let mut prev_best: Option<Move> = None;
let mut prev_score: Option<i32> = None;
for depth in 1..=max_depth {
let mut alpha_init: i32;
let mut beta_init: i32;
let aspirate = depth >= ASPIRATION_MIN_DEPTH && prev_score.is_some();
if aspirate {
let s = prev_score.unwrap();
alpha_init = s - ASPIRATION_INITIAL_DELTA;
beta_init = s + ASPIRATION_INITIAL_DELTA;
} else {
alpha_init = -INF;
beta_init = INF;
}
let mut alpha = alpha_init;
let mut beta = beta_init;
let mut delta = ASPIRATION_INITIAL_DELTA;
let mut iter_result: (Option<Move>, i32) = (None, 0);
loop {
iter_result = self.root_pvs_iteration(
board, weights, &mut inc, depth, alpha, beta, prev_best,
);
if self.aborted {
break;
}
let score = iter_result.1;
if !aspirate {
break;
}
if score <= alpha {
delta = (delta * 2).min(INF / 4);
alpha = (alpha - delta).max(-INF);
if alpha == -INF {
beta = INF;
iter_result = self.root_pvs_iteration(
board, weights, &mut inc, depth,
-INF, INF, prev_best,
);
break;
}
} else if score >= beta {
delta = (delta * 2).min(INF / 4);
beta = (beta + delta).min(INF);
if beta == INF {
alpha = -INF;
iter_result = self.root_pvs_iteration(
board, weights, &mut inc, depth,
-INF, INF, prev_best,
);
break;
}
} else {
break;
}
}
if self.aborted {
break;
}
let (best_move, score) = iter_result;
best_result = SearchResult {
best_move,
score,
depth,
nodes: self.nodes,
};
if score.abs() > WIN_SCORE - 100 {
break;
}
prev_best = best_move;
prev_score = Some(score);
}
best_result
}
fn root_pvs_iteration(
&mut self,
board: &mut Board,
weights: &NnueWeights,
inc: &mut IncrementalEval,
depth: u32,
alpha_init: i32,
beta_init: i32,
prev_best: Option<Move>,
) -> (Option<Move>, i32) {
let mut alpha = alpha_init;
let beta = beta_init;
let mut best_move: Option<Move> = None;
let mut moves = self.order_moves(board, 0);
if let Some(pv) = prev_best {
if let Some(pos) = moves.iter().position(|&(m, _)| m == pv) {
if pos != 0 {
moves.swap(0, pos);
}
}
}
for (move_idx, &(mv, is_forcing)) in moves.iter().enumerate() {
board.make_move(mv);
inc.push_move(board, mv, weights);
let is_killer = self.killers[0][0] == Some(mv) || self.killers[0][1] == Some(mv);
let score = if move_idx == 0 {
-self.alpha_beta(board, weights, inc, depth - 1, 1, -beta, -alpha)
} else {
let reduction = lmr_reduction(depth, move_idx, is_forcing, is_killer);
let reduced_depth = (depth - 1).saturating_sub(reduction);
let mut null = -self.alpha_beta(
board, weights, inc, reduced_depth, 1, -alpha - 1, -alpha,
);
if !self.aborted && reduction > 0 && null > alpha {
null = -self.alpha_beta(
board, weights, inc, depth - 1, 1, -alpha - 1, -alpha,
);
}
if !self.aborted && null > alpha && null < beta {
-self.alpha_beta(board, weights, inc, depth - 1, 1, -beta, -alpha)
} else {
null
}
};
inc.pop_move();
board.undo_move();
if self.aborted {
break;
}
if score > alpha {
alpha = score;
best_move = Some(mv);
}
}
(best_move, alpha)
}
fn alpha_beta(
&mut self,
board: &mut Board,
weights: &NnueWeights,
inc: &mut IncrementalEval,
depth: u32,
ply: usize,
mut alpha: i32,
beta: i32,
) -> i32 {
self.nodes += 1;
if self.nodes & 127 == 0 {
if let Some(deadline) = self.deadline {
if Instant::now() >= deadline {
self.aborted = true;
return 0;
}
}
}
match board.game_result() {
GameResult::BlackWin | GameResult::WhiteWin => {
return -(WIN_SCORE - ply as i32);
}
GameResult::Draw => return 0,
GameResult::Ongoing => {}
}
if depth == 0 {
return self.qsearch(board, weights, inc, 0, ply, alpha, beta);
}
let original_alpha = alpha;
let tt_hit = self.tt.probe(board.zobrist);
let mut tt_move: Option<Move> = None;
if let Some(entry) = tt_hit {
tt_move = if entry.best_move == u16::MAX {
None
} else {
Some(entry.best_move as Move)
};
if entry.depth as u32 >= depth {
let cached = entry.score;
match entry.bound {
Bound::Exact => {
self.tt_cutoffs += 1;
return cached;
}
Bound::Lower if cached >= beta => {
self.tt_cutoffs += 1;
return cached;
}
Bound::Upper if cached <= alpha => {
self.tt_cutoffs += 1;
return cached;
}
_ => {}
}
}
}
let is_pv = beta - alpha > 1;
let depth = if depth >= IIR_MIN_DEPTH && tt_move.is_none() && !is_pv {
depth - 1
} else {
depth
};
let mut moves = self.order_moves(board, ply);
if moves.is_empty() {
return 0;
}
if let Some(tt_mv) = tt_move {
if let Some(pos) = moves.iter().position(|&(m, _)| m == tt_mv) {
if pos != 0 {
moves.swap(0, pos);
}
}
}
let mut best_score = -INF;
let mut best_move_at_node: Option<Move> = None;
let side = board.side_to_move as usize;
for (move_idx, &(mv, is_forcing)) in moves.iter().enumerate() {
let is_killer = ply < 64
&& (self.killers[ply][0] == Some(mv) || self.killers[ply][1] == Some(mv));
if !is_pv && !is_forcing && !is_killer
&& depth >= LMP_MIN_DEPTH && depth <= LMP_MAX_DEPTH
{
let lmp_threshold = LMP_BASE + LMP_PER_DEPTH * depth as usize;
if move_idx >= lmp_threshold {
continue;
}
}
board.make_move(mv);
inc.push_move(board, mv, weights);
let score = if move_idx == 0 {
-self.alpha_beta(board, weights, inc, depth - 1, ply + 1, -beta, -alpha)
} else {
let reduction = lmr_reduction(depth, move_idx, is_forcing, is_killer);
let reduced_depth = (depth - 1).saturating_sub(reduction);
let mut null_score = -self.alpha_beta(
board, weights, inc, reduced_depth, ply + 1, -alpha - 1, -alpha,
);
if !self.aborted && reduction > 0 && null_score > alpha {
null_score = -self.alpha_beta(
board, weights, inc, depth - 1, ply + 1, -alpha - 1, -alpha,
);
}
if !self.aborted && null_score > alpha && null_score < beta {
-self.alpha_beta(board, weights, inc, depth - 1, ply + 1, -beta, -alpha)
} else {
null_score
}
};
inc.pop_move();
board.undo_move();
if self.aborted {
return 0;
}
if score > best_score {
best_score = score;
best_move_at_node = Some(mv);
}
if score > alpha {
alpha = score;
self.history[side][mv] += (depth * depth) as i32;
}
if alpha >= beta {
if ply < 64 {
self.killers[ply][1] = self.killers[ply][0];
self.killers[ply][0] = Some(mv);
}
break;
}
}
let bound = if best_score <= original_alpha {
Bound::Upper
} else if best_score >= beta {
Bound::Lower
} else {
Bound::Exact
};
self.tt.store(
board.zobrist,
best_score,
depth.min(255) as u8,
bound,
best_move_at_node,
);
best_score
}
fn qsearch(
&mut self,
board: &mut Board,
weights: &NnueWeights,
inc: &mut IncrementalEval,
qply: u32,
ply: usize,
mut alpha: i32,
beta: i32,
) -> i32 {
self.nodes += 1;
if self.nodes & 127 == 0 {
if let Some(deadline) = self.deadline {
if Instant::now() >= deadline {
self.aborted = true;
return 0;
}
}
}
match board.game_result() {
GameResult::BlackWin | GameResult::WhiteWin => {
return -(WIN_SCORE - ply as i32);
}
GameResult::Draw => return 0,
GameResult::Ongoing => {}
}
let stand_pat = inc.eval(weights);
if qply >= QSEARCH_MAX_PLY {
return stand_pat;
}
if stand_pat >= beta {
return stand_pat;
}
if stand_pat > alpha {
alpha = stand_pat;
}
let candidates = board.candidate_moves();
if candidates.is_empty() {
return stand_pat;
}
let (my, opp) = match board.side_to_move {
Stone::Black => (&board.black, &board.white),
Stone::White => (&board.white, &board.black),
};
let opp_has_five = candidates
.iter()
.any(|&m| matches!(classify_move(opp, my, m), ThreatKind::Five));
let mut forcing: Vec<(Move, i32)> = Vec::new();
for &mv in &candidates {
let my_kind = classify_move(my, opp, mv);
if matches!(my_kind, ThreatKind::Five) {
forcing.push((mv, 1_000_000));
continue;
}
if opp_has_five {
let opp_kind = classify_move(opp, my, mv);
if matches!(opp_kind, ThreatKind::Five) {
forcing.push((mv, 900_000));
}
continue;
}
let attack_score = match my_kind {
ThreatKind::OpenFour => Some(800_000),
ThreatKind::DoubleFour | ThreatKind::FourThree => Some(600_000),
_ => None,
};
if let Some(s) = attack_score {
forcing.push((mv, s));
continue;
}
let opp_kind = classify_move(opp, my, mv);
if matches!(opp_kind, ThreatKind::OpenFour) {
forcing.push((mv, 700_000));
}
}
if forcing.is_empty() {
return stand_pat;
}
forcing.sort_unstable_by(|a, b| b.1.cmp(&a.1));
let mut best = stand_pat;
for &(mv, _) in &forcing {
board.make_move(mv);
inc.push_move(board, mv, weights);
let score = -self.qsearch(board, weights, inc, qply + 1, ply + 1, -beta, -alpha);
inc.pop_move();
board.undo_move();
if self.aborted {
return 0;
}
if score > best {
best = score;
}
if score > alpha {
alpha = score;
}
if alpha >= beta {
break;
}
}
best
}
fn order_moves(&self, board: &Board, ply: usize) -> Vec<(Move, bool)> {
let candidates = board.candidate_moves();
let side = board.side_to_move as usize;
let (my, opp) = match board.side_to_move {
Stone::Black => (&board.black, &board.white),
Stone::White => (&board.white, &board.black),
};
let mut scored: Vec<(Move, i32, bool)> = candidates
.into_iter()
.map(|m| {
let (s, f) = self.move_score_and_forcing(m, ply, side, my, opp);
(m, s, f)
})
.collect();
scored.sort_unstable_by(|a, b| b.1.cmp(&a.1));
scored.into_iter().map(|(m, _, f)| (m, f)).collect()
}
fn move_score_and_forcing(
&self,
mv: Move,
ply: usize,
side: usize,
my: &crate::board::BitBoard,
opp: &crate::board::BitBoard,
) -> (i32, bool) {
let row = (mv / BOARD_SIZE) as i32;
let col = (mv % BOARD_SIZE) as i32;
let my_kind = classify_move(my, opp, mv);
let opp_kind = classify_move(opp, my, mv);
let is_forcing = matches!(
my_kind,
ThreatKind::Five
| ThreatKind::OpenFour
| ThreatKind::DoubleFour
| ThreatKind::FourThree
| ThreatKind::ClosedFour
| ThreatKind::OpenThree
) || matches!(
opp_kind,
ThreatKind::Five
| ThreatKind::OpenFour
| ThreatKind::DoubleFour
| ThreatKind::FourThree
| ThreatKind::ClosedFour
| ThreatKind::OpenThree
);
if matches!(my_kind, ThreatKind::Five) {
return (TIER_WIN, true);
}
if matches!(opp_kind, ThreatKind::Five) {
return (TIER_BLOCK_WIN, true);
}
let tier_score = if matches!(my_kind, ThreatKind::OpenFour) {
TIER_OPEN_FOUR
} else if matches!(opp_kind, ThreatKind::OpenFour) {
TIER_BLOCK_OPEN_FOUR
} else if matches!(my_kind, ThreatKind::DoubleFour | ThreatKind::FourThree) {
TIER_DOUBLE_FOUR
} else if matches!(opp_kind, ThreatKind::DoubleFour | ThreatKind::FourThree) {
TIER_BLOCK_DOUBLE_FOUR
} else if matches!(my_kind, ThreatKind::DoubleThree) {
TIER_DOUBLE_THREE
} else if matches!(opp_kind, ThreatKind::DoubleThree) {
TIER_BLOCK_DOUBLE_THREE
} else if matches!(my_kind, ThreatKind::ClosedFour) {
TIER_CLOSED_FOUR
} else if matches!(opp_kind, ThreatKind::ClosedFour) {
TIER_BLOCK_CLOSED_FOUR
} else if matches!(my_kind, ThreatKind::OpenThree) {
TIER_OPEN_THREE
} else if matches!(opp_kind, ThreatKind::OpenThree) {
TIER_BLOCK_OPEN_THREE
} else {
0
};
let mut score = tier_score;
if ply < 64 {
if self.killers[ply][0] == Some(mv) {
score += 80_000;
} else if self.killers[ply][1] == Some(mv) {
score += 40_000;
}
}
score += self.history[side][mv].min(50_000);
for &(dr, dc) in &DIR {
let my_info = scan_line(my, opp, row, col, dr, dc);
if my_info.count == 2 && my_info.open_front && my_info.open_back {
score += 200;
}
let opp_info = scan_line(opp, my, row, col, dr, dc);
if opp_info.count == 2 && opp_info.open_front && opp_info.open_back {
score += 150;
}
}
(score, is_forcing)
}
}
const TIER_WIN: i32 = 10_000_000;
const TIER_BLOCK_WIN: i32 = 9_000_000;
const TIER_OPEN_FOUR: i32 = 8_000_000;
const TIER_BLOCK_OPEN_FOUR: i32 = 7_000_000;
const TIER_DOUBLE_FOUR: i32 = 6_000_000;
const TIER_BLOCK_DOUBLE_FOUR: i32 = 5_000_000;
const TIER_DOUBLE_THREE: i32 = 4_000_000;
const TIER_BLOCK_DOUBLE_THREE: i32 = 3_000_000;
const TIER_CLOSED_FOUR: i32 = 1_500_000;
const TIER_BLOCK_CLOSED_FOUR: i32 = 1_400_000;
const TIER_OPEN_THREE: i32 = 1_000_000;
const TIER_BLOCK_OPEN_THREE: i32 = 900_000;
fn lmr_reduction(depth: u32, move_idx: usize, is_forcing: bool, is_killer: bool) -> u32 {
if depth < LMR_MIN_DEPTH || move_idx < LMR_MIN_MOVE_IDX || is_forcing || is_killer {
return 0;
}
let mut r = 1u32;
if depth >= 6 {
r += 1;
}
if move_idx >= 6 {
r += 1;
}
r.min(depth.saturating_sub(2))
}
#[allow(dead_code)]
fn threat_priority(kind: ThreatKind, defending: bool) -> i32 {
let base = match kind {
ThreatKind::Five => 1_000_000,
ThreatKind::OpenFour => 500_000,
ThreatKind::DoubleFour | ThreatKind::FourThree => 300_000,
ThreatKind::DoubleThree => 200_000,
ThreatKind::ClosedFour => 100_000,
ThreatKind::OpenThree => 30_000,
ThreatKind::None => 0,
};
if defending {
base * 9 / 10
} else {
base
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::board::{to_idx, Board};
use crate::features::GOMOKU_NNUE_CONFIG;
#[test]
fn test_search_finds_winning_move() {
let mut board = Board::new();
let weights = NnueWeights::zeros(GOMOKU_NNUE_CONFIG);
board.make_move(to_idx(7, 3));
board.make_move(to_idx(8, 3));
board.make_move(to_idx(7, 4));
board.make_move(to_idx(8, 4));
board.make_move(to_idx(7, 5));
board.make_move(to_idx(8, 5));
board.make_move(to_idx(7, 6));
board.make_move(to_idx(8, 6));
let mut searcher = Searcher::new();
let result = searcher.search(&mut board, &weights, 2, None);
let winning_moves = [to_idx(7, 7), to_idx(7, 2)];
assert!(result.best_move.is_some());
assert!(
winning_moves.contains(&result.best_move.unwrap()),
"should find the winning move, got {:?}",
result.best_move
);
}
#[test]
fn test_search_depth_1() {
let mut board = Board::new();
let weights = NnueWeights::zeros(GOMOKU_NNUE_CONFIG);
let mut searcher = Searcher::new();
let result = searcher.search(&mut board, &weights, 1, None);
assert!(result.best_move.is_some());
}
}