1use chess::{Board, ChessMove, Game, MoveGen};
2use indicatif::{ProgressBar, ProgressStyle};
3use pgn_reader::{BufferedReader, RawHeader, SanPlus, Skip, Visitor};
4use rayon::prelude::*;
5use serde::{Deserialize, Serialize};
6use std::fs::File;
7use std::io::{BufRead, BufReader, BufWriter, Write};
8use std::path::Path;
9use std::process::{Child, Command, Stdio};
10use std::str::FromStr;
11use std::sync::{Arc, Mutex};
12
13use crate::ChessVectorEngine;
14
15#[derive(Debug, Clone)]
17pub struct SelfPlayConfig {
18 pub games_per_iteration: usize,
20 pub max_moves_per_game: usize,
22 pub exploration_factor: f32,
24 pub min_confidence: f32,
26 pub use_opening_book: bool,
28 pub temperature: f32,
30}
31
32impl Default for SelfPlayConfig {
33 fn default() -> Self {
34 Self {
35 games_per_iteration: 100,
36 max_moves_per_game: 200,
37 exploration_factor: 0.3,
38 min_confidence: 0.1,
39 use_opening_book: true,
40 temperature: 0.8,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct TrainingData {
48 pub board: Board,
49 pub evaluation: f32,
50 pub depth: u8,
51 pub game_id: usize,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct TacticalPuzzle {
57 #[serde(rename = "PuzzleId")]
58 pub puzzle_id: String,
59 #[serde(rename = "FEN")]
60 pub fen: String,
61 #[serde(rename = "Moves")]
62 pub moves: String, #[serde(rename = "Rating")]
64 pub rating: u32,
65 #[serde(rename = "RatingDeviation")]
66 pub rating_deviation: u32,
67 #[serde(rename = "Popularity")]
68 pub popularity: i32,
69 #[serde(rename = "NbPlays")]
70 pub nb_plays: u32,
71 #[serde(rename = "Themes")]
72 pub themes: String, #[serde(rename = "GameUrl")]
74 pub game_url: Option<String>,
75 #[serde(rename = "OpeningTags")]
76 pub opening_tags: Option<String>,
77}
78
79#[derive(Debug, Clone)]
81pub struct TacticalTrainingData {
82 pub position: Board,
83 pub solution_move: ChessMove,
84 pub move_theme: String,
85 pub difficulty: f32, pub tactical_value: f32, }
88
89impl serde::Serialize for TacticalTrainingData {
91 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
92 where
93 S: serde::Serializer,
94 {
95 use serde::ser::SerializeStruct;
96 let mut state = serializer.serialize_struct("TacticalTrainingData", 5)?;
97 state.serialize_field("fen", &self.position.to_string())?;
98 state.serialize_field("solution_move", &self.solution_move.to_string())?;
99 state.serialize_field("move_theme", &self.move_theme)?;
100 state.serialize_field("difficulty", &self.difficulty)?;
101 state.serialize_field("tactical_value", &self.tactical_value)?;
102 state.end()
103 }
104}
105
106impl<'de> serde::Deserialize<'de> for TacticalTrainingData {
107 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108 where
109 D: serde::Deserializer<'de>,
110 {
111 use serde::de::{self, MapAccess, Visitor};
112 use std::fmt;
113
114 struct TacticalTrainingDataVisitor;
115
116 impl<'de> Visitor<'de> for TacticalTrainingDataVisitor {
117 type Value = TacticalTrainingData;
118
119 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
120 formatter.write_str("struct TacticalTrainingData")
121 }
122
123 fn visit_map<V>(self, mut map: V) -> Result<TacticalTrainingData, V::Error>
124 where
125 V: MapAccess<'de>,
126 {
127 let mut fen = None;
128 let mut solution_move = None;
129 let mut move_theme = None;
130 let mut difficulty = None;
131 let mut tactical_value = None;
132
133 while let Some(key) = map.next_key()? {
134 match key {
135 "fen" => {
136 if fen.is_some() {
137 return Err(de::Error::duplicate_field("fen"));
138 }
139 fen = Some(map.next_value()?);
140 }
141 "solution_move" => {
142 if solution_move.is_some() {
143 return Err(de::Error::duplicate_field("solution_move"));
144 }
145 solution_move = Some(map.next_value()?);
146 }
147 "move_theme" => {
148 if move_theme.is_some() {
149 return Err(de::Error::duplicate_field("move_theme"));
150 }
151 move_theme = Some(map.next_value()?);
152 }
153 "difficulty" => {
154 if difficulty.is_some() {
155 return Err(de::Error::duplicate_field("difficulty"));
156 }
157 difficulty = Some(map.next_value()?);
158 }
159 "tactical_value" => {
160 if tactical_value.is_some() {
161 return Err(de::Error::duplicate_field("tactical_value"));
162 }
163 tactical_value = Some(map.next_value()?);
164 }
165 _ => {
166 let _: serde_json::Value = map.next_value()?;
167 }
168 }
169 }
170
171 let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
172 let solution_move_str: String =
173 solution_move.ok_or_else(|| de::Error::missing_field("solution_move"))?;
174 let move_theme =
175 move_theme.ok_or_else(|| de::Error::missing_field("move_theme"))?;
176 let difficulty =
177 difficulty.ok_or_else(|| de::Error::missing_field("difficulty"))?;
178 let tactical_value =
179 tactical_value.ok_or_else(|| de::Error::missing_field("tactical_value"))?;
180
181 let position =
182 Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
183
184 let solution_move = ChessMove::from_str(&solution_move_str)
185 .map_err(|e| de::Error::custom(format!("Error: {e}")))?;
186
187 Ok(TacticalTrainingData {
188 position,
189 solution_move,
190 move_theme,
191 difficulty,
192 tactical_value,
193 })
194 }
195 }
196
197 const FIELDS: &[&str] = &[
198 "fen",
199 "solution_move",
200 "move_theme",
201 "difficulty",
202 "tactical_value",
203 ];
204 deserializer.deserialize_struct("TacticalTrainingData", FIELDS, TacticalTrainingDataVisitor)
205 }
206}
207
208pub struct GameExtractor {
210 pub positions: Vec<TrainingData>,
211 pub current_game: Game,
212 pub move_count: usize,
213 pub max_moves_per_game: usize,
214 pub game_id: usize,
215}
216
217impl GameExtractor {
218 pub fn new(max_moves_per_game: usize) -> Self {
219 Self {
220 positions: Vec::new(),
221 current_game: Game::new(),
222 move_count: 0,
223 max_moves_per_game,
224 game_id: 0,
225 }
226 }
227}
228
229impl Visitor for GameExtractor {
230 type Result = ();
231
232 fn begin_game(&mut self) {
233 self.current_game = Game::new();
234 self.move_count = 0;
235 self.game_id += 1;
236 }
237
238 fn header(&mut self, _key: &[u8], _value: RawHeader<'_>) {}
239
240 fn san(&mut self, san_plus: SanPlus) {
241 if self.move_count >= self.max_moves_per_game {
242 return;
243 }
244
245 let san_str = san_plus.san.to_string();
246
247 let current_pos = self.current_game.current_position();
249
250 match chess::ChessMove::from_san(¤t_pos, &san_str) {
252 Ok(chess_move) => {
253 let legal_moves: Vec<chess::ChessMove> = MoveGen::new_legal(¤t_pos).collect();
255 if legal_moves.contains(&chess_move) {
256 if self.current_game.make_move(chess_move) {
257 self.move_count += 1;
258
259 self.positions.push(TrainingData {
261 board: self.current_game.current_position(),
262 evaluation: 0.0, depth: 0,
264 game_id: self.game_id,
265 });
266 }
267 } else {
268 }
270 }
271 Err(_) => {
272 if !san_str.contains("O-O") && !san_str.contains("=") && san_str.len() > 6 {
276 }
278 }
279 }
280 }
281
282 fn begin_variation(&mut self) -> Skip {
283 Skip(true) }
285
286 fn end_game(&mut self) -> Self::Result {}
287}
288
289pub struct StockfishEvaluator {
291 depth: u8,
292}
293
294impl StockfishEvaluator {
295 pub fn new(depth: u8) -> Self {
296 Self { depth }
297 }
298
299 pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
301 let mut child = Command::new("stockfish")
302 .stdin(Stdio::piped())
303 .stdout(Stdio::piped())
304 .stderr(Stdio::piped())
305 .spawn()?;
306
307 let stdin = child
308 .stdin
309 .as_mut()
310 .ok_or("Failed to get stdin handle for Stockfish process")?;
311 let fen = board.to_string();
312
313 use std::io::Write;
315 writeln!(stdin, "uci")?;
316 writeln!(stdin, "isready")?;
317 writeln!(stdin, "position fen {fen}")?;
318 writeln!(stdin, "go depth {}", self.depth)?;
319 writeln!(stdin, "quit")?;
320
321 let output = child.wait_with_output()?;
322 let stdout = String::from_utf8_lossy(&output.stdout);
323
324 for line in stdout.lines() {
326 if line.starts_with("info") && line.contains("score cp") {
327 if let Some(cp_pos) = line.find("score cp ") {
328 let cp_str = &line[cp_pos + 9..];
329 if let Some(end) = cp_str.find(' ') {
330 let cp_value = cp_str[..end].parse::<i32>()?;
331 return Ok(cp_value as f32 / 100.0); }
333 }
334 } else if line.starts_with("info") && line.contains("score mate") {
335 if let Some(mate_pos) = line.find("score mate ") {
337 let mate_str = &line[mate_pos + 11..];
338 if let Some(end) = mate_str.find(' ') {
339 let mate_moves = mate_str[..end].parse::<i32>()?;
340 return Ok(if mate_moves > 0 { 100.0 } else { -100.0 });
341 }
342 }
343 }
344 }
345
346 Ok(0.0) }
348
349 pub fn evaluate_batch(
351 &self,
352 positions: &mut [TrainingData],
353 ) -> Result<(), Box<dyn std::error::Error>> {
354 let pb = ProgressBar::new(positions.len() as u64);
355 if let Ok(style) = ProgressStyle::default_bar().template(
356 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
357 ) {
358 pb.set_style(style.progress_chars("#>-"));
359 }
360
361 for data in positions.iter_mut() {
362 match self.evaluate_position(&data.board) {
363 Ok(eval) => {
364 data.evaluation = eval;
365 data.depth = self.depth;
366 }
367 Err(e) => {
368 eprintln!("Evaluation error: {e}");
369 data.evaluation = 0.0;
370 }
371 }
372 pb.inc(1);
373 }
374
375 pb.finish_with_message("Evaluation complete");
376 Ok(())
377 }
378
379 pub fn evaluate_batch_parallel(
381 &self,
382 positions: &mut [TrainingData],
383 num_threads: usize,
384 ) -> Result<(), Box<dyn std::error::Error>> {
385 let pb = ProgressBar::new(positions.len() as u64);
386 if let Ok(style) = ProgressStyle::default_bar()
387 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel evaluation") {
388 pb.set_style(style.progress_chars("#>-"));
389 }
390
391 let pool = rayon::ThreadPoolBuilder::new()
393 .num_threads(num_threads)
394 .build()?;
395
396 pool.install(|| {
397 positions.par_iter_mut().for_each(|data| {
399 match self.evaluate_position(&data.board) {
400 Ok(eval) => {
401 data.evaluation = eval;
402 data.depth = self.depth;
403 }
404 Err(_) => {
405 data.evaluation = 0.0;
407 }
408 }
409 pb.inc(1);
410 });
411 });
412
413 pb.finish_with_message("Parallel evaluation complete");
414 Ok(())
415 }
416}
417
418struct StockfishProcess {
420 child: Child,
421 stdin: BufWriter<std::process::ChildStdin>,
422 stdout: BufReader<std::process::ChildStdout>,
423 #[allow(dead_code)]
424 depth: u8,
425}
426
427impl StockfishProcess {
428 fn new(depth: u8) -> Result<Self, Box<dyn std::error::Error>> {
429 let mut child = Command::new("stockfish")
430 .stdin(Stdio::piped())
431 .stdout(Stdio::piped())
432 .stderr(Stdio::piped())
433 .spawn()?;
434
435 let stdin = BufWriter::new(
436 child
437 .stdin
438 .take()
439 .ok_or("Failed to get stdin handle for Stockfish process")?,
440 );
441 let stdout = BufReader::new(
442 child
443 .stdout
444 .take()
445 .ok_or("Failed to get stdout handle for Stockfish process")?,
446 );
447
448 let mut process = Self {
449 child,
450 stdin,
451 stdout,
452 depth,
453 };
454
455 process.send_command("uci")?;
457 process.wait_for_ready()?;
458 process.send_command("isready")?;
459 process.wait_for_ready()?;
460
461 Ok(process)
462 }
463
464 fn send_command(&mut self, command: &str) -> Result<(), Box<dyn std::error::Error>> {
465 writeln!(self.stdin, "{command}")?;
466 self.stdin.flush()?;
467 Ok(())
468 }
469
470 fn wait_for_ready(&mut self) -> Result<(), Box<dyn std::error::Error>> {
471 let mut line = String::new();
472 loop {
473 line.clear();
474 self.stdout.read_line(&mut line)?;
475 if line.trim() == "uciok" || line.trim() == "readyok" {
476 break;
477 }
478 }
479 Ok(())
480 }
481
482 fn evaluate_position(&mut self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
483 let fen = board.to_string();
484
485 self.send_command(&format!("position fen {fen}"))?;
487 self.send_command(&format!("position fen {fen}"))?;
488
489 let mut line = String::new();
491 let mut last_evaluation = 0.0;
492
493 loop {
494 line.clear();
495 self.stdout.read_line(&mut line)?;
496 let line = line.trim();
497
498 if line.starts_with("info") && line.contains("score cp") {
499 if let Some(cp_pos) = line.find("score cp ") {
500 let cp_str = &line[cp_pos + 9..];
501 if let Some(end) = cp_str.find(' ') {
502 if let Ok(cp_value) = cp_str[..end].parse::<i32>() {
503 last_evaluation = cp_value as f32 / 100.0;
504 }
505 }
506 }
507 } else if line.starts_with("info") && line.contains("score mate") {
508 if let Some(mate_pos) = line.find("score mate ") {
509 let mate_str = &line[mate_pos + 11..];
510 if let Some(end) = mate_str.find(' ') {
511 if let Ok(mate_moves) = mate_str[..end].parse::<i32>() {
512 last_evaluation = if mate_moves > 0 { 100.0 } else { -100.0 };
513 }
514 }
515 }
516 } else if line.starts_with("bestmove") {
517 break;
518 }
519 }
520
521 Ok(last_evaluation)
522 }
523}
524
525impl Drop for StockfishProcess {
526 fn drop(&mut self) {
527 let _ = self.send_command("quit");
528 let _ = self.child.wait();
529 }
530}
531
532pub struct StockfishPool {
534 pool: Arc<Mutex<Vec<StockfishProcess>>>,
535 depth: u8,
536 pool_size: usize,
537}
538
539impl StockfishPool {
540 pub fn new(depth: u8, pool_size: usize) -> Result<Self, Box<dyn std::error::Error>> {
541 let mut processes = Vec::with_capacity(pool_size);
542
543 println!(
544 "🚀 Initializing Stockfish pool with {pool_size} processes..."
545 );
546
547 for i in 0..pool_size {
548 match StockfishProcess::new(depth) {
549 Ok(process) => {
550 processes.push(process);
551 if i % 2 == 1 {
552 print!(".");
553 let _ = std::io::stdout().flush(); }
555 }
556 Err(e) => {
557 eprintln!("Evaluation error: {e}");
558 return Err(e);
559 }
560 }
561 }
562
563 println!(" ✅ Pool ready!");
564
565 Ok(Self {
566 pool: Arc::new(Mutex::new(processes)),
567 depth,
568 pool_size,
569 })
570 }
571
572 pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
573 let mut process = {
575 let mut pool = self.pool.lock().unwrap();
576 if let Some(process) = pool.pop() {
577 process
578 } else {
579 StockfishProcess::new(self.depth)?
581 }
582 };
583
584 let result = process.evaluate_position(board);
586
587 {
589 let mut pool = self.pool.lock().unwrap();
590 if pool.len() < self.pool_size {
591 pool.push(process);
592 }
593 }
595
596 result
597 }
598
599 pub fn evaluate_batch_parallel(
600 &self,
601 positions: &mut [TrainingData],
602 ) -> Result<(), Box<dyn std::error::Error>> {
603 let pb = ProgressBar::new(positions.len() as u64);
604 pb.set_style(ProgressStyle::default_bar()
605 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Pool evaluation")
606 .unwrap()
607 .progress_chars("#>-"));
608
609 positions.par_iter_mut().for_each(|data| {
611 match self.evaluate_position(&data.board) {
612 Ok(eval) => {
613 data.evaluation = eval;
614 data.depth = self.depth;
615 }
616 Err(_) => {
617 data.evaluation = 0.0;
618 }
619 }
620 pb.inc(1);
621 });
622
623 pb.finish_with_message("Pool evaluation complete");
624 Ok(())
625 }
626}
627
628pub struct TrainingDataset {
630 pub data: Vec<TrainingData>,
631}
632
633impl Default for TrainingDataset {
634 fn default() -> Self {
635 Self::new()
636 }
637}
638
639impl TrainingDataset {
640 pub fn new() -> Self {
641 Self { data: Vec::new() }
642 }
643
644 pub fn load_from_pgn<P: AsRef<Path>>(
646 &mut self,
647 path: P,
648 max_games: Option<usize>,
649 max_moves_per_game: usize,
650 ) -> Result<(), Box<dyn std::error::Error>> {
651 let file = File::open(path)?;
652 let reader = BufReader::new(file);
653
654 let mut extractor = GameExtractor::new(max_moves_per_game);
655 let mut games_processed = 0;
656
657 let mut pgn_content = String::new();
659 for line in reader.lines() {
660 let line = line?;
661 pgn_content.push_str(&line);
662 pgn_content.push('\n');
663
664 if line.trim().ends_with("1-0")
666 || line.trim().ends_with("0-1")
667 || line.trim().ends_with("1/2-1/2")
668 || line.trim().ends_with("*")
669 {
670 let cursor = std::io::Cursor::new(&pgn_content);
672 let mut reader = BufferedReader::new(cursor);
673 if let Err(e) = reader.read_all(&mut extractor) {
674 eprintln!("Evaluation error: {e}");
675 }
676
677 games_processed += 1;
678 pgn_content.clear();
679
680 if let Some(max) = max_games {
681 if games_processed >= max {
682 break;
683 }
684 }
685
686 if games_processed % 100 == 0 {
687 println!(
688 "Processed {} games, extracted {} positions",
689 games_processed,
690 extractor.positions.len()
691 );
692 }
693 }
694 }
695
696 self.data.extend(extractor.positions);
697 println!(
698 "Loaded {} positions from {} games",
699 self.data.len(),
700 games_processed
701 );
702 Ok(())
703 }
704
705 pub fn evaluate_with_stockfish(&mut self, depth: u8) -> Result<(), Box<dyn std::error::Error>> {
707 let evaluator = StockfishEvaluator::new(depth);
708 evaluator.evaluate_batch(&mut self.data)
709 }
710
711 pub fn evaluate_with_stockfish_parallel(
713 &mut self,
714 depth: u8,
715 num_threads: usize,
716 ) -> Result<(), Box<dyn std::error::Error>> {
717 let evaluator = StockfishEvaluator::new(depth);
718 evaluator.evaluate_batch_parallel(&mut self.data, num_threads)
719 }
720
721 pub fn train_engine(&self, engine: &mut ChessVectorEngine) {
723 let pb = ProgressBar::new(self.data.len() as u64);
724 pb.set_style(ProgressStyle::default_bar()
725 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Training positions")
726 .unwrap()
727 .progress_chars("#>-"));
728
729 for data in &self.data {
730 engine.add_position(&data.board, data.evaluation);
731 pb.inc(1);
732 }
733
734 pb.finish_with_message("Training complete");
735 println!("Trained engine with {} positions", self.data.len());
736 }
737
738 pub fn split(&self, train_ratio: f32) -> (TrainingDataset, TrainingDataset) {
740 use rand::seq::SliceRandom;
741 use rand::thread_rng;
742 use std::collections::{HashMap, HashSet};
743
744 let mut games: HashMap<usize, Vec<&TrainingData>> = HashMap::new();
746 for data in &self.data {
747 games.entry(data.game_id).or_default().push(data);
748 }
749
750 let mut game_ids: Vec<usize> = games.keys().cloned().collect();
752 game_ids.shuffle(&mut thread_rng());
753
754 let split_point = (game_ids.len() as f32 * train_ratio) as usize;
756 let train_game_ids: HashSet<usize> = game_ids[..split_point].iter().cloned().collect();
757
758 let mut train_data = Vec::new();
760 let mut test_data = Vec::new();
761
762 for data in &self.data {
763 if train_game_ids.contains(&data.game_id) {
764 train_data.push(data.clone());
765 } else {
766 test_data.push(data.clone());
767 }
768 }
769
770 (
771 TrainingDataset { data: train_data },
772 TrainingDataset { data: test_data },
773 )
774 }
775
776 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
778 let json = serde_json::to_string_pretty(&self.data)?;
779 std::fs::write(path, json)?;
780 Ok(())
781 }
782
783 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
785 let content = std::fs::read_to_string(path)?;
786 let data = serde_json::from_str(&content)?;
787 Ok(Self { data })
788 }
789
790 pub fn load_and_append<P: AsRef<Path>>(
792 &mut self,
793 path: P,
794 ) -> Result<(), Box<dyn std::error::Error>> {
795 let existing_len = self.data.len();
796 let additional_data = Self::load(path)?;
797 self.data.extend(additional_data.data);
798 println!(
799 "Loaded {} additional positions (total: {})",
800 self.data.len() - existing_len,
801 self.data.len()
802 );
803 Ok(())
804 }
805
806 pub fn merge(&mut self, other: TrainingDataset) {
808 let existing_len = self.data.len();
809 self.data.extend(other.data);
810 println!(
811 "Merged {} positions (total: {})",
812 self.data.len() - existing_len,
813 self.data.len()
814 );
815 }
816
817 pub fn save_incremental<P: AsRef<Path>>(
819 &self,
820 path: P,
821 ) -> Result<(), Box<dyn std::error::Error>> {
822 self.save_incremental_with_options(path, true)
823 }
824
825 pub fn save_incremental_with_options<P: AsRef<Path>>(
827 &self,
828 path: P,
829 deduplicate: bool,
830 ) -> Result<(), Box<dyn std::error::Error>> {
831 let path = path.as_ref();
832
833 if path.exists() {
834 if self.save_append_only(path).is_ok() {
836 return Ok(());
837 }
838
839 if deduplicate {
841 self.save_incremental_full_merge(path)
842 } else {
843 self.save_incremental_no_dedup(path)
844 }
845 } else {
846 self.save(path)
848 }
849 }
850
851 fn save_incremental_no_dedup<P: AsRef<Path>>(
853 &self,
854 path: P,
855 ) -> Result<(), Box<dyn std::error::Error>> {
856 let path = path.as_ref();
857
858 println!("📂 Loading existing training data...");
859 let mut existing = Self::load(path)?;
860
861 println!("⚡ Fast merge without deduplication...");
862 existing.data.extend(self.data.iter().cloned());
863
864 println!(
865 "💾 Serializing {} positions to JSON...",
866 existing.data.len()
867 );
868 let json = serde_json::to_string_pretty(&existing.data)?;
869
870 println!("✍️ Writing to disk...");
871 std::fs::write(path, json)?;
872
873 println!(
874 "✅ Fast merge save: total {} positions",
875 existing.data.len()
876 );
877 Ok(())
878 }
879
880 pub fn save_append_only<P: AsRef<Path>>(
882 &self,
883 path: P,
884 ) -> Result<(), Box<dyn std::error::Error>> {
885 use std::fs::OpenOptions;
886 use std::io::{BufRead, BufReader, Seek, SeekFrom, Write};
887
888 if self.data.is_empty() {
889 return Ok(());
890 }
891
892 let path = path.as_ref();
893 let mut file = OpenOptions::new().read(true).write(true).open(path)?;
894
895 file.seek(SeekFrom::End(-10))?;
897 let mut buffer = String::new();
898 BufReader::new(&file).read_line(&mut buffer)?;
899
900 if !buffer.trim().ends_with(']') {
901 return Err("File doesn't end with JSON array bracket".into());
902 }
903
904 file.seek(SeekFrom::End(-2))?; write!(file, ",")?;
909
910 for (i, data) in self.data.iter().enumerate() {
912 if i > 0 {
913 write!(file, ",")?;
914 }
915 let json = serde_json::to_string(data)?;
916 write!(file, "{json}")?;
917 }
918
919 write!(file, "\n]")?;
921
922 println!("Fast append: added {} new positions", self.data.len());
923 Ok(())
924 }
925
926 fn save_incremental_full_merge<P: AsRef<Path>>(
928 &self,
929 path: P,
930 ) -> Result<(), Box<dyn std::error::Error>> {
931 let path = path.as_ref();
932
933 println!("📂 Loading existing training data...");
934 let mut existing = Self::load(path)?;
935 let _original_len = existing.data.len();
936
937 println!("🔄 Streaming merge with deduplication (avoiding O(n²) operation)...");
938 existing.merge_and_deduplicate(self.data.clone());
939
940 println!(
941 "💾 Serializing {} positions to JSON...",
942 existing.data.len()
943 );
944 let json = serde_json::to_string_pretty(&existing.data)?;
945
946 println!("✍️ Writing to disk...");
947 std::fs::write(path, json)?;
948
949 println!(
950 "✅ Streaming merge save: total {} positions",
951 existing.data.len()
952 );
953 Ok(())
954 }
955
956 pub fn add_position(&mut self, board: Board, evaluation: f32, depth: u8, game_id: usize) {
958 self.data.push(TrainingData {
959 board,
960 evaluation,
961 depth,
962 game_id,
963 });
964 }
965
966 pub fn next_game_id(&self) -> usize {
968 self.data.iter().map(|data| data.game_id).max().unwrap_or(0) + 1
969 }
970
971 pub fn deduplicate(&mut self, similarity_threshold: f32) {
973 if similarity_threshold > 0.999 {
974 self.deduplicate_fast();
976 } else {
977 self.deduplicate_similarity_based(similarity_threshold);
979 }
980 }
981
982 pub fn deduplicate_fast(&mut self) {
984 use std::collections::HashSet;
985
986 if self.data.is_empty() {
987 return;
988 }
989
990 let mut seen_positions = HashSet::with_capacity(self.data.len());
991 let original_len = self.data.len();
992
993 self.data.retain(|data| {
995 let fen = data.board.to_string();
996 seen_positions.insert(fen)
997 });
998
999 println!(
1000 "Fast deduplicated: {} -> {} positions (removed {} exact duplicates)",
1001 original_len,
1002 self.data.len(),
1003 original_len - self.data.len()
1004 );
1005 }
1006
1007 pub fn merge_and_deduplicate(&mut self, new_data: Vec<TrainingData>) {
1009 use std::collections::HashSet;
1010
1011 if new_data.is_empty() {
1012 return;
1013 }
1014
1015 let _original_len = self.data.len();
1016
1017 let mut existing_positions: HashSet<String> = HashSet::with_capacity(self.data.len());
1019 for data in &self.data {
1020 existing_positions.insert(data.board.to_string());
1021 }
1022
1023 let mut added = 0;
1025 for data in new_data {
1026 let fen = data.board.to_string();
1027 if existing_positions.insert(fen) {
1028 self.data.push(data);
1029 added += 1;
1030 }
1031 }
1032
1033 println!(
1034 "Streaming merge: added {} unique positions (total: {})",
1035 added,
1036 self.data.len()
1037 );
1038 }
1039
1040 fn deduplicate_similarity_based(&mut self, similarity_threshold: f32) {
1042 use crate::PositionEncoder;
1043 use ndarray::Array1;
1044
1045 if self.data.is_empty() {
1046 return;
1047 }
1048
1049 let encoder = PositionEncoder::new(1024);
1050 let mut keep_indices: Vec<bool> = vec![true; self.data.len()];
1051
1052 let vectors: Vec<Array1<f32>> = if self.data.len() > 50 {
1054 self.data
1055 .par_iter()
1056 .map(|data| encoder.encode(&data.board))
1057 .collect()
1058 } else {
1059 self.data
1060 .iter()
1061 .map(|data| encoder.encode(&data.board))
1062 .collect()
1063 };
1064
1065 for i in 1..self.data.len() {
1067 if !keep_indices[i] {
1068 continue;
1069 }
1070
1071 for j in 0..i {
1072 if !keep_indices[j] {
1073 continue;
1074 }
1075
1076 let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1077 if similarity > similarity_threshold {
1078 keep_indices[i] = false;
1079 break;
1080 }
1081 }
1082 }
1083
1084 let original_len = self.data.len();
1086 self.data = self
1087 .data
1088 .iter()
1089 .enumerate()
1090 .filter_map(|(i, data)| {
1091 if keep_indices[i] {
1092 Some(data.clone())
1093 } else {
1094 None
1095 }
1096 })
1097 .collect();
1098
1099 println!(
1100 "Similarity deduplicated: {} -> {} positions (removed {} near-duplicates)",
1101 original_len,
1102 self.data.len(),
1103 original_len - self.data.len()
1104 );
1105 }
1106
1107 pub fn deduplicate_parallel(&mut self, similarity_threshold: f32, chunk_size: usize) {
1109 use crate::PositionEncoder;
1110 use ndarray::Array1;
1111 use std::sync::{Arc, Mutex};
1112
1113 if self.data.is_empty() {
1114 return;
1115 }
1116
1117 let encoder = PositionEncoder::new(1024);
1118
1119 let vectors: Vec<Array1<f32>> = self
1121 .data
1122 .par_iter()
1123 .map(|data| encoder.encode(&data.board))
1124 .collect();
1125
1126 let keep_indices = Arc::new(Mutex::new(vec![true; self.data.len()]));
1127
1128 (1..self.data.len())
1130 .collect::<Vec<_>>()
1131 .par_chunks(chunk_size)
1132 .for_each(|chunk| {
1133 for &i in chunk {
1134 {
1136 let indices = keep_indices.lock().unwrap();
1137 if !indices[i] {
1138 continue;
1139 }
1140 }
1141
1142 for j in 0..i {
1144 {
1145 let indices = keep_indices.lock().unwrap();
1146 if !indices[j] {
1147 continue;
1148 }
1149 }
1150
1151 let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1152 if similarity > similarity_threshold {
1153 let mut indices = keep_indices.lock().unwrap();
1154 indices[i] = false;
1155 break;
1156 }
1157 }
1158 }
1159 });
1160
1161 let keep_indices = keep_indices.lock().unwrap();
1163 let original_len = self.data.len();
1164 self.data = self
1165 .data
1166 .iter()
1167 .enumerate()
1168 .filter_map(|(i, data)| {
1169 if keep_indices[i] {
1170 Some(data.clone())
1171 } else {
1172 None
1173 }
1174 })
1175 .collect();
1176
1177 println!(
1178 "Parallel deduplicated: {} -> {} positions (removed {} duplicates)",
1179 original_len,
1180 self.data.len(),
1181 original_len - self.data.len()
1182 );
1183 }
1184
1185 fn cosine_similarity(a: &ndarray::Array1<f32>, b: &ndarray::Array1<f32>) -> f32 {
1187 let dot_product = a.dot(b);
1188 let norm_a = a.dot(a).sqrt();
1189 let norm_b = b.dot(b).sqrt();
1190
1191 if norm_a == 0.0 || norm_b == 0.0 {
1192 0.0
1193 } else {
1194 dot_product / (norm_a * norm_b)
1195 }
1196 }
1197}
1198
1199pub struct SelfPlayTrainer {
1201 config: SelfPlayConfig,
1202 game_counter: usize,
1203}
1204
1205impl SelfPlayTrainer {
1206 pub fn new(config: SelfPlayConfig) -> Self {
1207 Self {
1208 config,
1209 game_counter: 0,
1210 }
1211 }
1212
1213 pub fn generate_training_data(&mut self, engine: &mut ChessVectorEngine) -> TrainingDataset {
1215 let mut dataset = TrainingDataset::new();
1216
1217 println!(
1218 "🎮 Starting self-play training with {} games...",
1219 self.config.games_per_iteration
1220 );
1221 let pb = ProgressBar::new(self.config.games_per_iteration as u64);
1222 if let Ok(style) = ProgressStyle::default_bar().template(
1223 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
1224 ) {
1225 pb.set_style(style.progress_chars("#>-"));
1226 }
1227
1228 for _ in 0..self.config.games_per_iteration {
1229 let game_data = self.play_single_game(engine);
1230 dataset.data.extend(game_data);
1231 self.game_counter += 1;
1232 pb.inc(1);
1233 }
1234
1235 pb.finish_with_message("Self-play games completed");
1236 println!(
1237 "✅ Generated {} positions from {} games",
1238 dataset.data.len(),
1239 self.config.games_per_iteration
1240 );
1241
1242 dataset
1243 }
1244
1245 fn play_single_game(&self, engine: &mut ChessVectorEngine) -> Vec<TrainingData> {
1247 let mut game = Game::new();
1248 let mut positions = Vec::new();
1249 let mut move_count = 0;
1250
1251 if self.config.use_opening_book {
1253 if let Some(opening_moves) = self.get_random_opening() {
1254 for mv in opening_moves {
1255 if game.make_move(mv) {
1256 move_count += 1;
1257 } else {
1258 break;
1259 }
1260 }
1261 }
1262 }
1263
1264 while game.result().is_none() && move_count < self.config.max_moves_per_game {
1266 let current_position = game.current_position();
1267
1268 let move_choice = self.select_move_with_exploration(engine, ¤t_position);
1270
1271 if let Some(chess_move) = move_choice {
1272 if let Some(evaluation) = engine.evaluate_position(¤t_position) {
1274 if evaluation.abs() >= self.config.min_confidence || move_count < 10 {
1276 positions.push(TrainingData {
1277 board: current_position,
1278 evaluation,
1279 depth: 1, game_id: self.game_counter,
1281 });
1282 }
1283 }
1284
1285 if !game.make_move(chess_move) {
1287 break; }
1289 move_count += 1;
1290 } else {
1291 break; }
1293 }
1294
1295 if let Some(result) = game.result() {
1297 let final_position = game.current_position();
1298 let final_eval = match result {
1299 chess::GameResult::WhiteCheckmates => {
1300 if final_position.side_to_move() == chess::Color::Black {
1301 10.0
1302 } else {
1303 -10.0
1304 }
1305 }
1306 chess::GameResult::BlackCheckmates => {
1307 if final_position.side_to_move() == chess::Color::White {
1308 10.0
1309 } else {
1310 -10.0
1311 }
1312 }
1313 chess::GameResult::WhiteResigns => -10.0,
1314 chess::GameResult::BlackResigns => 10.0,
1315 chess::GameResult::Stalemate
1316 | chess::GameResult::DrawAccepted
1317 | chess::GameResult::DrawDeclared => 0.0,
1318 };
1319
1320 positions.push(TrainingData {
1321 board: final_position,
1322 evaluation: final_eval,
1323 depth: 1,
1324 game_id: self.game_counter,
1325 });
1326 }
1327
1328 positions
1329 }
1330
1331 fn select_move_with_exploration(
1333 &self,
1334 engine: &mut ChessVectorEngine,
1335 position: &Board,
1336 ) -> Option<ChessMove> {
1337 let recommendations = engine.recommend_legal_moves(position, 5);
1338
1339 if recommendations.is_empty() {
1340 return None;
1341 }
1342
1343 if fastrand::f32() < self.config.exploration_factor {
1345 self.select_move_with_temperature(&recommendations)
1347 } else {
1348 Some(recommendations[0].chess_move)
1350 }
1351 }
1352
1353 fn select_move_with_temperature(
1355 &self,
1356 recommendations: &[crate::MoveRecommendation],
1357 ) -> Option<ChessMove> {
1358 if recommendations.is_empty() {
1359 return None;
1360 }
1361
1362 let mut probabilities = Vec::new();
1364 let mut sum = 0.0;
1365
1366 for rec in recommendations {
1367 let prob = (rec.average_outcome / self.config.temperature).exp();
1369 probabilities.push(prob);
1370 sum += prob;
1371 }
1372
1373 for prob in &mut probabilities {
1375 *prob /= sum;
1376 }
1377
1378 let rand_val = fastrand::f32();
1380 let mut cumulative = 0.0;
1381
1382 for (i, &prob) in probabilities.iter().enumerate() {
1383 cumulative += prob;
1384 if rand_val <= cumulative {
1385 return Some(recommendations[i].chess_move);
1386 }
1387 }
1388
1389 Some(recommendations[0].chess_move)
1391 }
1392
1393 fn get_random_opening(&self) -> Option<Vec<ChessMove>> {
1395 let openings = [
1396 vec!["e4", "e5", "Nf3", "Nc6", "Bc4"],
1398 vec!["e4", "e5", "Nf3", "Nc6", "Bb5"],
1400 vec!["d4", "d5", "c4"],
1402 vec!["d4", "Nf6", "c4", "g6"],
1404 vec!["e4", "c5"],
1406 vec!["e4", "e6"],
1408 vec!["e4", "c6"],
1410 ];
1411
1412 let selected_opening = &openings[fastrand::usize(0..openings.len())];
1413
1414 let mut moves = Vec::new();
1415 let mut game = Game::new();
1416
1417 for move_str in selected_opening {
1418 if let Ok(chess_move) = ChessMove::from_str(move_str) {
1419 if game.make_move(chess_move) {
1420 moves.push(chess_move);
1421 } else {
1422 break;
1423 }
1424 }
1425 }
1426
1427 if moves.is_empty() {
1428 None
1429 } else {
1430 Some(moves)
1431 }
1432 }
1433}
1434
1435pub struct EngineEvaluator {
1437 #[allow(dead_code)]
1438 stockfish_depth: u8,
1439}
1440
1441impl EngineEvaluator {
1442 pub fn new(stockfish_depth: u8) -> Self {
1443 Self { stockfish_depth }
1444 }
1445
1446 pub fn evaluate_accuracy(
1448 &self,
1449 engine: &mut ChessVectorEngine,
1450 test_data: &TrainingDataset,
1451 ) -> Result<f32, Box<dyn std::error::Error>> {
1452 let mut total_error = 0.0;
1453 let mut valid_comparisons = 0;
1454
1455 let pb = ProgressBar::new(test_data.data.len() as u64);
1456 pb.set_style(ProgressStyle::default_bar()
1457 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Evaluating accuracy")
1458 .unwrap()
1459 .progress_chars("#>-"));
1460
1461 for data in &test_data.data {
1462 if let Some(engine_eval) = engine.evaluate_position(&data.board) {
1463 let error = (engine_eval - data.evaluation).abs();
1464 total_error += error;
1465 valid_comparisons += 1;
1466 }
1467 pb.inc(1);
1468 }
1469
1470 pb.finish_with_message("Accuracy evaluation complete");
1471
1472 if valid_comparisons > 0 {
1473 let mean_absolute_error = total_error / valid_comparisons as f32;
1474 println!("Mean Absolute Error: {mean_absolute_error:.3} pawns");
1475 println!("Evaluated {valid_comparisons} positions");
1476 Ok(mean_absolute_error)
1477 } else {
1478 Ok(f32::INFINITY)
1479 }
1480 }
1481}
1482
1483impl serde::Serialize for TrainingData {
1485 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1486 where
1487 S: serde::Serializer,
1488 {
1489 use serde::ser::SerializeStruct;
1490 let mut state = serializer.serialize_struct("TrainingData", 4)?;
1491 state.serialize_field("fen", &self.board.to_string())?;
1492 state.serialize_field("evaluation", &self.evaluation)?;
1493 state.serialize_field("depth", &self.depth)?;
1494 state.serialize_field("game_id", &self.game_id)?;
1495 state.end()
1496 }
1497}
1498
1499impl<'de> serde::Deserialize<'de> for TrainingData {
1500 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1501 where
1502 D: serde::Deserializer<'de>,
1503 {
1504 use serde::de::{self, MapAccess, Visitor};
1505 use std::fmt;
1506
1507 struct TrainingDataVisitor;
1508
1509 impl<'de> Visitor<'de> for TrainingDataVisitor {
1510 type Value = TrainingData;
1511
1512 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1513 formatter.write_str("struct TrainingData")
1514 }
1515
1516 fn visit_map<V>(self, mut map: V) -> Result<TrainingData, V::Error>
1517 where
1518 V: MapAccess<'de>,
1519 {
1520 let mut fen = None;
1521 let mut evaluation = None;
1522 let mut depth = None;
1523 let mut game_id = None;
1524
1525 while let Some(key) = map.next_key()? {
1526 match key {
1527 "fen" => {
1528 if fen.is_some() {
1529 return Err(de::Error::duplicate_field("fen"));
1530 }
1531 fen = Some(map.next_value()?);
1532 }
1533 "evaluation" => {
1534 if evaluation.is_some() {
1535 return Err(de::Error::duplicate_field("evaluation"));
1536 }
1537 evaluation = Some(map.next_value()?);
1538 }
1539 "depth" => {
1540 if depth.is_some() {
1541 return Err(de::Error::duplicate_field("depth"));
1542 }
1543 depth = Some(map.next_value()?);
1544 }
1545 "game_id" => {
1546 if game_id.is_some() {
1547 return Err(de::Error::duplicate_field("game_id"));
1548 }
1549 game_id = Some(map.next_value()?);
1550 }
1551 _ => {
1552 let _: serde_json::Value = map.next_value()?;
1553 }
1554 }
1555 }
1556
1557 let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
1558 let mut evaluation: f32 =
1559 evaluation.ok_or_else(|| de::Error::missing_field("evaluation"))?;
1560 let depth = depth.ok_or_else(|| de::Error::missing_field("depth"))?;
1561 let game_id = game_id.unwrap_or(0); if evaluation.abs() > 15.0 {
1567 evaluation /= 100.0;
1568 }
1569
1570 let board =
1571 Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
1572
1573 Ok(TrainingData {
1574 board,
1575 evaluation,
1576 depth,
1577 game_id,
1578 })
1579 }
1580 }
1581
1582 const FIELDS: &[&str] = &["fen", "evaluation", "depth", "game_id"];
1583 deserializer.deserialize_struct("TrainingData", FIELDS, TrainingDataVisitor)
1584 }
1585}
1586
1587pub struct TacticalPuzzleParser;
1589
1590impl TacticalPuzzleParser {
1591 pub fn parse_csv<P: AsRef<Path>>(
1593 file_path: P,
1594 max_puzzles: Option<usize>,
1595 min_rating: Option<u32>,
1596 max_rating: Option<u32>,
1597 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1598 let file = File::open(&file_path)?;
1599 let file_size = file.metadata()?.len();
1600
1601 if file_size > 100_000_000 {
1603 Self::parse_csv_parallel(file_path, max_puzzles, min_rating, max_rating)
1604 } else {
1605 Self::parse_csv_sequential(file_path, max_puzzles, min_rating, max_rating)
1606 }
1607 }
1608
1609 fn parse_csv_sequential<P: AsRef<Path>>(
1611 file_path: P,
1612 max_puzzles: Option<usize>,
1613 min_rating: Option<u32>,
1614 max_rating: Option<u32>,
1615 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1616 let file = File::open(file_path)?;
1617 let reader = BufReader::new(file);
1618
1619 let mut csv_reader = csv::ReaderBuilder::new()
1622 .has_headers(false)
1623 .flexible(true) .from_reader(reader);
1625
1626 let mut tactical_data = Vec::new();
1627 let mut processed = 0;
1628 let mut skipped = 0;
1629
1630 let pb = ProgressBar::new_spinner();
1631 pb.set_style(
1632 ProgressStyle::default_spinner()
1633 .template("{spinner:.green} Parsing tactical puzzles: {pos} (skipped: {skipped})")
1634 .unwrap(),
1635 );
1636
1637 for result in csv_reader.records() {
1638 let record = match result {
1639 Ok(r) => r,
1640 Err(e) => {
1641 skipped += 1;
1642 println!("CSV parsing error: {e}");
1643 continue;
1644 }
1645 };
1646
1647 if let Some(puzzle_data) = Self::parse_csv_record(&record, min_rating, max_rating) {
1648 if let Some(tactical_data_item) =
1649 Self::convert_puzzle_to_training_data(&puzzle_data)
1650 {
1651 tactical_data.push(tactical_data_item);
1652 processed += 1;
1653
1654 if let Some(max) = max_puzzles {
1655 if processed >= max {
1656 break;
1657 }
1658 }
1659 } else {
1660 skipped += 1;
1661 }
1662 } else {
1663 skipped += 1;
1664 }
1665
1666 pb.set_message(format!(
1667 "Parsing tactical puzzles: {processed} (skipped: {skipped})"
1668 ));
1669 }
1670
1671 pb.finish_with_message(format!(
1672 "Parsed {processed} puzzles (skipped: {skipped})"
1673 ));
1674
1675 Ok(tactical_data)
1676 }
1677
1678 fn parse_csv_parallel<P: AsRef<Path>>(
1680 file_path: P,
1681 max_puzzles: Option<usize>,
1682 min_rating: Option<u32>,
1683 max_rating: Option<u32>,
1684 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1685 use std::io::Read;
1686
1687 let mut file = File::open(&file_path)?;
1688
1689 let mut contents = String::new();
1691 file.read_to_string(&mut contents)?;
1692
1693 let lines: Vec<&str> = contents.lines().collect();
1695
1696 let pb = ProgressBar::new(lines.len() as u64);
1697 pb.set_style(ProgressStyle::default_bar()
1698 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel CSV parsing")
1699 .unwrap()
1700 .progress_chars("#>-"));
1701
1702 let tactical_data: Vec<TacticalTrainingData> = lines
1704 .par_iter()
1705 .take(max_puzzles.unwrap_or(usize::MAX))
1706 .filter_map(|line| {
1707 let fields: Vec<&str> = line.split(',').collect();
1709 if fields.len() < 8 {
1710 return None;
1711 }
1712
1713 if let Some(puzzle_data) = Self::parse_csv_fields(&fields, min_rating, max_rating) {
1715 Self::convert_puzzle_to_training_data(&puzzle_data)
1716 } else {
1717 None
1718 }
1719 })
1720 .collect();
1721
1722 pb.finish_with_message(format!(
1723 "Parallel parsing complete: {} puzzles",
1724 tactical_data.len()
1725 ));
1726
1727 Ok(tactical_data)
1728 }
1729
1730 fn parse_csv_record(
1732 record: &csv::StringRecord,
1733 min_rating: Option<u32>,
1734 max_rating: Option<u32>,
1735 ) -> Option<TacticalPuzzle> {
1736 if record.len() < 8 {
1738 return None;
1739 }
1740
1741 let rating: u32 = record[3].parse().ok()?;
1742 let rating_deviation: u32 = record[4].parse().ok()?;
1743 let popularity: i32 = record[5].parse().ok()?;
1744 let nb_plays: u32 = record[6].parse().ok()?;
1745
1746 if let Some(min) = min_rating {
1748 if rating < min {
1749 return None;
1750 }
1751 }
1752 if let Some(max) = max_rating {
1753 if rating > max {
1754 return None;
1755 }
1756 }
1757
1758 Some(TacticalPuzzle {
1759 puzzle_id: record[0].to_string(),
1760 fen: record[1].to_string(),
1761 moves: record[2].to_string(),
1762 rating,
1763 rating_deviation,
1764 popularity,
1765 nb_plays,
1766 themes: record[7].to_string(),
1767 game_url: if record.len() > 8 {
1768 Some(record[8].to_string())
1769 } else {
1770 None
1771 },
1772 opening_tags: if record.len() > 9 {
1773 Some(record[9].to_string())
1774 } else {
1775 None
1776 },
1777 })
1778 }
1779
1780 fn parse_csv_fields(
1782 fields: &[&str],
1783 min_rating: Option<u32>,
1784 max_rating: Option<u32>,
1785 ) -> Option<TacticalPuzzle> {
1786 if fields.len() < 8 {
1787 return None;
1788 }
1789
1790 let rating: u32 = fields[3].parse().ok()?;
1791 let rating_deviation: u32 = fields[4].parse().ok()?;
1792 let popularity: i32 = fields[5].parse().ok()?;
1793 let nb_plays: u32 = fields[6].parse().ok()?;
1794
1795 if let Some(min) = min_rating {
1797 if rating < min {
1798 return None;
1799 }
1800 }
1801 if let Some(max) = max_rating {
1802 if rating > max {
1803 return None;
1804 }
1805 }
1806
1807 Some(TacticalPuzzle {
1808 puzzle_id: fields[0].to_string(),
1809 fen: fields[1].to_string(),
1810 moves: fields[2].to_string(),
1811 rating,
1812 rating_deviation,
1813 popularity,
1814 nb_plays,
1815 themes: fields[7].to_string(),
1816 game_url: if fields.len() > 8 {
1817 Some(fields[8].to_string())
1818 } else {
1819 None
1820 },
1821 opening_tags: if fields.len() > 9 {
1822 Some(fields[9].to_string())
1823 } else {
1824 None
1825 },
1826 })
1827 }
1828
1829 fn convert_puzzle_to_training_data(puzzle: &TacticalPuzzle) -> Option<TacticalTrainingData> {
1831 let position = match Board::from_str(&puzzle.fen) {
1833 Ok(board) => board,
1834 Err(_) => return None,
1835 };
1836
1837 let moves: Vec<&str> = puzzle.moves.split_whitespace().collect();
1839 if moves.is_empty() {
1840 return None;
1841 }
1842
1843 let solution_move = match ChessMove::from_str(moves[0]) {
1845 Ok(mv) => mv,
1846 Err(_) => {
1847 match ChessMove::from_san(&position, moves[0]) {
1849 Ok(mv) => mv,
1850 Err(_) => return None,
1851 }
1852 }
1853 };
1854
1855 let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&position).collect();
1857 if !legal_moves.contains(&solution_move) {
1858 return None;
1859 }
1860
1861 let themes: Vec<&str> = puzzle.themes.split_whitespace().collect();
1863 let primary_theme = themes.first().unwrap_or(&"tactical").to_string();
1864
1865 let difficulty = puzzle.rating as f32 / 1000.0; let popularity_bonus = (puzzle.popularity as f32 / 100.0).min(2.0);
1868 let tactical_value = difficulty + popularity_bonus; Some(TacticalTrainingData {
1871 position,
1872 solution_move,
1873 move_theme: primary_theme,
1874 difficulty,
1875 tactical_value,
1876 })
1877 }
1878
1879 pub fn load_into_engine(
1881 tactical_data: &[TacticalTrainingData],
1882 engine: &mut ChessVectorEngine,
1883 ) {
1884 let pb = ProgressBar::new(tactical_data.len() as u64);
1885 pb.set_style(ProgressStyle::default_bar()
1886 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns")
1887 .unwrap()
1888 .progress_chars("#>-"));
1889
1890 for data in tactical_data {
1891 engine.add_position_with_move(
1893 &data.position,
1894 0.0, Some(data.solution_move),
1896 Some(data.tactical_value), );
1898 pb.inc(1);
1899 }
1900
1901 pb.finish_with_message(format!("Loaded {} tactical patterns", tactical_data.len()));
1902 }
1903
1904 pub fn load_into_engine_incremental(
1906 tactical_data: &[TacticalTrainingData],
1907 engine: &mut ChessVectorEngine,
1908 ) {
1909 let initial_size = engine.knowledge_base_size();
1910 let initial_moves = engine.position_moves.len();
1911
1912 if tactical_data.len() > 1000 {
1914 Self::load_into_engine_incremental_parallel(
1915 tactical_data,
1916 engine,
1917 initial_size,
1918 initial_moves,
1919 );
1920 } else {
1921 Self::load_into_engine_incremental_sequential(
1922 tactical_data,
1923 engine,
1924 initial_size,
1925 initial_moves,
1926 );
1927 }
1928 }
1929
1930 fn load_into_engine_incremental_sequential(
1932 tactical_data: &[TacticalTrainingData],
1933 engine: &mut ChessVectorEngine,
1934 initial_size: usize,
1935 initial_moves: usize,
1936 ) {
1937 let pb = ProgressBar::new(tactical_data.len() as u64);
1938 pb.set_style(ProgressStyle::default_bar()
1939 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns (incremental)")
1940 .unwrap()
1941 .progress_chars("#>-"));
1942
1943 let mut added = 0;
1944 let mut skipped = 0;
1945
1946 for data in tactical_data {
1947 if !engine.position_boards.contains(&data.position) {
1949 engine.add_position_with_move(
1950 &data.position,
1951 0.0, Some(data.solution_move),
1953 Some(data.tactical_value), );
1955 added += 1;
1956 } else {
1957 skipped += 1;
1958 }
1959 pb.inc(1);
1960 }
1961
1962 pb.finish_with_message(format!(
1963 "Loaded {} new tactical patterns (skipped {} duplicates, total: {})",
1964 added,
1965 skipped,
1966 engine.knowledge_base_size()
1967 ));
1968
1969 println!("Incremental tactical training:");
1970 println!(
1971 " - Positions: {} → {} (+{})",
1972 initial_size,
1973 engine.knowledge_base_size(),
1974 engine.knowledge_base_size() - initial_size
1975 );
1976 println!(
1977 " - Move entries: {} → {} (+{})",
1978 initial_moves,
1979 engine.position_moves.len(),
1980 engine.position_moves.len() - initial_moves
1981 );
1982 }
1983
1984 fn load_into_engine_incremental_parallel(
1986 tactical_data: &[TacticalTrainingData],
1987 engine: &mut ChessVectorEngine,
1988 initial_size: usize,
1989 initial_moves: usize,
1990 ) {
1991 let pb = ProgressBar::new(tactical_data.len() as u64);
1992 pb.set_style(ProgressStyle::default_bar()
1993 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Optimized batch loading tactical patterns")
1994 .unwrap()
1995 .progress_chars("#>-"));
1996
1997 let filtered_data: Vec<&TacticalTrainingData> = tactical_data
1999 .par_iter()
2000 .filter(|data| !engine.position_boards.contains(&data.position))
2001 .collect();
2002
2003 let batch_size = 1000; let mut added = 0;
2005
2006 println!(
2007 "Pre-filtered: {} → {} positions (removed {} duplicates)",
2008 tactical_data.len(),
2009 filtered_data.len(),
2010 tactical_data.len() - filtered_data.len()
2011 );
2012
2013 for batch in filtered_data.chunks(batch_size) {
2016 let batch_start = added;
2017
2018 for data in batch {
2019 if !engine.position_boards.contains(&data.position) {
2021 engine.add_position_with_move(
2022 &data.position,
2023 0.0, Some(data.solution_move),
2025 Some(data.tactical_value), );
2027 added += 1;
2028 }
2029 pb.inc(1);
2030 }
2031
2032 pb.set_message(format!("Loaded batch: {} positions", added - batch_start));
2034 }
2035
2036 let skipped = tactical_data.len() - added;
2037
2038 pb.finish_with_message(format!(
2039 "Optimized loaded {} new tactical patterns (skipped {} duplicates, total: {})",
2040 added,
2041 skipped,
2042 engine.knowledge_base_size()
2043 ));
2044
2045 println!("Incremental tactical training (optimized):");
2046 println!(
2047 " - Positions: {} → {} (+{})",
2048 initial_size,
2049 engine.knowledge_base_size(),
2050 engine.knowledge_base_size() - initial_size
2051 );
2052 println!(
2053 " - Move entries: {} → {} (+{})",
2054 initial_moves,
2055 engine.position_moves.len(),
2056 engine.position_moves.len() - initial_moves
2057 );
2058 println!(
2059 " - Batch size: {}, Pre-filtered efficiency: {:.1}%",
2060 batch_size,
2061 (filtered_data.len() as f32 / tactical_data.len() as f32) * 100.0
2062 );
2063 }
2064
2065 pub fn save_tactical_puzzles<P: AsRef<std::path::Path>>(
2067 tactical_data: &[TacticalTrainingData],
2068 path: P,
2069 ) -> Result<(), Box<dyn std::error::Error>> {
2070 let json = serde_json::to_string_pretty(tactical_data)?;
2071 std::fs::write(path, json)?;
2072 println!("Saved {} tactical puzzles", tactical_data.len());
2073 Ok(())
2074 }
2075
2076 pub fn load_tactical_puzzles<P: AsRef<std::path::Path>>(
2078 path: P,
2079 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
2080 let content = std::fs::read_to_string(path)?;
2081 let tactical_data: Vec<TacticalTrainingData> = serde_json::from_str(&content)?;
2082 println!("Loaded {} tactical puzzles from file", tactical_data.len());
2083 Ok(tactical_data)
2084 }
2085
2086 pub fn save_tactical_puzzles_incremental<P: AsRef<std::path::Path>>(
2088 tactical_data: &[TacticalTrainingData],
2089 path: P,
2090 ) -> Result<(), Box<dyn std::error::Error>> {
2091 let path = path.as_ref();
2092
2093 if path.exists() {
2094 let mut existing = Self::load_tactical_puzzles(path)?;
2096 let original_len = existing.len();
2097
2098 for new_puzzle in tactical_data {
2100 let exists = existing.iter().any(|existing_puzzle| {
2102 existing_puzzle.position == new_puzzle.position
2103 && existing_puzzle.solution_move == new_puzzle.solution_move
2104 });
2105
2106 if !exists {
2107 existing.push(new_puzzle.clone());
2108 }
2109 }
2110
2111 let json = serde_json::to_string_pretty(&existing)?;
2113 std::fs::write(path, json)?;
2114
2115 println!(
2116 "Incremental save: added {} new puzzles (total: {})",
2117 existing.len() - original_len,
2118 existing.len()
2119 );
2120 } else {
2121 Self::save_tactical_puzzles(tactical_data, path)?;
2123 }
2124 Ok(())
2125 }
2126
2127 pub fn parse_and_load_incremental<P: AsRef<std::path::Path>>(
2129 file_path: P,
2130 engine: &mut ChessVectorEngine,
2131 max_puzzles: Option<usize>,
2132 min_rating: Option<u32>,
2133 max_rating: Option<u32>,
2134 ) -> Result<(), Box<dyn std::error::Error>> {
2135 println!("Parsing Lichess puzzles incrementally...");
2136
2137 let tactical_data = Self::parse_csv(file_path, max_puzzles, min_rating, max_rating)?;
2139
2140 Self::load_into_engine_incremental(&tactical_data, engine);
2142
2143 Ok(())
2144 }
2145}
2146
2147#[cfg(test)]
2148mod tests {
2149 use super::*;
2150 use chess::Board;
2151 use std::str::FromStr;
2152
2153 #[test]
2154 fn test_training_dataset_creation() {
2155 let dataset = TrainingDataset::new();
2156 assert_eq!(dataset.data.len(), 0);
2157 }
2158
2159 #[test]
2160 fn test_add_training_data() {
2161 let mut dataset = TrainingDataset::new();
2162 let board = Board::default();
2163
2164 let training_data = TrainingData {
2165 board,
2166 evaluation: 0.5,
2167 depth: 15,
2168 game_id: 1,
2169 };
2170
2171 dataset.data.push(training_data);
2172 assert_eq!(dataset.data.len(), 1);
2173 assert_eq!(dataset.data[0].evaluation, 0.5);
2174 }
2175
2176 #[test]
2177 fn test_chess_engine_integration() {
2178 let mut dataset = TrainingDataset::new();
2179 let board = Board::default();
2180
2181 let training_data = TrainingData {
2182 board,
2183 evaluation: 0.3,
2184 depth: 15,
2185 game_id: 1,
2186 };
2187
2188 dataset.data.push(training_data);
2189
2190 let mut engine = ChessVectorEngine::new(1024);
2191 dataset.train_engine(&mut engine);
2192
2193 assert_eq!(engine.knowledge_base_size(), 1);
2194
2195 let eval = engine.evaluate_position(&board);
2196 assert!(eval.is_some());
2197 assert!((eval.unwrap() - 0.3).abs() < 1e-6);
2198 }
2199
2200 #[test]
2201 fn test_deduplication() {
2202 let mut dataset = TrainingDataset::new();
2203 let board = Board::default();
2204
2205 for i in 0..5 {
2207 let training_data = TrainingData {
2208 board,
2209 evaluation: i as f32 * 0.1,
2210 depth: 15,
2211 game_id: i,
2212 };
2213 dataset.data.push(training_data);
2214 }
2215
2216 assert_eq!(dataset.data.len(), 5);
2217
2218 dataset.deduplicate(0.999);
2220 assert_eq!(dataset.data.len(), 1);
2221 }
2222
2223 #[test]
2224 fn test_dataset_serialization() {
2225 let mut dataset = TrainingDataset::new();
2226 let board =
2227 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2228
2229 let training_data = TrainingData {
2230 board,
2231 evaluation: 0.2,
2232 depth: 10,
2233 game_id: 42,
2234 };
2235
2236 dataset.data.push(training_data);
2237
2238 let json = serde_json::to_string(&dataset.data).unwrap();
2240 let loaded_data: Vec<TrainingData> = serde_json::from_str(&json).unwrap();
2241 let loaded_dataset = TrainingDataset { data: loaded_data };
2242
2243 assert_eq!(loaded_dataset.data.len(), 1);
2244 assert_eq!(loaded_dataset.data[0].evaluation, 0.2);
2245 assert_eq!(loaded_dataset.data[0].depth, 10);
2246 assert_eq!(loaded_dataset.data[0].game_id, 42);
2247 }
2248
2249 #[test]
2250 fn test_tactical_puzzle_processing() {
2251 let puzzle = TacticalPuzzle {
2252 puzzle_id: "test123".to_string(),
2253 fen: "r1bqkbnr/pppp1ppp/2n5/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4".to_string(),
2254 moves: "Bxf7+ Ke7".to_string(),
2255 rating: 1500,
2256 rating_deviation: 100,
2257 popularity: 150,
2258 nb_plays: 1000,
2259 themes: "fork pin".to_string(),
2260 game_url: None,
2261 opening_tags: None,
2262 };
2263
2264 let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2265 assert!(tactical_data.is_some());
2266
2267 let data = tactical_data.unwrap();
2268 assert_eq!(data.move_theme, "fork");
2269 assert!(data.tactical_value > 1.0); assert!(data.difficulty > 0.0);
2271 }
2272
2273 #[test]
2274 fn test_tactical_puzzle_invalid_fen() {
2275 let puzzle = TacticalPuzzle {
2276 puzzle_id: "test123".to_string(),
2277 fen: "invalid_fen".to_string(),
2278 moves: "e2e4".to_string(),
2279 rating: 1500,
2280 rating_deviation: 100,
2281 popularity: 150,
2282 nb_plays: 1000,
2283 themes: "tactics".to_string(),
2284 game_url: None,
2285 opening_tags: None,
2286 };
2287
2288 let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2289 assert!(tactical_data.is_none());
2290 }
2291
2292 #[test]
2293 fn test_engine_evaluator() {
2294 let evaluator = EngineEvaluator::new(15);
2295
2296 let mut dataset = TrainingDataset::new();
2298 let board = Board::default();
2299
2300 let training_data = TrainingData {
2301 board,
2302 evaluation: 0.0,
2303 depth: 15,
2304 game_id: 1,
2305 };
2306
2307 dataset.data.push(training_data);
2308
2309 let mut engine = ChessVectorEngine::new(1024);
2311 engine.add_position(&board, 0.1);
2312
2313 let accuracy = evaluator.evaluate_accuracy(&mut engine, &dataset);
2315 assert!(accuracy.is_ok());
2316 assert!(accuracy.unwrap() < 1.0); }
2318
2319 #[test]
2320 fn test_tactical_training_integration() {
2321 let tactical_data = vec![TacticalTrainingData {
2322 position: Board::default(),
2323 solution_move: ChessMove::from_str("e2e4").unwrap(),
2324 move_theme: "opening".to_string(),
2325 difficulty: 1.2,
2326 tactical_value: 2.5,
2327 }];
2328
2329 let mut engine = ChessVectorEngine::new(1024);
2330 TacticalPuzzleParser::load_into_engine(&tactical_data, &mut engine);
2331
2332 assert_eq!(engine.knowledge_base_size(), 1);
2333 assert_eq!(engine.position_moves.len(), 1);
2334
2335 let recommendations = engine.recommend_moves(&Board::default(), 5);
2337 assert!(!recommendations.is_empty());
2338 }
2339
2340 #[test]
2341 fn test_multithreading_operations() {
2342 let mut dataset = TrainingDataset::new();
2343 let board = Board::default();
2344
2345 for i in 0..10 {
2347 let training_data = TrainingData {
2348 board,
2349 evaluation: i as f32 * 0.1,
2350 depth: 15,
2351 game_id: i,
2352 };
2353 dataset.data.push(training_data);
2354 }
2355
2356 dataset.deduplicate_parallel(0.95, 5);
2358 assert!(dataset.data.len() <= 10);
2359 }
2360
2361 #[test]
2362 fn test_incremental_dataset_operations() {
2363 let mut dataset1 = TrainingDataset::new();
2364 let board1 = Board::default();
2365 let board2 =
2366 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2367
2368 dataset1.add_position(board1, 0.0, 15, 1);
2370 dataset1.add_position(board2, 0.2, 15, 2);
2371 assert_eq!(dataset1.data.len(), 2);
2372
2373 let mut dataset2 = TrainingDataset::new();
2375 dataset2.add_position(
2376 Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2377 .unwrap(),
2378 0.3,
2379 15,
2380 3,
2381 );
2382
2383 dataset1.merge(dataset2);
2385 assert_eq!(dataset1.data.len(), 3);
2386
2387 let next_id = dataset1.next_game_id();
2389 assert_eq!(next_id, 4); }
2391
2392 #[test]
2393 fn test_save_load_incremental() {
2394 use tempfile::tempdir;
2395
2396 let temp_dir = tempdir().unwrap();
2397 let file_path = temp_dir.path().join("incremental_test.json");
2398
2399 let mut dataset1 = TrainingDataset::new();
2401 dataset1.add_position(Board::default(), 0.0, 15, 1);
2402 dataset1.save(&file_path).unwrap();
2403
2404 let mut dataset2 = TrainingDataset::new();
2406 dataset2.add_position(
2407 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap(),
2408 0.2,
2409 15,
2410 2,
2411 );
2412 dataset2.save_incremental(&file_path).unwrap();
2413
2414 let loaded = TrainingDataset::load(&file_path).unwrap();
2416 assert_eq!(loaded.data.len(), 2);
2417
2418 let mut dataset3 = TrainingDataset::new();
2420 dataset3.add_position(
2421 Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2422 .unwrap(),
2423 0.3,
2424 15,
2425 3,
2426 );
2427 dataset3.load_and_append(&file_path).unwrap();
2428 assert_eq!(dataset3.data.len(), 3); }
2430
2431 #[test]
2432 fn test_add_position_method() {
2433 let mut dataset = TrainingDataset::new();
2434 let board = Board::default();
2435
2436 dataset.add_position(board, 0.5, 20, 42);
2438 assert_eq!(dataset.data.len(), 1);
2439 assert_eq!(dataset.data[0].evaluation, 0.5);
2440 assert_eq!(dataset.data[0].depth, 20);
2441 assert_eq!(dataset.data[0].game_id, 42);
2442 }
2443
2444 #[test]
2445 fn test_incremental_save_deduplication() {
2446 use tempfile::tempdir;
2447
2448 let temp_dir = tempdir().unwrap();
2449 let file_path = temp_dir.path().join("dedup_test.json");
2450
2451 let mut dataset1 = TrainingDataset::new();
2453 dataset1.add_position(Board::default(), 0.0, 15, 1);
2454 dataset1.save(&file_path).unwrap();
2455
2456 let mut dataset2 = TrainingDataset::new();
2458 dataset2.add_position(Board::default(), 0.1, 15, 2); dataset2.save_incremental(&file_path).unwrap();
2460
2461 let loaded = TrainingDataset::load(&file_path).unwrap();
2463 assert_eq!(loaded.data.len(), 1);
2464 }
2465
2466 #[test]
2467 fn test_tactical_puzzle_incremental_loading() {
2468 let tactical_data = vec![
2469 TacticalTrainingData {
2470 position: Board::default(),
2471 solution_move: ChessMove::from_str("e2e4").unwrap(),
2472 move_theme: "opening".to_string(),
2473 difficulty: 1.2,
2474 tactical_value: 2.5,
2475 },
2476 TacticalTrainingData {
2477 position: Board::from_str(
2478 "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1",
2479 )
2480 .unwrap(),
2481 solution_move: ChessMove::from_str("e7e5").unwrap(),
2482 move_theme: "opening".to_string(),
2483 difficulty: 1.0,
2484 tactical_value: 2.0,
2485 },
2486 ];
2487
2488 let mut engine = ChessVectorEngine::new(1024);
2489
2490 engine.add_position(&Board::default(), 0.1);
2492 assert_eq!(engine.knowledge_base_size(), 1);
2493
2494 TacticalPuzzleParser::load_into_engine_incremental(&tactical_data, &mut engine);
2496
2497 assert_eq!(engine.knowledge_base_size(), 2);
2499
2500 assert!(engine.training_stats().has_move_data);
2502 assert!(engine.training_stats().move_data_entries > 0);
2503 }
2504
2505 #[test]
2506 fn test_tactical_puzzle_serialization() {
2507 use tempfile::tempdir;
2508
2509 let temp_dir = tempdir().unwrap();
2510 let file_path = temp_dir.path().join("tactical_test.json");
2511
2512 let tactical_data = vec![TacticalTrainingData {
2513 position: Board::default(),
2514 solution_move: ChessMove::from_str("e2e4").unwrap(),
2515 move_theme: "fork".to_string(),
2516 difficulty: 1.5,
2517 tactical_value: 3.0,
2518 }];
2519
2520 TacticalPuzzleParser::save_tactical_puzzles(&tactical_data, &file_path).unwrap();
2522
2523 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2525 assert_eq!(loaded.len(), 1);
2526 assert_eq!(loaded[0].move_theme, "fork");
2527 assert_eq!(loaded[0].difficulty, 1.5);
2528 assert_eq!(loaded[0].tactical_value, 3.0);
2529 }
2530
2531 #[test]
2532 fn test_tactical_puzzle_incremental_save() {
2533 use tempfile::tempdir;
2534
2535 let temp_dir = tempdir().unwrap();
2536 let file_path = temp_dir.path().join("incremental_tactical.json");
2537
2538 let batch1 = vec![TacticalTrainingData {
2540 position: Board::default(),
2541 solution_move: ChessMove::from_str("e2e4").unwrap(),
2542 move_theme: "opening".to_string(),
2543 difficulty: 1.0,
2544 tactical_value: 2.0,
2545 }];
2546 TacticalPuzzleParser::save_tactical_puzzles(&batch1, &file_path).unwrap();
2547
2548 let batch2 = vec![TacticalTrainingData {
2550 position: Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1")
2551 .unwrap(),
2552 solution_move: ChessMove::from_str("e7e5").unwrap(),
2553 move_theme: "counter".to_string(),
2554 difficulty: 1.2,
2555 tactical_value: 2.2,
2556 }];
2557 TacticalPuzzleParser::save_tactical_puzzles_incremental(&batch2, &file_path).unwrap();
2558
2559 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2561 assert_eq!(loaded.len(), 2);
2562 }
2563
2564 #[test]
2565 fn test_tactical_puzzle_incremental_deduplication() {
2566 use tempfile::tempdir;
2567
2568 let temp_dir = tempdir().unwrap();
2569 let file_path = temp_dir.path().join("dedup_tactical.json");
2570
2571 let tactical_data = TacticalTrainingData {
2572 position: Board::default(),
2573 solution_move: ChessMove::from_str("e2e4").unwrap(),
2574 move_theme: "opening".to_string(),
2575 difficulty: 1.0,
2576 tactical_value: 2.0,
2577 };
2578
2579 TacticalPuzzleParser::save_tactical_puzzles(&[tactical_data.clone()], &file_path).unwrap();
2581
2582 TacticalPuzzleParser::save_tactical_puzzles_incremental(&[tactical_data], &file_path)
2584 .unwrap();
2585
2586 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2588 assert_eq!(loaded.len(), 1);
2589 }
2590}