chess_vector_engine/
training.rs

1use chess::{Board, ChessMove, Game, MoveGen};
2use indicatif::{ProgressBar, ProgressStyle};
3use pgn_reader::{BufferedReader, RawHeader, SanPlus, Skip, Visitor};
4use rayon::prelude::*;
5use serde::{Deserialize, Serialize};
6use std::fs::File;
7use std::io::{BufRead, BufReader, BufWriter, Write};
8use std::path::Path;
9use std::process::{Child, Command, Stdio};
10use std::str::FromStr;
11use std::sync::{Arc, Mutex};
12
13use crate::ChessVectorEngine;
14
15/// Self-play training configuration
16#[derive(Debug, Clone)]
17pub struct SelfPlayConfig {
18    /// Number of games to play per training iteration
19    pub games_per_iteration: usize,
20    /// Maximum moves per game (to prevent infinite games)
21    pub max_moves_per_game: usize,
22    /// Exploration factor for move selection (0.0 = greedy, 1.0 = random)
23    pub exploration_factor: f32,
24    /// Minimum evaluation confidence to include position
25    pub min_confidence: f32,
26    /// Whether to use opening book for game starts
27    pub use_opening_book: bool,
28    /// Temperature for move selection (higher = more random)
29    pub temperature: f32,
30}
31
32impl Default for SelfPlayConfig {
33    fn default() -> Self {
34        Self {
35            games_per_iteration: 100,
36            max_moves_per_game: 200,
37            exploration_factor: 0.3,
38            min_confidence: 0.1,
39            use_opening_book: true,
40            temperature: 0.8,
41        }
42    }
43}
44
45/// Training data point containing a position and its evaluation
46#[derive(Debug, Clone)]
47pub struct TrainingData {
48    pub board: Board,
49    pub evaluation: f32,
50    pub depth: u8,
51    pub game_id: usize,
52}
53
54/// Tactical puzzle data from Lichess puzzle database
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct TacticalPuzzle {
57    #[serde(rename = "PuzzleId")]
58    pub puzzle_id: String,
59    #[serde(rename = "FEN")]
60    pub fen: String,
61    #[serde(rename = "Moves")]
62    pub moves: String, // Space-separated move sequence
63    #[serde(rename = "Rating")]
64    pub rating: u32,
65    #[serde(rename = "RatingDeviation")]
66    pub rating_deviation: u32,
67    #[serde(rename = "Popularity")]
68    pub popularity: i32,
69    #[serde(rename = "NbPlays")]
70    pub nb_plays: u32,
71    #[serde(rename = "Themes")]
72    pub themes: String, // Space-separated themes
73    #[serde(rename = "GameUrl")]
74    pub game_url: Option<String>,
75    #[serde(rename = "OpeningTags")]
76    pub opening_tags: Option<String>,
77}
78
79/// Processed tactical training data
80#[derive(Debug, Clone)]
81pub struct TacticalTrainingData {
82    pub position: Board,
83    pub solution_move: ChessMove,
84    pub move_theme: String,
85    pub difficulty: f32,     // Rating as difficulty
86    pub tactical_value: f32, // High value for move outcome
87}
88
89// Make TacticalTrainingData serializable
90impl serde::Serialize for TacticalTrainingData {
91    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
92    where
93        S: serde::Serializer,
94    {
95        use serde::ser::SerializeStruct;
96        let mut state = serializer.serialize_struct("TacticalTrainingData", 5)?;
97        state.serialize_field("fen", &self.position.to_string())?;
98        state.serialize_field("solution_move", &self.solution_move.to_string())?;
99        state.serialize_field("move_theme", &self.move_theme)?;
100        state.serialize_field("difficulty", &self.difficulty)?;
101        state.serialize_field("tactical_value", &self.tactical_value)?;
102        state.end()
103    }
104}
105
106impl<'de> serde::Deserialize<'de> for TacticalTrainingData {
107    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108    where
109        D: serde::Deserializer<'de>,
110    {
111        use serde::de::{self, MapAccess, Visitor};
112        use std::fmt;
113
114        struct TacticalTrainingDataVisitor;
115
116        impl<'de> Visitor<'de> for TacticalTrainingDataVisitor {
117            type Value = TacticalTrainingData;
118
119            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
120                formatter.write_str("struct TacticalTrainingData")
121            }
122
123            fn visit_map<V>(self, mut map: V) -> Result<TacticalTrainingData, V::Error>
124            where
125                V: MapAccess<'de>,
126            {
127                let mut fen = None;
128                let mut solution_move = None;
129                let mut move_theme = None;
130                let mut difficulty = None;
131                let mut tactical_value = None;
132
133                while let Some(key) = map.next_key()? {
134                    match key {
135                        "fen" => {
136                            if fen.is_some() {
137                                return Err(de::Error::duplicate_field("fen"));
138                            }
139                            fen = Some(map.next_value()?);
140                        }
141                        "solution_move" => {
142                            if solution_move.is_some() {
143                                return Err(de::Error::duplicate_field("solution_move"));
144                            }
145                            solution_move = Some(map.next_value()?);
146                        }
147                        "move_theme" => {
148                            if move_theme.is_some() {
149                                return Err(de::Error::duplicate_field("move_theme"));
150                            }
151                            move_theme = Some(map.next_value()?);
152                        }
153                        "difficulty" => {
154                            if difficulty.is_some() {
155                                return Err(de::Error::duplicate_field("difficulty"));
156                            }
157                            difficulty = Some(map.next_value()?);
158                        }
159                        "tactical_value" => {
160                            if tactical_value.is_some() {
161                                return Err(de::Error::duplicate_field("tactical_value"));
162                            }
163                            tactical_value = Some(map.next_value()?);
164                        }
165                        _ => {
166                            let _: serde_json::Value = map.next_value()?;
167                        }
168                    }
169                }
170
171                let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
172                let solution_move_str: String =
173                    solution_move.ok_or_else(|| de::Error::missing_field("solution_move"))?;
174                let move_theme =
175                    move_theme.ok_or_else(|| de::Error::missing_field("move_theme"))?;
176                let difficulty =
177                    difficulty.ok_or_else(|| de::Error::missing_field("difficulty"))?;
178                let tactical_value =
179                    tactical_value.ok_or_else(|| de::Error::missing_field("tactical_value"))?;
180
181                let position =
182                    Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
183
184                let solution_move = ChessMove::from_str(&solution_move_str)
185                    .map_err(|e| de::Error::custom(format!("Error: {e}")))?;
186
187                Ok(TacticalTrainingData {
188                    position,
189                    solution_move,
190                    move_theme,
191                    difficulty,
192                    tactical_value,
193                })
194            }
195        }
196
197        const FIELDS: &[&str] = &[
198            "fen",
199            "solution_move",
200            "move_theme",
201            "difficulty",
202            "tactical_value",
203        ];
204        deserializer.deserialize_struct("TacticalTrainingData", FIELDS, TacticalTrainingDataVisitor)
205    }
206}
207
208/// PGN game visitor for extracting positions
209pub struct GameExtractor {
210    pub positions: Vec<TrainingData>,
211    pub current_game: Game,
212    pub move_count: usize,
213    pub max_moves_per_game: usize,
214    pub game_id: usize,
215}
216
217impl GameExtractor {
218    pub fn new(max_moves_per_game: usize) -> Self {
219        Self {
220            positions: Vec::new(),
221            current_game: Game::new(),
222            move_count: 0,
223            max_moves_per_game,
224            game_id: 0,
225        }
226    }
227}
228
229impl Visitor for GameExtractor {
230    type Result = ();
231
232    fn begin_game(&mut self) {
233        self.current_game = Game::new();
234        self.move_count = 0;
235        self.game_id += 1;
236    }
237
238    fn header(&mut self, _key: &[u8], _value: RawHeader<'_>) {}
239
240    fn san(&mut self, san_plus: SanPlus) {
241        if self.move_count >= self.max_moves_per_game {
242            return;
243        }
244
245        let san_str = san_plus.san.to_string();
246
247        // First validate that we have a legal position to work with
248        let current_pos = self.current_game.current_position();
249
250        // Try to parse and make the move
251        match chess::ChessMove::from_san(&current_pos, &san_str) {
252            Ok(chess_move) => {
253                // Verify the move is legal before making it
254                let legal_moves: Vec<chess::ChessMove> = MoveGen::new_legal(&current_pos).collect();
255                if legal_moves.contains(&chess_move) {
256                    if self.current_game.make_move(chess_move) {
257                        self.move_count += 1;
258
259                        // Store position (we'll evaluate it later with Stockfish)
260                        self.positions.push(TrainingData {
261                            board: self.current_game.current_position(),
262                            evaluation: 0.0, // Will be filled by Stockfish
263                            depth: 0,
264                            game_id: self.game_id,
265                        });
266                    }
267                } else {
268                    // Move parsed but isn't legal - skip silently to avoid spam
269                }
270            }
271            Err(_) => {
272                // Failed to parse move - could be notation issues, corruption, etc.
273                // Skip silently to avoid excessive error output
274                // Only log if it's not a common problematic pattern
275                if !san_str.contains("O-O") && !san_str.contains("=") && san_str.len() > 6 {
276                    // Only log unusual failed moves to reduce noise
277                }
278            }
279        }
280    }
281
282    fn begin_variation(&mut self) -> Skip {
283        Skip(true) // Skip variations for now
284    }
285
286    fn end_game(&mut self) -> Self::Result {}
287}
288
289/// Stockfish engine wrapper for position evaluation
290pub struct StockfishEvaluator {
291    depth: u8,
292}
293
294impl StockfishEvaluator {
295    pub fn new(depth: u8) -> Self {
296        Self { depth }
297    }
298
299    /// Evaluate a single position using Stockfish
300    pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
301        let mut child = Command::new("stockfish")
302            .stdin(Stdio::piped())
303            .stdout(Stdio::piped())
304            .stderr(Stdio::piped())
305            .spawn()?;
306
307        let stdin = child
308            .stdin
309            .as_mut()
310            .ok_or("Failed to get stdin handle for Stockfish process")?;
311        let fen = board.to_string();
312
313        // Send UCI commands
314        use std::io::Write;
315        writeln!(stdin, "uci")?;
316        writeln!(stdin, "isready")?;
317        writeln!(stdin, "position fen {fen}")?;
318        writeln!(stdin, "go depth {}", self.depth)?;
319        writeln!(stdin, "quit")?;
320
321        let output = child.wait_with_output()?;
322        let stdout = String::from_utf8_lossy(&output.stdout);
323
324        // Parse the evaluation from Stockfish output
325        for line in stdout.lines() {
326            if line.starts_with("info") && line.contains("score cp") {
327                if let Some(cp_pos) = line.find("score cp ") {
328                    let cp_str = &line[cp_pos + 9..];
329                    if let Some(end) = cp_str.find(' ') {
330                        let cp_value = cp_str[..end].parse::<i32>()?;
331                        return Ok(cp_value as f32 / 100.0); // Convert centipawns to pawns
332                    }
333                }
334            } else if line.starts_with("info") && line.contains("score mate") {
335                // Handle mate scores
336                if let Some(mate_pos) = line.find("score mate ") {
337                    let mate_str = &line[mate_pos + 11..];
338                    if let Some(end) = mate_str.find(' ') {
339                        let mate_moves = mate_str[..end].parse::<i32>()?;
340                        return Ok(if mate_moves > 0 { 100.0 } else { -100.0 });
341                    }
342                }
343            }
344        }
345
346        Ok(0.0) // Default to 0 if no evaluation found
347    }
348
349    /// Batch evaluate multiple positions
350    pub fn evaluate_batch(
351        &self,
352        positions: &mut [TrainingData],
353    ) -> Result<(), Box<dyn std::error::Error>> {
354        let pb = ProgressBar::new(positions.len() as u64);
355        if let Ok(style) = ProgressStyle::default_bar().template(
356            "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
357        ) {
358            pb.set_style(style.progress_chars("#>-"));
359        }
360
361        for data in positions.iter_mut() {
362            match self.evaluate_position(&data.board) {
363                Ok(eval) => {
364                    data.evaluation = eval;
365                    data.depth = self.depth;
366                }
367                Err(e) => {
368                    eprintln!("Evaluation error: {e}");
369                    data.evaluation = 0.0;
370                }
371            }
372            pb.inc(1);
373        }
374
375        pb.finish_with_message("Evaluation complete");
376        Ok(())
377    }
378
379    /// Evaluate multiple positions in parallel using concurrent Stockfish instances
380    pub fn evaluate_batch_parallel(
381        &self,
382        positions: &mut [TrainingData],
383        num_threads: usize,
384    ) -> Result<(), Box<dyn std::error::Error>> {
385        let pb = ProgressBar::new(positions.len() as u64);
386        if let Ok(style) = ProgressStyle::default_bar()
387            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel evaluation") {
388            pb.set_style(style.progress_chars("#>-"));
389        }
390
391        // Set the thread pool size
392        let pool = rayon::ThreadPoolBuilder::new()
393            .num_threads(num_threads)
394            .build()?;
395
396        pool.install(|| {
397            // Use parallel iterator to evaluate positions
398            positions.par_iter_mut().for_each(|data| {
399                match self.evaluate_position(&data.board) {
400                    Ok(eval) => {
401                        data.evaluation = eval;
402                        data.depth = self.depth;
403                    }
404                    Err(_) => {
405                        // Silently fail individual positions to avoid spamming output
406                        data.evaluation = 0.0;
407                    }
408                }
409                pb.inc(1);
410            });
411        });
412
413        pb.finish_with_message("Parallel evaluation complete");
414        Ok(())
415    }
416}
417
418/// Persistent Stockfish process for fast UCI communication
419struct StockfishProcess {
420    child: Child,
421    stdin: BufWriter<std::process::ChildStdin>,
422    stdout: BufReader<std::process::ChildStdout>,
423    #[allow(dead_code)]
424    depth: u8,
425}
426
427impl StockfishProcess {
428    fn new(depth: u8) -> Result<Self, Box<dyn std::error::Error>> {
429        let mut child = Command::new("stockfish")
430            .stdin(Stdio::piped())
431            .stdout(Stdio::piped())
432            .stderr(Stdio::piped())
433            .spawn()?;
434
435        let stdin = BufWriter::new(
436            child
437                .stdin
438                .take()
439                .ok_or("Failed to get stdin handle for Stockfish process")?,
440        );
441        let stdout = BufReader::new(
442            child
443                .stdout
444                .take()
445                .ok_or("Failed to get stdout handle for Stockfish process")?,
446        );
447
448        let mut process = Self {
449            child,
450            stdin,
451            stdout,
452            depth,
453        };
454
455        // Initialize UCI
456        process.send_command("uci")?;
457        process.wait_for_ready()?;
458        process.send_command("isready")?;
459        process.wait_for_ready()?;
460
461        Ok(process)
462    }
463
464    fn send_command(&mut self, command: &str) -> Result<(), Box<dyn std::error::Error>> {
465        writeln!(self.stdin, "{command}")?;
466        self.stdin.flush()?;
467        Ok(())
468    }
469
470    fn wait_for_ready(&mut self) -> Result<(), Box<dyn std::error::Error>> {
471        let mut line = String::new();
472        loop {
473            line.clear();
474            self.stdout.read_line(&mut line)?;
475            if line.trim() == "uciok" || line.trim() == "readyok" {
476                break;
477            }
478        }
479        Ok(())
480    }
481
482    fn evaluate_position(&mut self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
483        let fen = board.to_string();
484
485        // Send position and evaluation commands
486        self.send_command(&format!("position fen {fen}"))?;
487        self.send_command(&format!("position fen {fen}"))?;
488
489        // Read response until we get bestmove
490        let mut line = String::new();
491        let mut last_evaluation = 0.0;
492
493        loop {
494            line.clear();
495            self.stdout.read_line(&mut line)?;
496            let line = line.trim();
497
498            if line.starts_with("info") && line.contains("score cp") {
499                if let Some(cp_pos) = line.find("score cp ") {
500                    let cp_str = &line[cp_pos + 9..];
501                    if let Some(end) = cp_str.find(' ') {
502                        if let Ok(cp_value) = cp_str[..end].parse::<i32>() {
503                            last_evaluation = cp_value as f32 / 100.0;
504                        }
505                    }
506                }
507            } else if line.starts_with("info") && line.contains("score mate") {
508                if let Some(mate_pos) = line.find("score mate ") {
509                    let mate_str = &line[mate_pos + 11..];
510                    if let Some(end) = mate_str.find(' ') {
511                        if let Ok(mate_moves) = mate_str[..end].parse::<i32>() {
512                            last_evaluation = if mate_moves > 0 { 100.0 } else { -100.0 };
513                        }
514                    }
515                }
516            } else if line.starts_with("bestmove") {
517                break;
518            }
519        }
520
521        Ok(last_evaluation)
522    }
523}
524
525impl Drop for StockfishProcess {
526    fn drop(&mut self) {
527        let _ = self.send_command("quit");
528        let _ = self.child.wait();
529    }
530}
531
532/// High-performance Stockfish process pool
533pub struct StockfishPool {
534    pool: Arc<Mutex<Vec<StockfishProcess>>>,
535    depth: u8,
536    pool_size: usize,
537}
538
539impl StockfishPool {
540    pub fn new(depth: u8, pool_size: usize) -> Result<Self, Box<dyn std::error::Error>> {
541        let mut processes = Vec::with_capacity(pool_size);
542
543        println!(
544            "🚀 Initializing Stockfish pool with {} processes...",
545            pool_size
546        );
547
548        for i in 0..pool_size {
549            match StockfishProcess::new(depth) {
550                Ok(process) => {
551                    processes.push(process);
552                    if i % 2 == 1 {
553                        print!(".");
554                        let _ = std::io::stdout().flush(); // Ignore flush errors
555                    }
556                }
557                Err(e) => {
558                    eprintln!("Evaluation error: {e}");
559                    return Err(e);
560                }
561            }
562        }
563
564        println!(" ✅ Pool ready!");
565
566        Ok(Self {
567            pool: Arc::new(Mutex::new(processes)),
568            depth,
569            pool_size,
570        })
571    }
572
573    pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
574        // Get a process from the pool
575        let mut process = {
576            let mut pool = self.pool.lock().unwrap();
577            if let Some(process) = pool.pop() {
578                process
579            } else {
580                // Pool is empty, create temporary process
581                StockfishProcess::new(self.depth)?
582            }
583        };
584
585        // Evaluate position
586        let result = process.evaluate_position(board);
587
588        // Return process to pool
589        {
590            let mut pool = self.pool.lock().unwrap();
591            if pool.len() < self.pool_size {
592                pool.push(process);
593            }
594            // Otherwise drop the process (in case of pool size changes)
595        }
596
597        result
598    }
599
600    pub fn evaluate_batch_parallel(
601        &self,
602        positions: &mut [TrainingData],
603    ) -> Result<(), Box<dyn std::error::Error>> {
604        let pb = ProgressBar::new(positions.len() as u64);
605        pb.set_style(ProgressStyle::default_bar()
606            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Pool evaluation")
607            .unwrap()
608            .progress_chars("#>-"));
609
610        // Use rayon for parallel evaluation
611        positions.par_iter_mut().for_each(|data| {
612            match self.evaluate_position(&data.board) {
613                Ok(eval) => {
614                    data.evaluation = eval;
615                    data.depth = self.depth;
616                }
617                Err(_) => {
618                    data.evaluation = 0.0;
619                }
620            }
621            pb.inc(1);
622        });
623
624        pb.finish_with_message("Pool evaluation complete");
625        Ok(())
626    }
627}
628
629/// Training dataset manager
630pub struct TrainingDataset {
631    pub data: Vec<TrainingData>,
632}
633
634impl Default for TrainingDataset {
635    fn default() -> Self {
636        Self::new()
637    }
638}
639
640impl TrainingDataset {
641    pub fn new() -> Self {
642        Self { data: Vec::new() }
643    }
644
645    /// Load positions from a PGN file
646    pub fn load_from_pgn<P: AsRef<Path>>(
647        &mut self,
648        path: P,
649        max_games: Option<usize>,
650        max_moves_per_game: usize,
651    ) -> Result<(), Box<dyn std::error::Error>> {
652        let file = File::open(path)?;
653        let reader = BufReader::new(file);
654
655        let mut extractor = GameExtractor::new(max_moves_per_game);
656        let mut games_processed = 0;
657
658        // Create a simple PGN parser
659        let mut pgn_content = String::new();
660        for line in reader.lines() {
661            let line = line?;
662            pgn_content.push_str(&line);
663            pgn_content.push('\n');
664
665            // Check if this is the end of a game
666            if line.trim().ends_with("1-0")
667                || line.trim().ends_with("0-1")
668                || line.trim().ends_with("1/2-1/2")
669                || line.trim().ends_with("*")
670            {
671                // Parse this game
672                let cursor = std::io::Cursor::new(&pgn_content);
673                let mut reader = BufferedReader::new(cursor);
674                if let Err(e) = reader.read_all(&mut extractor) {
675                    eprintln!("Evaluation error: {e}");
676                }
677
678                games_processed += 1;
679                pgn_content.clear();
680
681                if let Some(max) = max_games {
682                    if games_processed >= max {
683                        break;
684                    }
685                }
686
687                if games_processed % 100 == 0 {
688                    println!(
689                        "Processed {} games, extracted {} positions",
690                        games_processed,
691                        extractor.positions.len()
692                    );
693                }
694            }
695        }
696
697        self.data.extend(extractor.positions);
698        println!(
699            "Loaded {} positions from {} games",
700            self.data.len(),
701            games_processed
702        );
703        Ok(())
704    }
705
706    /// Evaluate all positions using Stockfish
707    pub fn evaluate_with_stockfish(&mut self, depth: u8) -> Result<(), Box<dyn std::error::Error>> {
708        let evaluator = StockfishEvaluator::new(depth);
709        evaluator.evaluate_batch(&mut self.data)
710    }
711
712    /// Evaluate all positions using Stockfish in parallel
713    pub fn evaluate_with_stockfish_parallel(
714        &mut self,
715        depth: u8,
716        num_threads: usize,
717    ) -> Result<(), Box<dyn std::error::Error>> {
718        let evaluator = StockfishEvaluator::new(depth);
719        evaluator.evaluate_batch_parallel(&mut self.data, num_threads)
720    }
721
722    /// Train the vector engine with this dataset
723    pub fn train_engine(&self, engine: &mut ChessVectorEngine) {
724        let pb = ProgressBar::new(self.data.len() as u64);
725        pb.set_style(ProgressStyle::default_bar()
726            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Training positions")
727            .unwrap()
728            .progress_chars("#>-"));
729
730        for data in &self.data {
731            engine.add_position(&data.board, data.evaluation);
732            pb.inc(1);
733        }
734
735        pb.finish_with_message("Training complete");
736        println!("Trained engine with {} positions", self.data.len());
737    }
738
739    /// Split dataset into train/test sets by games to prevent data leakage
740    pub fn split(&self, train_ratio: f32) -> (TrainingDataset, TrainingDataset) {
741        use rand::seq::SliceRandom;
742        use rand::thread_rng;
743        use std::collections::{HashMap, HashSet};
744
745        // Group positions by game_id
746        let mut games: HashMap<usize, Vec<&TrainingData>> = HashMap::new();
747        for data in &self.data {
748            games.entry(data.game_id).or_default().push(data);
749        }
750
751        // Get unique game IDs and shuffle them
752        let mut game_ids: Vec<usize> = games.keys().cloned().collect();
753        game_ids.shuffle(&mut thread_rng());
754
755        // Split games by ratio
756        let split_point = (game_ids.len() as f32 * train_ratio) as usize;
757        let train_game_ids: HashSet<usize> = game_ids[..split_point].iter().cloned().collect();
758
759        // Separate positions based on game membership
760        let mut train_data = Vec::new();
761        let mut test_data = Vec::new();
762
763        for data in &self.data {
764            if train_game_ids.contains(&data.game_id) {
765                train_data.push(data.clone());
766            } else {
767                test_data.push(data.clone());
768            }
769        }
770
771        (
772            TrainingDataset { data: train_data },
773            TrainingDataset { data: test_data },
774        )
775    }
776
777    /// Save dataset to file
778    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
779        let json = serde_json::to_string_pretty(&self.data)?;
780        std::fs::write(path, json)?;
781        Ok(())
782    }
783
784    /// Load dataset from file
785    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
786        let content = std::fs::read_to_string(path)?;
787        let data = serde_json::from_str(&content)?;
788        Ok(Self { data })
789    }
790
791    /// Load and append data from file to existing dataset (incremental training)
792    pub fn load_and_append<P: AsRef<Path>>(
793        &mut self,
794        path: P,
795    ) -> Result<(), Box<dyn std::error::Error>> {
796        let existing_len = self.data.len();
797        let additional_data = Self::load(path)?;
798        self.data.extend(additional_data.data);
799        println!(
800            "Loaded {} additional positions (total: {})",
801            self.data.len() - existing_len,
802            self.data.len()
803        );
804        Ok(())
805    }
806
807    /// Merge another dataset into this one
808    pub fn merge(&mut self, other: TrainingDataset) {
809        let existing_len = self.data.len();
810        self.data.extend(other.data);
811        println!(
812            "Merged {} positions (total: {})",
813            self.data.len() - existing_len,
814            self.data.len()
815        );
816    }
817
818    /// Save incrementally (append to existing file if it exists)
819    pub fn save_incremental<P: AsRef<Path>>(
820        &self,
821        path: P,
822    ) -> Result<(), Box<dyn std::error::Error>> {
823        self.save_incremental_with_options(path, true)
824    }
825
826    /// Save incrementally with option to skip deduplication
827    pub fn save_incremental_with_options<P: AsRef<Path>>(
828        &self,
829        path: P,
830        deduplicate: bool,
831    ) -> Result<(), Box<dyn std::error::Error>> {
832        let path = path.as_ref();
833
834        if path.exists() {
835            // Try fast append-only save first
836            if self.save_append_only(path).is_ok() {
837                return Ok(());
838            }
839
840            // Fall back to full merge
841            if deduplicate {
842                self.save_incremental_full_merge(path)
843            } else {
844                self.save_incremental_no_dedup(path)
845            }
846        } else {
847            // File doesn't exist, just save normally
848            self.save(path)
849        }
850    }
851
852    /// Fast merge without deduplication (for trusted unique data)
853    fn save_incremental_no_dedup<P: AsRef<Path>>(
854        &self,
855        path: P,
856    ) -> Result<(), Box<dyn std::error::Error>> {
857        let path = path.as_ref();
858
859        println!("📂 Loading existing training data...");
860        let mut existing = Self::load(path)?;
861
862        println!("⚡ Fast merge without deduplication...");
863        existing.data.extend(self.data.iter().cloned());
864
865        println!(
866            "💾 Serializing {} positions to JSON...",
867            existing.data.len()
868        );
869        let json = serde_json::to_string_pretty(&existing.data)?;
870
871        println!("✍️  Writing to disk...");
872        std::fs::write(path, json)?;
873
874        println!(
875            "✅ Fast merge save: total {} positions",
876            existing.data.len()
877        );
878        Ok(())
879    }
880
881    /// Fast append-only save (no deduplication, just append new positions)
882    pub fn save_append_only<P: AsRef<Path>>(
883        &self,
884        path: P,
885    ) -> Result<(), Box<dyn std::error::Error>> {
886        use std::fs::OpenOptions;
887        use std::io::{BufRead, BufReader, Seek, SeekFrom, Write};
888
889        if self.data.is_empty() {
890            return Ok(());
891        }
892
893        let path = path.as_ref();
894        let mut file = OpenOptions::new().read(true).write(true).open(path)?;
895
896        // Check if file is valid JSON array by reading last few bytes
897        file.seek(SeekFrom::End(-10))?;
898        let mut buffer = String::new();
899        BufReader::new(&file).read_line(&mut buffer)?;
900
901        if !buffer.trim().ends_with(']') {
902            return Err("File doesn't end with JSON array bracket".into());
903        }
904
905        // Seek back to overwrite the closing bracket
906        file.seek(SeekFrom::End(-2))?; // Go back 2 chars to overwrite "]\n"
907
908        // Append comma and new positions
909        write!(file, ",")?;
910
911        // Serialize and append new positions (without array brackets)
912        for (i, data) in self.data.iter().enumerate() {
913            if i > 0 {
914                write!(file, ",")?;
915            }
916            let json = serde_json::to_string(data)?;
917            write!(file, "{json}")?;
918        }
919
920        // Close the JSON array
921        write!(file, "\n]")?;
922
923        println!("Fast append: added {} new positions", self.data.len());
924        Ok(())
925    }
926
927    /// Full merge save with deduplication (slower but thorough)
928    fn save_incremental_full_merge<P: AsRef<Path>>(
929        &self,
930        path: P,
931    ) -> Result<(), Box<dyn std::error::Error>> {
932        let path = path.as_ref();
933
934        println!("📂 Loading existing training data...");
935        let mut existing = Self::load(path)?;
936        let _original_len = existing.data.len();
937
938        println!("🔄 Streaming merge with deduplication (avoiding O(n²) operation)...");
939        existing.merge_and_deduplicate(self.data.clone());
940
941        println!(
942            "💾 Serializing {} positions to JSON...",
943            existing.data.len()
944        );
945        let json = serde_json::to_string_pretty(&existing.data)?;
946
947        println!("✍️  Writing to disk...");
948        std::fs::write(path, json)?;
949
950        println!(
951            "✅ Streaming merge save: total {} positions",
952            existing.data.len()
953        );
954        Ok(())
955    }
956
957    /// Add a single training data point
958    pub fn add_position(&mut self, board: Board, evaluation: f32, depth: u8, game_id: usize) {
959        self.data.push(TrainingData {
960            board,
961            evaluation,
962            depth,
963            game_id,
964        });
965    }
966
967    /// Get the next available game ID for incremental training
968    pub fn next_game_id(&self) -> usize {
969        self.data.iter().map(|data| data.game_id).max().unwrap_or(0) + 1
970    }
971
972    /// Remove near-duplicate positions to reduce overfitting
973    pub fn deduplicate(&mut self, similarity_threshold: f32) {
974        if similarity_threshold > 0.999 {
975            // Use fast hash-based deduplication for exact duplicates
976            self.deduplicate_fast();
977        } else {
978            // Use slower similarity-based deduplication for near-duplicates
979            self.deduplicate_similarity_based(similarity_threshold);
980        }
981    }
982
983    /// Fast hash-based deduplication for exact duplicates (O(n))
984    pub fn deduplicate_fast(&mut self) {
985        use std::collections::HashSet;
986
987        if self.data.is_empty() {
988            return;
989        }
990
991        let mut seen_positions = HashSet::with_capacity(self.data.len());
992        let original_len = self.data.len();
993
994        // Keep positions with unique FEN strings
995        self.data.retain(|data| {
996            let fen = data.board.to_string();
997            seen_positions.insert(fen)
998        });
999
1000        println!(
1001            "Fast deduplicated: {} -> {} positions (removed {} exact duplicates)",
1002            original_len,
1003            self.data.len(),
1004            original_len - self.data.len()
1005        );
1006    }
1007
1008    /// Streaming deduplication when merging with existing data (faster for large datasets)
1009    pub fn merge_and_deduplicate(&mut self, new_data: Vec<TrainingData>) {
1010        use std::collections::HashSet;
1011
1012        if new_data.is_empty() {
1013            return;
1014        }
1015
1016        let _original_len = self.data.len();
1017
1018        // Build hashset of existing positions for fast lookup
1019        let mut existing_positions: HashSet<String> = HashSet::with_capacity(self.data.len());
1020        for data in &self.data {
1021            existing_positions.insert(data.board.to_string());
1022        }
1023
1024        // Only add new positions that don't already exist
1025        let mut added = 0;
1026        for data in new_data {
1027            let fen = data.board.to_string();
1028            if existing_positions.insert(fen) {
1029                self.data.push(data);
1030                added += 1;
1031            }
1032        }
1033
1034        println!(
1035            "Streaming merge: added {} unique positions (total: {})",
1036            added,
1037            self.data.len()
1038        );
1039    }
1040
1041    /// Similarity-based deduplication for near-duplicates (O(n²) but optimized)
1042    fn deduplicate_similarity_based(&mut self, similarity_threshold: f32) {
1043        use crate::PositionEncoder;
1044        use ndarray::Array1;
1045
1046        if self.data.is_empty() {
1047            return;
1048        }
1049
1050        let encoder = PositionEncoder::new(1024);
1051        let mut keep_indices: Vec<bool> = vec![true; self.data.len()];
1052
1053        // Encode all positions in parallel
1054        let vectors: Vec<Array1<f32>> = if self.data.len() > 50 {
1055            self.data
1056                .par_iter()
1057                .map(|data| encoder.encode(&data.board))
1058                .collect()
1059        } else {
1060            self.data
1061                .iter()
1062                .map(|data| encoder.encode(&data.board))
1063                .collect()
1064        };
1065
1066        // Compare each position with all previous ones
1067        for i in 1..self.data.len() {
1068            if !keep_indices[i] {
1069                continue;
1070            }
1071
1072            for j in 0..i {
1073                if !keep_indices[j] {
1074                    continue;
1075                }
1076
1077                let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1078                if similarity > similarity_threshold {
1079                    keep_indices[i] = false;
1080                    break;
1081                }
1082            }
1083        }
1084
1085        // Filter data based on keep_indices
1086        let original_len = self.data.len();
1087        self.data = self
1088            .data
1089            .iter()
1090            .enumerate()
1091            .filter_map(|(i, data)| {
1092                if keep_indices[i] {
1093                    Some(data.clone())
1094                } else {
1095                    None
1096                }
1097            })
1098            .collect();
1099
1100        println!(
1101            "Similarity deduplicated: {} -> {} positions (removed {} near-duplicates)",
1102            original_len,
1103            self.data.len(),
1104            original_len - self.data.len()
1105        );
1106    }
1107
1108    /// Remove near-duplicate positions using parallel comparison (faster for large datasets)
1109    pub fn deduplicate_parallel(&mut self, similarity_threshold: f32, chunk_size: usize) {
1110        use crate::PositionEncoder;
1111        use ndarray::Array1;
1112        use std::sync::{Arc, Mutex};
1113
1114        if self.data.is_empty() {
1115            return;
1116        }
1117
1118        let encoder = PositionEncoder::new(1024);
1119
1120        // Encode all positions in parallel
1121        let vectors: Vec<Array1<f32>> = self
1122            .data
1123            .par_iter()
1124            .map(|data| encoder.encode(&data.board))
1125            .collect();
1126
1127        let keep_indices = Arc::new(Mutex::new(vec![true; self.data.len()]));
1128
1129        // Process in chunks to balance parallelism and memory usage
1130        (1..self.data.len())
1131            .collect::<Vec<_>>()
1132            .par_chunks(chunk_size)
1133            .for_each(|chunk| {
1134                for &i in chunk {
1135                    // Check if this position is still being kept
1136                    {
1137                        let indices = keep_indices.lock().unwrap();
1138                        if !indices[i] {
1139                            continue;
1140                        }
1141                    }
1142
1143                    // Compare with all previous positions
1144                    for j in 0..i {
1145                        {
1146                            let indices = keep_indices.lock().unwrap();
1147                            if !indices[j] {
1148                                continue;
1149                            }
1150                        }
1151
1152                        let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1153                        if similarity > similarity_threshold {
1154                            let mut indices = keep_indices.lock().unwrap();
1155                            indices[i] = false;
1156                            break;
1157                        }
1158                    }
1159                }
1160            });
1161
1162        // Filter data based on keep_indices
1163        let keep_indices = keep_indices.lock().unwrap();
1164        let original_len = self.data.len();
1165        self.data = self
1166            .data
1167            .iter()
1168            .enumerate()
1169            .filter_map(|(i, data)| {
1170                if keep_indices[i] {
1171                    Some(data.clone())
1172                } else {
1173                    None
1174                }
1175            })
1176            .collect();
1177
1178        println!(
1179            "Parallel deduplicated: {} -> {} positions (removed {} duplicates)",
1180            original_len,
1181            self.data.len(),
1182            original_len - self.data.len()
1183        );
1184    }
1185
1186    /// Calculate cosine similarity between two vectors
1187    fn cosine_similarity(a: &ndarray::Array1<f32>, b: &ndarray::Array1<f32>) -> f32 {
1188        let dot_product = a.dot(b);
1189        let norm_a = a.dot(a).sqrt();
1190        let norm_b = b.dot(b).sqrt();
1191
1192        if norm_a == 0.0 || norm_b == 0.0 {
1193            0.0
1194        } else {
1195            dot_product / (norm_a * norm_b)
1196        }
1197    }
1198}
1199
1200/// Self-play training system for generating new positions
1201pub struct SelfPlayTrainer {
1202    config: SelfPlayConfig,
1203    game_counter: usize,
1204}
1205
1206impl SelfPlayTrainer {
1207    pub fn new(config: SelfPlayConfig) -> Self {
1208        Self {
1209            config,
1210            game_counter: 0,
1211        }
1212    }
1213
1214    /// Generate training data through self-play games
1215    pub fn generate_training_data(&mut self, engine: &mut ChessVectorEngine) -> TrainingDataset {
1216        let mut dataset = TrainingDataset::new();
1217
1218        println!(
1219            "🎮 Starting self-play training with {} games...",
1220            self.config.games_per_iteration
1221        );
1222        let pb = ProgressBar::new(self.config.games_per_iteration as u64);
1223        if let Ok(style) = ProgressStyle::default_bar().template(
1224            "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
1225        ) {
1226            pb.set_style(style.progress_chars("#>-"));
1227        }
1228
1229        for _ in 0..self.config.games_per_iteration {
1230            let game_data = self.play_single_game(engine);
1231            dataset.data.extend(game_data);
1232            self.game_counter += 1;
1233            pb.inc(1);
1234        }
1235
1236        pb.finish_with_message("Self-play games completed");
1237        println!(
1238            "✅ Generated {} positions from {} games",
1239            dataset.data.len(),
1240            self.config.games_per_iteration
1241        );
1242
1243        dataset
1244    }
1245
1246    /// Play a single self-play game and extract training positions
1247    fn play_single_game(&self, engine: &mut ChessVectorEngine) -> Vec<TrainingData> {
1248        let mut game = Game::new();
1249        let mut positions = Vec::new();
1250        let mut move_count = 0;
1251
1252        // Use opening book for variety if enabled
1253        if self.config.use_opening_book {
1254            if let Some(opening_moves) = self.get_random_opening() {
1255                for mv in opening_moves {
1256                    if game.make_move(mv) {
1257                        move_count += 1;
1258                    } else {
1259                        break;
1260                    }
1261                }
1262            }
1263        }
1264
1265        // Play the game
1266        while game.result().is_none() && move_count < self.config.max_moves_per_game {
1267            let current_position = game.current_position();
1268
1269            // Get engine's move recommendation with exploration
1270            let move_choice = self.select_move_with_exploration(engine, &current_position);
1271
1272            if let Some(chess_move) = move_choice {
1273                // Evaluate the position before making the move
1274                if let Some(evaluation) = engine.evaluate_position(&current_position) {
1275                    // Only include positions with sufficient confidence
1276                    if evaluation.abs() >= self.config.min_confidence || move_count < 10 {
1277                        positions.push(TrainingData {
1278                            board: current_position,
1279                            evaluation,
1280                            depth: 1, // Self-play depth
1281                            game_id: self.game_counter,
1282                        });
1283                    }
1284                }
1285
1286                // Make the move
1287                if !game.make_move(chess_move) {
1288                    break; // Invalid move, end game
1289                }
1290                move_count += 1;
1291            } else {
1292                break; // No legal moves
1293            }
1294        }
1295
1296        // Add final position evaluation based on game result
1297        if let Some(result) = game.result() {
1298            let final_position = game.current_position();
1299            let final_eval = match result {
1300                chess::GameResult::WhiteCheckmates => {
1301                    if final_position.side_to_move() == chess::Color::Black {
1302                        10.0
1303                    } else {
1304                        -10.0
1305                    }
1306                }
1307                chess::GameResult::BlackCheckmates => {
1308                    if final_position.side_to_move() == chess::Color::White {
1309                        10.0
1310                    } else {
1311                        -10.0
1312                    }
1313                }
1314                chess::GameResult::WhiteResigns => -10.0,
1315                chess::GameResult::BlackResigns => 10.0,
1316                chess::GameResult::Stalemate
1317                | chess::GameResult::DrawAccepted
1318                | chess::GameResult::DrawDeclared => 0.0,
1319            };
1320
1321            positions.push(TrainingData {
1322                board: final_position,
1323                evaluation: final_eval,
1324                depth: 1,
1325                game_id: self.game_counter,
1326            });
1327        }
1328
1329        positions
1330    }
1331
1332    /// Select a move with exploration vs exploitation balance
1333    fn select_move_with_exploration(
1334        &self,
1335        engine: &mut ChessVectorEngine,
1336        position: &Board,
1337    ) -> Option<ChessMove> {
1338        let recommendations = engine.recommend_legal_moves(position, 5);
1339
1340        if recommendations.is_empty() {
1341            return None;
1342        }
1343
1344        // Use temperature-based selection for exploration
1345        if fastrand::f32() < self.config.exploration_factor {
1346            // Exploration: weighted random selection based on evaluations
1347            self.select_move_with_temperature(&recommendations)
1348        } else {
1349            // Exploitation: take the best move
1350            Some(recommendations[0].chess_move)
1351        }
1352    }
1353
1354    /// Temperature-based move selection for exploration
1355    fn select_move_with_temperature(
1356        &self,
1357        recommendations: &[crate::MoveRecommendation],
1358    ) -> Option<ChessMove> {
1359        if recommendations.is_empty() {
1360            return None;
1361        }
1362
1363        // Convert evaluations to probabilities using temperature
1364        let mut probabilities = Vec::new();
1365        let mut sum = 0.0;
1366
1367        for rec in recommendations {
1368            // Use average_outcome as evaluation score for temperature selection
1369            let prob = (rec.average_outcome / self.config.temperature).exp();
1370            probabilities.push(prob);
1371            sum += prob;
1372        }
1373
1374        // Normalize probabilities
1375        for prob in &mut probabilities {
1376            *prob /= sum;
1377        }
1378
1379        // Random selection based on probabilities
1380        let rand_val = fastrand::f32();
1381        let mut cumulative = 0.0;
1382
1383        for (i, &prob) in probabilities.iter().enumerate() {
1384            cumulative += prob;
1385            if rand_val <= cumulative {
1386                return Some(recommendations[i].chess_move);
1387            }
1388        }
1389
1390        // Fallback to first move
1391        Some(recommendations[0].chess_move)
1392    }
1393
1394    /// Get random opening moves for variety
1395    fn get_random_opening(&self) -> Option<Vec<ChessMove>> {
1396        let openings = [
1397            // Italian Game
1398            vec!["e4", "e5", "Nf3", "Nc6", "Bc4"],
1399            // Ruy Lopez
1400            vec!["e4", "e5", "Nf3", "Nc6", "Bb5"],
1401            // Queen's Gambit
1402            vec!["d4", "d5", "c4"],
1403            // King's Indian Defense
1404            vec!["d4", "Nf6", "c4", "g6"],
1405            // Sicilian Defense
1406            vec!["e4", "c5"],
1407            // French Defense
1408            vec!["e4", "e6"],
1409            // Caro-Kann Defense
1410            vec!["e4", "c6"],
1411        ];
1412
1413        let selected_opening = &openings[fastrand::usize(0..openings.len())];
1414
1415        let mut moves = Vec::new();
1416        let mut game = Game::new();
1417
1418        for move_str in selected_opening {
1419            if let Ok(chess_move) = ChessMove::from_str(move_str) {
1420                if game.make_move(chess_move) {
1421                    moves.push(chess_move);
1422                } else {
1423                    break;
1424                }
1425            }
1426        }
1427
1428        if moves.is_empty() {
1429            None
1430        } else {
1431            Some(moves)
1432        }
1433    }
1434}
1435
1436/// Engine performance evaluator
1437pub struct EngineEvaluator {
1438    #[allow(dead_code)]
1439    stockfish_depth: u8,
1440}
1441
1442impl EngineEvaluator {
1443    pub fn new(stockfish_depth: u8) -> Self {
1444        Self { stockfish_depth }
1445    }
1446
1447    /// Compare engine evaluations against Stockfish on test set
1448    pub fn evaluate_accuracy(
1449        &self,
1450        engine: &mut ChessVectorEngine,
1451        test_data: &TrainingDataset,
1452    ) -> Result<f32, Box<dyn std::error::Error>> {
1453        let mut total_error = 0.0;
1454        let mut valid_comparisons = 0;
1455
1456        let pb = ProgressBar::new(test_data.data.len() as u64);
1457        pb.set_style(ProgressStyle::default_bar()
1458            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Evaluating accuracy")
1459            .unwrap()
1460            .progress_chars("#>-"));
1461
1462        for data in &test_data.data {
1463            if let Some(engine_eval) = engine.evaluate_position(&data.board) {
1464                let error = (engine_eval - data.evaluation).abs();
1465                total_error += error;
1466                valid_comparisons += 1;
1467            }
1468            pb.inc(1);
1469        }
1470
1471        pb.finish_with_message("Accuracy evaluation complete");
1472
1473        if valid_comparisons > 0 {
1474            let mean_absolute_error = total_error / valid_comparisons as f32;
1475            println!("Mean Absolute Error: {:.3} pawns", mean_absolute_error);
1476            println!("Evaluated {} positions", valid_comparisons);
1477            Ok(mean_absolute_error)
1478        } else {
1479            Ok(f32::INFINITY)
1480        }
1481    }
1482}
1483
1484// Make TrainingData serializable
1485impl serde::Serialize for TrainingData {
1486    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1487    where
1488        S: serde::Serializer,
1489    {
1490        use serde::ser::SerializeStruct;
1491        let mut state = serializer.serialize_struct("TrainingData", 4)?;
1492        state.serialize_field("fen", &self.board.to_string())?;
1493        state.serialize_field("evaluation", &self.evaluation)?;
1494        state.serialize_field("depth", &self.depth)?;
1495        state.serialize_field("game_id", &self.game_id)?;
1496        state.end()
1497    }
1498}
1499
1500impl<'de> serde::Deserialize<'de> for TrainingData {
1501    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1502    where
1503        D: serde::Deserializer<'de>,
1504    {
1505        use serde::de::{self, MapAccess, Visitor};
1506        use std::fmt;
1507
1508        struct TrainingDataVisitor;
1509
1510        impl<'de> Visitor<'de> for TrainingDataVisitor {
1511            type Value = TrainingData;
1512
1513            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1514                formatter.write_str("struct TrainingData")
1515            }
1516
1517            fn visit_map<V>(self, mut map: V) -> Result<TrainingData, V::Error>
1518            where
1519                V: MapAccess<'de>,
1520            {
1521                let mut fen = None;
1522                let mut evaluation = None;
1523                let mut depth = None;
1524                let mut game_id = None;
1525
1526                while let Some(key) = map.next_key()? {
1527                    match key {
1528                        "fen" => {
1529                            if fen.is_some() {
1530                                return Err(de::Error::duplicate_field("fen"));
1531                            }
1532                            fen = Some(map.next_value()?);
1533                        }
1534                        "evaluation" => {
1535                            if evaluation.is_some() {
1536                                return Err(de::Error::duplicate_field("evaluation"));
1537                            }
1538                            evaluation = Some(map.next_value()?);
1539                        }
1540                        "depth" => {
1541                            if depth.is_some() {
1542                                return Err(de::Error::duplicate_field("depth"));
1543                            }
1544                            depth = Some(map.next_value()?);
1545                        }
1546                        "game_id" => {
1547                            if game_id.is_some() {
1548                                return Err(de::Error::duplicate_field("game_id"));
1549                            }
1550                            game_id = Some(map.next_value()?);
1551                        }
1552                        _ => {
1553                            let _: serde_json::Value = map.next_value()?;
1554                        }
1555                    }
1556                }
1557
1558                let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
1559                let mut evaluation: f32 =
1560                    evaluation.ok_or_else(|| de::Error::missing_field("evaluation"))?;
1561                let depth = depth.ok_or_else(|| de::Error::missing_field("depth"))?;
1562                let game_id = game_id.unwrap_or(0); // Default to 0 for backward compatibility
1563
1564                // Convert evaluation from centipawns to pawns if needed
1565                // If evaluation is outside typical pawn range (-10 to +10),
1566                // assume it's in centipawns and convert to pawns
1567                if evaluation.abs() > 15.0 {
1568                    evaluation /= 100.0;
1569                }
1570
1571                let board =
1572                    Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
1573
1574                Ok(TrainingData {
1575                    board,
1576                    evaluation,
1577                    depth,
1578                    game_id,
1579                })
1580            }
1581        }
1582
1583        const FIELDS: &[&str] = &["fen", "evaluation", "depth", "game_id"];
1584        deserializer.deserialize_struct("TrainingData", FIELDS, TrainingDataVisitor)
1585    }
1586}
1587
1588/// Tactical puzzle parser for Lichess puzzle database
1589pub struct TacticalPuzzleParser;
1590
1591impl TacticalPuzzleParser {
1592    /// Parse Lichess puzzle CSV file with parallel processing
1593    pub fn parse_csv<P: AsRef<Path>>(
1594        file_path: P,
1595        max_puzzles: Option<usize>,
1596        min_rating: Option<u32>,
1597        max_rating: Option<u32>,
1598    ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1599        let file = File::open(&file_path)?;
1600        let file_size = file.metadata()?.len();
1601
1602        // For large files (>100MB), use parallel processing
1603        if file_size > 100_000_000 {
1604            Self::parse_csv_parallel(file_path, max_puzzles, min_rating, max_rating)
1605        } else {
1606            Self::parse_csv_sequential(file_path, max_puzzles, min_rating, max_rating)
1607        }
1608    }
1609
1610    /// Sequential CSV parsing for smaller files
1611    fn parse_csv_sequential<P: AsRef<Path>>(
1612        file_path: P,
1613        max_puzzles: Option<usize>,
1614        min_rating: Option<u32>,
1615        max_rating: Option<u32>,
1616    ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1617        let file = File::open(file_path)?;
1618        let reader = BufReader::new(file);
1619
1620        // Create CSV reader without headers since Lichess CSV has no header row
1621        // Set flexible field count to handle inconsistent CSV structure
1622        let mut csv_reader = csv::ReaderBuilder::new()
1623            .has_headers(false)
1624            .flexible(true) // Allow variable number of fields
1625            .from_reader(reader);
1626
1627        let mut tactical_data = Vec::new();
1628        let mut processed = 0;
1629        let mut skipped = 0;
1630
1631        let pb = ProgressBar::new_spinner();
1632        pb.set_style(
1633            ProgressStyle::default_spinner()
1634                .template("{spinner:.green} Parsing tactical puzzles: {pos} (skipped: {skipped})")
1635                .unwrap(),
1636        );
1637
1638        for result in csv_reader.records() {
1639            let record = match result {
1640                Ok(r) => r,
1641                Err(e) => {
1642                    skipped += 1;
1643                    println!("CSV parsing error: {e}");
1644                    continue;
1645                }
1646            };
1647
1648            if let Some(puzzle_data) = Self::parse_csv_record(&record, min_rating, max_rating) {
1649                if let Some(tactical_data_item) =
1650                    Self::convert_puzzle_to_training_data(&puzzle_data)
1651                {
1652                    tactical_data.push(tactical_data_item);
1653                    processed += 1;
1654
1655                    if let Some(max) = max_puzzles {
1656                        if processed >= max {
1657                            break;
1658                        }
1659                    }
1660                } else {
1661                    skipped += 1;
1662                }
1663            } else {
1664                skipped += 1;
1665            }
1666
1667            pb.set_message(format!(
1668                "Parsing tactical puzzles: {} (skipped: {})",
1669                processed, skipped
1670            ));
1671        }
1672
1673        pb.finish_with_message(format!(
1674            "Parsed {} puzzles (skipped: {})",
1675            processed, skipped
1676        ));
1677
1678        Ok(tactical_data)
1679    }
1680
1681    /// Parallel CSV parsing for large files
1682    fn parse_csv_parallel<P: AsRef<Path>>(
1683        file_path: P,
1684        max_puzzles: Option<usize>,
1685        min_rating: Option<u32>,
1686        max_rating: Option<u32>,
1687    ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1688        use std::io::Read;
1689
1690        let mut file = File::open(&file_path)?;
1691
1692        // Read entire file into memory for parallel processing
1693        let mut contents = String::new();
1694        file.read_to_string(&mut contents)?;
1695
1696        // Split into lines for parallel processing
1697        let lines: Vec<&str> = contents.lines().collect();
1698
1699        let pb = ProgressBar::new(lines.len() as u64);
1700        pb.set_style(ProgressStyle::default_bar()
1701            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel CSV parsing")
1702            .unwrap()
1703            .progress_chars("#>-"));
1704
1705        // Process lines in parallel
1706        let tactical_data: Vec<TacticalTrainingData> = lines
1707            .par_iter()
1708            .take(max_puzzles.unwrap_or(usize::MAX))
1709            .filter_map(|line| {
1710                // Parse CSV line manually
1711                let fields: Vec<&str> = line.split(',').collect();
1712                if fields.len() < 8 {
1713                    return None;
1714                }
1715
1716                // Build puzzle from fields
1717                if let Some(puzzle_data) = Self::parse_csv_fields(&fields, min_rating, max_rating) {
1718                    Self::convert_puzzle_to_training_data(&puzzle_data)
1719                } else {
1720                    None
1721                }
1722            })
1723            .collect();
1724
1725        pb.finish_with_message(format!(
1726            "Parallel parsing complete: {} puzzles",
1727            tactical_data.len()
1728        ));
1729
1730        Ok(tactical_data)
1731    }
1732
1733    /// Parse CSV record into TacticalPuzzle
1734    fn parse_csv_record(
1735        record: &csv::StringRecord,
1736        min_rating: Option<u32>,
1737        max_rating: Option<u32>,
1738    ) -> Option<TacticalPuzzle> {
1739        // Need at least 8 fields (PuzzleId, FEN, Moves, Rating, RatingDeviation, Popularity, NbPlays, Themes)
1740        if record.len() < 8 {
1741            return None;
1742        }
1743
1744        let rating: u32 = record[3].parse().ok()?;
1745        let rating_deviation: u32 = record[4].parse().ok()?;
1746        let popularity: i32 = record[5].parse().ok()?;
1747        let nb_plays: u32 = record[6].parse().ok()?;
1748
1749        // Apply rating filters
1750        if let Some(min) = min_rating {
1751            if rating < min {
1752                return None;
1753            }
1754        }
1755        if let Some(max) = max_rating {
1756            if rating > max {
1757                return None;
1758            }
1759        }
1760
1761        Some(TacticalPuzzle {
1762            puzzle_id: record[0].to_string(),
1763            fen: record[1].to_string(),
1764            moves: record[2].to_string(),
1765            rating,
1766            rating_deviation,
1767            popularity,
1768            nb_plays,
1769            themes: record[7].to_string(),
1770            game_url: if record.len() > 8 {
1771                Some(record[8].to_string())
1772            } else {
1773                None
1774            },
1775            opening_tags: if record.len() > 9 {
1776                Some(record[9].to_string())
1777            } else {
1778                None
1779            },
1780        })
1781    }
1782
1783    /// Parse CSV fields into TacticalPuzzle (for parallel processing)
1784    fn parse_csv_fields(
1785        fields: &[&str],
1786        min_rating: Option<u32>,
1787        max_rating: Option<u32>,
1788    ) -> Option<TacticalPuzzle> {
1789        if fields.len() < 8 {
1790            return None;
1791        }
1792
1793        let rating: u32 = fields[3].parse().ok()?;
1794        let rating_deviation: u32 = fields[4].parse().ok()?;
1795        let popularity: i32 = fields[5].parse().ok()?;
1796        let nb_plays: u32 = fields[6].parse().ok()?;
1797
1798        // Apply rating filters
1799        if let Some(min) = min_rating {
1800            if rating < min {
1801                return None;
1802            }
1803        }
1804        if let Some(max) = max_rating {
1805            if rating > max {
1806                return None;
1807            }
1808        }
1809
1810        Some(TacticalPuzzle {
1811            puzzle_id: fields[0].to_string(),
1812            fen: fields[1].to_string(),
1813            moves: fields[2].to_string(),
1814            rating,
1815            rating_deviation,
1816            popularity,
1817            nb_plays,
1818            themes: fields[7].to_string(),
1819            game_url: if fields.len() > 8 {
1820                Some(fields[8].to_string())
1821            } else {
1822                None
1823            },
1824            opening_tags: if fields.len() > 9 {
1825                Some(fields[9].to_string())
1826            } else {
1827                None
1828            },
1829        })
1830    }
1831
1832    /// Convert puzzle to training data
1833    fn convert_puzzle_to_training_data(puzzle: &TacticalPuzzle) -> Option<TacticalTrainingData> {
1834        // Parse FEN position
1835        let position = match Board::from_str(&puzzle.fen) {
1836            Ok(board) => board,
1837            Err(_) => return None,
1838        };
1839
1840        // Parse move sequence - first move is the solution
1841        let moves: Vec<&str> = puzzle.moves.split_whitespace().collect();
1842        if moves.is_empty() {
1843            return None;
1844        }
1845
1846        // Parse the solution move (first move in sequence)
1847        let solution_move = match ChessMove::from_str(moves[0]) {
1848            Ok(mv) => mv,
1849            Err(_) => {
1850                // Try parsing as SAN
1851                match ChessMove::from_san(&position, moves[0]) {
1852                    Ok(mv) => mv,
1853                    Err(_) => return None,
1854                }
1855            }
1856        };
1857
1858        // Verify move is legal
1859        let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&position).collect();
1860        if !legal_moves.contains(&solution_move) {
1861            return None;
1862        }
1863
1864        // Extract primary theme
1865        let themes: Vec<&str> = puzzle.themes.split_whitespace().collect();
1866        let primary_theme = themes.first().unwrap_or(&"tactical").to_string();
1867
1868        // Calculate tactical value based on rating and popularity
1869        let difficulty = puzzle.rating as f32 / 1000.0; // Normalize to 0.8-3.0 range
1870        let popularity_bonus = (puzzle.popularity as f32 / 100.0).min(2.0);
1871        let tactical_value = difficulty + popularity_bonus; // High value for move outcome
1872
1873        Some(TacticalTrainingData {
1874            position,
1875            solution_move,
1876            move_theme: primary_theme,
1877            difficulty,
1878            tactical_value,
1879        })
1880    }
1881
1882    /// Load tactical training data into chess engine
1883    pub fn load_into_engine(
1884        tactical_data: &[TacticalTrainingData],
1885        engine: &mut ChessVectorEngine,
1886    ) {
1887        let pb = ProgressBar::new(tactical_data.len() as u64);
1888        pb.set_style(ProgressStyle::default_bar()
1889            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns")
1890            .unwrap()
1891            .progress_chars("#>-"));
1892
1893        for data in tactical_data {
1894            // Add position with high-value tactical move
1895            engine.add_position_with_move(
1896                &data.position,
1897                0.0, // Position evaluation (neutral for puzzles)
1898                Some(data.solution_move),
1899                Some(data.tactical_value), // High tactical value
1900            );
1901            pb.inc(1);
1902        }
1903
1904        pb.finish_with_message(format!("Loaded {} tactical patterns", tactical_data.len()));
1905    }
1906
1907    /// Load tactical training data into chess engine incrementally (preserves existing data)
1908    pub fn load_into_engine_incremental(
1909        tactical_data: &[TacticalTrainingData],
1910        engine: &mut ChessVectorEngine,
1911    ) {
1912        let initial_size = engine.knowledge_base_size();
1913        let initial_moves = engine.position_moves.len();
1914
1915        // For large datasets, use parallel batch processing
1916        if tactical_data.len() > 1000 {
1917            Self::load_into_engine_incremental_parallel(
1918                tactical_data,
1919                engine,
1920                initial_size,
1921                initial_moves,
1922            );
1923        } else {
1924            Self::load_into_engine_incremental_sequential(
1925                tactical_data,
1926                engine,
1927                initial_size,
1928                initial_moves,
1929            );
1930        }
1931    }
1932
1933    /// Sequential loading for smaller datasets
1934    fn load_into_engine_incremental_sequential(
1935        tactical_data: &[TacticalTrainingData],
1936        engine: &mut ChessVectorEngine,
1937        initial_size: usize,
1938        initial_moves: usize,
1939    ) {
1940        let pb = ProgressBar::new(tactical_data.len() as u64);
1941        pb.set_style(ProgressStyle::default_bar()
1942            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns (incremental)")
1943            .unwrap()
1944            .progress_chars("#>-"));
1945
1946        let mut added = 0;
1947        let mut skipped = 0;
1948
1949        for data in tactical_data {
1950            // Check if this position already exists to avoid duplicates
1951            if !engine.position_boards.contains(&data.position) {
1952                engine.add_position_with_move(
1953                    &data.position,
1954                    0.0, // Position evaluation (neutral for puzzles)
1955                    Some(data.solution_move),
1956                    Some(data.tactical_value), // High tactical value
1957                );
1958                added += 1;
1959            } else {
1960                skipped += 1;
1961            }
1962            pb.inc(1);
1963        }
1964
1965        pb.finish_with_message(format!(
1966            "Loaded {} new tactical patterns (skipped {} duplicates, total: {})",
1967            added,
1968            skipped,
1969            engine.knowledge_base_size()
1970        ));
1971
1972        println!("Incremental tactical training:");
1973        println!(
1974            "  - Positions: {} → {} (+{})",
1975            initial_size,
1976            engine.knowledge_base_size(),
1977            engine.knowledge_base_size() - initial_size
1978        );
1979        println!(
1980            "  - Move entries: {} → {} (+{})",
1981            initial_moves,
1982            engine.position_moves.len(),
1983            engine.position_moves.len() - initial_moves
1984        );
1985    }
1986
1987    /// Parallel batch loading for large datasets
1988    fn load_into_engine_incremental_parallel(
1989        tactical_data: &[TacticalTrainingData],
1990        engine: &mut ChessVectorEngine,
1991        initial_size: usize,
1992        initial_moves: usize,
1993    ) {
1994        let pb = ProgressBar::new(tactical_data.len() as u64);
1995        pb.set_style(ProgressStyle::default_bar()
1996            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Optimized batch loading tactical patterns")
1997            .unwrap()
1998            .progress_chars("#>-"));
1999
2000        // Parallel pre-filtering to avoid duplicates (this is thread-safe for read operations)
2001        let filtered_data: Vec<&TacticalTrainingData> = tactical_data
2002            .par_iter()
2003            .filter(|data| !engine.position_boards.contains(&data.position))
2004            .collect();
2005
2006        let batch_size = 1000; // Larger batches for better performance
2007        let mut added = 0;
2008
2009        println!(
2010            "Pre-filtered: {} → {} positions (removed {} duplicates)",
2011            tactical_data.len(),
2012            filtered_data.len(),
2013            tactical_data.len() - filtered_data.len()
2014        );
2015
2016        // Process in sequential batches (engine operations aren't thread-safe)
2017        // But use optimized batch processing
2018        for batch in filtered_data.chunks(batch_size) {
2019            let batch_start = added;
2020
2021            for data in batch {
2022                // Final duplicate check (should be minimal after pre-filtering)
2023                if !engine.position_boards.contains(&data.position) {
2024                    engine.add_position_with_move(
2025                        &data.position,
2026                        0.0, // Position evaluation (neutral for puzzles)
2027                        Some(data.solution_move),
2028                        Some(data.tactical_value), // High tactical value
2029                    );
2030                    added += 1;
2031                }
2032                pb.inc(1);
2033            }
2034
2035            // Update progress message every batch
2036            pb.set_message(format!("Loaded batch: {} positions", added - batch_start));
2037        }
2038
2039        let skipped = tactical_data.len() - added;
2040
2041        pb.finish_with_message(format!(
2042            "Optimized loaded {} new tactical patterns (skipped {} duplicates, total: {})",
2043            added,
2044            skipped,
2045            engine.knowledge_base_size()
2046        ));
2047
2048        println!("Incremental tactical training (optimized):");
2049        println!(
2050            "  - Positions: {} → {} (+{})",
2051            initial_size,
2052            engine.knowledge_base_size(),
2053            engine.knowledge_base_size() - initial_size
2054        );
2055        println!(
2056            "  - Move entries: {} → {} (+{})",
2057            initial_moves,
2058            engine.position_moves.len(),
2059            engine.position_moves.len() - initial_moves
2060        );
2061        println!(
2062            "  - Batch size: {}, Pre-filtered efficiency: {:.1}%",
2063            batch_size,
2064            (filtered_data.len() as f32 / tactical_data.len() as f32) * 100.0
2065        );
2066    }
2067
2068    /// Save tactical puzzles to file for incremental loading later
2069    pub fn save_tactical_puzzles<P: AsRef<std::path::Path>>(
2070        tactical_data: &[TacticalTrainingData],
2071        path: P,
2072    ) -> Result<(), Box<dyn std::error::Error>> {
2073        let json = serde_json::to_string_pretty(tactical_data)?;
2074        std::fs::write(path, json)?;
2075        println!("Saved {} tactical puzzles", tactical_data.len());
2076        Ok(())
2077    }
2078
2079    /// Load tactical puzzles from file
2080    pub fn load_tactical_puzzles<P: AsRef<std::path::Path>>(
2081        path: P,
2082    ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
2083        let content = std::fs::read_to_string(path)?;
2084        let tactical_data: Vec<TacticalTrainingData> = serde_json::from_str(&content)?;
2085        println!("Loaded {} tactical puzzles from file", tactical_data.len());
2086        Ok(tactical_data)
2087    }
2088
2089    /// Save tactical puzzles incrementally (appends to existing file)
2090    pub fn save_tactical_puzzles_incremental<P: AsRef<std::path::Path>>(
2091        tactical_data: &[TacticalTrainingData],
2092        path: P,
2093    ) -> Result<(), Box<dyn std::error::Error>> {
2094        let path = path.as_ref();
2095
2096        if path.exists() {
2097            // Load existing puzzles
2098            let mut existing = Self::load_tactical_puzzles(path)?;
2099            let original_len = existing.len();
2100
2101            // Add new puzzles, checking for duplicates by puzzle ID if available
2102            for new_puzzle in tactical_data {
2103                // Check if this puzzle already exists (by position)
2104                let exists = existing.iter().any(|existing_puzzle| {
2105                    existing_puzzle.position == new_puzzle.position
2106                        && existing_puzzle.solution_move == new_puzzle.solution_move
2107                });
2108
2109                if !exists {
2110                    existing.push(new_puzzle.clone());
2111                }
2112            }
2113
2114            // Save merged data
2115            let json = serde_json::to_string_pretty(&existing)?;
2116            std::fs::write(path, json)?;
2117
2118            println!(
2119                "Incremental save: added {} new puzzles (total: {})",
2120                existing.len() - original_len,
2121                existing.len()
2122            );
2123        } else {
2124            // File doesn't exist, just save normally
2125            Self::save_tactical_puzzles(tactical_data, path)?;
2126        }
2127        Ok(())
2128    }
2129
2130    /// Parse Lichess puzzles incrementally (preserves existing engine state)
2131    pub fn parse_and_load_incremental<P: AsRef<std::path::Path>>(
2132        file_path: P,
2133        engine: &mut ChessVectorEngine,
2134        max_puzzles: Option<usize>,
2135        min_rating: Option<u32>,
2136        max_rating: Option<u32>,
2137    ) -> Result<(), Box<dyn std::error::Error>> {
2138        println!("Parsing Lichess puzzles incrementally...");
2139
2140        // Parse puzzles
2141        let tactical_data = Self::parse_csv(file_path, max_puzzles, min_rating, max_rating)?;
2142
2143        // Load into engine incrementally
2144        Self::load_into_engine_incremental(&tactical_data, engine);
2145
2146        Ok(())
2147    }
2148}
2149
2150#[cfg(test)]
2151mod tests {
2152    use super::*;
2153    use chess::Board;
2154    use std::str::FromStr;
2155
2156    #[test]
2157    fn test_training_dataset_creation() {
2158        let dataset = TrainingDataset::new();
2159        assert_eq!(dataset.data.len(), 0);
2160    }
2161
2162    #[test]
2163    fn test_add_training_data() {
2164        let mut dataset = TrainingDataset::new();
2165        let board = Board::default();
2166
2167        let training_data = TrainingData {
2168            board,
2169            evaluation: 0.5,
2170            depth: 15,
2171            game_id: 1,
2172        };
2173
2174        dataset.data.push(training_data);
2175        assert_eq!(dataset.data.len(), 1);
2176        assert_eq!(dataset.data[0].evaluation, 0.5);
2177    }
2178
2179    #[test]
2180    fn test_chess_engine_integration() {
2181        let mut dataset = TrainingDataset::new();
2182        let board = Board::default();
2183
2184        let training_data = TrainingData {
2185            board,
2186            evaluation: 0.3,
2187            depth: 15,
2188            game_id: 1,
2189        };
2190
2191        dataset.data.push(training_data);
2192
2193        let mut engine = ChessVectorEngine::new(1024);
2194        dataset.train_engine(&mut engine);
2195
2196        assert_eq!(engine.knowledge_base_size(), 1);
2197
2198        let eval = engine.evaluate_position(&board);
2199        assert!(eval.is_some());
2200        assert!((eval.unwrap() - 0.3).abs() < 1e-6);
2201    }
2202
2203    #[test]
2204    fn test_deduplication() {
2205        let mut dataset = TrainingDataset::new();
2206        let board = Board::default();
2207
2208        // Add duplicate positions
2209        for i in 0..5 {
2210            let training_data = TrainingData {
2211                board,
2212                evaluation: i as f32 * 0.1,
2213                depth: 15,
2214                game_id: i,
2215            };
2216            dataset.data.push(training_data);
2217        }
2218
2219        assert_eq!(dataset.data.len(), 5);
2220
2221        // Deduplicate with high threshold (should keep only 1)
2222        dataset.deduplicate(0.999);
2223        assert_eq!(dataset.data.len(), 1);
2224    }
2225
2226    #[test]
2227    fn test_dataset_serialization() {
2228        let mut dataset = TrainingDataset::new();
2229        let board =
2230            Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2231
2232        let training_data = TrainingData {
2233            board,
2234            evaluation: 0.2,
2235            depth: 10,
2236            game_id: 42,
2237        };
2238
2239        dataset.data.push(training_data);
2240
2241        // Test serialization/deserialization
2242        let json = serde_json::to_string(&dataset.data).unwrap();
2243        let loaded_data: Vec<TrainingData> = serde_json::from_str(&json).unwrap();
2244        let loaded_dataset = TrainingDataset { data: loaded_data };
2245
2246        assert_eq!(loaded_dataset.data.len(), 1);
2247        assert_eq!(loaded_dataset.data[0].evaluation, 0.2);
2248        assert_eq!(loaded_dataset.data[0].depth, 10);
2249        assert_eq!(loaded_dataset.data[0].game_id, 42);
2250    }
2251
2252    #[test]
2253    fn test_tactical_puzzle_processing() {
2254        let puzzle = TacticalPuzzle {
2255            puzzle_id: "test123".to_string(),
2256            fen: "r1bqkbnr/pppp1ppp/2n5/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4".to_string(),
2257            moves: "Bxf7+ Ke7".to_string(),
2258            rating: 1500,
2259            rating_deviation: 100,
2260            popularity: 150,
2261            nb_plays: 1000,
2262            themes: "fork pin".to_string(),
2263            game_url: None,
2264            opening_tags: None,
2265        };
2266
2267        let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2268        assert!(tactical_data.is_some());
2269
2270        let data = tactical_data.unwrap();
2271        assert_eq!(data.move_theme, "fork");
2272        assert!(data.tactical_value > 1.0); // Should have high tactical value
2273        assert!(data.difficulty > 0.0);
2274    }
2275
2276    #[test]
2277    fn test_tactical_puzzle_invalid_fen() {
2278        let puzzle = TacticalPuzzle {
2279            puzzle_id: "test123".to_string(),
2280            fen: "invalid_fen".to_string(),
2281            moves: "e2e4".to_string(),
2282            rating: 1500,
2283            rating_deviation: 100,
2284            popularity: 150,
2285            nb_plays: 1000,
2286            themes: "tactics".to_string(),
2287            game_url: None,
2288            opening_tags: None,
2289        };
2290
2291        let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2292        assert!(tactical_data.is_none());
2293    }
2294
2295    #[test]
2296    fn test_engine_evaluator() {
2297        let evaluator = EngineEvaluator::new(15);
2298
2299        // Create test dataset
2300        let mut dataset = TrainingDataset::new();
2301        let board = Board::default();
2302
2303        let training_data = TrainingData {
2304            board,
2305            evaluation: 0.0,
2306            depth: 15,
2307            game_id: 1,
2308        };
2309
2310        dataset.data.push(training_data);
2311
2312        // Create engine with some data
2313        let mut engine = ChessVectorEngine::new(1024);
2314        engine.add_position(&board, 0.1);
2315
2316        // Test accuracy evaluation
2317        let accuracy = evaluator.evaluate_accuracy(&mut engine, &dataset);
2318        assert!(accuracy.is_ok());
2319        assert!(accuracy.unwrap() < 1.0); // Should have some accuracy
2320    }
2321
2322    #[test]
2323    fn test_tactical_training_integration() {
2324        let tactical_data = vec![TacticalTrainingData {
2325            position: Board::default(),
2326            solution_move: ChessMove::from_str("e2e4").unwrap(),
2327            move_theme: "opening".to_string(),
2328            difficulty: 1.2,
2329            tactical_value: 2.5,
2330        }];
2331
2332        let mut engine = ChessVectorEngine::new(1024);
2333        TacticalPuzzleParser::load_into_engine(&tactical_data, &mut engine);
2334
2335        assert_eq!(engine.knowledge_base_size(), 1);
2336        assert_eq!(engine.position_moves.len(), 1);
2337
2338        // Test that tactical move is available in recommendations
2339        let recommendations = engine.recommend_moves(&Board::default(), 5);
2340        assert!(!recommendations.is_empty());
2341    }
2342
2343    #[test]
2344    fn test_multithreading_operations() {
2345        let mut dataset = TrainingDataset::new();
2346        let board = Board::default();
2347
2348        // Add test data
2349        for i in 0..10 {
2350            let training_data = TrainingData {
2351                board,
2352                evaluation: i as f32 * 0.1,
2353                depth: 15,
2354                game_id: i,
2355            };
2356            dataset.data.push(training_data);
2357        }
2358
2359        // Test parallel deduplication doesn't crash
2360        dataset.deduplicate_parallel(0.95, 5);
2361        assert!(dataset.data.len() <= 10);
2362    }
2363
2364    #[test]
2365    fn test_incremental_dataset_operations() {
2366        let mut dataset1 = TrainingDataset::new();
2367        let board1 = Board::default();
2368        let board2 =
2369            Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2370
2371        // Add initial data
2372        dataset1.add_position(board1, 0.0, 15, 1);
2373        dataset1.add_position(board2, 0.2, 15, 2);
2374        assert_eq!(dataset1.data.len(), 2);
2375
2376        // Create second dataset
2377        let mut dataset2 = TrainingDataset::new();
2378        dataset2.add_position(
2379            Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2380                .unwrap(),
2381            0.3,
2382            15,
2383            3,
2384        );
2385
2386        // Merge datasets
2387        dataset1.merge(dataset2);
2388        assert_eq!(dataset1.data.len(), 3);
2389
2390        // Test next_game_id
2391        let next_id = dataset1.next_game_id();
2392        assert_eq!(next_id, 4); // Should be max(1,2,3) + 1
2393    }
2394
2395    #[test]
2396    fn test_save_load_incremental() {
2397        use tempfile::tempdir;
2398
2399        let temp_dir = tempdir().unwrap();
2400        let file_path = temp_dir.path().join("incremental_test.json");
2401
2402        // Create and save first dataset
2403        let mut dataset1 = TrainingDataset::new();
2404        dataset1.add_position(Board::default(), 0.0, 15, 1);
2405        dataset1.save(&file_path).unwrap();
2406
2407        // Create second dataset and save incrementally
2408        let mut dataset2 = TrainingDataset::new();
2409        dataset2.add_position(
2410            Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap(),
2411            0.2,
2412            15,
2413            2,
2414        );
2415        dataset2.save_incremental(&file_path).unwrap();
2416
2417        // Load and verify merged data
2418        let loaded = TrainingDataset::load(&file_path).unwrap();
2419        assert_eq!(loaded.data.len(), 2);
2420
2421        // Test load_and_append
2422        let mut dataset3 = TrainingDataset::new();
2423        dataset3.add_position(
2424            Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2425                .unwrap(),
2426            0.3,
2427            15,
2428            3,
2429        );
2430        dataset3.load_and_append(&file_path).unwrap();
2431        assert_eq!(dataset3.data.len(), 3); // 1 original + 2 from file
2432    }
2433
2434    #[test]
2435    fn test_add_position_method() {
2436        let mut dataset = TrainingDataset::new();
2437        let board = Board::default();
2438
2439        // Test add_position method
2440        dataset.add_position(board, 0.5, 20, 42);
2441        assert_eq!(dataset.data.len(), 1);
2442        assert_eq!(dataset.data[0].evaluation, 0.5);
2443        assert_eq!(dataset.data[0].depth, 20);
2444        assert_eq!(dataset.data[0].game_id, 42);
2445    }
2446
2447    #[test]
2448    fn test_incremental_save_deduplication() {
2449        use tempfile::tempdir;
2450
2451        let temp_dir = tempdir().unwrap();
2452        let file_path = temp_dir.path().join("dedup_test.json");
2453
2454        // Create and save first dataset
2455        let mut dataset1 = TrainingDataset::new();
2456        dataset1.add_position(Board::default(), 0.0, 15, 1);
2457        dataset1.save(&file_path).unwrap();
2458
2459        // Create second dataset with duplicate position
2460        let mut dataset2 = TrainingDataset::new();
2461        dataset2.add_position(Board::default(), 0.1, 15, 2); // Same position, different eval
2462        dataset2.save_incremental(&file_path).unwrap();
2463
2464        // Should deduplicate and keep only one
2465        let loaded = TrainingDataset::load(&file_path).unwrap();
2466        assert_eq!(loaded.data.len(), 1);
2467    }
2468
2469    #[test]
2470    fn test_tactical_puzzle_incremental_loading() {
2471        let tactical_data = vec![
2472            TacticalTrainingData {
2473                position: Board::default(),
2474                solution_move: ChessMove::from_str("e2e4").unwrap(),
2475                move_theme: "opening".to_string(),
2476                difficulty: 1.2,
2477                tactical_value: 2.5,
2478            },
2479            TacticalTrainingData {
2480                position: Board::from_str(
2481                    "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1",
2482                )
2483                .unwrap(),
2484                solution_move: ChessMove::from_str("e7e5").unwrap(),
2485                move_theme: "opening".to_string(),
2486                difficulty: 1.0,
2487                tactical_value: 2.0,
2488            },
2489        ];
2490
2491        let mut engine = ChessVectorEngine::new(1024);
2492
2493        // Add some existing data
2494        engine.add_position(&Board::default(), 0.1);
2495        assert_eq!(engine.knowledge_base_size(), 1);
2496
2497        // Load tactical puzzles incrementally
2498        TacticalPuzzleParser::load_into_engine_incremental(&tactical_data, &mut engine);
2499
2500        // Should have added the new position but skipped the duplicate
2501        assert_eq!(engine.knowledge_base_size(), 2);
2502
2503        // Should have move data for both puzzles
2504        assert!(engine.training_stats().has_move_data);
2505        assert!(engine.training_stats().move_data_entries > 0);
2506    }
2507
2508    #[test]
2509    fn test_tactical_puzzle_serialization() {
2510        use tempfile::tempdir;
2511
2512        let temp_dir = tempdir().unwrap();
2513        let file_path = temp_dir.path().join("tactical_test.json");
2514
2515        let tactical_data = vec![TacticalTrainingData {
2516            position: Board::default(),
2517            solution_move: ChessMove::from_str("e2e4").unwrap(),
2518            move_theme: "fork".to_string(),
2519            difficulty: 1.5,
2520            tactical_value: 3.0,
2521        }];
2522
2523        // Save tactical puzzles
2524        TacticalPuzzleParser::save_tactical_puzzles(&tactical_data, &file_path).unwrap();
2525
2526        // Load them back
2527        let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2528        assert_eq!(loaded.len(), 1);
2529        assert_eq!(loaded[0].move_theme, "fork");
2530        assert_eq!(loaded[0].difficulty, 1.5);
2531        assert_eq!(loaded[0].tactical_value, 3.0);
2532    }
2533
2534    #[test]
2535    fn test_tactical_puzzle_incremental_save() {
2536        use tempfile::tempdir;
2537
2538        let temp_dir = tempdir().unwrap();
2539        let file_path = temp_dir.path().join("incremental_tactical.json");
2540
2541        // Save first batch
2542        let batch1 = vec![TacticalTrainingData {
2543            position: Board::default(),
2544            solution_move: ChessMove::from_str("e2e4").unwrap(),
2545            move_theme: "opening".to_string(),
2546            difficulty: 1.0,
2547            tactical_value: 2.0,
2548        }];
2549        TacticalPuzzleParser::save_tactical_puzzles(&batch1, &file_path).unwrap();
2550
2551        // Save second batch incrementally
2552        let batch2 = vec![TacticalTrainingData {
2553            position: Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1")
2554                .unwrap(),
2555            solution_move: ChessMove::from_str("e7e5").unwrap(),
2556            move_theme: "counter".to_string(),
2557            difficulty: 1.2,
2558            tactical_value: 2.2,
2559        }];
2560        TacticalPuzzleParser::save_tactical_puzzles_incremental(&batch2, &file_path).unwrap();
2561
2562        // Load and verify merged data
2563        let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2564        assert_eq!(loaded.len(), 2);
2565    }
2566
2567    #[test]
2568    fn test_tactical_puzzle_incremental_deduplication() {
2569        use tempfile::tempdir;
2570
2571        let temp_dir = tempdir().unwrap();
2572        let file_path = temp_dir.path().join("dedup_tactical.json");
2573
2574        let tactical_data = TacticalTrainingData {
2575            position: Board::default(),
2576            solution_move: ChessMove::from_str("e2e4").unwrap(),
2577            move_theme: "opening".to_string(),
2578            difficulty: 1.0,
2579            tactical_value: 2.0,
2580        };
2581
2582        // Save first time
2583        TacticalPuzzleParser::save_tactical_puzzles(&[tactical_data.clone()], &file_path).unwrap();
2584
2585        // Try to save the same puzzle again
2586        TacticalPuzzleParser::save_tactical_puzzles_incremental(&[tactical_data], &file_path)
2587            .unwrap();
2588
2589        // Should still have only one puzzle (deduplicated)
2590        let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2591        assert_eq!(loaded.len(), 1);
2592    }
2593}