use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Instant;
use crate::arrayvec::{self, ArrayVec};
use crate::coretypes::{Cp, Move, MoveInfo, MoveKind, PieceKind, PlyKind, MAX_DEPTH};
use crate::eval::{draw, terminal};
use crate::movelist::{Line, MoveInfoList};
use crate::moveorder::order_all_moves;
use crate::position::{Cache, Position};
use crate::search::{quiescence, History, SearchResult};
use crate::timeman::Mode;
use crate::transposition::{Entry, NodeKind, TranspositionTable};
use crate::zobrist::HashKind;
pub fn negamax(mut position: Position, ply: PlyKind, tt: &TranspositionTable) -> SearchResult {
assert!(0 < ply && ply < MAX_DEPTH);
let root_player = *position.player();
let hash = tt.generate_hash(&position);
let instant = Instant::now();
let age = position.age();
let mut pv = Line::new();
let mut nodes = 0;
let best_score = negamax_impl(
&mut position,
tt,
hash,
&mut pv,
&mut nodes,
ply,
Cp::MIN,
Cp::MAX,
age,
);
SearchResult {
player: root_player,
depth: ply,
best_move: *pv.get(0).unwrap(),
score: best_score * root_player.sign(),
pv,
nodes,
elapsed: instant.elapsed(),
..Default::default()
}
}
fn negamax_impl(
position: &mut Position,
tt: &TranspositionTable,
hash: HashKind,
pv: &mut Line,
nodes: &mut u64,
ply: PlyKind,
mut alpha: Cp,
beta: Cp,
age: u8,
) -> Cp {
*nodes += 1;
let legal_moves = position.get_legal_moves();
let num_moves = legal_moves.len();
let mut hash_move = None;
if num_moves == 0 {
pv.clear();
return terminal(&position);
}
else if let Some(entry) = tt.get(hash) {
if entry.ply >= ply && legal_moves.contains(&entry.key_move) {
pv.clear();
pv.push(entry.key_move);
return entry.score;
}
hash_move = Some(entry.key_move);
} else if ply == 0 {
pv.clear();
let q_ply = 10;
return quiescence(position, alpha, beta, q_ply, nodes);
}
let legal_moves = legal_moves
.into_iter()
.map(|move_| position.move_info(move_))
.collect();
let ordered_legal_moves = order_all_moves(legal_moves, hash_move);
debug_assert_eq!(num_moves, ordered_legal_moves.len());
let cache = position.cache();
let mut best_move = Move::illegal();
let mut local_pv = Line::new();
let mut best_score = Cp::MIN;
let mut alpha_raised = false;
for legal_move_info in ordered_legal_moves.into_iter().rev() {
position.do_move_info(legal_move_info);
let move_hash = tt.update_from_hash(hash, &position, legal_move_info, cache);
let move_score = -negamax_impl(
position,
tt,
move_hash,
&mut local_pv,
nodes,
ply - 1,
-beta,
-alpha,
age,
);
position.undo_move(legal_move_info, cache);
if move_score > best_score {
best_score = move_score;
best_move = legal_move_info.move_();
}
if move_score >= beta {
let cut_move = legal_move_info.move_();
let entry = Entry::new(hash, cut_move, move_score, ply, NodeKind::Cut);
tt.replace_by(entry, age, replace_scheme);
return move_score;
}
if best_score > alpha {
alpha_raised = true;
alpha = best_score;
pv.clear();
pv.push(best_move);
arrayvec::append(pv, local_pv.clone());
}
}
let node_kind = match alpha_raised {
true => NodeKind::Pv,
false => NodeKind::All,
};
let entry = Entry::new(hash, best_move, best_score, ply, node_kind);
if node_kind == NodeKind::Pv {
tt.replace(entry, age);
} else {
tt.replace_by(entry, age, replace_scheme);
}
best_score
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
enum Label {
Initialize,
Search,
Retrieve,
}
#[derive(Debug, Clone)]
struct Frame {
pub label: Label,
pub local_pv: Line,
pub legal_moves: MoveInfoList,
pub alpha: Cp,
pub beta: Cp,
pub best_score: Cp,
pub best_move: Move,
pub hash: HashKind,
pub move_info: MoveInfo,
pub cache: Cache,
pub alpha_raised: bool,
}
impl Default for Frame {
fn default() -> Self {
let illegal_move = Move::illegal();
Self {
label: Label::Initialize,
local_pv: Line::new(),
legal_moves: MoveInfoList::new(),
alpha: Cp::MIN,
beta: Cp::MAX,
best_score: Cp::MIN,
best_move: Move::illegal(),
hash: 0,
move_info: MoveInfo {
from: illegal_move.from,
to: illegal_move.to,
promotion: illegal_move.promotion,
piece_kind: PieceKind::Pawn,
move_kind: MoveKind::Quiet,
},
cache: Cache::illegal(),
alpha_raised: false,
}
}
}
#[inline(always)]
fn split_window_frames(frames: &mut [Frame], idx: usize) -> (&mut Frame, &mut Frame, &mut Frame) {
debug_assert!(idx > 0, "cannot get parent frame of index 0");
let (parent_slice, rest) = frames.split_at_mut(idx);
let (curr_slice, rest) = rest.split_at_mut(1);
let parent_frame = parent_slice.last_mut().unwrap();
let current_frame = &mut curr_slice[0];
let child_frame = &mut rest[0];
(parent_frame, current_frame, child_frame)
}
#[inline(always)]
fn parent_idx(frame_idx: usize) -> usize {
frame_idx - 1
}
#[inline(always)]
fn child_idx(frame_idx: usize) -> usize {
frame_idx + 1
}
#[inline(always)]
fn curr_ply(frame_idx: usize) -> PlyKind {
debug_assert!(frame_idx > 0);
(frame_idx - 1) as PlyKind
}
#[inline]
fn replace_scheme(new_entry: &Entry, new_age: u8, existing: &Entry, existing_age: u8) -> bool {
new_age != existing_age || (existing.node_kind != NodeKind::Pv && new_entry.ply >= existing.ply)
}
pub fn iterative_negamax(
mut position: Position,
ply: PlyKind,
mode: Mode,
mut history: History,
tt: &TranspositionTable,
stopper: Arc<AtomicBool>,
) -> Option<SearchResult> {
assert!(0 < ply && ply <= MAX_DEPTH);
assert_ne!(position.get_legal_moves().len(), 0);
let instant = Instant::now(); let root_position = position.clone(); let root_hash = tt.generate_hash(&position); let root_history = history.clone();
let age = position.age();
let nodes_per_stop_check = 2000; let mut stopped = false; let mut stop_check_counter = nodes_per_stop_check;
let contempt = Cp(50);
let mut metrics = SearchResult::default();
metrics.player = root_position.player;
metrics.depth = ply;
metrics.stopped = false;
const BASE_IDX: usize = 0; const ROOT_IDX: usize = 1; let mut stack: ArrayVec<Frame, { (MAX_DEPTH + 1) as usize }> = ArrayVec::new();
while !stack.is_full() {
stack.push(Default::default());
}
stack[ROOT_IDX].label = Label::Initialize;
stack[ROOT_IDX].hash = root_hash;
stack[ROOT_IDX].cache = root_position.cache();
let mut frame_idx: usize = ROOT_IDX;
while frame_idx > 0 {
let (parent, us, child) = split_window_frames(&mut stack, frame_idx);
let remaining_ply = ply - curr_ply(frame_idx);
let label: Label = us.label;
if label == Label::Initialize && stop_check_counter <= 0 {
stop_check_counter = nodes_per_stop_check;
stopped |= stopper.load(Ordering::Acquire);
stopped |= mode.stop(root_position.player, ply);
}
if stopped {
break;
}
if Label::Initialize == label {
stop_check_counter -= 1;
metrics.nodes += 1;
let legal_moves = position.get_legal_moves();
let num_moves = legal_moves.len();
let mut hash_move = None;
if num_moves == 0 {
parent.label = Label::Retrieve;
parent.local_pv.clear();
us.best_score = terminal(&position);
frame_idx = parent_idx(frame_idx);
continue;
}
else if position.fifty_move_rule(num_moves)
|| history.is_threefold_repetition(us.hash)
{
parent.label = Label::Retrieve;
parent.local_pv.clear();
us.best_score = draw(root_position.player == position.player, contempt);
frame_idx = parent_idx(frame_idx);
continue;
}
else if let Some(entry) = tt.get(us.hash) {
metrics.tt_hits += 1;
if entry.ply >= remaining_ply && legal_moves.contains(&entry.key_move) {
metrics.tt_cuts += 1;
parent.label = Label::Retrieve;
parent.local_pv.clear();
parent.local_pv.push(entry.key_move);
us.best_score = entry.score;
us.best_move = entry.key_move;
frame_idx = parent_idx(frame_idx);
continue;
}
hash_move = Some(entry.key_move);
}
else if remaining_ply == 0 {
parent.label = Label::Retrieve;
parent.local_pv.clear();
let q_ply = 10;
let q_instant = Instant::now();
let mut q_nodes = 0;
us.best_score = quiescence(&mut position, us.alpha, us.beta, q_ply, &mut q_nodes);
metrics.q_elapsed += q_instant.elapsed();
metrics.nodes += q_nodes;
metrics.q_nodes += q_nodes;
frame_idx = parent_idx(frame_idx);
continue;
}
let legal_moves: MoveInfoList = legal_moves
.into_iter()
.map(|move_| position.move_info(move_))
.collect();
us.legal_moves = order_all_moves(legal_moves, hash_move);
us.cache = position.cache();
us.label = Label::Search;
} else if Label::Search == label {
if let Some(legal_move) = us.legal_moves.pop() {
us.move_info = legal_move;
position.do_move_info(legal_move);
history.push(us.hash, us.move_info.is_unrepeatable());
let child_hash = tt.update_from_hash(us.hash, &position, us.move_info, us.cache);
child.label = Label::Initialize;
child.hash = child_hash;
child.alpha = -us.beta;
child.beta = -us.alpha;
child.best_score = Cp::MIN;
child.alpha_raised = false;
frame_idx = child_idx(frame_idx);
} else {
let node_kind = match us.alpha_raised {
true => {
metrics.pv_nodes += 1;
NodeKind::Pv
}
false => {
metrics.all_nodes += 1;
NodeKind::All
}
};
let entry = Entry::new(
us.hash,
us.best_move,
us.best_score,
remaining_ply,
node_kind,
);
if node_kind == NodeKind::Pv {
tt.replace(entry, age);
} else {
tt.replace_by(entry, age, replace_scheme);
}
parent.label = Label::Retrieve;
frame_idx = parent_idx(frame_idx);
}
} else if Label::Retrieve == label {
position.undo_move(us.move_info, us.cache);
history.pop();
let move_score = -child.best_score;
if move_score > us.best_score {
us.best_score = move_score;
us.best_move = us.move_info.move_();
}
if us.best_score >= us.beta {
metrics.cut_nodes += 1;
let entry = Entry::new(
us.hash,
us.best_move,
us.best_score,
remaining_ply,
NodeKind::Cut,
);
tt.replace_by(entry, age, replace_scheme);
parent.label = Label::Retrieve;
frame_idx = parent_idx(frame_idx);
continue;
}
if us.best_score > us.alpha {
us.alpha_raised = true;
us.alpha = us.best_score;
parent.local_pv.clear();
parent.local_pv.push(us.best_move);
arrayvec::append(&mut parent.local_pv, us.local_pv.clone());
}
us.label = Label::Search;
}
}
if !stopped {
debug_assert_eq!(root_hash, tt.generate_hash(&position));
debug_assert_eq!(root_hash, stack[ROOT_IDX].hash);
debug_assert_eq!(root_history, history);
}
if stack[BASE_IDX].local_pv.len() == 0 {
None
} else {
let best_move = stack[ROOT_IDX].best_move;
let score = stack[ROOT_IDX].best_score * root_position.player.sign();
let pv = stack[BASE_IDX].local_pv.clone();
assert_ne!(best_move, Move::illegal());
assert_eq!(metrics.player, root_position.player);
assert_eq!(metrics.depth, ply);
metrics.elapsed = instant.elapsed();
metrics.best_move = best_move;
metrics.score = score;
metrics.pv = pv;
metrics.stopped = stopped;
Some(metrics)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::coretypes::{Color, Move, Square::*};
use crate::fen::Fen;
#[test]
#[ignore]
fn mate_pv() {
let position =
Position::parse_fen("r4rk1/1b3ppp/pp2p3/2p5/P1B1NR1Q/3P3P/2q3P1/7K w - - 0 24")
.unwrap();
let mut tt = TranspositionTable::new();
let result = negamax(position, 6, &mut tt);
assert_eq!(result.leading(), Some(Color::White));
assert_eq!(result.best_move, Move::new(E4, F6, None));
println!("{:?}", result.pv);
}
#[test]
fn color_sign() {
let cp = Cp(40);
let w_signed = cp * Color::White.sign();
let b_signed = cp * Color::Black.sign();
assert_eq!(w_signed, Cp(40));
assert_eq!(b_signed, Cp(-40));
}
#[test]
fn nodetype_ordering() {
assert!(NodeKind::Pv > NodeKind::All);
assert!(NodeKind::Pv > NodeKind::Cut);
}
}