chess_vector_engine/
training.rs

1use chess::{Board, ChessMove, Game, MoveGen};
2use indicatif::{ProgressBar, ProgressStyle};
3use pgn_reader::{BufferedReader, RawHeader, SanPlus, Skip, Visitor};
4use rayon::prelude::*;
5use serde::{Deserialize, Serialize};
6use std::fs::File;
7use std::io::{BufRead, BufReader, BufWriter, Write};
8use std::path::Path;
9use std::process::{Child, Command, Stdio};
10use std::str::FromStr;
11use std::sync::{Arc, Mutex};
12
13use crate::ChessVectorEngine;
14
15/// Self-play training configuration
16#[derive(Debug, Clone)]
17pub struct SelfPlayConfig {
18    /// Number of games to play per training iteration
19    pub games_per_iteration: usize,
20    /// Maximum moves per game (to prevent infinite games)
21    pub max_moves_per_game: usize,
22    /// Exploration factor for move selection (0.0 = greedy, 1.0 = random)
23    pub exploration_factor: f32,
24    /// Minimum evaluation confidence to include position
25    pub min_confidence: f32,
26    /// Whether to use opening book for game starts
27    pub use_opening_book: bool,
28    /// Temperature for move selection (higher = more random)
29    pub temperature: f32,
30}
31
32impl Default for SelfPlayConfig {
33    fn default() -> Self {
34        Self {
35            games_per_iteration: 100,
36            max_moves_per_game: 200,
37            exploration_factor: 0.3,
38            min_confidence: 0.1,
39            use_opening_book: true,
40            temperature: 0.8,
41        }
42    }
43}
44
45/// Training data point containing a position and its evaluation
46#[derive(Debug, Clone)]
47pub struct TrainingData {
48    pub board: Board,
49    pub evaluation: f32,
50    pub depth: u8,
51    pub game_id: usize,
52}
53
54/// Tactical puzzle data from Lichess puzzle database
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct TacticalPuzzle {
57    #[serde(rename = "PuzzleId")]
58    pub puzzle_id: String,
59    #[serde(rename = "FEN")]
60    pub fen: String,
61    #[serde(rename = "Moves")]
62    pub moves: String, // Space-separated move sequence
63    #[serde(rename = "Rating")]
64    pub rating: u32,
65    #[serde(rename = "RatingDeviation")]
66    pub rating_deviation: u32,
67    #[serde(rename = "Popularity")]
68    pub popularity: i32,
69    #[serde(rename = "NbPlays")]
70    pub nb_plays: u32,
71    #[serde(rename = "Themes")]
72    pub themes: String, // Space-separated themes
73    #[serde(rename = "GameUrl")]
74    pub game_url: Option<String>,
75    #[serde(rename = "OpeningTags")]
76    pub opening_tags: Option<String>,
77}
78
79/// Processed tactical training data
80#[derive(Debug, Clone)]
81pub struct TacticalTrainingData {
82    pub position: Board,
83    pub solution_move: ChessMove,
84    pub move_theme: String,
85    pub difficulty: f32,     // Rating as difficulty
86    pub tactical_value: f32, // High value for move outcome
87}
88
89// Make TacticalTrainingData serializable
90impl serde::Serialize for TacticalTrainingData {
91    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
92    where
93        S: serde::Serializer,
94    {
95        use serde::ser::SerializeStruct;
96        let mut state = serializer.serialize_struct("TacticalTrainingData", 5)?;
97        state.serialize_field("fen", &self.position.to_string())?;
98        state.serialize_field("solution_move", &self.solution_move.to_string())?;
99        state.serialize_field("move_theme", &self.move_theme)?;
100        state.serialize_field("difficulty", &self.difficulty)?;
101        state.serialize_field("tactical_value", &self.tactical_value)?;
102        state.end()
103    }
104}
105
106impl<'de> serde::Deserialize<'de> for TacticalTrainingData {
107    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108    where
109        D: serde::Deserializer<'de>,
110    {
111        use serde::de::{self, MapAccess, Visitor};
112        use std::fmt;
113
114        struct TacticalTrainingDataVisitor;
115
116        impl<'de> Visitor<'de> for TacticalTrainingDataVisitor {
117            type Value = TacticalTrainingData;
118
119            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
120                formatter.write_str("struct TacticalTrainingData")
121            }
122
123            fn visit_map<V>(self, mut map: V) -> Result<TacticalTrainingData, V::Error>
124            where
125                V: MapAccess<'de>,
126            {
127                let mut fen = None;
128                let mut solution_move = None;
129                let mut move_theme = None;
130                let mut difficulty = None;
131                let mut tactical_value = None;
132
133                while let Some(key) = map.next_key()? {
134                    match key {
135                        "fen" => {
136                            if fen.is_some() {
137                                return Err(de::Error::duplicate_field("fen"));
138                            }
139                            fen = Some(map.next_value()?);
140                        }
141                        "solution_move" => {
142                            if solution_move.is_some() {
143                                return Err(de::Error::duplicate_field("solution_move"));
144                            }
145                            solution_move = Some(map.next_value()?);
146                        }
147                        "move_theme" => {
148                            if move_theme.is_some() {
149                                return Err(de::Error::duplicate_field("move_theme"));
150                            }
151                            move_theme = Some(map.next_value()?);
152                        }
153                        "difficulty" => {
154                            if difficulty.is_some() {
155                                return Err(de::Error::duplicate_field("difficulty"));
156                            }
157                            difficulty = Some(map.next_value()?);
158                        }
159                        "tactical_value" => {
160                            if tactical_value.is_some() {
161                                return Err(de::Error::duplicate_field("tactical_value"));
162                            }
163                            tactical_value = Some(map.next_value()?);
164                        }
165                        _ => {
166                            let _: serde_json::Value = map.next_value()?;
167                        }
168                    }
169                }
170
171                let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
172                let solution_move_str: String =
173                    solution_move.ok_or_else(|| de::Error::missing_field("solution_move"))?;
174                let move_theme =
175                    move_theme.ok_or_else(|| de::Error::missing_field("move_theme"))?;
176                let difficulty =
177                    difficulty.ok_or_else(|| de::Error::missing_field("difficulty"))?;
178                let tactical_value =
179                    tactical_value.ok_or_else(|| de::Error::missing_field("tactical_value"))?;
180
181                let position =
182                    Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
183
184                let solution_move = ChessMove::from_str(&solution_move_str)
185                    .map_err(|e| de::Error::custom(format!("Error: {e}")))?;
186
187                Ok(TacticalTrainingData {
188                    position,
189                    solution_move,
190                    move_theme,
191                    difficulty,
192                    tactical_value,
193                })
194            }
195        }
196
197        const FIELDS: &[&str] = &[
198            "fen",
199            "solution_move",
200            "move_theme",
201            "difficulty",
202            "tactical_value",
203        ];
204        deserializer.deserialize_struct("TacticalTrainingData", FIELDS, TacticalTrainingDataVisitor)
205    }
206}
207
208/// PGN game visitor for extracting positions
209pub struct GameExtractor {
210    pub positions: Vec<TrainingData>,
211    pub current_game: Game,
212    pub move_count: usize,
213    pub max_moves_per_game: usize,
214    pub game_id: usize,
215}
216
217impl GameExtractor {
218    pub fn new(max_moves_per_game: usize) -> Self {
219        Self {
220            positions: Vec::new(),
221            current_game: Game::new(),
222            move_count: 0,
223            max_moves_per_game,
224            game_id: 0,
225        }
226    }
227}
228
229impl Visitor for GameExtractor {
230    type Result = ();
231
232    fn begin_game(&mut self) {
233        self.current_game = Game::new();
234        self.move_count = 0;
235        self.game_id += 1;
236    }
237
238    fn header(&mut self, _key: &[u8], _value: RawHeader<'_>) {}
239
240    fn san(&mut self, san_plus: SanPlus) {
241        if self.move_count >= self.max_moves_per_game {
242            return;
243        }
244
245        let san_str = san_plus.san.to_string();
246
247        // First validate that we have a legal position to work with
248        let current_pos = self.current_game.current_position();
249
250        // Try to parse and make the move
251        match chess::ChessMove::from_san(&current_pos, &san_str) {
252            Ok(chess_move) => {
253                // Verify the move is legal before making it
254                let legal_moves: Vec<chess::ChessMove> = MoveGen::new_legal(&current_pos).collect();
255                if legal_moves.contains(&chess_move) {
256                    if self.current_game.make_move(chess_move) {
257                        self.move_count += 1;
258
259                        // Store position (we'll evaluate it later with Stockfish)
260                        self.positions.push(TrainingData {
261                            board: self.current_game.current_position(),
262                            evaluation: 0.0, // Will be filled by Stockfish
263                            depth: 0,
264                            game_id: self.game_id,
265                        });
266                    }
267                } else {
268                    // Move parsed but isn't legal - skip silently to avoid spam
269                }
270            }
271            Err(_) => {
272                // Failed to parse move - could be notation issues, corruption, etc.
273                // Skip silently to avoid excessive error output
274                // Only log if it's not a common problematic pattern
275                if !san_str.contains("O-O") && !san_str.contains("=") && san_str.len() > 6 {
276                    // Only log unusual failed moves to reduce noise
277                }
278            }
279        }
280    }
281
282    fn begin_variation(&mut self) -> Skip {
283        Skip(true) // Skip variations for now
284    }
285
286    fn end_game(&mut self) -> Self::Result {}
287}
288
289/// Stockfish engine wrapper for position evaluation
290pub struct StockfishEvaluator {
291    depth: u8,
292}
293
294impl StockfishEvaluator {
295    pub fn new(depth: u8) -> Self {
296        Self { depth }
297    }
298
299    /// Evaluate a single position using Stockfish
300    pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
301        let mut child = Command::new("stockfish")
302            .stdin(Stdio::piped())
303            .stdout(Stdio::piped())
304            .stderr(Stdio::piped())
305            .spawn()?;
306
307        let stdin = child
308            .stdin
309            .as_mut()
310            .ok_or("Failed to get stdin handle for Stockfish process")?;
311        let fen = board.to_string();
312
313        // Send UCI commands
314        use std::io::Write;
315        writeln!(stdin, "uci")?;
316        writeln!(stdin, "isready")?;
317        writeln!(stdin, "position fen {fen}")?;
318        writeln!(stdin, "go depth {}", self.depth)?;
319        writeln!(stdin, "quit")?;
320
321        let output = child.wait_with_output()?;
322        let stdout = String::from_utf8_lossy(&output.stdout);
323
324        // Parse the evaluation from Stockfish output
325        for line in stdout.lines() {
326            if line.starts_with("info") && line.contains("score cp") {
327                if let Some(cp_pos) = line.find("score cp ") {
328                    let cp_str = &line[cp_pos + 9..];
329                    if let Some(end) = cp_str.find(' ') {
330                        let cp_value = cp_str[..end].parse::<i32>()?;
331                        return Ok(cp_value as f32 / 100.0); // Convert centipawns to pawns
332                    }
333                }
334            } else if line.starts_with("info") && line.contains("score mate") {
335                // Handle mate scores
336                if let Some(mate_pos) = line.find("score mate ") {
337                    let mate_str = &line[mate_pos + 11..];
338                    if let Some(end) = mate_str.find(' ') {
339                        let mate_moves = mate_str[..end].parse::<i32>()?;
340                        return Ok(if mate_moves > 0 { 100.0 } else { -100.0 });
341                    }
342                }
343            }
344        }
345
346        Ok(0.0) // Default to 0 if no evaluation found
347    }
348
349    /// Batch evaluate multiple positions
350    pub fn evaluate_batch(
351        &self,
352        positions: &mut [TrainingData],
353    ) -> Result<(), Box<dyn std::error::Error>> {
354        let pb = ProgressBar::new(positions.len() as u64);
355        if let Ok(style) = ProgressStyle::default_bar().template(
356            "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
357        ) {
358            pb.set_style(style.progress_chars("#>-"));
359        }
360
361        for data in positions.iter_mut() {
362            match self.evaluate_position(&data.board) {
363                Ok(eval) => {
364                    data.evaluation = eval;
365                    data.depth = self.depth;
366                }
367                Err(e) => {
368                    eprintln!("Evaluation error: {e}");
369                    data.evaluation = 0.0;
370                }
371            }
372            pb.inc(1);
373        }
374
375        pb.finish_with_message("Evaluation complete");
376        Ok(())
377    }
378
379    /// Evaluate multiple positions in parallel using concurrent Stockfish instances
380    pub fn evaluate_batch_parallel(
381        &self,
382        positions: &mut [TrainingData],
383        num_threads: usize,
384    ) -> Result<(), Box<dyn std::error::Error>> {
385        let pb = ProgressBar::new(positions.len() as u64);
386        if let Ok(style) = ProgressStyle::default_bar()
387            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel evaluation") {
388            pb.set_style(style.progress_chars("#>-"));
389        }
390
391        // Set the thread pool size
392        let pool = rayon::ThreadPoolBuilder::new()
393            .num_threads(num_threads)
394            .build()?;
395
396        pool.install(|| {
397            // Use parallel iterator to evaluate positions
398            positions.par_iter_mut().for_each(|data| {
399                match self.evaluate_position(&data.board) {
400                    Ok(eval) => {
401                        data.evaluation = eval;
402                        data.depth = self.depth;
403                    }
404                    Err(_) => {
405                        // Silently fail individual positions to avoid spamming output
406                        data.evaluation = 0.0;
407                    }
408                }
409                pb.inc(1);
410            });
411        });
412
413        pb.finish_with_message("Parallel evaluation complete");
414        Ok(())
415    }
416}
417
418/// Persistent Stockfish process for fast UCI communication
419struct StockfishProcess {
420    child: Child,
421    stdin: BufWriter<std::process::ChildStdin>,
422    stdout: BufReader<std::process::ChildStdout>,
423    #[allow(dead_code)]
424    depth: u8,
425}
426
427impl StockfishProcess {
428    fn new(depth: u8) -> Result<Self, Box<dyn std::error::Error>> {
429        let mut child = Command::new("stockfish")
430            .stdin(Stdio::piped())
431            .stdout(Stdio::piped())
432            .stderr(Stdio::piped())
433            .spawn()?;
434
435        let stdin = BufWriter::new(
436            child
437                .stdin
438                .take()
439                .ok_or("Failed to get stdin handle for Stockfish process")?,
440        );
441        let stdout = BufReader::new(
442            child
443                .stdout
444                .take()
445                .ok_or("Failed to get stdout handle for Stockfish process")?,
446        );
447
448        let mut process = Self {
449            child,
450            stdin,
451            stdout,
452            depth,
453        };
454
455        // Initialize UCI
456        process.send_command("uci")?;
457        process.wait_for_ready()?;
458        process.send_command("isready")?;
459        process.wait_for_ready()?;
460
461        Ok(process)
462    }
463
464    fn send_command(&mut self, command: &str) -> Result<(), Box<dyn std::error::Error>> {
465        writeln!(self.stdin, "{command}")?;
466        self.stdin.flush()?;
467        Ok(())
468    }
469
470    fn wait_for_ready(&mut self) -> Result<(), Box<dyn std::error::Error>> {
471        let mut line = String::new();
472        loop {
473            line.clear();
474            self.stdout.read_line(&mut line)?;
475            if line.trim() == "uciok" || line.trim() == "readyok" {
476                break;
477            }
478        }
479        Ok(())
480    }
481
482    fn evaluate_position(&mut self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
483        let fen = board.to_string();
484
485        // Send position and evaluation commands
486        self.send_command(&format!("position fen {fen}"))?;
487        self.send_command(&format!("position fen {fen}"))?;
488
489        // Read response until we get bestmove
490        let mut line = String::new();
491        let mut last_evaluation = 0.0;
492
493        loop {
494            line.clear();
495            self.stdout.read_line(&mut line)?;
496            let line = line.trim();
497
498            if line.starts_with("info") && line.contains("score cp") {
499                if let Some(cp_pos) = line.find("score cp ") {
500                    let cp_str = &line[cp_pos + 9..];
501                    if let Some(end) = cp_str.find(' ') {
502                        if let Ok(cp_value) = cp_str[..end].parse::<i32>() {
503                            last_evaluation = cp_value as f32 / 100.0;
504                        }
505                    }
506                }
507            } else if line.starts_with("info") && line.contains("score mate") {
508                if let Some(mate_pos) = line.find("score mate ") {
509                    let mate_str = &line[mate_pos + 11..];
510                    if let Some(end) = mate_str.find(' ') {
511                        if let Ok(mate_moves) = mate_str[..end].parse::<i32>() {
512                            last_evaluation = if mate_moves > 0 { 100.0 } else { -100.0 };
513                        }
514                    }
515                }
516            } else if line.starts_with("bestmove") {
517                break;
518            }
519        }
520
521        Ok(last_evaluation)
522    }
523}
524
525impl Drop for StockfishProcess {
526    fn drop(&mut self) {
527        let _ = self.send_command("quit");
528        let _ = self.child.wait();
529    }
530}
531
532/// High-performance Stockfish process pool
533pub struct StockfishPool {
534    pool: Arc<Mutex<Vec<StockfishProcess>>>,
535    depth: u8,
536    pool_size: usize,
537}
538
539impl StockfishPool {
540    pub fn new(depth: u8, pool_size: usize) -> Result<Self, Box<dyn std::error::Error>> {
541        let mut processes = Vec::with_capacity(pool_size);
542
543        println!("🚀 Initializing Stockfish pool with {pool_size} processes...");
544
545        for i in 0..pool_size {
546            match StockfishProcess::new(depth) {
547                Ok(process) => {
548                    processes.push(process);
549                    if i % 2 == 1 {
550                        print!(".");
551                        let _ = std::io::stdout().flush(); // Ignore flush errors
552                    }
553                }
554                Err(e) => {
555                    eprintln!("Evaluation error: {e}");
556                    return Err(e);
557                }
558            }
559        }
560
561        println!(" ✅ Pool ready!");
562
563        Ok(Self {
564            pool: Arc::new(Mutex::new(processes)),
565            depth,
566            pool_size,
567        })
568    }
569
570    pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
571        // Get a process from the pool
572        let mut process = {
573            let mut pool = self.pool.lock().unwrap();
574            if let Some(process) = pool.pop() {
575                process
576            } else {
577                // Pool is empty, create temporary process
578                StockfishProcess::new(self.depth)?
579            }
580        };
581
582        // Evaluate position
583        let result = process.evaluate_position(board);
584
585        // Return process to pool
586        {
587            let mut pool = self.pool.lock().unwrap();
588            if pool.len() < self.pool_size {
589                pool.push(process);
590            }
591            // Otherwise drop the process (in case of pool size changes)
592        }
593
594        result
595    }
596
597    pub fn evaluate_batch_parallel(
598        &self,
599        positions: &mut [TrainingData],
600    ) -> Result<(), Box<dyn std::error::Error>> {
601        let pb = ProgressBar::new(positions.len() as u64);
602        pb.set_style(ProgressStyle::default_bar()
603            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Pool evaluation")
604            .unwrap()
605            .progress_chars("#>-"));
606
607        // Use rayon for parallel evaluation
608        positions.par_iter_mut().for_each(|data| {
609            match self.evaluate_position(&data.board) {
610                Ok(eval) => {
611                    data.evaluation = eval;
612                    data.depth = self.depth;
613                }
614                Err(_) => {
615                    data.evaluation = 0.0;
616                }
617            }
618            pb.inc(1);
619        });
620
621        pb.finish_with_message("Pool evaluation complete");
622        Ok(())
623    }
624}
625
626/// Training dataset manager
627pub struct TrainingDataset {
628    pub data: Vec<TrainingData>,
629}
630
631impl Default for TrainingDataset {
632    fn default() -> Self {
633        Self::new()
634    }
635}
636
637impl TrainingDataset {
638    pub fn new() -> Self {
639        Self { data: Vec::new() }
640    }
641
642    /// Load positions from a PGN file
643    pub fn load_from_pgn<P: AsRef<Path>>(
644        &mut self,
645        path: P,
646        max_games: Option<usize>,
647        max_moves_per_game: usize,
648    ) -> Result<(), Box<dyn std::error::Error>> {
649        let file = File::open(path)?;
650        let reader = BufReader::new(file);
651
652        let mut extractor = GameExtractor::new(max_moves_per_game);
653        let mut games_processed = 0;
654
655        // Create a simple PGN parser
656        let mut pgn_content = String::new();
657        for line in reader.lines() {
658            let line = line?;
659            pgn_content.push_str(&line);
660            pgn_content.push('\n');
661
662            // Check if this is the end of a game
663            if line.trim().ends_with("1-0")
664                || line.trim().ends_with("0-1")
665                || line.trim().ends_with("1/2-1/2")
666                || line.trim().ends_with("*")
667            {
668                // Parse this game
669                let cursor = std::io::Cursor::new(&pgn_content);
670                let mut reader = BufferedReader::new(cursor);
671                if let Err(e) = reader.read_all(&mut extractor) {
672                    eprintln!("Evaluation error: {e}");
673                }
674
675                games_processed += 1;
676                pgn_content.clear();
677
678                if let Some(max) = max_games {
679                    if games_processed >= max {
680                        break;
681                    }
682                }
683
684                if games_processed % 100 == 0 {
685                    println!(
686                        "Processed {} games, extracted {} positions",
687                        games_processed,
688                        extractor.positions.len()
689                    );
690                }
691            }
692        }
693
694        self.data.extend(extractor.positions);
695        println!(
696            "Loaded {} positions from {} games",
697            self.data.len(),
698            games_processed
699        );
700        Ok(())
701    }
702
703    /// Evaluate all positions using Stockfish
704    pub fn evaluate_with_stockfish(&mut self, depth: u8) -> Result<(), Box<dyn std::error::Error>> {
705        let evaluator = StockfishEvaluator::new(depth);
706        evaluator.evaluate_batch(&mut self.data)
707    }
708
709    /// Evaluate all positions using Stockfish in parallel
710    pub fn evaluate_with_stockfish_parallel(
711        &mut self,
712        depth: u8,
713        num_threads: usize,
714    ) -> Result<(), Box<dyn std::error::Error>> {
715        let evaluator = StockfishEvaluator::new(depth);
716        evaluator.evaluate_batch_parallel(&mut self.data, num_threads)
717    }
718
719    /// Train the vector engine with this dataset
720    pub fn train_engine(&self, engine: &mut ChessVectorEngine) {
721        let pb = ProgressBar::new(self.data.len() as u64);
722        pb.set_style(ProgressStyle::default_bar()
723            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Training positions")
724            .unwrap()
725            .progress_chars("#>-"));
726
727        for data in &self.data {
728            engine.add_position(&data.board, data.evaluation);
729            pb.inc(1);
730        }
731
732        pb.finish_with_message("Training complete");
733        println!("Trained engine with {} positions", self.data.len());
734    }
735
736    /// Split dataset into train/test sets by games to prevent data leakage
737    pub fn split(&self, train_ratio: f32) -> (TrainingDataset, TrainingDataset) {
738        use rand::seq::SliceRandom;
739        use rand::thread_rng;
740        use std::collections::{HashMap, HashSet};
741
742        // Group positions by game_id
743        let mut games: HashMap<usize, Vec<&TrainingData>> = HashMap::new();
744        for data in &self.data {
745            games.entry(data.game_id).or_default().push(data);
746        }
747
748        // Get unique game IDs and shuffle them
749        let mut game_ids: Vec<usize> = games.keys().cloned().collect();
750        game_ids.shuffle(&mut thread_rng());
751
752        // Split games by ratio
753        let split_point = (game_ids.len() as f32 * train_ratio) as usize;
754        let train_game_ids: HashSet<usize> = game_ids[..split_point].iter().cloned().collect();
755
756        // Separate positions based on game membership
757        let mut train_data = Vec::new();
758        let mut test_data = Vec::new();
759
760        for data in &self.data {
761            if train_game_ids.contains(&data.game_id) {
762                train_data.push(data.clone());
763            } else {
764                test_data.push(data.clone());
765            }
766        }
767
768        (
769            TrainingDataset { data: train_data },
770            TrainingDataset { data: test_data },
771        )
772    }
773
774    /// Save dataset to file
775    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
776        let json = serde_json::to_string_pretty(&self.data)?;
777        std::fs::write(path, json)?;
778        Ok(())
779    }
780
781    /// Load dataset from file
782    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
783        let content = std::fs::read_to_string(path)?;
784        let data = serde_json::from_str(&content)?;
785        Ok(Self { data })
786    }
787
788    /// Load and append data from file to existing dataset (incremental training)
789    pub fn load_and_append<P: AsRef<Path>>(
790        &mut self,
791        path: P,
792    ) -> Result<(), Box<dyn std::error::Error>> {
793        let existing_len = self.data.len();
794        let additional_data = Self::load(path)?;
795        self.data.extend(additional_data.data);
796        println!(
797            "Loaded {} additional positions (total: {})",
798            self.data.len() - existing_len,
799            self.data.len()
800        );
801        Ok(())
802    }
803
804    /// Merge another dataset into this one
805    pub fn merge(&mut self, other: TrainingDataset) {
806        let existing_len = self.data.len();
807        self.data.extend(other.data);
808        println!(
809            "Merged {} positions (total: {})",
810            self.data.len() - existing_len,
811            self.data.len()
812        );
813    }
814
815    /// Save incrementally (append to existing file if it exists)
816    pub fn save_incremental<P: AsRef<Path>>(
817        &self,
818        path: P,
819    ) -> Result<(), Box<dyn std::error::Error>> {
820        self.save_incremental_with_options(path, true)
821    }
822
823    /// Save incrementally with option to skip deduplication
824    pub fn save_incremental_with_options<P: AsRef<Path>>(
825        &self,
826        path: P,
827        deduplicate: bool,
828    ) -> Result<(), Box<dyn std::error::Error>> {
829        let path = path.as_ref();
830
831        if path.exists() {
832            // Try fast append-only save first
833            if self.save_append_only(path).is_ok() {
834                return Ok(());
835            }
836
837            // Fall back to full merge
838            if deduplicate {
839                self.save_incremental_full_merge(path)
840            } else {
841                self.save_incremental_no_dedup(path)
842            }
843        } else {
844            // File doesn't exist, just save normally
845            self.save(path)
846        }
847    }
848
849    /// Fast merge without deduplication (for trusted unique data)
850    fn save_incremental_no_dedup<P: AsRef<Path>>(
851        &self,
852        path: P,
853    ) -> Result<(), Box<dyn std::error::Error>> {
854        let path = path.as_ref();
855
856        println!("📂 Loading existing training data...");
857        let mut existing = Self::load(path)?;
858
859        println!("⚡ Fast merge without deduplication...");
860        existing.data.extend(self.data.iter().cloned());
861
862        println!(
863            "💾 Serializing {} positions to JSON...",
864            existing.data.len()
865        );
866        let json = serde_json::to_string_pretty(&existing.data)?;
867
868        println!("✍️  Writing to disk...");
869        std::fs::write(path, json)?;
870
871        println!(
872            "✅ Fast merge save: total {} positions",
873            existing.data.len()
874        );
875        Ok(())
876    }
877
878    /// Fast append-only save (no deduplication, just append new positions)
879    pub fn save_append_only<P: AsRef<Path>>(
880        &self,
881        path: P,
882    ) -> Result<(), Box<dyn std::error::Error>> {
883        use std::fs::OpenOptions;
884        use std::io::{BufRead, BufReader, Seek, SeekFrom, Write};
885
886        if self.data.is_empty() {
887            return Ok(());
888        }
889
890        let path = path.as_ref();
891        let mut file = OpenOptions::new().read(true).write(true).open(path)?;
892
893        // Check if file is valid JSON array by reading last few bytes
894        file.seek(SeekFrom::End(-10))?;
895        let mut buffer = String::new();
896        BufReader::new(&file).read_line(&mut buffer)?;
897
898        if !buffer.trim().ends_with(']') {
899            return Err("File doesn't end with JSON array bracket".into());
900        }
901
902        // Seek back to overwrite the closing bracket
903        file.seek(SeekFrom::End(-2))?; // Go back 2 chars to overwrite "]\n"
904
905        // Append comma and new positions
906        write!(file, ",")?;
907
908        // Serialize and append new positions (without array brackets)
909        for (i, data) in self.data.iter().enumerate() {
910            if i > 0 {
911                write!(file, ",")?;
912            }
913            let json = serde_json::to_string(data)?;
914            write!(file, "{json}")?;
915        }
916
917        // Close the JSON array
918        write!(file, "\n]")?;
919
920        println!("Fast append: added {} new positions", self.data.len());
921        Ok(())
922    }
923
924    /// Full merge save with deduplication (slower but thorough)
925    fn save_incremental_full_merge<P: AsRef<Path>>(
926        &self,
927        path: P,
928    ) -> Result<(), Box<dyn std::error::Error>> {
929        let path = path.as_ref();
930
931        println!("📂 Loading existing training data...");
932        let mut existing = Self::load(path)?;
933        let _original_len = existing.data.len();
934
935        println!("🔄 Streaming merge with deduplication (avoiding O(n²) operation)...");
936        existing.merge_and_deduplicate(self.data.clone());
937
938        println!(
939            "💾 Serializing {} positions to JSON...",
940            existing.data.len()
941        );
942        let json = serde_json::to_string_pretty(&existing.data)?;
943
944        println!("✍️  Writing to disk...");
945        std::fs::write(path, json)?;
946
947        println!(
948            "✅ Streaming merge save: total {} positions",
949            existing.data.len()
950        );
951        Ok(())
952    }
953
954    /// Add a single training data point
955    pub fn add_position(&mut self, board: Board, evaluation: f32, depth: u8, game_id: usize) {
956        self.data.push(TrainingData {
957            board,
958            evaluation,
959            depth,
960            game_id,
961        });
962    }
963
964    /// Get the next available game ID for incremental training
965    pub fn next_game_id(&self) -> usize {
966        self.data.iter().map(|data| data.game_id).max().unwrap_or(0) + 1
967    }
968
969    /// Remove near-duplicate positions to reduce overfitting
970    pub fn deduplicate(&mut self, similarity_threshold: f32) {
971        if similarity_threshold > 0.999 {
972            // Use fast hash-based deduplication for exact duplicates
973            self.deduplicate_fast();
974        } else {
975            // Use slower similarity-based deduplication for near-duplicates
976            self.deduplicate_similarity_based(similarity_threshold);
977        }
978    }
979
980    /// Fast hash-based deduplication for exact duplicates (O(n))
981    pub fn deduplicate_fast(&mut self) {
982        use std::collections::HashSet;
983
984        if self.data.is_empty() {
985            return;
986        }
987
988        let mut seen_positions = HashSet::with_capacity(self.data.len());
989        let original_len = self.data.len();
990
991        // Keep positions with unique FEN strings
992        self.data.retain(|data| {
993            let fen = data.board.to_string();
994            seen_positions.insert(fen)
995        });
996
997        println!(
998            "Fast deduplicated: {} -> {} positions (removed {} exact duplicates)",
999            original_len,
1000            self.data.len(),
1001            original_len - self.data.len()
1002        );
1003    }
1004
1005    /// Streaming deduplication when merging with existing data (faster for large datasets)
1006    pub fn merge_and_deduplicate(&mut self, new_data: Vec<TrainingData>) {
1007        use std::collections::HashSet;
1008
1009        if new_data.is_empty() {
1010            return;
1011        }
1012
1013        let _original_len = self.data.len();
1014
1015        // Build hashset of existing positions for fast lookup
1016        let mut existing_positions: HashSet<String> = HashSet::with_capacity(self.data.len());
1017        for data in &self.data {
1018            existing_positions.insert(data.board.to_string());
1019        }
1020
1021        // Only add new positions that don't already exist
1022        let mut added = 0;
1023        for data in new_data {
1024            let fen = data.board.to_string();
1025            if existing_positions.insert(fen) {
1026                self.data.push(data);
1027                added += 1;
1028            }
1029        }
1030
1031        println!(
1032            "Streaming merge: added {} unique positions (total: {})",
1033            added,
1034            self.data.len()
1035        );
1036    }
1037
1038    /// Similarity-based deduplication for near-duplicates (O(n²) but optimized)
1039    fn deduplicate_similarity_based(&mut self, similarity_threshold: f32) {
1040        use crate::PositionEncoder;
1041        use ndarray::Array1;
1042
1043        if self.data.is_empty() {
1044            return;
1045        }
1046
1047        let encoder = PositionEncoder::new(1024);
1048        let mut keep_indices: Vec<bool> = vec![true; self.data.len()];
1049
1050        // Encode all positions in parallel
1051        let vectors: Vec<Array1<f32>> = if self.data.len() > 50 {
1052            self.data
1053                .par_iter()
1054                .map(|data| encoder.encode(&data.board))
1055                .collect()
1056        } else {
1057            self.data
1058                .iter()
1059                .map(|data| encoder.encode(&data.board))
1060                .collect()
1061        };
1062
1063        // Compare each position with all previous ones
1064        for i in 1..self.data.len() {
1065            if !keep_indices[i] {
1066                continue;
1067            }
1068
1069            for j in 0..i {
1070                if !keep_indices[j] {
1071                    continue;
1072                }
1073
1074                let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1075                if similarity > similarity_threshold {
1076                    keep_indices[i] = false;
1077                    break;
1078                }
1079            }
1080        }
1081
1082        // Filter data based on keep_indices
1083        let original_len = self.data.len();
1084        self.data = self
1085            .data
1086            .iter()
1087            .enumerate()
1088            .filter_map(|(i, data)| {
1089                if keep_indices[i] {
1090                    Some(data.clone())
1091                } else {
1092                    None
1093                }
1094            })
1095            .collect();
1096
1097        println!(
1098            "Similarity deduplicated: {} -> {} positions (removed {} near-duplicates)",
1099            original_len,
1100            self.data.len(),
1101            original_len - self.data.len()
1102        );
1103    }
1104
1105    /// Remove near-duplicate positions using parallel comparison (faster for large datasets)
1106    pub fn deduplicate_parallel(&mut self, similarity_threshold: f32, chunk_size: usize) {
1107        use crate::PositionEncoder;
1108        use ndarray::Array1;
1109        use std::sync::{Arc, Mutex};
1110
1111        if self.data.is_empty() {
1112            return;
1113        }
1114
1115        let encoder = PositionEncoder::new(1024);
1116
1117        // Encode all positions in parallel
1118        let vectors: Vec<Array1<f32>> = self
1119            .data
1120            .par_iter()
1121            .map(|data| encoder.encode(&data.board))
1122            .collect();
1123
1124        let keep_indices = Arc::new(Mutex::new(vec![true; self.data.len()]));
1125
1126        // Process in chunks to balance parallelism and memory usage
1127        (1..self.data.len())
1128            .collect::<Vec<_>>()
1129            .par_chunks(chunk_size)
1130            .for_each(|chunk| {
1131                for &i in chunk {
1132                    // Check if this position is still being kept
1133                    {
1134                        let indices = keep_indices.lock().unwrap();
1135                        if !indices[i] {
1136                            continue;
1137                        }
1138                    }
1139
1140                    // Compare with all previous positions
1141                    for j in 0..i {
1142                        {
1143                            let indices = keep_indices.lock().unwrap();
1144                            if !indices[j] {
1145                                continue;
1146                            }
1147                        }
1148
1149                        let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1150                        if similarity > similarity_threshold {
1151                            let mut indices = keep_indices.lock().unwrap();
1152                            indices[i] = false;
1153                            break;
1154                        }
1155                    }
1156                }
1157            });
1158
1159        // Filter data based on keep_indices
1160        let keep_indices = keep_indices.lock().unwrap();
1161        let original_len = self.data.len();
1162        self.data = self
1163            .data
1164            .iter()
1165            .enumerate()
1166            .filter_map(|(i, data)| {
1167                if keep_indices[i] {
1168                    Some(data.clone())
1169                } else {
1170                    None
1171                }
1172            })
1173            .collect();
1174
1175        println!(
1176            "Parallel deduplicated: {} -> {} positions (removed {} duplicates)",
1177            original_len,
1178            self.data.len(),
1179            original_len - self.data.len()
1180        );
1181    }
1182
1183    /// Calculate cosine similarity between two vectors
1184    fn cosine_similarity(a: &ndarray::Array1<f32>, b: &ndarray::Array1<f32>) -> f32 {
1185        let dot_product = a.dot(b);
1186        let norm_a = a.dot(a).sqrt();
1187        let norm_b = b.dot(b).sqrt();
1188
1189        if norm_a == 0.0 || norm_b == 0.0 {
1190            0.0
1191        } else {
1192            dot_product / (norm_a * norm_b)
1193        }
1194    }
1195}
1196
1197/// Self-play training system for generating new positions
1198pub struct SelfPlayTrainer {
1199    config: SelfPlayConfig,
1200    game_counter: usize,
1201}
1202
1203impl SelfPlayTrainer {
1204    pub fn new(config: SelfPlayConfig) -> Self {
1205        Self {
1206            config,
1207            game_counter: 0,
1208        }
1209    }
1210
1211    /// Generate training data through self-play games
1212    pub fn generate_training_data(&mut self, engine: &mut ChessVectorEngine) -> TrainingDataset {
1213        let mut dataset = TrainingDataset::new();
1214
1215        println!(
1216            "🎮 Starting self-play training with {} games...",
1217            self.config.games_per_iteration
1218        );
1219        let pb = ProgressBar::new(self.config.games_per_iteration as u64);
1220        if let Ok(style) = ProgressStyle::default_bar().template(
1221            "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
1222        ) {
1223            pb.set_style(style.progress_chars("#>-"));
1224        }
1225
1226        for _ in 0..self.config.games_per_iteration {
1227            let game_data = self.play_single_game(engine);
1228            dataset.data.extend(game_data);
1229            self.game_counter += 1;
1230            pb.inc(1);
1231        }
1232
1233        pb.finish_with_message("Self-play games completed");
1234        println!(
1235            "✅ Generated {} positions from {} games",
1236            dataset.data.len(),
1237            self.config.games_per_iteration
1238        );
1239
1240        dataset
1241    }
1242
1243    /// Play a single self-play game and extract training positions
1244    fn play_single_game(&self, engine: &mut ChessVectorEngine) -> Vec<TrainingData> {
1245        let mut game = Game::new();
1246        let mut positions = Vec::new();
1247        let mut move_count = 0;
1248
1249        // Use opening book for variety if enabled
1250        if self.config.use_opening_book {
1251            if let Some(opening_moves) = self.get_random_opening() {
1252                for mv in opening_moves {
1253                    if game.make_move(mv) {
1254                        move_count += 1;
1255                    } else {
1256                        break;
1257                    }
1258                }
1259            }
1260        }
1261
1262        // Play the game
1263        while game.result().is_none() && move_count < self.config.max_moves_per_game {
1264            let current_position = game.current_position();
1265
1266            // Get engine's move recommendation with exploration
1267            let move_choice = self.select_move_with_exploration(engine, &current_position);
1268
1269            if let Some(chess_move) = move_choice {
1270                // Evaluate the position before making the move
1271                if let Some(evaluation) = engine.evaluate_position(&current_position) {
1272                    // Only include positions with sufficient confidence
1273                    if evaluation.abs() >= self.config.min_confidence || move_count < 10 {
1274                        positions.push(TrainingData {
1275                            board: current_position,
1276                            evaluation,
1277                            depth: 1, // Self-play depth
1278                            game_id: self.game_counter,
1279                        });
1280                    }
1281                }
1282
1283                // Make the move
1284                if !game.make_move(chess_move) {
1285                    break; // Invalid move, end game
1286                }
1287                move_count += 1;
1288            } else {
1289                break; // No legal moves
1290            }
1291        }
1292
1293        // Add final position evaluation based on game result
1294        if let Some(result) = game.result() {
1295            let final_position = game.current_position();
1296            let final_eval = match result {
1297                chess::GameResult::WhiteCheckmates => {
1298                    if final_position.side_to_move() == chess::Color::Black {
1299                        10.0
1300                    } else {
1301                        -10.0
1302                    }
1303                }
1304                chess::GameResult::BlackCheckmates => {
1305                    if final_position.side_to_move() == chess::Color::White {
1306                        10.0
1307                    } else {
1308                        -10.0
1309                    }
1310                }
1311                chess::GameResult::WhiteResigns => -10.0,
1312                chess::GameResult::BlackResigns => 10.0,
1313                chess::GameResult::Stalemate
1314                | chess::GameResult::DrawAccepted
1315                | chess::GameResult::DrawDeclared => 0.0,
1316            };
1317
1318            positions.push(TrainingData {
1319                board: final_position,
1320                evaluation: final_eval,
1321                depth: 1,
1322                game_id: self.game_counter,
1323            });
1324        }
1325
1326        positions
1327    }
1328
1329    /// Select a move with exploration vs exploitation balance
1330    fn select_move_with_exploration(
1331        &self,
1332        engine: &mut ChessVectorEngine,
1333        position: &Board,
1334    ) -> Option<ChessMove> {
1335        let recommendations = engine.recommend_legal_moves(position, 5);
1336
1337        if recommendations.is_empty() {
1338            return None;
1339        }
1340
1341        // Use temperature-based selection for exploration
1342        if fastrand::f32() < self.config.exploration_factor {
1343            // Exploration: weighted random selection based on evaluations
1344            self.select_move_with_temperature(&recommendations)
1345        } else {
1346            // Exploitation: take the best move
1347            Some(recommendations[0].chess_move)
1348        }
1349    }
1350
1351    /// Temperature-based move selection for exploration
1352    fn select_move_with_temperature(
1353        &self,
1354        recommendations: &[crate::MoveRecommendation],
1355    ) -> Option<ChessMove> {
1356        if recommendations.is_empty() {
1357            return None;
1358        }
1359
1360        // Convert evaluations to probabilities using temperature
1361        let mut probabilities = Vec::new();
1362        let mut sum = 0.0;
1363
1364        for rec in recommendations {
1365            // Use average_outcome as evaluation score for temperature selection
1366            let prob = (rec.average_outcome / self.config.temperature).exp();
1367            probabilities.push(prob);
1368            sum += prob;
1369        }
1370
1371        // Normalize probabilities
1372        for prob in &mut probabilities {
1373            *prob /= sum;
1374        }
1375
1376        // Random selection based on probabilities
1377        let rand_val = fastrand::f32();
1378        let mut cumulative = 0.0;
1379
1380        for (i, &prob) in probabilities.iter().enumerate() {
1381            cumulative += prob;
1382            if rand_val <= cumulative {
1383                return Some(recommendations[i].chess_move);
1384            }
1385        }
1386
1387        // Fallback to first move
1388        Some(recommendations[0].chess_move)
1389    }
1390
1391    /// Get random opening moves for variety
1392    fn get_random_opening(&self) -> Option<Vec<ChessMove>> {
1393        let openings = [
1394            // Italian Game
1395            vec!["e4", "e5", "Nf3", "Nc6", "Bc4"],
1396            // Ruy Lopez
1397            vec!["e4", "e5", "Nf3", "Nc6", "Bb5"],
1398            // Queen's Gambit
1399            vec!["d4", "d5", "c4"],
1400            // King's Indian Defense
1401            vec!["d4", "Nf6", "c4", "g6"],
1402            // Sicilian Defense
1403            vec!["e4", "c5"],
1404            // French Defense
1405            vec!["e4", "e6"],
1406            // Caro-Kann Defense
1407            vec!["e4", "c6"],
1408        ];
1409
1410        let selected_opening = &openings[fastrand::usize(0..openings.len())];
1411
1412        let mut moves = Vec::new();
1413        let mut game = Game::new();
1414
1415        for move_str in selected_opening {
1416            if let Ok(chess_move) = ChessMove::from_str(move_str) {
1417                if game.make_move(chess_move) {
1418                    moves.push(chess_move);
1419                } else {
1420                    break;
1421                }
1422            }
1423        }
1424
1425        if moves.is_empty() {
1426            None
1427        } else {
1428            Some(moves)
1429        }
1430    }
1431}
1432
1433/// Engine performance evaluator
1434pub struct EngineEvaluator {
1435    #[allow(dead_code)]
1436    stockfish_depth: u8,
1437}
1438
1439impl EngineEvaluator {
1440    pub fn new(stockfish_depth: u8) -> Self {
1441        Self { stockfish_depth }
1442    }
1443
1444    /// Compare engine evaluations against Stockfish on test set
1445    pub fn evaluate_accuracy(
1446        &self,
1447        engine: &mut ChessVectorEngine,
1448        test_data: &TrainingDataset,
1449    ) -> Result<f32, Box<dyn std::error::Error>> {
1450        let mut total_error = 0.0;
1451        let mut valid_comparisons = 0;
1452
1453        let pb = ProgressBar::new(test_data.data.len() as u64);
1454        pb.set_style(ProgressStyle::default_bar()
1455            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Evaluating accuracy")
1456            .unwrap()
1457            .progress_chars("#>-"));
1458
1459        for data in &test_data.data {
1460            if let Some(engine_eval) = engine.evaluate_position(&data.board) {
1461                let error = (engine_eval - data.evaluation).abs();
1462                total_error += error;
1463                valid_comparisons += 1;
1464            }
1465            pb.inc(1);
1466        }
1467
1468        pb.finish_with_message("Accuracy evaluation complete");
1469
1470        if valid_comparisons > 0 {
1471            let mean_absolute_error = total_error / valid_comparisons as f32;
1472            println!("Mean Absolute Error: {mean_absolute_error:.3} pawns");
1473            println!("Evaluated {valid_comparisons} positions");
1474            Ok(mean_absolute_error)
1475        } else {
1476            Ok(f32::INFINITY)
1477        }
1478    }
1479}
1480
1481// Make TrainingData serializable
1482impl serde::Serialize for TrainingData {
1483    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1484    where
1485        S: serde::Serializer,
1486    {
1487        use serde::ser::SerializeStruct;
1488        let mut state = serializer.serialize_struct("TrainingData", 4)?;
1489        state.serialize_field("fen", &self.board.to_string())?;
1490        state.serialize_field("evaluation", &self.evaluation)?;
1491        state.serialize_field("depth", &self.depth)?;
1492        state.serialize_field("game_id", &self.game_id)?;
1493        state.end()
1494    }
1495}
1496
1497impl<'de> serde::Deserialize<'de> for TrainingData {
1498    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1499    where
1500        D: serde::Deserializer<'de>,
1501    {
1502        use serde::de::{self, MapAccess, Visitor};
1503        use std::fmt;
1504
1505        struct TrainingDataVisitor;
1506
1507        impl<'de> Visitor<'de> for TrainingDataVisitor {
1508            type Value = TrainingData;
1509
1510            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1511                formatter.write_str("struct TrainingData")
1512            }
1513
1514            fn visit_map<V>(self, mut map: V) -> Result<TrainingData, V::Error>
1515            where
1516                V: MapAccess<'de>,
1517            {
1518                let mut fen = None;
1519                let mut evaluation = None;
1520                let mut depth = None;
1521                let mut game_id = None;
1522
1523                while let Some(key) = map.next_key()? {
1524                    match key {
1525                        "fen" => {
1526                            if fen.is_some() {
1527                                return Err(de::Error::duplicate_field("fen"));
1528                            }
1529                            fen = Some(map.next_value()?);
1530                        }
1531                        "evaluation" => {
1532                            if evaluation.is_some() {
1533                                return Err(de::Error::duplicate_field("evaluation"));
1534                            }
1535                            evaluation = Some(map.next_value()?);
1536                        }
1537                        "depth" => {
1538                            if depth.is_some() {
1539                                return Err(de::Error::duplicate_field("depth"));
1540                            }
1541                            depth = Some(map.next_value()?);
1542                        }
1543                        "game_id" => {
1544                            if game_id.is_some() {
1545                                return Err(de::Error::duplicate_field("game_id"));
1546                            }
1547                            game_id = Some(map.next_value()?);
1548                        }
1549                        _ => {
1550                            let _: serde_json::Value = map.next_value()?;
1551                        }
1552                    }
1553                }
1554
1555                let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
1556                let mut evaluation: f32 =
1557                    evaluation.ok_or_else(|| de::Error::missing_field("evaluation"))?;
1558                let depth = depth.ok_or_else(|| de::Error::missing_field("depth"))?;
1559                let game_id = game_id.unwrap_or(0); // Default to 0 for backward compatibility
1560
1561                // Convert evaluation from centipawns to pawns if needed
1562                // If evaluation is outside typical pawn range (-10 to +10),
1563                // assume it's in centipawns and convert to pawns
1564                if evaluation.abs() > 15.0 {
1565                    evaluation /= 100.0;
1566                }
1567
1568                let board =
1569                    Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
1570
1571                Ok(TrainingData {
1572                    board,
1573                    evaluation,
1574                    depth,
1575                    game_id,
1576                })
1577            }
1578        }
1579
1580        const FIELDS: &[&str] = &["fen", "evaluation", "depth", "game_id"];
1581        deserializer.deserialize_struct("TrainingData", FIELDS, TrainingDataVisitor)
1582    }
1583}
1584
1585/// Tactical puzzle parser for Lichess puzzle database
1586pub struct TacticalPuzzleParser;
1587
1588impl TacticalPuzzleParser {
1589    /// Parse Lichess puzzle CSV file with parallel processing
1590    pub fn parse_csv<P: AsRef<Path>>(
1591        file_path: P,
1592        max_puzzles: Option<usize>,
1593        min_rating: Option<u32>,
1594        max_rating: Option<u32>,
1595    ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1596        let file = File::open(&file_path)?;
1597        let file_size = file.metadata()?.len();
1598
1599        // For large files (>100MB), use parallel processing
1600        if file_size > 100_000_000 {
1601            Self::parse_csv_parallel(file_path, max_puzzles, min_rating, max_rating)
1602        } else {
1603            Self::parse_csv_sequential(file_path, max_puzzles, min_rating, max_rating)
1604        }
1605    }
1606
1607    /// Sequential CSV parsing for smaller files
1608    fn parse_csv_sequential<P: AsRef<Path>>(
1609        file_path: P,
1610        max_puzzles: Option<usize>,
1611        min_rating: Option<u32>,
1612        max_rating: Option<u32>,
1613    ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1614        let file = File::open(file_path)?;
1615        let reader = BufReader::new(file);
1616
1617        // Create CSV reader without headers since Lichess CSV has no header row
1618        // Set flexible field count to handle inconsistent CSV structure
1619        let mut csv_reader = csv::ReaderBuilder::new()
1620            .has_headers(false)
1621            .flexible(true) // Allow variable number of fields
1622            .from_reader(reader);
1623
1624        let mut tactical_data = Vec::new();
1625        let mut processed = 0;
1626        let mut skipped = 0;
1627
1628        let pb = ProgressBar::new_spinner();
1629        pb.set_style(
1630            ProgressStyle::default_spinner()
1631                .template("{spinner:.green} Parsing tactical puzzles: {pos} (skipped: {skipped})")
1632                .unwrap(),
1633        );
1634
1635        for result in csv_reader.records() {
1636            let record = match result {
1637                Ok(r) => r,
1638                Err(e) => {
1639                    skipped += 1;
1640                    println!("CSV parsing error: {e}");
1641                    continue;
1642                }
1643            };
1644
1645            if let Some(puzzle_data) = Self::parse_csv_record(&record, min_rating, max_rating) {
1646                if let Some(tactical_data_item) =
1647                    Self::convert_puzzle_to_training_data(&puzzle_data)
1648                {
1649                    tactical_data.push(tactical_data_item);
1650                    processed += 1;
1651
1652                    if let Some(max) = max_puzzles {
1653                        if processed >= max {
1654                            break;
1655                        }
1656                    }
1657                } else {
1658                    skipped += 1;
1659                }
1660            } else {
1661                skipped += 1;
1662            }
1663
1664            pb.set_message(format!(
1665                "Parsing tactical puzzles: {processed} (skipped: {skipped})"
1666            ));
1667        }
1668
1669        pb.finish_with_message(format!("Parsed {processed} puzzles (skipped: {skipped})"));
1670
1671        Ok(tactical_data)
1672    }
1673
1674    /// Parallel CSV parsing for large files
1675    fn parse_csv_parallel<P: AsRef<Path>>(
1676        file_path: P,
1677        max_puzzles: Option<usize>,
1678        min_rating: Option<u32>,
1679        max_rating: Option<u32>,
1680    ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1681        use std::io::Read;
1682
1683        let mut file = File::open(&file_path)?;
1684
1685        // Read entire file into memory for parallel processing
1686        let mut contents = String::new();
1687        file.read_to_string(&mut contents)?;
1688
1689        // Split into lines for parallel processing
1690        let lines: Vec<&str> = contents.lines().collect();
1691
1692        let pb = ProgressBar::new(lines.len() as u64);
1693        pb.set_style(ProgressStyle::default_bar()
1694            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel CSV parsing")
1695            .unwrap()
1696            .progress_chars("#>-"));
1697
1698        // Process lines in parallel
1699        let tactical_data: Vec<TacticalTrainingData> = lines
1700            .par_iter()
1701            .take(max_puzzles.unwrap_or(usize::MAX))
1702            .filter_map(|line| {
1703                // Parse CSV line manually
1704                let fields: Vec<&str> = line.split(',').collect();
1705                if fields.len() < 8 {
1706                    return None;
1707                }
1708
1709                // Build puzzle from fields
1710                if let Some(puzzle_data) = Self::parse_csv_fields(&fields, min_rating, max_rating) {
1711                    Self::convert_puzzle_to_training_data(&puzzle_data)
1712                } else {
1713                    None
1714                }
1715            })
1716            .collect();
1717
1718        pb.finish_with_message(format!(
1719            "Parallel parsing complete: {} puzzles",
1720            tactical_data.len()
1721        ));
1722
1723        Ok(tactical_data)
1724    }
1725
1726    /// Parse CSV record into TacticalPuzzle
1727    fn parse_csv_record(
1728        record: &csv::StringRecord,
1729        min_rating: Option<u32>,
1730        max_rating: Option<u32>,
1731    ) -> Option<TacticalPuzzle> {
1732        // Need at least 8 fields (PuzzleId, FEN, Moves, Rating, RatingDeviation, Popularity, NbPlays, Themes)
1733        if record.len() < 8 {
1734            return None;
1735        }
1736
1737        let rating: u32 = record[3].parse().ok()?;
1738        let rating_deviation: u32 = record[4].parse().ok()?;
1739        let popularity: i32 = record[5].parse().ok()?;
1740        let nb_plays: u32 = record[6].parse().ok()?;
1741
1742        // Apply rating filters
1743        if let Some(min) = min_rating {
1744            if rating < min {
1745                return None;
1746            }
1747        }
1748        if let Some(max) = max_rating {
1749            if rating > max {
1750                return None;
1751            }
1752        }
1753
1754        Some(TacticalPuzzle {
1755            puzzle_id: record[0].to_string(),
1756            fen: record[1].to_string(),
1757            moves: record[2].to_string(),
1758            rating,
1759            rating_deviation,
1760            popularity,
1761            nb_plays,
1762            themes: record[7].to_string(),
1763            game_url: if record.len() > 8 {
1764                Some(record[8].to_string())
1765            } else {
1766                None
1767            },
1768            opening_tags: if record.len() > 9 {
1769                Some(record[9].to_string())
1770            } else {
1771                None
1772            },
1773        })
1774    }
1775
1776    /// Parse CSV fields into TacticalPuzzle (for parallel processing)
1777    fn parse_csv_fields(
1778        fields: &[&str],
1779        min_rating: Option<u32>,
1780        max_rating: Option<u32>,
1781    ) -> Option<TacticalPuzzle> {
1782        if fields.len() < 8 {
1783            return None;
1784        }
1785
1786        let rating: u32 = fields[3].parse().ok()?;
1787        let rating_deviation: u32 = fields[4].parse().ok()?;
1788        let popularity: i32 = fields[5].parse().ok()?;
1789        let nb_plays: u32 = fields[6].parse().ok()?;
1790
1791        // Apply rating filters
1792        if let Some(min) = min_rating {
1793            if rating < min {
1794                return None;
1795            }
1796        }
1797        if let Some(max) = max_rating {
1798            if rating > max {
1799                return None;
1800            }
1801        }
1802
1803        Some(TacticalPuzzle {
1804            puzzle_id: fields[0].to_string(),
1805            fen: fields[1].to_string(),
1806            moves: fields[2].to_string(),
1807            rating,
1808            rating_deviation,
1809            popularity,
1810            nb_plays,
1811            themes: fields[7].to_string(),
1812            game_url: if fields.len() > 8 {
1813                Some(fields[8].to_string())
1814            } else {
1815                None
1816            },
1817            opening_tags: if fields.len() > 9 {
1818                Some(fields[9].to_string())
1819            } else {
1820                None
1821            },
1822        })
1823    }
1824
1825    /// Convert puzzle to training data
1826    fn convert_puzzle_to_training_data(puzzle: &TacticalPuzzle) -> Option<TacticalTrainingData> {
1827        // Parse FEN position
1828        let position = match Board::from_str(&puzzle.fen) {
1829            Ok(board) => board,
1830            Err(_) => return None,
1831        };
1832
1833        // Parse move sequence - first move is the solution
1834        let moves: Vec<&str> = puzzle.moves.split_whitespace().collect();
1835        if moves.is_empty() {
1836            return None;
1837        }
1838
1839        // Parse the solution move (first move in sequence)
1840        let solution_move = match ChessMove::from_str(moves[0]) {
1841            Ok(mv) => mv,
1842            Err(_) => {
1843                // Try parsing as SAN
1844                match ChessMove::from_san(&position, moves[0]) {
1845                    Ok(mv) => mv,
1846                    Err(_) => return None,
1847                }
1848            }
1849        };
1850
1851        // Verify move is legal
1852        let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&position).collect();
1853        if !legal_moves.contains(&solution_move) {
1854            return None;
1855        }
1856
1857        // Extract primary theme
1858        let themes: Vec<&str> = puzzle.themes.split_whitespace().collect();
1859        let primary_theme = themes.first().unwrap_or(&"tactical").to_string();
1860
1861        // Calculate tactical value based on rating and popularity
1862        let difficulty = puzzle.rating as f32 / 1000.0; // Normalize to 0.8-3.0 range
1863        let popularity_bonus = (puzzle.popularity as f32 / 100.0).min(2.0);
1864        let tactical_value = difficulty + popularity_bonus; // High value for move outcome
1865
1866        Some(TacticalTrainingData {
1867            position,
1868            solution_move,
1869            move_theme: primary_theme,
1870            difficulty,
1871            tactical_value,
1872        })
1873    }
1874
1875    /// Load tactical training data into chess engine
1876    pub fn load_into_engine(
1877        tactical_data: &[TacticalTrainingData],
1878        engine: &mut ChessVectorEngine,
1879    ) {
1880        let pb = ProgressBar::new(tactical_data.len() as u64);
1881        pb.set_style(ProgressStyle::default_bar()
1882            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns")
1883            .unwrap()
1884            .progress_chars("#>-"));
1885
1886        for data in tactical_data {
1887            // Add position with high-value tactical move
1888            engine.add_position_with_move(
1889                &data.position,
1890                0.0, // Position evaluation (neutral for puzzles)
1891                Some(data.solution_move),
1892                Some(data.tactical_value), // High tactical value
1893            );
1894            pb.inc(1);
1895        }
1896
1897        pb.finish_with_message(format!("Loaded {} tactical patterns", tactical_data.len()));
1898    }
1899
1900    /// Load tactical training data into chess engine incrementally (preserves existing data)
1901    pub fn load_into_engine_incremental(
1902        tactical_data: &[TacticalTrainingData],
1903        engine: &mut ChessVectorEngine,
1904    ) {
1905        let initial_size = engine.knowledge_base_size();
1906        let initial_moves = engine.position_moves.len();
1907
1908        // For large datasets, use parallel batch processing
1909        if tactical_data.len() > 1000 {
1910            Self::load_into_engine_incremental_parallel(
1911                tactical_data,
1912                engine,
1913                initial_size,
1914                initial_moves,
1915            );
1916        } else {
1917            Self::load_into_engine_incremental_sequential(
1918                tactical_data,
1919                engine,
1920                initial_size,
1921                initial_moves,
1922            );
1923        }
1924    }
1925
1926    /// Sequential loading for smaller datasets
1927    fn load_into_engine_incremental_sequential(
1928        tactical_data: &[TacticalTrainingData],
1929        engine: &mut ChessVectorEngine,
1930        initial_size: usize,
1931        initial_moves: usize,
1932    ) {
1933        let pb = ProgressBar::new(tactical_data.len() as u64);
1934        pb.set_style(ProgressStyle::default_bar()
1935            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns (incremental)")
1936            .unwrap()
1937            .progress_chars("#>-"));
1938
1939        let mut added = 0;
1940        let mut skipped = 0;
1941
1942        for data in tactical_data {
1943            // Check if this position already exists to avoid duplicates
1944            if !engine.position_boards.contains(&data.position) {
1945                engine.add_position_with_move(
1946                    &data.position,
1947                    0.0, // Position evaluation (neutral for puzzles)
1948                    Some(data.solution_move),
1949                    Some(data.tactical_value), // High tactical value
1950                );
1951                added += 1;
1952            } else {
1953                skipped += 1;
1954            }
1955            pb.inc(1);
1956        }
1957
1958        pb.finish_with_message(format!(
1959            "Loaded {} new tactical patterns (skipped {} duplicates, total: {})",
1960            added,
1961            skipped,
1962            engine.knowledge_base_size()
1963        ));
1964
1965        println!("Incremental tactical training:");
1966        println!(
1967            "  - Positions: {} → {} (+{})",
1968            initial_size,
1969            engine.knowledge_base_size(),
1970            engine.knowledge_base_size() - initial_size
1971        );
1972        println!(
1973            "  - Move entries: {} → {} (+{})",
1974            initial_moves,
1975            engine.position_moves.len(),
1976            engine.position_moves.len() - initial_moves
1977        );
1978    }
1979
1980    /// Parallel batch loading for large datasets
1981    fn load_into_engine_incremental_parallel(
1982        tactical_data: &[TacticalTrainingData],
1983        engine: &mut ChessVectorEngine,
1984        initial_size: usize,
1985        initial_moves: usize,
1986    ) {
1987        let pb = ProgressBar::new(tactical_data.len() as u64);
1988        pb.set_style(ProgressStyle::default_bar()
1989            .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Optimized batch loading tactical patterns")
1990            .unwrap()
1991            .progress_chars("#>-"));
1992
1993        // Parallel pre-filtering to avoid duplicates (this is thread-safe for read operations)
1994        let filtered_data: Vec<&TacticalTrainingData> = tactical_data
1995            .par_iter()
1996            .filter(|data| !engine.position_boards.contains(&data.position))
1997            .collect();
1998
1999        let batch_size = 1000; // Larger batches for better performance
2000        let mut added = 0;
2001
2002        println!(
2003            "Pre-filtered: {} → {} positions (removed {} duplicates)",
2004            tactical_data.len(),
2005            filtered_data.len(),
2006            tactical_data.len() - filtered_data.len()
2007        );
2008
2009        // Process in sequential batches (engine operations aren't thread-safe)
2010        // But use optimized batch processing
2011        for batch in filtered_data.chunks(batch_size) {
2012            let batch_start = added;
2013
2014            for data in batch {
2015                // Final duplicate check (should be minimal after pre-filtering)
2016                if !engine.position_boards.contains(&data.position) {
2017                    engine.add_position_with_move(
2018                        &data.position,
2019                        0.0, // Position evaluation (neutral for puzzles)
2020                        Some(data.solution_move),
2021                        Some(data.tactical_value), // High tactical value
2022                    );
2023                    added += 1;
2024                }
2025                pb.inc(1);
2026            }
2027
2028            // Update progress message every batch
2029            pb.set_message(format!("Loaded batch: {} positions", added - batch_start));
2030        }
2031
2032        let skipped = tactical_data.len() - added;
2033
2034        pb.finish_with_message(format!(
2035            "Optimized loaded {} new tactical patterns (skipped {} duplicates, total: {})",
2036            added,
2037            skipped,
2038            engine.knowledge_base_size()
2039        ));
2040
2041        println!("Incremental tactical training (optimized):");
2042        println!(
2043            "  - Positions: {} → {} (+{})",
2044            initial_size,
2045            engine.knowledge_base_size(),
2046            engine.knowledge_base_size() - initial_size
2047        );
2048        println!(
2049            "  - Move entries: {} → {} (+{})",
2050            initial_moves,
2051            engine.position_moves.len(),
2052            engine.position_moves.len() - initial_moves
2053        );
2054        println!(
2055            "  - Batch size: {}, Pre-filtered efficiency: {:.1}%",
2056            batch_size,
2057            (filtered_data.len() as f32 / tactical_data.len() as f32) * 100.0
2058        );
2059    }
2060
2061    /// Save tactical puzzles to file for incremental loading later
2062    pub fn save_tactical_puzzles<P: AsRef<std::path::Path>>(
2063        tactical_data: &[TacticalTrainingData],
2064        path: P,
2065    ) -> Result<(), Box<dyn std::error::Error>> {
2066        let json = serde_json::to_string_pretty(tactical_data)?;
2067        std::fs::write(path, json)?;
2068        println!("Saved {} tactical puzzles", tactical_data.len());
2069        Ok(())
2070    }
2071
2072    /// Load tactical puzzles from file
2073    pub fn load_tactical_puzzles<P: AsRef<std::path::Path>>(
2074        path: P,
2075    ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
2076        let content = std::fs::read_to_string(path)?;
2077        let tactical_data: Vec<TacticalTrainingData> = serde_json::from_str(&content)?;
2078        println!("Loaded {} tactical puzzles from file", tactical_data.len());
2079        Ok(tactical_data)
2080    }
2081
2082    /// Save tactical puzzles incrementally (appends to existing file)
2083    pub fn save_tactical_puzzles_incremental<P: AsRef<std::path::Path>>(
2084        tactical_data: &[TacticalTrainingData],
2085        path: P,
2086    ) -> Result<(), Box<dyn std::error::Error>> {
2087        let path = path.as_ref();
2088
2089        if path.exists() {
2090            // Load existing puzzles
2091            let mut existing = Self::load_tactical_puzzles(path)?;
2092            let original_len = existing.len();
2093
2094            // Add new puzzles, checking for duplicates by puzzle ID if available
2095            for new_puzzle in tactical_data {
2096                // Check if this puzzle already exists (by position)
2097                let exists = existing.iter().any(|existing_puzzle| {
2098                    existing_puzzle.position == new_puzzle.position
2099                        && existing_puzzle.solution_move == new_puzzle.solution_move
2100                });
2101
2102                if !exists {
2103                    existing.push(new_puzzle.clone());
2104                }
2105            }
2106
2107            // Save merged data
2108            let json = serde_json::to_string_pretty(&existing)?;
2109            std::fs::write(path, json)?;
2110
2111            println!(
2112                "Incremental save: added {} new puzzles (total: {})",
2113                existing.len() - original_len,
2114                existing.len()
2115            );
2116        } else {
2117            // File doesn't exist, just save normally
2118            Self::save_tactical_puzzles(tactical_data, path)?;
2119        }
2120        Ok(())
2121    }
2122
2123    /// Parse Lichess puzzles incrementally (preserves existing engine state)
2124    pub fn parse_and_load_incremental<P: AsRef<std::path::Path>>(
2125        file_path: P,
2126        engine: &mut ChessVectorEngine,
2127        max_puzzles: Option<usize>,
2128        min_rating: Option<u32>,
2129        max_rating: Option<u32>,
2130    ) -> Result<(), Box<dyn std::error::Error>> {
2131        println!("Parsing Lichess puzzles incrementally...");
2132
2133        // Parse puzzles
2134        let tactical_data = Self::parse_csv(file_path, max_puzzles, min_rating, max_rating)?;
2135
2136        // Load into engine incrementally
2137        Self::load_into_engine_incremental(&tactical_data, engine);
2138
2139        Ok(())
2140    }
2141}
2142
2143#[cfg(test)]
2144mod tests {
2145    use super::*;
2146    use chess::Board;
2147    use std::str::FromStr;
2148
2149    #[test]
2150    fn test_training_dataset_creation() {
2151        let dataset = TrainingDataset::new();
2152        assert_eq!(dataset.data.len(), 0);
2153    }
2154
2155    #[test]
2156    fn test_add_training_data() {
2157        let mut dataset = TrainingDataset::new();
2158        let board = Board::default();
2159
2160        let training_data = TrainingData {
2161            board,
2162            evaluation: 0.5,
2163            depth: 15,
2164            game_id: 1,
2165        };
2166
2167        dataset.data.push(training_data);
2168        assert_eq!(dataset.data.len(), 1);
2169        assert_eq!(dataset.data[0].evaluation, 0.5);
2170    }
2171
2172    #[test]
2173    fn test_chess_engine_integration() {
2174        let mut dataset = TrainingDataset::new();
2175        let board = Board::default();
2176
2177        let training_data = TrainingData {
2178            board,
2179            evaluation: 0.3,
2180            depth: 15,
2181            game_id: 1,
2182        };
2183
2184        dataset.data.push(training_data);
2185
2186        let mut engine = ChessVectorEngine::new(1024);
2187        dataset.train_engine(&mut engine);
2188
2189        assert_eq!(engine.knowledge_base_size(), 1);
2190
2191        let eval = engine.evaluate_position(&board);
2192        assert!(eval.is_some());
2193        assert!((eval.unwrap() - 0.3).abs() < 1e-6);
2194    }
2195
2196    #[test]
2197    fn test_deduplication() {
2198        let mut dataset = TrainingDataset::new();
2199        let board = Board::default();
2200
2201        // Add duplicate positions
2202        for i in 0..5 {
2203            let training_data = TrainingData {
2204                board,
2205                evaluation: i as f32 * 0.1,
2206                depth: 15,
2207                game_id: i,
2208            };
2209            dataset.data.push(training_data);
2210        }
2211
2212        assert_eq!(dataset.data.len(), 5);
2213
2214        // Deduplicate with high threshold (should keep only 1)
2215        dataset.deduplicate(0.999);
2216        assert_eq!(dataset.data.len(), 1);
2217    }
2218
2219    #[test]
2220    fn test_dataset_serialization() {
2221        let mut dataset = TrainingDataset::new();
2222        let board =
2223            Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2224
2225        let training_data = TrainingData {
2226            board,
2227            evaluation: 0.2,
2228            depth: 10,
2229            game_id: 42,
2230        };
2231
2232        dataset.data.push(training_data);
2233
2234        // Test serialization/deserialization
2235        let json = serde_json::to_string(&dataset.data).unwrap();
2236        let loaded_data: Vec<TrainingData> = serde_json::from_str(&json).unwrap();
2237        let loaded_dataset = TrainingDataset { data: loaded_data };
2238
2239        assert_eq!(loaded_dataset.data.len(), 1);
2240        assert_eq!(loaded_dataset.data[0].evaluation, 0.2);
2241        assert_eq!(loaded_dataset.data[0].depth, 10);
2242        assert_eq!(loaded_dataset.data[0].game_id, 42);
2243    }
2244
2245    #[test]
2246    fn test_tactical_puzzle_processing() {
2247        let puzzle = TacticalPuzzle {
2248            puzzle_id: "test123".to_string(),
2249            fen: "r1bqkbnr/pppp1ppp/2n5/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4".to_string(),
2250            moves: "Bxf7+ Ke7".to_string(),
2251            rating: 1500,
2252            rating_deviation: 100,
2253            popularity: 150,
2254            nb_plays: 1000,
2255            themes: "fork pin".to_string(),
2256            game_url: None,
2257            opening_tags: None,
2258        };
2259
2260        let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2261        assert!(tactical_data.is_some());
2262
2263        let data = tactical_data.unwrap();
2264        assert_eq!(data.move_theme, "fork");
2265        assert!(data.tactical_value > 1.0); // Should have high tactical value
2266        assert!(data.difficulty > 0.0);
2267    }
2268
2269    #[test]
2270    fn test_tactical_puzzle_invalid_fen() {
2271        let puzzle = TacticalPuzzle {
2272            puzzle_id: "test123".to_string(),
2273            fen: "invalid_fen".to_string(),
2274            moves: "e2e4".to_string(),
2275            rating: 1500,
2276            rating_deviation: 100,
2277            popularity: 150,
2278            nb_plays: 1000,
2279            themes: "tactics".to_string(),
2280            game_url: None,
2281            opening_tags: None,
2282        };
2283
2284        let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2285        assert!(tactical_data.is_none());
2286    }
2287
2288    #[test]
2289    fn test_engine_evaluator() {
2290        let evaluator = EngineEvaluator::new(15);
2291
2292        // Create test dataset
2293        let mut dataset = TrainingDataset::new();
2294        let board = Board::default();
2295
2296        let training_data = TrainingData {
2297            board,
2298            evaluation: 0.0,
2299            depth: 15,
2300            game_id: 1,
2301        };
2302
2303        dataset.data.push(training_data);
2304
2305        // Create engine with some data
2306        let mut engine = ChessVectorEngine::new(1024);
2307        engine.add_position(&board, 0.1);
2308
2309        // Test accuracy evaluation
2310        let accuracy = evaluator.evaluate_accuracy(&mut engine, &dataset);
2311        assert!(accuracy.is_ok());
2312        assert!(accuracy.unwrap() < 1.0); // Should have some accuracy
2313    }
2314
2315    #[test]
2316    fn test_tactical_training_integration() {
2317        let tactical_data = vec![TacticalTrainingData {
2318            position: Board::default(),
2319            solution_move: ChessMove::from_str("e2e4").unwrap(),
2320            move_theme: "opening".to_string(),
2321            difficulty: 1.2,
2322            tactical_value: 2.5,
2323        }];
2324
2325        let mut engine = ChessVectorEngine::new(1024);
2326        TacticalPuzzleParser::load_into_engine(&tactical_data, &mut engine);
2327
2328        assert_eq!(engine.knowledge_base_size(), 1);
2329        assert_eq!(engine.position_moves.len(), 1);
2330
2331        // Test that tactical move is available in recommendations
2332        let recommendations = engine.recommend_moves(&Board::default(), 5);
2333        assert!(!recommendations.is_empty());
2334    }
2335
2336    #[test]
2337    fn test_multithreading_operations() {
2338        let mut dataset = TrainingDataset::new();
2339        let board = Board::default();
2340
2341        // Add test data
2342        for i in 0..10 {
2343            let training_data = TrainingData {
2344                board,
2345                evaluation: i as f32 * 0.1,
2346                depth: 15,
2347                game_id: i,
2348            };
2349            dataset.data.push(training_data);
2350        }
2351
2352        // Test parallel deduplication doesn't crash
2353        dataset.deduplicate_parallel(0.95, 5);
2354        assert!(dataset.data.len() <= 10);
2355    }
2356
2357    #[test]
2358    fn test_incremental_dataset_operations() {
2359        let mut dataset1 = TrainingDataset::new();
2360        let board1 = Board::default();
2361        let board2 =
2362            Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2363
2364        // Add initial data
2365        dataset1.add_position(board1, 0.0, 15, 1);
2366        dataset1.add_position(board2, 0.2, 15, 2);
2367        assert_eq!(dataset1.data.len(), 2);
2368
2369        // Create second dataset
2370        let mut dataset2 = TrainingDataset::new();
2371        dataset2.add_position(
2372            Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2373                .unwrap(),
2374            0.3,
2375            15,
2376            3,
2377        );
2378
2379        // Merge datasets
2380        dataset1.merge(dataset2);
2381        assert_eq!(dataset1.data.len(), 3);
2382
2383        // Test next_game_id
2384        let next_id = dataset1.next_game_id();
2385        assert_eq!(next_id, 4); // Should be max(1,2,3) + 1
2386    }
2387
2388    #[test]
2389    fn test_save_load_incremental() {
2390        use tempfile::tempdir;
2391
2392        let temp_dir = tempdir().unwrap();
2393        let file_path = temp_dir.path().join("incremental_test.json");
2394
2395        // Create and save first dataset
2396        let mut dataset1 = TrainingDataset::new();
2397        dataset1.add_position(Board::default(), 0.0, 15, 1);
2398        dataset1.save(&file_path).unwrap();
2399
2400        // Create second dataset and save incrementally
2401        let mut dataset2 = TrainingDataset::new();
2402        dataset2.add_position(
2403            Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap(),
2404            0.2,
2405            15,
2406            2,
2407        );
2408        dataset2.save_incremental(&file_path).unwrap();
2409
2410        // Load and verify merged data
2411        let loaded = TrainingDataset::load(&file_path).unwrap();
2412        assert_eq!(loaded.data.len(), 2);
2413
2414        // Test load_and_append
2415        let mut dataset3 = TrainingDataset::new();
2416        dataset3.add_position(
2417            Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2418                .unwrap(),
2419            0.3,
2420            15,
2421            3,
2422        );
2423        dataset3.load_and_append(&file_path).unwrap();
2424        assert_eq!(dataset3.data.len(), 3); // 1 original + 2 from file
2425    }
2426
2427    #[test]
2428    fn test_add_position_method() {
2429        let mut dataset = TrainingDataset::new();
2430        let board = Board::default();
2431
2432        // Test add_position method
2433        dataset.add_position(board, 0.5, 20, 42);
2434        assert_eq!(dataset.data.len(), 1);
2435        assert_eq!(dataset.data[0].evaluation, 0.5);
2436        assert_eq!(dataset.data[0].depth, 20);
2437        assert_eq!(dataset.data[0].game_id, 42);
2438    }
2439
2440    #[test]
2441    fn test_incremental_save_deduplication() {
2442        use tempfile::tempdir;
2443
2444        let temp_dir = tempdir().unwrap();
2445        let file_path = temp_dir.path().join("dedup_test.json");
2446
2447        // Create and save first dataset
2448        let mut dataset1 = TrainingDataset::new();
2449        dataset1.add_position(Board::default(), 0.0, 15, 1);
2450        dataset1.save(&file_path).unwrap();
2451
2452        // Create second dataset with duplicate position
2453        let mut dataset2 = TrainingDataset::new();
2454        dataset2.add_position(Board::default(), 0.1, 15, 2); // Same position, different eval
2455        dataset2.save_incremental(&file_path).unwrap();
2456
2457        // Should deduplicate and keep only one
2458        let loaded = TrainingDataset::load(&file_path).unwrap();
2459        assert_eq!(loaded.data.len(), 1);
2460    }
2461
2462    #[test]
2463    fn test_tactical_puzzle_incremental_loading() {
2464        let tactical_data = vec![
2465            TacticalTrainingData {
2466                position: Board::default(),
2467                solution_move: ChessMove::from_str("e2e4").unwrap(),
2468                move_theme: "opening".to_string(),
2469                difficulty: 1.2,
2470                tactical_value: 2.5,
2471            },
2472            TacticalTrainingData {
2473                position: Board::from_str(
2474                    "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1",
2475                )
2476                .unwrap(),
2477                solution_move: ChessMove::from_str("e7e5").unwrap(),
2478                move_theme: "opening".to_string(),
2479                difficulty: 1.0,
2480                tactical_value: 2.0,
2481            },
2482        ];
2483
2484        let mut engine = ChessVectorEngine::new(1024);
2485
2486        // Add some existing data
2487        engine.add_position(&Board::default(), 0.1);
2488        assert_eq!(engine.knowledge_base_size(), 1);
2489
2490        // Load tactical puzzles incrementally
2491        TacticalPuzzleParser::load_into_engine_incremental(&tactical_data, &mut engine);
2492
2493        // Should have added the new position but skipped the duplicate
2494        assert_eq!(engine.knowledge_base_size(), 2);
2495
2496        // Should have move data for both puzzles
2497        assert!(engine.training_stats().has_move_data);
2498        assert!(engine.training_stats().move_data_entries > 0);
2499    }
2500
2501    #[test]
2502    fn test_tactical_puzzle_serialization() {
2503        use tempfile::tempdir;
2504
2505        let temp_dir = tempdir().unwrap();
2506        let file_path = temp_dir.path().join("tactical_test.json");
2507
2508        let tactical_data = vec![TacticalTrainingData {
2509            position: Board::default(),
2510            solution_move: ChessMove::from_str("e2e4").unwrap(),
2511            move_theme: "fork".to_string(),
2512            difficulty: 1.5,
2513            tactical_value: 3.0,
2514        }];
2515
2516        // Save tactical puzzles
2517        TacticalPuzzleParser::save_tactical_puzzles(&tactical_data, &file_path).unwrap();
2518
2519        // Load them back
2520        let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2521        assert_eq!(loaded.len(), 1);
2522        assert_eq!(loaded[0].move_theme, "fork");
2523        assert_eq!(loaded[0].difficulty, 1.5);
2524        assert_eq!(loaded[0].tactical_value, 3.0);
2525    }
2526
2527    #[test]
2528    fn test_tactical_puzzle_incremental_save() {
2529        use tempfile::tempdir;
2530
2531        let temp_dir = tempdir().unwrap();
2532        let file_path = temp_dir.path().join("incremental_tactical.json");
2533
2534        // Save first batch
2535        let batch1 = vec![TacticalTrainingData {
2536            position: Board::default(),
2537            solution_move: ChessMove::from_str("e2e4").unwrap(),
2538            move_theme: "opening".to_string(),
2539            difficulty: 1.0,
2540            tactical_value: 2.0,
2541        }];
2542        TacticalPuzzleParser::save_tactical_puzzles(&batch1, &file_path).unwrap();
2543
2544        // Save second batch incrementally
2545        let batch2 = vec![TacticalTrainingData {
2546            position: Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1")
2547                .unwrap(),
2548            solution_move: ChessMove::from_str("e7e5").unwrap(),
2549            move_theme: "counter".to_string(),
2550            difficulty: 1.2,
2551            tactical_value: 2.2,
2552        }];
2553        TacticalPuzzleParser::save_tactical_puzzles_incremental(&batch2, &file_path).unwrap();
2554
2555        // Load and verify merged data
2556        let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2557        assert_eq!(loaded.len(), 2);
2558    }
2559
2560    #[test]
2561    fn test_tactical_puzzle_incremental_deduplication() {
2562        use tempfile::tempdir;
2563
2564        let temp_dir = tempdir().unwrap();
2565        let file_path = temp_dir.path().join("dedup_tactical.json");
2566
2567        let tactical_data = TacticalTrainingData {
2568            position: Board::default(),
2569            solution_move: ChessMove::from_str("e2e4").unwrap(),
2570            move_theme: "opening".to_string(),
2571            difficulty: 1.0,
2572            tactical_value: 2.0,
2573        };
2574
2575        // Save first time
2576        TacticalPuzzleParser::save_tactical_puzzles(&[tactical_data.clone()], &file_path).unwrap();
2577
2578        // Try to save the same puzzle again
2579        TacticalPuzzleParser::save_tactical_puzzles_incremental(&[tactical_data], &file_path)
2580            .unwrap();
2581
2582        // Should still have only one puzzle (deduplicated)
2583        let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2584        assert_eq!(loaded.len(), 1);
2585    }
2586}