use std::time::Instant;
use ruchess::{fen, mve::Move, position::Position};
const POSITIONS: &[(&str, &str, &[u64])] = &[
(
"start",
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
&[20, 400, 8_902, 197_281, 4_865_609, 119_060_324],
),
(
"kiwipete",
"r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 1",
&[48, 2_039, 97_862, 4_085_603, 193_690_690],
),
(
"pos3",
"8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 1",
&[14, 191, 2_812, 43_238, 674_624, 11_030_083],
),
(
"pos4",
"r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1",
&[6, 264, 9_467, 422_333, 15_833_292],
),
(
"pos5",
"rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8",
&[44, 1_486, 62_379, 2_103_487, 89_941_194],
),
(
"pos6",
"r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10",
&[46, 2_079, 89_890, 3_894_594, 164_075_551],
),
];
fn perft(pos: &mut Position, depth: u32) -> u64 {
if depth == 1 {
return pos.valid_moves().count() as u64;
}
let mut moves: Vec<Move> = Vec::with_capacity(256);
moves.extend(pos.valid_moves());
let mut total: u64 = 0;
for m in &moves {
let undo = pos.make(m);
total += perft(pos, depth - 1);
pos.unmake(undo);
}
total
}
fn main() {
let args: Vec<String> = std::env::args().skip(1).collect();
let (max_depth, only): (Option<u32>, Option<&str>) =
args.iter().fold((None, None), |(d, n), a| match a.parse() {
Ok(parsed) => (Some(parsed), n),
Err(_) => (d, Some(a.as_str())),
});
let only = only.map(|s| s.to_string());
println!(
"{:<10} {:>6} {:>15} {:>12} {:>12} ok",
"position", "depth", "nodes", "ms", "Mn/s"
);
println!("{}", "-".repeat(72));
let mut grand_total: u64 = 0;
let mut grand_ms: u128 = 0;
let mut mismatches: Vec<(String, u32, u64, u64)> = Vec::new();
for (name, fen_str, expected) in POSITIONS {
if let Some(o) = &only
&& o != name
{
continue;
}
let pos = fen::parse(fen_str)
.expect("FEN should parse")
.without_repetition();
let depths = expected.len() as u32;
let cap = max_depth.unwrap_or(depths).min(depths);
for d in 1..=cap {
let want = expected[(d - 1) as usize];
let start = Instant::now();
let mut working = pos.clone();
let got = perft(&mut working, d);
let elapsed = start.elapsed();
let ms = elapsed.as_millis().max(1);
let mnps = (got as f64) / (elapsed.as_secs_f64().max(1e-9) * 1e6);
let ok = if got == want { "✓" } else { "✗" };
println!(
"{:<10} {:>6} {:>15} {:>12} {:>12.2} {}{}",
name,
d,
got,
ms,
mnps,
ok,
if got == want {
String::new()
} else {
format!(" (expected {want})")
}
);
grand_total += got;
grand_ms += elapsed.as_millis();
if got != want {
mismatches.push((name.to_string(), d, got, want));
}
}
}
println!("{}", "-".repeat(72));
let total_s = (grand_ms as f64) / 1000.0;
let mnps = if total_s > 0.0 {
(grand_total as f64) / (total_s * 1e6)
} else {
0.0
};
println!(
"total: {} nodes in {:.2}s — {:.2} Mn/s",
grand_total, total_s, mnps
);
if !mismatches.is_empty() {
println!("\nmismatches (move-generation discrepancies vs canonical):");
for (name, d, got, want) in &mismatches {
let diff = *got as i128 - *want as i128;
println!(" {name:<10} depth {d}: got {got}, expected {want} (diff {diff:+})");
}
std::process::exit(1);
}
}