1use chess::{Board, ChessMove, Game, MoveGen};
2use indicatif::{ProgressBar, ProgressStyle};
3use pgn_reader::{BufferedReader, RawHeader, SanPlus, Skip, Visitor};
4use rayon::prelude::*;
5use serde::{Deserialize, Serialize};
6use std::fs::File;
7use std::io::{BufRead, BufReader, BufWriter, Write};
8use std::path::Path;
9use std::process::{Child, Command, Stdio};
10use std::str::FromStr;
11use std::sync::{Arc, Mutex};
12
13use crate::ChessVectorEngine;
14
15#[derive(Debug, Clone)]
17pub struct SelfPlayConfig {
18 pub games_per_iteration: usize,
20 pub max_moves_per_game: usize,
22 pub exploration_factor: f32,
24 pub min_confidence: f32,
26 pub use_opening_book: bool,
28 pub temperature: f32,
30}
31
32impl Default for SelfPlayConfig {
33 fn default() -> Self {
34 Self {
35 games_per_iteration: 100,
36 max_moves_per_game: 200,
37 exploration_factor: 0.3,
38 min_confidence: 0.1,
39 use_opening_book: true,
40 temperature: 0.8,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct TrainingData {
48 pub board: Board,
49 pub evaluation: f32,
50 pub depth: u8,
51 pub game_id: usize,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct TacticalPuzzle {
57 #[serde(rename = "PuzzleId")]
58 pub puzzle_id: String,
59 #[serde(rename = "FEN")]
60 pub fen: String,
61 #[serde(rename = "Moves")]
62 pub moves: String, #[serde(rename = "Rating")]
64 pub rating: u32,
65 #[serde(rename = "RatingDeviation")]
66 pub rating_deviation: u32,
67 #[serde(rename = "Popularity")]
68 pub popularity: i32,
69 #[serde(rename = "NbPlays")]
70 pub nb_plays: u32,
71 #[serde(rename = "Themes")]
72 pub themes: String, #[serde(rename = "GameUrl")]
74 pub game_url: Option<String>,
75 #[serde(rename = "OpeningTags")]
76 pub opening_tags: Option<String>,
77}
78
79#[derive(Debug, Clone)]
81pub struct TacticalTrainingData {
82 pub position: Board,
83 pub solution_move: ChessMove,
84 pub move_theme: String,
85 pub difficulty: f32, pub tactical_value: f32, }
88
89impl serde::Serialize for TacticalTrainingData {
91 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
92 where
93 S: serde::Serializer,
94 {
95 use serde::ser::SerializeStruct;
96 let mut state = serializer.serialize_struct("TacticalTrainingData", 5)?;
97 state.serialize_field("fen", &self.position.to_string())?;
98 state.serialize_field("solution_move", &self.solution_move.to_string())?;
99 state.serialize_field("move_theme", &self.move_theme)?;
100 state.serialize_field("difficulty", &self.difficulty)?;
101 state.serialize_field("tactical_value", &self.tactical_value)?;
102 state.end()
103 }
104}
105
106impl<'de> serde::Deserialize<'de> for TacticalTrainingData {
107 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108 where
109 D: serde::Deserializer<'de>,
110 {
111 use serde::de::{self, MapAccess, Visitor};
112 use std::fmt;
113
114 struct TacticalTrainingDataVisitor;
115
116 impl<'de> Visitor<'de> for TacticalTrainingDataVisitor {
117 type Value = TacticalTrainingData;
118
119 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
120 formatter.write_str("struct TacticalTrainingData")
121 }
122
123 fn visit_map<V>(self, mut map: V) -> Result<TacticalTrainingData, V::Error>
124 where
125 V: MapAccess<'de>,
126 {
127 let mut fen = None;
128 let mut solution_move = None;
129 let mut move_theme = None;
130 let mut difficulty = None;
131 let mut tactical_value = None;
132
133 while let Some(key) = map.next_key()? {
134 match key {
135 "fen" => {
136 if fen.is_some() {
137 return Err(de::Error::duplicate_field("fen"));
138 }
139 fen = Some(map.next_value()?);
140 }
141 "solution_move" => {
142 if solution_move.is_some() {
143 return Err(de::Error::duplicate_field("solution_move"));
144 }
145 solution_move = Some(map.next_value()?);
146 }
147 "move_theme" => {
148 if move_theme.is_some() {
149 return Err(de::Error::duplicate_field("move_theme"));
150 }
151 move_theme = Some(map.next_value()?);
152 }
153 "difficulty" => {
154 if difficulty.is_some() {
155 return Err(de::Error::duplicate_field("difficulty"));
156 }
157 difficulty = Some(map.next_value()?);
158 }
159 "tactical_value" => {
160 if tactical_value.is_some() {
161 return Err(de::Error::duplicate_field("tactical_value"));
162 }
163 tactical_value = Some(map.next_value()?);
164 }
165 _ => {
166 let _: serde_json::Value = map.next_value()?;
167 }
168 }
169 }
170
171 let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
172 let solution_move_str: String =
173 solution_move.ok_or_else(|| de::Error::missing_field("solution_move"))?;
174 let move_theme =
175 move_theme.ok_or_else(|| de::Error::missing_field("move_theme"))?;
176 let difficulty =
177 difficulty.ok_or_else(|| de::Error::missing_field("difficulty"))?;
178 let tactical_value =
179 tactical_value.ok_or_else(|| de::Error::missing_field("tactical_value"))?;
180
181 let position =
182 Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
183
184 let solution_move = ChessMove::from_str(&solution_move_str)
185 .map_err(|e| de::Error::custom(format!("Error: {e}")))?;
186
187 Ok(TacticalTrainingData {
188 position,
189 solution_move,
190 move_theme,
191 difficulty,
192 tactical_value,
193 })
194 }
195 }
196
197 const FIELDS: &[&str] = &[
198 "fen",
199 "solution_move",
200 "move_theme",
201 "difficulty",
202 "tactical_value",
203 ];
204 deserializer.deserialize_struct("TacticalTrainingData", FIELDS, TacticalTrainingDataVisitor)
205 }
206}
207
208pub struct GameExtractor {
210 pub positions: Vec<TrainingData>,
211 pub current_game: Game,
212 pub move_count: usize,
213 pub max_moves_per_game: usize,
214 pub game_id: usize,
215}
216
217impl GameExtractor {
218 pub fn new(max_moves_per_game: usize) -> Self {
219 Self {
220 positions: Vec::new(),
221 current_game: Game::new(),
222 move_count: 0,
223 max_moves_per_game,
224 game_id: 0,
225 }
226 }
227}
228
229impl Visitor for GameExtractor {
230 type Result = ();
231
232 fn begin_game(&mut self) {
233 self.current_game = Game::new();
234 self.move_count = 0;
235 self.game_id += 1;
236 }
237
238 fn header(&mut self, _key: &[u8], _value: RawHeader<'_>) {}
239
240 fn san(&mut self, san_plus: SanPlus) {
241 if self.move_count >= self.max_moves_per_game {
242 return;
243 }
244
245 let san_str = san_plus.san.to_string();
246
247 let current_pos = self.current_game.current_position();
249
250 match chess::ChessMove::from_san(¤t_pos, &san_str) {
252 Ok(chess_move) => {
253 let legal_moves: Vec<chess::ChessMove> = MoveGen::new_legal(¤t_pos).collect();
255 if legal_moves.contains(&chess_move) {
256 if self.current_game.make_move(chess_move) {
257 self.move_count += 1;
258
259 self.positions.push(TrainingData {
261 board: self.current_game.current_position(),
262 evaluation: 0.0, depth: 0,
264 game_id: self.game_id,
265 });
266 }
267 } else {
268 }
270 }
271 Err(_) => {
272 if !san_str.contains("O-O") && !san_str.contains("=") && san_str.len() > 6 {
276 }
278 }
279 }
280 }
281
282 fn begin_variation(&mut self) -> Skip {
283 Skip(true) }
285
286 fn end_game(&mut self) -> Self::Result {}
287}
288
289pub struct StockfishEvaluator {
291 depth: u8,
292}
293
294impl StockfishEvaluator {
295 pub fn new(depth: u8) -> Self {
296 Self { depth }
297 }
298
299 pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
301 let mut child = Command::new("stockfish")
302 .stdin(Stdio::piped())
303 .stdout(Stdio::piped())
304 .stderr(Stdio::piped())
305 .spawn()?;
306
307 let stdin = child
308 .stdin
309 .as_mut()
310 .ok_or("Failed to get stdin handle for Stockfish process")?;
311 let fen = board.to_string();
312
313 use std::io::Write;
315 writeln!(stdin, "uci")?;
316 writeln!(stdin, "isready")?;
317 writeln!(stdin, "position fen {fen}")?;
318 writeln!(stdin, "go depth {}", self.depth)?;
319 writeln!(stdin, "quit")?;
320
321 let output = child.wait_with_output()?;
322 let stdout = String::from_utf8_lossy(&output.stdout);
323
324 for line in stdout.lines() {
326 if line.starts_with("info") && line.contains("score cp") {
327 if let Some(cp_pos) = line.find("score cp ") {
328 let cp_str = &line[cp_pos + 9..];
329 if let Some(end) = cp_str.find(' ') {
330 let cp_value = cp_str[..end].parse::<i32>()?;
331 return Ok(cp_value as f32 / 100.0); }
333 }
334 } else if line.starts_with("info") && line.contains("score mate") {
335 if let Some(mate_pos) = line.find("score mate ") {
337 let mate_str = &line[mate_pos + 11..];
338 if let Some(end) = mate_str.find(' ') {
339 let mate_moves = mate_str[..end].parse::<i32>()?;
340 return Ok(if mate_moves > 0 { 100.0 } else { -100.0 });
341 }
342 }
343 }
344 }
345
346 Ok(0.0) }
348
349 pub fn evaluate_batch(
351 &self,
352 positions: &mut [TrainingData],
353 ) -> Result<(), Box<dyn std::error::Error>> {
354 let pb = ProgressBar::new(positions.len() as u64);
355 if let Ok(style) = ProgressStyle::default_bar().template(
356 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
357 ) {
358 pb.set_style(style.progress_chars("#>-"));
359 }
360
361 for data in positions.iter_mut() {
362 match self.evaluate_position(&data.board) {
363 Ok(eval) => {
364 data.evaluation = eval;
365 data.depth = self.depth;
366 }
367 Err(e) => {
368 eprintln!("Evaluation error: {e}");
369 data.evaluation = 0.0;
370 }
371 }
372 pb.inc(1);
373 }
374
375 pb.finish_with_message("Evaluation complete");
376 Ok(())
377 }
378
379 pub fn evaluate_batch_parallel(
381 &self,
382 positions: &mut [TrainingData],
383 num_threads: usize,
384 ) -> Result<(), Box<dyn std::error::Error>> {
385 let pb = ProgressBar::new(positions.len() as u64);
386 if let Ok(style) = ProgressStyle::default_bar()
387 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel evaluation") {
388 pb.set_style(style.progress_chars("#>-"));
389 }
390
391 let pool = rayon::ThreadPoolBuilder::new()
393 .num_threads(num_threads)
394 .build()?;
395
396 pool.install(|| {
397 positions.par_iter_mut().for_each(|data| {
399 match self.evaluate_position(&data.board) {
400 Ok(eval) => {
401 data.evaluation = eval;
402 data.depth = self.depth;
403 }
404 Err(_) => {
405 data.evaluation = 0.0;
407 }
408 }
409 pb.inc(1);
410 });
411 });
412
413 pb.finish_with_message("Parallel evaluation complete");
414 Ok(())
415 }
416}
417
418struct StockfishProcess {
420 child: Child,
421 stdin: BufWriter<std::process::ChildStdin>,
422 stdout: BufReader<std::process::ChildStdout>,
423 #[allow(dead_code)]
424 depth: u8,
425}
426
427impl StockfishProcess {
428 fn new(depth: u8) -> Result<Self, Box<dyn std::error::Error>> {
429 let mut child = Command::new("stockfish")
430 .stdin(Stdio::piped())
431 .stdout(Stdio::piped())
432 .stderr(Stdio::piped())
433 .spawn()?;
434
435 let stdin = BufWriter::new(
436 child
437 .stdin
438 .take()
439 .ok_or("Failed to get stdin handle for Stockfish process")?,
440 );
441 let stdout = BufReader::new(
442 child
443 .stdout
444 .take()
445 .ok_or("Failed to get stdout handle for Stockfish process")?,
446 );
447
448 let mut process = Self {
449 child,
450 stdin,
451 stdout,
452 depth,
453 };
454
455 process.send_command("uci")?;
457 process.wait_for_ready()?;
458 process.send_command("isready")?;
459 process.wait_for_ready()?;
460
461 Ok(process)
462 }
463
464 fn send_command(&mut self, command: &str) -> Result<(), Box<dyn std::error::Error>> {
465 writeln!(self.stdin, "{command}")?;
466 self.stdin.flush()?;
467 Ok(())
468 }
469
470 fn wait_for_ready(&mut self) -> Result<(), Box<dyn std::error::Error>> {
471 let mut line = String::new();
472 loop {
473 line.clear();
474 self.stdout.read_line(&mut line)?;
475 if line.trim() == "uciok" || line.trim() == "readyok" {
476 break;
477 }
478 }
479 Ok(())
480 }
481
482 fn evaluate_position(&mut self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
483 let fen = board.to_string();
484
485 self.send_command(&format!("position fen {fen}"))?;
487 self.send_command(&format!("position fen {fen}"))?;
488
489 let mut line = String::new();
491 let mut last_evaluation = 0.0;
492
493 loop {
494 line.clear();
495 self.stdout.read_line(&mut line)?;
496 let line = line.trim();
497
498 if line.starts_with("info") && line.contains("score cp") {
499 if let Some(cp_pos) = line.find("score cp ") {
500 let cp_str = &line[cp_pos + 9..];
501 if let Some(end) = cp_str.find(' ') {
502 if let Ok(cp_value) = cp_str[..end].parse::<i32>() {
503 last_evaluation = cp_value as f32 / 100.0;
504 }
505 }
506 }
507 } else if line.starts_with("info") && line.contains("score mate") {
508 if let Some(mate_pos) = line.find("score mate ") {
509 let mate_str = &line[mate_pos + 11..];
510 if let Some(end) = mate_str.find(' ') {
511 if let Ok(mate_moves) = mate_str[..end].parse::<i32>() {
512 last_evaluation = if mate_moves > 0 { 100.0 } else { -100.0 };
513 }
514 }
515 }
516 } else if line.starts_with("bestmove") {
517 break;
518 }
519 }
520
521 Ok(last_evaluation)
522 }
523}
524
525impl Drop for StockfishProcess {
526 fn drop(&mut self) {
527 let _ = self.send_command("quit");
528 let _ = self.child.wait();
529 }
530}
531
532pub struct StockfishPool {
534 pool: Arc<Mutex<Vec<StockfishProcess>>>,
535 depth: u8,
536 pool_size: usize,
537}
538
539impl StockfishPool {
540 pub fn new(depth: u8, pool_size: usize) -> Result<Self, Box<dyn std::error::Error>> {
541 let mut processes = Vec::with_capacity(pool_size);
542
543 println!("🚀 Initializing Stockfish pool with {pool_size} processes...");
544
545 for i in 0..pool_size {
546 match StockfishProcess::new(depth) {
547 Ok(process) => {
548 processes.push(process);
549 if i % 2 == 1 {
550 print!(".");
551 let _ = std::io::stdout().flush(); }
553 }
554 Err(e) => {
555 eprintln!("Evaluation error: {e}");
556 return Err(e);
557 }
558 }
559 }
560
561 println!(" ✅ Pool ready!");
562
563 Ok(Self {
564 pool: Arc::new(Mutex::new(processes)),
565 depth,
566 pool_size,
567 })
568 }
569
570 pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
571 let mut process = {
573 let mut pool = self.pool.lock().unwrap();
574 if let Some(process) = pool.pop() {
575 process
576 } else {
577 StockfishProcess::new(self.depth)?
579 }
580 };
581
582 let result = process.evaluate_position(board);
584
585 {
587 let mut pool = self.pool.lock().unwrap();
588 if pool.len() < self.pool_size {
589 pool.push(process);
590 }
591 }
593
594 result
595 }
596
597 pub fn evaluate_batch_parallel(
598 &self,
599 positions: &mut [TrainingData],
600 ) -> Result<(), Box<dyn std::error::Error>> {
601 let pb = ProgressBar::new(positions.len() as u64);
602 pb.set_style(ProgressStyle::default_bar()
603 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Pool evaluation")
604 .unwrap()
605 .progress_chars("#>-"));
606
607 positions.par_iter_mut().for_each(|data| {
609 match self.evaluate_position(&data.board) {
610 Ok(eval) => {
611 data.evaluation = eval;
612 data.depth = self.depth;
613 }
614 Err(_) => {
615 data.evaluation = 0.0;
616 }
617 }
618 pb.inc(1);
619 });
620
621 pb.finish_with_message("Pool evaluation complete");
622 Ok(())
623 }
624}
625
626pub struct TrainingDataset {
628 pub data: Vec<TrainingData>,
629}
630
631impl Default for TrainingDataset {
632 fn default() -> Self {
633 Self::new()
634 }
635}
636
637impl TrainingDataset {
638 pub fn new() -> Self {
639 Self { data: Vec::new() }
640 }
641
642 pub fn load_from_pgn<P: AsRef<Path>>(
644 &mut self,
645 path: P,
646 max_games: Option<usize>,
647 max_moves_per_game: usize,
648 ) -> Result<(), Box<dyn std::error::Error>> {
649 let file = File::open(path)?;
650 let reader = BufReader::new(file);
651
652 let mut extractor = GameExtractor::new(max_moves_per_game);
653 let mut games_processed = 0;
654
655 let mut pgn_content = String::new();
657 for line in reader.lines() {
658 let line = line?;
659 pgn_content.push_str(&line);
660 pgn_content.push('\n');
661
662 if line.trim().ends_with("1-0")
664 || line.trim().ends_with("0-1")
665 || line.trim().ends_with("1/2-1/2")
666 || line.trim().ends_with("*")
667 {
668 let cursor = std::io::Cursor::new(&pgn_content);
670 let mut reader = BufferedReader::new(cursor);
671 if let Err(e) = reader.read_all(&mut extractor) {
672 eprintln!("Evaluation error: {e}");
673 }
674
675 games_processed += 1;
676 pgn_content.clear();
677
678 if let Some(max) = max_games {
679 if games_processed >= max {
680 break;
681 }
682 }
683
684 if games_processed % 100 == 0 {
685 println!(
686 "Processed {} games, extracted {} positions",
687 games_processed,
688 extractor.positions.len()
689 );
690 }
691 }
692 }
693
694 self.data.extend(extractor.positions);
695 println!(
696 "Loaded {} positions from {} games",
697 self.data.len(),
698 games_processed
699 );
700 Ok(())
701 }
702
703 pub fn evaluate_with_stockfish(&mut self, depth: u8) -> Result<(), Box<dyn std::error::Error>> {
705 let evaluator = StockfishEvaluator::new(depth);
706 evaluator.evaluate_batch(&mut self.data)
707 }
708
709 pub fn evaluate_with_stockfish_parallel(
711 &mut self,
712 depth: u8,
713 num_threads: usize,
714 ) -> Result<(), Box<dyn std::error::Error>> {
715 let evaluator = StockfishEvaluator::new(depth);
716 evaluator.evaluate_batch_parallel(&mut self.data, num_threads)
717 }
718
719 pub fn train_engine(&self, engine: &mut ChessVectorEngine) {
721 let pb = ProgressBar::new(self.data.len() as u64);
722 pb.set_style(ProgressStyle::default_bar()
723 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Training positions")
724 .unwrap()
725 .progress_chars("#>-"));
726
727 for data in &self.data {
728 engine.add_position(&data.board, data.evaluation);
729 pb.inc(1);
730 }
731
732 pb.finish_with_message("Training complete");
733 println!("Trained engine with {} positions", self.data.len());
734 }
735
736 pub fn split(&self, train_ratio: f32) -> (TrainingDataset, TrainingDataset) {
738 use rand::seq::SliceRandom;
739 use rand::thread_rng;
740 use std::collections::{HashMap, HashSet};
741
742 let mut games: HashMap<usize, Vec<&TrainingData>> = HashMap::new();
744 for data in &self.data {
745 games.entry(data.game_id).or_default().push(data);
746 }
747
748 let mut game_ids: Vec<usize> = games.keys().cloned().collect();
750 game_ids.shuffle(&mut thread_rng());
751
752 let split_point = (game_ids.len() as f32 * train_ratio) as usize;
754 let train_game_ids: HashSet<usize> = game_ids[..split_point].iter().cloned().collect();
755
756 let mut train_data = Vec::new();
758 let mut test_data = Vec::new();
759
760 for data in &self.data {
761 if train_game_ids.contains(&data.game_id) {
762 train_data.push(data.clone());
763 } else {
764 test_data.push(data.clone());
765 }
766 }
767
768 (
769 TrainingDataset { data: train_data },
770 TrainingDataset { data: test_data },
771 )
772 }
773
774 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
776 let json = serde_json::to_string_pretty(&self.data)?;
777 std::fs::write(path, json)?;
778 Ok(())
779 }
780
781 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
783 let content = std::fs::read_to_string(path)?;
784 let data = serde_json::from_str(&content)?;
785 Ok(Self { data })
786 }
787
788 pub fn load_and_append<P: AsRef<Path>>(
790 &mut self,
791 path: P,
792 ) -> Result<(), Box<dyn std::error::Error>> {
793 let existing_len = self.data.len();
794 let additional_data = Self::load(path)?;
795 self.data.extend(additional_data.data);
796 println!(
797 "Loaded {} additional positions (total: {})",
798 self.data.len() - existing_len,
799 self.data.len()
800 );
801 Ok(())
802 }
803
804 pub fn merge(&mut self, other: TrainingDataset) {
806 let existing_len = self.data.len();
807 self.data.extend(other.data);
808 println!(
809 "Merged {} positions (total: {})",
810 self.data.len() - existing_len,
811 self.data.len()
812 );
813 }
814
815 pub fn save_incremental<P: AsRef<Path>>(
817 &self,
818 path: P,
819 ) -> Result<(), Box<dyn std::error::Error>> {
820 self.save_incremental_with_options(path, true)
821 }
822
823 pub fn save_incremental_with_options<P: AsRef<Path>>(
825 &self,
826 path: P,
827 deduplicate: bool,
828 ) -> Result<(), Box<dyn std::error::Error>> {
829 let path = path.as_ref();
830
831 if path.exists() {
832 if self.save_append_only(path).is_ok() {
834 return Ok(());
835 }
836
837 if deduplicate {
839 self.save_incremental_full_merge(path)
840 } else {
841 self.save_incremental_no_dedup(path)
842 }
843 } else {
844 self.save(path)
846 }
847 }
848
849 fn save_incremental_no_dedup<P: AsRef<Path>>(
851 &self,
852 path: P,
853 ) -> Result<(), Box<dyn std::error::Error>> {
854 let path = path.as_ref();
855
856 println!("📂 Loading existing training data...");
857 let mut existing = Self::load(path)?;
858
859 println!("⚡ Fast merge without deduplication...");
860 existing.data.extend(self.data.iter().cloned());
861
862 println!(
863 "💾 Serializing {} positions to JSON...",
864 existing.data.len()
865 );
866 let json = serde_json::to_string_pretty(&existing.data)?;
867
868 println!("✍️ Writing to disk...");
869 std::fs::write(path, json)?;
870
871 println!(
872 "✅ Fast merge save: total {} positions",
873 existing.data.len()
874 );
875 Ok(())
876 }
877
878 pub fn save_append_only<P: AsRef<Path>>(
880 &self,
881 path: P,
882 ) -> Result<(), Box<dyn std::error::Error>> {
883 use std::fs::OpenOptions;
884 use std::io::{BufRead, BufReader, Seek, SeekFrom, Write};
885
886 if self.data.is_empty() {
887 return Ok(());
888 }
889
890 let path = path.as_ref();
891 let mut file = OpenOptions::new().read(true).write(true).open(path)?;
892
893 file.seek(SeekFrom::End(-10))?;
895 let mut buffer = String::new();
896 BufReader::new(&file).read_line(&mut buffer)?;
897
898 if !buffer.trim().ends_with(']') {
899 return Err("File doesn't end with JSON array bracket".into());
900 }
901
902 file.seek(SeekFrom::End(-2))?; write!(file, ",")?;
907
908 for (i, data) in self.data.iter().enumerate() {
910 if i > 0 {
911 write!(file, ",")?;
912 }
913 let json = serde_json::to_string(data)?;
914 write!(file, "{json}")?;
915 }
916
917 write!(file, "\n]")?;
919
920 println!("Fast append: added {} new positions", self.data.len());
921 Ok(())
922 }
923
924 fn save_incremental_full_merge<P: AsRef<Path>>(
926 &self,
927 path: P,
928 ) -> Result<(), Box<dyn std::error::Error>> {
929 let path = path.as_ref();
930
931 println!("📂 Loading existing training data...");
932 let mut existing = Self::load(path)?;
933 let _original_len = existing.data.len();
934
935 println!("🔄 Streaming merge with deduplication (avoiding O(n²) operation)...");
936 existing.merge_and_deduplicate(self.data.clone());
937
938 println!(
939 "💾 Serializing {} positions to JSON...",
940 existing.data.len()
941 );
942 let json = serde_json::to_string_pretty(&existing.data)?;
943
944 println!("✍️ Writing to disk...");
945 std::fs::write(path, json)?;
946
947 println!(
948 "✅ Streaming merge save: total {} positions",
949 existing.data.len()
950 );
951 Ok(())
952 }
953
954 pub fn add_position(&mut self, board: Board, evaluation: f32, depth: u8, game_id: usize) {
956 self.data.push(TrainingData {
957 board,
958 evaluation,
959 depth,
960 game_id,
961 });
962 }
963
964 pub fn next_game_id(&self) -> usize {
966 self.data.iter().map(|data| data.game_id).max().unwrap_or(0) + 1
967 }
968
969 pub fn deduplicate(&mut self, similarity_threshold: f32) {
971 if similarity_threshold > 0.999 {
972 self.deduplicate_fast();
974 } else {
975 self.deduplicate_similarity_based(similarity_threshold);
977 }
978 }
979
980 pub fn deduplicate_fast(&mut self) {
982 use std::collections::HashSet;
983
984 if self.data.is_empty() {
985 return;
986 }
987
988 let mut seen_positions = HashSet::with_capacity(self.data.len());
989 let original_len = self.data.len();
990
991 self.data.retain(|data| {
993 let fen = data.board.to_string();
994 seen_positions.insert(fen)
995 });
996
997 println!(
998 "Fast deduplicated: {} -> {} positions (removed {} exact duplicates)",
999 original_len,
1000 self.data.len(),
1001 original_len - self.data.len()
1002 );
1003 }
1004
1005 pub fn merge_and_deduplicate(&mut self, new_data: Vec<TrainingData>) {
1007 use std::collections::HashSet;
1008
1009 if new_data.is_empty() {
1010 return;
1011 }
1012
1013 let _original_len = self.data.len();
1014
1015 let mut existing_positions: HashSet<String> = HashSet::with_capacity(self.data.len());
1017 for data in &self.data {
1018 existing_positions.insert(data.board.to_string());
1019 }
1020
1021 let mut added = 0;
1023 for data in new_data {
1024 let fen = data.board.to_string();
1025 if existing_positions.insert(fen) {
1026 self.data.push(data);
1027 added += 1;
1028 }
1029 }
1030
1031 println!(
1032 "Streaming merge: added {} unique positions (total: {})",
1033 added,
1034 self.data.len()
1035 );
1036 }
1037
1038 fn deduplicate_similarity_based(&mut self, similarity_threshold: f32) {
1040 use crate::PositionEncoder;
1041 use ndarray::Array1;
1042
1043 if self.data.is_empty() {
1044 return;
1045 }
1046
1047 let encoder = PositionEncoder::new(1024);
1048 let mut keep_indices: Vec<bool> = vec![true; self.data.len()];
1049
1050 let vectors: Vec<Array1<f32>> = if self.data.len() > 50 {
1052 self.data
1053 .par_iter()
1054 .map(|data| encoder.encode(&data.board))
1055 .collect()
1056 } else {
1057 self.data
1058 .iter()
1059 .map(|data| encoder.encode(&data.board))
1060 .collect()
1061 };
1062
1063 for i in 1..self.data.len() {
1065 if !keep_indices[i] {
1066 continue;
1067 }
1068
1069 for j in 0..i {
1070 if !keep_indices[j] {
1071 continue;
1072 }
1073
1074 let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1075 if similarity > similarity_threshold {
1076 keep_indices[i] = false;
1077 break;
1078 }
1079 }
1080 }
1081
1082 let original_len = self.data.len();
1084 self.data = self
1085 .data
1086 .iter()
1087 .enumerate()
1088 .filter_map(|(i, data)| {
1089 if keep_indices[i] {
1090 Some(data.clone())
1091 } else {
1092 None
1093 }
1094 })
1095 .collect();
1096
1097 println!(
1098 "Similarity deduplicated: {} -> {} positions (removed {} near-duplicates)",
1099 original_len,
1100 self.data.len(),
1101 original_len - self.data.len()
1102 );
1103 }
1104
1105 pub fn deduplicate_parallel(&mut self, similarity_threshold: f32, chunk_size: usize) {
1107 use crate::PositionEncoder;
1108 use ndarray::Array1;
1109 use std::sync::{Arc, Mutex};
1110
1111 if self.data.is_empty() {
1112 return;
1113 }
1114
1115 let encoder = PositionEncoder::new(1024);
1116
1117 let vectors: Vec<Array1<f32>> = self
1119 .data
1120 .par_iter()
1121 .map(|data| encoder.encode(&data.board))
1122 .collect();
1123
1124 let keep_indices = Arc::new(Mutex::new(vec![true; self.data.len()]));
1125
1126 (1..self.data.len())
1128 .collect::<Vec<_>>()
1129 .par_chunks(chunk_size)
1130 .for_each(|chunk| {
1131 for &i in chunk {
1132 {
1134 let indices = keep_indices.lock().unwrap();
1135 if !indices[i] {
1136 continue;
1137 }
1138 }
1139
1140 for j in 0..i {
1142 {
1143 let indices = keep_indices.lock().unwrap();
1144 if !indices[j] {
1145 continue;
1146 }
1147 }
1148
1149 let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1150 if similarity > similarity_threshold {
1151 let mut indices = keep_indices.lock().unwrap();
1152 indices[i] = false;
1153 break;
1154 }
1155 }
1156 }
1157 });
1158
1159 let keep_indices = keep_indices.lock().unwrap();
1161 let original_len = self.data.len();
1162 self.data = self
1163 .data
1164 .iter()
1165 .enumerate()
1166 .filter_map(|(i, data)| {
1167 if keep_indices[i] {
1168 Some(data.clone())
1169 } else {
1170 None
1171 }
1172 })
1173 .collect();
1174
1175 println!(
1176 "Parallel deduplicated: {} -> {} positions (removed {} duplicates)",
1177 original_len,
1178 self.data.len(),
1179 original_len - self.data.len()
1180 );
1181 }
1182
1183 fn cosine_similarity(a: &ndarray::Array1<f32>, b: &ndarray::Array1<f32>) -> f32 {
1185 let dot_product = a.dot(b);
1186 let norm_a = a.dot(a).sqrt();
1187 let norm_b = b.dot(b).sqrt();
1188
1189 if norm_a == 0.0 || norm_b == 0.0 {
1190 0.0
1191 } else {
1192 dot_product / (norm_a * norm_b)
1193 }
1194 }
1195}
1196
1197pub struct SelfPlayTrainer {
1199 config: SelfPlayConfig,
1200 game_counter: usize,
1201}
1202
1203impl SelfPlayTrainer {
1204 pub fn new(config: SelfPlayConfig) -> Self {
1205 Self {
1206 config,
1207 game_counter: 0,
1208 }
1209 }
1210
1211 pub fn generate_training_data(&mut self, engine: &mut ChessVectorEngine) -> TrainingDataset {
1213 let mut dataset = TrainingDataset::new();
1214
1215 println!(
1216 "🎮 Starting self-play training with {} games...",
1217 self.config.games_per_iteration
1218 );
1219 let pb = ProgressBar::new(self.config.games_per_iteration as u64);
1220 if let Ok(style) = ProgressStyle::default_bar().template(
1221 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
1222 ) {
1223 pb.set_style(style.progress_chars("#>-"));
1224 }
1225
1226 for _ in 0..self.config.games_per_iteration {
1227 let game_data = self.play_single_game(engine);
1228 dataset.data.extend(game_data);
1229 self.game_counter += 1;
1230 pb.inc(1);
1231 }
1232
1233 pb.finish_with_message("Self-play games completed");
1234 println!(
1235 "✅ Generated {} positions from {} games",
1236 dataset.data.len(),
1237 self.config.games_per_iteration
1238 );
1239
1240 dataset
1241 }
1242
1243 fn play_single_game(&self, engine: &mut ChessVectorEngine) -> Vec<TrainingData> {
1245 let mut game = Game::new();
1246 let mut positions = Vec::new();
1247 let mut move_count = 0;
1248
1249 if self.config.use_opening_book {
1251 if let Some(opening_moves) = self.get_random_opening() {
1252 for mv in opening_moves {
1253 if game.make_move(mv) {
1254 move_count += 1;
1255 } else {
1256 break;
1257 }
1258 }
1259 }
1260 }
1261
1262 while game.result().is_none() && move_count < self.config.max_moves_per_game {
1264 let current_position = game.current_position();
1265
1266 let move_choice = self.select_move_with_exploration(engine, ¤t_position);
1268
1269 if let Some(chess_move) = move_choice {
1270 if let Some(evaluation) = engine.evaluate_position(¤t_position) {
1272 if evaluation.abs() >= self.config.min_confidence || move_count < 10 {
1274 positions.push(TrainingData {
1275 board: current_position,
1276 evaluation,
1277 depth: 1, game_id: self.game_counter,
1279 });
1280 }
1281 }
1282
1283 if !game.make_move(chess_move) {
1285 break; }
1287 move_count += 1;
1288 } else {
1289 break; }
1291 }
1292
1293 if let Some(result) = game.result() {
1295 let final_position = game.current_position();
1296 let final_eval = match result {
1297 chess::GameResult::WhiteCheckmates => {
1298 if final_position.side_to_move() == chess::Color::Black {
1299 10.0
1300 } else {
1301 -10.0
1302 }
1303 }
1304 chess::GameResult::BlackCheckmates => {
1305 if final_position.side_to_move() == chess::Color::White {
1306 10.0
1307 } else {
1308 -10.0
1309 }
1310 }
1311 chess::GameResult::WhiteResigns => -10.0,
1312 chess::GameResult::BlackResigns => 10.0,
1313 chess::GameResult::Stalemate
1314 | chess::GameResult::DrawAccepted
1315 | chess::GameResult::DrawDeclared => 0.0,
1316 };
1317
1318 positions.push(TrainingData {
1319 board: final_position,
1320 evaluation: final_eval,
1321 depth: 1,
1322 game_id: self.game_counter,
1323 });
1324 }
1325
1326 positions
1327 }
1328
1329 fn select_move_with_exploration(
1331 &self,
1332 engine: &mut ChessVectorEngine,
1333 position: &Board,
1334 ) -> Option<ChessMove> {
1335 let recommendations = engine.recommend_legal_moves(position, 5);
1336
1337 if recommendations.is_empty() {
1338 return None;
1339 }
1340
1341 if fastrand::f32() < self.config.exploration_factor {
1343 self.select_move_with_temperature(&recommendations)
1345 } else {
1346 Some(recommendations[0].chess_move)
1348 }
1349 }
1350
1351 fn select_move_with_temperature(
1353 &self,
1354 recommendations: &[crate::MoveRecommendation],
1355 ) -> Option<ChessMove> {
1356 if recommendations.is_empty() {
1357 return None;
1358 }
1359
1360 let mut probabilities = Vec::new();
1362 let mut sum = 0.0;
1363
1364 for rec in recommendations {
1365 let prob = (rec.average_outcome / self.config.temperature).exp();
1367 probabilities.push(prob);
1368 sum += prob;
1369 }
1370
1371 for prob in &mut probabilities {
1373 *prob /= sum;
1374 }
1375
1376 let rand_val = fastrand::f32();
1378 let mut cumulative = 0.0;
1379
1380 for (i, &prob) in probabilities.iter().enumerate() {
1381 cumulative += prob;
1382 if rand_val <= cumulative {
1383 return Some(recommendations[i].chess_move);
1384 }
1385 }
1386
1387 Some(recommendations[0].chess_move)
1389 }
1390
1391 fn get_random_opening(&self) -> Option<Vec<ChessMove>> {
1393 let openings = [
1394 vec!["e4", "e5", "Nf3", "Nc6", "Bc4"],
1396 vec!["e4", "e5", "Nf3", "Nc6", "Bb5"],
1398 vec!["d4", "d5", "c4"],
1400 vec!["d4", "Nf6", "c4", "g6"],
1402 vec!["e4", "c5"],
1404 vec!["e4", "e6"],
1406 vec!["e4", "c6"],
1408 ];
1409
1410 let selected_opening = &openings[fastrand::usize(0..openings.len())];
1411
1412 let mut moves = Vec::new();
1413 let mut game = Game::new();
1414
1415 for move_str in selected_opening {
1416 if let Ok(chess_move) = ChessMove::from_str(move_str) {
1417 if game.make_move(chess_move) {
1418 moves.push(chess_move);
1419 } else {
1420 break;
1421 }
1422 }
1423 }
1424
1425 if moves.is_empty() {
1426 None
1427 } else {
1428 Some(moves)
1429 }
1430 }
1431}
1432
1433pub struct EngineEvaluator {
1435 #[allow(dead_code)]
1436 stockfish_depth: u8,
1437}
1438
1439impl EngineEvaluator {
1440 pub fn new(stockfish_depth: u8) -> Self {
1441 Self { stockfish_depth }
1442 }
1443
1444 pub fn evaluate_accuracy(
1446 &self,
1447 engine: &mut ChessVectorEngine,
1448 test_data: &TrainingDataset,
1449 ) -> Result<f32, Box<dyn std::error::Error>> {
1450 let mut total_error = 0.0;
1451 let mut valid_comparisons = 0;
1452
1453 let pb = ProgressBar::new(test_data.data.len() as u64);
1454 pb.set_style(ProgressStyle::default_bar()
1455 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Evaluating accuracy")
1456 .unwrap()
1457 .progress_chars("#>-"));
1458
1459 for data in &test_data.data {
1460 if let Some(engine_eval) = engine.evaluate_position(&data.board) {
1461 let error = (engine_eval - data.evaluation).abs();
1462 total_error += error;
1463 valid_comparisons += 1;
1464 }
1465 pb.inc(1);
1466 }
1467
1468 pb.finish_with_message("Accuracy evaluation complete");
1469
1470 if valid_comparisons > 0 {
1471 let mean_absolute_error = total_error / valid_comparisons as f32;
1472 println!("Mean Absolute Error: {mean_absolute_error:.3} pawns");
1473 println!("Evaluated {valid_comparisons} positions");
1474 Ok(mean_absolute_error)
1475 } else {
1476 Ok(f32::INFINITY)
1477 }
1478 }
1479}
1480
1481impl serde::Serialize for TrainingData {
1483 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1484 where
1485 S: serde::Serializer,
1486 {
1487 use serde::ser::SerializeStruct;
1488 let mut state = serializer.serialize_struct("TrainingData", 4)?;
1489 state.serialize_field("fen", &self.board.to_string())?;
1490 state.serialize_field("evaluation", &self.evaluation)?;
1491 state.serialize_field("depth", &self.depth)?;
1492 state.serialize_field("game_id", &self.game_id)?;
1493 state.end()
1494 }
1495}
1496
1497impl<'de> serde::Deserialize<'de> for TrainingData {
1498 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1499 where
1500 D: serde::Deserializer<'de>,
1501 {
1502 use serde::de::{self, MapAccess, Visitor};
1503 use std::fmt;
1504
1505 struct TrainingDataVisitor;
1506
1507 impl<'de> Visitor<'de> for TrainingDataVisitor {
1508 type Value = TrainingData;
1509
1510 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1511 formatter.write_str("struct TrainingData")
1512 }
1513
1514 fn visit_map<V>(self, mut map: V) -> Result<TrainingData, V::Error>
1515 where
1516 V: MapAccess<'de>,
1517 {
1518 let mut fen = None;
1519 let mut evaluation = None;
1520 let mut depth = None;
1521 let mut game_id = None;
1522
1523 while let Some(key) = map.next_key()? {
1524 match key {
1525 "fen" => {
1526 if fen.is_some() {
1527 return Err(de::Error::duplicate_field("fen"));
1528 }
1529 fen = Some(map.next_value()?);
1530 }
1531 "evaluation" => {
1532 if evaluation.is_some() {
1533 return Err(de::Error::duplicate_field("evaluation"));
1534 }
1535 evaluation = Some(map.next_value()?);
1536 }
1537 "depth" => {
1538 if depth.is_some() {
1539 return Err(de::Error::duplicate_field("depth"));
1540 }
1541 depth = Some(map.next_value()?);
1542 }
1543 "game_id" => {
1544 if game_id.is_some() {
1545 return Err(de::Error::duplicate_field("game_id"));
1546 }
1547 game_id = Some(map.next_value()?);
1548 }
1549 _ => {
1550 let _: serde_json::Value = map.next_value()?;
1551 }
1552 }
1553 }
1554
1555 let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
1556 let mut evaluation: f32 =
1557 evaluation.ok_or_else(|| de::Error::missing_field("evaluation"))?;
1558 let depth = depth.ok_or_else(|| de::Error::missing_field("depth"))?;
1559 let game_id = game_id.unwrap_or(0); if evaluation.abs() > 15.0 {
1565 evaluation /= 100.0;
1566 }
1567
1568 let board =
1569 Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
1570
1571 Ok(TrainingData {
1572 board,
1573 evaluation,
1574 depth,
1575 game_id,
1576 })
1577 }
1578 }
1579
1580 const FIELDS: &[&str] = &["fen", "evaluation", "depth", "game_id"];
1581 deserializer.deserialize_struct("TrainingData", FIELDS, TrainingDataVisitor)
1582 }
1583}
1584
1585pub struct TacticalPuzzleParser;
1587
1588impl TacticalPuzzleParser {
1589 pub fn parse_csv<P: AsRef<Path>>(
1591 file_path: P,
1592 max_puzzles: Option<usize>,
1593 min_rating: Option<u32>,
1594 max_rating: Option<u32>,
1595 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1596 let file = File::open(&file_path)?;
1597 let file_size = file.metadata()?.len();
1598
1599 if file_size > 100_000_000 {
1601 Self::parse_csv_parallel(file_path, max_puzzles, min_rating, max_rating)
1602 } else {
1603 Self::parse_csv_sequential(file_path, max_puzzles, min_rating, max_rating)
1604 }
1605 }
1606
1607 fn parse_csv_sequential<P: AsRef<Path>>(
1609 file_path: P,
1610 max_puzzles: Option<usize>,
1611 min_rating: Option<u32>,
1612 max_rating: Option<u32>,
1613 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1614 let file = File::open(file_path)?;
1615 let reader = BufReader::new(file);
1616
1617 let mut csv_reader = csv::ReaderBuilder::new()
1620 .has_headers(false)
1621 .flexible(true) .from_reader(reader);
1623
1624 let mut tactical_data = Vec::new();
1625 let mut processed = 0;
1626 let mut skipped = 0;
1627
1628 let pb = ProgressBar::new_spinner();
1629 pb.set_style(
1630 ProgressStyle::default_spinner()
1631 .template("{spinner:.green} Parsing tactical puzzles: {pos} (skipped: {skipped})")
1632 .unwrap(),
1633 );
1634
1635 for result in csv_reader.records() {
1636 let record = match result {
1637 Ok(r) => r,
1638 Err(e) => {
1639 skipped += 1;
1640 println!("CSV parsing error: {e}");
1641 continue;
1642 }
1643 };
1644
1645 if let Some(puzzle_data) = Self::parse_csv_record(&record, min_rating, max_rating) {
1646 if let Some(tactical_data_item) =
1647 Self::convert_puzzle_to_training_data(&puzzle_data)
1648 {
1649 tactical_data.push(tactical_data_item);
1650 processed += 1;
1651
1652 if let Some(max) = max_puzzles {
1653 if processed >= max {
1654 break;
1655 }
1656 }
1657 } else {
1658 skipped += 1;
1659 }
1660 } else {
1661 skipped += 1;
1662 }
1663
1664 pb.set_message(format!(
1665 "Parsing tactical puzzles: {processed} (skipped: {skipped})"
1666 ));
1667 }
1668
1669 pb.finish_with_message(format!("Parsed {processed} puzzles (skipped: {skipped})"));
1670
1671 Ok(tactical_data)
1672 }
1673
1674 fn parse_csv_parallel<P: AsRef<Path>>(
1676 file_path: P,
1677 max_puzzles: Option<usize>,
1678 min_rating: Option<u32>,
1679 max_rating: Option<u32>,
1680 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1681 use std::io::Read;
1682
1683 let mut file = File::open(&file_path)?;
1684
1685 let mut contents = String::new();
1687 file.read_to_string(&mut contents)?;
1688
1689 let lines: Vec<&str> = contents.lines().collect();
1691
1692 let pb = ProgressBar::new(lines.len() as u64);
1693 pb.set_style(ProgressStyle::default_bar()
1694 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel CSV parsing")
1695 .unwrap()
1696 .progress_chars("#>-"));
1697
1698 let tactical_data: Vec<TacticalTrainingData> = lines
1700 .par_iter()
1701 .take(max_puzzles.unwrap_or(usize::MAX))
1702 .filter_map(|line| {
1703 let fields: Vec<&str> = line.split(',').collect();
1705 if fields.len() < 8 {
1706 return None;
1707 }
1708
1709 if let Some(puzzle_data) = Self::parse_csv_fields(&fields, min_rating, max_rating) {
1711 Self::convert_puzzle_to_training_data(&puzzle_data)
1712 } else {
1713 None
1714 }
1715 })
1716 .collect();
1717
1718 pb.finish_with_message(format!(
1719 "Parallel parsing complete: {} puzzles",
1720 tactical_data.len()
1721 ));
1722
1723 Ok(tactical_data)
1724 }
1725
1726 fn parse_csv_record(
1728 record: &csv::StringRecord,
1729 min_rating: Option<u32>,
1730 max_rating: Option<u32>,
1731 ) -> Option<TacticalPuzzle> {
1732 if record.len() < 8 {
1734 return None;
1735 }
1736
1737 let rating: u32 = record[3].parse().ok()?;
1738 let rating_deviation: u32 = record[4].parse().ok()?;
1739 let popularity: i32 = record[5].parse().ok()?;
1740 let nb_plays: u32 = record[6].parse().ok()?;
1741
1742 if let Some(min) = min_rating {
1744 if rating < min {
1745 return None;
1746 }
1747 }
1748 if let Some(max) = max_rating {
1749 if rating > max {
1750 return None;
1751 }
1752 }
1753
1754 Some(TacticalPuzzle {
1755 puzzle_id: record[0].to_string(),
1756 fen: record[1].to_string(),
1757 moves: record[2].to_string(),
1758 rating,
1759 rating_deviation,
1760 popularity,
1761 nb_plays,
1762 themes: record[7].to_string(),
1763 game_url: if record.len() > 8 {
1764 Some(record[8].to_string())
1765 } else {
1766 None
1767 },
1768 opening_tags: if record.len() > 9 {
1769 Some(record[9].to_string())
1770 } else {
1771 None
1772 },
1773 })
1774 }
1775
1776 fn parse_csv_fields(
1778 fields: &[&str],
1779 min_rating: Option<u32>,
1780 max_rating: Option<u32>,
1781 ) -> Option<TacticalPuzzle> {
1782 if fields.len() < 8 {
1783 return None;
1784 }
1785
1786 let rating: u32 = fields[3].parse().ok()?;
1787 let rating_deviation: u32 = fields[4].parse().ok()?;
1788 let popularity: i32 = fields[5].parse().ok()?;
1789 let nb_plays: u32 = fields[6].parse().ok()?;
1790
1791 if let Some(min) = min_rating {
1793 if rating < min {
1794 return None;
1795 }
1796 }
1797 if let Some(max) = max_rating {
1798 if rating > max {
1799 return None;
1800 }
1801 }
1802
1803 Some(TacticalPuzzle {
1804 puzzle_id: fields[0].to_string(),
1805 fen: fields[1].to_string(),
1806 moves: fields[2].to_string(),
1807 rating,
1808 rating_deviation,
1809 popularity,
1810 nb_plays,
1811 themes: fields[7].to_string(),
1812 game_url: if fields.len() > 8 {
1813 Some(fields[8].to_string())
1814 } else {
1815 None
1816 },
1817 opening_tags: if fields.len() > 9 {
1818 Some(fields[9].to_string())
1819 } else {
1820 None
1821 },
1822 })
1823 }
1824
1825 fn convert_puzzle_to_training_data(puzzle: &TacticalPuzzle) -> Option<TacticalTrainingData> {
1827 let position = match Board::from_str(&puzzle.fen) {
1829 Ok(board) => board,
1830 Err(_) => return None,
1831 };
1832
1833 let moves: Vec<&str> = puzzle.moves.split_whitespace().collect();
1835 if moves.is_empty() {
1836 return None;
1837 }
1838
1839 let solution_move = match ChessMove::from_str(moves[0]) {
1841 Ok(mv) => mv,
1842 Err(_) => {
1843 match ChessMove::from_san(&position, moves[0]) {
1845 Ok(mv) => mv,
1846 Err(_) => return None,
1847 }
1848 }
1849 };
1850
1851 let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&position).collect();
1853 if !legal_moves.contains(&solution_move) {
1854 return None;
1855 }
1856
1857 let themes: Vec<&str> = puzzle.themes.split_whitespace().collect();
1859 let primary_theme = themes.first().unwrap_or(&"tactical").to_string();
1860
1861 let difficulty = puzzle.rating as f32 / 1000.0; let popularity_bonus = (puzzle.popularity as f32 / 100.0).min(2.0);
1864 let tactical_value = difficulty + popularity_bonus; Some(TacticalTrainingData {
1867 position,
1868 solution_move,
1869 move_theme: primary_theme,
1870 difficulty,
1871 tactical_value,
1872 })
1873 }
1874
1875 pub fn load_into_engine(
1877 tactical_data: &[TacticalTrainingData],
1878 engine: &mut ChessVectorEngine,
1879 ) {
1880 let pb = ProgressBar::new(tactical_data.len() as u64);
1881 pb.set_style(ProgressStyle::default_bar()
1882 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns")
1883 .unwrap()
1884 .progress_chars("#>-"));
1885
1886 for data in tactical_data {
1887 engine.add_position_with_move(
1889 &data.position,
1890 0.0, Some(data.solution_move),
1892 Some(data.tactical_value), );
1894 pb.inc(1);
1895 }
1896
1897 pb.finish_with_message(format!("Loaded {} tactical patterns", tactical_data.len()));
1898 }
1899
1900 pub fn load_into_engine_incremental(
1902 tactical_data: &[TacticalTrainingData],
1903 engine: &mut ChessVectorEngine,
1904 ) {
1905 let initial_size = engine.knowledge_base_size();
1906 let initial_moves = engine.position_moves.len();
1907
1908 if tactical_data.len() > 1000 {
1910 Self::load_into_engine_incremental_parallel(
1911 tactical_data,
1912 engine,
1913 initial_size,
1914 initial_moves,
1915 );
1916 } else {
1917 Self::load_into_engine_incremental_sequential(
1918 tactical_data,
1919 engine,
1920 initial_size,
1921 initial_moves,
1922 );
1923 }
1924 }
1925
1926 fn load_into_engine_incremental_sequential(
1928 tactical_data: &[TacticalTrainingData],
1929 engine: &mut ChessVectorEngine,
1930 initial_size: usize,
1931 initial_moves: usize,
1932 ) {
1933 let pb = ProgressBar::new(tactical_data.len() as u64);
1934 pb.set_style(ProgressStyle::default_bar()
1935 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns (incremental)")
1936 .unwrap()
1937 .progress_chars("#>-"));
1938
1939 let mut added = 0;
1940 let mut skipped = 0;
1941
1942 for data in tactical_data {
1943 if !engine.position_boards.contains(&data.position) {
1945 engine.add_position_with_move(
1946 &data.position,
1947 0.0, Some(data.solution_move),
1949 Some(data.tactical_value), );
1951 added += 1;
1952 } else {
1953 skipped += 1;
1954 }
1955 pb.inc(1);
1956 }
1957
1958 pb.finish_with_message(format!(
1959 "Loaded {} new tactical patterns (skipped {} duplicates, total: {})",
1960 added,
1961 skipped,
1962 engine.knowledge_base_size()
1963 ));
1964
1965 println!("Incremental tactical training:");
1966 println!(
1967 " - Positions: {} → {} (+{})",
1968 initial_size,
1969 engine.knowledge_base_size(),
1970 engine.knowledge_base_size() - initial_size
1971 );
1972 println!(
1973 " - Move entries: {} → {} (+{})",
1974 initial_moves,
1975 engine.position_moves.len(),
1976 engine.position_moves.len() - initial_moves
1977 );
1978 }
1979
1980 fn load_into_engine_incremental_parallel(
1982 tactical_data: &[TacticalTrainingData],
1983 engine: &mut ChessVectorEngine,
1984 initial_size: usize,
1985 initial_moves: usize,
1986 ) {
1987 let pb = ProgressBar::new(tactical_data.len() as u64);
1988 pb.set_style(ProgressStyle::default_bar()
1989 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Optimized batch loading tactical patterns")
1990 .unwrap()
1991 .progress_chars("#>-"));
1992
1993 let filtered_data: Vec<&TacticalTrainingData> = tactical_data
1995 .par_iter()
1996 .filter(|data| !engine.position_boards.contains(&data.position))
1997 .collect();
1998
1999 let batch_size = 1000; let mut added = 0;
2001
2002 println!(
2003 "Pre-filtered: {} → {} positions (removed {} duplicates)",
2004 tactical_data.len(),
2005 filtered_data.len(),
2006 tactical_data.len() - filtered_data.len()
2007 );
2008
2009 for batch in filtered_data.chunks(batch_size) {
2012 let batch_start = added;
2013
2014 for data in batch {
2015 if !engine.position_boards.contains(&data.position) {
2017 engine.add_position_with_move(
2018 &data.position,
2019 0.0, Some(data.solution_move),
2021 Some(data.tactical_value), );
2023 added += 1;
2024 }
2025 pb.inc(1);
2026 }
2027
2028 pb.set_message(format!("Loaded batch: {} positions", added - batch_start));
2030 }
2031
2032 let skipped = tactical_data.len() - added;
2033
2034 pb.finish_with_message(format!(
2035 "Optimized loaded {} new tactical patterns (skipped {} duplicates, total: {})",
2036 added,
2037 skipped,
2038 engine.knowledge_base_size()
2039 ));
2040
2041 println!("Incremental tactical training (optimized):");
2042 println!(
2043 " - Positions: {} → {} (+{})",
2044 initial_size,
2045 engine.knowledge_base_size(),
2046 engine.knowledge_base_size() - initial_size
2047 );
2048 println!(
2049 " - Move entries: {} → {} (+{})",
2050 initial_moves,
2051 engine.position_moves.len(),
2052 engine.position_moves.len() - initial_moves
2053 );
2054 println!(
2055 " - Batch size: {}, Pre-filtered efficiency: {:.1}%",
2056 batch_size,
2057 (filtered_data.len() as f32 / tactical_data.len() as f32) * 100.0
2058 );
2059 }
2060
2061 pub fn save_tactical_puzzles<P: AsRef<std::path::Path>>(
2063 tactical_data: &[TacticalTrainingData],
2064 path: P,
2065 ) -> Result<(), Box<dyn std::error::Error>> {
2066 let json = serde_json::to_string_pretty(tactical_data)?;
2067 std::fs::write(path, json)?;
2068 println!("Saved {} tactical puzzles", tactical_data.len());
2069 Ok(())
2070 }
2071
2072 pub fn load_tactical_puzzles<P: AsRef<std::path::Path>>(
2074 path: P,
2075 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
2076 let content = std::fs::read_to_string(path)?;
2077 let tactical_data: Vec<TacticalTrainingData> = serde_json::from_str(&content)?;
2078 println!("Loaded {} tactical puzzles from file", tactical_data.len());
2079 Ok(tactical_data)
2080 }
2081
2082 pub fn save_tactical_puzzles_incremental<P: AsRef<std::path::Path>>(
2084 tactical_data: &[TacticalTrainingData],
2085 path: P,
2086 ) -> Result<(), Box<dyn std::error::Error>> {
2087 let path = path.as_ref();
2088
2089 if path.exists() {
2090 let mut existing = Self::load_tactical_puzzles(path)?;
2092 let original_len = existing.len();
2093
2094 for new_puzzle in tactical_data {
2096 let exists = existing.iter().any(|existing_puzzle| {
2098 existing_puzzle.position == new_puzzle.position
2099 && existing_puzzle.solution_move == new_puzzle.solution_move
2100 });
2101
2102 if !exists {
2103 existing.push(new_puzzle.clone());
2104 }
2105 }
2106
2107 let json = serde_json::to_string_pretty(&existing)?;
2109 std::fs::write(path, json)?;
2110
2111 println!(
2112 "Incremental save: added {} new puzzles (total: {})",
2113 existing.len() - original_len,
2114 existing.len()
2115 );
2116 } else {
2117 Self::save_tactical_puzzles(tactical_data, path)?;
2119 }
2120 Ok(())
2121 }
2122
2123 pub fn parse_and_load_incremental<P: AsRef<std::path::Path>>(
2125 file_path: P,
2126 engine: &mut ChessVectorEngine,
2127 max_puzzles: Option<usize>,
2128 min_rating: Option<u32>,
2129 max_rating: Option<u32>,
2130 ) -> Result<(), Box<dyn std::error::Error>> {
2131 println!("Parsing Lichess puzzles incrementally...");
2132
2133 let tactical_data = Self::parse_csv(file_path, max_puzzles, min_rating, max_rating)?;
2135
2136 Self::load_into_engine_incremental(&tactical_data, engine);
2138
2139 Ok(())
2140 }
2141}
2142
2143#[cfg(test)]
2144mod tests {
2145 use super::*;
2146 use chess::Board;
2147 use std::str::FromStr;
2148
2149 #[test]
2150 fn test_training_dataset_creation() {
2151 let dataset = TrainingDataset::new();
2152 assert_eq!(dataset.data.len(), 0);
2153 }
2154
2155 #[test]
2156 fn test_add_training_data() {
2157 let mut dataset = TrainingDataset::new();
2158 let board = Board::default();
2159
2160 let training_data = TrainingData {
2161 board,
2162 evaluation: 0.5,
2163 depth: 15,
2164 game_id: 1,
2165 };
2166
2167 dataset.data.push(training_data);
2168 assert_eq!(dataset.data.len(), 1);
2169 assert_eq!(dataset.data[0].evaluation, 0.5);
2170 }
2171
2172 #[test]
2173 fn test_chess_engine_integration() {
2174 let mut dataset = TrainingDataset::new();
2175 let board = Board::default();
2176
2177 let training_data = TrainingData {
2178 board,
2179 evaluation: 0.3,
2180 depth: 15,
2181 game_id: 1,
2182 };
2183
2184 dataset.data.push(training_data);
2185
2186 let mut engine = ChessVectorEngine::new(1024);
2187 dataset.train_engine(&mut engine);
2188
2189 assert_eq!(engine.knowledge_base_size(), 1);
2190
2191 let eval = engine.evaluate_position(&board);
2192 assert!(eval.is_some());
2193 assert!((eval.unwrap() - 0.3).abs() < 1e-6);
2194 }
2195
2196 #[test]
2197 fn test_deduplication() {
2198 let mut dataset = TrainingDataset::new();
2199 let board = Board::default();
2200
2201 for i in 0..5 {
2203 let training_data = TrainingData {
2204 board,
2205 evaluation: i as f32 * 0.1,
2206 depth: 15,
2207 game_id: i,
2208 };
2209 dataset.data.push(training_data);
2210 }
2211
2212 assert_eq!(dataset.data.len(), 5);
2213
2214 dataset.deduplicate(0.999);
2216 assert_eq!(dataset.data.len(), 1);
2217 }
2218
2219 #[test]
2220 fn test_dataset_serialization() {
2221 let mut dataset = TrainingDataset::new();
2222 let board =
2223 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2224
2225 let training_data = TrainingData {
2226 board,
2227 evaluation: 0.2,
2228 depth: 10,
2229 game_id: 42,
2230 };
2231
2232 dataset.data.push(training_data);
2233
2234 let json = serde_json::to_string(&dataset.data).unwrap();
2236 let loaded_data: Vec<TrainingData> = serde_json::from_str(&json).unwrap();
2237 let loaded_dataset = TrainingDataset { data: loaded_data };
2238
2239 assert_eq!(loaded_dataset.data.len(), 1);
2240 assert_eq!(loaded_dataset.data[0].evaluation, 0.2);
2241 assert_eq!(loaded_dataset.data[0].depth, 10);
2242 assert_eq!(loaded_dataset.data[0].game_id, 42);
2243 }
2244
2245 #[test]
2246 fn test_tactical_puzzle_processing() {
2247 let puzzle = TacticalPuzzle {
2248 puzzle_id: "test123".to_string(),
2249 fen: "r1bqkbnr/pppp1ppp/2n5/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4".to_string(),
2250 moves: "Bxf7+ Ke7".to_string(),
2251 rating: 1500,
2252 rating_deviation: 100,
2253 popularity: 150,
2254 nb_plays: 1000,
2255 themes: "fork pin".to_string(),
2256 game_url: None,
2257 opening_tags: None,
2258 };
2259
2260 let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2261 assert!(tactical_data.is_some());
2262
2263 let data = tactical_data.unwrap();
2264 assert_eq!(data.move_theme, "fork");
2265 assert!(data.tactical_value > 1.0); assert!(data.difficulty > 0.0);
2267 }
2268
2269 #[test]
2270 fn test_tactical_puzzle_invalid_fen() {
2271 let puzzle = TacticalPuzzle {
2272 puzzle_id: "test123".to_string(),
2273 fen: "invalid_fen".to_string(),
2274 moves: "e2e4".to_string(),
2275 rating: 1500,
2276 rating_deviation: 100,
2277 popularity: 150,
2278 nb_plays: 1000,
2279 themes: "tactics".to_string(),
2280 game_url: None,
2281 opening_tags: None,
2282 };
2283
2284 let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2285 assert!(tactical_data.is_none());
2286 }
2287
2288 #[test]
2289 fn test_engine_evaluator() {
2290 let evaluator = EngineEvaluator::new(15);
2291
2292 let mut dataset = TrainingDataset::new();
2294 let board = Board::default();
2295
2296 let training_data = TrainingData {
2297 board,
2298 evaluation: 0.0,
2299 depth: 15,
2300 game_id: 1,
2301 };
2302
2303 dataset.data.push(training_data);
2304
2305 let mut engine = ChessVectorEngine::new(1024);
2307 engine.add_position(&board, 0.1);
2308
2309 let accuracy = evaluator.evaluate_accuracy(&mut engine, &dataset);
2311 assert!(accuracy.is_ok());
2312 assert!(accuracy.unwrap() < 1.0); }
2314
2315 #[test]
2316 fn test_tactical_training_integration() {
2317 let tactical_data = vec![TacticalTrainingData {
2318 position: Board::default(),
2319 solution_move: ChessMove::from_str("e2e4").unwrap(),
2320 move_theme: "opening".to_string(),
2321 difficulty: 1.2,
2322 tactical_value: 2.5,
2323 }];
2324
2325 let mut engine = ChessVectorEngine::new(1024);
2326 TacticalPuzzleParser::load_into_engine(&tactical_data, &mut engine);
2327
2328 assert_eq!(engine.knowledge_base_size(), 1);
2329 assert_eq!(engine.position_moves.len(), 1);
2330
2331 let recommendations = engine.recommend_moves(&Board::default(), 5);
2333 assert!(!recommendations.is_empty());
2334 }
2335
2336 #[test]
2337 fn test_multithreading_operations() {
2338 let mut dataset = TrainingDataset::new();
2339 let board = Board::default();
2340
2341 for i in 0..10 {
2343 let training_data = TrainingData {
2344 board,
2345 evaluation: i as f32 * 0.1,
2346 depth: 15,
2347 game_id: i,
2348 };
2349 dataset.data.push(training_data);
2350 }
2351
2352 dataset.deduplicate_parallel(0.95, 5);
2354 assert!(dataset.data.len() <= 10);
2355 }
2356
2357 #[test]
2358 fn test_incremental_dataset_operations() {
2359 let mut dataset1 = TrainingDataset::new();
2360 let board1 = Board::default();
2361 let board2 =
2362 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2363
2364 dataset1.add_position(board1, 0.0, 15, 1);
2366 dataset1.add_position(board2, 0.2, 15, 2);
2367 assert_eq!(dataset1.data.len(), 2);
2368
2369 let mut dataset2 = TrainingDataset::new();
2371 dataset2.add_position(
2372 Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2373 .unwrap(),
2374 0.3,
2375 15,
2376 3,
2377 );
2378
2379 dataset1.merge(dataset2);
2381 assert_eq!(dataset1.data.len(), 3);
2382
2383 let next_id = dataset1.next_game_id();
2385 assert_eq!(next_id, 4); }
2387
2388 #[test]
2389 fn test_save_load_incremental() {
2390 use tempfile::tempdir;
2391
2392 let temp_dir = tempdir().unwrap();
2393 let file_path = temp_dir.path().join("incremental_test.json");
2394
2395 let mut dataset1 = TrainingDataset::new();
2397 dataset1.add_position(Board::default(), 0.0, 15, 1);
2398 dataset1.save(&file_path).unwrap();
2399
2400 let mut dataset2 = TrainingDataset::new();
2402 dataset2.add_position(
2403 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap(),
2404 0.2,
2405 15,
2406 2,
2407 );
2408 dataset2.save_incremental(&file_path).unwrap();
2409
2410 let loaded = TrainingDataset::load(&file_path).unwrap();
2412 assert_eq!(loaded.data.len(), 2);
2413
2414 let mut dataset3 = TrainingDataset::new();
2416 dataset3.add_position(
2417 Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2418 .unwrap(),
2419 0.3,
2420 15,
2421 3,
2422 );
2423 dataset3.load_and_append(&file_path).unwrap();
2424 assert_eq!(dataset3.data.len(), 3); }
2426
2427 #[test]
2428 fn test_add_position_method() {
2429 let mut dataset = TrainingDataset::new();
2430 let board = Board::default();
2431
2432 dataset.add_position(board, 0.5, 20, 42);
2434 assert_eq!(dataset.data.len(), 1);
2435 assert_eq!(dataset.data[0].evaluation, 0.5);
2436 assert_eq!(dataset.data[0].depth, 20);
2437 assert_eq!(dataset.data[0].game_id, 42);
2438 }
2439
2440 #[test]
2441 fn test_incremental_save_deduplication() {
2442 use tempfile::tempdir;
2443
2444 let temp_dir = tempdir().unwrap();
2445 let file_path = temp_dir.path().join("dedup_test.json");
2446
2447 let mut dataset1 = TrainingDataset::new();
2449 dataset1.add_position(Board::default(), 0.0, 15, 1);
2450 dataset1.save(&file_path).unwrap();
2451
2452 let mut dataset2 = TrainingDataset::new();
2454 dataset2.add_position(Board::default(), 0.1, 15, 2); dataset2.save_incremental(&file_path).unwrap();
2456
2457 let loaded = TrainingDataset::load(&file_path).unwrap();
2459 assert_eq!(loaded.data.len(), 1);
2460 }
2461
2462 #[test]
2463 fn test_tactical_puzzle_incremental_loading() {
2464 let tactical_data = vec![
2465 TacticalTrainingData {
2466 position: Board::default(),
2467 solution_move: ChessMove::from_str("e2e4").unwrap(),
2468 move_theme: "opening".to_string(),
2469 difficulty: 1.2,
2470 tactical_value: 2.5,
2471 },
2472 TacticalTrainingData {
2473 position: Board::from_str(
2474 "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1",
2475 )
2476 .unwrap(),
2477 solution_move: ChessMove::from_str("e7e5").unwrap(),
2478 move_theme: "opening".to_string(),
2479 difficulty: 1.0,
2480 tactical_value: 2.0,
2481 },
2482 ];
2483
2484 let mut engine = ChessVectorEngine::new(1024);
2485
2486 engine.add_position(&Board::default(), 0.1);
2488 assert_eq!(engine.knowledge_base_size(), 1);
2489
2490 TacticalPuzzleParser::load_into_engine_incremental(&tactical_data, &mut engine);
2492
2493 assert_eq!(engine.knowledge_base_size(), 2);
2495
2496 assert!(engine.training_stats().has_move_data);
2498 assert!(engine.training_stats().move_data_entries > 0);
2499 }
2500
2501 #[test]
2502 fn test_tactical_puzzle_serialization() {
2503 use tempfile::tempdir;
2504
2505 let temp_dir = tempdir().unwrap();
2506 let file_path = temp_dir.path().join("tactical_test.json");
2507
2508 let tactical_data = vec![TacticalTrainingData {
2509 position: Board::default(),
2510 solution_move: ChessMove::from_str("e2e4").unwrap(),
2511 move_theme: "fork".to_string(),
2512 difficulty: 1.5,
2513 tactical_value: 3.0,
2514 }];
2515
2516 TacticalPuzzleParser::save_tactical_puzzles(&tactical_data, &file_path).unwrap();
2518
2519 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2521 assert_eq!(loaded.len(), 1);
2522 assert_eq!(loaded[0].move_theme, "fork");
2523 assert_eq!(loaded[0].difficulty, 1.5);
2524 assert_eq!(loaded[0].tactical_value, 3.0);
2525 }
2526
2527 #[test]
2528 fn test_tactical_puzzle_incremental_save() {
2529 use tempfile::tempdir;
2530
2531 let temp_dir = tempdir().unwrap();
2532 let file_path = temp_dir.path().join("incremental_tactical.json");
2533
2534 let batch1 = vec![TacticalTrainingData {
2536 position: Board::default(),
2537 solution_move: ChessMove::from_str("e2e4").unwrap(),
2538 move_theme: "opening".to_string(),
2539 difficulty: 1.0,
2540 tactical_value: 2.0,
2541 }];
2542 TacticalPuzzleParser::save_tactical_puzzles(&batch1, &file_path).unwrap();
2543
2544 let batch2 = vec![TacticalTrainingData {
2546 position: Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1")
2547 .unwrap(),
2548 solution_move: ChessMove::from_str("e7e5").unwrap(),
2549 move_theme: "counter".to_string(),
2550 difficulty: 1.2,
2551 tactical_value: 2.2,
2552 }];
2553 TacticalPuzzleParser::save_tactical_puzzles_incremental(&batch2, &file_path).unwrap();
2554
2555 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2557 assert_eq!(loaded.len(), 2);
2558 }
2559
2560 #[test]
2561 fn test_tactical_puzzle_incremental_deduplication() {
2562 use tempfile::tempdir;
2563
2564 let temp_dir = tempdir().unwrap();
2565 let file_path = temp_dir.path().join("dedup_tactical.json");
2566
2567 let tactical_data = TacticalTrainingData {
2568 position: Board::default(),
2569 solution_move: ChessMove::from_str("e2e4").unwrap(),
2570 move_theme: "opening".to_string(),
2571 difficulty: 1.0,
2572 tactical_value: 2.0,
2573 };
2574
2575 TacticalPuzzleParser::save_tactical_puzzles(&[tactical_data.clone()], &file_path).unwrap();
2577
2578 TacticalPuzzleParser::save_tactical_puzzles_incremental(&[tactical_data], &file_path)
2580 .unwrap();
2581
2582 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2584 assert_eq!(loaded.len(), 1);
2585 }
2586}