use super::{Action, Board};
use rayon::prelude::*;
use rustc_hash::FxHasher;
use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher};
use std::sync::atomic::{AtomicU64, Ordering};
impl Board {
#[must_use]
pub fn perft(&self, depth: u64) -> u64 {
if depth == 0 {
return 1;
}
let mut count_scratch = Vec::with_capacity(48);
let mut scratches: Vec<Vec<Action>> = (1..depth).map(|_| Vec::with_capacity(48)).collect();
self.perft_inner(depth, &mut scratches, &mut count_scratch)
}
#[inline(always)]
fn perft_inner(
&self,
depth: u64,
scratches: &mut [Vec<Action>],
count_scratch: &mut Vec<Action>,
) -> u64 {
if depth == 1 {
let count = self.count_actions(count_scratch);
return if count == 0 { 1 } else { count };
}
let idx = depth as usize - 2;
self.actions_into(&mut scratches[idx]);
if scratches[idx].is_empty() {
return 1; }
let action_count = scratches[idx].len();
let mut nodes = 0u64;
for i in 0..action_count {
let action = scratches[idx][i];
let mut board = self.apply(&action);
board.swap_turn_();
nodes += board.perft_inner(depth - 1, scratches, count_scratch);
}
nodes
}
#[must_use]
pub fn perft_tt(&self, depth: u64, tt_size_mb: usize) -> u64 {
if depth == 0 {
return 1;
}
if depth <= 2 || tt_size_mb == 0 {
return self.perft(depth);
}
let tt_capacity = (tt_size_mb * 1024 * 1024) / 64;
let tt = TranspositionTable::new(tt_capacity);
let mut count_scratch = Vec::with_capacity(48);
let mut scratches: Vec<Vec<Action>> = (1..depth).map(|_| Vec::with_capacity(48)).collect();
self.perft_tt_seq_inner(depth, &mut scratches, &mut count_scratch, &tt)
}
#[inline(always)]
fn perft_tt_seq_inner(
&self,
depth: u64,
scratches: &mut [Vec<Action>],
count_scratch: &mut Vec<Action>,
tt: &TranspositionTable,
) -> u64 {
if depth == 1 {
let count = self.count_actions(count_scratch);
return if count == 0 { 1 } else { count };
}
if depth >= 3 {
if let Some(nodes) = tt.get(self, depth as u8) {
return nodes;
}
}
let idx = depth as usize - 2;
self.actions_into(&mut scratches[idx]);
if scratches[idx].is_empty() {
return 1;
}
let action_count = scratches[idx].len();
let mut nodes = 0u64;
for i in 0..action_count {
let action = scratches[idx][i];
let mut board = self.apply(&action);
board.swap_turn_();
nodes += board.perft_tt_seq_inner(depth - 1, scratches, count_scratch, tt);
}
if depth >= 3 {
tt.insert(self, depth as u8, nodes);
}
nodes
}
#[must_use]
pub fn perft_parallel(&self, depth: u64, tt_size_mb: usize) -> u64 {
if depth == 0 {
return 1;
}
if depth <= 2 {
return self.perft(depth);
}
let tt_capacity = if tt_size_mb > 0 {
(tt_size_mb * 1024 * 1024) / 64
} else {
0
};
let tt = TranspositionTable::new(tt_capacity);
let tt_hits = AtomicU64::new(0);
let tt_lookups = AtomicU64::new(0);
let actions = self.actions();
if actions.is_empty() {
return 1;
}
let nodes: u64 = actions
.par_iter()
.map(|action| {
let mut board = self.apply(action);
board.swap_turn_();
let mut count_scratch = Vec::with_capacity(48);
let mut scratches: Vec<Vec<Action>> =
(1..depth).map(|_| Vec::with_capacity(48)).collect();
board.perft_tt_inner(
depth - 1,
&mut scratches,
&mut count_scratch,
&tt,
&tt_hits,
&tt_lookups,
)
})
.sum();
let hits = tt_hits.load(Ordering::Relaxed);
let lookups = tt_lookups.load(Ordering::Relaxed);
if lookups > 0 {
let hit_rate = (hits as f64 / lookups as f64) * 100.0;
eprintln!(
"TT hits: {} / {} lookups ({:.2}% hit rate)",
hits, lookups, hit_rate
);
}
nodes
}
#[inline(always)]
fn perft_tt_inner(
&self,
depth: u64,
scratches: &mut [Vec<Action>],
count_scratch: &mut Vec<Action>,
tt: &TranspositionTable,
tt_hits: &AtomicU64,
tt_lookups: &AtomicU64,
) -> u64 {
if depth == 1 {
let count = self.count_actions(count_scratch);
return if count == 0 { 1 } else { count };
}
if depth >= 3 {
tt_lookups.fetch_add(1, Ordering::Relaxed);
if let Some(nodes) = tt.get(self, depth as u8) {
tt_hits.fetch_add(1, Ordering::Relaxed);
return nodes;
}
}
let idx = depth as usize - 2;
self.actions_into(&mut scratches[idx]);
if scratches[idx].is_empty() {
return 1;
}
let action_count = scratches[idx].len();
let mut nodes = 0u64;
for i in 0..action_count {
let action = scratches[idx][i];
let mut board = self.apply(&action);
board.swap_turn_();
nodes +=
board.perft_tt_inner(depth - 1, scratches, count_scratch, tt, tt_hits, tt_lookups);
}
if depth >= 3 {
tt.insert(self, depth as u8, nodes);
}
nodes
}
}
struct TranspositionTable {
entries: Vec<AtomicEntry>,
mask: usize,
}
struct AtomicEntry {
key: AtomicU64,
value: AtomicU64,
}
impl TranspositionTable {
fn new(capacity: usize) -> Self {
if capacity == 0 {
return Self {
entries: Vec::new(),
mask: 0,
};
}
let capacity = capacity.next_power_of_two();
let mask = capacity - 1;
let mut entries = Vec::with_capacity(capacity);
for _ in 0..capacity {
entries.push(AtomicEntry {
key: AtomicU64::new(0),
value: AtomicU64::new(0),
});
}
Self { entries, mask }
}
#[inline]
fn get(&self, board: &Board, depth: u8) -> Option<u64> {
if self.entries.is_empty() {
return None;
}
let hash = self.hash_board(board);
let index = (hash as usize) & self.mask;
let entry = &self.entries[index];
let value = entry.value.load(Ordering::Relaxed);
let stored_key_xored = entry.key.load(Ordering::Relaxed);
let recovered_key = stored_key_xored ^ value;
let expected_key = (hash & 0xFFFF_FFFF_FFFF_FF00) | (depth as u64);
if recovered_key == expected_key {
Some(value)
} else {
None
}
}
#[inline]
fn insert(&self, board: &Board, depth: u8, nodes: u64) {
if self.entries.is_empty() {
return;
}
let hash = self.hash_board(board);
let index = (hash as usize) & self.mask;
let entry = &self.entries[index];
let key = (hash & 0xFFFF_FFFF_FFFF_FF00) | (depth as u64);
entry.key.store(key ^ nodes, Ordering::Relaxed);
entry.value.store(nodes, Ordering::Relaxed);
}
#[inline]
fn hash_board(&self, board: &Board) -> u64 {
let build_hasher = BuildHasherDefault::<FxHasher>::default();
let mut hasher = build_hasher.build_hasher();
board.hash(&mut hasher);
hasher.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Square, Team};
#[test]
fn perft_depth_0() {
let board = Board::new_default();
assert_eq!(board.perft(0), 1);
}
#[test]
fn perft_depth_1() {
let board = Board::new_default();
assert_eq!(board.perft(1), 8);
}
#[test]
fn perft_depth_2() {
let board = Board::new_default();
assert_eq!(board.perft(2), 64);
}
#[test]
fn perft_depth_3() {
let board = Board::new_default();
assert_eq!(board.perft(3), 708);
}
#[test]
fn perft_depth_4() {
let board = Board::new_default();
assert_eq!(board.perft(4), 7538);
}
#[test]
fn perft_depth_5() {
let board = Board::new_default();
assert_eq!(board.perft(5), 85090);
}
#[test]
fn perft_depth_6() {
let board = Board::new_default();
assert_eq!(board.perft(6), 931_312);
}
#[test]
fn perft_depth_7() {
let board = Board::new_default();
assert_eq!(board.perft(7), 10_782_382);
}
#[test]
fn perft_depth_8() {
let board = Board::new_default();
assert_eq!(board.perft(8), 123_290_300);
}
#[test]
fn perft_parallel_matches_sequential() {
let board = Board::new_default();
let seq = board.perft(7);
let par = board.perft_parallel(7, 64); assert_eq!(seq, par);
}
#[test]
fn perft_tt_matches_sequential() {
let board = Board::new_default();
let seq = board.perft(7);
let tt = board.perft_tt(7, 64); assert_eq!(seq, tt);
}
#[test]
fn perft_single_pawn() {
let board = Board::from_squares(Team::White, &[Square::D4], &[Square::H8], &[]);
let perft1 = board.perft(1);
assert_eq!(perft1, 3, "Single pawn at D4 should have 3 moves");
}
#[test]
fn perft_single_king() {
let board = Board::from_squares(Team::White, &[Square::D4], &[Square::H8], &[Square::D4]);
let perft1 = board.perft(1);
assert_eq!(perft1, 14, "King at D4 should have 14 moves");
}
#[test]
fn perft_forced_capture() {
let board = Board::from_squares(Team::White, &[Square::D4], &[Square::D5], &[]);
let perft1 = board.perft(1);
assert_eq!(perft1, 1, "Should have exactly 1 capture");
}
#[test]
fn perft_terminal_position() {
let board = Board::from_squares(Team::White, &[], &[Square::D4], &[]);
let perft1 = board.perft(1);
assert_eq!(perft1, 1, "Terminal position should count as 1");
}
}