1use chess::{Board, ChessMove};
7use std::process::{Command, Stdio};
8use std::io::{BufRead, BufReader, Write};
9use std::str::FromStr;
10use std::time::{Duration, Instant};
11use crate::{ChessVectorEngine, CalibratedEvaluator, CalibrationConfig};
12
13#[derive(Debug, Clone)]
15pub struct StockfishTestConfig {
16 pub stockfish_path: String,
17 pub depth_limit: Option<u8>, pub time_limit_ms: Option<u64>, pub skill_level: Option<u8>, pub num_threads: Option<u8>, pub hash_size_mb: Option<u32>, }
23
24impl Default for StockfishTestConfig {
25 fn default() -> Self {
26 Self {
27 stockfish_path: "stockfish".to_string(), depth_limit: Some(6), time_limit_ms: Some(1000), skill_level: Some(10), num_threads: Some(1), hash_size_mb: Some(64), }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct EvaluationComparison {
40 pub fen: String,
41 pub our_evaluation_cp: i32,
42 pub stockfish_evaluation_cp: i32,
43 pub our_best_move: Option<ChessMove>,
44 pub stockfish_best_move: Option<ChessMove>,
45 pub evaluation_diff_cp: i32,
46 pub move_agreement: bool,
47 pub evaluation_category: EvaluationCategory,
48}
49
50#[derive(Debug, Clone, PartialEq)]
52pub enum EvaluationCategory {
53 ExactMatch, CloseMatch, ReasonableMatch, SignMatch, Mismatch, }
59
60impl EvaluationCategory {
61 fn from_difference(diff_cp: i32) -> Self {
62 match diff_cp.abs() {
63 0..=25 => Self::ExactMatch,
64 26..=50 => Self::CloseMatch,
65 51..=100 => Self::ReasonableMatch,
66 _ => {
67 Self::SignMatch }
71 }
72 }
73}
74
75trait UCIEngine {
77 fn send_command(&mut self, command: &str) -> Result<(), StockfishTestError>;
78 fn read_response(&mut self) -> Result<String, StockfishTestError>;
79 fn evaluate_position(&mut self, fen: &str) -> Result<(i32, Option<ChessMove>), StockfishTestError>;
80 fn close(&mut self) -> Result<(), StockfishTestError>;
81}
82
83pub struct StockfishEngine {
85 process: std::process::Child,
86 stdin: std::process::ChildStdin,
87 stdout: BufReader<std::process::ChildStdout>,
88 config: StockfishTestConfig,
89}
90
91impl StockfishEngine {
92 pub fn new(config: StockfishTestConfig) -> Result<Self, StockfishTestError> {
93 let mut process = Command::new(&config.stockfish_path)
94 .stdin(Stdio::piped())
95 .stdout(Stdio::piped())
96 .stderr(Stdio::null())
97 .spawn()
98 .map_err(|e| StockfishTestError::LaunchError(format!("Failed to start Stockfish: {}", e)))?;
99
100 let stdin = process.stdin.take()
101 .ok_or_else(|| StockfishTestError::LaunchError("Failed to get stdin".to_string()))?;
102
103 let stdout = BufReader::new(process.stdout.take()
104 .ok_or_else(|| StockfishTestError::LaunchError("Failed to get stdout".to_string()))?);
105
106 let mut engine = Self {
107 process,
108 stdin,
109 stdout,
110 config,
111 };
112
113 engine.initialize()?;
115
116 Ok(engine)
117 }
118
119 fn initialize(&mut self) -> Result<(), StockfishTestError> {
120 self.send_command("uci")?;
122
123 loop {
125 let response = self.read_response()?;
126 if response.contains("uciok") {
127 break;
128 }
129 }
130
131 if let Some(skill_level) = self.config.skill_level {
133 self.send_command(&format!("setoption name Skill Level value {}", skill_level))?;
134 }
135
136 if let Some(threads) = self.config.num_threads {
137 self.send_command(&format!("setoption name Threads value {}", threads))?;
138 }
139
140 if let Some(hash_size) = self.config.hash_size_mb {
141 self.send_command(&format!("setoption name Hash value {}", hash_size))?;
142 }
143
144 self.send_command("setoption name Ponder value false")?;
146
147 self.send_command("isready")?;
148
149 loop {
151 let response = self.read_response()?;
152 if response.contains("readyok") {
153 break;
154 }
155 }
156
157 Ok(())
158 }
159}
160
161impl UCIEngine for StockfishEngine {
162 fn send_command(&mut self, command: &str) -> Result<(), StockfishTestError> {
163 writeln!(self.stdin, "{}", command)
164 .map_err(|e| StockfishTestError::CommunicationError(format!("Send failed: {}", e)))?;
165 self.stdin.flush()
166 .map_err(|e| StockfishTestError::CommunicationError(format!("Flush failed: {}", e)))?;
167 Ok(())
168 }
169
170 fn read_response(&mut self) -> Result<String, StockfishTestError> {
171 let mut line = String::new();
172 self.stdout.read_line(&mut line)
173 .map_err(|e| StockfishTestError::CommunicationError(format!("Read failed: {}", e)))?;
174 Ok(line.trim().to_string())
175 }
176
177 fn evaluate_position(&mut self, fen: &str) -> Result<(i32, Option<ChessMove>), StockfishTestError> {
178 self.send_command(&format!("position fen {}", fen))?;
180
181 let mut go_command = "go".to_string();
183
184 if let Some(depth) = self.config.depth_limit {
185 go_command.push_str(&format!(" depth {}", depth));
186 }
187
188 if let Some(time_ms) = self.config.time_limit_ms {
189 go_command.push_str(&format!(" movetime {}", time_ms));
190 }
191
192 self.send_command(&go_command)?;
193
194 let mut best_move = None;
195 let mut evaluation_cp = 0;
196
197 loop {
199 let response = self.read_response()?;
200
201 if response.starts_with("info") {
202 if let Some(cp_pos) = response.find(" cp ") {
204 if let Ok(cp) = response[cp_pos + 4..].split_whitespace().next()
205 .unwrap_or("0").parse::<i32>() {
206 evaluation_cp = cp;
207 }
208 }
209 } else if response.starts_with("bestmove") {
210 let parts: Vec<&str> = response.split_whitespace().collect();
212 if parts.len() >= 2 && parts[1] != "(none)" {
213 if parts[1].len() >= 4 {
216 let move_str = parts[1];
218 if move_str.chars().all(|c| c.is_ascii_alphanumeric()) {
219 }
223 }
224 }
225 break;
226 }
227 }
228
229 Ok((evaluation_cp, best_move))
230 }
231
232 fn close(&mut self) -> Result<(), StockfishTestError> {
233 self.send_command("quit")?;
234 self.process.wait()
235 .map_err(|e| StockfishTestError::LaunchError(format!("Failed to close: {}", e)))?;
236 Ok(())
237 }
238}
239
240#[derive(Debug)]
242pub enum StockfishTestError {
243 LaunchError(String),
244 CommunicationError(String),
245 ParseError(String),
246 TimeoutError(String),
247}
248
249impl std::fmt::Display for StockfishTestError {
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 match self {
252 Self::LaunchError(msg) => write!(f, "Launch error: {}", msg),
253 Self::CommunicationError(msg) => write!(f, "Communication error: {}", msg),
254 Self::ParseError(msg) => write!(f, "Parse error: {}", msg),
255 Self::TimeoutError(msg) => write!(f, "Timeout error: {}", msg),
256 }
257 }
258}
259
260impl std::error::Error for StockfishTestError {}
261
262pub struct StockfishTester {
264 our_engine: ChessVectorEngine,
265 calibrated_evaluator: CalibratedEvaluator,
266 stockfish_config: StockfishTestConfig,
267}
268
269impl StockfishTester {
270 pub fn new(stockfish_config: StockfishTestConfig) -> Self {
271 let mut our_engine = ChessVectorEngine::new(1024);
272 our_engine.enable_opening_book();
273
274 let calibration_config = CalibrationConfig::default();
275 let calibrated_evaluator = CalibratedEvaluator::new(calibration_config);
276
277 Self {
278 our_engine,
279 calibrated_evaluator,
280 stockfish_config,
281 }
282 }
283
284 fn classify_evaluation_match(our_eval_cp: i32, stockfish_eval_cp: i32) -> EvaluationCategory {
286 let diff = (our_eval_cp - stockfish_eval_cp).abs();
287
288 match diff {
289 0..=25 => EvaluationCategory::ExactMatch,
290 26..=50 => EvaluationCategory::CloseMatch,
291 51..=100 => EvaluationCategory::ReasonableMatch,
292 _ => {
293 let our_sign = our_eval_cp > 0;
295 let stockfish_sign = stockfish_eval_cp > 0;
296
297 if our_sign == stockfish_sign || (our_eval_cp.abs() <= 50 && stockfish_eval_cp.abs() <= 50) {
298 EvaluationCategory::SignMatch
299 } else {
300 EvaluationCategory::Mismatch
301 }
302 }
303 }
304 }
305
306 pub fn test_position(&mut self, fen: &str) -> Result<EvaluationComparison, StockfishTestError> {
308 let board = Board::from_str(fen)
309 .map_err(|e| StockfishTestError::ParseError(format!("Invalid FEN: {}", e)))?;
310
311 let our_evaluation_cp = self.calibrated_evaluator.evaluate_centipawns(&board);
313
314 let mut stockfish = StockfishEngine::new(self.stockfish_config.clone())?;
316 let (stockfish_evaluation_cp, stockfish_best_move) = stockfish.evaluate_position(fen)?;
317 stockfish.close()?;
318
319 let evaluation_diff_cp = our_evaluation_cp - stockfish_evaluation_cp;
321 let evaluation_category = Self::classify_evaluation_match(our_evaluation_cp, stockfish_evaluation_cp);
322
323 let move_agreement = false; let our_best_move = None; Ok(EvaluationComparison {
328 fen: fen.to_string(),
329 our_evaluation_cp,
330 stockfish_evaluation_cp,
331 our_best_move,
332 stockfish_best_move,
333 evaluation_diff_cp,
334 move_agreement,
335 evaluation_category,
336 })
337 }
338
339 pub fn test_positions(&mut self, positions: &[&str]) -> Result<TestSuiteResults, StockfishTestError> {
341 let mut results = Vec::new();
342 let start_time = Instant::now();
343
344 for (i, fen) in positions.iter().enumerate() {
345 println!("Testing position {}/{}: {}", i + 1, positions.len(), fen);
346
347 match self.test_position(fen) {
348 Ok(comparison) => {
349 println!(" Our: {}cp, Stockfish: {}cp, Diff: {}cp, Category: {:?}",
350 comparison.our_evaluation_cp,
351 comparison.stockfish_evaluation_cp,
352 comparison.evaluation_diff_cp,
353 comparison.evaluation_category);
354 results.push(comparison);
355 }
356 Err(e) => {
357 println!(" Error: {}", e);
358 return Err(e);
359 }
360 }
361 }
362
363 let total_time = start_time.elapsed();
364
365 Ok(TestSuiteResults::new(results, total_time))
366 }
367}
368
369#[derive(Debug)]
371pub struct TestSuiteResults {
372 pub comparisons: Vec<EvaluationComparison>,
373 pub total_time: Duration,
374 pub statistics: TestStatistics,
375}
376
377#[derive(Debug)]
378pub struct TestStatistics {
379 pub total_positions: usize,
380 pub exact_matches: usize,
381 pub close_matches: usize,
382 pub reasonable_matches: usize,
383 pub sign_matches: usize,
384 pub mismatches: usize,
385 pub avg_evaluation_diff_cp: f32,
386 pub rms_evaluation_diff_cp: f32,
387 pub success_rate: f32,
388}
389
390impl TestSuiteResults {
391 fn new(comparisons: Vec<EvaluationComparison>, total_time: Duration) -> Self {
392 let statistics = Self::calculate_statistics(&comparisons);
393 Self {
394 comparisons,
395 total_time,
396 statistics,
397 }
398 }
399
400 fn calculate_statistics(comparisons: &[EvaluationComparison]) -> TestStatistics {
401 let total_positions = comparisons.len();
402
403 let mut exact_matches = 0;
404 let mut close_matches = 0;
405 let mut reasonable_matches = 0;
406 let mut sign_matches = 0;
407 let mut mismatches = 0;
408
409 let mut sum_diff = 0.0;
410 let mut sum_squared_diff = 0.0;
411
412 for comparison in comparisons {
413 match comparison.evaluation_category {
414 EvaluationCategory::ExactMatch => exact_matches += 1,
415 EvaluationCategory::CloseMatch => close_matches += 1,
416 EvaluationCategory::ReasonableMatch => reasonable_matches += 1,
417 EvaluationCategory::SignMatch => sign_matches += 1,
418 EvaluationCategory::Mismatch => mismatches += 1,
419 }
420
421 let diff = comparison.evaluation_diff_cp as f32;
422 sum_diff += diff;
423 sum_squared_diff += diff * diff;
424 }
425
426 let avg_evaluation_diff_cp = if total_positions > 0 {
427 sum_diff / total_positions as f32
428 } else {
429 0.0
430 };
431
432 let rms_evaluation_diff_cp = if total_positions > 0 {
433 (sum_squared_diff / total_positions as f32).sqrt()
434 } else {
435 0.0
436 };
437
438 let successful_matches = exact_matches + close_matches + reasonable_matches;
440 let success_rate = if total_positions > 0 {
441 successful_matches as f32 / total_positions as f32
442 } else {
443 0.0
444 };
445
446 TestStatistics {
447 total_positions,
448 exact_matches,
449 close_matches,
450 reasonable_matches,
451 sign_matches,
452 mismatches,
453 avg_evaluation_diff_cp,
454 rms_evaluation_diff_cp,
455 success_rate,
456 }
457 }
458
459 pub fn display_summary(&self) -> String {
460 let stats = &self.statistics;
461
462 format!(
463 "Stockfish Comparison Results\n\
464 =============================\n\
465 Total positions: {}\n\
466 Test duration: {:.2}s\n\n\
467 Evaluation Agreement:\n\
468 - Exact matches (±25cp): {} ({:.1}%)\n\
469 - Close matches (±50cp): {} ({:.1}%)\n\
470 - Reasonable matches (±100cp): {} ({:.1}%)\n\
471 - Sign matches: {} ({:.1}%)\n\
472 - Mismatches: {} ({:.1}%)\n\n\
473 Statistical Analysis:\n\
474 - Success rate: {:.1}%\n\
475 - Average difference: {:.1}cp\n\
476 - RMS difference: {:.1}cp\n\n\
477 Assessment: {}",
478 stats.total_positions,
479 self.total_time.as_secs_f32(),
480 stats.exact_matches, Self::percentage(stats.exact_matches, stats.total_positions),
481 stats.close_matches, Self::percentage(stats.close_matches, stats.total_positions),
482 stats.reasonable_matches, Self::percentage(stats.reasonable_matches, stats.total_positions),
483 stats.sign_matches, Self::percentage(stats.sign_matches, stats.total_positions),
484 stats.mismatches, Self::percentage(stats.mismatches, stats.total_positions),
485 stats.success_rate * 100.0,
486 stats.avg_evaluation_diff_cp,
487 stats.rms_evaluation_diff_cp,
488 Self::assessment(stats.success_rate)
489 )
490 }
491
492 fn percentage(count: usize, total: usize) -> f32 {
493 if total > 0 {
494 count as f32 / total as f32 * 100.0
495 } else {
496 0.0
497 }
498 }
499
500 fn assessment(success_rate: f32) -> &'static str {
501 match (success_rate * 100.0) as u32 {
502 80..=100 => "✅ Excellent agreement with Stockfish",
503 60..=79 => "✅ Good agreement with Stockfish",
504 40..=59 => "⚠️ Moderate agreement with Stockfish",
505 20..=39 => "⚠️ Poor agreement with Stockfish",
506 _ => "❌ Very poor agreement with Stockfish",
507 }
508 }
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn test_stockfish_config_default() {
517 let config = StockfishTestConfig::default();
518 assert_eq!(config.depth_limit, Some(6));
519 assert_eq!(config.skill_level, Some(10));
520 }
521
522 #[test]
523 fn test_evaluation_category_classification() {
524 assert_eq!(EvaluationCategory::from_difference(10), EvaluationCategory::ExactMatch);
525 assert_eq!(EvaluationCategory::from_difference(40), EvaluationCategory::CloseMatch);
526 assert_eq!(EvaluationCategory::from_difference(80), EvaluationCategory::ReasonableMatch);
527 }
528}