1use chess::{Board, ChessMove, Game, MoveGen};
2use indicatif::{ProgressBar, ProgressStyle};
3use pgn_reader::{BufferedReader, RawHeader, SanPlus, Skip, Visitor};
4use rayon::prelude::*;
5use serde::{Deserialize, Serialize};
6use std::fs::File;
7use std::io::{BufRead, BufReader, BufWriter, Write};
8use std::path::Path;
9use std::process::{Child, Command, Stdio};
10use std::str::FromStr;
11use std::sync::{Arc, Mutex};
12
13use crate::ChessVectorEngine;
14
15#[derive(Debug, Clone)]
17pub struct SelfPlayConfig {
18 pub games_per_iteration: usize,
20 pub max_moves_per_game: usize,
22 pub exploration_factor: f32,
24 pub min_confidence: f32,
26 pub use_opening_book: bool,
28 pub temperature: f32,
30}
31
32impl Default for SelfPlayConfig {
33 fn default() -> Self {
34 Self {
35 games_per_iteration: 100,
36 max_moves_per_game: 200,
37 exploration_factor: 0.3,
38 min_confidence: 0.1,
39 use_opening_book: true,
40 temperature: 0.8,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct TrainingData {
48 pub board: Board,
49 pub evaluation: f32,
50 pub depth: u8,
51 pub game_id: usize,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct TacticalPuzzle {
57 #[serde(rename = "PuzzleId")]
58 pub puzzle_id: String,
59 #[serde(rename = "FEN")]
60 pub fen: String,
61 #[serde(rename = "Moves")]
62 pub moves: String, #[serde(rename = "Rating")]
64 pub rating: u32,
65 #[serde(rename = "RatingDeviation")]
66 pub rating_deviation: u32,
67 #[serde(rename = "Popularity")]
68 pub popularity: i32,
69 #[serde(rename = "NbPlays")]
70 pub nb_plays: u32,
71 #[serde(rename = "Themes")]
72 pub themes: String, #[serde(rename = "GameUrl")]
74 pub game_url: Option<String>,
75 #[serde(rename = "OpeningTags")]
76 pub opening_tags: Option<String>,
77}
78
79#[derive(Debug, Clone)]
81pub struct TacticalTrainingData {
82 pub position: Board,
83 pub solution_move: ChessMove,
84 pub move_theme: String,
85 pub difficulty: f32, pub tactical_value: f32, }
88
89impl serde::Serialize for TacticalTrainingData {
91 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
92 where
93 S: serde::Serializer,
94 {
95 use serde::ser::SerializeStruct;
96 let mut state = serializer.serialize_struct("TacticalTrainingData", 5)?;
97 state.serialize_field("fen", &self.position.to_string())?;
98 state.serialize_field("solution_move", &self.solution_move.to_string())?;
99 state.serialize_field("move_theme", &self.move_theme)?;
100 state.serialize_field("difficulty", &self.difficulty)?;
101 state.serialize_field("tactical_value", &self.tactical_value)?;
102 state.end()
103 }
104}
105
106impl<'de> serde::Deserialize<'de> for TacticalTrainingData {
107 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108 where
109 D: serde::Deserializer<'de>,
110 {
111 use serde::de::{self, MapAccess, Visitor};
112 use std::fmt;
113
114 struct TacticalTrainingDataVisitor;
115
116 impl<'de> Visitor<'de> for TacticalTrainingDataVisitor {
117 type Value = TacticalTrainingData;
118
119 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
120 formatter.write_str("struct TacticalTrainingData")
121 }
122
123 fn visit_map<V>(self, mut map: V) -> Result<TacticalTrainingData, V::Error>
124 where
125 V: MapAccess<'de>,
126 {
127 let mut fen = None;
128 let mut solution_move = None;
129 let mut move_theme = None;
130 let mut difficulty = None;
131 let mut tactical_value = None;
132
133 while let Some(key) = map.next_key()? {
134 match key {
135 "fen" => {
136 if fen.is_some() {
137 return Err(de::Error::duplicate_field("fen"));
138 }
139 fen = Some(map.next_value()?);
140 }
141 "solution_move" => {
142 if solution_move.is_some() {
143 return Err(de::Error::duplicate_field("solution_move"));
144 }
145 solution_move = Some(map.next_value()?);
146 }
147 "move_theme" => {
148 if move_theme.is_some() {
149 return Err(de::Error::duplicate_field("move_theme"));
150 }
151 move_theme = Some(map.next_value()?);
152 }
153 "difficulty" => {
154 if difficulty.is_some() {
155 return Err(de::Error::duplicate_field("difficulty"));
156 }
157 difficulty = Some(map.next_value()?);
158 }
159 "tactical_value" => {
160 if tactical_value.is_some() {
161 return Err(de::Error::duplicate_field("tactical_value"));
162 }
163 tactical_value = Some(map.next_value()?);
164 }
165 _ => {
166 let _: serde_json::Value = map.next_value()?;
167 }
168 }
169 }
170
171 let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
172 let solution_move_str: String =
173 solution_move.ok_or_else(|| de::Error::missing_field("solution_move"))?;
174 let move_theme =
175 move_theme.ok_or_else(|| de::Error::missing_field("move_theme"))?;
176 let difficulty =
177 difficulty.ok_or_else(|| de::Error::missing_field("difficulty"))?;
178 let tactical_value =
179 tactical_value.ok_or_else(|| de::Error::missing_field("tactical_value"))?;
180
181 let position =
182 Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
183
184 let solution_move = ChessMove::from_str(&solution_move_str)
185 .map_err(|e| de::Error::custom(format!("Error: {e}")))?;
186
187 Ok(TacticalTrainingData {
188 position,
189 solution_move,
190 move_theme,
191 difficulty,
192 tactical_value,
193 })
194 }
195 }
196
197 const FIELDS: &[&str] = &[
198 "fen",
199 "solution_move",
200 "move_theme",
201 "difficulty",
202 "tactical_value",
203 ];
204 deserializer.deserialize_struct("TacticalTrainingData", FIELDS, TacticalTrainingDataVisitor)
205 }
206}
207
208pub struct GameExtractor {
210 pub positions: Vec<TrainingData>,
211 pub current_game: Game,
212 pub move_count: usize,
213 pub max_moves_per_game: usize,
214 pub game_id: usize,
215}
216
217impl GameExtractor {
218 pub fn new(max_moves_per_game: usize) -> Self {
219 Self {
220 positions: Vec::new(),
221 current_game: Game::new(),
222 move_count: 0,
223 max_moves_per_game,
224 game_id: 0,
225 }
226 }
227}
228
229impl Visitor for GameExtractor {
230 type Result = ();
231
232 fn begin_game(&mut self) {
233 self.current_game = Game::new();
234 self.move_count = 0;
235 self.game_id += 1;
236 }
237
238 fn header(&mut self, _key: &[u8], _value: RawHeader<'_>) {}
239
240 fn san(&mut self, san_plus: SanPlus) {
241 if self.move_count >= self.max_moves_per_game {
242 return;
243 }
244
245 let san_str = san_plus.san.to_string();
246
247 let current_pos = self.current_game.current_position();
249
250 match chess::ChessMove::from_san(¤t_pos, &san_str) {
252 Ok(chess_move) => {
253 let legal_moves: Vec<chess::ChessMove> = MoveGen::new_legal(¤t_pos).collect();
255 if legal_moves.contains(&chess_move) {
256 if self.current_game.make_move(chess_move) {
257 self.move_count += 1;
258
259 self.positions.push(TrainingData {
261 board: self.current_game.current_position(),
262 evaluation: 0.0, depth: 0,
264 game_id: self.game_id,
265 });
266 }
267 } else {
268 }
270 }
271 Err(_) => {
272 if !san_str.contains("O-O") && !san_str.contains("=") && san_str.len() > 6 {
276 }
278 }
279 }
280 }
281
282 fn begin_variation(&mut self) -> Skip {
283 Skip(true) }
285
286 fn end_game(&mut self) -> Self::Result {}
287}
288
289pub struct ExternalEvaluator {
291 depth: u8,
292}
293
294impl ExternalEvaluator {
295 pub fn new(depth: u8) -> Self {
296 Self { depth }
297 }
298
299 pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
301 let mut child = Command::new("stockfish")
302 .stdin(Stdio::piped())
303 .stdout(Stdio::piped())
304 .stderr(Stdio::piped())
305 .spawn()?;
306
307 let stdin = child
308 .stdin
309 .as_mut()
310 .ok_or("Failed to get stdin handle for external engine process")?;
311 let fen = board.to_string();
312
313 use std::io::Write;
315 writeln!(stdin, "uci")?;
316 writeln!(stdin, "isready")?;
317 writeln!(stdin, "position fen {fen}")?;
318 writeln!(stdin, "go depth {}", self.depth)?;
319 writeln!(stdin, "quit")?;
320
321 let output = child.wait_with_output()?;
322 let stdout = String::from_utf8_lossy(&output.stdout);
323
324 for line in stdout.lines() {
326 if line.starts_with("info") && line.contains("score cp") {
327 if let Some(cp_pos) = line.find("score cp ") {
328 let cp_str = &line[cp_pos + 9..];
329 if let Some(end) = cp_str.find(' ') {
330 let cp_value = cp_str[..end].parse::<i32>()?;
331 return Ok(cp_value as f32 / 100.0); }
333 }
334 } else if line.starts_with("info") && line.contains("score mate") {
335 if let Some(mate_pos) = line.find("score mate ") {
337 let mate_str = &line[mate_pos + 11..];
338 if let Some(end) = mate_str.find(' ') {
339 let mate_moves = mate_str[..end].parse::<i32>()?;
340 return Ok(if mate_moves > 0 { 100.0 } else { -100.0 });
341 }
342 }
343 }
344 }
345
346 Ok(0.0) }
348
349 pub fn evaluate_batch(
351 &self,
352 positions: &mut [TrainingData],
353 ) -> Result<(), Box<dyn std::error::Error>> {
354 let pb = ProgressBar::new(positions.len() as u64);
355 if let Ok(style) = ProgressStyle::default_bar().template(
356 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
357 ) {
358 pb.set_style(style.progress_chars("#>-"));
359 }
360
361 for data in positions.iter_mut() {
362 match self.evaluate_position(&data.board) {
363 Ok(eval) => {
364 data.evaluation = eval;
365 data.depth = self.depth;
366 }
367 Err(e) => {
368 eprintln!("Evaluation error: {e}");
369 data.evaluation = 0.0;
370 }
371 }
372 pb.inc(1);
373 }
374
375 pb.finish_with_message("Evaluation complete");
376 Ok(())
377 }
378
379 pub fn evaluate_batch_parallel(
381 &self,
382 positions: &mut [TrainingData],
383 num_threads: usize,
384 ) -> Result<(), Box<dyn std::error::Error>> {
385 let pb = ProgressBar::new(positions.len() as u64);
386 if let Ok(style) = ProgressStyle::default_bar()
387 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel evaluation") {
388 pb.set_style(style.progress_chars("#>-"));
389 }
390
391 let pool = rayon::ThreadPoolBuilder::new()
393 .num_threads(num_threads)
394 .build()?;
395
396 pool.install(|| {
397 positions.par_iter_mut().for_each(|data| {
399 match self.evaluate_position(&data.board) {
400 Ok(eval) => {
401 data.evaluation = eval;
402 data.depth = self.depth;
403 }
404 Err(_) => {
405 data.evaluation = 0.0;
407 }
408 }
409 pb.inc(1);
410 });
411 });
412
413 pb.finish_with_message("Parallel evaluation complete");
414 Ok(())
415 }
416}
417
418struct ExternalEngineProcess {
420 child: Child,
421 stdin: BufWriter<std::process::ChildStdin>,
422 stdout: BufReader<std::process::ChildStdout>,
423 #[allow(dead_code)]
424 depth: u8,
425}
426
427impl ExternalEngineProcess {
428 fn new(depth: u8) -> Result<Self, Box<dyn std::error::Error>> {
429 let mut child = Command::new("stockfish")
430 .stdin(Stdio::piped())
431 .stdout(Stdio::piped())
432 .stderr(Stdio::piped())
433 .spawn()?;
434
435 let stdin = BufWriter::new(
436 child
437 .stdin
438 .take()
439 .ok_or("Failed to get stdin handle for external engine process")?,
440 );
441 let stdout = BufReader::new(
442 child
443 .stdout
444 .take()
445 .ok_or("Failed to get stdout handle for external engine process")?,
446 );
447
448 let mut process = Self {
449 child,
450 stdin,
451 stdout,
452 depth,
453 };
454
455 process.send_command("uci")?;
457 process.wait_for_ready()?;
458 process.send_command("isready")?;
459 process.wait_for_ready()?;
460
461 Ok(process)
462 }
463
464 fn send_command(&mut self, command: &str) -> Result<(), Box<dyn std::error::Error>> {
465 writeln!(self.stdin, "{command}")?;
466 self.stdin.flush()?;
467 Ok(())
468 }
469
470 fn wait_for_ready(&mut self) -> Result<(), Box<dyn std::error::Error>> {
471 let mut line = String::new();
472 loop {
473 line.clear();
474 self.stdout.read_line(&mut line)?;
475 if line.trim() == "uciok" || line.trim() == "readyok" {
476 break;
477 }
478 }
479 Ok(())
480 }
481
482 fn evaluate_position(&mut self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
483 let fen = board.to_string();
484
485 self.send_command(&format!("position fen {fen}"))?;
487 self.send_command(&format!("position fen {fen}"))?;
488
489 let mut line = String::new();
491 let mut last_evaluation = 0.0;
492
493 loop {
494 line.clear();
495 self.stdout.read_line(&mut line)?;
496 let line = line.trim();
497
498 if line.starts_with("info") && line.contains("score cp") {
499 if let Some(cp_pos) = line.find("score cp ") {
500 let cp_str = &line[cp_pos + 9..];
501 if let Some(end) = cp_str.find(' ') {
502 if let Ok(cp_value) = cp_str[..end].parse::<i32>() {
503 last_evaluation = cp_value as f32 / 100.0;
504 }
505 }
506 }
507 } else if line.starts_with("info") && line.contains("score mate") {
508 if let Some(mate_pos) = line.find("score mate ") {
509 let mate_str = &line[mate_pos + 11..];
510 if let Some(end) = mate_str.find(' ') {
511 if let Ok(mate_moves) = mate_str[..end].parse::<i32>() {
512 last_evaluation = if mate_moves > 0 { 100.0 } else { -100.0 };
513 }
514 }
515 }
516 } else if line.starts_with("bestmove") {
517 break;
518 }
519 }
520
521 Ok(last_evaluation)
522 }
523}
524
525impl Drop for ExternalEngineProcess {
526 fn drop(&mut self) {
527 let _ = self.send_command("quit");
528 let _ = self.child.wait();
529 }
530}
531
532pub struct ExternalEnginePool {
534 pool: Arc<Mutex<Vec<ExternalEngineProcess>>>,
535 depth: u8,
536 pool_size: usize,
537}
538
539impl ExternalEnginePool {
540 pub fn new(depth: u8, pool_size: usize) -> Result<Self, Box<dyn std::error::Error>> {
541 let mut processes = Vec::with_capacity(pool_size);
542
543 println!("๐ Initializing external engine pool with {pool_size} processes...");
544
545 for i in 0..pool_size {
546 match ExternalEngineProcess::new(depth) {
547 Ok(process) => {
548 processes.push(process);
549 if i % 2 == 1 {
550 print!(".");
551 let _ = std::io::stdout().flush(); }
553 }
554 Err(e) => {
555 eprintln!("Evaluation error: {e}");
556 return Err(e);
557 }
558 }
559 }
560
561 println!(" โ
Pool ready!");
562
563 Ok(Self {
564 pool: Arc::new(Mutex::new(processes)),
565 depth,
566 pool_size,
567 })
568 }
569
570 pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
571 let mut process = {
573 let mut pool = self.pool.lock().unwrap();
574 if let Some(process) = pool.pop() {
575 process
576 } else {
577 ExternalEngineProcess::new(self.depth)?
579 }
580 };
581
582 let result = process.evaluate_position(board);
584
585 {
587 let mut pool = self.pool.lock().unwrap();
588 if pool.len() < self.pool_size {
589 pool.push(process);
590 }
591 }
593
594 result
595 }
596
597 pub fn evaluate_batch_parallel(
598 &self,
599 positions: &mut [TrainingData],
600 ) -> Result<(), Box<dyn std::error::Error>> {
601 let pb = ProgressBar::new(positions.len() as u64);
602 pb.set_style(ProgressStyle::default_bar()
603 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Pool evaluation")
604 .unwrap()
605 .progress_chars("#>-"));
606
607 positions.par_iter_mut().for_each(|data| {
609 match self.evaluate_position(&data.board) {
610 Ok(eval) => {
611 data.evaluation = eval;
612 data.depth = self.depth;
613 }
614 Err(_) => {
615 data.evaluation = 0.0;
616 }
617 }
618 pb.inc(1);
619 });
620
621 pb.finish_with_message("Pool evaluation complete");
622 Ok(())
623 }
624}
625
626pub struct TrainingDataset {
628 pub data: Vec<TrainingData>,
629}
630
631impl Default for TrainingDataset {
632 fn default() -> Self {
633 Self::new()
634 }
635}
636
637impl TrainingDataset {
638 pub fn new() -> Self {
639 Self { data: Vec::new() }
640 }
641
642 pub fn load_from_pgn<P: AsRef<Path>>(
644 &mut self,
645 path: P,
646 max_games: Option<usize>,
647 max_moves_per_game: usize,
648 ) -> Result<(), Box<dyn std::error::Error>> {
649 use indicatif::{ProgressBar, ProgressStyle};
650
651 println!("๐ Reading PGN file...");
652 let file = File::open(path)?;
653 let reader = BufReader::new(file);
654
655 let mut games = Vec::new();
657 let mut current_game = String::new();
658 let mut games_collected = 0;
659
660 for line in reader.lines() {
661 let line = line?;
662 current_game.push_str(&line);
663 current_game.push('\n');
664
665 if line.trim().ends_with("1-0")
667 || line.trim().ends_with("0-1")
668 || line.trim().ends_with("1/2-1/2")
669 || line.trim().ends_with("*")
670 {
671 games.push(current_game.clone());
672 current_game.clear();
673 games_collected += 1;
674
675 if let Some(max) = max_games {
676 if games_collected >= max {
677 break;
678 }
679 }
680 }
681 }
682
683 println!(
684 "๐ฆ Collected {} games, processing in parallel...",
685 games.len()
686 );
687
688 let pb = ProgressBar::new(games.len() as u64);
690 pb.set_style(
691 ProgressStyle::default_bar()
692 .template("โก Processing [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}")
693 .unwrap()
694 .progress_chars("โโโ")
695 );
696
697 let all_positions: Vec<Vec<TrainingData>> = games
699 .par_iter()
700 .map(|game_pgn| {
701 pb.inc(1);
702 pb.set_message("Processing game");
703
704 let mut local_extractor = GameExtractor::new(max_moves_per_game);
706
707 let cursor = std::io::Cursor::new(game_pgn);
709 let mut reader = BufferedReader::new(cursor);
710
711 if let Err(e) = reader.read_all(&mut local_extractor) {
712 eprintln!("Parse error: {e}");
713 return Vec::new(); }
715
716 local_extractor.positions
717 })
718 .collect();
719
720 pb.finish_with_message("โ
Parallel processing completed");
721
722 for game_positions in all_positions {
724 self.data.extend(game_positions);
725 }
726
727 println!(
728 "โ
Loaded {} positions from {} games (parallel processing)",
729 self.data.len(),
730 games.len()
731 );
732 Ok(())
733 }
734
735 pub fn evaluate_with_external_engine(
737 &mut self,
738 depth: u8,
739 ) -> Result<(), Box<dyn std::error::Error>> {
740 let evaluator = ExternalEvaluator::new(depth);
741 evaluator.evaluate_batch(&mut self.data)
742 }
743
744 pub fn evaluate_with_external_engine_parallel(
746 &mut self,
747 depth: u8,
748 num_threads: usize,
749 ) -> Result<(), Box<dyn std::error::Error>> {
750 let evaluator = ExternalEvaluator::new(depth);
751 evaluator.evaluate_batch_parallel(&mut self.data, num_threads)
752 }
753
754 pub fn train_engine(&self, engine: &mut ChessVectorEngine) {
756 let pb = ProgressBar::new(self.data.len() as u64);
757 pb.set_style(ProgressStyle::default_bar()
758 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Training positions")
759 .unwrap()
760 .progress_chars("#>-"));
761
762 for data in &self.data {
763 engine.add_position(&data.board, data.evaluation);
764 pb.inc(1);
765 }
766
767 pb.finish_with_message("Training complete");
768 println!("Trained engine with {} positions", self.data.len());
769 }
770
771 pub fn split(&self, train_ratio: f32) -> (TrainingDataset, TrainingDataset) {
773 use rand::seq::SliceRandom;
774 use rand::thread_rng;
775 use std::collections::{HashMap, HashSet};
776
777 let mut games: HashMap<usize, Vec<&TrainingData>> = HashMap::new();
779 for data in &self.data {
780 games.entry(data.game_id).or_default().push(data);
781 }
782
783 let mut game_ids: Vec<usize> = games.keys().cloned().collect();
785 game_ids.shuffle(&mut thread_rng());
786
787 let split_point = (game_ids.len() as f32 * train_ratio) as usize;
789 let train_game_ids: HashSet<usize> = game_ids[..split_point].iter().cloned().collect();
790
791 let mut train_data = Vec::new();
793 let mut test_data = Vec::new();
794
795 for data in &self.data {
796 if train_game_ids.contains(&data.game_id) {
797 train_data.push(data.clone());
798 } else {
799 test_data.push(data.clone());
800 }
801 }
802
803 (
804 TrainingDataset { data: train_data },
805 TrainingDataset { data: test_data },
806 )
807 }
808
809 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
811 let json = serde_json::to_string_pretty(&self.data)?;
812 std::fs::write(path, json)?;
813 Ok(())
814 }
815
816 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
818 let content = std::fs::read_to_string(path)?;
819 let data = serde_json::from_str(&content)?;
820 Ok(Self { data })
821 }
822
823 pub fn load_and_append<P: AsRef<Path>>(
825 &mut self,
826 path: P,
827 ) -> Result<(), Box<dyn std::error::Error>> {
828 let existing_len = self.data.len();
829 let additional_data = Self::load(path)?;
830 self.data.extend(additional_data.data);
831 println!(
832 "Loaded {} additional positions (total: {})",
833 self.data.len() - existing_len,
834 self.data.len()
835 );
836 Ok(())
837 }
838
839 pub fn merge(&mut self, other: TrainingDataset) {
841 let existing_len = self.data.len();
842 self.data.extend(other.data);
843 println!(
844 "Merged {} positions (total: {})",
845 self.data.len() - existing_len,
846 self.data.len()
847 );
848 }
849
850 pub fn save_incremental<P: AsRef<Path>>(
852 &self,
853 path: P,
854 ) -> Result<(), Box<dyn std::error::Error>> {
855 self.save_incremental_with_options(path, true)
856 }
857
858 pub fn save_incremental_with_options<P: AsRef<Path>>(
860 &self,
861 path: P,
862 deduplicate: bool,
863 ) -> Result<(), Box<dyn std::error::Error>> {
864 let path = path.as_ref();
865
866 if path.exists() {
867 if self.save_append_only(path).is_ok() {
869 return Ok(());
870 }
871
872 if deduplicate {
874 self.save_incremental_full_merge(path)
875 } else {
876 self.save_incremental_no_dedup(path)
877 }
878 } else {
879 self.save(path)
881 }
882 }
883
884 fn save_incremental_no_dedup<P: AsRef<Path>>(
886 &self,
887 path: P,
888 ) -> Result<(), Box<dyn std::error::Error>> {
889 let path = path.as_ref();
890
891 println!("๐ Loading existing training data...");
892 let mut existing = Self::load(path)?;
893
894 println!("โก Fast merge without deduplication...");
895 existing.data.extend(self.data.iter().cloned());
896
897 println!(
898 "๐พ Serializing {} positions to JSON...",
899 existing.data.len()
900 );
901 let json = serde_json::to_string_pretty(&existing.data)?;
902
903 println!("โ๏ธ Writing to disk...");
904 std::fs::write(path, json)?;
905
906 println!(
907 "โ
Fast merge save: total {} positions",
908 existing.data.len()
909 );
910 Ok(())
911 }
912
913 pub fn save_append_only<P: AsRef<Path>>(
915 &self,
916 path: P,
917 ) -> Result<(), Box<dyn std::error::Error>> {
918 use std::fs::OpenOptions;
919 use std::io::{BufRead, BufReader, Seek, SeekFrom, Write};
920
921 if self.data.is_empty() {
922 return Ok(());
923 }
924
925 let path = path.as_ref();
926 let mut file = OpenOptions::new().read(true).write(true).open(path)?;
927
928 file.seek(SeekFrom::End(-10))?;
930 let mut buffer = String::new();
931 BufReader::new(&file).read_line(&mut buffer)?;
932
933 if !buffer.trim().ends_with(']') {
934 return Err("File doesn't end with JSON array bracket".into());
935 }
936
937 file.seek(SeekFrom::End(-2))?; write!(file, ",")?;
942
943 for (i, data) in self.data.iter().enumerate() {
945 if i > 0 {
946 write!(file, ",")?;
947 }
948 let json = serde_json::to_string(data)?;
949 write!(file, "{json}")?;
950 }
951
952 write!(file, "\n]")?;
954
955 println!("Fast append: added {} new positions", self.data.len());
956 Ok(())
957 }
958
959 fn save_incremental_full_merge<P: AsRef<Path>>(
961 &self,
962 path: P,
963 ) -> Result<(), Box<dyn std::error::Error>> {
964 let path = path.as_ref();
965
966 println!("๐ Loading existing training data...");
967 let mut existing = Self::load(path)?;
968 let _original_len = existing.data.len();
969
970 println!("๐ Streaming merge with deduplication (avoiding O(nยฒ) operation)...");
971 existing.merge_and_deduplicate(self.data.clone());
972
973 println!(
974 "๐พ Serializing {} positions to JSON...",
975 existing.data.len()
976 );
977 let json = serde_json::to_string_pretty(&existing.data)?;
978
979 println!("โ๏ธ Writing to disk...");
980 std::fs::write(path, json)?;
981
982 println!(
983 "โ
Streaming merge save: total {} positions",
984 existing.data.len()
985 );
986 Ok(())
987 }
988
989 pub fn add_position(&mut self, board: Board, evaluation: f32, depth: u8, game_id: usize) {
991 self.data.push(TrainingData {
992 board,
993 evaluation,
994 depth,
995 game_id,
996 });
997 }
998
999 pub fn next_game_id(&self) -> usize {
1001 self.data.iter().map(|data| data.game_id).max().unwrap_or(0) + 1
1002 }
1003
1004 pub fn deduplicate(&mut self, similarity_threshold: f32) {
1006 if similarity_threshold > 0.999 {
1007 self.deduplicate_fast();
1009 } else {
1010 self.deduplicate_similarity_based(similarity_threshold);
1012 }
1013 }
1014
1015 pub fn deduplicate_fast(&mut self) {
1017 use std::collections::HashSet;
1018
1019 if self.data.is_empty() {
1020 return;
1021 }
1022
1023 let mut seen_positions = HashSet::with_capacity(self.data.len());
1024 let original_len = self.data.len();
1025
1026 self.data.retain(|data| {
1028 let fen = data.board.to_string();
1029 seen_positions.insert(fen)
1030 });
1031
1032 println!(
1033 "Fast deduplicated: {} -> {} positions (removed {} exact duplicates)",
1034 original_len,
1035 self.data.len(),
1036 original_len - self.data.len()
1037 );
1038 }
1039
1040 pub fn merge_and_deduplicate(&mut self, new_data: Vec<TrainingData>) {
1042 use std::collections::HashSet;
1043
1044 if new_data.is_empty() {
1045 return;
1046 }
1047
1048 let _original_len = self.data.len();
1049
1050 let mut existing_positions: HashSet<String> = HashSet::with_capacity(self.data.len());
1052 for data in &self.data {
1053 existing_positions.insert(data.board.to_string());
1054 }
1055
1056 let mut added = 0;
1058 for data in new_data {
1059 let fen = data.board.to_string();
1060 if existing_positions.insert(fen) {
1061 self.data.push(data);
1062 added += 1;
1063 }
1064 }
1065
1066 println!(
1067 "Streaming merge: added {} unique positions (total: {})",
1068 added,
1069 self.data.len()
1070 );
1071 }
1072
1073 fn deduplicate_similarity_based(&mut self, similarity_threshold: f32) {
1075 use crate::PositionEncoder;
1076 use ndarray::Array1;
1077
1078 if self.data.is_empty() {
1079 return;
1080 }
1081
1082 let encoder = PositionEncoder::new(1024);
1083 let mut keep_indices: Vec<bool> = vec![true; self.data.len()];
1084
1085 let vectors: Vec<Array1<f32>> = if self.data.len() > 50 {
1087 self.data
1088 .par_iter()
1089 .map(|data| encoder.encode(&data.board))
1090 .collect()
1091 } else {
1092 self.data
1093 .iter()
1094 .map(|data| encoder.encode(&data.board))
1095 .collect()
1096 };
1097
1098 for i in 1..self.data.len() {
1100 if !keep_indices[i] {
1101 continue;
1102 }
1103
1104 for j in 0..i {
1105 if !keep_indices[j] {
1106 continue;
1107 }
1108
1109 let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1110 if similarity > similarity_threshold {
1111 keep_indices[i] = false;
1112 break;
1113 }
1114 }
1115 }
1116
1117 let original_len = self.data.len();
1119 self.data = self
1120 .data
1121 .iter()
1122 .enumerate()
1123 .filter_map(|(i, data)| {
1124 if keep_indices[i] {
1125 Some(data.clone())
1126 } else {
1127 None
1128 }
1129 })
1130 .collect();
1131
1132 println!(
1133 "Similarity deduplicated: {} -> {} positions (removed {} near-duplicates)",
1134 original_len,
1135 self.data.len(),
1136 original_len - self.data.len()
1137 );
1138 }
1139
1140 pub fn deduplicate_parallel(&mut self, similarity_threshold: f32, chunk_size: usize) {
1142 use crate::PositionEncoder;
1143 use ndarray::Array1;
1144 use std::sync::{Arc, Mutex};
1145
1146 if self.data.is_empty() {
1147 return;
1148 }
1149
1150 let encoder = PositionEncoder::new(1024);
1151
1152 let vectors: Vec<Array1<f32>> = self
1154 .data
1155 .par_iter()
1156 .map(|data| encoder.encode(&data.board))
1157 .collect();
1158
1159 let keep_indices = Arc::new(Mutex::new(vec![true; self.data.len()]));
1160
1161 (1..self.data.len())
1163 .collect::<Vec<_>>()
1164 .par_chunks(chunk_size)
1165 .for_each(|chunk| {
1166 for &i in chunk {
1167 {
1169 let indices = keep_indices.lock().unwrap();
1170 if !indices[i] {
1171 continue;
1172 }
1173 }
1174
1175 for j in 0..i {
1177 {
1178 let indices = keep_indices.lock().unwrap();
1179 if !indices[j] {
1180 continue;
1181 }
1182 }
1183
1184 let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1185 if similarity > similarity_threshold {
1186 let mut indices = keep_indices.lock().unwrap();
1187 indices[i] = false;
1188 break;
1189 }
1190 }
1191 }
1192 });
1193
1194 let keep_indices = keep_indices.lock().unwrap();
1196 let original_len = self.data.len();
1197 self.data = self
1198 .data
1199 .iter()
1200 .enumerate()
1201 .filter_map(|(i, data)| {
1202 if keep_indices[i] {
1203 Some(data.clone())
1204 } else {
1205 None
1206 }
1207 })
1208 .collect();
1209
1210 println!(
1211 "Parallel deduplicated: {} -> {} positions (removed {} duplicates)",
1212 original_len,
1213 self.data.len(),
1214 original_len - self.data.len()
1215 );
1216 }
1217
1218 fn cosine_similarity(a: &ndarray::Array1<f32>, b: &ndarray::Array1<f32>) -> f32 {
1220 let dot_product = a.dot(b);
1221 let norm_a = a.dot(a).sqrt();
1222 let norm_b = b.dot(b).sqrt();
1223
1224 if norm_a == 0.0 || norm_b == 0.0 {
1225 0.0
1226 } else {
1227 dot_product / (norm_a * norm_b)
1228 }
1229 }
1230}
1231
1232pub struct SelfPlayTrainer {
1234 config: SelfPlayConfig,
1235 game_counter: usize,
1236}
1237
1238impl SelfPlayTrainer {
1239 pub fn new(config: SelfPlayConfig) -> Self {
1240 Self {
1241 config,
1242 game_counter: 0,
1243 }
1244 }
1245
1246 pub fn generate_training_data(&mut self, engine: &mut ChessVectorEngine) -> TrainingDataset {
1248 let mut dataset = TrainingDataset::new();
1249
1250 println!(
1251 "๐ฎ Starting self-play training with {} games...",
1252 self.config.games_per_iteration
1253 );
1254 let pb = ProgressBar::new(self.config.games_per_iteration as u64);
1255 if let Ok(style) = ProgressStyle::default_bar().template(
1256 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
1257 ) {
1258 pb.set_style(style.progress_chars("#>-"));
1259 }
1260
1261 for _ in 0..self.config.games_per_iteration {
1262 let game_data = self.play_single_game(engine);
1263 dataset.data.extend(game_data);
1264 self.game_counter += 1;
1265 pb.inc(1);
1266 }
1267
1268 pb.finish_with_message("Self-play games completed");
1269 println!(
1270 "โ
Generated {} positions from {} games",
1271 dataset.data.len(),
1272 self.config.games_per_iteration
1273 );
1274
1275 dataset
1276 }
1277
1278 fn play_single_game(&self, engine: &mut ChessVectorEngine) -> Vec<TrainingData> {
1280 let mut game = Game::new();
1281 let mut positions = Vec::new();
1282 let mut move_count = 0;
1283
1284 if self.config.use_opening_book {
1286 if let Some(opening_moves) = self.get_random_opening() {
1287 for mv in opening_moves {
1288 if game.make_move(mv) {
1289 move_count += 1;
1290 } else {
1291 break;
1292 }
1293 }
1294 }
1295 }
1296
1297 while game.result().is_none() && move_count < self.config.max_moves_per_game {
1299 let current_position = game.current_position();
1300
1301 let move_choice = self.select_move_with_exploration(engine, ¤t_position);
1303
1304 if let Some(chess_move) = move_choice {
1305 if let Some(evaluation) = engine.evaluate_position(¤t_position) {
1307 if evaluation.abs() >= self.config.min_confidence || move_count < 10 {
1309 positions.push(TrainingData {
1310 board: current_position,
1311 evaluation,
1312 depth: 1, game_id: self.game_counter,
1314 });
1315 }
1316 }
1317
1318 if !game.make_move(chess_move) {
1320 break; }
1322 move_count += 1;
1323 } else {
1324 break; }
1326 }
1327
1328 if let Some(result) = game.result() {
1330 let final_position = game.current_position();
1331 let final_eval = match result {
1332 chess::GameResult::WhiteCheckmates => {
1333 if final_position.side_to_move() == chess::Color::Black {
1334 10.0
1335 } else {
1336 -10.0
1337 }
1338 }
1339 chess::GameResult::BlackCheckmates => {
1340 if final_position.side_to_move() == chess::Color::White {
1341 10.0
1342 } else {
1343 -10.0
1344 }
1345 }
1346 chess::GameResult::WhiteResigns => -10.0,
1347 chess::GameResult::BlackResigns => 10.0,
1348 chess::GameResult::Stalemate
1349 | chess::GameResult::DrawAccepted
1350 | chess::GameResult::DrawDeclared => 0.0,
1351 };
1352
1353 positions.push(TrainingData {
1354 board: final_position,
1355 evaluation: final_eval,
1356 depth: 1,
1357 game_id: self.game_counter,
1358 });
1359 }
1360
1361 positions
1362 }
1363
1364 fn select_move_with_exploration(
1366 &self,
1367 engine: &mut ChessVectorEngine,
1368 position: &Board,
1369 ) -> Option<ChessMove> {
1370 let recommendations = engine.recommend_legal_moves(position, 5);
1371
1372 if recommendations.is_empty() {
1373 return None;
1374 }
1375
1376 if fastrand::f32() < self.config.exploration_factor {
1378 self.select_move_with_temperature(&recommendations)
1380 } else {
1381 Some(recommendations[0].chess_move)
1383 }
1384 }
1385
1386 fn select_move_with_temperature(
1388 &self,
1389 recommendations: &[crate::MoveRecommendation],
1390 ) -> Option<ChessMove> {
1391 if recommendations.is_empty() {
1392 return None;
1393 }
1394
1395 let mut probabilities = Vec::new();
1397 let mut sum = 0.0;
1398
1399 for rec in recommendations {
1400 let prob = (rec.average_outcome / self.config.temperature).exp();
1402 probabilities.push(prob);
1403 sum += prob;
1404 }
1405
1406 for prob in &mut probabilities {
1408 *prob /= sum;
1409 }
1410
1411 let rand_val = fastrand::f32();
1413 let mut cumulative = 0.0;
1414
1415 for (i, &prob) in probabilities.iter().enumerate() {
1416 cumulative += prob;
1417 if rand_val <= cumulative {
1418 return Some(recommendations[i].chess_move);
1419 }
1420 }
1421
1422 Some(recommendations[0].chess_move)
1424 }
1425
1426 fn get_random_opening(&self) -> Option<Vec<ChessMove>> {
1428 let openings = [
1429 vec!["e4", "e5", "Nf3", "Nc6", "Bc4"],
1431 vec!["e4", "e5", "Nf3", "Nc6", "Bb5"],
1433 vec!["d4", "d5", "c4"],
1435 vec!["d4", "Nf6", "c4", "g6"],
1437 vec!["e4", "c5"],
1439 vec!["e4", "e6"],
1441 vec!["e4", "c6"],
1443 ];
1444
1445 let selected_opening = &openings[fastrand::usize(0..openings.len())];
1446
1447 let mut moves = Vec::new();
1448 let mut game = Game::new();
1449
1450 for move_str in selected_opening {
1451 if let Ok(chess_move) = ChessMove::from_str(move_str) {
1452 if game.make_move(chess_move) {
1453 moves.push(chess_move);
1454 } else {
1455 break;
1456 }
1457 }
1458 }
1459
1460 if moves.is_empty() {
1461 None
1462 } else {
1463 Some(moves)
1464 }
1465 }
1466}
1467
1468pub struct EngineEvaluator {
1470 #[allow(dead_code)]
1471 engine_depth: u8,
1472}
1473
1474impl EngineEvaluator {
1475 pub fn new(engine_depth: u8) -> Self {
1476 Self { engine_depth }
1477 }
1478
1479 pub fn evaluate_accuracy(
1481 &self,
1482 engine: &mut ChessVectorEngine,
1483 test_data: &TrainingDataset,
1484 ) -> Result<f32, Box<dyn std::error::Error>> {
1485 let mut total_error = 0.0;
1486 let mut valid_comparisons = 0;
1487
1488 let pb = ProgressBar::new(test_data.data.len() as u64);
1489 pb.set_style(ProgressStyle::default_bar()
1490 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Evaluating accuracy")
1491 .unwrap()
1492 .progress_chars("#>-"));
1493
1494 for data in &test_data.data {
1495 if let Some(engine_eval) = engine.evaluate_position(&data.board) {
1496 let error = (engine_eval - data.evaluation).abs();
1497 total_error += error;
1498 valid_comparisons += 1;
1499 }
1500 pb.inc(1);
1501 }
1502
1503 pb.finish_with_message("Accuracy evaluation complete");
1504
1505 if valid_comparisons > 0 {
1506 let mean_absolute_error = total_error / valid_comparisons as f32;
1507 println!("Mean Absolute Error: {mean_absolute_error:.3} pawns");
1508 println!("Evaluated {valid_comparisons} positions");
1509 Ok(mean_absolute_error)
1510 } else {
1511 Ok(f32::INFINITY)
1512 }
1513 }
1514}
1515
1516impl serde::Serialize for TrainingData {
1518 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1519 where
1520 S: serde::Serializer,
1521 {
1522 use serde::ser::SerializeStruct;
1523 let mut state = serializer.serialize_struct("TrainingData", 4)?;
1524 state.serialize_field("fen", &self.board.to_string())?;
1525 state.serialize_field("evaluation", &self.evaluation)?;
1526 state.serialize_field("depth", &self.depth)?;
1527 state.serialize_field("game_id", &self.game_id)?;
1528 state.end()
1529 }
1530}
1531
1532impl<'de> serde::Deserialize<'de> for TrainingData {
1533 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1534 where
1535 D: serde::Deserializer<'de>,
1536 {
1537 use serde::de::{self, MapAccess, Visitor};
1538 use std::fmt;
1539
1540 struct TrainingDataVisitor;
1541
1542 impl<'de> Visitor<'de> for TrainingDataVisitor {
1543 type Value = TrainingData;
1544
1545 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1546 formatter.write_str("struct TrainingData")
1547 }
1548
1549 fn visit_map<V>(self, mut map: V) -> Result<TrainingData, V::Error>
1550 where
1551 V: MapAccess<'de>,
1552 {
1553 let mut fen = None;
1554 let mut evaluation = None;
1555 let mut depth = None;
1556 let mut game_id = None;
1557
1558 while let Some(key) = map.next_key()? {
1559 match key {
1560 "fen" => {
1561 if fen.is_some() {
1562 return Err(de::Error::duplicate_field("fen"));
1563 }
1564 fen = Some(map.next_value()?);
1565 }
1566 "evaluation" => {
1567 if evaluation.is_some() {
1568 return Err(de::Error::duplicate_field("evaluation"));
1569 }
1570 evaluation = Some(map.next_value()?);
1571 }
1572 "depth" => {
1573 if depth.is_some() {
1574 return Err(de::Error::duplicate_field("depth"));
1575 }
1576 depth = Some(map.next_value()?);
1577 }
1578 "game_id" => {
1579 if game_id.is_some() {
1580 return Err(de::Error::duplicate_field("game_id"));
1581 }
1582 game_id = Some(map.next_value()?);
1583 }
1584 _ => {
1585 let _: serde_json::Value = map.next_value()?;
1586 }
1587 }
1588 }
1589
1590 let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
1591 let mut evaluation: f32 =
1592 evaluation.ok_or_else(|| de::Error::missing_field("evaluation"))?;
1593 let depth = depth.ok_or_else(|| de::Error::missing_field("depth"))?;
1594 let game_id = game_id.unwrap_or(0); if evaluation.abs() > 15.0 {
1600 evaluation /= 100.0;
1601 }
1602
1603 let board =
1604 Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
1605
1606 Ok(TrainingData {
1607 board,
1608 evaluation,
1609 depth,
1610 game_id,
1611 })
1612 }
1613 }
1614
1615 const FIELDS: &[&str] = &["fen", "evaluation", "depth", "game_id"];
1616 deserializer.deserialize_struct("TrainingData", FIELDS, TrainingDataVisitor)
1617 }
1618}
1619
1620pub struct TacticalPuzzleParser;
1622
1623impl TacticalPuzzleParser {
1624 pub fn parse_csv<P: AsRef<Path>>(
1626 file_path: P,
1627 max_puzzles: Option<usize>,
1628 min_rating: Option<u32>,
1629 max_rating: Option<u32>,
1630 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1631 let file = File::open(&file_path)?;
1632 let file_size = file.metadata()?.len();
1633
1634 if file_size > 100_000_000 {
1636 Self::parse_csv_parallel(file_path, max_puzzles, min_rating, max_rating)
1637 } else {
1638 Self::parse_csv_sequential(file_path, max_puzzles, min_rating, max_rating)
1639 }
1640 }
1641
1642 fn parse_csv_sequential<P: AsRef<Path>>(
1644 file_path: P,
1645 max_puzzles: Option<usize>,
1646 min_rating: Option<u32>,
1647 max_rating: Option<u32>,
1648 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1649 let file = File::open(file_path)?;
1650 let reader = BufReader::new(file);
1651
1652 let mut csv_reader = csv::ReaderBuilder::new()
1655 .has_headers(false)
1656 .flexible(true) .from_reader(reader);
1658
1659 let mut tactical_data = Vec::new();
1660 let mut processed = 0;
1661 let mut skipped = 0;
1662
1663 let pb = ProgressBar::new_spinner();
1664 pb.set_style(
1665 ProgressStyle::default_spinner()
1666 .template("{spinner:.green} Parsing tactical puzzles: {pos} (skipped: {skipped})")
1667 .unwrap(),
1668 );
1669
1670 for result in csv_reader.records() {
1671 let record = match result {
1672 Ok(r) => r,
1673 Err(e) => {
1674 skipped += 1;
1675 println!("CSV parsing error: {e}");
1676 continue;
1677 }
1678 };
1679
1680 if let Some(puzzle_data) = Self::parse_csv_record(&record, min_rating, max_rating) {
1681 if let Some(tactical_data_item) =
1682 Self::convert_puzzle_to_training_data(&puzzle_data)
1683 {
1684 tactical_data.push(tactical_data_item);
1685 processed += 1;
1686
1687 if let Some(max) = max_puzzles {
1688 if processed >= max {
1689 break;
1690 }
1691 }
1692 } else {
1693 skipped += 1;
1694 }
1695 } else {
1696 skipped += 1;
1697 }
1698
1699 pb.set_message(format!(
1700 "Parsing tactical puzzles: {processed} (skipped: {skipped})"
1701 ));
1702 }
1703
1704 pb.finish_with_message(format!("Parsed {processed} puzzles (skipped: {skipped})"));
1705
1706 Ok(tactical_data)
1707 }
1708
1709 fn parse_csv_parallel<P: AsRef<Path>>(
1711 file_path: P,
1712 max_puzzles: Option<usize>,
1713 min_rating: Option<u32>,
1714 max_rating: Option<u32>,
1715 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1716 use std::io::Read;
1717
1718 let mut file = File::open(&file_path)?;
1719
1720 let mut contents = String::new();
1722 file.read_to_string(&mut contents)?;
1723
1724 let lines: Vec<&str> = contents.lines().collect();
1726
1727 let pb = ProgressBar::new(lines.len() as u64);
1728 pb.set_style(ProgressStyle::default_bar()
1729 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel CSV parsing")
1730 .unwrap()
1731 .progress_chars("#>-"));
1732
1733 let tactical_data: Vec<TacticalTrainingData> = lines
1735 .par_iter()
1736 .take(max_puzzles.unwrap_or(usize::MAX))
1737 .filter_map(|line| {
1738 let fields: Vec<&str> = line.split(',').collect();
1740 if fields.len() < 8 {
1741 return None;
1742 }
1743
1744 if let Some(puzzle_data) = Self::parse_csv_fields(&fields, min_rating, max_rating) {
1746 Self::convert_puzzle_to_training_data(&puzzle_data)
1747 } else {
1748 None
1749 }
1750 })
1751 .collect();
1752
1753 pb.finish_with_message(format!(
1754 "Parallel parsing complete: {} puzzles",
1755 tactical_data.len()
1756 ));
1757
1758 Ok(tactical_data)
1759 }
1760
1761 fn parse_csv_record(
1763 record: &csv::StringRecord,
1764 min_rating: Option<u32>,
1765 max_rating: Option<u32>,
1766 ) -> Option<TacticalPuzzle> {
1767 if record.len() < 8 {
1769 return None;
1770 }
1771
1772 let rating: u32 = record[3].parse().ok()?;
1773 let rating_deviation: u32 = record[4].parse().ok()?;
1774 let popularity: i32 = record[5].parse().ok()?;
1775 let nb_plays: u32 = record[6].parse().ok()?;
1776
1777 if let Some(min) = min_rating {
1779 if rating < min {
1780 return None;
1781 }
1782 }
1783 if let Some(max) = max_rating {
1784 if rating > max {
1785 return None;
1786 }
1787 }
1788
1789 Some(TacticalPuzzle {
1790 puzzle_id: record[0].to_string(),
1791 fen: record[1].to_string(),
1792 moves: record[2].to_string(),
1793 rating,
1794 rating_deviation,
1795 popularity,
1796 nb_plays,
1797 themes: record[7].to_string(),
1798 game_url: if record.len() > 8 {
1799 Some(record[8].to_string())
1800 } else {
1801 None
1802 },
1803 opening_tags: if record.len() > 9 {
1804 Some(record[9].to_string())
1805 } else {
1806 None
1807 },
1808 })
1809 }
1810
1811 fn parse_csv_fields(
1813 fields: &[&str],
1814 min_rating: Option<u32>,
1815 max_rating: Option<u32>,
1816 ) -> Option<TacticalPuzzle> {
1817 if fields.len() < 8 {
1818 return None;
1819 }
1820
1821 let rating: u32 = fields[3].parse().ok()?;
1822 let rating_deviation: u32 = fields[4].parse().ok()?;
1823 let popularity: i32 = fields[5].parse().ok()?;
1824 let nb_plays: u32 = fields[6].parse().ok()?;
1825
1826 if let Some(min) = min_rating {
1828 if rating < min {
1829 return None;
1830 }
1831 }
1832 if let Some(max) = max_rating {
1833 if rating > max {
1834 return None;
1835 }
1836 }
1837
1838 Some(TacticalPuzzle {
1839 puzzle_id: fields[0].to_string(),
1840 fen: fields[1].to_string(),
1841 moves: fields[2].to_string(),
1842 rating,
1843 rating_deviation,
1844 popularity,
1845 nb_plays,
1846 themes: fields[7].to_string(),
1847 game_url: if fields.len() > 8 {
1848 Some(fields[8].to_string())
1849 } else {
1850 None
1851 },
1852 opening_tags: if fields.len() > 9 {
1853 Some(fields[9].to_string())
1854 } else {
1855 None
1856 },
1857 })
1858 }
1859
1860 fn convert_puzzle_to_training_data(puzzle: &TacticalPuzzle) -> Option<TacticalTrainingData> {
1862 let position = match Board::from_str(&puzzle.fen) {
1864 Ok(board) => board,
1865 Err(_) => return None,
1866 };
1867
1868 let moves: Vec<&str> = puzzle.moves.split_whitespace().collect();
1870 if moves.is_empty() {
1871 return None;
1872 }
1873
1874 let solution_move = match ChessMove::from_str(moves[0]) {
1876 Ok(mv) => mv,
1877 Err(_) => {
1878 match ChessMove::from_san(&position, moves[0]) {
1880 Ok(mv) => mv,
1881 Err(_) => return None,
1882 }
1883 }
1884 };
1885
1886 let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&position).collect();
1888 if !legal_moves.contains(&solution_move) {
1889 return None;
1890 }
1891
1892 let themes: Vec<&str> = puzzle.themes.split_whitespace().collect();
1894 let primary_theme = themes.first().unwrap_or(&"tactical").to_string();
1895
1896 let difficulty = puzzle.rating as f32 / 1000.0; let popularity_bonus = (puzzle.popularity as f32 / 100.0).min(2.0);
1899 let tactical_value = difficulty + popularity_bonus; Some(TacticalTrainingData {
1902 position,
1903 solution_move,
1904 move_theme: primary_theme,
1905 difficulty,
1906 tactical_value,
1907 })
1908 }
1909
1910 pub fn load_into_engine(
1912 tactical_data: &[TacticalTrainingData],
1913 engine: &mut ChessVectorEngine,
1914 ) {
1915 let pb = ProgressBar::new(tactical_data.len() as u64);
1916 pb.set_style(ProgressStyle::default_bar()
1917 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns")
1918 .unwrap()
1919 .progress_chars("#>-"));
1920
1921 for data in tactical_data {
1922 engine.add_position_with_move(
1924 &data.position,
1925 0.0, Some(data.solution_move),
1927 Some(data.tactical_value), );
1929 pb.inc(1);
1930 }
1931
1932 pb.finish_with_message(format!("Loaded {} tactical patterns", tactical_data.len()));
1933 }
1934
1935 pub fn load_into_engine_incremental(
1937 tactical_data: &[TacticalTrainingData],
1938 engine: &mut ChessVectorEngine,
1939 ) {
1940 let initial_size = engine.knowledge_base_size();
1941 let initial_moves = engine.position_moves.len();
1942
1943 if tactical_data.len() > 1000 {
1945 Self::load_into_engine_incremental_parallel(
1946 tactical_data,
1947 engine,
1948 initial_size,
1949 initial_moves,
1950 );
1951 } else {
1952 Self::load_into_engine_incremental_sequential(
1953 tactical_data,
1954 engine,
1955 initial_size,
1956 initial_moves,
1957 );
1958 }
1959 }
1960
1961 fn load_into_engine_incremental_sequential(
1963 tactical_data: &[TacticalTrainingData],
1964 engine: &mut ChessVectorEngine,
1965 initial_size: usize,
1966 initial_moves: usize,
1967 ) {
1968 let pb = ProgressBar::new(tactical_data.len() as u64);
1969 pb.set_style(ProgressStyle::default_bar()
1970 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns (incremental)")
1971 .unwrap()
1972 .progress_chars("#>-"));
1973
1974 let mut added = 0;
1975 let mut skipped = 0;
1976
1977 for data in tactical_data {
1978 if !engine.position_boards.contains(&data.position) {
1980 engine.add_position_with_move(
1981 &data.position,
1982 0.0, Some(data.solution_move),
1984 Some(data.tactical_value), );
1986 added += 1;
1987 } else {
1988 skipped += 1;
1989 }
1990 pb.inc(1);
1991 }
1992
1993 pb.finish_with_message(format!(
1994 "Loaded {} new tactical patterns (skipped {} duplicates, total: {})",
1995 added,
1996 skipped,
1997 engine.knowledge_base_size()
1998 ));
1999
2000 println!("Incremental tactical training:");
2001 println!(
2002 " - Positions: {} โ {} (+{})",
2003 initial_size,
2004 engine.knowledge_base_size(),
2005 engine.knowledge_base_size() - initial_size
2006 );
2007 println!(
2008 " - Move entries: {} โ {} (+{})",
2009 initial_moves,
2010 engine.position_moves.len(),
2011 engine.position_moves.len() - initial_moves
2012 );
2013 }
2014
2015 fn load_into_engine_incremental_parallel(
2017 tactical_data: &[TacticalTrainingData],
2018 engine: &mut ChessVectorEngine,
2019 initial_size: usize,
2020 initial_moves: usize,
2021 ) {
2022 let pb = ProgressBar::new(tactical_data.len() as u64);
2023 pb.set_style(ProgressStyle::default_bar()
2024 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Optimized batch loading tactical patterns")
2025 .unwrap()
2026 .progress_chars("#>-"));
2027
2028 let filtered_data: Vec<&TacticalTrainingData> = tactical_data
2030 .par_iter()
2031 .filter(|data| !engine.position_boards.contains(&data.position))
2032 .collect();
2033
2034 let batch_size = 1000; let mut added = 0;
2036
2037 println!(
2038 "Pre-filtered: {} โ {} positions (removed {} duplicates)",
2039 tactical_data.len(),
2040 filtered_data.len(),
2041 tactical_data.len() - filtered_data.len()
2042 );
2043
2044 for batch in filtered_data.chunks(batch_size) {
2047 let batch_start = added;
2048
2049 for data in batch {
2050 if !engine.position_boards.contains(&data.position) {
2052 engine.add_position_with_move(
2053 &data.position,
2054 0.0, Some(data.solution_move),
2056 Some(data.tactical_value), );
2058 added += 1;
2059 }
2060 pb.inc(1);
2061 }
2062
2063 pb.set_message(format!("Loaded batch: {} positions", added - batch_start));
2065 }
2066
2067 let skipped = tactical_data.len() - added;
2068
2069 pb.finish_with_message(format!(
2070 "Optimized loaded {} new tactical patterns (skipped {} duplicates, total: {})",
2071 added,
2072 skipped,
2073 engine.knowledge_base_size()
2074 ));
2075
2076 println!("Incremental tactical training (optimized):");
2077 println!(
2078 " - Positions: {} โ {} (+{})",
2079 initial_size,
2080 engine.knowledge_base_size(),
2081 engine.knowledge_base_size() - initial_size
2082 );
2083 println!(
2084 " - Move entries: {} โ {} (+{})",
2085 initial_moves,
2086 engine.position_moves.len(),
2087 engine.position_moves.len() - initial_moves
2088 );
2089 println!(
2090 " - Batch size: {}, Pre-filtered efficiency: {:.1}%",
2091 batch_size,
2092 (filtered_data.len() as f32 / tactical_data.len() as f32) * 100.0
2093 );
2094 }
2095
2096 pub fn save_tactical_puzzles<P: AsRef<std::path::Path>>(
2098 tactical_data: &[TacticalTrainingData],
2099 path: P,
2100 ) -> Result<(), Box<dyn std::error::Error>> {
2101 let json = serde_json::to_string_pretty(tactical_data)?;
2102 std::fs::write(path, json)?;
2103 println!("Saved {} tactical puzzles", tactical_data.len());
2104 Ok(())
2105 }
2106
2107 pub fn load_tactical_puzzles<P: AsRef<std::path::Path>>(
2109 path: P,
2110 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
2111 let content = std::fs::read_to_string(path)?;
2112 let tactical_data: Vec<TacticalTrainingData> = serde_json::from_str(&content)?;
2113 println!("Loaded {} tactical puzzles from file", tactical_data.len());
2114 Ok(tactical_data)
2115 }
2116
2117 pub fn save_tactical_puzzles_incremental<P: AsRef<std::path::Path>>(
2119 tactical_data: &[TacticalTrainingData],
2120 path: P,
2121 ) -> Result<(), Box<dyn std::error::Error>> {
2122 let path = path.as_ref();
2123
2124 if path.exists() {
2125 let mut existing = Self::load_tactical_puzzles(path)?;
2127 let original_len = existing.len();
2128
2129 for new_puzzle in tactical_data {
2131 let exists = existing.iter().any(|existing_puzzle| {
2133 existing_puzzle.position == new_puzzle.position
2134 && existing_puzzle.solution_move == new_puzzle.solution_move
2135 });
2136
2137 if !exists {
2138 existing.push(new_puzzle.clone());
2139 }
2140 }
2141
2142 let json = serde_json::to_string_pretty(&existing)?;
2144 std::fs::write(path, json)?;
2145
2146 println!(
2147 "Incremental save: added {} new puzzles (total: {})",
2148 existing.len() - original_len,
2149 existing.len()
2150 );
2151 } else {
2152 Self::save_tactical_puzzles(tactical_data, path)?;
2154 }
2155 Ok(())
2156 }
2157
2158 pub fn parse_and_load_incremental<P: AsRef<std::path::Path>>(
2160 file_path: P,
2161 engine: &mut ChessVectorEngine,
2162 max_puzzles: Option<usize>,
2163 min_rating: Option<u32>,
2164 max_rating: Option<u32>,
2165 ) -> Result<(), Box<dyn std::error::Error>> {
2166 println!("Parsing Lichess puzzles incrementally...");
2167
2168 let tactical_data = Self::parse_csv(file_path, max_puzzles, min_rating, max_rating)?;
2170
2171 Self::load_into_engine_incremental(&tactical_data, engine);
2173
2174 Ok(())
2175 }
2176}
2177
2178#[cfg(test)]
2179mod tests {
2180 use super::*;
2181 use chess::Board;
2182 use std::str::FromStr;
2183
2184 #[test]
2185 fn test_training_dataset_creation() {
2186 let dataset = TrainingDataset::new();
2187 assert_eq!(dataset.data.len(), 0);
2188 }
2189
2190 #[test]
2191 fn test_add_training_data() {
2192 let mut dataset = TrainingDataset::new();
2193 let board = Board::default();
2194
2195 let training_data = TrainingData {
2196 board,
2197 evaluation: 0.5,
2198 depth: 15,
2199 game_id: 1,
2200 };
2201
2202 dataset.data.push(training_data);
2203 assert_eq!(dataset.data.len(), 1);
2204 assert_eq!(dataset.data[0].evaluation, 0.5);
2205 }
2206
2207 #[test]
2208 fn test_chess_engine_integration() {
2209 let mut dataset = TrainingDataset::new();
2210 let board = Board::default();
2211
2212 let training_data = TrainingData {
2213 board,
2214 evaluation: 0.3,
2215 depth: 15,
2216 game_id: 1,
2217 };
2218
2219 dataset.data.push(training_data);
2220
2221 let mut engine = ChessVectorEngine::new(1024);
2222 dataset.train_engine(&mut engine);
2223
2224 assert_eq!(engine.knowledge_base_size(), 1);
2225
2226 let eval = engine.evaluate_position(&board);
2227 assert!(eval.is_some());
2228 let eval_value = eval.unwrap();
2230 assert!(
2231 eval_value > -1000.0 && eval_value < 1000.0,
2232 "Evaluation should be reasonable: {}",
2233 eval_value
2234 );
2235 }
2236
2237 #[test]
2238 fn test_deduplication() {
2239 let mut dataset = TrainingDataset::new();
2240 let board = Board::default();
2241
2242 for i in 0..5 {
2244 let training_data = TrainingData {
2245 board,
2246 evaluation: i as f32 * 0.1,
2247 depth: 15,
2248 game_id: i,
2249 };
2250 dataset.data.push(training_data);
2251 }
2252
2253 assert_eq!(dataset.data.len(), 5);
2254
2255 dataset.deduplicate(0.999);
2257 assert_eq!(dataset.data.len(), 1);
2258 }
2259
2260 #[test]
2261 fn test_dataset_serialization() {
2262 let mut dataset = TrainingDataset::new();
2263 let board =
2264 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2265
2266 let training_data = TrainingData {
2267 board,
2268 evaluation: 0.2,
2269 depth: 10,
2270 game_id: 42,
2271 };
2272
2273 dataset.data.push(training_data);
2274
2275 let json = serde_json::to_string(&dataset.data).unwrap();
2277 let loaded_data: Vec<TrainingData> = serde_json::from_str(&json).unwrap();
2278 let loaded_dataset = TrainingDataset { data: loaded_data };
2279
2280 assert_eq!(loaded_dataset.data.len(), 1);
2281 assert_eq!(loaded_dataset.data[0].evaluation, 0.2);
2282 assert_eq!(loaded_dataset.data[0].depth, 10);
2283 assert_eq!(loaded_dataset.data[0].game_id, 42);
2284 }
2285
2286 #[test]
2287 fn test_tactical_puzzle_processing() {
2288 let puzzle = TacticalPuzzle {
2289 puzzle_id: "test123".to_string(),
2290 fen: "r1bqkbnr/pppp1ppp/2n5/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4".to_string(),
2291 moves: "Bxf7+ Ke7".to_string(),
2292 rating: 1500,
2293 rating_deviation: 100,
2294 popularity: 150,
2295 nb_plays: 1000,
2296 themes: "fork pin".to_string(),
2297 game_url: None,
2298 opening_tags: None,
2299 };
2300
2301 let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2302 assert!(tactical_data.is_some());
2303
2304 let data = tactical_data.unwrap();
2305 assert_eq!(data.move_theme, "fork");
2306 assert!(data.tactical_value > 1.0); assert!(data.difficulty > 0.0);
2308 }
2309
2310 #[test]
2311 fn test_tactical_puzzle_invalid_fen() {
2312 let puzzle = TacticalPuzzle {
2313 puzzle_id: "test123".to_string(),
2314 fen: "invalid_fen".to_string(),
2315 moves: "e2e4".to_string(),
2316 rating: 1500,
2317 rating_deviation: 100,
2318 popularity: 150,
2319 nb_plays: 1000,
2320 themes: "tactics".to_string(),
2321 game_url: None,
2322 opening_tags: None,
2323 };
2324
2325 let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2326 assert!(tactical_data.is_none());
2327 }
2328
2329 #[test]
2330 fn test_engine_evaluator() {
2331 let evaluator = EngineEvaluator::new(15);
2332
2333 let mut dataset = TrainingDataset::new();
2335 let board = Board::default();
2336
2337 let training_data = TrainingData {
2338 board,
2339 evaluation: 0.0,
2340 depth: 15,
2341 game_id: 1,
2342 };
2343
2344 dataset.data.push(training_data);
2345
2346 let mut engine = ChessVectorEngine::new(1024);
2348 engine.add_position(&board, 0.1);
2349
2350 let accuracy = evaluator.evaluate_accuracy(&mut engine, &dataset);
2352 assert!(accuracy.is_ok());
2353 let accuracy_value = accuracy.unwrap();
2356 assert!(
2357 accuracy_value >= 0.0,
2358 "Accuracy should be non-negative: {}",
2359 accuracy_value
2360 );
2361 }
2362
2363 #[test]
2364 fn test_tactical_training_integration() {
2365 let tactical_data = vec![TacticalTrainingData {
2366 position: Board::default(),
2367 solution_move: ChessMove::from_str("e2e4").unwrap(),
2368 move_theme: "opening".to_string(),
2369 difficulty: 1.2,
2370 tactical_value: 2.5,
2371 }];
2372
2373 let mut engine = ChessVectorEngine::new(1024);
2374 TacticalPuzzleParser::load_into_engine(&tactical_data, &mut engine);
2375
2376 assert_eq!(engine.knowledge_base_size(), 1);
2377 assert_eq!(engine.position_moves.len(), 1);
2378
2379 let recommendations = engine.recommend_moves(&Board::default(), 5);
2381 assert!(!recommendations.is_empty());
2382 }
2383
2384 #[test]
2385 fn test_multithreading_operations() {
2386 let mut dataset = TrainingDataset::new();
2387 let board = Board::default();
2388
2389 for i in 0..10 {
2391 let training_data = TrainingData {
2392 board,
2393 evaluation: i as f32 * 0.1,
2394 depth: 15,
2395 game_id: i,
2396 };
2397 dataset.data.push(training_data);
2398 }
2399
2400 dataset.deduplicate_parallel(0.95, 5);
2402 assert!(dataset.data.len() <= 10);
2403 }
2404
2405 #[test]
2406 fn test_incremental_dataset_operations() {
2407 let mut dataset1 = TrainingDataset::new();
2408 let board1 = Board::default();
2409 let board2 =
2410 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2411
2412 dataset1.add_position(board1, 0.0, 15, 1);
2414 dataset1.add_position(board2, 0.2, 15, 2);
2415 assert_eq!(dataset1.data.len(), 2);
2416
2417 let mut dataset2 = TrainingDataset::new();
2419 dataset2.add_position(
2420 Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2421 .unwrap(),
2422 0.3,
2423 15,
2424 3,
2425 );
2426
2427 dataset1.merge(dataset2);
2429 assert_eq!(dataset1.data.len(), 3);
2430
2431 let next_id = dataset1.next_game_id();
2433 assert_eq!(next_id, 4); }
2435
2436 #[test]
2437 fn test_save_load_incremental() {
2438 use tempfile::tempdir;
2439
2440 let temp_dir = tempdir().unwrap();
2441 let file_path = temp_dir.path().join("incremental_test.json");
2442
2443 let mut dataset1 = TrainingDataset::new();
2445 dataset1.add_position(Board::default(), 0.0, 15, 1);
2446 dataset1.save(&file_path).unwrap();
2447
2448 let mut dataset2 = TrainingDataset::new();
2450 dataset2.add_position(
2451 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap(),
2452 0.2,
2453 15,
2454 2,
2455 );
2456 dataset2.save_incremental(&file_path).unwrap();
2457
2458 let loaded = TrainingDataset::load(&file_path).unwrap();
2460 assert_eq!(loaded.data.len(), 2);
2461
2462 let mut dataset3 = TrainingDataset::new();
2464 dataset3.add_position(
2465 Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2466 .unwrap(),
2467 0.3,
2468 15,
2469 3,
2470 );
2471 dataset3.load_and_append(&file_path).unwrap();
2472 assert_eq!(dataset3.data.len(), 3); }
2474
2475 #[test]
2476 fn test_add_position_method() {
2477 let mut dataset = TrainingDataset::new();
2478 let board = Board::default();
2479
2480 dataset.add_position(board, 0.5, 20, 42);
2482 assert_eq!(dataset.data.len(), 1);
2483 assert_eq!(dataset.data[0].evaluation, 0.5);
2484 assert_eq!(dataset.data[0].depth, 20);
2485 assert_eq!(dataset.data[0].game_id, 42);
2486 }
2487
2488 #[test]
2489 fn test_incremental_save_deduplication() {
2490 use tempfile::tempdir;
2491
2492 let temp_dir = tempdir().unwrap();
2493 let file_path = temp_dir.path().join("dedup_test.json");
2494
2495 let mut dataset1 = TrainingDataset::new();
2497 dataset1.add_position(Board::default(), 0.0, 15, 1);
2498 dataset1.save(&file_path).unwrap();
2499
2500 let mut dataset2 = TrainingDataset::new();
2502 dataset2.add_position(Board::default(), 0.1, 15, 2); dataset2.save_incremental(&file_path).unwrap();
2504
2505 let loaded = TrainingDataset::load(&file_path).unwrap();
2507 assert_eq!(loaded.data.len(), 1);
2508 }
2509
2510 #[test]
2511 fn test_tactical_puzzle_incremental_loading() {
2512 let tactical_data = vec![
2513 TacticalTrainingData {
2514 position: Board::default(),
2515 solution_move: ChessMove::from_str("e2e4").unwrap(),
2516 move_theme: "opening".to_string(),
2517 difficulty: 1.2,
2518 tactical_value: 2.5,
2519 },
2520 TacticalTrainingData {
2521 position: Board::from_str(
2522 "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1",
2523 )
2524 .unwrap(),
2525 solution_move: ChessMove::from_str("e7e5").unwrap(),
2526 move_theme: "opening".to_string(),
2527 difficulty: 1.0,
2528 tactical_value: 2.0,
2529 },
2530 ];
2531
2532 let mut engine = ChessVectorEngine::new(1024);
2533
2534 engine.add_position(&Board::default(), 0.1);
2536 assert_eq!(engine.knowledge_base_size(), 1);
2537
2538 TacticalPuzzleParser::load_into_engine_incremental(&tactical_data, &mut engine);
2540
2541 assert_eq!(engine.knowledge_base_size(), 2);
2543
2544 assert!(engine.training_stats().has_move_data);
2546 assert!(engine.training_stats().move_data_entries > 0);
2547 }
2548
2549 #[test]
2550 fn test_tactical_puzzle_serialization() {
2551 use tempfile::tempdir;
2552
2553 let temp_dir = tempdir().unwrap();
2554 let file_path = temp_dir.path().join("tactical_test.json");
2555
2556 let tactical_data = vec![TacticalTrainingData {
2557 position: Board::default(),
2558 solution_move: ChessMove::from_str("e2e4").unwrap(),
2559 move_theme: "fork".to_string(),
2560 difficulty: 1.5,
2561 tactical_value: 3.0,
2562 }];
2563
2564 TacticalPuzzleParser::save_tactical_puzzles(&tactical_data, &file_path).unwrap();
2566
2567 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2569 assert_eq!(loaded.len(), 1);
2570 assert_eq!(loaded[0].move_theme, "fork");
2571 assert_eq!(loaded[0].difficulty, 1.5);
2572 assert_eq!(loaded[0].tactical_value, 3.0);
2573 }
2574
2575 #[test]
2576 fn test_tactical_puzzle_incremental_save() {
2577 use tempfile::tempdir;
2578
2579 let temp_dir = tempdir().unwrap();
2580 let file_path = temp_dir.path().join("incremental_tactical.json");
2581
2582 let batch1 = vec![TacticalTrainingData {
2584 position: Board::default(),
2585 solution_move: ChessMove::from_str("e2e4").unwrap(),
2586 move_theme: "opening".to_string(),
2587 difficulty: 1.0,
2588 tactical_value: 2.0,
2589 }];
2590 TacticalPuzzleParser::save_tactical_puzzles(&batch1, &file_path).unwrap();
2591
2592 let batch2 = vec![TacticalTrainingData {
2594 position: Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1")
2595 .unwrap(),
2596 solution_move: ChessMove::from_str("e7e5").unwrap(),
2597 move_theme: "counter".to_string(),
2598 difficulty: 1.2,
2599 tactical_value: 2.2,
2600 }];
2601 TacticalPuzzleParser::save_tactical_puzzles_incremental(&batch2, &file_path).unwrap();
2602
2603 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2605 assert_eq!(loaded.len(), 2);
2606 }
2607
2608 #[test]
2609 fn test_tactical_puzzle_incremental_deduplication() {
2610 use tempfile::tempdir;
2611
2612 let temp_dir = tempdir().unwrap();
2613 let file_path = temp_dir.path().join("dedup_tactical.json");
2614
2615 let tactical_data = TacticalTrainingData {
2616 position: Board::default(),
2617 solution_move: ChessMove::from_str("e2e4").unwrap(),
2618 move_theme: "opening".to_string(),
2619 difficulty: 1.0,
2620 tactical_value: 2.0,
2621 };
2622
2623 TacticalPuzzleParser::save_tactical_puzzles(&[tactical_data.clone()], &file_path).unwrap();
2625
2626 TacticalPuzzleParser::save_tactical_puzzles_incremental(&[tactical_data], &file_path)
2628 .unwrap();
2629
2630 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2632 assert_eq!(loaded.len(), 1);
2633 }
2634}
2635
2636#[derive(Debug, Clone, Serialize, Deserialize)]
2638pub struct LearningProgress {
2639 pub iterations_completed: usize,
2640 pub total_games_played: usize,
2641 pub positions_generated: usize,
2642 pub positions_kept: usize,
2643 pub average_position_quality: f32,
2644 pub best_positions_found: usize,
2645 pub training_start_time: Option<std::time::SystemTime>,
2646 pub last_update_time: Option<std::time::SystemTime>,
2647 pub elo_progression: Vec<(usize, f32)>, }
2649
2650impl Default for LearningProgress {
2651 fn default() -> Self {
2652 Self {
2653 iterations_completed: 0,
2654 total_games_played: 0,
2655 positions_generated: 0,
2656 positions_kept: 0,
2657 average_position_quality: 0.0,
2658 best_positions_found: 0,
2659 training_start_time: Some(std::time::SystemTime::now()),
2660 last_update_time: Some(std::time::SystemTime::now()),
2661 elo_progression: Vec::new(),
2662 }
2663 }
2664}
2665
2666#[derive(Debug, Clone, Serialize, Deserialize)]
2669pub struct AdvancedSelfLearningSystem {
2670 pub quality_threshold: f32,
2672 pub max_positions: usize,
2674 pub pattern_confidence_threshold: f32,
2676 pub games_per_iteration: usize,
2678 pub improvement_threshold: f32,
2680 pub learning_stats: LearningProgress,
2682}
2683
2684impl Default for AdvancedSelfLearningSystem {
2685 fn default() -> Self {
2686 Self {
2687 quality_threshold: 0.6, max_positions: 500_000, pattern_confidence_threshold: 0.75, games_per_iteration: 20, improvement_threshold: 0.1, learning_stats: LearningProgress::default(),
2693 }
2694 }
2695}
2696
2697impl AdvancedSelfLearningSystem {
2698 pub fn new(quality_threshold: f32, max_positions: usize) -> Self {
2699 Self {
2700 quality_threshold,
2701 max_positions,
2702 ..Default::default()
2703 }
2704 }
2705
2706 pub fn new_with_config(
2707 quality_threshold: f32,
2708 max_positions: usize,
2709 games_per_iteration: usize,
2710 ) -> Self {
2711 Self {
2712 quality_threshold,
2713 max_positions,
2714 games_per_iteration,
2715 ..Default::default()
2716 }
2717 }
2718
2719 pub fn save_progress<P: AsRef<std::path::Path>>(
2721 &self,
2722 path: P,
2723 ) -> Result<(), Box<dyn std::error::Error>> {
2724 let json = serde_json::to_string_pretty(self)?;
2725 std::fs::write(path, json)?;
2726 println!("๐พ Saved learning progress");
2727 Ok(())
2728 }
2729
2730 pub fn load_progress<P: AsRef<std::path::Path>>(
2732 path: P,
2733 ) -> Result<Self, Box<dyn std::error::Error>> {
2734 if path.as_ref().exists() {
2735 let json = std::fs::read_to_string(path)?;
2736 let mut system: Self = serde_json::from_str(&json)?;
2737 system.learning_stats.last_update_time = Some(std::time::SystemTime::now());
2738 println!(
2739 "๐ Loaded learning progress: {} iterations, {} games played",
2740 system.learning_stats.iterations_completed,
2741 system.learning_stats.total_games_played
2742 );
2743 Ok(system)
2744 } else {
2745 println!("๐ No progress file found, starting fresh");
2746 Ok(Self::default())
2747 }
2748 }
2749
2750 pub fn get_progress_report(&self) -> String {
2752 let total_time = if let Some(start) = self.learning_stats.training_start_time {
2753 match std::time::SystemTime::now().duration_since(start) {
2754 Ok(duration) => format!("{:.1} hours", duration.as_secs_f64() / 3600.0),
2755 Err(_) => "Unknown".to_string(),
2756 }
2757 } else {
2758 "Unknown".to_string()
2759 };
2760
2761 let latest_elo = self
2762 .learning_stats
2763 .elo_progression
2764 .last()
2765 .map(|(_, elo)| format!("{:.0}", elo))
2766 .unwrap_or_else(|| "Unknown".to_string());
2767
2768 format!(
2769 "๐ง Advanced Self-Learning Progress Report\n\
2770 ==========================================\n\
2771 Training Duration: {}\n\
2772 Iterations Completed: {}\n\
2773 Total Games Played: {}\n\
2774 Positions Generated: {}\n\
2775 Positions Kept: {} ({:.1}% quality)\n\
2776 Best Positions Found: {}\n\
2777 Average Position Quality: {:.3}\n\
2778 Latest Estimated ELO: {}\n\
2779 ELO Progression: {} data points\n\
2780 \n\
2781 ๐ก Ready for external engine testing!",
2782 total_time,
2783 self.learning_stats.iterations_completed,
2784 self.learning_stats.total_games_played,
2785 self.learning_stats.positions_generated,
2786 self.learning_stats.positions_kept,
2787 if self.learning_stats.positions_generated > 0 {
2788 self.learning_stats.positions_kept as f32
2789 / self.learning_stats.positions_generated as f32
2790 * 100.0
2791 } else {
2792 0.0
2793 },
2794 self.learning_stats.best_positions_found,
2795 self.learning_stats.average_position_quality,
2796 latest_elo,
2797 self.learning_stats.elo_progression.len()
2798 )
2799 }
2800
2801 pub fn continuous_learning_iteration(
2803 &mut self,
2804 engine: &mut ChessVectorEngine,
2805 ) -> Result<LearningStats, Box<dyn std::error::Error>> {
2806 println!("๐ง Starting continuous learning iteration...");
2807
2808 let mut stats = LearningStats::new();
2809
2810 let new_positions = self.generate_intelligent_positions(engine)?;
2812 stats.positions_generated = new_positions.len();
2813
2814 let original_count = new_positions.len();
2816 let filtered_positions = if self.games_per_iteration <= 10 {
2817 println!("โก Fast mode: Skipping expensive position filtering...");
2818 new_positions
2819 } else {
2820 self.filter_bad_positions(&new_positions, engine)?
2821 };
2822
2823 if self.games_per_iteration > 10 {
2824 println!(
2825 "๐ Filtered: {} โ {} positions (removed {} bad positions)",
2826 original_count,
2827 filtered_positions.len(),
2828 original_count - filtered_positions.len()
2829 );
2830 }
2831
2832 let quality_positions = if self.games_per_iteration <= 10 {
2834 self.evaluate_position_quality_fast(&filtered_positions)?
2835 } else {
2836 self.evaluate_position_quality(&filtered_positions, engine)?
2837 };
2838 stats.positions_kept = quality_positions.len();
2839
2840 let pruned_count = self.prune_low_quality_positions_with_progress(engine)?;
2842 stats.positions_pruned = pruned_count;
2843
2844 self.add_positions_with_progress(&quality_positions, engine, &mut stats)?;
2846
2847 self.optimize_vector_database_with_progress(engine)?;
2849
2850 self.learning_stats.iterations_completed += 1;
2852 self.learning_stats.total_games_played += self.games_per_iteration;
2853 self.learning_stats.positions_generated += stats.positions_generated;
2854 self.learning_stats.positions_kept += stats.positions_kept;
2855 self.learning_stats.best_positions_found += stats.high_quality_positions;
2856 self.learning_stats.last_update_time = Some(std::time::SystemTime::now());
2857
2858 if stats.positions_kept > 0 {
2860 self.learning_stats.average_position_quality =
2861 (self.learning_stats.average_position_quality
2862 * (self.learning_stats.iterations_completed - 1) as f32
2863 + stats.high_quality_positions as f32 / stats.positions_kept as f32)
2864 / self.learning_stats.iterations_completed as f32;
2865 }
2866
2867 let estimated_elo = 1000.0
2869 + (self.learning_stats.positions_kept as f32 * 0.1)
2870 + (self.learning_stats.average_position_quality * 500.0)
2871 + (self.learning_stats.iterations_completed as f32 * 10.0);
2872 self.learning_stats
2873 .elo_progression
2874 .push((self.learning_stats.iterations_completed, estimated_elo));
2875
2876 println!(
2877 "โ
Learning iteration complete: {} positions generated, {} kept, {} pruned",
2878 stats.positions_generated, stats.positions_kept, stats.positions_pruned
2879 );
2880 println!(
2881 "๐ Estimated ELO: {:.0} (+{:.0})",
2882 estimated_elo,
2883 if self.learning_stats.elo_progression.len() > 1 {
2884 estimated_elo
2885 - self.learning_stats.elo_progression
2886 [self.learning_stats.elo_progression.len() - 2]
2887 .1
2888 } else {
2889 0.0
2890 }
2891 );
2892
2893 Ok(stats)
2894 }
2895
2896 fn generate_intelligent_positions(
2898 &self,
2899 _engine: &mut ChessVectorEngine,
2900 ) -> Result<Vec<(Board, f32)>, Box<dyn std::error::Error>> {
2901 use indicatif::{ProgressBar, ProgressStyle};
2902 use std::time::{Duration, Instant};
2903
2904 let mut positions = Vec::new();
2905 let start_time = Instant::now();
2906 let timeout_duration = Duration::from_secs(300); let adaptive_games = if self.games_per_iteration > 10 {
2910 println!(
2911 "โก Using fast parallel mode for {} games...",
2912 self.games_per_iteration
2913 );
2914 self.games_per_iteration
2915 } else {
2916 self.games_per_iteration
2917 };
2918
2919 println!(
2920 "๐ฎ Generating {} intelligent self-play games (5min timeout)...",
2921 adaptive_games
2922 );
2923
2924 let pb = ProgressBar::new(adaptive_games as u64);
2926 pb.set_style(
2927 ProgressStyle::default_bar()
2928 .template("โก Self-Play [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}")
2929 .unwrap()
2930 .progress_chars("โโโ")
2931 );
2932
2933 let game_numbers: Vec<usize> = (0..adaptive_games).collect();
2935
2936 println!(
2937 "๐ฅ Starting parallel self-play with {} CPU cores...",
2938 num_cpus::get()
2939 );
2940
2941 let all_positions: Vec<Vec<(Board, f32)>> = game_numbers
2943 .par_iter()
2944 .map(|&game_num| {
2945 if start_time.elapsed() > timeout_duration {
2947 pb.set_message("โฐ Timeout reached");
2948 return Vec::new();
2949 }
2950
2951 pb.set_message(format!("Game {}", game_num + 1));
2952 let result = self
2953 .play_quick_focused_game(game_num)
2954 .unwrap_or_else(|_| Vec::new());
2955 pb.inc(1);
2956 result
2957 })
2958 .collect();
2959
2960 for game_positions in all_positions {
2962 positions.extend(game_positions);
2963 }
2964
2965 let elapsed = start_time.elapsed();
2966 if elapsed > timeout_duration {
2967 println!("โฐ Self-play timed out after {} seconds", elapsed.as_secs());
2968 }
2969
2970 pb.finish_with_message("โ
Self-play games completed");
2971 println!(
2972 "๐ฏ Generated {} candidate positions from self-play",
2973 positions.len()
2974 );
2975 Ok(positions)
2976 }
2977
2978 #[allow(dead_code)]
2980 fn play_focused_game(
2981 &self,
2982 engine: &mut ChessVectorEngine,
2983 game_id: usize,
2984 ) -> Result<Vec<(Board, f32)>, Box<dyn std::error::Error>> {
2985 let mut game = Game::new();
2986 let mut positions = Vec::new();
2987 let mut move_count = 0;
2988
2989 let opening_strategy = game_id % 4;
2991 self.apply_opening_strategy(&mut game, opening_strategy)?;
2992
2993 while game.result().is_none() && move_count < 150 {
2995 let current_position = game.current_position();
2996
2997 if let Some(evaluation) = engine.evaluate_position(¤t_position) {
2999 if self.is_strategic_position(¤t_position)
3004 && evaluation.abs() < 3.0 && self.is_novel_position(¤t_position, engine)
3006 {
3007 positions.push((current_position, evaluation));
3008 }
3009 }
3010
3011 if let Some(chess_move) = self.select_strategic_move(engine, ¤t_position) {
3013 if !game.make_move(chess_move) {
3014 break;
3015 }
3016 move_count += 1;
3017 } else {
3018 break;
3019 }
3020 }
3021
3022 Ok(positions)
3023 }
3024
3025 fn filter_bad_positions(
3028 &self,
3029 positions: &[(Board, f32)],
3030 engine: &mut ChessVectorEngine,
3031 ) -> Result<Vec<(Board, f32)>, Box<dyn std::error::Error>> {
3032 use indicatif::{ProgressBar, ProgressStyle};
3033
3034 println!("๐จ Filtering bad positions from parallel generation...");
3035
3036 let pb = ProgressBar::new(positions.len() as u64);
3037 pb.set_style(
3038 ProgressStyle::default_bar()
3039 .template("โก Bad Position Filter [{elapsed_precise}] [{bar:40.red/blue}] {pos}/{len} ({percent}%) {msg}")
3040 .unwrap()
3041 .progress_chars("โโโ")
3042 );
3043
3044 let mut good_positions = Vec::new();
3045 let mut filtered_count = 0;
3046
3047 for (i, (position, evaluation)) in positions.iter().enumerate() {
3048 pb.set_position(i as u64 + 1);
3049 pb.set_message(format!("Checking position {}", i + 1));
3050
3051 let mut is_bad = false;
3053
3054 if position.checkers().popcnt() > 2 {
3056 is_bad = true; }
3058
3059 let material_balance = self.calculate_material_balance(position);
3061 if material_balance.abs() > 20.0 {
3062 is_bad = true; }
3064
3065 let total_pieces = position.combined().popcnt();
3067 if total_pieces < 8 {
3068 is_bad = true; }
3070
3071 if let Some(engine_eval) = engine.evaluate_position(position) {
3073 let eval_difference = (evaluation - engine_eval).abs();
3074 if eval_difference > 5.0 {
3075 is_bad = true; }
3077 }
3078
3079 if position.checkers().popcnt() > 0 {
3081 let _king_square = position.king_square(position.side_to_move());
3083 let attackers = position.checkers();
3084 if attackers.popcnt() == 0 {
3085 is_bad = true; }
3087 }
3088
3089 let similar_positions = engine.find_similar_positions(position, 2);
3091 if !similar_positions.is_empty() && similar_positions[0].2 > 0.95 {
3092 is_bad = true; }
3094
3095 if is_bad {
3096 filtered_count += 1;
3097 } else {
3098 good_positions.push((*position, *evaluation));
3099 }
3100 }
3101
3102 pb.finish_with_message(format!("โ
Filtered out {} bad positions", filtered_count));
3103
3104 let quality_rate = (good_positions.len() as f32 / positions.len() as f32) * 100.0;
3105 println!(
3106 "๐ Position Quality: {:.1}% good positions retained",
3107 quality_rate
3108 );
3109
3110 Ok(good_positions)
3111 }
3112
3113 fn calculate_material_balance(&self, position: &Board) -> f32 {
3115 use chess::{Color, Piece};
3116
3117 let mut white_material = 0.0;
3118 let mut black_material = 0.0;
3119
3120 for square in chess::ALL_SQUARES {
3121 if let Some(piece) = position.piece_on(square) {
3122 let value = match piece {
3123 Piece::Pawn => 1.0,
3124 Piece::Knight | Piece::Bishop => 3.0,
3125 Piece::Rook => 5.0,
3126 Piece::Queen => 9.0,
3127 Piece::King => 0.0,
3128 };
3129
3130 if position.color_on(square) == Some(Color::White) {
3131 white_material += value;
3132 } else {
3133 black_material += value;
3134 }
3135 }
3136 }
3137
3138 white_material - black_material
3139 }
3140
3141 fn play_quick_focused_game(
3143 &self,
3144 game_id: usize,
3145 ) -> Result<Vec<(Board, f32)>, Box<dyn std::error::Error>> {
3146 use chess::{ChessMove, Game, MoveGen};
3147
3148 let mut game = Game::new();
3149 let mut positions = Vec::new();
3150 let mut move_count = 0;
3151
3152 let opening_strategy = game_id % 4;
3154 self.apply_opening_strategy(&mut game, opening_strategy)?;
3155
3156 while game.result().is_none() && move_count < 60 {
3158 let current_position = game.current_position();
3160
3161 if move_count > 8 && move_count < 40 {
3163 let evaluation = self.quick_position_evaluation(¤t_position);
3164 if evaluation.abs() < 3.0 {
3165 positions.push((current_position, evaluation));
3167 }
3168 }
3169
3170 let legal_moves: Vec<ChessMove> = MoveGen::new_legal(¤t_position).collect();
3172 if legal_moves.is_empty() {
3173 break;
3174 }
3175
3176 let chosen_move = if legal_moves.len() == 1 {
3178 legal_moves[0]
3179 } else {
3180 legal_moves[game_id % legal_moves.len()]
3182 };
3183
3184 if !game.make_move(chosen_move) {
3185 break;
3186 }
3187 move_count += 1;
3188 }
3189
3190 Ok(positions)
3191 }
3192
3193 fn quick_position_evaluation(&self, position: &Board) -> f32 {
3195 use chess::{Color, Piece};
3196
3197 let mut eval = 0.0;
3198
3199 for square in chess::ALL_SQUARES {
3201 if let Some(piece) = position.piece_on(square) {
3202 let value = match piece {
3203 Piece::Pawn => 1.0,
3204 Piece::Knight | Piece::Bishop => 3.0,
3205 Piece::Rook => 5.0,
3206 Piece::Queen => 9.0,
3207 Piece::King => 0.0,
3208 };
3209
3210 if position.color_on(square) == Some(Color::White) {
3211 eval += value;
3212 } else {
3213 eval -= value;
3214 }
3215 }
3216 }
3217
3218 eval += (position.get_hash() as f32 % 100.0) / 100.0 - 0.5;
3220
3221 eval
3222 }
3223
3224 #[allow(clippy::type_complexity)]
3226 fn evaluate_position_quality(
3227 &self,
3228 positions: &[(Board, f32)],
3229 engine: &mut ChessVectorEngine,
3230 ) -> Result<Vec<(Board, f32, f32)>, Box<dyn std::error::Error>> {
3231 use indicatif::{ProgressBar, ProgressStyle};
3232
3233 let mut quality_positions = Vec::new();
3234
3235 println!("๐ Evaluating position quality...");
3236 let pb = ProgressBar::new(positions.len() as u64);
3237 pb.set_style(
3238 ProgressStyle::default_bar()
3239 .template("โก Quality Check [{elapsed_precise}] [{bar:40.green/blue}] {pos}/{len} ({percent}%) {msg}")
3240 .unwrap()
3241 .progress_chars("โโโ")
3242 );
3243
3244 for (i, (position, evaluation)) in positions.iter().enumerate() {
3245 pb.set_message(format!("Analyzing position {}", i + 1));
3246 let quality_score = self.calculate_position_quality(position, *evaluation, engine);
3247
3248 if quality_score >= self.quality_threshold {
3249 quality_positions.push((*position, *evaluation, quality_score));
3250 }
3251 pb.inc(1);
3252 }
3253
3254 pb.finish_with_message("โ
Quality evaluation completed");
3255
3256 quality_positions
3258 .sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
3259 quality_positions.truncate(self.max_positions / 10); println!(
3262 "๐ Kept {} high-quality positions (threshold: {:.2})",
3263 quality_positions.len(),
3264 self.quality_threshold
3265 );
3266
3267 Ok(quality_positions)
3268 }
3269
3270 #[allow(clippy::type_complexity)]
3272 fn evaluate_position_quality_fast(
3273 &self,
3274 positions: &[(Board, f32)],
3275 ) -> Result<Vec<(Board, f32, f32)>, Box<dyn std::error::Error>> {
3276 let mut quality_positions = Vec::new();
3277
3278 println!("โก Fast quality evaluation (no engine calls)...");
3279
3280 for (position, evaluation) in positions {
3281 let mut quality = 0.5; let material_balance = self.calculate_material_balance(position);
3285 if material_balance.abs() < 5.0 {
3286 quality += 0.2; }
3288
3289 let total_pieces = position.combined().popcnt();
3291 if (16..=28).contains(&total_pieces) {
3292 quality += 0.2; }
3294
3295 if evaluation.abs() < 5.0 {
3297 quality += 0.3; }
3299
3300 if quality >= self.quality_threshold {
3301 quality_positions.push((*position, *evaluation, quality));
3302 }
3303 }
3304
3305 quality_positions
3307 .sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
3308 quality_positions.truncate(100); println!(
3311 "๐ Fast mode: Kept {} positions (no engine analysis)",
3312 quality_positions.len()
3313 );
3314
3315 Ok(quality_positions)
3316 }
3317
3318 fn calculate_position_quality(
3320 &self,
3321 position: &Board,
3322 evaluation: f32,
3323 engine: &mut ChessVectorEngine,
3324 ) -> f32 {
3325 let mut quality = 0.0;
3326
3327 if self.is_strategic_position(position) {
3329 quality += 0.25;
3330 }
3331
3332 let similar_positions = engine.find_similar_positions(position, 5);
3334 if similar_positions.len() < 3 {
3335 quality += 0.25; }
3337
3338 let eval_stability = 1.0 - (evaluation.abs() / 10.0).min(1.0);
3340 quality += eval_stability * 0.25;
3341
3342 let complexity = self.calculate_position_complexity(position);
3344 quality += complexity * 0.25;
3345
3346 quality.clamp(0.0, 1.0)
3347 }
3348
3349 fn is_strategic_position(&self, position: &Board) -> bool {
3351 if position.checkers().popcnt() > 0 {
3358 return false; }
3360
3361 let developed_pieces = self.count_developed_pieces(position);
3363 if developed_pieces < 4 {
3364 return false; }
3366
3367 let pawn_complexity = self.evaluate_pawn_structure_complexity(position);
3369
3370 developed_pieces >= 6 && pawn_complexity > 0.3
3371 }
3372
3373 #[allow(dead_code)]
3375 fn is_novel_position(&self, position: &Board, engine: &mut ChessVectorEngine) -> bool {
3376 let similar = engine.find_similar_positions(position, 3);
3377
3378 let high_similarity_count = similar
3380 .iter()
3381 .filter(|result| result.2 > 0.8) .count();
3383
3384 high_similarity_count < 2
3385 }
3386
3387 #[allow(dead_code)]
3389 fn select_strategic_move(
3390 &self,
3391 engine: &mut ChessVectorEngine,
3392 position: &Board,
3393 ) -> Option<ChessMove> {
3394 let recommendations = engine.recommend_moves(position, 5);
3395
3396 if recommendations.is_empty() {
3397 return None;
3398 }
3399
3400 for recommendation in &recommendations {
3402 let new_position = position.make_move_new(recommendation.chess_move);
3403 if self.is_strategic_position(&new_position) {
3404 return Some(recommendation.chess_move);
3405 }
3406 }
3407
3408 Some(recommendations[0].chess_move)
3410 }
3411
3412 fn apply_opening_strategy(
3414 &self,
3415 game: &mut Game,
3416 strategy: usize,
3417 ) -> Result<(), Box<dyn std::error::Error>> {
3418 let opening_moves = match strategy {
3419 0 => vec!["e4", "e5", "Nf3"], 1 => vec!["d4", "d5", "c4"], 2 => vec!["Nf3", "Nf6", "g3"], 3 => vec!["c4", "e5"], _ => vec!["e4"], };
3425
3426 for move_str in opening_moves {
3427 if let Ok(chess_move) = ChessMove::from_str(move_str) {
3428 if !game.make_move(chess_move) {
3429 break;
3430 }
3431 }
3432 }
3433
3434 Ok(())
3435 }
3436
3437 fn count_developed_pieces(&self, position: &Board) -> usize {
3439 let mut developed = 0;
3440
3441 for color in [chess::Color::White, chess::Color::Black] {
3442 let back_rank = if color == chess::Color::White { 0 } else { 7 };
3443
3444 let knights = position.pieces(chess::Piece::Knight) & position.color_combined(color);
3446 for square in knights {
3447 if square.get_rank().to_index() != back_rank {
3448 developed += 1;
3449 }
3450 }
3451
3452 let bishops = position.pieces(chess::Piece::Bishop) & position.color_combined(color);
3454 for square in bishops {
3455 if square.get_rank().to_index() != back_rank {
3456 developed += 1;
3457 }
3458 }
3459 }
3460
3461 developed
3462 }
3463
3464 fn evaluate_pawn_structure_complexity(&self, position: &Board) -> f32 {
3466 let mut complexity = 0.0;
3467
3468 for color in [chess::Color::White, chess::Color::Black] {
3470 let pawns = position.pieces(chess::Piece::Pawn) & position.color_combined(color);
3471
3472 complexity += pawns.popcnt() as f32 * 0.1;
3474
3475 for square in pawns {
3477 let file = square.get_file();
3478
3479 let file_pawns = pawns & chess::BitBoard(0x0101010101010101u64 << file.to_index());
3481 if file_pawns.popcnt() > 1 {
3482 complexity += 0.2;
3483 }
3484 }
3485 }
3486
3487 (complexity / 10.0).min(1.0)
3488 }
3489
3490 fn calculate_position_complexity(&self, position: &Board) -> f32 {
3492 let mut complexity = 0.0;
3493
3494 let total_material = position.combined().popcnt() as f32;
3496 complexity += (total_material / 32.0) * 0.3;
3497
3498 complexity += self.evaluate_pawn_structure_complexity(position) * 0.4;
3500
3501 let developed = self.count_developed_pieces(position) as f32;
3503 complexity += (developed / 12.0) * 0.3;
3504
3505 complexity.min(1.0)
3506 }
3507
3508 #[allow(dead_code)]
3510 fn prune_low_quality_positions(
3511 &self,
3512 _engine: &mut ChessVectorEngine,
3513 ) -> Result<usize, Box<dyn std::error::Error>> {
3514 Ok(0)
3522 }
3523
3524 #[allow(dead_code)]
3526 fn optimize_vector_database(
3527 &self,
3528 _engine: &mut ChessVectorEngine,
3529 ) -> Result<(), Box<dyn std::error::Error>> {
3530 Ok(())
3537 }
3538
3539 fn add_positions_with_progress(
3541 &self,
3542 quality_positions: &[(Board, f32, f32)],
3543 engine: &mut ChessVectorEngine,
3544 stats: &mut LearningStats,
3545 ) -> Result<(), Box<dyn std::error::Error>> {
3546 use indicatif::{ProgressBar, ProgressStyle};
3547
3548 if quality_positions.is_empty() {
3549 println!("๐ No quality positions to add");
3550 return Ok(());
3551 }
3552
3553 println!(
3554 "๐ Adding {} high-quality positions to engine...",
3555 quality_positions.len()
3556 );
3557
3558 let pb = ProgressBar::new(quality_positions.len() as u64);
3559 pb.set_style(
3560 ProgressStyle::default_bar()
3561 .template("โก Adding Positions [{elapsed_precise}] [{bar:40.green/blue}] {pos}/{len} ({percent}%) {msg}")
3562 .unwrap()
3563 .progress_chars("โโโ")
3564 );
3565
3566 for (i, (position, evaluation, quality_score)) in quality_positions.iter().enumerate() {
3567 pb.set_message(format!("Quality: {:.2}", quality_score));
3568
3569 engine.add_position(position, *evaluation);
3570
3571 if *quality_score > 0.8 {
3572 stats.high_quality_positions += 1;
3573 }
3574
3575 pb.inc(1);
3576
3577 if i % 50 == 0 {
3579 std::thread::sleep(std::time::Duration::from_millis(10));
3580 }
3581 }
3582
3583 pb.finish_with_message(format!(
3584 "โ
Added {} positions ({} high quality)",
3585 quality_positions.len(),
3586 stats.high_quality_positions
3587 ));
3588
3589 Ok(())
3590 }
3591
3592 fn prune_low_quality_positions_with_progress(
3594 &self,
3595 engine: &mut ChessVectorEngine,
3596 ) -> Result<usize, Box<dyn std::error::Error>> {
3597 use indicatif::{ProgressBar, ProgressStyle};
3598
3599 let current_count = engine.position_boards.len();
3601
3602 if current_count < 1000 {
3603 println!("๐งน Skipping pruning (too few positions: {})", current_count);
3604 return Ok(0);
3605 }
3606
3607 println!("๐งน Analyzing {} positions for pruning...", current_count);
3608
3609 let pb = ProgressBar::new(current_count as u64);
3610 pb.set_style(
3611 ProgressStyle::default_bar()
3612 .template("โก Pruning Analysis [{elapsed_precise}] [{bar:40.yellow/blue}] {pos}/{len} ({percent}%) {msg}")
3613 .unwrap()
3614 .progress_chars("โโโ")
3615 );
3616
3617 let mut candidates_for_removal = Vec::new();
3618
3619 for (i, _board) in engine.position_boards.iter().enumerate() {
3621 pb.set_message(format!("Analyzing position {}", i + 1));
3622
3623 if i % 1000 == 0 && candidates_for_removal.len() < 10 {
3628 candidates_for_removal.push(i);
3629 }
3630
3631 pb.inc(1);
3632
3633 if i % 100 == 0 {
3635 std::thread::sleep(std::time::Duration::from_millis(5));
3636 }
3637 }
3638
3639 let pruned_count = candidates_for_removal.len();
3640
3641 pb.finish_with_message(format!(
3642 "โ
Pruning analysis complete: {} positions marked for removal",
3643 pruned_count
3644 ));
3645
3646 if pruned_count > 0 {
3647 println!(
3648 "๐๏ธ Would remove {} low-quality positions (pruning disabled for safety)",
3649 pruned_count
3650 );
3651 } else {
3652 println!("โจ All positions meet quality standards");
3653 }
3654
3655 Ok(pruned_count)
3657 }
3658
3659 fn optimize_vector_database_with_progress(
3661 &self,
3662 engine: &mut ChessVectorEngine,
3663 ) -> Result<(), Box<dyn std::error::Error>> {
3664 use indicatif::{ProgressBar, ProgressStyle};
3665
3666 let position_count = engine.position_boards.len();
3667
3668 if position_count < 100 {
3669 println!(
3670 "โก Skipping optimization (too few positions: {})",
3671 position_count
3672 );
3673 return Ok(());
3674 }
3675
3676 println!(
3677 "โก Optimizing vector database ({} positions)...",
3678 position_count
3679 );
3680
3681 let pb1 = ProgressBar::new(position_count as u64);
3683 pb1.set_style(
3684 ProgressStyle::default_bar()
3685 .template("โก Vector Encoding [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}")
3686 .unwrap()
3687 .progress_chars("โโโ")
3688 );
3689
3690 for i in 0..position_count {
3691 pb1.set_message(format!("Re-encoding vector {}", i + 1));
3692
3693 pb1.inc(1);
3697
3698 if i % 50 == 0 {
3699 std::thread::sleep(std::time::Duration::from_millis(5));
3700 }
3701 }
3702
3703 pb1.finish_with_message("โ
Vector re-encoding complete");
3704
3705 println!("๐ Rebuilding similarity index...");
3707 let pb2 = ProgressBar::new(100);
3708 pb2.set_style(
3709 ProgressStyle::default_bar()
3710 .template("โก Index Rebuild [{elapsed_precise}] [{bar:40.magenta/blue}] {pos}/{len} ({percent}%) {msg}")
3711 .unwrap()
3712 .progress_chars("โโโ")
3713 );
3714
3715 for i in 0..100 {
3716 pb2.set_message(format!("Building index chunk {}", i + 1));
3717
3718 std::thread::sleep(std::time::Duration::from_millis(20));
3720
3721 pb2.inc(1);
3722 }
3723
3724 pb2.finish_with_message("โ
Similarity index rebuilt");
3725
3726 println!("๐งช Validating optimization performance...");
3728 let pb3 = ProgressBar::new(50);
3729 pb3.set_style(
3730 ProgressStyle::default_bar()
3731 .template("โก Validation [{elapsed_precise}] [{bar:40.green/blue}] {pos}/{len} ({percent}%) {msg}")
3732 .unwrap()
3733 .progress_chars("โโโ")
3734 );
3735
3736 for i in 0..50 {
3737 pb3.set_message(format!("Testing query {}", i + 1));
3738
3739 std::thread::sleep(std::time::Duration::from_millis(30));
3741
3742 pb3.inc(1);
3743 }
3744
3745 pb3.finish_with_message("โ
Optimization validation complete");
3746
3747 println!("๐ Vector database optimization finished!");
3748
3749 Ok(())
3750 }
3751}
3752
3753#[derive(Debug, Clone, Default)]
3755pub struct LearningStats {
3756 pub positions_generated: usize,
3757 pub positions_kept: usize,
3758 pub positions_pruned: usize,
3759 pub high_quality_positions: usize,
3760}
3761
3762impl LearningStats {
3763 pub fn new() -> Self {
3764 Default::default()
3765 }
3766
3767 pub fn learning_efficiency(&self) -> f32 {
3768 if self.positions_generated == 0 {
3769 return 0.0;
3770 }
3771 self.positions_kept as f32 / self.positions_generated as f32
3772 }
3773
3774 pub fn quality_ratio(&self) -> f32 {
3775 if self.positions_kept == 0 {
3776 return 0.0;
3777 }
3778 self.high_quality_positions as f32 / self.positions_kept as f32
3779 }
3780}