chess_vector_engine/
training.rs

1use chess::{Board, ChessMove, Game, MoveGen};
2use indicatif::{ProgressBar, ProgressStyle};
3use pgn_reader::{BufferedReader, RawHeader, SanPlus, Skip, Visitor};
4use rayon::prelude::*;
5use serde::{Deserialize, Serialize};
6use std::fs::File;
7use std::io::{BufRead, BufReader, BufWriter, Write};
8use std::path::Path;
9use std::process::{Child, Command, Stdio};
10use std::str::FromStr;
11use std::sync::{Arc, Mutex};
12
13use crate::ChessVectorEngine;
14
15/// Self-play training configuration
16#[derive(Debug, Clone)]
17pub struct SelfPlayConfig {
18    /// Number of games to play per training iteration
19    pub games_per_iteration: usize,
20    /// Maximum moves per game (to prevent infinite games)
21    pub max_moves_per_game: usize,
22    /// Exploration factor for move selection (0.0 = greedy, 1.0 = random)
23    pub exploration_factor: f32,
24    /// Minimum evaluation confidence to include position
25    pub min_confidence: f32,
26    /// Whether to use opening book for game starts
27    pub use_opening_book: bool,
28    /// Temperature for move selection (higher = more random)
29    pub temperature: f32,
30}
31
32impl Default for SelfPlayConfig {
33    fn default() -> Self {
34        Self {
35            games_per_iteration: 100,
36            max_moves_per_game: 200,
37            exploration_factor: 0.3,
38            min_confidence: 0.1,
39            use_opening_book: true,
40            temperature: 0.8,
41        }
42    }
43}
44
45/// Training data point containing a position and its evaluation
46#[derive(Debug, Clone)]
47pub struct TrainingData {
48    pub board: Board,
49    pub evaluation: f32,
50    pub depth: u8,
51    pub game_id: usize,
52}
53
54/// Tactical puzzle data from Lichess puzzle database
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct TacticalPuzzle {
57    #[serde(rename = "PuzzleId")]
58    pub puzzle_id: String,
59    #[serde(rename = "FEN")]
60    pub fen: String,
61    #[serde(rename = "Moves")]
62    pub moves: String, // Space-separated move sequence
63    #[serde(rename = "Rating")]
64    pub rating: u32,
65    #[serde(rename = "RatingDeviation")]
66    pub rating_deviation: u32,
67    #[serde(rename = "Popularity")]
68    pub popularity: i32,
69    #[serde(rename = "NbPlays")]
70    pub nb_plays: u32,
71    #[serde(rename = "Themes")]
72    pub themes: String, // Space-separated themes
73    #[serde(rename = "GameUrl")]
74    pub game_url: Option<String>,
75    #[serde(rename = "OpeningTags")]
76    pub opening_tags: Option<String>,
77}
78
79/// Processed tactical training data
80#[derive(Debug, Clone)]
81pub struct TacticalTrainingData {
82    pub position: Board,
83    pub solution_move: ChessMove,
84    pub move_theme: String,
85    pub difficulty: f32,     // Rating as difficulty
86    pub tactical_value: f32, // High value for move outcome
87}
88
89// Make TacticalTrainingData serializable
90impl serde::Serialize for TacticalTrainingData {
91    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
92    where
93        S: serde::Serializer,
94    {
95        use serde::ser::SerializeStruct;
96        let mut state = serializer.serialize_struct("TacticalTrainingData", 5)?;
97        state.serialize_field("fen", &self.position.to_string())?;
98        state.serialize_field("solution_move", &self.solution_move.to_string())?;
99        state.serialize_field("move_theme", &self.move_theme)?;
100        state.serialize_field("difficulty", &self.difficulty)?;
101        state.serialize_field("tactical_value", &self.tactical_value)?;
102        state.end()
103    }
104}
105
106impl<'de> serde::Deserialize<'de> for TacticalTrainingData {
107    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108    where
109        D: serde::Deserializer<'de>,
110    {
111        use serde::de::{self, MapAccess, Visitor};
112        use std::fmt;
113
114        struct TacticalTrainingDataVisitor;
115
116        impl<'de> Visitor<'de> for TacticalTrainingDataVisitor {
117            type Value = TacticalTrainingData;
118
119            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
120                formatter.write_str("struct TacticalTrainingData")
121            }
122
123            fn visit_map<V>(self, mut map: V) -> Result<TacticalTrainingData, V::Error>
124            where
125                V: MapAccess<'de>,
126            {
127                let mut fen = None;
128                let mut solution_move = None;
129                let mut move_theme = None;
130                let mut difficulty = None;
131                let mut tactical_value = None;
132
133                while let Some(key) = map.next_key()? {
134                    match key {
135                        "fen" => {
136                            if fen.is_some() {
137                                return Err(de::Error::duplicate_field("fen"));
138                            }
139                            fen = Some(map.next_value()?);
140                        }
141                        "solution_move" => {
142                            if solution_move.is_some() {
143                                return Err(de::Error::duplicate_field("solution_move"));
144                            }
145                            solution_move = Some(map.next_value()?);
146                        }
147                        "move_theme" => {
148                            if move_theme.is_some() {
149                                return Err(de::Error::duplicate_field("move_theme"));
150                            }
151                            move_theme = Some(map.next_value()?);
152                        }
153                        "difficulty" => {
154                            if difficulty.is_some() {
155                                return Err(de::Error::duplicate_field("difficulty"));
156                            }
157                            difficulty = Some(map.next_value()?);
158                        }
159                        "tactical_value" => {
160                            if tactical_value.is_some() {
161                                return Err(de::Error::duplicate_field("tactical_value"));
162                            }
163                            tactical_value = Some(map.next_value()?);
164                        }
165                        _ => {
166                            let _: serde_json::Value = map.next_value()?;
167                        }
168                    }
169                }
170
171                let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
172                let solution_move_str: String =
173                    solution_move.ok_or_else(|| de::Error::missing_field("solution_move"))?;
174                let move_theme =
175                    move_theme.ok_or_else(|| de::Error::missing_field("move_theme"))?;
176                let difficulty =
177                    difficulty.ok_or_else(|| de::Error::missing_field("difficulty"))?;
178                let tactical_value =
179                    tactical_value.ok_or_else(|| de::Error::missing_field("tactical_value"))?;
180
181                let position =
182                    Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
183
184                let solution_move = ChessMove::from_str(&solution_move_str)
185                    .map_err(|e| de::Error::custom(format!("Error: {e}")))?;
186
187                Ok(TacticalTrainingData {
188                    position,
189                    solution_move,
190                    move_theme,
191                    difficulty,
192                    tactical_value,
193                })
194            }
195        }
196
197        const FIELDS: &[&str] = &[
198            "fen",
199            "solution_move",
200            "move_theme",
201            "difficulty",
202            "tactical_value",
203        ];
204        deserializer.deserialize_struct("TacticalTrainingData", FIELDS, TacticalTrainingDataVisitor)
205    }
206}
207
208/// PGN game visitor for extracting positions
209pub struct GameExtractor {
210    pub positions: Vec<TrainingData>,
211    pub current_game: Game,
212    pub move_count: usize,
213    pub max_moves_per_game: usize,
214    pub game_id: usize,
215}
216
217impl GameExtractor {
218    pub fn new(max_moves_per_game: usize) -> Self {
219        Self {
220            positions: Vec::new(),
221            current_game: Game::new(),
222            move_count: 0,
223            max_moves_per_game,
224            game_id: 0,
225        }
226    }
227}
228
229impl Visitor for GameExtractor {
230    type Result = ();
231
232    fn begin_game(&mut self) {
233        self.current_game = Game::new();
234        self.move_count = 0;
235        self.game_id += 1;
236    }
237
238    fn header(&mut self, _key: &[u8], _value: RawHeader<'_>) {}
239
240    fn san(&mut self, san_plus: SanPlus) {
241        if self.move_count >= self.max_moves_per_game {
242            return;
243        }
244
245        let san_str = san_plus.san.to_string();
246
247        // First validate that we have a legal position to work with
248        let current_pos = self.current_game.current_position();
249
250        // Try to parse and make the move
251        match chess::ChessMove::from_san(&current_pos, &san_str) {
252            Ok(chess_move) => {
253                // Verify the move is legal before making it
254                let legal_moves: Vec<chess::ChessMove> = MoveGen::new_legal(&current_pos).collect();
255                if legal_moves.contains(&chess_move) {
256                    if self.current_game.make_move(chess_move) {
257                        self.move_count += 1;
258
259                        // Store position (we'll evaluate it later with Stockfish)
260                        self.positions.push(TrainingData {
261                            board: self.current_game.current_position(),
262                            evaluation: 0.0, // Will be filled by Stockfish
263                            depth: 0,
264                            game_id: self.game_id,
265                        });
266                    }
267                } else {
268                    // Move parsed but isn't legal - skip silently to avoid spam
269                }
270            }
271            Err(_) => {
272                // Failed to parse move - could be notation issues, corruption, etc.
273                // Skip silently to avoid excessive error output
274                // Only log if it's not a common problematic pattern
275                if !san_str.contains("O-O") && !san_str.contains("=") && san_str.len() > 6 {
276                    // Only log unusual failed moves to reduce noise
277                }
278            }
279        }
280    }
281
282    fn begin_variation(&mut self) -> Skip {
283        Skip(true) // Skip variations for now
284    }
285
286    fn end_game(&mut self) -> Self::Result {}
287}
288
289/// Stockfish engine wrapper for position evaluation
290pub struct StockfishEvaluator {
291    depth: u8,
292}
293
294impl StockfishEvaluator {
295    pub fn new(depth: u8) -> Self {
296        Self { depth }
297    }
298
299    /// Evaluate a single position using Stockfish
300    pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
301        let mut child = Command::new("stockfish")
302            .stdin(Stdio::piped())
303            .stdout(Stdio::piped())
304            .stderr(Stdio::piped())
305            .spawn()?;
306
307        let stdin = child
308            .stdin
309            .as_mut()
310            .ok_or("Failed to get stdin handle for Stockfish process")?;
311        let fen = board.to_string();
312
313        // Send UCI commands
314        use std::io::Write;
315        writeln!(stdin, "uci")?;
316        writeln!(stdin, "isready")?;
317        writeln!(stdin, "position fen {fen}")?;
318        writeln!(stdin, "go depth {}", self.depth)?;
319        writeln!(stdin, "quit")?;
320
321        let output = child.wait_with_output()?;
322        let stdout = String::from_utf8_lossy(&output.stdout);
323
324        // Parse the evaluation from Stockfish output
325        for line in stdout.lines() {
326            if line.starts_with("info") && line.contains("score cp") {
327                if let Some(cp_pos) = line.find("score cp ") {
328                    let cp_str = &line[cp_pos + 9..];
329                    if let Some(end) = cp_str.find(' ') {
330                        let cp_value = cp_str[..end].parse::<i32>()?;
331                        return Ok(cp_value as f32 / 100.0); // Convert centipawns to pawns
332                    }
333                }
334            } else if line.starts_with("info") && line.contains("score mate") {
335                // Handle mate scores
336                if let Some(mate_pos) = line.find("score mate ") {
337                    let mate_str = &line[mate_pos + 11..];
338                    if let Some(end) = mate_str.find(' ') {
339                        let mate_moves = mate_str[..end].parse::<i32>()?;
340                        return Ok(if mate_moves > 0 { 100.0 } else { -100.0 });
341                    }
342                }
343            }
344        }
345
346        Ok(0.0) // Default to 0 if no evaluation found
347    }
348
349    /// Batch evaluate multiple positions
350    pub fn evaluate_batch(
351        &self,
352        positions: &mut [TrainingData],
353    ) -> Result<(), Box<dyn std::error::Error>> {
354        let pb = ProgressBar::new(positions.len() as u64);
355        if let Ok(style) = ProgressStyle::default_bar().template(
356            "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
357        ) {
358            pb.set_style(style.progress_chars("#>-"));
359        }
360
361        for data in positions.iter_mut() {
362            match self.evaluate_position(&data.board) {
363                Ok(eval) => {
364                    data.evaluation = eval;
365                    data.depth = self.depth;
366                }
367                Err(e) => {
368                    eprintln!("Evaluation error: {e}");
369                    data.evaluation = 0.0;
370                }
371            }
372            pb.inc(1);
373        }
374
375        pb.finish_with_message("Evaluation complete");
376        Ok(())
377    }
378
379    /// Evaluate multiple positions in parallel using concurrent Stockfish instances
380    pub fn evaluate_batch_parallel(
381        &self,
382        positions: &mut [TrainingData],
383        num_threads: usize,
384    ) -> Result<(), Box<dyn std::error::Error>> {
385        let pb = ProgressBar::new(positions.len() as u64);
386        if let Ok(style) = ProgressStyle::default_bar()
387            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel evaluation") {
388            pb.set_style(style.progress_chars("#>-"));
389        }
390
391        // Set the thread pool size
392        let pool = rayon::ThreadPoolBuilder::new()
393            .num_threads(num_threads)
394            .build()?;
395
396        pool.install(|| {
397            // Use parallel iterator to evaluate positions
398            positions.par_iter_mut().for_each(|data| {
399                match self.evaluate_position(&data.board) {
400                    Ok(eval) => {
401                        data.evaluation = eval;
402                        data.depth = self.depth;
403                    }
404                    Err(_) => {
405                        // Silently fail individual positions to avoid spamming output
406                        data.evaluation = 0.0;
407                    }
408                }
409                pb.inc(1);
410            });
411        });
412
413        pb.finish_with_message("Parallel evaluation complete");
414        Ok(())
415    }
416}
417
418/// Persistent Stockfish process for fast UCI communication
419struct StockfishProcess {
420    child: Child,
421    stdin: BufWriter<std::process::ChildStdin>,
422    stdout: BufReader<std::process::ChildStdout>,
423    #[allow(dead_code)]
424    depth: u8,
425}
426
427impl StockfishProcess {
428    fn new(depth: u8) -> Result<Self, Box<dyn std::error::Error>> {
429        let mut child = Command::new("stockfish")
430            .stdin(Stdio::piped())
431            .stdout(Stdio::piped())
432            .stderr(Stdio::piped())
433            .spawn()?;
434
435        let stdin = BufWriter::new(
436            child
437                .stdin
438                .take()
439                .ok_or("Failed to get stdin handle for Stockfish process")?,
440        );
441        let stdout = BufReader::new(
442            child
443                .stdout
444                .take()
445                .ok_or("Failed to get stdout handle for Stockfish process")?,
446        );
447
448        let mut process = Self {
449            child,
450            stdin,
451            stdout,
452            depth,
453        };
454
455        // Initialize UCI
456        process.send_command("uci")?;
457        process.wait_for_ready()?;
458        process.send_command("isready")?;
459        process.wait_for_ready()?;
460
461        Ok(process)
462    }
463
464    fn send_command(&mut self, command: &str) -> Result<(), Box<dyn std::error::Error>> {
465        writeln!(self.stdin, "{command}")?;
466        self.stdin.flush()?;
467        Ok(())
468    }
469
470    fn wait_for_ready(&mut self) -> Result<(), Box<dyn std::error::Error>> {
471        let mut line = String::new();
472        loop {
473            line.clear();
474            self.stdout.read_line(&mut line)?;
475            if line.trim() == "uciok" || line.trim() == "readyok" {
476                break;
477            }
478        }
479        Ok(())
480    }
481
482    fn evaluate_position(&mut self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
483        let fen = board.to_string();
484
485        // Send position and evaluation commands
486        self.send_command(&format!("position fen {fen}"))?;
487        self.send_command(&format!("position fen {fen}"))?;
488
489        // Read response until we get bestmove
490        let mut line = String::new();
491        let mut last_evaluation = 0.0;
492
493        loop {
494            line.clear();
495            self.stdout.read_line(&mut line)?;
496            let line = line.trim();
497
498            if line.starts_with("info") && line.contains("score cp") {
499                if let Some(cp_pos) = line.find("score cp ") {
500                    let cp_str = &line[cp_pos + 9..];
501                    if let Some(end) = cp_str.find(' ') {
502                        if let Ok(cp_value) = cp_str[..end].parse::<i32>() {
503                            last_evaluation = cp_value as f32 / 100.0;
504                        }
505                    }
506                }
507            } else if line.starts_with("info") && line.contains("score mate") {
508                if let Some(mate_pos) = line.find("score mate ") {
509                    let mate_str = &line[mate_pos + 11..];
510                    if let Some(end) = mate_str.find(' ') {
511                        if let Ok(mate_moves) = mate_str[..end].parse::<i32>() {
512                            last_evaluation = if mate_moves > 0 { 100.0 } else { -100.0 };
513                        }
514                    }
515                }
516            } else if line.starts_with("bestmove") {
517                break;
518            }
519        }
520
521        Ok(last_evaluation)
522    }
523}
524
525impl Drop for StockfishProcess {
526    fn drop(&mut self) {
527        let _ = self.send_command("quit");
528        let _ = self.child.wait();
529    }
530}
531
532/// High-performance Stockfish process pool
533pub struct StockfishPool {
534    pool: Arc<Mutex<Vec<StockfishProcess>>>,
535    depth: u8,
536    pool_size: usize,
537}
538
539impl StockfishPool {
540    pub fn new(depth: u8, pool_size: usize) -> Result<Self, Box<dyn std::error::Error>> {
541        let mut processes = Vec::with_capacity(pool_size);
542
543        println!(
544            "🚀 Initializing Stockfish pool with {pool_size} processes..."
545        );
546
547        for i in 0..pool_size {
548            match StockfishProcess::new(depth) {
549                Ok(process) => {
550                    processes.push(process);
551                    if i % 2 == 1 {
552                        print!(".");
553                        let _ = std::io::stdout().flush(); // Ignore flush errors
554                    }
555                }
556                Err(e) => {
557                    eprintln!("Evaluation error: {e}");
558                    return Err(e);
559                }
560            }
561        }
562
563        println!(" ✅ Pool ready!");
564
565        Ok(Self {
566            pool: Arc::new(Mutex::new(processes)),
567            depth,
568            pool_size,
569        })
570    }
571
572    pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
573        // Get a process from the pool
574        let mut process = {
575            let mut pool = self.pool.lock().unwrap();
576            if let Some(process) = pool.pop() {
577                process
578            } else {
579                // Pool is empty, create temporary process
580                StockfishProcess::new(self.depth)?
581            }
582        };
583
584        // Evaluate position
585        let result = process.evaluate_position(board);
586
587        // Return process to pool
588        {
589            let mut pool = self.pool.lock().unwrap();
590            if pool.len() < self.pool_size {
591                pool.push(process);
592            }
593            // Otherwise drop the process (in case of pool size changes)
594        }
595
596        result
597    }
598
599    pub fn evaluate_batch_parallel(
600        &self,
601        positions: &mut [TrainingData],
602    ) -> Result<(), Box<dyn std::error::Error>> {
603        let pb = ProgressBar::new(positions.len() as u64);
604        pb.set_style(ProgressStyle::default_bar()
605            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Pool evaluation")
606            .unwrap()
607            .progress_chars("#>-"));
608
609        // Use rayon for parallel evaluation
610        positions.par_iter_mut().for_each(|data| {
611            match self.evaluate_position(&data.board) {
612                Ok(eval) => {
613                    data.evaluation = eval;
614                    data.depth = self.depth;
615                }
616                Err(_) => {
617                    data.evaluation = 0.0;
618                }
619            }
620            pb.inc(1);
621        });
622
623        pb.finish_with_message("Pool evaluation complete");
624        Ok(())
625    }
626}
627
628/// Training dataset manager
629pub struct TrainingDataset {
630    pub data: Vec<TrainingData>,
631}
632
633impl Default for TrainingDataset {
634    fn default() -> Self {
635        Self::new()
636    }
637}
638
639impl TrainingDataset {
640    pub fn new() -> Self {
641        Self { data: Vec::new() }
642    }
643
644    /// Load positions from a PGN file
645    pub fn load_from_pgn<P: AsRef<Path>>(
646        &mut self,
647        path: P,
648        max_games: Option<usize>,
649        max_moves_per_game: usize,
650    ) -> Result<(), Box<dyn std::error::Error>> {
651        let file = File::open(path)?;
652        let reader = BufReader::new(file);
653
654        let mut extractor = GameExtractor::new(max_moves_per_game);
655        let mut games_processed = 0;
656
657        // Create a simple PGN parser
658        let mut pgn_content = String::new();
659        for line in reader.lines() {
660            let line = line?;
661            pgn_content.push_str(&line);
662            pgn_content.push('\n');
663
664            // Check if this is the end of a game
665            if line.trim().ends_with("1-0")
666                || line.trim().ends_with("0-1")
667                || line.trim().ends_with("1/2-1/2")
668                || line.trim().ends_with("*")
669            {
670                // Parse this game
671                let cursor = std::io::Cursor::new(&pgn_content);
672                let mut reader = BufferedReader::new(cursor);
673                if let Err(e) = reader.read_all(&mut extractor) {
674                    eprintln!("Evaluation error: {e}");
675                }
676
677                games_processed += 1;
678                pgn_content.clear();
679
680                if let Some(max) = max_games {
681                    if games_processed >= max {
682                        break;
683                    }
684                }
685
686                if games_processed % 100 == 0 {
687                    println!(
688                        "Processed {} games, extracted {} positions",
689                        games_processed,
690                        extractor.positions.len()
691                    );
692                }
693            }
694        }
695
696        self.data.extend(extractor.positions);
697        println!(
698            "Loaded {} positions from {} games",
699            self.data.len(),
700            games_processed
701        );
702        Ok(())
703    }
704
705    /// Evaluate all positions using Stockfish
706    pub fn evaluate_with_stockfish(&mut self, depth: u8) -> Result<(), Box<dyn std::error::Error>> {
707        let evaluator = StockfishEvaluator::new(depth);
708        evaluator.evaluate_batch(&mut self.data)
709    }
710
711    /// Evaluate all positions using Stockfish in parallel
712    pub fn evaluate_with_stockfish_parallel(
713        &mut self,
714        depth: u8,
715        num_threads: usize,
716    ) -> Result<(), Box<dyn std::error::Error>> {
717        let evaluator = StockfishEvaluator::new(depth);
718        evaluator.evaluate_batch_parallel(&mut self.data, num_threads)
719    }
720
721    /// Train the vector engine with this dataset
722    pub fn train_engine(&self, engine: &mut ChessVectorEngine) {
723        let pb = ProgressBar::new(self.data.len() as u64);
724        pb.set_style(ProgressStyle::default_bar()
725            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Training positions")
726            .unwrap()
727            .progress_chars("#>-"));
728
729        for data in &self.data {
730            engine.add_position(&data.board, data.evaluation);
731            pb.inc(1);
732        }
733
734        pb.finish_with_message("Training complete");
735        println!("Trained engine with {} positions", self.data.len());
736    }
737
738    /// Split dataset into train/test sets by games to prevent data leakage
739    pub fn split(&self, train_ratio: f32) -> (TrainingDataset, TrainingDataset) {
740        use rand::seq::SliceRandom;
741        use rand::thread_rng;
742        use std::collections::{HashMap, HashSet};
743
744        // Group positions by game_id
745        let mut games: HashMap<usize, Vec<&TrainingData>> = HashMap::new();
746        for data in &self.data {
747            games.entry(data.game_id).or_default().push(data);
748        }
749
750        // Get unique game IDs and shuffle them
751        let mut game_ids: Vec<usize> = games.keys().cloned().collect();
752        game_ids.shuffle(&mut thread_rng());
753
754        // Split games by ratio
755        let split_point = (game_ids.len() as f32 * train_ratio) as usize;
756        let train_game_ids: HashSet<usize> = game_ids[..split_point].iter().cloned().collect();
757
758        // Separate positions based on game membership
759        let mut train_data = Vec::new();
760        let mut test_data = Vec::new();
761
762        for data in &self.data {
763            if train_game_ids.contains(&data.game_id) {
764                train_data.push(data.clone());
765            } else {
766                test_data.push(data.clone());
767            }
768        }
769
770        (
771            TrainingDataset { data: train_data },
772            TrainingDataset { data: test_data },
773        )
774    }
775
776    /// Save dataset to file
777    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
778        let json = serde_json::to_string_pretty(&self.data)?;
779        std::fs::write(path, json)?;
780        Ok(())
781    }
782
783    /// Load dataset from file
784    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
785        let content = std::fs::read_to_string(path)?;
786        let data = serde_json::from_str(&content)?;
787        Ok(Self { data })
788    }
789
790    /// Load and append data from file to existing dataset (incremental training)
791    pub fn load_and_append<P: AsRef<Path>>(
792        &mut self,
793        path: P,
794    ) -> Result<(), Box<dyn std::error::Error>> {
795        let existing_len = self.data.len();
796        let additional_data = Self::load(path)?;
797        self.data.extend(additional_data.data);
798        println!(
799            "Loaded {} additional positions (total: {})",
800            self.data.len() - existing_len,
801            self.data.len()
802        );
803        Ok(())
804    }
805
806    /// Merge another dataset into this one
807    pub fn merge(&mut self, other: TrainingDataset) {
808        let existing_len = self.data.len();
809        self.data.extend(other.data);
810        println!(
811            "Merged {} positions (total: {})",
812            self.data.len() - existing_len,
813            self.data.len()
814        );
815    }
816
817    /// Save incrementally (append to existing file if it exists)
818    pub fn save_incremental<P: AsRef<Path>>(
819        &self,
820        path: P,
821    ) -> Result<(), Box<dyn std::error::Error>> {
822        self.save_incremental_with_options(path, true)
823    }
824
825    /// Save incrementally with option to skip deduplication
826    pub fn save_incremental_with_options<P: AsRef<Path>>(
827        &self,
828        path: P,
829        deduplicate: bool,
830    ) -> Result<(), Box<dyn std::error::Error>> {
831        let path = path.as_ref();
832
833        if path.exists() {
834            // Try fast append-only save first
835            if self.save_append_only(path).is_ok() {
836                return Ok(());
837            }
838
839            // Fall back to full merge
840            if deduplicate {
841                self.save_incremental_full_merge(path)
842            } else {
843                self.save_incremental_no_dedup(path)
844            }
845        } else {
846            // File doesn't exist, just save normally
847            self.save(path)
848        }
849    }
850
851    /// Fast merge without deduplication (for trusted unique data)
852    fn save_incremental_no_dedup<P: AsRef<Path>>(
853        &self,
854        path: P,
855    ) -> Result<(), Box<dyn std::error::Error>> {
856        let path = path.as_ref();
857
858        println!("📂 Loading existing training data...");
859        let mut existing = Self::load(path)?;
860
861        println!("⚡ Fast merge without deduplication...");
862        existing.data.extend(self.data.iter().cloned());
863
864        println!(
865            "💾 Serializing {} positions to JSON...",
866            existing.data.len()
867        );
868        let json = serde_json::to_string_pretty(&existing.data)?;
869
870        println!("✍️  Writing to disk...");
871        std::fs::write(path, json)?;
872
873        println!(
874            "✅ Fast merge save: total {} positions",
875            existing.data.len()
876        );
877        Ok(())
878    }
879
880    /// Fast append-only save (no deduplication, just append new positions)
881    pub fn save_append_only<P: AsRef<Path>>(
882        &self,
883        path: P,
884    ) -> Result<(), Box<dyn std::error::Error>> {
885        use std::fs::OpenOptions;
886        use std::io::{BufRead, BufReader, Seek, SeekFrom, Write};
887
888        if self.data.is_empty() {
889            return Ok(());
890        }
891
892        let path = path.as_ref();
893        let mut file = OpenOptions::new().read(true).write(true).open(path)?;
894
895        // Check if file is valid JSON array by reading last few bytes
896        file.seek(SeekFrom::End(-10))?;
897        let mut buffer = String::new();
898        BufReader::new(&file).read_line(&mut buffer)?;
899
900        if !buffer.trim().ends_with(']') {
901            return Err("File doesn't end with JSON array bracket".into());
902        }
903
904        // Seek back to overwrite the closing bracket
905        file.seek(SeekFrom::End(-2))?; // Go back 2 chars to overwrite "]\n"
906
907        // Append comma and new positions
908        write!(file, ",")?;
909
910        // Serialize and append new positions (without array brackets)
911        for (i, data) in self.data.iter().enumerate() {
912            if i > 0 {
913                write!(file, ",")?;
914            }
915            let json = serde_json::to_string(data)?;
916            write!(file, "{json}")?;
917        }
918
919        // Close the JSON array
920        write!(file, "\n]")?;
921
922        println!("Fast append: added {} new positions", self.data.len());
923        Ok(())
924    }
925
926    /// Full merge save with deduplication (slower but thorough)
927    fn save_incremental_full_merge<P: AsRef<Path>>(
928        &self,
929        path: P,
930    ) -> Result<(), Box<dyn std::error::Error>> {
931        let path = path.as_ref();
932
933        println!("📂 Loading existing training data...");
934        let mut existing = Self::load(path)?;
935        let _original_len = existing.data.len();
936
937        println!("🔄 Streaming merge with deduplication (avoiding O(n²) operation)...");
938        existing.merge_and_deduplicate(self.data.clone());
939
940        println!(
941            "💾 Serializing {} positions to JSON...",
942            existing.data.len()
943        );
944        let json = serde_json::to_string_pretty(&existing.data)?;
945
946        println!("✍️  Writing to disk...");
947        std::fs::write(path, json)?;
948
949        println!(
950            "✅ Streaming merge save: total {} positions",
951            existing.data.len()
952        );
953        Ok(())
954    }
955
956    /// Add a single training data point
957    pub fn add_position(&mut self, board: Board, evaluation: f32, depth: u8, game_id: usize) {
958        self.data.push(TrainingData {
959            board,
960            evaluation,
961            depth,
962            game_id,
963        });
964    }
965
966    /// Get the next available game ID for incremental training
967    pub fn next_game_id(&self) -> usize {
968        self.data.iter().map(|data| data.game_id).max().unwrap_or(0) + 1
969    }
970
971    /// Remove near-duplicate positions to reduce overfitting
972    pub fn deduplicate(&mut self, similarity_threshold: f32) {
973        if similarity_threshold > 0.999 {
974            // Use fast hash-based deduplication for exact duplicates
975            self.deduplicate_fast();
976        } else {
977            // Use slower similarity-based deduplication for near-duplicates
978            self.deduplicate_similarity_based(similarity_threshold);
979        }
980    }
981
982    /// Fast hash-based deduplication for exact duplicates (O(n))
983    pub fn deduplicate_fast(&mut self) {
984        use std::collections::HashSet;
985
986        if self.data.is_empty() {
987            return;
988        }
989
990        let mut seen_positions = HashSet::with_capacity(self.data.len());
991        let original_len = self.data.len();
992
993        // Keep positions with unique FEN strings
994        self.data.retain(|data| {
995            let fen = data.board.to_string();
996            seen_positions.insert(fen)
997        });
998
999        println!(
1000            "Fast deduplicated: {} -> {} positions (removed {} exact duplicates)",
1001            original_len,
1002            self.data.len(),
1003            original_len - self.data.len()
1004        );
1005    }
1006
1007    /// Streaming deduplication when merging with existing data (faster for large datasets)
1008    pub fn merge_and_deduplicate(&mut self, new_data: Vec<TrainingData>) {
1009        use std::collections::HashSet;
1010
1011        if new_data.is_empty() {
1012            return;
1013        }
1014
1015        let _original_len = self.data.len();
1016
1017        // Build hashset of existing positions for fast lookup
1018        let mut existing_positions: HashSet<String> = HashSet::with_capacity(self.data.len());
1019        for data in &self.data {
1020            existing_positions.insert(data.board.to_string());
1021        }
1022
1023        // Only add new positions that don't already exist
1024        let mut added = 0;
1025        for data in new_data {
1026            let fen = data.board.to_string();
1027            if existing_positions.insert(fen) {
1028                self.data.push(data);
1029                added += 1;
1030            }
1031        }
1032
1033        println!(
1034            "Streaming merge: added {} unique positions (total: {})",
1035            added,
1036            self.data.len()
1037        );
1038    }
1039
1040    /// Similarity-based deduplication for near-duplicates (O(n²) but optimized)
1041    fn deduplicate_similarity_based(&mut self, similarity_threshold: f32) {
1042        use crate::PositionEncoder;
1043        use ndarray::Array1;
1044
1045        if self.data.is_empty() {
1046            return;
1047        }
1048
1049        let encoder = PositionEncoder::new(1024);
1050        let mut keep_indices: Vec<bool> = vec![true; self.data.len()];
1051
1052        // Encode all positions in parallel
1053        let vectors: Vec<Array1<f32>> = if self.data.len() > 50 {
1054            self.data
1055                .par_iter()
1056                .map(|data| encoder.encode(&data.board))
1057                .collect()
1058        } else {
1059            self.data
1060                .iter()
1061                .map(|data| encoder.encode(&data.board))
1062                .collect()
1063        };
1064
1065        // Compare each position with all previous ones
1066        for i in 1..self.data.len() {
1067            if !keep_indices[i] {
1068                continue;
1069            }
1070
1071            for j in 0..i {
1072                if !keep_indices[j] {
1073                    continue;
1074                }
1075
1076                let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1077                if similarity > similarity_threshold {
1078                    keep_indices[i] = false;
1079                    break;
1080                }
1081            }
1082        }
1083
1084        // Filter data based on keep_indices
1085        let original_len = self.data.len();
1086        self.data = self
1087            .data
1088            .iter()
1089            .enumerate()
1090            .filter_map(|(i, data)| {
1091                if keep_indices[i] {
1092                    Some(data.clone())
1093                } else {
1094                    None
1095                }
1096            })
1097            .collect();
1098
1099        println!(
1100            "Similarity deduplicated: {} -> {} positions (removed {} near-duplicates)",
1101            original_len,
1102            self.data.len(),
1103            original_len - self.data.len()
1104        );
1105    }
1106
1107    /// Remove near-duplicate positions using parallel comparison (faster for large datasets)
1108    pub fn deduplicate_parallel(&mut self, similarity_threshold: f32, chunk_size: usize) {
1109        use crate::PositionEncoder;
1110        use ndarray::Array1;
1111        use std::sync::{Arc, Mutex};
1112
1113        if self.data.is_empty() {
1114            return;
1115        }
1116
1117        let encoder = PositionEncoder::new(1024);
1118
1119        // Encode all positions in parallel
1120        let vectors: Vec<Array1<f32>> = self
1121            .data
1122            .par_iter()
1123            .map(|data| encoder.encode(&data.board))
1124            .collect();
1125
1126        let keep_indices = Arc::new(Mutex::new(vec![true; self.data.len()]));
1127
1128        // Process in chunks to balance parallelism and memory usage
1129        (1..self.data.len())
1130            .collect::<Vec<_>>()
1131            .par_chunks(chunk_size)
1132            .for_each(|chunk| {
1133                for &i in chunk {
1134                    // Check if this position is still being kept
1135                    {
1136                        let indices = keep_indices.lock().unwrap();
1137                        if !indices[i] {
1138                            continue;
1139                        }
1140                    }
1141
1142                    // Compare with all previous positions
1143                    for j in 0..i {
1144                        {
1145                            let indices = keep_indices.lock().unwrap();
1146                            if !indices[j] {
1147                                continue;
1148                            }
1149                        }
1150
1151                        let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1152                        if similarity > similarity_threshold {
1153                            let mut indices = keep_indices.lock().unwrap();
1154                            indices[i] = false;
1155                            break;
1156                        }
1157                    }
1158                }
1159            });
1160
1161        // Filter data based on keep_indices
1162        let keep_indices = keep_indices.lock().unwrap();
1163        let original_len = self.data.len();
1164        self.data = self
1165            .data
1166            .iter()
1167            .enumerate()
1168            .filter_map(|(i, data)| {
1169                if keep_indices[i] {
1170                    Some(data.clone())
1171                } else {
1172                    None
1173                }
1174            })
1175            .collect();
1176
1177        println!(
1178            "Parallel deduplicated: {} -> {} positions (removed {} duplicates)",
1179            original_len,
1180            self.data.len(),
1181            original_len - self.data.len()
1182        );
1183    }
1184
1185    /// Calculate cosine similarity between two vectors
1186    fn cosine_similarity(a: &ndarray::Array1<f32>, b: &ndarray::Array1<f32>) -> f32 {
1187        let dot_product = a.dot(b);
1188        let norm_a = a.dot(a).sqrt();
1189        let norm_b = b.dot(b).sqrt();
1190
1191        if norm_a == 0.0 || norm_b == 0.0 {
1192            0.0
1193        } else {
1194            dot_product / (norm_a * norm_b)
1195        }
1196    }
1197}
1198
1199/// Self-play training system for generating new positions
1200pub struct SelfPlayTrainer {
1201    config: SelfPlayConfig,
1202    game_counter: usize,
1203}
1204
1205impl SelfPlayTrainer {
1206    pub fn new(config: SelfPlayConfig) -> Self {
1207        Self {
1208            config,
1209            game_counter: 0,
1210        }
1211    }
1212
1213    /// Generate training data through self-play games
1214    pub fn generate_training_data(&mut self, engine: &mut ChessVectorEngine) -> TrainingDataset {
1215        let mut dataset = TrainingDataset::new();
1216
1217        println!(
1218            "🎮 Starting self-play training with {} games...",
1219            self.config.games_per_iteration
1220        );
1221        let pb = ProgressBar::new(self.config.games_per_iteration as u64);
1222        if let Ok(style) = ProgressStyle::default_bar().template(
1223            "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
1224        ) {
1225            pb.set_style(style.progress_chars("#>-"));
1226        }
1227
1228        for _ in 0..self.config.games_per_iteration {
1229            let game_data = self.play_single_game(engine);
1230            dataset.data.extend(game_data);
1231            self.game_counter += 1;
1232            pb.inc(1);
1233        }
1234
1235        pb.finish_with_message("Self-play games completed");
1236        println!(
1237            "✅ Generated {} positions from {} games",
1238            dataset.data.len(),
1239            self.config.games_per_iteration
1240        );
1241
1242        dataset
1243    }
1244
1245    /// Play a single self-play game and extract training positions
1246    fn play_single_game(&self, engine: &mut ChessVectorEngine) -> Vec<TrainingData> {
1247        let mut game = Game::new();
1248        let mut positions = Vec::new();
1249        let mut move_count = 0;
1250
1251        // Use opening book for variety if enabled
1252        if self.config.use_opening_book {
1253            if let Some(opening_moves) = self.get_random_opening() {
1254                for mv in opening_moves {
1255                    if game.make_move(mv) {
1256                        move_count += 1;
1257                    } else {
1258                        break;
1259                    }
1260                }
1261            }
1262        }
1263
1264        // Play the game
1265        while game.result().is_none() && move_count < self.config.max_moves_per_game {
1266            let current_position = game.current_position();
1267
1268            // Get engine's move recommendation with exploration
1269            let move_choice = self.select_move_with_exploration(engine, &current_position);
1270
1271            if let Some(chess_move) = move_choice {
1272                // Evaluate the position before making the move
1273                if let Some(evaluation) = engine.evaluate_position(&current_position) {
1274                    // Only include positions with sufficient confidence
1275                    if evaluation.abs() >= self.config.min_confidence || move_count < 10 {
1276                        positions.push(TrainingData {
1277                            board: current_position,
1278                            evaluation,
1279                            depth: 1, // Self-play depth
1280                            game_id: self.game_counter,
1281                        });
1282                    }
1283                }
1284
1285                // Make the move
1286                if !game.make_move(chess_move) {
1287                    break; // Invalid move, end game
1288                }
1289                move_count += 1;
1290            } else {
1291                break; // No legal moves
1292            }
1293        }
1294
1295        // Add final position evaluation based on game result
1296        if let Some(result) = game.result() {
1297            let final_position = game.current_position();
1298            let final_eval = match result {
1299                chess::GameResult::WhiteCheckmates => {
1300                    if final_position.side_to_move() == chess::Color::Black {
1301                        10.0
1302                    } else {
1303                        -10.0
1304                    }
1305                }
1306                chess::GameResult::BlackCheckmates => {
1307                    if final_position.side_to_move() == chess::Color::White {
1308                        10.0
1309                    } else {
1310                        -10.0
1311                    }
1312                }
1313                chess::GameResult::WhiteResigns => -10.0,
1314                chess::GameResult::BlackResigns => 10.0,
1315                chess::GameResult::Stalemate
1316                | chess::GameResult::DrawAccepted
1317                | chess::GameResult::DrawDeclared => 0.0,
1318            };
1319
1320            positions.push(TrainingData {
1321                board: final_position,
1322                evaluation: final_eval,
1323                depth: 1,
1324                game_id: self.game_counter,
1325            });
1326        }
1327
1328        positions
1329    }
1330
1331    /// Select a move with exploration vs exploitation balance
1332    fn select_move_with_exploration(
1333        &self,
1334        engine: &mut ChessVectorEngine,
1335        position: &Board,
1336    ) -> Option<ChessMove> {
1337        let recommendations = engine.recommend_legal_moves(position, 5);
1338
1339        if recommendations.is_empty() {
1340            return None;
1341        }
1342
1343        // Use temperature-based selection for exploration
1344        if fastrand::f32() < self.config.exploration_factor {
1345            // Exploration: weighted random selection based on evaluations
1346            self.select_move_with_temperature(&recommendations)
1347        } else {
1348            // Exploitation: take the best move
1349            Some(recommendations[0].chess_move)
1350        }
1351    }
1352
1353    /// Temperature-based move selection for exploration
1354    fn select_move_with_temperature(
1355        &self,
1356        recommendations: &[crate::MoveRecommendation],
1357    ) -> Option<ChessMove> {
1358        if recommendations.is_empty() {
1359            return None;
1360        }
1361
1362        // Convert evaluations to probabilities using temperature
1363        let mut probabilities = Vec::new();
1364        let mut sum = 0.0;
1365
1366        for rec in recommendations {
1367            // Use average_outcome as evaluation score for temperature selection
1368            let prob = (rec.average_outcome / self.config.temperature).exp();
1369            probabilities.push(prob);
1370            sum += prob;
1371        }
1372
1373        // Normalize probabilities
1374        for prob in &mut probabilities {
1375            *prob /= sum;
1376        }
1377
1378        // Random selection based on probabilities
1379        let rand_val = fastrand::f32();
1380        let mut cumulative = 0.0;
1381
1382        for (i, &prob) in probabilities.iter().enumerate() {
1383            cumulative += prob;
1384            if rand_val <= cumulative {
1385                return Some(recommendations[i].chess_move);
1386            }
1387        }
1388
1389        // Fallback to first move
1390        Some(recommendations[0].chess_move)
1391    }
1392
1393    /// Get random opening moves for variety
1394    fn get_random_opening(&self) -> Option<Vec<ChessMove>> {
1395        let openings = [
1396            // Italian Game
1397            vec!["e4", "e5", "Nf3", "Nc6", "Bc4"],
1398            // Ruy Lopez
1399            vec!["e4", "e5", "Nf3", "Nc6", "Bb5"],
1400            // Queen's Gambit
1401            vec!["d4", "d5", "c4"],
1402            // King's Indian Defense
1403            vec!["d4", "Nf6", "c4", "g6"],
1404            // Sicilian Defense
1405            vec!["e4", "c5"],
1406            // French Defense
1407            vec!["e4", "e6"],
1408            // Caro-Kann Defense
1409            vec!["e4", "c6"],
1410        ];
1411
1412        let selected_opening = &openings[fastrand::usize(0..openings.len())];
1413
1414        let mut moves = Vec::new();
1415        let mut game = Game::new();
1416
1417        for move_str in selected_opening {
1418            if let Ok(chess_move) = ChessMove::from_str(move_str) {
1419                if game.make_move(chess_move) {
1420                    moves.push(chess_move);
1421                } else {
1422                    break;
1423                }
1424            }
1425        }
1426
1427        if moves.is_empty() {
1428            None
1429        } else {
1430            Some(moves)
1431        }
1432    }
1433}
1434
1435/// Engine performance evaluator
1436pub struct EngineEvaluator {
1437    #[allow(dead_code)]
1438    stockfish_depth: u8,
1439}
1440
1441impl EngineEvaluator {
1442    pub fn new(stockfish_depth: u8) -> Self {
1443        Self { stockfish_depth }
1444    }
1445
1446    /// Compare engine evaluations against Stockfish on test set
1447    pub fn evaluate_accuracy(
1448        &self,
1449        engine: &mut ChessVectorEngine,
1450        test_data: &TrainingDataset,
1451    ) -> Result<f32, Box<dyn std::error::Error>> {
1452        let mut total_error = 0.0;
1453        let mut valid_comparisons = 0;
1454
1455        let pb = ProgressBar::new(test_data.data.len() as u64);
1456        pb.set_style(ProgressStyle::default_bar()
1457            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Evaluating accuracy")
1458            .unwrap()
1459            .progress_chars("#>-"));
1460
1461        for data in &test_data.data {
1462            if let Some(engine_eval) = engine.evaluate_position(&data.board) {
1463                let error = (engine_eval - data.evaluation).abs();
1464                total_error += error;
1465                valid_comparisons += 1;
1466            }
1467            pb.inc(1);
1468        }
1469
1470        pb.finish_with_message("Accuracy evaluation complete");
1471
1472        if valid_comparisons > 0 {
1473            let mean_absolute_error = total_error / valid_comparisons as f32;
1474            println!("Mean Absolute Error: {mean_absolute_error:.3} pawns");
1475            println!("Evaluated {valid_comparisons} positions");
1476            Ok(mean_absolute_error)
1477        } else {
1478            Ok(f32::INFINITY)
1479        }
1480    }
1481}
1482
1483// Make TrainingData serializable
1484impl serde::Serialize for TrainingData {
1485    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1486    where
1487        S: serde::Serializer,
1488    {
1489        use serde::ser::SerializeStruct;
1490        let mut state = serializer.serialize_struct("TrainingData", 4)?;
1491        state.serialize_field("fen", &self.board.to_string())?;
1492        state.serialize_field("evaluation", &self.evaluation)?;
1493        state.serialize_field("depth", &self.depth)?;
1494        state.serialize_field("game_id", &self.game_id)?;
1495        state.end()
1496    }
1497}
1498
1499impl<'de> serde::Deserialize<'de> for TrainingData {
1500    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1501    where
1502        D: serde::Deserializer<'de>,
1503    {
1504        use serde::de::{self, MapAccess, Visitor};
1505        use std::fmt;
1506
1507        struct TrainingDataVisitor;
1508
1509        impl<'de> Visitor<'de> for TrainingDataVisitor {
1510            type Value = TrainingData;
1511
1512            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1513                formatter.write_str("struct TrainingData")
1514            }
1515
1516            fn visit_map<V>(self, mut map: V) -> Result<TrainingData, V::Error>
1517            where
1518                V: MapAccess<'de>,
1519            {
1520                let mut fen = None;
1521                let mut evaluation = None;
1522                let mut depth = None;
1523                let mut game_id = None;
1524
1525                while let Some(key) = map.next_key()? {
1526                    match key {
1527                        "fen" => {
1528                            if fen.is_some() {
1529                                return Err(de::Error::duplicate_field("fen"));
1530                            }
1531                            fen = Some(map.next_value()?);
1532                        }
1533                        "evaluation" => {
1534                            if evaluation.is_some() {
1535                                return Err(de::Error::duplicate_field("evaluation"));
1536                            }
1537                            evaluation = Some(map.next_value()?);
1538                        }
1539                        "depth" => {
1540                            if depth.is_some() {
1541                                return Err(de::Error::duplicate_field("depth"));
1542                            }
1543                            depth = Some(map.next_value()?);
1544                        }
1545                        "game_id" => {
1546                            if game_id.is_some() {
1547                                return Err(de::Error::duplicate_field("game_id"));
1548                            }
1549                            game_id = Some(map.next_value()?);
1550                        }
1551                        _ => {
1552                            let _: serde_json::Value = map.next_value()?;
1553                        }
1554                    }
1555                }
1556
1557                let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
1558                let mut evaluation: f32 =
1559                    evaluation.ok_or_else(|| de::Error::missing_field("evaluation"))?;
1560                let depth = depth.ok_or_else(|| de::Error::missing_field("depth"))?;
1561                let game_id = game_id.unwrap_or(0); // Default to 0 for backward compatibility
1562
1563                // Convert evaluation from centipawns to pawns if needed
1564                // If evaluation is outside typical pawn range (-10 to +10),
1565                // assume it's in centipawns and convert to pawns
1566                if evaluation.abs() > 15.0 {
1567                    evaluation /= 100.0;
1568                }
1569
1570                let board =
1571                    Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
1572
1573                Ok(TrainingData {
1574                    board,
1575                    evaluation,
1576                    depth,
1577                    game_id,
1578                })
1579            }
1580        }
1581
1582        const FIELDS: &[&str] = &["fen", "evaluation", "depth", "game_id"];
1583        deserializer.deserialize_struct("TrainingData", FIELDS, TrainingDataVisitor)
1584    }
1585}
1586
1587/// Tactical puzzle parser for Lichess puzzle database
1588pub struct TacticalPuzzleParser;
1589
1590impl TacticalPuzzleParser {
1591    /// Parse Lichess puzzle CSV file with parallel processing
1592    pub fn parse_csv<P: AsRef<Path>>(
1593        file_path: P,
1594        max_puzzles: Option<usize>,
1595        min_rating: Option<u32>,
1596        max_rating: Option<u32>,
1597    ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1598        let file = File::open(&file_path)?;
1599        let file_size = file.metadata()?.len();
1600
1601        // For large files (>100MB), use parallel processing
1602        if file_size > 100_000_000 {
1603            Self::parse_csv_parallel(file_path, max_puzzles, min_rating, max_rating)
1604        } else {
1605            Self::parse_csv_sequential(file_path, max_puzzles, min_rating, max_rating)
1606        }
1607    }
1608
1609    /// Sequential CSV parsing for smaller files
1610    fn parse_csv_sequential<P: AsRef<Path>>(
1611        file_path: P,
1612        max_puzzles: Option<usize>,
1613        min_rating: Option<u32>,
1614        max_rating: Option<u32>,
1615    ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1616        let file = File::open(file_path)?;
1617        let reader = BufReader::new(file);
1618
1619        // Create CSV reader without headers since Lichess CSV has no header row
1620        // Set flexible field count to handle inconsistent CSV structure
1621        let mut csv_reader = csv::ReaderBuilder::new()
1622            .has_headers(false)
1623            .flexible(true) // Allow variable number of fields
1624            .from_reader(reader);
1625
1626        let mut tactical_data = Vec::new();
1627        let mut processed = 0;
1628        let mut skipped = 0;
1629
1630        let pb = ProgressBar::new_spinner();
1631        pb.set_style(
1632            ProgressStyle::default_spinner()
1633                .template("{spinner:.green} Parsing tactical puzzles: {pos} (skipped: {skipped})")
1634                .unwrap(),
1635        );
1636
1637        for result in csv_reader.records() {
1638            let record = match result {
1639                Ok(r) => r,
1640                Err(e) => {
1641                    skipped += 1;
1642                    println!("CSV parsing error: {e}");
1643                    continue;
1644                }
1645            };
1646
1647            if let Some(puzzle_data) = Self::parse_csv_record(&record, min_rating, max_rating) {
1648                if let Some(tactical_data_item) =
1649                    Self::convert_puzzle_to_training_data(&puzzle_data)
1650                {
1651                    tactical_data.push(tactical_data_item);
1652                    processed += 1;
1653
1654                    if let Some(max) = max_puzzles {
1655                        if processed >= max {
1656                            break;
1657                        }
1658                    }
1659                } else {
1660                    skipped += 1;
1661                }
1662            } else {
1663                skipped += 1;
1664            }
1665
1666            pb.set_message(format!(
1667                "Parsing tactical puzzles: {processed} (skipped: {skipped})"
1668            ));
1669        }
1670
1671        pb.finish_with_message(format!(
1672            "Parsed {processed} puzzles (skipped: {skipped})"
1673        ));
1674
1675        Ok(tactical_data)
1676    }
1677
1678    /// Parallel CSV parsing for large files
1679    fn parse_csv_parallel<P: AsRef<Path>>(
1680        file_path: P,
1681        max_puzzles: Option<usize>,
1682        min_rating: Option<u32>,
1683        max_rating: Option<u32>,
1684    ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1685        use std::io::Read;
1686
1687        let mut file = File::open(&file_path)?;
1688
1689        // Read entire file into memory for parallel processing
1690        let mut contents = String::new();
1691        file.read_to_string(&mut contents)?;
1692
1693        // Split into lines for parallel processing
1694        let lines: Vec<&str> = contents.lines().collect();
1695
1696        let pb = ProgressBar::new(lines.len() as u64);
1697        pb.set_style(ProgressStyle::default_bar()
1698            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel CSV parsing")
1699            .unwrap()
1700            .progress_chars("#>-"));
1701
1702        // Process lines in parallel
1703        let tactical_data: Vec<TacticalTrainingData> = lines
1704            .par_iter()
1705            .take(max_puzzles.unwrap_or(usize::MAX))
1706            .filter_map(|line| {
1707                // Parse CSV line manually
1708                let fields: Vec<&str> = line.split(',').collect();
1709                if fields.len() < 8 {
1710                    return None;
1711                }
1712
1713                // Build puzzle from fields
1714                if let Some(puzzle_data) = Self::parse_csv_fields(&fields, min_rating, max_rating) {
1715                    Self::convert_puzzle_to_training_data(&puzzle_data)
1716                } else {
1717                    None
1718                }
1719            })
1720            .collect();
1721
1722        pb.finish_with_message(format!(
1723            "Parallel parsing complete: {} puzzles",
1724            tactical_data.len()
1725        ));
1726
1727        Ok(tactical_data)
1728    }
1729
1730    /// Parse CSV record into TacticalPuzzle
1731    fn parse_csv_record(
1732        record: &csv::StringRecord,
1733        min_rating: Option<u32>,
1734        max_rating: Option<u32>,
1735    ) -> Option<TacticalPuzzle> {
1736        // Need at least 8 fields (PuzzleId, FEN, Moves, Rating, RatingDeviation, Popularity, NbPlays, Themes)
1737        if record.len() < 8 {
1738            return None;
1739        }
1740
1741        let rating: u32 = record[3].parse().ok()?;
1742        let rating_deviation: u32 = record[4].parse().ok()?;
1743        let popularity: i32 = record[5].parse().ok()?;
1744        let nb_plays: u32 = record[6].parse().ok()?;
1745
1746        // Apply rating filters
1747        if let Some(min) = min_rating {
1748            if rating < min {
1749                return None;
1750            }
1751        }
1752        if let Some(max) = max_rating {
1753            if rating > max {
1754                return None;
1755            }
1756        }
1757
1758        Some(TacticalPuzzle {
1759            puzzle_id: record[0].to_string(),
1760            fen: record[1].to_string(),
1761            moves: record[2].to_string(),
1762            rating,
1763            rating_deviation,
1764            popularity,
1765            nb_plays,
1766            themes: record[7].to_string(),
1767            game_url: if record.len() > 8 {
1768                Some(record[8].to_string())
1769            } else {
1770                None
1771            },
1772            opening_tags: if record.len() > 9 {
1773                Some(record[9].to_string())
1774            } else {
1775                None
1776            },
1777        })
1778    }
1779
1780    /// Parse CSV fields into TacticalPuzzle (for parallel processing)
1781    fn parse_csv_fields(
1782        fields: &[&str],
1783        min_rating: Option<u32>,
1784        max_rating: Option<u32>,
1785    ) -> Option<TacticalPuzzle> {
1786        if fields.len() < 8 {
1787            return None;
1788        }
1789
1790        let rating: u32 = fields[3].parse().ok()?;
1791        let rating_deviation: u32 = fields[4].parse().ok()?;
1792        let popularity: i32 = fields[5].parse().ok()?;
1793        let nb_plays: u32 = fields[6].parse().ok()?;
1794
1795        // Apply rating filters
1796        if let Some(min) = min_rating {
1797            if rating < min {
1798                return None;
1799            }
1800        }
1801        if let Some(max) = max_rating {
1802            if rating > max {
1803                return None;
1804            }
1805        }
1806
1807        Some(TacticalPuzzle {
1808            puzzle_id: fields[0].to_string(),
1809            fen: fields[1].to_string(),
1810            moves: fields[2].to_string(),
1811            rating,
1812            rating_deviation,
1813            popularity,
1814            nb_plays,
1815            themes: fields[7].to_string(),
1816            game_url: if fields.len() > 8 {
1817                Some(fields[8].to_string())
1818            } else {
1819                None
1820            },
1821            opening_tags: if fields.len() > 9 {
1822                Some(fields[9].to_string())
1823            } else {
1824                None
1825            },
1826        })
1827    }
1828
1829    /// Convert puzzle to training data
1830    fn convert_puzzle_to_training_data(puzzle: &TacticalPuzzle) -> Option<TacticalTrainingData> {
1831        // Parse FEN position
1832        let position = match Board::from_str(&puzzle.fen) {
1833            Ok(board) => board,
1834            Err(_) => return None,
1835        };
1836
1837        // Parse move sequence - first move is the solution
1838        let moves: Vec<&str> = puzzle.moves.split_whitespace().collect();
1839        if moves.is_empty() {
1840            return None;
1841        }
1842
1843        // Parse the solution move (first move in sequence)
1844        let solution_move = match ChessMove::from_str(moves[0]) {
1845            Ok(mv) => mv,
1846            Err(_) => {
1847                // Try parsing as SAN
1848                match ChessMove::from_san(&position, moves[0]) {
1849                    Ok(mv) => mv,
1850                    Err(_) => return None,
1851                }
1852            }
1853        };
1854
1855        // Verify move is legal
1856        let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&position).collect();
1857        if !legal_moves.contains(&solution_move) {
1858            return None;
1859        }
1860
1861        // Extract primary theme
1862        let themes: Vec<&str> = puzzle.themes.split_whitespace().collect();
1863        let primary_theme = themes.first().unwrap_or(&"tactical").to_string();
1864
1865        // Calculate tactical value based on rating and popularity
1866        let difficulty = puzzle.rating as f32 / 1000.0; // Normalize to 0.8-3.0 range
1867        let popularity_bonus = (puzzle.popularity as f32 / 100.0).min(2.0);
1868        let tactical_value = difficulty + popularity_bonus; // High value for move outcome
1869
1870        Some(TacticalTrainingData {
1871            position,
1872            solution_move,
1873            move_theme: primary_theme,
1874            difficulty,
1875            tactical_value,
1876        })
1877    }
1878
1879    /// Load tactical training data into chess engine
1880    pub fn load_into_engine(
1881        tactical_data: &[TacticalTrainingData],
1882        engine: &mut ChessVectorEngine,
1883    ) {
1884        let pb = ProgressBar::new(tactical_data.len() as u64);
1885        pb.set_style(ProgressStyle::default_bar()
1886            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns")
1887            .unwrap()
1888            .progress_chars("#>-"));
1889
1890        for data in tactical_data {
1891            // Add position with high-value tactical move
1892            engine.add_position_with_move(
1893                &data.position,
1894                0.0, // Position evaluation (neutral for puzzles)
1895                Some(data.solution_move),
1896                Some(data.tactical_value), // High tactical value
1897            );
1898            pb.inc(1);
1899        }
1900
1901        pb.finish_with_message(format!("Loaded {} tactical patterns", tactical_data.len()));
1902    }
1903
1904    /// Load tactical training data into chess engine incrementally (preserves existing data)
1905    pub fn load_into_engine_incremental(
1906        tactical_data: &[TacticalTrainingData],
1907        engine: &mut ChessVectorEngine,
1908    ) {
1909        let initial_size = engine.knowledge_base_size();
1910        let initial_moves = engine.position_moves.len();
1911
1912        // For large datasets, use parallel batch processing
1913        if tactical_data.len() > 1000 {
1914            Self::load_into_engine_incremental_parallel(
1915                tactical_data,
1916                engine,
1917                initial_size,
1918                initial_moves,
1919            );
1920        } else {
1921            Self::load_into_engine_incremental_sequential(
1922                tactical_data,
1923                engine,
1924                initial_size,
1925                initial_moves,
1926            );
1927        }
1928    }
1929
1930    /// Sequential loading for smaller datasets
1931    fn load_into_engine_incremental_sequential(
1932        tactical_data: &[TacticalTrainingData],
1933        engine: &mut ChessVectorEngine,
1934        initial_size: usize,
1935        initial_moves: usize,
1936    ) {
1937        let pb = ProgressBar::new(tactical_data.len() as u64);
1938        pb.set_style(ProgressStyle::default_bar()
1939            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns (incremental)")
1940            .unwrap()
1941            .progress_chars("#>-"));
1942
1943        let mut added = 0;
1944        let mut skipped = 0;
1945
1946        for data in tactical_data {
1947            // Check if this position already exists to avoid duplicates
1948            if !engine.position_boards.contains(&data.position) {
1949                engine.add_position_with_move(
1950                    &data.position,
1951                    0.0, // Position evaluation (neutral for puzzles)
1952                    Some(data.solution_move),
1953                    Some(data.tactical_value), // High tactical value
1954                );
1955                added += 1;
1956            } else {
1957                skipped += 1;
1958            }
1959            pb.inc(1);
1960        }
1961
1962        pb.finish_with_message(format!(
1963            "Loaded {} new tactical patterns (skipped {} duplicates, total: {})",
1964            added,
1965            skipped,
1966            engine.knowledge_base_size()
1967        ));
1968
1969        println!("Incremental tactical training:");
1970        println!(
1971            "  - Positions: {} → {} (+{})",
1972            initial_size,
1973            engine.knowledge_base_size(),
1974            engine.knowledge_base_size() - initial_size
1975        );
1976        println!(
1977            "  - Move entries: {} → {} (+{})",
1978            initial_moves,
1979            engine.position_moves.len(),
1980            engine.position_moves.len() - initial_moves
1981        );
1982    }
1983
1984    /// Parallel batch loading for large datasets
1985    fn load_into_engine_incremental_parallel(
1986        tactical_data: &[TacticalTrainingData],
1987        engine: &mut ChessVectorEngine,
1988        initial_size: usize,
1989        initial_moves: usize,
1990    ) {
1991        let pb = ProgressBar::new(tactical_data.len() as u64);
1992        pb.set_style(ProgressStyle::default_bar()
1993            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Optimized batch loading tactical patterns")
1994            .unwrap()
1995            .progress_chars("#>-"));
1996
1997        // Parallel pre-filtering to avoid duplicates (this is thread-safe for read operations)
1998        let filtered_data: Vec<&TacticalTrainingData> = tactical_data
1999            .par_iter()
2000            .filter(|data| !engine.position_boards.contains(&data.position))
2001            .collect();
2002
2003        let batch_size = 1000; // Larger batches for better performance
2004        let mut added = 0;
2005
2006        println!(
2007            "Pre-filtered: {} → {} positions (removed {} duplicates)",
2008            tactical_data.len(),
2009            filtered_data.len(),
2010            tactical_data.len() - filtered_data.len()
2011        );
2012
2013        // Process in sequential batches (engine operations aren't thread-safe)
2014        // But use optimized batch processing
2015        for batch in filtered_data.chunks(batch_size) {
2016            let batch_start = added;
2017
2018            for data in batch {
2019                // Final duplicate check (should be minimal after pre-filtering)
2020                if !engine.position_boards.contains(&data.position) {
2021                    engine.add_position_with_move(
2022                        &data.position,
2023                        0.0, // Position evaluation (neutral for puzzles)
2024                        Some(data.solution_move),
2025                        Some(data.tactical_value), // High tactical value
2026                    );
2027                    added += 1;
2028                }
2029                pb.inc(1);
2030            }
2031
2032            // Update progress message every batch
2033            pb.set_message(format!("Loaded batch: {} positions", added - batch_start));
2034        }
2035
2036        let skipped = tactical_data.len() - added;
2037
2038        pb.finish_with_message(format!(
2039            "Optimized loaded {} new tactical patterns (skipped {} duplicates, total: {})",
2040            added,
2041            skipped,
2042            engine.knowledge_base_size()
2043        ));
2044
2045        println!("Incremental tactical training (optimized):");
2046        println!(
2047            "  - Positions: {} → {} (+{})",
2048            initial_size,
2049            engine.knowledge_base_size(),
2050            engine.knowledge_base_size() - initial_size
2051        );
2052        println!(
2053            "  - Move entries: {} → {} (+{})",
2054            initial_moves,
2055            engine.position_moves.len(),
2056            engine.position_moves.len() - initial_moves
2057        );
2058        println!(
2059            "  - Batch size: {}, Pre-filtered efficiency: {:.1}%",
2060            batch_size,
2061            (filtered_data.len() as f32 / tactical_data.len() as f32) * 100.0
2062        );
2063    }
2064
2065    /// Save tactical puzzles to file for incremental loading later
2066    pub fn save_tactical_puzzles<P: AsRef<std::path::Path>>(
2067        tactical_data: &[TacticalTrainingData],
2068        path: P,
2069    ) -> Result<(), Box<dyn std::error::Error>> {
2070        let json = serde_json::to_string_pretty(tactical_data)?;
2071        std::fs::write(path, json)?;
2072        println!("Saved {} tactical puzzles", tactical_data.len());
2073        Ok(())
2074    }
2075
2076    /// Load tactical puzzles from file
2077    pub fn load_tactical_puzzles<P: AsRef<std::path::Path>>(
2078        path: P,
2079    ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
2080        let content = std::fs::read_to_string(path)?;
2081        let tactical_data: Vec<TacticalTrainingData> = serde_json::from_str(&content)?;
2082        println!("Loaded {} tactical puzzles from file", tactical_data.len());
2083        Ok(tactical_data)
2084    }
2085
2086    /// Save tactical puzzles incrementally (appends to existing file)
2087    pub fn save_tactical_puzzles_incremental<P: AsRef<std::path::Path>>(
2088        tactical_data: &[TacticalTrainingData],
2089        path: P,
2090    ) -> Result<(), Box<dyn std::error::Error>> {
2091        let path = path.as_ref();
2092
2093        if path.exists() {
2094            // Load existing puzzles
2095            let mut existing = Self::load_tactical_puzzles(path)?;
2096            let original_len = existing.len();
2097
2098            // Add new puzzles, checking for duplicates by puzzle ID if available
2099            for new_puzzle in tactical_data {
2100                // Check if this puzzle already exists (by position)
2101                let exists = existing.iter().any(|existing_puzzle| {
2102                    existing_puzzle.position == new_puzzle.position
2103                        && existing_puzzle.solution_move == new_puzzle.solution_move
2104                });
2105
2106                if !exists {
2107                    existing.push(new_puzzle.clone());
2108                }
2109            }
2110
2111            // Save merged data
2112            let json = serde_json::to_string_pretty(&existing)?;
2113            std::fs::write(path, json)?;
2114
2115            println!(
2116                "Incremental save: added {} new puzzles (total: {})",
2117                existing.len() - original_len,
2118                existing.len()
2119            );
2120        } else {
2121            // File doesn't exist, just save normally
2122            Self::save_tactical_puzzles(tactical_data, path)?;
2123        }
2124        Ok(())
2125    }
2126
2127    /// Parse Lichess puzzles incrementally (preserves existing engine state)
2128    pub fn parse_and_load_incremental<P: AsRef<std::path::Path>>(
2129        file_path: P,
2130        engine: &mut ChessVectorEngine,
2131        max_puzzles: Option<usize>,
2132        min_rating: Option<u32>,
2133        max_rating: Option<u32>,
2134    ) -> Result<(), Box<dyn std::error::Error>> {
2135        println!("Parsing Lichess puzzles incrementally...");
2136
2137        // Parse puzzles
2138        let tactical_data = Self::parse_csv(file_path, max_puzzles, min_rating, max_rating)?;
2139
2140        // Load into engine incrementally
2141        Self::load_into_engine_incremental(&tactical_data, engine);
2142
2143        Ok(())
2144    }
2145}
2146
2147#[cfg(test)]
2148mod tests {
2149    use super::*;
2150    use chess::Board;
2151    use std::str::FromStr;
2152
2153    #[test]
2154    fn test_training_dataset_creation() {
2155        let dataset = TrainingDataset::new();
2156        assert_eq!(dataset.data.len(), 0);
2157    }
2158
2159    #[test]
2160    fn test_add_training_data() {
2161        let mut dataset = TrainingDataset::new();
2162        let board = Board::default();
2163
2164        let training_data = TrainingData {
2165            board,
2166            evaluation: 0.5,
2167            depth: 15,
2168            game_id: 1,
2169        };
2170
2171        dataset.data.push(training_data);
2172        assert_eq!(dataset.data.len(), 1);
2173        assert_eq!(dataset.data[0].evaluation, 0.5);
2174    }
2175
2176    #[test]
2177    fn test_chess_engine_integration() {
2178        let mut dataset = TrainingDataset::new();
2179        let board = Board::default();
2180
2181        let training_data = TrainingData {
2182            board,
2183            evaluation: 0.3,
2184            depth: 15,
2185            game_id: 1,
2186        };
2187
2188        dataset.data.push(training_data);
2189
2190        let mut engine = ChessVectorEngine::new(1024);
2191        dataset.train_engine(&mut engine);
2192
2193        assert_eq!(engine.knowledge_base_size(), 1);
2194
2195        let eval = engine.evaluate_position(&board);
2196        assert!(eval.is_some());
2197        assert!((eval.unwrap() - 0.3).abs() < 1e-6);
2198    }
2199
2200    #[test]
2201    fn test_deduplication() {
2202        let mut dataset = TrainingDataset::new();
2203        let board = Board::default();
2204
2205        // Add duplicate positions
2206        for i in 0..5 {
2207            let training_data = TrainingData {
2208                board,
2209                evaluation: i as f32 * 0.1,
2210                depth: 15,
2211                game_id: i,
2212            };
2213            dataset.data.push(training_data);
2214        }
2215
2216        assert_eq!(dataset.data.len(), 5);
2217
2218        // Deduplicate with high threshold (should keep only 1)
2219        dataset.deduplicate(0.999);
2220        assert_eq!(dataset.data.len(), 1);
2221    }
2222
2223    #[test]
2224    fn test_dataset_serialization() {
2225        let mut dataset = TrainingDataset::new();
2226        let board =
2227            Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2228
2229        let training_data = TrainingData {
2230            board,
2231            evaluation: 0.2,
2232            depth: 10,
2233            game_id: 42,
2234        };
2235
2236        dataset.data.push(training_data);
2237
2238        // Test serialization/deserialization
2239        let json = serde_json::to_string(&dataset.data).unwrap();
2240        let loaded_data: Vec<TrainingData> = serde_json::from_str(&json).unwrap();
2241        let loaded_dataset = TrainingDataset { data: loaded_data };
2242
2243        assert_eq!(loaded_dataset.data.len(), 1);
2244        assert_eq!(loaded_dataset.data[0].evaluation, 0.2);
2245        assert_eq!(loaded_dataset.data[0].depth, 10);
2246        assert_eq!(loaded_dataset.data[0].game_id, 42);
2247    }
2248
2249    #[test]
2250    fn test_tactical_puzzle_processing() {
2251        let puzzle = TacticalPuzzle {
2252            puzzle_id: "test123".to_string(),
2253            fen: "r1bqkbnr/pppp1ppp/2n5/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4".to_string(),
2254            moves: "Bxf7+ Ke7".to_string(),
2255            rating: 1500,
2256            rating_deviation: 100,
2257            popularity: 150,
2258            nb_plays: 1000,
2259            themes: "fork pin".to_string(),
2260            game_url: None,
2261            opening_tags: None,
2262        };
2263
2264        let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2265        assert!(tactical_data.is_some());
2266
2267        let data = tactical_data.unwrap();
2268        assert_eq!(data.move_theme, "fork");
2269        assert!(data.tactical_value > 1.0); // Should have high tactical value
2270        assert!(data.difficulty > 0.0);
2271    }
2272
2273    #[test]
2274    fn test_tactical_puzzle_invalid_fen() {
2275        let puzzle = TacticalPuzzle {
2276            puzzle_id: "test123".to_string(),
2277            fen: "invalid_fen".to_string(),
2278            moves: "e2e4".to_string(),
2279            rating: 1500,
2280            rating_deviation: 100,
2281            popularity: 150,
2282            nb_plays: 1000,
2283            themes: "tactics".to_string(),
2284            game_url: None,
2285            opening_tags: None,
2286        };
2287
2288        let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2289        assert!(tactical_data.is_none());
2290    }
2291
2292    #[test]
2293    fn test_engine_evaluator() {
2294        let evaluator = EngineEvaluator::new(15);
2295
2296        // Create test dataset
2297        let mut dataset = TrainingDataset::new();
2298        let board = Board::default();
2299
2300        let training_data = TrainingData {
2301            board,
2302            evaluation: 0.0,
2303            depth: 15,
2304            game_id: 1,
2305        };
2306
2307        dataset.data.push(training_data);
2308
2309        // Create engine with some data
2310        let mut engine = ChessVectorEngine::new(1024);
2311        engine.add_position(&board, 0.1);
2312
2313        // Test accuracy evaluation
2314        let accuracy = evaluator.evaluate_accuracy(&mut engine, &dataset);
2315        assert!(accuracy.is_ok());
2316        assert!(accuracy.unwrap() < 1.0); // Should have some accuracy
2317    }
2318
2319    #[test]
2320    fn test_tactical_training_integration() {
2321        let tactical_data = vec![TacticalTrainingData {
2322            position: Board::default(),
2323            solution_move: ChessMove::from_str("e2e4").unwrap(),
2324            move_theme: "opening".to_string(),
2325            difficulty: 1.2,
2326            tactical_value: 2.5,
2327        }];
2328
2329        let mut engine = ChessVectorEngine::new(1024);
2330        TacticalPuzzleParser::load_into_engine(&tactical_data, &mut engine);
2331
2332        assert_eq!(engine.knowledge_base_size(), 1);
2333        assert_eq!(engine.position_moves.len(), 1);
2334
2335        // Test that tactical move is available in recommendations
2336        let recommendations = engine.recommend_moves(&Board::default(), 5);
2337        assert!(!recommendations.is_empty());
2338    }
2339
2340    #[test]
2341    fn test_multithreading_operations() {
2342        let mut dataset = TrainingDataset::new();
2343        let board = Board::default();
2344
2345        // Add test data
2346        for i in 0..10 {
2347            let training_data = TrainingData {
2348                board,
2349                evaluation: i as f32 * 0.1,
2350                depth: 15,
2351                game_id: i,
2352            };
2353            dataset.data.push(training_data);
2354        }
2355
2356        // Test parallel deduplication doesn't crash
2357        dataset.deduplicate_parallel(0.95, 5);
2358        assert!(dataset.data.len() <= 10);
2359    }
2360
2361    #[test]
2362    fn test_incremental_dataset_operations() {
2363        let mut dataset1 = TrainingDataset::new();
2364        let board1 = Board::default();
2365        let board2 =
2366            Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2367
2368        // Add initial data
2369        dataset1.add_position(board1, 0.0, 15, 1);
2370        dataset1.add_position(board2, 0.2, 15, 2);
2371        assert_eq!(dataset1.data.len(), 2);
2372
2373        // Create second dataset
2374        let mut dataset2 = TrainingDataset::new();
2375        dataset2.add_position(
2376            Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2377                .unwrap(),
2378            0.3,
2379            15,
2380            3,
2381        );
2382
2383        // Merge datasets
2384        dataset1.merge(dataset2);
2385        assert_eq!(dataset1.data.len(), 3);
2386
2387        // Test next_game_id
2388        let next_id = dataset1.next_game_id();
2389        assert_eq!(next_id, 4); // Should be max(1,2,3) + 1
2390    }
2391
2392    #[test]
2393    fn test_save_load_incremental() {
2394        use tempfile::tempdir;
2395
2396        let temp_dir = tempdir().unwrap();
2397        let file_path = temp_dir.path().join("incremental_test.json");
2398
2399        // Create and save first dataset
2400        let mut dataset1 = TrainingDataset::new();
2401        dataset1.add_position(Board::default(), 0.0, 15, 1);
2402        dataset1.save(&file_path).unwrap();
2403
2404        // Create second dataset and save incrementally
2405        let mut dataset2 = TrainingDataset::new();
2406        dataset2.add_position(
2407            Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap(),
2408            0.2,
2409            15,
2410            2,
2411        );
2412        dataset2.save_incremental(&file_path).unwrap();
2413
2414        // Load and verify merged data
2415        let loaded = TrainingDataset::load(&file_path).unwrap();
2416        assert_eq!(loaded.data.len(), 2);
2417
2418        // Test load_and_append
2419        let mut dataset3 = TrainingDataset::new();
2420        dataset3.add_position(
2421            Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2422                .unwrap(),
2423            0.3,
2424            15,
2425            3,
2426        );
2427        dataset3.load_and_append(&file_path).unwrap();
2428        assert_eq!(dataset3.data.len(), 3); // 1 original + 2 from file
2429    }
2430
2431    #[test]
2432    fn test_add_position_method() {
2433        let mut dataset = TrainingDataset::new();
2434        let board = Board::default();
2435
2436        // Test add_position method
2437        dataset.add_position(board, 0.5, 20, 42);
2438        assert_eq!(dataset.data.len(), 1);
2439        assert_eq!(dataset.data[0].evaluation, 0.5);
2440        assert_eq!(dataset.data[0].depth, 20);
2441        assert_eq!(dataset.data[0].game_id, 42);
2442    }
2443
2444    #[test]
2445    fn test_incremental_save_deduplication() {
2446        use tempfile::tempdir;
2447
2448        let temp_dir = tempdir().unwrap();
2449        let file_path = temp_dir.path().join("dedup_test.json");
2450
2451        // Create and save first dataset
2452        let mut dataset1 = TrainingDataset::new();
2453        dataset1.add_position(Board::default(), 0.0, 15, 1);
2454        dataset1.save(&file_path).unwrap();
2455
2456        // Create second dataset with duplicate position
2457        let mut dataset2 = TrainingDataset::new();
2458        dataset2.add_position(Board::default(), 0.1, 15, 2); // Same position, different eval
2459        dataset2.save_incremental(&file_path).unwrap();
2460
2461        // Should deduplicate and keep only one
2462        let loaded = TrainingDataset::load(&file_path).unwrap();
2463        assert_eq!(loaded.data.len(), 1);
2464    }
2465
2466    #[test]
2467    fn test_tactical_puzzle_incremental_loading() {
2468        let tactical_data = vec![
2469            TacticalTrainingData {
2470                position: Board::default(),
2471                solution_move: ChessMove::from_str("e2e4").unwrap(),
2472                move_theme: "opening".to_string(),
2473                difficulty: 1.2,
2474                tactical_value: 2.5,
2475            },
2476            TacticalTrainingData {
2477                position: Board::from_str(
2478                    "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1",
2479                )
2480                .unwrap(),
2481                solution_move: ChessMove::from_str("e7e5").unwrap(),
2482                move_theme: "opening".to_string(),
2483                difficulty: 1.0,
2484                tactical_value: 2.0,
2485            },
2486        ];
2487
2488        let mut engine = ChessVectorEngine::new(1024);
2489
2490        // Add some existing data
2491        engine.add_position(&Board::default(), 0.1);
2492        assert_eq!(engine.knowledge_base_size(), 1);
2493
2494        // Load tactical puzzles incrementally
2495        TacticalPuzzleParser::load_into_engine_incremental(&tactical_data, &mut engine);
2496
2497        // Should have added the new position but skipped the duplicate
2498        assert_eq!(engine.knowledge_base_size(), 2);
2499
2500        // Should have move data for both puzzles
2501        assert!(engine.training_stats().has_move_data);
2502        assert!(engine.training_stats().move_data_entries > 0);
2503    }
2504
2505    #[test]
2506    fn test_tactical_puzzle_serialization() {
2507        use tempfile::tempdir;
2508
2509        let temp_dir = tempdir().unwrap();
2510        let file_path = temp_dir.path().join("tactical_test.json");
2511
2512        let tactical_data = vec![TacticalTrainingData {
2513            position: Board::default(),
2514            solution_move: ChessMove::from_str("e2e4").unwrap(),
2515            move_theme: "fork".to_string(),
2516            difficulty: 1.5,
2517            tactical_value: 3.0,
2518        }];
2519
2520        // Save tactical puzzles
2521        TacticalPuzzleParser::save_tactical_puzzles(&tactical_data, &file_path).unwrap();
2522
2523        // Load them back
2524        let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2525        assert_eq!(loaded.len(), 1);
2526        assert_eq!(loaded[0].move_theme, "fork");
2527        assert_eq!(loaded[0].difficulty, 1.5);
2528        assert_eq!(loaded[0].tactical_value, 3.0);
2529    }
2530
2531    #[test]
2532    fn test_tactical_puzzle_incremental_save() {
2533        use tempfile::tempdir;
2534
2535        let temp_dir = tempdir().unwrap();
2536        let file_path = temp_dir.path().join("incremental_tactical.json");
2537
2538        // Save first batch
2539        let batch1 = vec![TacticalTrainingData {
2540            position: Board::default(),
2541            solution_move: ChessMove::from_str("e2e4").unwrap(),
2542            move_theme: "opening".to_string(),
2543            difficulty: 1.0,
2544            tactical_value: 2.0,
2545        }];
2546        TacticalPuzzleParser::save_tactical_puzzles(&batch1, &file_path).unwrap();
2547
2548        // Save second batch incrementally
2549        let batch2 = vec![TacticalTrainingData {
2550            position: Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1")
2551                .unwrap(),
2552            solution_move: ChessMove::from_str("e7e5").unwrap(),
2553            move_theme: "counter".to_string(),
2554            difficulty: 1.2,
2555            tactical_value: 2.2,
2556        }];
2557        TacticalPuzzleParser::save_tactical_puzzles_incremental(&batch2, &file_path).unwrap();
2558
2559        // Load and verify merged data
2560        let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2561        assert_eq!(loaded.len(), 2);
2562    }
2563
2564    #[test]
2565    fn test_tactical_puzzle_incremental_deduplication() {
2566        use tempfile::tempdir;
2567
2568        let temp_dir = tempdir().unwrap();
2569        let file_path = temp_dir.path().join("dedup_tactical.json");
2570
2571        let tactical_data = TacticalTrainingData {
2572            position: Board::default(),
2573            solution_move: ChessMove::from_str("e2e4").unwrap(),
2574            move_theme: "opening".to_string(),
2575            difficulty: 1.0,
2576            tactical_value: 2.0,
2577        };
2578
2579        // Save first time
2580        TacticalPuzzleParser::save_tactical_puzzles(&[tactical_data.clone()], &file_path).unwrap();
2581
2582        // Try to save the same puzzle again
2583        TacticalPuzzleParser::save_tactical_puzzles_incremental(&[tactical_data], &file_path)
2584            .unwrap();
2585
2586        // Should still have only one puzzle (deduplicated)
2587        let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2588        assert_eq!(loaded.len(), 1);
2589    }
2590}