use super::{
ButterflyHistory, CapturePieceToHistory, LowPlyHistory, PawnHistory, PieceToHistory,
LOW_PLY_HISTORY_SIZE,
};
use crate::movegen::{ExtMove, ExtMoveBuffer};
use crate::position::Position;
use crate::types::{Depth, Move, Piece, PieceType, Value, DEPTH_QS};
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[repr(u8)]
pub enum Stage {
MainTT,
ProbCutTT,
CaptureInit,
ProbCutInit,
ProbCut,
GoodCapture,
QuietInit,
GoodQuiet,
BadCapture,
BadQuiet,
EvasionTT,
EvasionInit,
Evasion,
QSearchTT,
QCaptureInit,
QCapture,
}
impl Stage {
pub fn next(self) -> Self {
match self {
Stage::MainTT => Stage::CaptureInit,
Stage::CaptureInit => Stage::GoodCapture,
Stage::ProbCutTT => Stage::ProbCutInit,
Stage::ProbCutInit => Stage::ProbCut,
Stage::ProbCut => Stage::ProbCut, Stage::GoodCapture => Stage::QuietInit,
Stage::QuietInit => Stage::GoodQuiet,
Stage::GoodQuiet => Stage::BadCapture,
Stage::BadCapture => Stage::BadQuiet,
Stage::BadQuiet => Stage::BadQuiet,
Stage::EvasionTT => Stage::EvasionInit,
Stage::EvasionInit => Stage::Evasion,
Stage::Evasion => Stage::Evasion,
Stage::QSearchTT => Stage::QCaptureInit,
Stage::QCaptureInit => Stage::QCapture,
Stage::QCapture => Stage::QCapture, }
}
}
pub struct MovePicker<'a> {
pos: &'a Position,
main_history: &'a ButterflyHistory,
low_ply_history: &'a LowPlyHistory,
capture_history: &'a CapturePieceToHistory,
continuation_history: [&'a PieceToHistory; 6],
pawn_history: &'a PawnHistory,
stage: Stage,
tt_move: Move,
probcut_threshold: Option<Value>,
depth: Depth,
ply: i32,
skip_quiets: bool,
generate_all_legal_moves: bool,
moves: ExtMoveBuffer,
cur: usize,
end_cur: usize,
end_bad_captures: usize,
end_captures: usize,
end_generated: usize,
end_good_quiets: usize,
}
impl<'a> MovePicker<'a> {
pub fn new(
pos: &'a Position,
tt_move: Move,
depth: Depth,
main_history: &'a ButterflyHistory,
low_ply_history: &'a LowPlyHistory,
capture_history: &'a CapturePieceToHistory,
continuation_history: [&'a PieceToHistory; 6],
pawn_history: &'a PawnHistory,
ply: i32,
generate_all_legal_moves: bool,
) -> Self {
let stage = if pos.in_check() {
if tt_move.is_some() && pos.pseudo_legal_with_all(tt_move, generate_all_legal_moves) {
Stage::EvasionTT
} else {
Stage::EvasionInit
}
} else if depth > DEPTH_QS {
if tt_move.is_some() && pos.pseudo_legal_with_all(tt_move, generate_all_legal_moves) {
Stage::MainTT
} else {
Stage::CaptureInit
}
} else {
if tt_move.is_some() && pos.pseudo_legal_with_all(tt_move, generate_all_legal_moves) {
Stage::QSearchTT
} else {
Stage::QCaptureInit
}
};
Self {
pos,
main_history,
low_ply_history,
capture_history,
continuation_history,
pawn_history,
stage,
tt_move,
probcut_threshold: None,
depth,
ply,
skip_quiets: false,
generate_all_legal_moves,
moves: ExtMoveBuffer::new(),
cur: 0,
end_cur: 0,
end_bad_captures: 0,
end_captures: 0,
end_generated: 0,
end_good_quiets: 0,
}
}
pub fn new_evasions(
pos: &'a Position,
tt_move: Move,
main_history: &'a ButterflyHistory,
low_ply_history: &'a LowPlyHistory,
capture_history: &'a CapturePieceToHistory,
continuation_history: [&'a PieceToHistory; 6],
pawn_history: &'a PawnHistory,
ply: i32,
generate_all_legal_moves: bool,
) -> Self {
debug_assert!(pos.in_check());
let stage =
if tt_move.is_some() && pos.pseudo_legal_with_all(tt_move, generate_all_legal_moves) {
Stage::EvasionTT
} else {
Stage::EvasionInit
};
Self {
pos,
main_history,
low_ply_history,
capture_history,
continuation_history,
pawn_history,
stage,
tt_move,
probcut_threshold: None,
depth: DEPTH_QS,
ply,
skip_quiets: false,
generate_all_legal_moves,
moves: ExtMoveBuffer::new(),
cur: 0,
end_cur: 0,
end_bad_captures: 0,
end_captures: 0,
end_generated: 0,
end_good_quiets: 0,
}
}
pub fn new_probcut(
pos: &'a Position,
tt_move: Move,
threshold: Value,
main_history: &'a ButterflyHistory,
low_ply_history: &'a LowPlyHistory,
capture_history: &'a CapturePieceToHistory,
continuation_history: [&'a PieceToHistory; 6],
pawn_history: &'a PawnHistory,
ply: i32,
generate_all_legal_moves: bool,
) -> Self {
debug_assert!(!pos.in_check());
let stage = if tt_move.is_some()
&& pos.is_capture(tt_move)
&& pos.pseudo_legal_with_all(tt_move, generate_all_legal_moves)
&& pos.see_ge(tt_move, threshold)
{
Stage::ProbCutTT
} else {
Stage::ProbCutInit
};
Self {
pos,
main_history,
low_ply_history,
capture_history,
continuation_history,
pawn_history,
stage,
tt_move,
probcut_threshold: Some(threshold),
depth: DEPTH_QS,
ply,
skip_quiets: false,
generate_all_legal_moves,
moves: ExtMoveBuffer::new(),
cur: 0,
end_cur: 0,
end_bad_captures: 0,
end_captures: 0,
end_generated: 0,
end_good_quiets: 0,
}
}
pub fn skip_quiet_moves(&mut self) {
self.skip_quiets = true;
}
pub fn next_move(&mut self) -> Move {
loop {
match self.stage {
Stage::MainTT | Stage::EvasionTT | Stage::QSearchTT | Stage::ProbCutTT => {
self.stage = self.stage.next();
return self.tt_move;
}
Stage::CaptureInit | Stage::QCaptureInit | Stage::ProbCutInit => {
self.cur = 0;
self.end_bad_captures = 0;
let count = if self.generate_all_legal_moves {
let gen_type = if matches!(self.stage, Stage::ProbCutInit) {
crate::movegen::GenType::CapturesAll
} else {
crate::movegen::GenType::CapturesProPlusAll
};
let mut buf = ExtMoveBuffer::new();
crate::movegen::generate_with_type(self.pos, gen_type, &mut buf, None);
let mut c = 0;
for ext in buf.iter() {
if self.pos.is_capture(ext.mv) {
self.moves.set(c, ExtMove::new(ext.mv, 0));
c += 1;
}
}
self.moves.set_len(c);
c
} else if matches!(self.stage, Stage::ProbCutInit) {
let mut buf = ExtMoveBuffer::new();
crate::movegen::generate_with_type(
self.pos,
crate::movegen::GenType::CapturesProPlus,
&mut buf,
None,
);
let mut tmp_count = 0usize;
for ext in buf.iter() {
if self.pos.is_capture(ext.mv) {
self.moves.set(tmp_count, ExtMove::new(ext.mv, 0));
tmp_count += 1;
}
}
self.moves.set_len(tmp_count);
tmp_count
} else {
self.pos.generate_captures(&mut self.moves)
};
self.end_cur = count;
self.end_captures = count;
self.score_captures();
partial_insertion_sort(self.moves.as_mut_slice(), self.end_cur, i32::MIN);
self.stage = self.stage.next();
}
Stage::GoodCapture => {
if let Some(m) = self.select_good_capture() {
return m;
}
self.stage = Stage::QuietInit;
}
Stage::QuietInit => {
if !self.skip_quiets {
let count = if self.generate_all_legal_moves {
let mut buf = ExtMoveBuffer::new();
crate::movegen::generate_with_type(
self.pos,
crate::movegen::GenType::QuietsAll,
&mut buf,
None,
);
let mut c = 0;
for ext in buf.iter() {
if !self.pos.is_capture(ext.mv) {
self.moves.set(self.end_captures + c, ExtMove::new(ext.mv, 0));
c += 1;
}
}
c
} else {
self.pos.generate_quiets(&mut self.moves, self.end_captures)
};
self.end_cur = self.end_captures + count;
self.end_generated = self.end_cur;
self.moves.set_len(self.end_cur);
self.cur = self.end_captures;
self.score_quiets();
let limit = -3560 * self.depth;
let quiet_count = self.end_cur - self.end_captures;
let good_count = partial_insertion_sort(
&mut self.moves.as_mut_slice()[self.end_captures..],
quiet_count,
limit,
);
self.end_good_quiets = self.end_captures + good_count;
} else {
self.end_good_quiets = self.end_captures;
}
self.stage = Stage::GoodQuiet;
}
Stage::GoodQuiet => {
if !self.skip_quiets {
self.end_cur = self.end_good_quiets;
if let Some(m) = self.select(|_, _| true) {
return m;
}
}
self.cur = 0;
self.end_cur = self.end_bad_captures;
self.stage = Stage::BadCapture;
}
Stage::BadCapture => {
if let Some(m) = self.select(|_, _| true) {
return m;
}
self.cur = self.end_good_quiets;
self.end_cur = self.end_generated;
self.stage = Stage::BadQuiet;
}
Stage::BadQuiet => {
if !self.skip_quiets {
if let Some(m) = self.select(|_, _| true) {
return m;
}
}
return Move::NONE;
}
Stage::EvasionInit => {
let count = if self.generate_all_legal_moves {
let mut buf = ExtMoveBuffer::new();
crate::movegen::generate_with_type(
self.pos,
crate::movegen::GenType::EvasionsAll,
&mut buf,
None,
);
let gen_count = buf.len();
for (i, ext) in buf.iter().enumerate() {
self.moves.set(i, ExtMove::new(ext.mv, 0));
}
self.moves.set_len(gen_count);
gen_count
} else {
self.pos.generate_evasions_ext(&mut self.moves)
};
self.cur = 0;
self.end_cur = count;
self.end_generated = count;
self.score_evasions();
partial_insertion_sort(self.moves.as_mut_slice(), self.end_cur, i32::MIN);
self.stage = Stage::Evasion;
}
Stage::Evasion => {
return self.select(|_, _| true).unwrap_or(Move::NONE);
}
Stage::QCapture => {
return self.select(|_, _| true).unwrap_or(Move::NONE);
}
Stage::ProbCut => {
if let Some(th) = self.probcut_threshold {
return self
.select(|s, ext| s.pos.see_ge(ext.mv, th))
.unwrap_or(Move::NONE);
}
return Move::NONE;
}
}
}
}
fn score_captures(&mut self) {
for i in self.cur..self.end_cur {
let ext = self.moves.get(i);
let m = ext.mv;
let to = m.to();
let pc = m.moved_piece_after();
let pt = pc.piece_type();
let captured = self.pos.piece_on(to);
let captured_pt = captured.piece_type();
let mut value = self.capture_history.get(pc, to, captured_pt) as i32;
value += 7 * piece_value(captured);
if self.pos.check_squares(pt).contains(to) {
value += 1024;
}
self.moves.set_value(i, value);
}
}
fn score_quiets(&mut self) {
let us = self.pos.side_to_move();
let pawn_idx = self.pos.pawn_history_index();
for i in self.cur..self.end_cur {
let ext = self.moves.get(i);
let m = ext.mv;
let to = m.to();
let pc = m.moved_piece_after();
let pt = pc.piece_type();
let mut value = 0i32;
value += 2 * self.main_history.get(us, m) as i32;
value += 2 * self.pawn_history.get(pawn_idx, pc, to) as i32;
for (idx, weight) in [(0, 1), (1, 1), (2, 1), (3, 1), (5, 1)] {
let ch = self.continuation_history[idx];
value += weight * ch.get(pc, to) as i32;
}
if self.pos.check_squares(pt).contains(to) && self.pos.see_ge(m, Value::new(-75)) {
value += 16384;
}
if self.ply < LOW_PLY_HISTORY_SIZE as i32 {
let ply_idx = self.ply as usize;
value += 8 * self.low_ply_history.get(ply_idx, m) as i32 / (1 + self.ply);
}
self.moves.set_value(i, value);
}
}
fn score_evasions(&mut self) {
let us = self.pos.side_to_move();
for i in self.cur..self.end_cur {
let ext = self.moves.get(i);
let m = ext.mv;
let to = m.to();
let pc = m.moved_piece_after();
if self.pos.is_capture(m) {
let captured = self.pos.piece_on(to);
self.moves.set_value(i, piece_value(captured) + (1 << 28));
} else {
let mut value = self.main_history.get(us, m) as i32;
let ch = self.continuation_history[0];
value += ch.get(pc, to) as i32;
if self.ply < LOW_PLY_HISTORY_SIZE as i32 {
let ply_idx = self.ply as usize;
value += 2 * self.low_ply_history.get(ply_idx, m) as i32 / (1 + self.ply);
}
self.moves.set_value(i, value);
}
}
}
fn select_good_capture(&mut self) -> Option<Move> {
while self.cur < self.end_cur {
let ext = self.moves.get(self.cur);
self.cur += 1;
if ext.mv == self.tt_move {
continue;
}
let threshold = Value::new(-ext.value / 18);
if self.pos.see_ge(ext.mv, threshold) {
return Some(ext.mv);
} else {
self.moves.swap(self.end_bad_captures, self.cur - 1);
self.end_bad_captures += 1;
}
}
None
}
fn select<F>(&mut self, filter: F) -> Option<Move>
where
F: Fn(&Self, &ExtMove) -> bool,
{
while self.cur < self.end_cur {
let ext = self.moves.get(self.cur);
self.cur += 1;
if ext.mv == self.tt_move {
continue;
}
if filter(self, &ext) {
return Some(ext.mv);
}
}
None
}
}
impl Iterator for MovePicker<'_> {
type Item = Move;
fn next(&mut self) -> Option<Self::Item> {
let m = self.next_move();
if m == Move::NONE {
None
} else {
Some(m)
}
}
}
const SORT_SWITCH_THRESHOLD: usize = 16;
fn partial_insertion_sort(moves: &mut [ExtMove], end: usize, limit: i32) -> usize {
if end == 0 {
return 0;
}
if end == 1 {
return if moves[0].value >= limit { 1 } else { 0 };
}
let slice = &mut moves[..end];
if limit == i32::MIN {
if end > SORT_SWITCH_THRESHOLD {
slice.sort_unstable_by(|a, b| b.value.cmp(&a.value));
} else {
for i in 1..end {
let tmp = slice[i];
let mut j = i;
while j > 0 && slice[j - 1].value < tmp.value {
slice[j] = slice[j - 1];
j -= 1;
}
slice[j] = tmp;
}
}
return end;
}
let mut good_count = 0;
for i in 0..end {
if slice[i].value >= limit {
slice.swap(i, good_count);
good_count += 1;
}
}
if good_count == 0 {
return 0;
}
let good_slice = &mut slice[..good_count];
if good_count > SORT_SWITCH_THRESHOLD {
good_slice.sort_unstable_by(|a, b| b.value.cmp(&a.value));
} else {
for i in 1..good_count {
let tmp = good_slice[i];
let mut j = i;
while j > 0 && good_slice[j - 1].value < tmp.value {
good_slice[j] = good_slice[j - 1];
j -= 1;
}
good_slice[j] = tmp;
}
}
good_count
}
pub(crate) fn piece_value(pc: Piece) -> i32 {
if pc.is_none() {
return 0;
}
use PieceType::*;
match pc.piece_type() {
Pawn => 90,
Lance => 315,
Knight => 405,
Silver => 495,
Gold | ProPawn | ProLance | ProKnight | ProSilver => 540,
Bishop => 855,
Rook => 990,
Horse => 1089, Dragon => 1224, King => 15000,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stage_next() {
assert_eq!(Stage::MainTT.next(), Stage::CaptureInit);
assert_eq!(Stage::CaptureInit.next(), Stage::GoodCapture);
assert_eq!(Stage::GoodCapture.next(), Stage::QuietInit);
assert_eq!(Stage::QuietInit.next(), Stage::GoodQuiet);
assert_eq!(Stage::GoodQuiet.next(), Stage::BadCapture);
assert_eq!(Stage::BadCapture.next(), Stage::BadQuiet);
assert_eq!(Stage::BadQuiet.next(), Stage::BadQuiet);
assert_eq!(Stage::EvasionTT.next(), Stage::EvasionInit);
assert_eq!(Stage::EvasionInit.next(), Stage::Evasion);
assert_eq!(Stage::Evasion.next(), Stage::Evasion);
assert_eq!(Stage::QSearchTT.next(), Stage::QCaptureInit);
assert_eq!(Stage::QCaptureInit.next(), Stage::QCapture);
assert_eq!(Stage::QCapture.next(), Stage::QCapture);
}
#[test]
fn test_partial_insertion_sort() {
let mut moves = vec![
ExtMove::new(Move::NONE, 100),
ExtMove::new(Move::NONE, 50),
ExtMove::new(Move::NONE, 200),
ExtMove::new(Move::NONE, 10),
ExtMove::new(Move::NONE, 150),
];
let len = moves.len();
let good_count = partial_insertion_sort(&mut moves, len, 100);
assert_eq!(good_count, 3); assert_eq!(moves[0].value, 200);
assert_eq!(moves[1].value, 150);
assert_eq!(moves[2].value, 100);
}
#[test]
fn test_partial_insertion_sort_boundary_value() {
let mut moves = vec![
ExtMove::new(Move::NONE, 99),
ExtMove::new(Move::NONE, 100), ExtMove::new(Move::NONE, 101),
];
let len = moves.len();
let good_count = partial_insertion_sort(&mut moves, len, 100);
assert_eq!(good_count, 2); assert_eq!(moves[0].value, 101);
assert_eq!(moves[1].value, 100);
assert_eq!(moves[2].value, 99);
}
#[test]
fn test_partial_insertion_sort_large_array() {
let mut moves: Vec<ExtMove> = (0..20).map(|i| ExtMove::new(Move::NONE, i * 10)).collect();
let len = moves.len();
let good_count = partial_insertion_sort(&mut moves, len, 100);
assert_eq!(good_count, 10);
assert_eq!(moves[0].value, 190);
assert_eq!(moves[1].value, 180);
assert_eq!(moves[9].value, 100);
}
#[test]
fn test_partial_insertion_sort_no_good_moves() {
let mut moves = vec![
ExtMove::new(Move::NONE, 10),
ExtMove::new(Move::NONE, 20),
ExtMove::new(Move::NONE, 30),
];
let len = moves.len();
let good_count = partial_insertion_sort(&mut moves, len, 100);
assert_eq!(good_count, 0);
}
#[test]
fn test_partial_insertion_sort_all_good_moves() {
let mut moves = vec![
ExtMove::new(Move::NONE, 100),
ExtMove::new(Move::NONE, 200),
ExtMove::new(Move::NONE, 150),
];
let len = moves.len();
let good_count = partial_insertion_sort(&mut moves, len, 50);
assert_eq!(good_count, 3);
assert_eq!(moves[0].value, 200);
assert_eq!(moves[1].value, 150);
assert_eq!(moves[2].value, 100);
}
#[test]
fn test_partial_insertion_sort_full_sort() {
let mut moves = vec![
ExtMove::new(Move::NONE, 50),
ExtMove::new(Move::NONE, -100),
ExtMove::new(Move::NONE, 200),
ExtMove::new(Move::NONE, 0),
];
let len = moves.len();
let good_count = partial_insertion_sort(&mut moves, len, i32::MIN);
assert_eq!(good_count, 4); assert_eq!(moves[0].value, 200);
assert_eq!(moves[1].value, 50);
assert_eq!(moves[2].value, 0);
assert_eq!(moves[3].value, -100);
}
#[test]
fn test_partial_insertion_sort_empty() {
let mut moves: Vec<ExtMove> = vec![];
let good_count = partial_insertion_sort(&mut moves, 0, 100);
assert_eq!(good_count, 0);
}
#[test]
fn test_partial_insertion_sort_single_element() {
let mut moves = vec![ExtMove::new(Move::NONE, 150)];
let good_count = partial_insertion_sort(&mut moves, 1, 100);
assert_eq!(good_count, 1);
assert_eq!(moves[0].value, 150);
let mut moves2 = vec![ExtMove::new(Move::NONE, 50)];
let good_count2 = partial_insertion_sort(&mut moves2, 1, 100);
assert_eq!(good_count2, 0); }
#[test]
fn test_piece_value() {
assert_eq!(piece_value(Piece::B_PAWN), 90);
assert_eq!(piece_value(Piece::W_GOLD), 540);
assert_eq!(piece_value(Piece::B_ROOK), 990);
assert_eq!(piece_value(Piece::W_DRAGON), 1224);
}
#[test]
fn test_end_good_quiets_boundary() {
let mut moves = vec![
ExtMove::new(Move::NONE, 100), ExtMove::new(Move::NONE, -200), ExtMove::new(Move::NONE, 50), ExtMove::new(Move::NONE, 200), ExtMove::new(Move::NONE, -100), ];
let len = moves.len();
let good_count = partial_insertion_sort(&mut moves, len, 0);
assert_eq!(good_count, 3);
for (i, m) in moves.iter().enumerate().take(good_count) {
assert!(m.value >= 0, "Move at index {i} should have value >= 0, got {}", m.value);
}
for (i, m) in moves.iter().enumerate().skip(good_count) {
assert!(m.value < 0, "Move at index {i} should have value < 0, got {}", m.value);
}
}
}