1use chess::{Board, ChessMove, Game, MoveGen};
2use indicatif::{ProgressBar, ProgressStyle};
3use pgn_reader::{BufferedReader, RawHeader, SanPlus, Skip, Visitor};
4use rayon::prelude::*;
5use serde::{Deserialize, Serialize};
6use std::fs::File;
7use std::io::{BufRead, BufReader, BufWriter, Write};
8use std::path::Path;
9use std::process::{Child, Command, Stdio};
10use std::str::FromStr;
11use std::sync::{Arc, Mutex};
12
13use crate::ChessVectorEngine;
14
15#[derive(Debug, Clone)]
17pub struct SelfPlayConfig {
18 pub games_per_iteration: usize,
20 pub max_moves_per_game: usize,
22 pub exploration_factor: f32,
24 pub min_confidence: f32,
26 pub use_opening_book: bool,
28 pub temperature: f32,
30}
31
32impl Default for SelfPlayConfig {
33 fn default() -> Self {
34 Self {
35 games_per_iteration: 100,
36 max_moves_per_game: 200,
37 exploration_factor: 0.3,
38 min_confidence: 0.1,
39 use_opening_book: true,
40 temperature: 0.8,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct TrainingData {
48 pub board: Board,
49 pub evaluation: f32,
50 pub depth: u8,
51 pub game_id: usize,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct TacticalPuzzle {
57 #[serde(rename = "PuzzleId")]
58 pub puzzle_id: String,
59 #[serde(rename = "FEN")]
60 pub fen: String,
61 #[serde(rename = "Moves")]
62 pub moves: String, #[serde(rename = "Rating")]
64 pub rating: u32,
65 #[serde(rename = "RatingDeviation")]
66 pub rating_deviation: u32,
67 #[serde(rename = "Popularity")]
68 pub popularity: i32,
69 #[serde(rename = "NbPlays")]
70 pub nb_plays: u32,
71 #[serde(rename = "Themes")]
72 pub themes: String, #[serde(rename = "GameUrl")]
74 pub game_url: Option<String>,
75 #[serde(rename = "OpeningTags")]
76 pub opening_tags: Option<String>,
77}
78
79#[derive(Debug, Clone)]
81pub struct TacticalTrainingData {
82 pub position: Board,
83 pub solution_move: ChessMove,
84 pub move_theme: String,
85 pub difficulty: f32, pub tactical_value: f32, }
88
89impl serde::Serialize for TacticalTrainingData {
91 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
92 where
93 S: serde::Serializer,
94 {
95 use serde::ser::SerializeStruct;
96 let mut state = serializer.serialize_struct("TacticalTrainingData", 5)?;
97 state.serialize_field("fen", &self.position.to_string())?;
98 state.serialize_field("solution_move", &self.solution_move.to_string())?;
99 state.serialize_field("move_theme", &self.move_theme)?;
100 state.serialize_field("difficulty", &self.difficulty)?;
101 state.serialize_field("tactical_value", &self.tactical_value)?;
102 state.end()
103 }
104}
105
106impl<'de> serde::Deserialize<'de> for TacticalTrainingData {
107 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108 where
109 D: serde::Deserializer<'de>,
110 {
111 use serde::de::{self, MapAccess, Visitor};
112 use std::fmt;
113
114 struct TacticalTrainingDataVisitor;
115
116 impl<'de> Visitor<'de> for TacticalTrainingDataVisitor {
117 type Value = TacticalTrainingData;
118
119 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
120 formatter.write_str("struct TacticalTrainingData")
121 }
122
123 fn visit_map<V>(self, mut map: V) -> Result<TacticalTrainingData, V::Error>
124 where
125 V: MapAccess<'de>,
126 {
127 let mut fen = None;
128 let mut solution_move = None;
129 let mut move_theme = None;
130 let mut difficulty = None;
131 let mut tactical_value = None;
132
133 while let Some(key) = map.next_key()? {
134 match key {
135 "fen" => {
136 if fen.is_some() {
137 return Err(de::Error::duplicate_field("fen"));
138 }
139 fen = Some(map.next_value()?);
140 }
141 "solution_move" => {
142 if solution_move.is_some() {
143 return Err(de::Error::duplicate_field("solution_move"));
144 }
145 solution_move = Some(map.next_value()?);
146 }
147 "move_theme" => {
148 if move_theme.is_some() {
149 return Err(de::Error::duplicate_field("move_theme"));
150 }
151 move_theme = Some(map.next_value()?);
152 }
153 "difficulty" => {
154 if difficulty.is_some() {
155 return Err(de::Error::duplicate_field("difficulty"));
156 }
157 difficulty = Some(map.next_value()?);
158 }
159 "tactical_value" => {
160 if tactical_value.is_some() {
161 return Err(de::Error::duplicate_field("tactical_value"));
162 }
163 tactical_value = Some(map.next_value()?);
164 }
165 _ => {
166 let _: serde_json::Value = map.next_value()?;
167 }
168 }
169 }
170
171 let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
172 let solution_move_str: String =
173 solution_move.ok_or_else(|| de::Error::missing_field("solution_move"))?;
174 let move_theme =
175 move_theme.ok_or_else(|| de::Error::missing_field("move_theme"))?;
176 let difficulty =
177 difficulty.ok_or_else(|| de::Error::missing_field("difficulty"))?;
178 let tactical_value =
179 tactical_value.ok_or_else(|| de::Error::missing_field("tactical_value"))?;
180
181 let position =
182 Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
183
184 let solution_move = ChessMove::from_str(&solution_move_str)
185 .map_err(|e| de::Error::custom(format!("Error: {e}")))?;
186
187 Ok(TacticalTrainingData {
188 position,
189 solution_move,
190 move_theme,
191 difficulty,
192 tactical_value,
193 })
194 }
195 }
196
197 const FIELDS: &[&str] = &[
198 "fen",
199 "solution_move",
200 "move_theme",
201 "difficulty",
202 "tactical_value",
203 ];
204 deserializer.deserialize_struct("TacticalTrainingData", FIELDS, TacticalTrainingDataVisitor)
205 }
206}
207
208pub struct GameExtractor {
210 pub positions: Vec<TrainingData>,
211 pub current_game: Game,
212 pub move_count: usize,
213 pub max_moves_per_game: usize,
214 pub game_id: usize,
215}
216
217impl GameExtractor {
218 pub fn new(max_moves_per_game: usize) -> Self {
219 Self {
220 positions: Vec::new(),
221 current_game: Game::new(),
222 move_count: 0,
223 max_moves_per_game,
224 game_id: 0,
225 }
226 }
227}
228
229impl Visitor for GameExtractor {
230 type Result = ();
231
232 fn begin_game(&mut self) {
233 self.current_game = Game::new();
234 self.move_count = 0;
235 self.game_id += 1;
236 }
237
238 fn header(&mut self, _key: &[u8], _value: RawHeader<'_>) {}
239
240 fn san(&mut self, san_plus: SanPlus) {
241 if self.move_count >= self.max_moves_per_game {
242 return;
243 }
244
245 let san_str = san_plus.san.to_string();
246
247 let current_pos = self.current_game.current_position();
249
250 match chess::ChessMove::from_san(¤t_pos, &san_str) {
252 Ok(chess_move) => {
253 let legal_moves: Vec<chess::ChessMove> = MoveGen::new_legal(¤t_pos).collect();
255 if legal_moves.contains(&chess_move) {
256 if self.current_game.make_move(chess_move) {
257 self.move_count += 1;
258
259 self.positions.push(TrainingData {
261 board: self.current_game.current_position(),
262 evaluation: 0.0, depth: 0,
264 game_id: self.game_id,
265 });
266 }
267 } else {
268 }
270 }
271 Err(_) => {
272 if !san_str.contains("O-O") && !san_str.contains("=") && san_str.len() > 6 {
276 }
278 }
279 }
280 }
281
282 fn begin_variation(&mut self) -> Skip {
283 Skip(true) }
285
286 fn end_game(&mut self) -> Self::Result {}
287}
288
289pub struct StockfishEvaluator {
291 depth: u8,
292}
293
294impl StockfishEvaluator {
295 pub fn new(depth: u8) -> Self {
296 Self { depth }
297 }
298
299 pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
301 let mut child = Command::new("stockfish")
302 .stdin(Stdio::piped())
303 .stdout(Stdio::piped())
304 .stderr(Stdio::piped())
305 .spawn()?;
306
307 let stdin = child
308 .stdin
309 .as_mut()
310 .ok_or("Failed to get stdin handle for Stockfish process")?;
311 let fen = board.to_string();
312
313 use std::io::Write;
315 writeln!(stdin, "uci")?;
316 writeln!(stdin, "isready")?;
317 writeln!(stdin, "position fen {fen}")?;
318 writeln!(stdin, "go depth {}", self.depth)?;
319 writeln!(stdin, "quit")?;
320
321 let output = child.wait_with_output()?;
322 let stdout = String::from_utf8_lossy(&output.stdout);
323
324 for line in stdout.lines() {
326 if line.starts_with("info") && line.contains("score cp") {
327 if let Some(cp_pos) = line.find("score cp ") {
328 let cp_str = &line[cp_pos + 9..];
329 if let Some(end) = cp_str.find(' ') {
330 let cp_value = cp_str[..end].parse::<i32>()?;
331 return Ok(cp_value as f32 / 100.0); }
333 }
334 } else if line.starts_with("info") && line.contains("score mate") {
335 if let Some(mate_pos) = line.find("score mate ") {
337 let mate_str = &line[mate_pos + 11..];
338 if let Some(end) = mate_str.find(' ') {
339 let mate_moves = mate_str[..end].parse::<i32>()?;
340 return Ok(if mate_moves > 0 { 100.0 } else { -100.0 });
341 }
342 }
343 }
344 }
345
346 Ok(0.0) }
348
349 pub fn evaluate_batch(
351 &self,
352 positions: &mut [TrainingData],
353 ) -> Result<(), Box<dyn std::error::Error>> {
354 let pb = ProgressBar::new(positions.len() as u64);
355 if let Ok(style) = ProgressStyle::default_bar().template(
356 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
357 ) {
358 pb.set_style(style.progress_chars("#>-"));
359 }
360
361 for data in positions.iter_mut() {
362 match self.evaluate_position(&data.board) {
363 Ok(eval) => {
364 data.evaluation = eval;
365 data.depth = self.depth;
366 }
367 Err(e) => {
368 eprintln!("Evaluation error: {e}");
369 data.evaluation = 0.0;
370 }
371 }
372 pb.inc(1);
373 }
374
375 pb.finish_with_message("Evaluation complete");
376 Ok(())
377 }
378
379 pub fn evaluate_batch_parallel(
381 &self,
382 positions: &mut [TrainingData],
383 num_threads: usize,
384 ) -> Result<(), Box<dyn std::error::Error>> {
385 let pb = ProgressBar::new(positions.len() as u64);
386 if let Ok(style) = ProgressStyle::default_bar()
387 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel evaluation") {
388 pb.set_style(style.progress_chars("#>-"));
389 }
390
391 let pool = rayon::ThreadPoolBuilder::new()
393 .num_threads(num_threads)
394 .build()?;
395
396 pool.install(|| {
397 positions.par_iter_mut().for_each(|data| {
399 match self.evaluate_position(&data.board) {
400 Ok(eval) => {
401 data.evaluation = eval;
402 data.depth = self.depth;
403 }
404 Err(_) => {
405 data.evaluation = 0.0;
407 }
408 }
409 pb.inc(1);
410 });
411 });
412
413 pb.finish_with_message("Parallel evaluation complete");
414 Ok(())
415 }
416}
417
418struct StockfishProcess {
420 child: Child,
421 stdin: BufWriter<std::process::ChildStdin>,
422 stdout: BufReader<std::process::ChildStdout>,
423 #[allow(dead_code)]
424 depth: u8,
425}
426
427impl StockfishProcess {
428 fn new(depth: u8) -> Result<Self, Box<dyn std::error::Error>> {
429 let mut child = Command::new("stockfish")
430 .stdin(Stdio::piped())
431 .stdout(Stdio::piped())
432 .stderr(Stdio::piped())
433 .spawn()?;
434
435 let stdin = BufWriter::new(
436 child
437 .stdin
438 .take()
439 .ok_or("Failed to get stdin handle for Stockfish process")?,
440 );
441 let stdout = BufReader::new(
442 child
443 .stdout
444 .take()
445 .ok_or("Failed to get stdout handle for Stockfish process")?,
446 );
447
448 let mut process = Self {
449 child,
450 stdin,
451 stdout,
452 depth,
453 };
454
455 process.send_command("uci")?;
457 process.wait_for_ready()?;
458 process.send_command("isready")?;
459 process.wait_for_ready()?;
460
461 Ok(process)
462 }
463
464 fn send_command(&mut self, command: &str) -> Result<(), Box<dyn std::error::Error>> {
465 writeln!(self.stdin, "{command}")?;
466 self.stdin.flush()?;
467 Ok(())
468 }
469
470 fn wait_for_ready(&mut self) -> Result<(), Box<dyn std::error::Error>> {
471 let mut line = String::new();
472 loop {
473 line.clear();
474 self.stdout.read_line(&mut line)?;
475 if line.trim() == "uciok" || line.trim() == "readyok" {
476 break;
477 }
478 }
479 Ok(())
480 }
481
482 fn evaluate_position(&mut self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
483 let fen = board.to_string();
484
485 self.send_command(&format!("position fen {fen}"))?;
487 self.send_command(&format!("position fen {fen}"))?;
488
489 let mut line = String::new();
491 let mut last_evaluation = 0.0;
492
493 loop {
494 line.clear();
495 self.stdout.read_line(&mut line)?;
496 let line = line.trim();
497
498 if line.starts_with("info") && line.contains("score cp") {
499 if let Some(cp_pos) = line.find("score cp ") {
500 let cp_str = &line[cp_pos + 9..];
501 if let Some(end) = cp_str.find(' ') {
502 if let Ok(cp_value) = cp_str[..end].parse::<i32>() {
503 last_evaluation = cp_value as f32 / 100.0;
504 }
505 }
506 }
507 } else if line.starts_with("info") && line.contains("score mate") {
508 if let Some(mate_pos) = line.find("score mate ") {
509 let mate_str = &line[mate_pos + 11..];
510 if let Some(end) = mate_str.find(' ') {
511 if let Ok(mate_moves) = mate_str[..end].parse::<i32>() {
512 last_evaluation = if mate_moves > 0 { 100.0 } else { -100.0 };
513 }
514 }
515 }
516 } else if line.starts_with("bestmove") {
517 break;
518 }
519 }
520
521 Ok(last_evaluation)
522 }
523}
524
525impl Drop for StockfishProcess {
526 fn drop(&mut self) {
527 let _ = self.send_command("quit");
528 let _ = self.child.wait();
529 }
530}
531
532pub struct StockfishPool {
534 pool: Arc<Mutex<Vec<StockfishProcess>>>,
535 depth: u8,
536 pool_size: usize,
537}
538
539impl StockfishPool {
540 pub fn new(depth: u8, pool_size: usize) -> Result<Self, Box<dyn std::error::Error>> {
541 let mut processes = Vec::with_capacity(pool_size);
542
543 println!(
544 "🚀 Initializing Stockfish pool with {} processes...",
545 pool_size
546 );
547
548 for i in 0..pool_size {
549 match StockfishProcess::new(depth) {
550 Ok(process) => {
551 processes.push(process);
552 if i % 2 == 1 {
553 print!(".");
554 let _ = std::io::stdout().flush(); }
556 }
557 Err(e) => {
558 eprintln!("Evaluation error: {e}");
559 return Err(e);
560 }
561 }
562 }
563
564 println!(" ✅ Pool ready!");
565
566 Ok(Self {
567 pool: Arc::new(Mutex::new(processes)),
568 depth,
569 pool_size,
570 })
571 }
572
573 pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
574 let mut process = {
576 let mut pool = self.pool.lock().unwrap();
577 if let Some(process) = pool.pop() {
578 process
579 } else {
580 StockfishProcess::new(self.depth)?
582 }
583 };
584
585 let result = process.evaluate_position(board);
587
588 {
590 let mut pool = self.pool.lock().unwrap();
591 if pool.len() < self.pool_size {
592 pool.push(process);
593 }
594 }
596
597 result
598 }
599
600 pub fn evaluate_batch_parallel(
601 &self,
602 positions: &mut [TrainingData],
603 ) -> Result<(), Box<dyn std::error::Error>> {
604 let pb = ProgressBar::new(positions.len() as u64);
605 pb.set_style(ProgressStyle::default_bar()
606 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Pool evaluation")
607 .unwrap()
608 .progress_chars("#>-"));
609
610 positions.par_iter_mut().for_each(|data| {
612 match self.evaluate_position(&data.board) {
613 Ok(eval) => {
614 data.evaluation = eval;
615 data.depth = self.depth;
616 }
617 Err(_) => {
618 data.evaluation = 0.0;
619 }
620 }
621 pb.inc(1);
622 });
623
624 pb.finish_with_message("Pool evaluation complete");
625 Ok(())
626 }
627}
628
629pub struct TrainingDataset {
631 pub data: Vec<TrainingData>,
632}
633
634impl Default for TrainingDataset {
635 fn default() -> Self {
636 Self::new()
637 }
638}
639
640impl TrainingDataset {
641 pub fn new() -> Self {
642 Self { data: Vec::new() }
643 }
644
645 pub fn load_from_pgn<P: AsRef<Path>>(
647 &mut self,
648 path: P,
649 max_games: Option<usize>,
650 max_moves_per_game: usize,
651 ) -> Result<(), Box<dyn std::error::Error>> {
652 let file = File::open(path)?;
653 let reader = BufReader::new(file);
654
655 let mut extractor = GameExtractor::new(max_moves_per_game);
656 let mut games_processed = 0;
657
658 let mut pgn_content = String::new();
660 for line in reader.lines() {
661 let line = line?;
662 pgn_content.push_str(&line);
663 pgn_content.push('\n');
664
665 if line.trim().ends_with("1-0")
667 || line.trim().ends_with("0-1")
668 || line.trim().ends_with("1/2-1/2")
669 || line.trim().ends_with("*")
670 {
671 let cursor = std::io::Cursor::new(&pgn_content);
673 let mut reader = BufferedReader::new(cursor);
674 if let Err(e) = reader.read_all(&mut extractor) {
675 eprintln!("Evaluation error: {e}");
676 }
677
678 games_processed += 1;
679 pgn_content.clear();
680
681 if let Some(max) = max_games {
682 if games_processed >= max {
683 break;
684 }
685 }
686
687 if games_processed % 100 == 0 {
688 println!(
689 "Processed {} games, extracted {} positions",
690 games_processed,
691 extractor.positions.len()
692 );
693 }
694 }
695 }
696
697 self.data.extend(extractor.positions);
698 println!(
699 "Loaded {} positions from {} games",
700 self.data.len(),
701 games_processed
702 );
703 Ok(())
704 }
705
706 pub fn evaluate_with_stockfish(&mut self, depth: u8) -> Result<(), Box<dyn std::error::Error>> {
708 let evaluator = StockfishEvaluator::new(depth);
709 evaluator.evaluate_batch(&mut self.data)
710 }
711
712 pub fn evaluate_with_stockfish_parallel(
714 &mut self,
715 depth: u8,
716 num_threads: usize,
717 ) -> Result<(), Box<dyn std::error::Error>> {
718 let evaluator = StockfishEvaluator::new(depth);
719 evaluator.evaluate_batch_parallel(&mut self.data, num_threads)
720 }
721
722 pub fn train_engine(&self, engine: &mut ChessVectorEngine) {
724 let pb = ProgressBar::new(self.data.len() as u64);
725 pb.set_style(ProgressStyle::default_bar()
726 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Training positions")
727 .unwrap()
728 .progress_chars("#>-"));
729
730 for data in &self.data {
731 engine.add_position(&data.board, data.evaluation);
732 pb.inc(1);
733 }
734
735 pb.finish_with_message("Training complete");
736 println!("Trained engine with {} positions", self.data.len());
737 }
738
739 pub fn split(&self, train_ratio: f32) -> (TrainingDataset, TrainingDataset) {
741 use rand::seq::SliceRandom;
742 use rand::thread_rng;
743 use std::collections::{HashMap, HashSet};
744
745 let mut games: HashMap<usize, Vec<&TrainingData>> = HashMap::new();
747 for data in &self.data {
748 games.entry(data.game_id).or_default().push(data);
749 }
750
751 let mut game_ids: Vec<usize> = games.keys().cloned().collect();
753 game_ids.shuffle(&mut thread_rng());
754
755 let split_point = (game_ids.len() as f32 * train_ratio) as usize;
757 let train_game_ids: HashSet<usize> = game_ids[..split_point].iter().cloned().collect();
758
759 let mut train_data = Vec::new();
761 let mut test_data = Vec::new();
762
763 for data in &self.data {
764 if train_game_ids.contains(&data.game_id) {
765 train_data.push(data.clone());
766 } else {
767 test_data.push(data.clone());
768 }
769 }
770
771 (
772 TrainingDataset { data: train_data },
773 TrainingDataset { data: test_data },
774 )
775 }
776
777 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
779 let json = serde_json::to_string_pretty(&self.data)?;
780 std::fs::write(path, json)?;
781 Ok(())
782 }
783
784 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
786 let content = std::fs::read_to_string(path)?;
787 let data = serde_json::from_str(&content)?;
788 Ok(Self { data })
789 }
790
791 pub fn load_and_append<P: AsRef<Path>>(
793 &mut self,
794 path: P,
795 ) -> Result<(), Box<dyn std::error::Error>> {
796 let existing_len = self.data.len();
797 let additional_data = Self::load(path)?;
798 self.data.extend(additional_data.data);
799 println!(
800 "Loaded {} additional positions (total: {})",
801 self.data.len() - existing_len,
802 self.data.len()
803 );
804 Ok(())
805 }
806
807 pub fn merge(&mut self, other: TrainingDataset) {
809 let existing_len = self.data.len();
810 self.data.extend(other.data);
811 println!(
812 "Merged {} positions (total: {})",
813 self.data.len() - existing_len,
814 self.data.len()
815 );
816 }
817
818 pub fn save_incremental<P: AsRef<Path>>(
820 &self,
821 path: P,
822 ) -> Result<(), Box<dyn std::error::Error>> {
823 self.save_incremental_with_options(path, true)
824 }
825
826 pub fn save_incremental_with_options<P: AsRef<Path>>(
828 &self,
829 path: P,
830 deduplicate: bool,
831 ) -> Result<(), Box<dyn std::error::Error>> {
832 let path = path.as_ref();
833
834 if path.exists() {
835 if self.save_append_only(path).is_ok() {
837 return Ok(());
838 }
839
840 if deduplicate {
842 self.save_incremental_full_merge(path)
843 } else {
844 self.save_incremental_no_dedup(path)
845 }
846 } else {
847 self.save(path)
849 }
850 }
851
852 fn save_incremental_no_dedup<P: AsRef<Path>>(
854 &self,
855 path: P,
856 ) -> Result<(), Box<dyn std::error::Error>> {
857 let path = path.as_ref();
858
859 println!("📂 Loading existing training data...");
860 let mut existing = Self::load(path)?;
861
862 println!("⚡ Fast merge without deduplication...");
863 existing.data.extend(self.data.iter().cloned());
864
865 println!(
866 "💾 Serializing {} positions to JSON...",
867 existing.data.len()
868 );
869 let json = serde_json::to_string_pretty(&existing.data)?;
870
871 println!("✍️ Writing to disk...");
872 std::fs::write(path, json)?;
873
874 println!(
875 "✅ Fast merge save: total {} positions",
876 existing.data.len()
877 );
878 Ok(())
879 }
880
881 pub fn save_append_only<P: AsRef<Path>>(
883 &self,
884 path: P,
885 ) -> Result<(), Box<dyn std::error::Error>> {
886 use std::fs::OpenOptions;
887 use std::io::{BufRead, BufReader, Seek, SeekFrom, Write};
888
889 if self.data.is_empty() {
890 return Ok(());
891 }
892
893 let path = path.as_ref();
894 let mut file = OpenOptions::new().read(true).write(true).open(path)?;
895
896 file.seek(SeekFrom::End(-10))?;
898 let mut buffer = String::new();
899 BufReader::new(&file).read_line(&mut buffer)?;
900
901 if !buffer.trim().ends_with(']') {
902 return Err("File doesn't end with JSON array bracket".into());
903 }
904
905 file.seek(SeekFrom::End(-2))?; write!(file, ",")?;
910
911 for (i, data) in self.data.iter().enumerate() {
913 if i > 0 {
914 write!(file, ",")?;
915 }
916 let json = serde_json::to_string(data)?;
917 write!(file, "{json}")?;
918 }
919
920 write!(file, "\n]")?;
922
923 println!("Fast append: added {} new positions", self.data.len());
924 Ok(())
925 }
926
927 fn save_incremental_full_merge<P: AsRef<Path>>(
929 &self,
930 path: P,
931 ) -> Result<(), Box<dyn std::error::Error>> {
932 let path = path.as_ref();
933
934 println!("📂 Loading existing training data...");
935 let mut existing = Self::load(path)?;
936 let _original_len = existing.data.len();
937
938 println!("🔄 Streaming merge with deduplication (avoiding O(n²) operation)...");
939 existing.merge_and_deduplicate(self.data.clone());
940
941 println!(
942 "💾 Serializing {} positions to JSON...",
943 existing.data.len()
944 );
945 let json = serde_json::to_string_pretty(&existing.data)?;
946
947 println!("✍️ Writing to disk...");
948 std::fs::write(path, json)?;
949
950 println!(
951 "✅ Streaming merge save: total {} positions",
952 existing.data.len()
953 );
954 Ok(())
955 }
956
957 pub fn add_position(&mut self, board: Board, evaluation: f32, depth: u8, game_id: usize) {
959 self.data.push(TrainingData {
960 board,
961 evaluation,
962 depth,
963 game_id,
964 });
965 }
966
967 pub fn next_game_id(&self) -> usize {
969 self.data.iter().map(|data| data.game_id).max().unwrap_or(0) + 1
970 }
971
972 pub fn deduplicate(&mut self, similarity_threshold: f32) {
974 if similarity_threshold > 0.999 {
975 self.deduplicate_fast();
977 } else {
978 self.deduplicate_similarity_based(similarity_threshold);
980 }
981 }
982
983 pub fn deduplicate_fast(&mut self) {
985 use std::collections::HashSet;
986
987 if self.data.is_empty() {
988 return;
989 }
990
991 let mut seen_positions = HashSet::with_capacity(self.data.len());
992 let original_len = self.data.len();
993
994 self.data.retain(|data| {
996 let fen = data.board.to_string();
997 seen_positions.insert(fen)
998 });
999
1000 println!(
1001 "Fast deduplicated: {} -> {} positions (removed {} exact duplicates)",
1002 original_len,
1003 self.data.len(),
1004 original_len - self.data.len()
1005 );
1006 }
1007
1008 pub fn merge_and_deduplicate(&mut self, new_data: Vec<TrainingData>) {
1010 use std::collections::HashSet;
1011
1012 if new_data.is_empty() {
1013 return;
1014 }
1015
1016 let _original_len = self.data.len();
1017
1018 let mut existing_positions: HashSet<String> = HashSet::with_capacity(self.data.len());
1020 for data in &self.data {
1021 existing_positions.insert(data.board.to_string());
1022 }
1023
1024 let mut added = 0;
1026 for data in new_data {
1027 let fen = data.board.to_string();
1028 if existing_positions.insert(fen) {
1029 self.data.push(data);
1030 added += 1;
1031 }
1032 }
1033
1034 println!(
1035 "Streaming merge: added {} unique positions (total: {})",
1036 added,
1037 self.data.len()
1038 );
1039 }
1040
1041 fn deduplicate_similarity_based(&mut self, similarity_threshold: f32) {
1043 use crate::PositionEncoder;
1044 use ndarray::Array1;
1045
1046 if self.data.is_empty() {
1047 return;
1048 }
1049
1050 let encoder = PositionEncoder::new(1024);
1051 let mut keep_indices: Vec<bool> = vec![true; self.data.len()];
1052
1053 let vectors: Vec<Array1<f32>> = if self.data.len() > 50 {
1055 self.data
1056 .par_iter()
1057 .map(|data| encoder.encode(&data.board))
1058 .collect()
1059 } else {
1060 self.data
1061 .iter()
1062 .map(|data| encoder.encode(&data.board))
1063 .collect()
1064 };
1065
1066 for i in 1..self.data.len() {
1068 if !keep_indices[i] {
1069 continue;
1070 }
1071
1072 for j in 0..i {
1073 if !keep_indices[j] {
1074 continue;
1075 }
1076
1077 let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1078 if similarity > similarity_threshold {
1079 keep_indices[i] = false;
1080 break;
1081 }
1082 }
1083 }
1084
1085 let original_len = self.data.len();
1087 self.data = self
1088 .data
1089 .iter()
1090 .enumerate()
1091 .filter_map(|(i, data)| {
1092 if keep_indices[i] {
1093 Some(data.clone())
1094 } else {
1095 None
1096 }
1097 })
1098 .collect();
1099
1100 println!(
1101 "Similarity deduplicated: {} -> {} positions (removed {} near-duplicates)",
1102 original_len,
1103 self.data.len(),
1104 original_len - self.data.len()
1105 );
1106 }
1107
1108 pub fn deduplicate_parallel(&mut self, similarity_threshold: f32, chunk_size: usize) {
1110 use crate::PositionEncoder;
1111 use ndarray::Array1;
1112 use std::sync::{Arc, Mutex};
1113
1114 if self.data.is_empty() {
1115 return;
1116 }
1117
1118 let encoder = PositionEncoder::new(1024);
1119
1120 let vectors: Vec<Array1<f32>> = self
1122 .data
1123 .par_iter()
1124 .map(|data| encoder.encode(&data.board))
1125 .collect();
1126
1127 let keep_indices = Arc::new(Mutex::new(vec![true; self.data.len()]));
1128
1129 (1..self.data.len())
1131 .collect::<Vec<_>>()
1132 .par_chunks(chunk_size)
1133 .for_each(|chunk| {
1134 for &i in chunk {
1135 {
1137 let indices = keep_indices.lock().unwrap();
1138 if !indices[i] {
1139 continue;
1140 }
1141 }
1142
1143 for j in 0..i {
1145 {
1146 let indices = keep_indices.lock().unwrap();
1147 if !indices[j] {
1148 continue;
1149 }
1150 }
1151
1152 let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1153 if similarity > similarity_threshold {
1154 let mut indices = keep_indices.lock().unwrap();
1155 indices[i] = false;
1156 break;
1157 }
1158 }
1159 }
1160 });
1161
1162 let keep_indices = keep_indices.lock().unwrap();
1164 let original_len = self.data.len();
1165 self.data = self
1166 .data
1167 .iter()
1168 .enumerate()
1169 .filter_map(|(i, data)| {
1170 if keep_indices[i] {
1171 Some(data.clone())
1172 } else {
1173 None
1174 }
1175 })
1176 .collect();
1177
1178 println!(
1179 "Parallel deduplicated: {} -> {} positions (removed {} duplicates)",
1180 original_len,
1181 self.data.len(),
1182 original_len - self.data.len()
1183 );
1184 }
1185
1186 fn cosine_similarity(a: &ndarray::Array1<f32>, b: &ndarray::Array1<f32>) -> f32 {
1188 let dot_product = a.dot(b);
1189 let norm_a = a.dot(a).sqrt();
1190 let norm_b = b.dot(b).sqrt();
1191
1192 if norm_a == 0.0 || norm_b == 0.0 {
1193 0.0
1194 } else {
1195 dot_product / (norm_a * norm_b)
1196 }
1197 }
1198}
1199
1200pub struct SelfPlayTrainer {
1202 config: SelfPlayConfig,
1203 game_counter: usize,
1204}
1205
1206impl SelfPlayTrainer {
1207 pub fn new(config: SelfPlayConfig) -> Self {
1208 Self {
1209 config,
1210 game_counter: 0,
1211 }
1212 }
1213
1214 pub fn generate_training_data(&mut self, engine: &mut ChessVectorEngine) -> TrainingDataset {
1216 let mut dataset = TrainingDataset::new();
1217
1218 println!(
1219 "🎮 Starting self-play training with {} games...",
1220 self.config.games_per_iteration
1221 );
1222 let pb = ProgressBar::new(self.config.games_per_iteration as u64);
1223 if let Ok(style) = ProgressStyle::default_bar().template(
1224 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
1225 ) {
1226 pb.set_style(style.progress_chars("#>-"));
1227 }
1228
1229 for _ in 0..self.config.games_per_iteration {
1230 let game_data = self.play_single_game(engine);
1231 dataset.data.extend(game_data);
1232 self.game_counter += 1;
1233 pb.inc(1);
1234 }
1235
1236 pb.finish_with_message("Self-play games completed");
1237 println!(
1238 "✅ Generated {} positions from {} games",
1239 dataset.data.len(),
1240 self.config.games_per_iteration
1241 );
1242
1243 dataset
1244 }
1245
1246 fn play_single_game(&self, engine: &mut ChessVectorEngine) -> Vec<TrainingData> {
1248 let mut game = Game::new();
1249 let mut positions = Vec::new();
1250 let mut move_count = 0;
1251
1252 if self.config.use_opening_book {
1254 if let Some(opening_moves) = self.get_random_opening() {
1255 for mv in opening_moves {
1256 if game.make_move(mv) {
1257 move_count += 1;
1258 } else {
1259 break;
1260 }
1261 }
1262 }
1263 }
1264
1265 while game.result().is_none() && move_count < self.config.max_moves_per_game {
1267 let current_position = game.current_position();
1268
1269 let move_choice = self.select_move_with_exploration(engine, ¤t_position);
1271
1272 if let Some(chess_move) = move_choice {
1273 if let Some(evaluation) = engine.evaluate_position(¤t_position) {
1275 if evaluation.abs() >= self.config.min_confidence || move_count < 10 {
1277 positions.push(TrainingData {
1278 board: current_position,
1279 evaluation,
1280 depth: 1, game_id: self.game_counter,
1282 });
1283 }
1284 }
1285
1286 if !game.make_move(chess_move) {
1288 break; }
1290 move_count += 1;
1291 } else {
1292 break; }
1294 }
1295
1296 if let Some(result) = game.result() {
1298 let final_position = game.current_position();
1299 let final_eval = match result {
1300 chess::GameResult::WhiteCheckmates => {
1301 if final_position.side_to_move() == chess::Color::Black {
1302 10.0
1303 } else {
1304 -10.0
1305 }
1306 }
1307 chess::GameResult::BlackCheckmates => {
1308 if final_position.side_to_move() == chess::Color::White {
1309 10.0
1310 } else {
1311 -10.0
1312 }
1313 }
1314 chess::GameResult::WhiteResigns => -10.0,
1315 chess::GameResult::BlackResigns => 10.0,
1316 chess::GameResult::Stalemate
1317 | chess::GameResult::DrawAccepted
1318 | chess::GameResult::DrawDeclared => 0.0,
1319 };
1320
1321 positions.push(TrainingData {
1322 board: final_position,
1323 evaluation: final_eval,
1324 depth: 1,
1325 game_id: self.game_counter,
1326 });
1327 }
1328
1329 positions
1330 }
1331
1332 fn select_move_with_exploration(
1334 &self,
1335 engine: &mut ChessVectorEngine,
1336 position: &Board,
1337 ) -> Option<ChessMove> {
1338 let recommendations = engine.recommend_legal_moves(position, 5);
1339
1340 if recommendations.is_empty() {
1341 return None;
1342 }
1343
1344 if fastrand::f32() < self.config.exploration_factor {
1346 self.select_move_with_temperature(&recommendations)
1348 } else {
1349 Some(recommendations[0].chess_move)
1351 }
1352 }
1353
1354 fn select_move_with_temperature(
1356 &self,
1357 recommendations: &[crate::MoveRecommendation],
1358 ) -> Option<ChessMove> {
1359 if recommendations.is_empty() {
1360 return None;
1361 }
1362
1363 let mut probabilities = Vec::new();
1365 let mut sum = 0.0;
1366
1367 for rec in recommendations {
1368 let prob = (rec.average_outcome / self.config.temperature).exp();
1370 probabilities.push(prob);
1371 sum += prob;
1372 }
1373
1374 for prob in &mut probabilities {
1376 *prob /= sum;
1377 }
1378
1379 let rand_val = fastrand::f32();
1381 let mut cumulative = 0.0;
1382
1383 for (i, &prob) in probabilities.iter().enumerate() {
1384 cumulative += prob;
1385 if rand_val <= cumulative {
1386 return Some(recommendations[i].chess_move);
1387 }
1388 }
1389
1390 Some(recommendations[0].chess_move)
1392 }
1393
1394 fn get_random_opening(&self) -> Option<Vec<ChessMove>> {
1396 let openings = [
1397 vec!["e4", "e5", "Nf3", "Nc6", "Bc4"],
1399 vec!["e4", "e5", "Nf3", "Nc6", "Bb5"],
1401 vec!["d4", "d5", "c4"],
1403 vec!["d4", "Nf6", "c4", "g6"],
1405 vec!["e4", "c5"],
1407 vec!["e4", "e6"],
1409 vec!["e4", "c6"],
1411 ];
1412
1413 let selected_opening = &openings[fastrand::usize(0..openings.len())];
1414
1415 let mut moves = Vec::new();
1416 let mut game = Game::new();
1417
1418 for move_str in selected_opening {
1419 if let Ok(chess_move) = ChessMove::from_str(move_str) {
1420 if game.make_move(chess_move) {
1421 moves.push(chess_move);
1422 } else {
1423 break;
1424 }
1425 }
1426 }
1427
1428 if moves.is_empty() {
1429 None
1430 } else {
1431 Some(moves)
1432 }
1433 }
1434}
1435
1436pub struct EngineEvaluator {
1438 #[allow(dead_code)]
1439 stockfish_depth: u8,
1440}
1441
1442impl EngineEvaluator {
1443 pub fn new(stockfish_depth: u8) -> Self {
1444 Self { stockfish_depth }
1445 }
1446
1447 pub fn evaluate_accuracy(
1449 &self,
1450 engine: &mut ChessVectorEngine,
1451 test_data: &TrainingDataset,
1452 ) -> Result<f32, Box<dyn std::error::Error>> {
1453 let mut total_error = 0.0;
1454 let mut valid_comparisons = 0;
1455
1456 let pb = ProgressBar::new(test_data.data.len() as u64);
1457 pb.set_style(ProgressStyle::default_bar()
1458 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Evaluating accuracy")
1459 .unwrap()
1460 .progress_chars("#>-"));
1461
1462 for data in &test_data.data {
1463 if let Some(engine_eval) = engine.evaluate_position(&data.board) {
1464 let error = (engine_eval - data.evaluation).abs();
1465 total_error += error;
1466 valid_comparisons += 1;
1467 }
1468 pb.inc(1);
1469 }
1470
1471 pb.finish_with_message("Accuracy evaluation complete");
1472
1473 if valid_comparisons > 0 {
1474 let mean_absolute_error = total_error / valid_comparisons as f32;
1475 println!("Mean Absolute Error: {:.3} pawns", mean_absolute_error);
1476 println!("Evaluated {} positions", valid_comparisons);
1477 Ok(mean_absolute_error)
1478 } else {
1479 Ok(f32::INFINITY)
1480 }
1481 }
1482}
1483
1484impl serde::Serialize for TrainingData {
1486 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1487 where
1488 S: serde::Serializer,
1489 {
1490 use serde::ser::SerializeStruct;
1491 let mut state = serializer.serialize_struct("TrainingData", 4)?;
1492 state.serialize_field("fen", &self.board.to_string())?;
1493 state.serialize_field("evaluation", &self.evaluation)?;
1494 state.serialize_field("depth", &self.depth)?;
1495 state.serialize_field("game_id", &self.game_id)?;
1496 state.end()
1497 }
1498}
1499
1500impl<'de> serde::Deserialize<'de> for TrainingData {
1501 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1502 where
1503 D: serde::Deserializer<'de>,
1504 {
1505 use serde::de::{self, MapAccess, Visitor};
1506 use std::fmt;
1507
1508 struct TrainingDataVisitor;
1509
1510 impl<'de> Visitor<'de> for TrainingDataVisitor {
1511 type Value = TrainingData;
1512
1513 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1514 formatter.write_str("struct TrainingData")
1515 }
1516
1517 fn visit_map<V>(self, mut map: V) -> Result<TrainingData, V::Error>
1518 where
1519 V: MapAccess<'de>,
1520 {
1521 let mut fen = None;
1522 let mut evaluation = None;
1523 let mut depth = None;
1524 let mut game_id = None;
1525
1526 while let Some(key) = map.next_key()? {
1527 match key {
1528 "fen" => {
1529 if fen.is_some() {
1530 return Err(de::Error::duplicate_field("fen"));
1531 }
1532 fen = Some(map.next_value()?);
1533 }
1534 "evaluation" => {
1535 if evaluation.is_some() {
1536 return Err(de::Error::duplicate_field("evaluation"));
1537 }
1538 evaluation = Some(map.next_value()?);
1539 }
1540 "depth" => {
1541 if depth.is_some() {
1542 return Err(de::Error::duplicate_field("depth"));
1543 }
1544 depth = Some(map.next_value()?);
1545 }
1546 "game_id" => {
1547 if game_id.is_some() {
1548 return Err(de::Error::duplicate_field("game_id"));
1549 }
1550 game_id = Some(map.next_value()?);
1551 }
1552 _ => {
1553 let _: serde_json::Value = map.next_value()?;
1554 }
1555 }
1556 }
1557
1558 let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
1559 let evaluation =
1560 evaluation.ok_or_else(|| de::Error::missing_field("evaluation"))?;
1561 let depth = depth.ok_or_else(|| de::Error::missing_field("depth"))?;
1562 let game_id = game_id.unwrap_or(0); let board =
1565 Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
1566
1567 Ok(TrainingData {
1568 board,
1569 evaluation,
1570 depth,
1571 game_id,
1572 })
1573 }
1574 }
1575
1576 const FIELDS: &[&str] = &["fen", "evaluation", "depth", "game_id"];
1577 deserializer.deserialize_struct("TrainingData", FIELDS, TrainingDataVisitor)
1578 }
1579}
1580
1581pub struct TacticalPuzzleParser;
1583
1584impl TacticalPuzzleParser {
1585 pub fn parse_csv<P: AsRef<Path>>(
1587 file_path: P,
1588 max_puzzles: Option<usize>,
1589 min_rating: Option<u32>,
1590 max_rating: Option<u32>,
1591 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1592 let file = File::open(&file_path)?;
1593 let file_size = file.metadata()?.len();
1594
1595 if file_size > 100_000_000 {
1597 Self::parse_csv_parallel(file_path, max_puzzles, min_rating, max_rating)
1598 } else {
1599 Self::parse_csv_sequential(file_path, max_puzzles, min_rating, max_rating)
1600 }
1601 }
1602
1603 fn parse_csv_sequential<P: AsRef<Path>>(
1605 file_path: P,
1606 max_puzzles: Option<usize>,
1607 min_rating: Option<u32>,
1608 max_rating: Option<u32>,
1609 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1610 let file = File::open(file_path)?;
1611 let reader = BufReader::new(file);
1612
1613 let mut csv_reader = csv::ReaderBuilder::new()
1616 .has_headers(false)
1617 .flexible(true) .from_reader(reader);
1619
1620 let mut tactical_data = Vec::new();
1621 let mut processed = 0;
1622 let mut skipped = 0;
1623
1624 let pb = ProgressBar::new_spinner();
1625 pb.set_style(
1626 ProgressStyle::default_spinner()
1627 .template("{spinner:.green} Parsing tactical puzzles: {pos} (skipped: {skipped})")
1628 .unwrap(),
1629 );
1630
1631 for result in csv_reader.records() {
1632 let record = match result {
1633 Ok(r) => r,
1634 Err(e) => {
1635 skipped += 1;
1636 println!("CSV parsing error: {e}");
1637 continue;
1638 }
1639 };
1640
1641 if let Some(puzzle_data) = Self::parse_csv_record(&record, min_rating, max_rating) {
1642 if let Some(tactical_data_item) =
1643 Self::convert_puzzle_to_training_data(&puzzle_data)
1644 {
1645 tactical_data.push(tactical_data_item);
1646 processed += 1;
1647
1648 if let Some(max) = max_puzzles {
1649 if processed >= max {
1650 break;
1651 }
1652 }
1653 } else {
1654 skipped += 1;
1655 }
1656 } else {
1657 skipped += 1;
1658 }
1659
1660 pb.set_message(format!(
1661 "Parsing tactical puzzles: {} (skipped: {})",
1662 processed, skipped
1663 ));
1664 }
1665
1666 pb.finish_with_message(format!(
1667 "Parsed {} puzzles (skipped: {})",
1668 processed, skipped
1669 ));
1670
1671 Ok(tactical_data)
1672 }
1673
1674 fn parse_csv_parallel<P: AsRef<Path>>(
1676 file_path: P,
1677 max_puzzles: Option<usize>,
1678 min_rating: Option<u32>,
1679 max_rating: Option<u32>,
1680 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1681 use std::io::Read;
1682
1683 let mut file = File::open(&file_path)?;
1684
1685 let mut contents = String::new();
1687 file.read_to_string(&mut contents)?;
1688
1689 let lines: Vec<&str> = contents.lines().collect();
1691
1692 let pb = ProgressBar::new(lines.len() as u64);
1693 pb.set_style(ProgressStyle::default_bar()
1694 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel CSV parsing")
1695 .unwrap()
1696 .progress_chars("#>-"));
1697
1698 let tactical_data: Vec<TacticalTrainingData> = lines
1700 .par_iter()
1701 .take(max_puzzles.unwrap_or(usize::MAX))
1702 .filter_map(|line| {
1703 let fields: Vec<&str> = line.split(',').collect();
1705 if fields.len() < 8 {
1706 return None;
1707 }
1708
1709 if let Some(puzzle_data) = Self::parse_csv_fields(&fields, min_rating, max_rating) {
1711 Self::convert_puzzle_to_training_data(&puzzle_data)
1712 } else {
1713 None
1714 }
1715 })
1716 .collect();
1717
1718 pb.finish_with_message(format!(
1719 "Parallel parsing complete: {} puzzles",
1720 tactical_data.len()
1721 ));
1722
1723 Ok(tactical_data)
1724 }
1725
1726 fn parse_csv_record(
1728 record: &csv::StringRecord,
1729 min_rating: Option<u32>,
1730 max_rating: Option<u32>,
1731 ) -> Option<TacticalPuzzle> {
1732 if record.len() < 8 {
1734 return None;
1735 }
1736
1737 let rating: u32 = record[3].parse().ok()?;
1738 let rating_deviation: u32 = record[4].parse().ok()?;
1739 let popularity: i32 = record[5].parse().ok()?;
1740 let nb_plays: u32 = record[6].parse().ok()?;
1741
1742 if let Some(min) = min_rating {
1744 if rating < min {
1745 return None;
1746 }
1747 }
1748 if let Some(max) = max_rating {
1749 if rating > max {
1750 return None;
1751 }
1752 }
1753
1754 Some(TacticalPuzzle {
1755 puzzle_id: record[0].to_string(),
1756 fen: record[1].to_string(),
1757 moves: record[2].to_string(),
1758 rating,
1759 rating_deviation,
1760 popularity,
1761 nb_plays,
1762 themes: record[7].to_string(),
1763 game_url: if record.len() > 8 {
1764 Some(record[8].to_string())
1765 } else {
1766 None
1767 },
1768 opening_tags: if record.len() > 9 {
1769 Some(record[9].to_string())
1770 } else {
1771 None
1772 },
1773 })
1774 }
1775
1776 fn parse_csv_fields(
1778 fields: &[&str],
1779 min_rating: Option<u32>,
1780 max_rating: Option<u32>,
1781 ) -> Option<TacticalPuzzle> {
1782 if fields.len() < 8 {
1783 return None;
1784 }
1785
1786 let rating: u32 = fields[3].parse().ok()?;
1787 let rating_deviation: u32 = fields[4].parse().ok()?;
1788 let popularity: i32 = fields[5].parse().ok()?;
1789 let nb_plays: u32 = fields[6].parse().ok()?;
1790
1791 if let Some(min) = min_rating {
1793 if rating < min {
1794 return None;
1795 }
1796 }
1797 if let Some(max) = max_rating {
1798 if rating > max {
1799 return None;
1800 }
1801 }
1802
1803 Some(TacticalPuzzle {
1804 puzzle_id: fields[0].to_string(),
1805 fen: fields[1].to_string(),
1806 moves: fields[2].to_string(),
1807 rating,
1808 rating_deviation,
1809 popularity,
1810 nb_plays,
1811 themes: fields[7].to_string(),
1812 game_url: if fields.len() > 8 {
1813 Some(fields[8].to_string())
1814 } else {
1815 None
1816 },
1817 opening_tags: if fields.len() > 9 {
1818 Some(fields[9].to_string())
1819 } else {
1820 None
1821 },
1822 })
1823 }
1824
1825 fn convert_puzzle_to_training_data(puzzle: &TacticalPuzzle) -> Option<TacticalTrainingData> {
1827 let position = match Board::from_str(&puzzle.fen) {
1829 Ok(board) => board,
1830 Err(_) => return None,
1831 };
1832
1833 let moves: Vec<&str> = puzzle.moves.split_whitespace().collect();
1835 if moves.is_empty() {
1836 return None;
1837 }
1838
1839 let solution_move = match ChessMove::from_str(moves[0]) {
1841 Ok(mv) => mv,
1842 Err(_) => {
1843 match ChessMove::from_san(&position, moves[0]) {
1845 Ok(mv) => mv,
1846 Err(_) => return None,
1847 }
1848 }
1849 };
1850
1851 let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&position).collect();
1853 if !legal_moves.contains(&solution_move) {
1854 return None;
1855 }
1856
1857 let themes: Vec<&str> = puzzle.themes.split_whitespace().collect();
1859 let primary_theme = themes.first().unwrap_or(&"tactical").to_string();
1860
1861 let difficulty = puzzle.rating as f32 / 1000.0; let popularity_bonus = (puzzle.popularity as f32 / 100.0).min(2.0);
1864 let tactical_value = difficulty + popularity_bonus; Some(TacticalTrainingData {
1867 position,
1868 solution_move,
1869 move_theme: primary_theme,
1870 difficulty,
1871 tactical_value,
1872 })
1873 }
1874
1875 pub fn load_into_engine(
1877 tactical_data: &[TacticalTrainingData],
1878 engine: &mut ChessVectorEngine,
1879 ) {
1880 let pb = ProgressBar::new(tactical_data.len() as u64);
1881 pb.set_style(ProgressStyle::default_bar()
1882 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns")
1883 .unwrap()
1884 .progress_chars("#>-"));
1885
1886 for data in tactical_data {
1887 engine.add_position_with_move(
1889 &data.position,
1890 0.0, Some(data.solution_move),
1892 Some(data.tactical_value), );
1894 pb.inc(1);
1895 }
1896
1897 pb.finish_with_message(format!("Loaded {} tactical patterns", tactical_data.len()));
1898 }
1899
1900 pub fn load_into_engine_incremental(
1902 tactical_data: &[TacticalTrainingData],
1903 engine: &mut ChessVectorEngine,
1904 ) {
1905 let initial_size = engine.knowledge_base_size();
1906 let initial_moves = engine.position_moves.len();
1907
1908 if tactical_data.len() > 1000 {
1910 Self::load_into_engine_incremental_parallel(
1911 tactical_data,
1912 engine,
1913 initial_size,
1914 initial_moves,
1915 );
1916 } else {
1917 Self::load_into_engine_incremental_sequential(
1918 tactical_data,
1919 engine,
1920 initial_size,
1921 initial_moves,
1922 );
1923 }
1924 }
1925
1926 fn load_into_engine_incremental_sequential(
1928 tactical_data: &[TacticalTrainingData],
1929 engine: &mut ChessVectorEngine,
1930 initial_size: usize,
1931 initial_moves: usize,
1932 ) {
1933 let pb = ProgressBar::new(tactical_data.len() as u64);
1934 pb.set_style(ProgressStyle::default_bar()
1935 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns (incremental)")
1936 .unwrap()
1937 .progress_chars("#>-"));
1938
1939 let mut added = 0;
1940 let mut skipped = 0;
1941
1942 for data in tactical_data {
1943 if !engine.position_boards.contains(&data.position) {
1945 engine.add_position_with_move(
1946 &data.position,
1947 0.0, Some(data.solution_move),
1949 Some(data.tactical_value), );
1951 added += 1;
1952 } else {
1953 skipped += 1;
1954 }
1955 pb.inc(1);
1956 }
1957
1958 pb.finish_with_message(format!(
1959 "Loaded {} new tactical patterns (skipped {} duplicates, total: {})",
1960 added,
1961 skipped,
1962 engine.knowledge_base_size()
1963 ));
1964
1965 println!("Incremental tactical training:");
1966 println!(
1967 " - Positions: {} → {} (+{})",
1968 initial_size,
1969 engine.knowledge_base_size(),
1970 engine.knowledge_base_size() - initial_size
1971 );
1972 println!(
1973 " - Move entries: {} → {} (+{})",
1974 initial_moves,
1975 engine.position_moves.len(),
1976 engine.position_moves.len() - initial_moves
1977 );
1978 }
1979
1980 fn load_into_engine_incremental_parallel(
1982 tactical_data: &[TacticalTrainingData],
1983 engine: &mut ChessVectorEngine,
1984 initial_size: usize,
1985 initial_moves: usize,
1986 ) {
1987 let pb = ProgressBar::new(tactical_data.len() as u64);
1988 pb.set_style(ProgressStyle::default_bar()
1989 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Optimized batch loading tactical patterns")
1990 .unwrap()
1991 .progress_chars("#>-"));
1992
1993 let filtered_data: Vec<&TacticalTrainingData> = tactical_data
1995 .par_iter()
1996 .filter(|data| !engine.position_boards.contains(&data.position))
1997 .collect();
1998
1999 let batch_size = 1000; let mut added = 0;
2001
2002 println!(
2003 "Pre-filtered: {} → {} positions (removed {} duplicates)",
2004 tactical_data.len(),
2005 filtered_data.len(),
2006 tactical_data.len() - filtered_data.len()
2007 );
2008
2009 for batch in filtered_data.chunks(batch_size) {
2012 let batch_start = added;
2013
2014 for data in batch {
2015 if !engine.position_boards.contains(&data.position) {
2017 engine.add_position_with_move(
2018 &data.position,
2019 0.0, Some(data.solution_move),
2021 Some(data.tactical_value), );
2023 added += 1;
2024 }
2025 pb.inc(1);
2026 }
2027
2028 pb.set_message(format!("Loaded batch: {} positions", added - batch_start));
2030 }
2031
2032 let skipped = tactical_data.len() - added;
2033
2034 pb.finish_with_message(format!(
2035 "Optimized loaded {} new tactical patterns (skipped {} duplicates, total: {})",
2036 added,
2037 skipped,
2038 engine.knowledge_base_size()
2039 ));
2040
2041 println!("Incremental tactical training (optimized):");
2042 println!(
2043 " - Positions: {} → {} (+{})",
2044 initial_size,
2045 engine.knowledge_base_size(),
2046 engine.knowledge_base_size() - initial_size
2047 );
2048 println!(
2049 " - Move entries: {} → {} (+{})",
2050 initial_moves,
2051 engine.position_moves.len(),
2052 engine.position_moves.len() - initial_moves
2053 );
2054 println!(
2055 " - Batch size: {}, Pre-filtered efficiency: {:.1}%",
2056 batch_size,
2057 (filtered_data.len() as f32 / tactical_data.len() as f32) * 100.0
2058 );
2059 }
2060
2061 pub fn save_tactical_puzzles<P: AsRef<std::path::Path>>(
2063 tactical_data: &[TacticalTrainingData],
2064 path: P,
2065 ) -> Result<(), Box<dyn std::error::Error>> {
2066 let json = serde_json::to_string_pretty(tactical_data)?;
2067 std::fs::write(path, json)?;
2068 println!("Saved {} tactical puzzles", tactical_data.len());
2069 Ok(())
2070 }
2071
2072 pub fn load_tactical_puzzles<P: AsRef<std::path::Path>>(
2074 path: P,
2075 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
2076 let content = std::fs::read_to_string(path)?;
2077 let tactical_data: Vec<TacticalTrainingData> = serde_json::from_str(&content)?;
2078 println!("Loaded {} tactical puzzles from file", tactical_data.len());
2079 Ok(tactical_data)
2080 }
2081
2082 pub fn save_tactical_puzzles_incremental<P: AsRef<std::path::Path>>(
2084 tactical_data: &[TacticalTrainingData],
2085 path: P,
2086 ) -> Result<(), Box<dyn std::error::Error>> {
2087 let path = path.as_ref();
2088
2089 if path.exists() {
2090 let mut existing = Self::load_tactical_puzzles(path)?;
2092 let original_len = existing.len();
2093
2094 for new_puzzle in tactical_data {
2096 let exists = existing.iter().any(|existing_puzzle| {
2098 existing_puzzle.position == new_puzzle.position
2099 && existing_puzzle.solution_move == new_puzzle.solution_move
2100 });
2101
2102 if !exists {
2103 existing.push(new_puzzle.clone());
2104 }
2105 }
2106
2107 let json = serde_json::to_string_pretty(&existing)?;
2109 std::fs::write(path, json)?;
2110
2111 println!(
2112 "Incremental save: added {} new puzzles (total: {})",
2113 existing.len() - original_len,
2114 existing.len()
2115 );
2116 } else {
2117 Self::save_tactical_puzzles(tactical_data, path)?;
2119 }
2120 Ok(())
2121 }
2122
2123 pub fn parse_and_load_incremental<P: AsRef<std::path::Path>>(
2125 file_path: P,
2126 engine: &mut ChessVectorEngine,
2127 max_puzzles: Option<usize>,
2128 min_rating: Option<u32>,
2129 max_rating: Option<u32>,
2130 ) -> Result<(), Box<dyn std::error::Error>> {
2131 println!("Parsing Lichess puzzles incrementally...");
2132
2133 let tactical_data = Self::parse_csv(file_path, max_puzzles, min_rating, max_rating)?;
2135
2136 Self::load_into_engine_incremental(&tactical_data, engine);
2138
2139 Ok(())
2140 }
2141}
2142
2143#[cfg(test)]
2144mod tests {
2145 use super::*;
2146 use chess::Board;
2147 use std::str::FromStr;
2148
2149 #[test]
2150 fn test_training_dataset_creation() {
2151 let dataset = TrainingDataset::new();
2152 assert_eq!(dataset.data.len(), 0);
2153 }
2154
2155 #[test]
2156 fn test_add_training_data() {
2157 let mut dataset = TrainingDataset::new();
2158 let board = Board::default();
2159
2160 let training_data = TrainingData {
2161 board,
2162 evaluation: 0.5,
2163 depth: 15,
2164 game_id: 1,
2165 };
2166
2167 dataset.data.push(training_data);
2168 assert_eq!(dataset.data.len(), 1);
2169 assert_eq!(dataset.data[0].evaluation, 0.5);
2170 }
2171
2172 #[test]
2173 fn test_chess_engine_integration() {
2174 let mut dataset = TrainingDataset::new();
2175 let board = Board::default();
2176
2177 let training_data = TrainingData {
2178 board,
2179 evaluation: 0.3,
2180 depth: 15,
2181 game_id: 1,
2182 };
2183
2184 dataset.data.push(training_data);
2185
2186 let mut engine = ChessVectorEngine::new(1024);
2187 dataset.train_engine(&mut engine);
2188
2189 assert_eq!(engine.knowledge_base_size(), 1);
2190
2191 let eval = engine.evaluate_position(&board);
2192 assert!(eval.is_some());
2193 assert!((eval.unwrap() - 0.3).abs() < 1e-6);
2194 }
2195
2196 #[test]
2197 fn test_deduplication() {
2198 let mut dataset = TrainingDataset::new();
2199 let board = Board::default();
2200
2201 for i in 0..5 {
2203 let training_data = TrainingData {
2204 board,
2205 evaluation: i as f32 * 0.1,
2206 depth: 15,
2207 game_id: i,
2208 };
2209 dataset.data.push(training_data);
2210 }
2211
2212 assert_eq!(dataset.data.len(), 5);
2213
2214 dataset.deduplicate(0.999);
2216 assert_eq!(dataset.data.len(), 1);
2217 }
2218
2219 #[test]
2220 fn test_dataset_serialization() {
2221 let mut dataset = TrainingDataset::new();
2222 let board =
2223 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2224
2225 let training_data = TrainingData {
2226 board,
2227 evaluation: 0.2,
2228 depth: 10,
2229 game_id: 42,
2230 };
2231
2232 dataset.data.push(training_data);
2233
2234 let json = serde_json::to_string(&dataset.data).unwrap();
2236 let loaded_data: Vec<TrainingData> = serde_json::from_str(&json).unwrap();
2237 let loaded_dataset = TrainingDataset { data: loaded_data };
2238
2239 assert_eq!(loaded_dataset.data.len(), 1);
2240 assert_eq!(loaded_dataset.data[0].evaluation, 0.2);
2241 assert_eq!(loaded_dataset.data[0].depth, 10);
2242 assert_eq!(loaded_dataset.data[0].game_id, 42);
2243 }
2244
2245 #[test]
2246 fn test_tactical_puzzle_processing() {
2247 let puzzle = TacticalPuzzle {
2248 puzzle_id: "test123".to_string(),
2249 fen: "r1bqkbnr/pppp1ppp/2n5/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4".to_string(),
2250 moves: "Bxf7+ Ke7".to_string(),
2251 rating: 1500,
2252 rating_deviation: 100,
2253 popularity: 150,
2254 nb_plays: 1000,
2255 themes: "fork pin".to_string(),
2256 game_url: None,
2257 opening_tags: None,
2258 };
2259
2260 let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2261 assert!(tactical_data.is_some());
2262
2263 let data = tactical_data.unwrap();
2264 assert_eq!(data.move_theme, "fork");
2265 assert!(data.tactical_value > 1.0); assert!(data.difficulty > 0.0);
2267 }
2268
2269 #[test]
2270 fn test_tactical_puzzle_invalid_fen() {
2271 let puzzle = TacticalPuzzle {
2272 puzzle_id: "test123".to_string(),
2273 fen: "invalid_fen".to_string(),
2274 moves: "e2e4".to_string(),
2275 rating: 1500,
2276 rating_deviation: 100,
2277 popularity: 150,
2278 nb_plays: 1000,
2279 themes: "tactics".to_string(),
2280 game_url: None,
2281 opening_tags: None,
2282 };
2283
2284 let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2285 assert!(tactical_data.is_none());
2286 }
2287
2288 #[test]
2289 fn test_engine_evaluator() {
2290 let evaluator = EngineEvaluator::new(15);
2291
2292 let mut dataset = TrainingDataset::new();
2294 let board = Board::default();
2295
2296 let training_data = TrainingData {
2297 board,
2298 evaluation: 0.0,
2299 depth: 15,
2300 game_id: 1,
2301 };
2302
2303 dataset.data.push(training_data);
2304
2305 let mut engine = ChessVectorEngine::new(1024);
2307 engine.add_position(&board, 0.1);
2308
2309 let accuracy = evaluator.evaluate_accuracy(&mut engine, &dataset);
2311 assert!(accuracy.is_ok());
2312 assert!(accuracy.unwrap() < 1.0); }
2314
2315 #[test]
2316 fn test_tactical_training_integration() {
2317 let tactical_data = vec![TacticalTrainingData {
2318 position: Board::default(),
2319 solution_move: ChessMove::from_str("e2e4").unwrap(),
2320 move_theme: "opening".to_string(),
2321 difficulty: 1.2,
2322 tactical_value: 2.5,
2323 }];
2324
2325 let mut engine = ChessVectorEngine::new(1024);
2326 TacticalPuzzleParser::load_into_engine(&tactical_data, &mut engine);
2327
2328 assert_eq!(engine.knowledge_base_size(), 1);
2329 assert_eq!(engine.position_moves.len(), 1);
2330
2331 let recommendations = engine.recommend_moves(&Board::default(), 5);
2333 assert!(!recommendations.is_empty());
2334 }
2335
2336 #[test]
2337 fn test_multithreading_operations() {
2338 let mut dataset = TrainingDataset::new();
2339 let board = Board::default();
2340
2341 for i in 0..10 {
2343 let training_data = TrainingData {
2344 board,
2345 evaluation: i as f32 * 0.1,
2346 depth: 15,
2347 game_id: i,
2348 };
2349 dataset.data.push(training_data);
2350 }
2351
2352 dataset.deduplicate_parallel(0.95, 5);
2354 assert!(dataset.data.len() <= 10);
2355 }
2356
2357 #[test]
2358 fn test_incremental_dataset_operations() {
2359 let mut dataset1 = TrainingDataset::new();
2360 let board1 = Board::default();
2361 let board2 =
2362 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2363
2364 dataset1.add_position(board1, 0.0, 15, 1);
2366 dataset1.add_position(board2, 0.2, 15, 2);
2367 assert_eq!(dataset1.data.len(), 2);
2368
2369 let mut dataset2 = TrainingDataset::new();
2371 dataset2.add_position(
2372 Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2373 .unwrap(),
2374 0.3,
2375 15,
2376 3,
2377 );
2378
2379 dataset1.merge(dataset2);
2381 assert_eq!(dataset1.data.len(), 3);
2382
2383 let next_id = dataset1.next_game_id();
2385 assert_eq!(next_id, 4); }
2387
2388 #[test]
2389 fn test_save_load_incremental() {
2390 use tempfile::tempdir;
2391
2392 let temp_dir = tempdir().unwrap();
2393 let file_path = temp_dir.path().join("incremental_test.json");
2394
2395 let mut dataset1 = TrainingDataset::new();
2397 dataset1.add_position(Board::default(), 0.0, 15, 1);
2398 dataset1.save(&file_path).unwrap();
2399
2400 let mut dataset2 = TrainingDataset::new();
2402 dataset2.add_position(
2403 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap(),
2404 0.2,
2405 15,
2406 2,
2407 );
2408 dataset2.save_incremental(&file_path).unwrap();
2409
2410 let loaded = TrainingDataset::load(&file_path).unwrap();
2412 assert_eq!(loaded.data.len(), 2);
2413
2414 let mut dataset3 = TrainingDataset::new();
2416 dataset3.add_position(
2417 Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2418 .unwrap(),
2419 0.3,
2420 15,
2421 3,
2422 );
2423 dataset3.load_and_append(&file_path).unwrap();
2424 assert_eq!(dataset3.data.len(), 3); }
2426
2427 #[test]
2428 fn test_add_position_method() {
2429 let mut dataset = TrainingDataset::new();
2430 let board = Board::default();
2431
2432 dataset.add_position(board, 0.5, 20, 42);
2434 assert_eq!(dataset.data.len(), 1);
2435 assert_eq!(dataset.data[0].evaluation, 0.5);
2436 assert_eq!(dataset.data[0].depth, 20);
2437 assert_eq!(dataset.data[0].game_id, 42);
2438 }
2439
2440 #[test]
2441 fn test_incremental_save_deduplication() {
2442 use tempfile::tempdir;
2443
2444 let temp_dir = tempdir().unwrap();
2445 let file_path = temp_dir.path().join("dedup_test.json");
2446
2447 let mut dataset1 = TrainingDataset::new();
2449 dataset1.add_position(Board::default(), 0.0, 15, 1);
2450 dataset1.save(&file_path).unwrap();
2451
2452 let mut dataset2 = TrainingDataset::new();
2454 dataset2.add_position(Board::default(), 0.1, 15, 2); dataset2.save_incremental(&file_path).unwrap();
2456
2457 let loaded = TrainingDataset::load(&file_path).unwrap();
2459 assert_eq!(loaded.data.len(), 1);
2460 }
2461
2462 #[test]
2463 fn test_tactical_puzzle_incremental_loading() {
2464 let tactical_data = vec![
2465 TacticalTrainingData {
2466 position: Board::default(),
2467 solution_move: ChessMove::from_str("e2e4").unwrap(),
2468 move_theme: "opening".to_string(),
2469 difficulty: 1.2,
2470 tactical_value: 2.5,
2471 },
2472 TacticalTrainingData {
2473 position: Board::from_str(
2474 "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1",
2475 )
2476 .unwrap(),
2477 solution_move: ChessMove::from_str("e7e5").unwrap(),
2478 move_theme: "opening".to_string(),
2479 difficulty: 1.0,
2480 tactical_value: 2.0,
2481 },
2482 ];
2483
2484 let mut engine = ChessVectorEngine::new(1024);
2485
2486 engine.add_position(&Board::default(), 0.1);
2488 assert_eq!(engine.knowledge_base_size(), 1);
2489
2490 TacticalPuzzleParser::load_into_engine_incremental(&tactical_data, &mut engine);
2492
2493 assert_eq!(engine.knowledge_base_size(), 2);
2495
2496 assert!(engine.training_stats().has_move_data);
2498 assert!(engine.training_stats().move_data_entries > 0);
2499 }
2500
2501 #[test]
2502 fn test_tactical_puzzle_serialization() {
2503 use tempfile::tempdir;
2504
2505 let temp_dir = tempdir().unwrap();
2506 let file_path = temp_dir.path().join("tactical_test.json");
2507
2508 let tactical_data = vec![TacticalTrainingData {
2509 position: Board::default(),
2510 solution_move: ChessMove::from_str("e2e4").unwrap(),
2511 move_theme: "fork".to_string(),
2512 difficulty: 1.5,
2513 tactical_value: 3.0,
2514 }];
2515
2516 TacticalPuzzleParser::save_tactical_puzzles(&tactical_data, &file_path).unwrap();
2518
2519 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2521 assert_eq!(loaded.len(), 1);
2522 assert_eq!(loaded[0].move_theme, "fork");
2523 assert_eq!(loaded[0].difficulty, 1.5);
2524 assert_eq!(loaded[0].tactical_value, 3.0);
2525 }
2526
2527 #[test]
2528 fn test_tactical_puzzle_incremental_save() {
2529 use tempfile::tempdir;
2530
2531 let temp_dir = tempdir().unwrap();
2532 let file_path = temp_dir.path().join("incremental_tactical.json");
2533
2534 let batch1 = vec![TacticalTrainingData {
2536 position: Board::default(),
2537 solution_move: ChessMove::from_str("e2e4").unwrap(),
2538 move_theme: "opening".to_string(),
2539 difficulty: 1.0,
2540 tactical_value: 2.0,
2541 }];
2542 TacticalPuzzleParser::save_tactical_puzzles(&batch1, &file_path).unwrap();
2543
2544 let batch2 = vec![TacticalTrainingData {
2546 position: Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1")
2547 .unwrap(),
2548 solution_move: ChessMove::from_str("e7e5").unwrap(),
2549 move_theme: "counter".to_string(),
2550 difficulty: 1.2,
2551 tactical_value: 2.2,
2552 }];
2553 TacticalPuzzleParser::save_tactical_puzzles_incremental(&batch2, &file_path).unwrap();
2554
2555 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2557 assert_eq!(loaded.len(), 2);
2558 }
2559
2560 #[test]
2561 fn test_tactical_puzzle_incremental_deduplication() {
2562 use tempfile::tempdir;
2563
2564 let temp_dir = tempdir().unwrap();
2565 let file_path = temp_dir.path().join("dedup_tactical.json");
2566
2567 let tactical_data = TacticalTrainingData {
2568 position: Board::default(),
2569 solution_move: ChessMove::from_str("e2e4").unwrap(),
2570 move_theme: "opening".to_string(),
2571 difficulty: 1.0,
2572 tactical_value: 2.0,
2573 };
2574
2575 TacticalPuzzleParser::save_tactical_puzzles(&[tactical_data.clone()], &file_path).unwrap();
2577
2578 TacticalPuzzleParser::save_tactical_puzzles_incremental(&[tactical_data], &file_path)
2580 .unwrap();
2581
2582 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2584 assert_eq!(loaded.len(), 1);
2585 }
2586}