use super::{HistoryTables, LOW_PLY_HISTORY_SIZE, PieceToHistory};
use crate::movegen::{ExtMove, ExtMoveBuffer};
use crate::position::Position;
use crate::types::{Color, DEPTH_QS, Depth, Move, Piece, PieceType, Value};
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[repr(u8)]
pub enum Stage {
MainTT,
ProbCutTT,
CaptureInit,
ProbCutInit,
ProbCut,
GoodCapture,
QuietInit,
GoodQuiet,
BadCapture,
BadQuiet,
EvasionTT,
EvasionInit,
Evasion,
QSearchTT,
QCaptureInit,
QCapture,
}
impl Stage {
pub fn next(self) -> Self {
match self {
Stage::MainTT => Stage::CaptureInit,
Stage::CaptureInit => Stage::GoodCapture,
Stage::ProbCutTT => Stage::ProbCutInit,
Stage::ProbCutInit => Stage::ProbCut,
Stage::ProbCut => Stage::ProbCut, Stage::GoodCapture => Stage::QuietInit,
Stage::QuietInit => Stage::GoodQuiet,
Stage::GoodQuiet => Stage::BadCapture,
Stage::BadCapture => Stage::BadQuiet,
Stage::BadQuiet => Stage::BadQuiet,
Stage::EvasionTT => Stage::EvasionInit,
Stage::EvasionInit => Stage::Evasion,
Stage::Evasion => Stage::Evasion,
Stage::QSearchTT => Stage::QCaptureInit,
Stage::QCaptureInit => Stage::QCapture,
Stage::QCapture => Stage::QCapture, }
}
}
pub struct MovePicker {
continuation_history: [*const PieceToHistory; 6],
stage: Stage,
tt_move: Move,
probcut_threshold: Option<Value>,
depth: Depth,
ply: i32,
skip_quiets: bool,
generate_all_legal_moves: bool,
side_to_move: Color,
pawn_history_index: usize,
moves: ExtMoveBuffer,
cur: usize,
end_cur: usize,
end_bad_captures: usize,
end_captures: usize,
end_generated: usize,
end_good_quiets: usize,
}
impl MovePicker {
pub fn new(
pos: &Position,
tt_move: Move,
depth: Depth,
ply: i32,
continuation_history: [&PieceToHistory; 6],
generate_all_legal_moves: bool,
) -> Self {
let stage = if pos.in_check() {
if tt_move.is_some() && pos.pseudo_legal_with_all(tt_move, generate_all_legal_moves) {
Stage::EvasionTT
} else {
Stage::EvasionInit
}
} else if depth > DEPTH_QS {
if tt_move.is_some() && pos.pseudo_legal_with_all(tt_move, generate_all_legal_moves) {
Stage::MainTT
} else {
Stage::CaptureInit
}
} else {
if tt_move.is_some() && pos.pseudo_legal_with_all(tt_move, generate_all_legal_moves) {
Stage::QSearchTT
} else {
Stage::QCaptureInit
}
};
Self {
continuation_history: [
continuation_history[0] as *const _,
continuation_history[1] as *const _,
continuation_history[2] as *const _,
continuation_history[3] as *const _,
continuation_history[4] as *const _,
continuation_history[5] as *const _,
],
stage,
tt_move,
probcut_threshold: None,
depth,
ply,
skip_quiets: false,
generate_all_legal_moves,
side_to_move: pos.side_to_move(),
pawn_history_index: pos.pawn_history_index(),
moves: ExtMoveBuffer::new(),
cur: 0,
end_cur: 0,
end_bad_captures: 0,
end_captures: 0,
end_generated: 0,
end_good_quiets: 0,
}
}
pub fn new_evasions(
pos: &Position,
tt_move: Move,
ply: i32,
continuation_history: [&PieceToHistory; 6],
generate_all_legal_moves: bool,
) -> Self {
debug_assert!(pos.in_check());
let stage =
if tt_move.is_some() && pos.pseudo_legal_with_all(tt_move, generate_all_legal_moves) {
Stage::EvasionTT
} else {
Stage::EvasionInit
};
Self {
continuation_history: [
continuation_history[0] as *const _,
continuation_history[1] as *const _,
continuation_history[2] as *const _,
continuation_history[3] as *const _,
continuation_history[4] as *const _,
continuation_history[5] as *const _,
],
stage,
tt_move,
probcut_threshold: None,
depth: DEPTH_QS,
ply,
skip_quiets: false,
generate_all_legal_moves,
side_to_move: pos.side_to_move(),
pawn_history_index: pos.pawn_history_index(),
moves: ExtMoveBuffer::new(),
cur: 0,
end_cur: 0,
end_bad_captures: 0,
end_captures: 0,
end_generated: 0,
end_good_quiets: 0,
}
}
pub fn new_probcut(
pos: &Position,
tt_move: Move,
threshold: Value,
ply: i32,
continuation_history: [&PieceToHistory; 6],
generate_all_legal_moves: bool,
) -> Self {
debug_assert!(!pos.in_check());
let stage = if tt_move.is_some()
&& pos.capture_stage(tt_move)
&& pos.pseudo_legal_with_all(tt_move, generate_all_legal_moves)
{
Stage::ProbCutTT
} else {
Stage::ProbCutInit
};
Self {
continuation_history: [
continuation_history[0] as *const _,
continuation_history[1] as *const _,
continuation_history[2] as *const _,
continuation_history[3] as *const _,
continuation_history[4] as *const _,
continuation_history[5] as *const _,
],
stage,
tt_move,
probcut_threshold: Some(threshold),
depth: DEPTH_QS,
ply,
skip_quiets: false,
generate_all_legal_moves,
side_to_move: pos.side_to_move(),
pawn_history_index: pos.pawn_history_index(),
moves: ExtMoveBuffer::new(),
cur: 0,
end_cur: 0,
end_bad_captures: 0,
end_captures: 0,
end_generated: 0,
end_good_quiets: 0,
}
}
pub fn skip_quiets(&mut self) {
self.skip_quiets = true;
}
#[inline]
pub fn is_quiet_stage(&self) -> bool {
matches!(self.stage, Stage::QuietInit | Stage::GoodQuiet | Stage::BadQuiet)
}
#[inline]
pub fn stage(&self) -> Stage {
self.stage
}
pub fn next_move(&mut self, pos: &Position, history: &HistoryTables) -> Move {
loop {
match self.stage {
Stage::MainTT | Stage::EvasionTT | Stage::QSearchTT | Stage::ProbCutTT => {
self.stage = self.stage.next();
return self.tt_move;
}
Stage::CaptureInit | Stage::QCaptureInit | Stage::ProbCutInit => {
self.cur = 0;
self.end_bad_captures = 0;
self.moves.clear();
let count = if self.generate_all_legal_moves {
crate::movegen::generate_with_type(
pos,
crate::movegen::GenType::CapturesAll,
&mut self.moves,
None,
);
self.moves.len()
} else {
crate::movegen::generate_with_type(
pos,
crate::movegen::GenType::Captures,
&mut self.moves,
None,
);
self.moves.len()
};
self.end_cur = count;
self.end_captures = count;
self.score_captures(pos, history);
partial_insertion_sort(self.moves.as_mut_slice(), self.end_cur, i32::MIN);
self.stage = self.stage.next();
}
Stage::GoodCapture => {
if let Some(m) = self.select_good_capture(pos) {
return m;
}
self.stage = Stage::QuietInit;
}
Stage::QuietInit => {
if !self.skip_quiets {
let count = if self.generate_all_legal_moves {
crate::movegen::generate_with_type(
pos,
crate::movegen::GenType::QuietsAll,
&mut self.moves,
None,
);
self.moves.len() - self.end_captures
} else {
pos.generate_quiets(&mut self.moves, self.end_captures)
};
self.end_cur = self.end_captures + count;
self.end_generated = self.end_cur;
self.moves.set_len(self.end_cur);
self.cur = self.end_captures;
self.score_quiets(pos, history);
let limit = -3560 * self.depth;
let quiet_count = self.end_cur - self.end_captures;
let sorted_end = partial_insertion_sort(
&mut self.moves.as_mut_slice()[self.end_captures..],
quiet_count,
limit,
);
self.end_good_quiets = self.end_captures + sorted_end;
} else {
self.end_good_quiets = self.end_captures;
}
self.stage = Stage::GoodQuiet;
}
Stage::GoodQuiet => {
if !self.skip_quiets {
self.end_cur = self.end_generated;
if let Some(m) = self.select_good_quiet() {
return m;
}
}
self.cur = 0;
self.end_cur = self.end_bad_captures;
self.stage = Stage::BadCapture;
}
Stage::BadCapture => {
if let Some(m) = self.select_simple() {
return m;
}
self.cur = self.end_captures;
self.end_cur = self.end_generated;
self.stage = Stage::BadQuiet;
}
Stage::BadQuiet => {
if !self.skip_quiets {
if let Some(m) = self.select_bad_quiet() {
return m;
}
}
return Move::NONE;
}
Stage::EvasionInit => {
let count = if self.generate_all_legal_moves {
self.moves.clear();
crate::movegen::generate_with_type(
pos,
crate::movegen::GenType::EvasionsAll,
&mut self.moves,
None,
);
self.moves.len()
} else {
pos.generate_evasions_ext(&mut self.moves)
};
self.cur = 0;
self.end_cur = count;
self.end_generated = count;
self.score_evasions(pos, history);
partial_insertion_sort(self.moves.as_mut_slice(), self.end_cur, i32::MIN);
self.stage = Stage::Evasion;
}
Stage::Evasion => {
return self.select_simple().unwrap_or(Move::NONE);
}
Stage::QCapture => {
return self.select_simple().unwrap_or(Move::NONE);
}
Stage::ProbCut => {
if let Some(th) = self.probcut_threshold {
return self.select_probcut(pos, th).unwrap_or(Move::NONE);
}
return Move::NONE;
}
}
}
}
fn score_captures(&mut self, pos: &Position, history: &HistoryTables) {
for ext in &mut self.moves.as_mut_slice()[self.cur..self.end_cur] {
let m = ext.mv;
let to = m.to();
let pc = m.moved_piece_after();
let captured = pos.piece_on(to);
let captured_pt = captured.piece_type();
let mut value = history.capture_history.get(pc, to, captured_pt) as i32;
value += 7 * piece_value(captured);
ext.value = value;
}
}
fn score_quiets(&mut self, pos: &Position, history: &HistoryTables) {
let us = self.side_to_move;
let pawn_idx = self.pawn_history_index;
let moves = &mut self.moves.as_mut_slice()[self.cur..self.end_cur];
let (ch0, ch1, ch2, ch3, ch5) = unsafe {
(
&*self.continuation_history[0],
&*self.continuation_history[1],
&*self.continuation_history[2],
&*self.continuation_history[3],
&*self.continuation_history[5],
)
};
if self.ply < LOW_PLY_HISTORY_SIZE as i32 {
let low_ply_idx = self.ply as usize;
let low_ply_div = 1 + self.ply;
for ext in moves {
let m = ext.mv;
let to = m.to();
let pc = m.moved_piece_after();
let pt = pc.piece_type();
let mut value = 2 * history.main_history.get(us, m) as i32;
value += 2 * history.pawn_history.get(pawn_idx, pc, to) as i32;
value += ch0.get(pc, to) as i32;
value += ch1.get(pc, to) as i32;
value += ch2.get(pc, to) as i32;
value += ch3.get(pc, to) as i32;
value += ch5.get(pc, to) as i32;
if pos.check_squares(pt).contains(to) && pos.see_ge(m, Value::new(-75)) {
value += 16384;
}
value += 8 * history.low_ply_history.get(low_ply_idx, m) as i32 / low_ply_div;
ext.value = value;
}
} else {
for ext in moves {
let m = ext.mv;
let to = m.to();
let pc = m.moved_piece_after();
let pt = pc.piece_type();
let mut value = 2 * history.main_history.get(us, m) as i32;
value += 2 * history.pawn_history.get(pawn_idx, pc, to) as i32;
value += ch0.get(pc, to) as i32;
value += ch1.get(pc, to) as i32;
value += ch2.get(pc, to) as i32;
value += ch3.get(pc, to) as i32;
value += ch5.get(pc, to) as i32;
if pos.check_squares(pt).contains(to) && pos.see_ge(m, Value::new(-75)) {
value += 16384;
}
ext.value = value;
}
}
}
fn score_evasions(&mut self, pos: &Position, history: &HistoryTables) {
let us = self.side_to_move;
let ch = unsafe { &*self.continuation_history[0] };
for ext in &mut self.moves.as_mut_slice()[self.cur..self.end_cur] {
let m = ext.mv;
let to = m.to();
let pc = m.moved_piece_after();
if pos.capture_stage(m) {
let captured = pos.piece_on(to);
ext.value = piece_value(captured) + (1 << 28);
} else {
let mut value = history.main_history.get(us, m) as i32;
value += ch.get(pc, to) as i32;
ext.value = value;
}
}
}
fn select_good_capture(&mut self, pos: &Position) -> Option<Move> {
while self.cur < self.end_cur {
let ext = self.moves.get(self.cur);
self.cur += 1;
if ext.mv == self.tt_move {
continue;
}
let threshold = Value::new(-ext.value / 18);
if pos.see_ge(ext.mv, threshold) {
return Some(ext.mv);
} else {
self.moves.swap(self.end_bad_captures, self.cur - 1);
self.end_bad_captures += 1;
}
}
None
}
fn select_good_quiet(&mut self) -> Option<Move> {
const GOOD_QUIET_THRESHOLD: i32 = -14000;
while self.cur < self.end_cur {
let ext = self.moves.get(self.cur);
self.cur += 1;
if ext.mv == self.tt_move {
continue;
}
if ext.value > GOOD_QUIET_THRESHOLD {
return Some(ext.mv);
}
}
None
}
fn select_bad_quiet(&mut self) -> Option<Move> {
const GOOD_QUIET_THRESHOLD: i32 = -14000;
while self.cur < self.end_cur {
let ext = self.moves.get(self.cur);
self.cur += 1;
if ext.mv == self.tt_move {
continue;
}
if ext.value <= GOOD_QUIET_THRESHOLD {
return Some(ext.mv);
}
}
None
}
fn select_simple(&mut self) -> Option<Move> {
while self.cur < self.end_cur {
let ext = self.moves.get(self.cur);
self.cur += 1;
if ext.mv == self.tt_move {
continue;
}
return Some(ext.mv);
}
None
}
fn select_probcut(&mut self, pos: &Position, threshold: Value) -> Option<Move> {
while self.cur < self.end_cur {
let ext = self.moves.get(self.cur);
self.cur += 1;
if ext.mv == self.tt_move {
continue;
}
if pos.see_ge(ext.mv, threshold) {
return Some(ext.mv);
}
}
None
}
}
#[inline]
fn partial_insertion_sort(moves: &mut [ExtMove], end: usize, limit: i32) -> usize {
if end <= 1 {
return 0;
}
let mut sorted_end: usize = 0;
let base = moves.as_mut_ptr();
for p in 1..end {
unsafe {
let tmp = *base.add(p);
if tmp.value >= limit {
sorted_end += 1;
if p != sorted_end {
*base.add(p) = *base.add(sorted_end);
}
let mut q = sorted_end;
while q > 0 && (*base.add(q - 1)).value < tmp.value {
*base.add(q) = *base.add(q - 1);
q -= 1;
}
*base.add(q) = tmp;
}
}
}
sorted_end
}
#[inline]
pub(crate) fn piece_value(pc: Piece) -> i32 {
if pc.is_none() {
return 0;
}
use PieceType::*;
match pc.piece_type() {
Pawn => 90,
Lance => 315,
Knight => 405,
Silver => 495,
Gold | ProPawn | ProLance | ProKnight | ProSilver => 540,
Bishop => 855,
Rook => 990,
Horse => 945,
Dragon => 1395,
King => 15000,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stage_next() {
assert_eq!(Stage::MainTT.next(), Stage::CaptureInit);
assert_eq!(Stage::CaptureInit.next(), Stage::GoodCapture);
assert_eq!(Stage::GoodCapture.next(), Stage::QuietInit);
assert_eq!(Stage::QuietInit.next(), Stage::GoodQuiet);
assert_eq!(Stage::GoodQuiet.next(), Stage::BadCapture);
assert_eq!(Stage::BadCapture.next(), Stage::BadQuiet);
assert_eq!(Stage::BadQuiet.next(), Stage::BadQuiet);
assert_eq!(Stage::EvasionTT.next(), Stage::EvasionInit);
assert_eq!(Stage::EvasionInit.next(), Stage::Evasion);
assert_eq!(Stage::Evasion.next(), Stage::Evasion);
assert_eq!(Stage::QSearchTT.next(), Stage::QCaptureInit);
assert_eq!(Stage::QCaptureInit.next(), Stage::QCapture);
assert_eq!(Stage::QCapture.next(), Stage::QCapture);
}
#[test]
fn test_partial_insertion_sort() {
let mut moves = vec![
ExtMove::new(Move::NONE, 100),
ExtMove::new(Move::NONE, 50),
ExtMove::new(Move::NONE, 200),
ExtMove::new(Move::NONE, 10),
ExtMove::new(Move::NONE, 150),
];
let len = moves.len();
let sorted_end = partial_insertion_sort(&mut moves, len, 100);
assert_eq!(sorted_end, 2);
assert_eq!(moves[0].value, 200);
assert_eq!(moves[1].value, 150);
assert_eq!(moves[2].value, 100);
}
#[test]
fn test_partial_insertion_sort_boundary_value() {
let mut moves = vec![
ExtMove::new(Move::NONE, 99),
ExtMove::new(Move::NONE, 100), ExtMove::new(Move::NONE, 101),
];
let len = moves.len();
let sorted_end = partial_insertion_sort(&mut moves, len, 100);
assert_eq!(sorted_end, 2);
assert_eq!(moves[0].value, 101);
assert_eq!(moves[1].value, 100);
assert_eq!(moves[2].value, 99);
}
#[test]
fn test_partial_insertion_sort_large_array() {
let mut moves: Vec<ExtMove> = (0..20).map(|i| ExtMove::new(Move::NONE, i * 10)).collect();
let len = moves.len();
let sorted_end = partial_insertion_sort(&mut moves, len, 100);
assert_eq!(sorted_end, 10);
assert_eq!(moves[0].value, 190);
assert_eq!(moves[1].value, 180);
assert_eq!(moves[9].value, 100);
}
#[test]
fn test_partial_insertion_sort_no_good_moves() {
let mut moves = vec![
ExtMove::new(Move::NONE, 10),
ExtMove::new(Move::NONE, 20),
ExtMove::new(Move::NONE, 30),
];
let len = moves.len();
let sorted_end = partial_insertion_sort(&mut moves, len, 100);
assert_eq!(sorted_end, 0);
}
#[test]
fn test_partial_insertion_sort_all_good_moves() {
let mut moves = vec![
ExtMove::new(Move::NONE, 100),
ExtMove::new(Move::NONE, 200),
ExtMove::new(Move::NONE, 150),
];
let len = moves.len();
let sorted_end = partial_insertion_sort(&mut moves, len, 50);
assert_eq!(sorted_end, 2);
assert_eq!(moves[0].value, 200);
assert_eq!(moves[1].value, 150);
assert_eq!(moves[2].value, 100);
}
#[test]
fn test_partial_insertion_sort_full_sort() {
let mut moves = vec![
ExtMove::new(Move::NONE, 50),
ExtMove::new(Move::NONE, -100),
ExtMove::new(Move::NONE, 200),
ExtMove::new(Move::NONE, 0),
];
let len = moves.len();
let sorted_end = partial_insertion_sort(&mut moves, len, i32::MIN);
assert_eq!(sorted_end, 3);
assert_eq!(moves[0].value, 200);
assert_eq!(moves[1].value, 50);
assert_eq!(moves[2].value, 0);
assert_eq!(moves[3].value, -100);
}
#[test]
fn test_partial_insertion_sort_empty() {
let mut moves: Vec<ExtMove> = vec![];
let sorted_end = partial_insertion_sort(&mut moves, 0, 100);
assert_eq!(sorted_end, 0);
}
#[test]
fn test_partial_insertion_sort_single_element() {
let mut moves = vec![ExtMove::new(Move::NONE, 150)];
let sorted_end = partial_insertion_sort(&mut moves, 1, 100);
assert_eq!(sorted_end, 0);
assert_eq!(moves[0].value, 150);
let mut moves2 = vec![ExtMove::new(Move::NONE, 50)];
let sorted_end2 = partial_insertion_sort(&mut moves2, 1, 100);
assert_eq!(sorted_end2, 0);
}
#[test]
fn test_piece_value() {
assert_eq!(piece_value(Piece::B_PAWN), 90);
assert_eq!(piece_value(Piece::W_GOLD), 540);
assert_eq!(piece_value(Piece::B_ROOK), 990);
assert_eq!(piece_value(Piece::W_HORSE), 945);
assert_eq!(piece_value(Piece::W_DRAGON), 1395);
}
#[test]
fn test_partial_insertion_sort_order() {
let mut moves = vec![
ExtMove::new(Move::NONE, 100), ExtMove::new(Move::NONE, -200), ExtMove::new(Move::NONE, 50), ExtMove::new(Move::NONE, 200), ExtMove::new(Move::NONE, -100), ];
let len = moves.len();
let sorted_end = partial_insertion_sort(&mut moves, len, 0);
assert_eq!(sorted_end, 2);
assert_eq!(moves[0].value, 200);
assert_eq!(moves[1].value, 100);
assert_eq!(moves[2].value, 50);
assert!(moves[3].value < 0);
assert!(moves[4].value < 0);
}
}