use super::{HistoryTables, LOW_PLY_HISTORY_SIZE, PieceToHistory};
use crate::movegen::{ExtMove, ExtMoveBuffer};
use crate::position::Position;
use crate::types::{Color, DEPTH_QS, Depth, Move, Piece, PieceType, Value};
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[repr(u8)]
pub enum Stage {
MainTT,
ProbCutTT,
CaptureInit,
ProbCutInit,
ProbCut,
GoodCapture,
QuietInit,
GoodQuiet,
BadCapture,
BadQuiet,
EvasionTT,
EvasionInit,
Evasion,
QSearchTT,
QCaptureInit,
QCapture,
}
impl Stage {
pub fn next(self) -> Self {
match self {
Stage::MainTT => Stage::CaptureInit,
Stage::CaptureInit => Stage::GoodCapture,
Stage::ProbCutTT => Stage::ProbCutInit,
Stage::ProbCutInit => Stage::ProbCut,
Stage::ProbCut => Stage::ProbCut, Stage::GoodCapture => Stage::QuietInit,
Stage::QuietInit => Stage::GoodQuiet,
Stage::GoodQuiet => Stage::BadCapture,
Stage::BadCapture => Stage::BadQuiet,
Stage::BadQuiet => Stage::BadQuiet,
Stage::EvasionTT => Stage::EvasionInit,
Stage::EvasionInit => Stage::Evasion,
Stage::Evasion => Stage::Evasion,
Stage::QSearchTT => Stage::QCaptureInit,
Stage::QCaptureInit => Stage::QCapture,
Stage::QCapture => Stage::QCapture, }
}
}
pub struct MovePicker {
continuation_history: [*const PieceToHistory; 6],
stage: Stage,
tt_move: Move,
probcut_threshold: Option<Value>,
depth: Depth,
ply: i32,
skip_quiets: bool,
generate_all_legal_moves: bool,
side_to_move: Color,
pawn_history_index: usize,
moves: ExtMoveBuffer,
cur: usize,
end_cur: usize,
end_bad_captures: usize,
end_captures: usize,
end_generated: usize,
end_good_quiets: usize,
}
impl MovePicker {
pub fn new(
pos: &Position,
tt_move: Move,
depth: Depth,
ply: i32,
continuation_history: [&PieceToHistory; 6],
generate_all_legal_moves: bool,
) -> Self {
let stage = if pos.in_check() {
if tt_move.is_some() && pos.pseudo_legal_with_all(tt_move, generate_all_legal_moves) {
Stage::EvasionTT
} else {
Stage::EvasionInit
}
} else if depth > DEPTH_QS {
if tt_move.is_some() && pos.pseudo_legal_with_all(tt_move, generate_all_legal_moves) {
Stage::MainTT
} else {
Stage::CaptureInit
}
} else {
if tt_move.is_some() && pos.pseudo_legal_with_all(tt_move, generate_all_legal_moves) {
Stage::QSearchTT
} else {
Stage::QCaptureInit
}
};
Self {
continuation_history: [
continuation_history[0] as *const _,
continuation_history[1] as *const _,
continuation_history[2] as *const _,
continuation_history[3] as *const _,
continuation_history[4] as *const _,
continuation_history[5] as *const _,
],
stage,
tt_move,
probcut_threshold: None,
depth,
ply,
skip_quiets: false,
generate_all_legal_moves,
side_to_move: pos.side_to_move(),
pawn_history_index: pos.pawn_history_index(),
moves: ExtMoveBuffer::new(),
cur: 0,
end_cur: 0,
end_bad_captures: 0,
end_captures: 0,
end_generated: 0,
end_good_quiets: 0,
}
}
pub fn new_evasions(
pos: &Position,
tt_move: Move,
ply: i32,
continuation_history: [&PieceToHistory; 6],
generate_all_legal_moves: bool,
) -> Self {
debug_assert!(pos.in_check());
let stage =
if tt_move.is_some() && pos.pseudo_legal_with_all(tt_move, generate_all_legal_moves) {
Stage::EvasionTT
} else {
Stage::EvasionInit
};
Self {
continuation_history: [
continuation_history[0] as *const _,
continuation_history[1] as *const _,
continuation_history[2] as *const _,
continuation_history[3] as *const _,
continuation_history[4] as *const _,
continuation_history[5] as *const _,
],
stage,
tt_move,
probcut_threshold: None,
depth: DEPTH_QS,
ply,
skip_quiets: false,
generate_all_legal_moves,
side_to_move: pos.side_to_move(),
pawn_history_index: pos.pawn_history_index(),
moves: ExtMoveBuffer::new(),
cur: 0,
end_cur: 0,
end_bad_captures: 0,
end_captures: 0,
end_generated: 0,
end_good_quiets: 0,
}
}
pub fn new_probcut(
pos: &Position,
tt_move: Move,
threshold: Value,
ply: i32,
continuation_history: [&PieceToHistory; 6],
generate_all_legal_moves: bool,
) -> Self {
debug_assert!(!pos.in_check());
let stage = if tt_move.is_some()
&& pos.capture_stage(tt_move)
&& pos.pseudo_legal_with_all(tt_move, generate_all_legal_moves)
{
Stage::ProbCutTT
} else {
Stage::ProbCutInit
};
Self {
continuation_history: [
continuation_history[0] as *const _,
continuation_history[1] as *const _,
continuation_history[2] as *const _,
continuation_history[3] as *const _,
continuation_history[4] as *const _,
continuation_history[5] as *const _,
],
stage,
tt_move,
probcut_threshold: Some(threshold),
depth: DEPTH_QS,
ply,
skip_quiets: false,
generate_all_legal_moves,
side_to_move: pos.side_to_move(),
pawn_history_index: pos.pawn_history_index(),
moves: ExtMoveBuffer::new(),
cur: 0,
end_cur: 0,
end_bad_captures: 0,
end_captures: 0,
end_generated: 0,
end_good_quiets: 0,
}
}
pub fn skip_quiets(&mut self) {
self.skip_quiets = true;
}
#[inline]
pub fn is_quiet_stage(&self) -> bool {
matches!(self.stage, Stage::QuietInit | Stage::GoodQuiet | Stage::BadQuiet)
}
#[inline]
pub fn stage(&self) -> Stage {
self.stage
}
pub fn next_move(&mut self, pos: &Position, history: &HistoryTables) -> Move {
loop {
match self.stage {
Stage::MainTT | Stage::EvasionTT | Stage::QSearchTT | Stage::ProbCutTT => {
self.stage = self.stage.next();
return self.tt_move;
}
Stage::CaptureInit | Stage::QCaptureInit | Stage::ProbCutInit => {
self.cur = 0;
self.end_bad_captures = 0;
self.moves.clear();
let count = if self.generate_all_legal_moves {
crate::movegen::generate_with_type(
pos,
crate::movegen::GenType::CapturesAll,
&mut self.moves,
None,
);
self.moves.len()
} else {
crate::movegen::generate_with_type(
pos,
crate::movegen::GenType::Captures,
&mut self.moves,
None,
);
self.moves.len()
};
self.end_cur = count;
self.end_captures = count;
self.score_captures(pos, history);
partial_insertion_sort(self.moves.as_mut_slice(), self.end_cur, i32::MIN);
self.stage = self.stage.next();
}
Stage::GoodCapture => {
if let Some(m) = self.select_good_capture(pos) {
return m;
}
self.stage = Stage::QuietInit;
}
Stage::QuietInit => {
if !self.skip_quiets {
let count = if self.generate_all_legal_moves {
let mut buf = ExtMoveBuffer::new();
crate::movegen::generate_with_type(
pos,
crate::movegen::GenType::QuietsAll,
&mut buf,
None,
);
let mut c = 0;
for ext in buf.iter() {
self.moves.set(self.end_captures + c, ExtMove::new(ext.mv, 0));
c += 1;
}
c
} else {
pos.generate_quiets(&mut self.moves, self.end_captures)
};
self.end_cur = self.end_captures + count;
self.end_generated = self.end_cur;
self.moves.set_len(self.end_cur);
self.cur = self.end_captures;
self.score_quiets(pos, history);
let limit = -3560 * self.depth;
let quiet_count = self.end_cur - self.end_captures;
let sorted_end = partial_insertion_sort(
&mut self.moves.as_mut_slice()[self.end_captures..],
quiet_count,
limit,
);
self.end_good_quiets = self.end_captures + sorted_end;
} else {
self.end_good_quiets = self.end_captures;
}
self.stage = Stage::GoodQuiet;
}
Stage::GoodQuiet => {
if !self.skip_quiets {
self.end_cur = self.end_generated;
if let Some(m) = self.select_good_quiet() {
return m;
}
}
self.cur = 0;
self.end_cur = self.end_bad_captures;
self.stage = Stage::BadCapture;
}
Stage::BadCapture => {
if let Some(m) = self.select_simple() {
return m;
}
self.cur = self.end_captures;
self.end_cur = self.end_generated;
self.stage = Stage::BadQuiet;
}
Stage::BadQuiet => {
if !self.skip_quiets {
if let Some(m) = self.select_bad_quiet() {
return m;
}
}
return Move::NONE;
}
Stage::EvasionInit => {
let count = if self.generate_all_legal_moves {
let mut buf = ExtMoveBuffer::new();
crate::movegen::generate_with_type(
pos,
crate::movegen::GenType::EvasionsAll,
&mut buf,
None,
);
let gen_count = buf.len();
for (i, ext) in buf.iter().enumerate() {
self.moves.set(i, ExtMove::new(ext.mv, 0));
}
self.moves.set_len(gen_count);
gen_count
} else {
pos.generate_evasions_ext(&mut self.moves)
};
self.cur = 0;
self.end_cur = count;
self.end_generated = count;
self.score_evasions(pos, history);
partial_insertion_sort(self.moves.as_mut_slice(), self.end_cur, i32::MIN);
self.stage = Stage::Evasion;
}
Stage::Evasion => {
return self.select_simple().unwrap_or(Move::NONE);
}
Stage::QCapture => {
return self.select_simple().unwrap_or(Move::NONE);
}
Stage::ProbCut => {
if let Some(th) = self.probcut_threshold {
return self.select_probcut(pos, th).unwrap_or(Move::NONE);
}
return Move::NONE;
}
}
}
}
#[inline]
fn cont_history(&self, idx: usize) -> &PieceToHistory {
unsafe { &*self.continuation_history[idx] }
}
fn score_captures(&mut self, pos: &Position, history: &HistoryTables) {
for i in self.cur..self.end_cur {
let ext = self.moves.get(i);
let m = ext.mv;
let to = m.to();
let pc = m.moved_piece_after();
let captured = pos.piece_on(to);
let captured_pt = captured.piece_type();
let mut value = history.capture_history.get(pc, to, captured_pt) as i32;
value += 7 * piece_value(captured);
self.moves.set_value(i, value);
}
}
fn score_quiets(&mut self, pos: &Position, history: &HistoryTables) {
let us = self.side_to_move;
let pawn_idx = self.pawn_history_index;
for i in self.cur..self.end_cur {
let ext = self.moves.get(i);
let m = ext.mv;
let to = m.to();
let pc = m.moved_piece_after();
let pt = pc.piece_type();
let mut value = 0i32;
let main_h = 2 * history.main_history.get(us, m) as i32;
value += main_h;
let pawn_h = 2 * history.pawn_history.get(pawn_idx, pc, to) as i32;
value += pawn_h;
let mut cont_h = 0i32;
for (idx, weight) in [(0, 1), (1, 1), (2, 1), (3, 1), (5, 1)] {
let ch = self.cont_history(idx);
cont_h += weight * ch.get(pc, to) as i32;
}
value += cont_h;
let check_bonus =
if pos.check_squares(pt).contains(to) && pos.see_ge(m, Value::new(-75)) {
16384
} else {
0
};
value += check_bonus;
let low_ply_h = if self.ply < LOW_PLY_HISTORY_SIZE as i32 {
let ply_idx = self.ply as usize;
8 * history.low_ply_history.get(ply_idx, m) as i32 / (1 + self.ply)
} else {
0
};
value += low_ply_h;
self.moves.set_value(i, value);
}
}
fn score_evasions(&mut self, pos: &Position, history: &HistoryTables) {
let us = self.side_to_move;
for i in self.cur..self.end_cur {
let ext = self.moves.get(i);
let m = ext.mv;
let to = m.to();
let pc = m.moved_piece_after();
if pos.capture_stage(m) {
let captured = pos.piece_on(to);
self.moves.set_value(i, piece_value(captured) + (1 << 28));
} else {
let mut value = history.main_history.get(us, m) as i32;
let ch = self.cont_history(0);
value += ch.get(pc, to) as i32;
self.moves.set_value(i, value);
}
}
}
fn select_good_capture(&mut self, pos: &Position) -> Option<Move> {
while self.cur < self.end_cur {
let ext = self.moves.get(self.cur);
self.cur += 1;
if ext.mv == self.tt_move {
continue;
}
let threshold = Value::new(-ext.value / 18);
if pos.see_ge(ext.mv, threshold) {
return Some(ext.mv);
} else {
self.moves.swap(self.end_bad_captures, self.cur - 1);
self.end_bad_captures += 1;
}
}
None
}
fn select_good_quiet(&mut self) -> Option<Move> {
const GOOD_QUIET_THRESHOLD: i32 = -14000;
while self.cur < self.end_cur {
let ext = self.moves.get(self.cur);
self.cur += 1;
if ext.mv == self.tt_move {
continue;
}
if ext.value > GOOD_QUIET_THRESHOLD {
return Some(ext.mv);
}
}
None
}
fn select_bad_quiet(&mut self) -> Option<Move> {
const GOOD_QUIET_THRESHOLD: i32 = -14000;
while self.cur < self.end_cur {
let ext = self.moves.get(self.cur);
self.cur += 1;
if ext.mv == self.tt_move {
continue;
}
if ext.value <= GOOD_QUIET_THRESHOLD {
return Some(ext.mv);
}
}
None
}
fn select_simple(&mut self) -> Option<Move> {
while self.cur < self.end_cur {
let ext = self.moves.get(self.cur);
self.cur += 1;
if ext.mv == self.tt_move {
continue;
}
return Some(ext.mv);
}
None
}
fn select_probcut(&mut self, pos: &Position, threshold: Value) -> Option<Move> {
while self.cur < self.end_cur {
let ext = self.moves.get(self.cur);
self.cur += 1;
if ext.mv == self.tt_move {
continue;
}
if pos.see_ge(ext.mv, threshold) {
return Some(ext.mv);
}
}
None
}
}
fn partial_insertion_sort(moves: &mut [ExtMove], end: usize, limit: i32) -> usize {
let mut sorted_end: usize = 0;
for p in 1..end {
if moves[p].value >= limit {
let tmp = moves[p];
sorted_end += 1;
moves[p] = moves[sorted_end];
let mut q = sorted_end;
while q > 0 && moves[q - 1].value < tmp.value {
moves[q] = moves[q - 1];
q -= 1;
}
moves[q] = tmp;
}
}
sorted_end
}
pub(crate) fn piece_value(pc: Piece) -> i32 {
if pc.is_none() {
return 0;
}
use PieceType::*;
match pc.piece_type() {
Pawn => 90,
Lance => 315,
Knight => 405,
Silver => 495,
Gold | ProPawn | ProLance | ProKnight | ProSilver => 540,
Bishop => 855,
Rook => 990,
Horse => 945,
Dragon => 1395,
King => 15000,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stage_next() {
assert_eq!(Stage::MainTT.next(), Stage::CaptureInit);
assert_eq!(Stage::CaptureInit.next(), Stage::GoodCapture);
assert_eq!(Stage::GoodCapture.next(), Stage::QuietInit);
assert_eq!(Stage::QuietInit.next(), Stage::GoodQuiet);
assert_eq!(Stage::GoodQuiet.next(), Stage::BadCapture);
assert_eq!(Stage::BadCapture.next(), Stage::BadQuiet);
assert_eq!(Stage::BadQuiet.next(), Stage::BadQuiet);
assert_eq!(Stage::EvasionTT.next(), Stage::EvasionInit);
assert_eq!(Stage::EvasionInit.next(), Stage::Evasion);
assert_eq!(Stage::Evasion.next(), Stage::Evasion);
assert_eq!(Stage::QSearchTT.next(), Stage::QCaptureInit);
assert_eq!(Stage::QCaptureInit.next(), Stage::QCapture);
assert_eq!(Stage::QCapture.next(), Stage::QCapture);
}
#[test]
fn test_partial_insertion_sort() {
let mut moves = vec![
ExtMove::new(Move::NONE, 100),
ExtMove::new(Move::NONE, 50),
ExtMove::new(Move::NONE, 200),
ExtMove::new(Move::NONE, 10),
ExtMove::new(Move::NONE, 150),
];
let len = moves.len();
let sorted_end = partial_insertion_sort(&mut moves, len, 100);
assert_eq!(sorted_end, 2);
assert_eq!(moves[0].value, 200);
assert_eq!(moves[1].value, 150);
assert_eq!(moves[2].value, 100);
}
#[test]
fn test_partial_insertion_sort_boundary_value() {
let mut moves = vec![
ExtMove::new(Move::NONE, 99),
ExtMove::new(Move::NONE, 100), ExtMove::new(Move::NONE, 101),
];
let len = moves.len();
let sorted_end = partial_insertion_sort(&mut moves, len, 100);
assert_eq!(sorted_end, 2);
assert_eq!(moves[0].value, 101);
assert_eq!(moves[1].value, 100);
assert_eq!(moves[2].value, 99);
}
#[test]
fn test_partial_insertion_sort_large_array() {
let mut moves: Vec<ExtMove> = (0..20).map(|i| ExtMove::new(Move::NONE, i * 10)).collect();
let len = moves.len();
let sorted_end = partial_insertion_sort(&mut moves, len, 100);
assert_eq!(sorted_end, 10);
assert_eq!(moves[0].value, 190);
assert_eq!(moves[1].value, 180);
assert_eq!(moves[9].value, 100);
}
#[test]
fn test_partial_insertion_sort_no_good_moves() {
let mut moves = vec![
ExtMove::new(Move::NONE, 10),
ExtMove::new(Move::NONE, 20),
ExtMove::new(Move::NONE, 30),
];
let len = moves.len();
let sorted_end = partial_insertion_sort(&mut moves, len, 100);
assert_eq!(sorted_end, 0);
}
#[test]
fn test_partial_insertion_sort_all_good_moves() {
let mut moves = vec![
ExtMove::new(Move::NONE, 100),
ExtMove::new(Move::NONE, 200),
ExtMove::new(Move::NONE, 150),
];
let len = moves.len();
let sorted_end = partial_insertion_sort(&mut moves, len, 50);
assert_eq!(sorted_end, 2);
assert_eq!(moves[0].value, 200);
assert_eq!(moves[1].value, 150);
assert_eq!(moves[2].value, 100);
}
#[test]
fn test_partial_insertion_sort_full_sort() {
let mut moves = vec![
ExtMove::new(Move::NONE, 50),
ExtMove::new(Move::NONE, -100),
ExtMove::new(Move::NONE, 200),
ExtMove::new(Move::NONE, 0),
];
let len = moves.len();
let sorted_end = partial_insertion_sort(&mut moves, len, i32::MIN);
assert_eq!(sorted_end, 3);
assert_eq!(moves[0].value, 200);
assert_eq!(moves[1].value, 50);
assert_eq!(moves[2].value, 0);
assert_eq!(moves[3].value, -100);
}
#[test]
fn test_partial_insertion_sort_empty() {
let mut moves: Vec<ExtMove> = vec![];
let sorted_end = partial_insertion_sort(&mut moves, 0, 100);
assert_eq!(sorted_end, 0);
}
#[test]
fn test_partial_insertion_sort_single_element() {
let mut moves = vec![ExtMove::new(Move::NONE, 150)];
let sorted_end = partial_insertion_sort(&mut moves, 1, 100);
assert_eq!(sorted_end, 0);
assert_eq!(moves[0].value, 150);
let mut moves2 = vec![ExtMove::new(Move::NONE, 50)];
let sorted_end2 = partial_insertion_sort(&mut moves2, 1, 100);
assert_eq!(sorted_end2, 0);
}
#[test]
fn test_piece_value() {
assert_eq!(piece_value(Piece::B_PAWN), 90);
assert_eq!(piece_value(Piece::W_GOLD), 540);
assert_eq!(piece_value(Piece::B_ROOK), 990);
assert_eq!(piece_value(Piece::W_HORSE), 945);
assert_eq!(piece_value(Piece::W_DRAGON), 1395);
}
#[test]
fn test_partial_insertion_sort_order() {
let mut moves = vec![
ExtMove::new(Move::NONE, 100), ExtMove::new(Move::NONE, -200), ExtMove::new(Move::NONE, 50), ExtMove::new(Move::NONE, 200), ExtMove::new(Move::NONE, -100), ];
let len = moves.len();
let sorted_end = partial_insertion_sort(&mut moves, len, 0);
assert_eq!(sorted_end, 2);
assert_eq!(moves[0].value, 200);
assert_eq!(moves[1].value, 100);
assert_eq!(moves[2].value, 50);
assert!(moves[3].value < 0);
assert!(moves[4].value < 0);
}
}