1use chess::{Board, ChessMove, Game, MoveGen};
2use indicatif::{ProgressBar, ProgressStyle};
3use pgn_reader::{BufferedReader, RawHeader, SanPlus, Skip, Visitor};
4use rayon::prelude::*;
5use serde::{Deserialize, Serialize};
6use std::fs::File;
7use std::io::{BufRead, BufReader, BufWriter, Write};
8use std::path::Path;
9use std::process::{Child, Command, Stdio};
10use std::str::FromStr;
11use std::sync::{Arc, Mutex};
12
13use crate::ChessVectorEngine;
14
15#[derive(Debug, Clone)]
17pub struct SelfPlayConfig {
18 pub games_per_iteration: usize,
20 pub max_moves_per_game: usize,
22 pub exploration_factor: f32,
24 pub min_confidence: f32,
26 pub use_opening_book: bool,
28 pub temperature: f32,
30}
31
32impl Default for SelfPlayConfig {
33 fn default() -> Self {
34 Self {
35 games_per_iteration: 100,
36 max_moves_per_game: 200,
37 exploration_factor: 0.3,
38 min_confidence: 0.1,
39 use_opening_book: true,
40 temperature: 0.8,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct TrainingData {
48 pub board: Board,
49 pub evaluation: f32,
50 pub depth: u8,
51 pub game_id: usize,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct TacticalPuzzle {
57 #[serde(rename = "PuzzleId")]
58 pub puzzle_id: String,
59 #[serde(rename = "FEN")]
60 pub fen: String,
61 #[serde(rename = "Moves")]
62 pub moves: String, #[serde(rename = "Rating")]
64 pub rating: u32,
65 #[serde(rename = "RatingDeviation")]
66 pub rating_deviation: u32,
67 #[serde(rename = "Popularity")]
68 pub popularity: i32,
69 #[serde(rename = "NbPlays")]
70 pub nb_plays: u32,
71 #[serde(rename = "Themes")]
72 pub themes: String, #[serde(rename = "GameUrl")]
74 pub game_url: Option<String>,
75 #[serde(rename = "OpeningTags")]
76 pub opening_tags: Option<String>,
77}
78
79#[derive(Debug, Clone)]
81pub struct TacticalTrainingData {
82 pub position: Board,
83 pub solution_move: ChessMove,
84 pub move_theme: String,
85 pub difficulty: f32, pub tactical_value: f32, }
88
89impl serde::Serialize for TacticalTrainingData {
91 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
92 where
93 S: serde::Serializer,
94 {
95 use serde::ser::SerializeStruct;
96 let mut state = serializer.serialize_struct("TacticalTrainingData", 5)?;
97 state.serialize_field("fen", &self.position.to_string())?;
98 state.serialize_field("solution_move", &self.solution_move.to_string())?;
99 state.serialize_field("move_theme", &self.move_theme)?;
100 state.serialize_field("difficulty", &self.difficulty)?;
101 state.serialize_field("tactical_value", &self.tactical_value)?;
102 state.end()
103 }
104}
105
106impl<'de> serde::Deserialize<'de> for TacticalTrainingData {
107 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108 where
109 D: serde::Deserializer<'de>,
110 {
111 use serde::de::{self, MapAccess, Visitor};
112 use std::fmt;
113
114 struct TacticalTrainingDataVisitor;
115
116 impl<'de> Visitor<'de> for TacticalTrainingDataVisitor {
117 type Value = TacticalTrainingData;
118
119 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
120 formatter.write_str("struct TacticalTrainingData")
121 }
122
123 fn visit_map<V>(self, mut map: V) -> Result<TacticalTrainingData, V::Error>
124 where
125 V: MapAccess<'de>,
126 {
127 let mut fen = None;
128 let mut solution_move = None;
129 let mut move_theme = None;
130 let mut difficulty = None;
131 let mut tactical_value = None;
132
133 while let Some(key) = map.next_key()? {
134 match key {
135 "fen" => {
136 if fen.is_some() {
137 return Err(de::Error::duplicate_field("fen"));
138 }
139 fen = Some(map.next_value()?);
140 }
141 "solution_move" => {
142 if solution_move.is_some() {
143 return Err(de::Error::duplicate_field("solution_move"));
144 }
145 solution_move = Some(map.next_value()?);
146 }
147 "move_theme" => {
148 if move_theme.is_some() {
149 return Err(de::Error::duplicate_field("move_theme"));
150 }
151 move_theme = Some(map.next_value()?);
152 }
153 "difficulty" => {
154 if difficulty.is_some() {
155 return Err(de::Error::duplicate_field("difficulty"));
156 }
157 difficulty = Some(map.next_value()?);
158 }
159 "tactical_value" => {
160 if tactical_value.is_some() {
161 return Err(de::Error::duplicate_field("tactical_value"));
162 }
163 tactical_value = Some(map.next_value()?);
164 }
165 _ => {
166 let _: serde_json::Value = map.next_value()?;
167 }
168 }
169 }
170
171 let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
172 let solution_move_str: String =
173 solution_move.ok_or_else(|| de::Error::missing_field("solution_move"))?;
174 let move_theme =
175 move_theme.ok_or_else(|| de::Error::missing_field("move_theme"))?;
176 let difficulty =
177 difficulty.ok_or_else(|| de::Error::missing_field("difficulty"))?;
178 let tactical_value =
179 tactical_value.ok_or_else(|| de::Error::missing_field("tactical_value"))?;
180
181 let position =
182 Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
183
184 let solution_move = ChessMove::from_str(&solution_move_str)
185 .map_err(|e| de::Error::custom(format!("Error: {e}")))?;
186
187 Ok(TacticalTrainingData {
188 position,
189 solution_move,
190 move_theme,
191 difficulty,
192 tactical_value,
193 })
194 }
195 }
196
197 const FIELDS: &[&str] = &[
198 "fen",
199 "solution_move",
200 "move_theme",
201 "difficulty",
202 "tactical_value",
203 ];
204 deserializer.deserialize_struct("TacticalTrainingData", FIELDS, TacticalTrainingDataVisitor)
205 }
206}
207
208pub struct GameExtractor {
210 pub positions: Vec<TrainingData>,
211 pub current_game: Game,
212 pub move_count: usize,
213 pub max_moves_per_game: usize,
214 pub game_id: usize,
215}
216
217impl GameExtractor {
218 pub fn new(max_moves_per_game: usize) -> Self {
219 Self {
220 positions: Vec::new(),
221 current_game: Game::new(),
222 move_count: 0,
223 max_moves_per_game,
224 game_id: 0,
225 }
226 }
227}
228
229impl Visitor for GameExtractor {
230 type Result = ();
231
232 fn begin_game(&mut self) {
233 self.current_game = Game::new();
234 self.move_count = 0;
235 self.game_id += 1;
236 }
237
238 fn header(&mut self, _key: &[u8], _value: RawHeader<'_>) {}
239
240 fn san(&mut self, san_plus: SanPlus) {
241 if self.move_count >= self.max_moves_per_game {
242 return;
243 }
244
245 let san_str = san_plus.san.to_string();
246
247 let current_pos = self.current_game.current_position();
249
250 match chess::ChessMove::from_san(¤t_pos, &san_str) {
252 Ok(chess_move) => {
253 let legal_moves: Vec<chess::ChessMove> = MoveGen::new_legal(¤t_pos).collect();
255 if legal_moves.contains(&chess_move) {
256 if self.current_game.make_move(chess_move) {
257 self.move_count += 1;
258
259 self.positions.push(TrainingData {
261 board: self.current_game.current_position(),
262 evaluation: 0.0, depth: 0,
264 game_id: self.game_id,
265 });
266 }
267 } else {
268 }
270 }
271 Err(_) => {
272 if !san_str.contains("O-O") && !san_str.contains("=") && san_str.len() > 6 {
276 }
278 }
279 }
280 }
281
282 fn begin_variation(&mut self) -> Skip {
283 Skip(true) }
285
286 fn end_game(&mut self) -> Self::Result {}
287}
288
289pub struct StockfishEvaluator {
291 depth: u8,
292}
293
294impl StockfishEvaluator {
295 pub fn new(depth: u8) -> Self {
296 Self { depth }
297 }
298
299 pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
301 let mut child = Command::new("stockfish")
302 .stdin(Stdio::piped())
303 .stdout(Stdio::piped())
304 .stderr(Stdio::piped())
305 .spawn()?;
306
307 let stdin = child
308 .stdin
309 .as_mut()
310 .ok_or("Failed to get stdin handle for Stockfish process")?;
311 let fen = board.to_string();
312
313 use std::io::Write;
315 writeln!(stdin, "uci")?;
316 writeln!(stdin, "isready")?;
317 writeln!(stdin, "position fen {fen}")?;
318 writeln!(stdin, "go depth {}", self.depth)?;
319 writeln!(stdin, "quit")?;
320
321 let output = child.wait_with_output()?;
322 let stdout = String::from_utf8_lossy(&output.stdout);
323
324 for line in stdout.lines() {
326 if line.starts_with("info") && line.contains("score cp") {
327 if let Some(cp_pos) = line.find("score cp ") {
328 let cp_str = &line[cp_pos + 9..];
329 if let Some(end) = cp_str.find(' ') {
330 let cp_value = cp_str[..end].parse::<i32>()?;
331 return Ok(cp_value as f32 / 100.0); }
333 }
334 } else if line.starts_with("info") && line.contains("score mate") {
335 if let Some(mate_pos) = line.find("score mate ") {
337 let mate_str = &line[mate_pos + 11..];
338 if let Some(end) = mate_str.find(' ') {
339 let mate_moves = mate_str[..end].parse::<i32>()?;
340 return Ok(if mate_moves > 0 { 100.0 } else { -100.0 });
341 }
342 }
343 }
344 }
345
346 Ok(0.0) }
348
349 pub fn evaluate_batch(
351 &self,
352 positions: &mut [TrainingData],
353 ) -> Result<(), Box<dyn std::error::Error>> {
354 let pb = ProgressBar::new(positions.len() as u64);
355 if let Ok(style) = ProgressStyle::default_bar().template(
356 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
357 ) {
358 pb.set_style(style.progress_chars("#>-"));
359 }
360
361 for data in positions.iter_mut() {
362 match self.evaluate_position(&data.board) {
363 Ok(eval) => {
364 data.evaluation = eval;
365 data.depth = self.depth;
366 }
367 Err(e) => {
368 eprintln!("Evaluation error: {e}");
369 data.evaluation = 0.0;
370 }
371 }
372 pb.inc(1);
373 }
374
375 pb.finish_with_message("Evaluation complete");
376 Ok(())
377 }
378
379 pub fn evaluate_batch_parallel(
381 &self,
382 positions: &mut [TrainingData],
383 num_threads: usize,
384 ) -> Result<(), Box<dyn std::error::Error>> {
385 let pb = ProgressBar::new(positions.len() as u64);
386 if let Ok(style) = ProgressStyle::default_bar()
387 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel evaluation") {
388 pb.set_style(style.progress_chars("#>-"));
389 }
390
391 let pool = rayon::ThreadPoolBuilder::new()
393 .num_threads(num_threads)
394 .build()?;
395
396 pool.install(|| {
397 positions.par_iter_mut().for_each(|data| {
399 match self.evaluate_position(&data.board) {
400 Ok(eval) => {
401 data.evaluation = eval;
402 data.depth = self.depth;
403 }
404 Err(_) => {
405 data.evaluation = 0.0;
407 }
408 }
409 pb.inc(1);
410 });
411 });
412
413 pb.finish_with_message("Parallel evaluation complete");
414 Ok(())
415 }
416}
417
418struct StockfishProcess {
420 child: Child,
421 stdin: BufWriter<std::process::ChildStdin>,
422 stdout: BufReader<std::process::ChildStdout>,
423 #[allow(dead_code)]
424 depth: u8,
425}
426
427impl StockfishProcess {
428 fn new(depth: u8) -> Result<Self, Box<dyn std::error::Error>> {
429 let mut child = Command::new("stockfish")
430 .stdin(Stdio::piped())
431 .stdout(Stdio::piped())
432 .stderr(Stdio::piped())
433 .spawn()?;
434
435 let stdin = BufWriter::new(
436 child
437 .stdin
438 .take()
439 .ok_or("Failed to get stdin handle for Stockfish process")?,
440 );
441 let stdout = BufReader::new(
442 child
443 .stdout
444 .take()
445 .ok_or("Failed to get stdout handle for Stockfish process")?,
446 );
447
448 let mut process = Self {
449 child,
450 stdin,
451 stdout,
452 depth,
453 };
454
455 process.send_command("uci")?;
457 process.wait_for_ready()?;
458 process.send_command("isready")?;
459 process.wait_for_ready()?;
460
461 Ok(process)
462 }
463
464 fn send_command(&mut self, command: &str) -> Result<(), Box<dyn std::error::Error>> {
465 writeln!(self.stdin, "{command}")?;
466 self.stdin.flush()?;
467 Ok(())
468 }
469
470 fn wait_for_ready(&mut self) -> Result<(), Box<dyn std::error::Error>> {
471 let mut line = String::new();
472 loop {
473 line.clear();
474 self.stdout.read_line(&mut line)?;
475 if line.trim() == "uciok" || line.trim() == "readyok" {
476 break;
477 }
478 }
479 Ok(())
480 }
481
482 fn evaluate_position(&mut self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
483 let fen = board.to_string();
484
485 self.send_command(&format!("position fen {fen}"))?;
487 self.send_command(&format!("position fen {fen}"))?;
488
489 let mut line = String::new();
491 let mut last_evaluation = 0.0;
492
493 loop {
494 line.clear();
495 self.stdout.read_line(&mut line)?;
496 let line = line.trim();
497
498 if line.starts_with("info") && line.contains("score cp") {
499 if let Some(cp_pos) = line.find("score cp ") {
500 let cp_str = &line[cp_pos + 9..];
501 if let Some(end) = cp_str.find(' ') {
502 if let Ok(cp_value) = cp_str[..end].parse::<i32>() {
503 last_evaluation = cp_value as f32 / 100.0;
504 }
505 }
506 }
507 } else if line.starts_with("info") && line.contains("score mate") {
508 if let Some(mate_pos) = line.find("score mate ") {
509 let mate_str = &line[mate_pos + 11..];
510 if let Some(end) = mate_str.find(' ') {
511 if let Ok(mate_moves) = mate_str[..end].parse::<i32>() {
512 last_evaluation = if mate_moves > 0 { 100.0 } else { -100.0 };
513 }
514 }
515 }
516 } else if line.starts_with("bestmove") {
517 break;
518 }
519 }
520
521 Ok(last_evaluation)
522 }
523}
524
525impl Drop for StockfishProcess {
526 fn drop(&mut self) {
527 let _ = self.send_command("quit");
528 let _ = self.child.wait();
529 }
530}
531
532pub struct StockfishPool {
534 pool: Arc<Mutex<Vec<StockfishProcess>>>,
535 depth: u8,
536 pool_size: usize,
537}
538
539impl StockfishPool {
540 pub fn new(depth: u8, pool_size: usize) -> Result<Self, Box<dyn std::error::Error>> {
541 let mut processes = Vec::with_capacity(pool_size);
542
543 println!("๐ Initializing Stockfish pool with {pool_size} processes...");
544
545 for i in 0..pool_size {
546 match StockfishProcess::new(depth) {
547 Ok(process) => {
548 processes.push(process);
549 if i % 2 == 1 {
550 print!(".");
551 let _ = std::io::stdout().flush(); }
553 }
554 Err(e) => {
555 eprintln!("Evaluation error: {e}");
556 return Err(e);
557 }
558 }
559 }
560
561 println!(" โ
Pool ready!");
562
563 Ok(Self {
564 pool: Arc::new(Mutex::new(processes)),
565 depth,
566 pool_size,
567 })
568 }
569
570 pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
571 let mut process = {
573 let mut pool = self.pool.lock().unwrap();
574 if let Some(process) = pool.pop() {
575 process
576 } else {
577 StockfishProcess::new(self.depth)?
579 }
580 };
581
582 let result = process.evaluate_position(board);
584
585 {
587 let mut pool = self.pool.lock().unwrap();
588 if pool.len() < self.pool_size {
589 pool.push(process);
590 }
591 }
593
594 result
595 }
596
597 pub fn evaluate_batch_parallel(
598 &self,
599 positions: &mut [TrainingData],
600 ) -> Result<(), Box<dyn std::error::Error>> {
601 let pb = ProgressBar::new(positions.len() as u64);
602 pb.set_style(ProgressStyle::default_bar()
603 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Pool evaluation")
604 .unwrap()
605 .progress_chars("#>-"));
606
607 positions.par_iter_mut().for_each(|data| {
609 match self.evaluate_position(&data.board) {
610 Ok(eval) => {
611 data.evaluation = eval;
612 data.depth = self.depth;
613 }
614 Err(_) => {
615 data.evaluation = 0.0;
616 }
617 }
618 pb.inc(1);
619 });
620
621 pb.finish_with_message("Pool evaluation complete");
622 Ok(())
623 }
624}
625
626pub struct TrainingDataset {
628 pub data: Vec<TrainingData>,
629}
630
631impl Default for TrainingDataset {
632 fn default() -> Self {
633 Self::new()
634 }
635}
636
637impl TrainingDataset {
638 pub fn new() -> Self {
639 Self { data: Vec::new() }
640 }
641
642 pub fn load_from_pgn<P: AsRef<Path>>(
644 &mut self,
645 path: P,
646 max_games: Option<usize>,
647 max_moves_per_game: usize,
648 ) -> Result<(), Box<dyn std::error::Error>> {
649 use indicatif::{ProgressBar, ProgressStyle};
650
651 println!("๐ Reading PGN file...");
652 let file = File::open(path)?;
653 let reader = BufReader::new(file);
654
655 let mut games = Vec::new();
657 let mut current_game = String::new();
658 let mut games_collected = 0;
659
660 for line in reader.lines() {
661 let line = line?;
662 current_game.push_str(&line);
663 current_game.push('\n');
664
665 if line.trim().ends_with("1-0")
667 || line.trim().ends_with("0-1")
668 || line.trim().ends_with("1/2-1/2")
669 || line.trim().ends_with("*")
670 {
671 games.push(current_game.clone());
672 current_game.clear();
673 games_collected += 1;
674
675 if let Some(max) = max_games {
676 if games_collected >= max {
677 break;
678 }
679 }
680 }
681 }
682
683 println!(
684 "๐ฆ Collected {} games, processing in parallel...",
685 games.len()
686 );
687
688 let pb = ProgressBar::new(games.len() as u64);
690 pb.set_style(
691 ProgressStyle::default_bar()
692 .template("โก Processing [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}")
693 .unwrap()
694 .progress_chars("โโโ")
695 );
696
697 let all_positions: Vec<Vec<TrainingData>> = games
699 .par_iter()
700 .map(|game_pgn| {
701 pb.inc(1);
702 pb.set_message("Processing game");
703
704 let mut local_extractor = GameExtractor::new(max_moves_per_game);
706
707 let cursor = std::io::Cursor::new(game_pgn);
709 let mut reader = BufferedReader::new(cursor);
710
711 if let Err(e) = reader.read_all(&mut local_extractor) {
712 eprintln!("Parse error: {e}");
713 return Vec::new(); }
715
716 local_extractor.positions
717 })
718 .collect();
719
720 pb.finish_with_message("โ
Parallel processing completed");
721
722 for game_positions in all_positions {
724 self.data.extend(game_positions);
725 }
726
727 println!(
728 "โ
Loaded {} positions from {} games (parallel processing)",
729 self.data.len(),
730 games.len()
731 );
732 Ok(())
733 }
734
735 pub fn evaluate_with_stockfish(&mut self, depth: u8) -> Result<(), Box<dyn std::error::Error>> {
737 let evaluator = StockfishEvaluator::new(depth);
738 evaluator.evaluate_batch(&mut self.data)
739 }
740
741 pub fn evaluate_with_stockfish_parallel(
743 &mut self,
744 depth: u8,
745 num_threads: usize,
746 ) -> Result<(), Box<dyn std::error::Error>> {
747 let evaluator = StockfishEvaluator::new(depth);
748 evaluator.evaluate_batch_parallel(&mut self.data, num_threads)
749 }
750
751 pub fn train_engine(&self, engine: &mut ChessVectorEngine) {
753 let pb = ProgressBar::new(self.data.len() as u64);
754 pb.set_style(ProgressStyle::default_bar()
755 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Training positions")
756 .unwrap()
757 .progress_chars("#>-"));
758
759 for data in &self.data {
760 engine.add_position(&data.board, data.evaluation);
761 pb.inc(1);
762 }
763
764 pb.finish_with_message("Training complete");
765 println!("Trained engine with {} positions", self.data.len());
766 }
767
768 pub fn split(&self, train_ratio: f32) -> (TrainingDataset, TrainingDataset) {
770 use rand::seq::SliceRandom;
771 use rand::thread_rng;
772 use std::collections::{HashMap, HashSet};
773
774 let mut games: HashMap<usize, Vec<&TrainingData>> = HashMap::new();
776 for data in &self.data {
777 games.entry(data.game_id).or_default().push(data);
778 }
779
780 let mut game_ids: Vec<usize> = games.keys().cloned().collect();
782 game_ids.shuffle(&mut thread_rng());
783
784 let split_point = (game_ids.len() as f32 * train_ratio) as usize;
786 let train_game_ids: HashSet<usize> = game_ids[..split_point].iter().cloned().collect();
787
788 let mut train_data = Vec::new();
790 let mut test_data = Vec::new();
791
792 for data in &self.data {
793 if train_game_ids.contains(&data.game_id) {
794 train_data.push(data.clone());
795 } else {
796 test_data.push(data.clone());
797 }
798 }
799
800 (
801 TrainingDataset { data: train_data },
802 TrainingDataset { data: test_data },
803 )
804 }
805
806 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
808 let json = serde_json::to_string_pretty(&self.data)?;
809 std::fs::write(path, json)?;
810 Ok(())
811 }
812
813 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
815 let content = std::fs::read_to_string(path)?;
816 let data = serde_json::from_str(&content)?;
817 Ok(Self { data })
818 }
819
820 pub fn load_and_append<P: AsRef<Path>>(
822 &mut self,
823 path: P,
824 ) -> Result<(), Box<dyn std::error::Error>> {
825 let existing_len = self.data.len();
826 let additional_data = Self::load(path)?;
827 self.data.extend(additional_data.data);
828 println!(
829 "Loaded {} additional positions (total: {})",
830 self.data.len() - existing_len,
831 self.data.len()
832 );
833 Ok(())
834 }
835
836 pub fn merge(&mut self, other: TrainingDataset) {
838 let existing_len = self.data.len();
839 self.data.extend(other.data);
840 println!(
841 "Merged {} positions (total: {})",
842 self.data.len() - existing_len,
843 self.data.len()
844 );
845 }
846
847 pub fn save_incremental<P: AsRef<Path>>(
849 &self,
850 path: P,
851 ) -> Result<(), Box<dyn std::error::Error>> {
852 self.save_incremental_with_options(path, true)
853 }
854
855 pub fn save_incremental_with_options<P: AsRef<Path>>(
857 &self,
858 path: P,
859 deduplicate: bool,
860 ) -> Result<(), Box<dyn std::error::Error>> {
861 let path = path.as_ref();
862
863 if path.exists() {
864 if self.save_append_only(path).is_ok() {
866 return Ok(());
867 }
868
869 if deduplicate {
871 self.save_incremental_full_merge(path)
872 } else {
873 self.save_incremental_no_dedup(path)
874 }
875 } else {
876 self.save(path)
878 }
879 }
880
881 fn save_incremental_no_dedup<P: AsRef<Path>>(
883 &self,
884 path: P,
885 ) -> Result<(), Box<dyn std::error::Error>> {
886 let path = path.as_ref();
887
888 println!("๐ Loading existing training data...");
889 let mut existing = Self::load(path)?;
890
891 println!("โก Fast merge without deduplication...");
892 existing.data.extend(self.data.iter().cloned());
893
894 println!(
895 "๐พ Serializing {} positions to JSON...",
896 existing.data.len()
897 );
898 let json = serde_json::to_string_pretty(&existing.data)?;
899
900 println!("โ๏ธ Writing to disk...");
901 std::fs::write(path, json)?;
902
903 println!(
904 "โ
Fast merge save: total {} positions",
905 existing.data.len()
906 );
907 Ok(())
908 }
909
910 pub fn save_append_only<P: AsRef<Path>>(
912 &self,
913 path: P,
914 ) -> Result<(), Box<dyn std::error::Error>> {
915 use std::fs::OpenOptions;
916 use std::io::{BufRead, BufReader, Seek, SeekFrom, Write};
917
918 if self.data.is_empty() {
919 return Ok(());
920 }
921
922 let path = path.as_ref();
923 let mut file = OpenOptions::new().read(true).write(true).open(path)?;
924
925 file.seek(SeekFrom::End(-10))?;
927 let mut buffer = String::new();
928 BufReader::new(&file).read_line(&mut buffer)?;
929
930 if !buffer.trim().ends_with(']') {
931 return Err("File doesn't end with JSON array bracket".into());
932 }
933
934 file.seek(SeekFrom::End(-2))?; write!(file, ",")?;
939
940 for (i, data) in self.data.iter().enumerate() {
942 if i > 0 {
943 write!(file, ",")?;
944 }
945 let json = serde_json::to_string(data)?;
946 write!(file, "{json}")?;
947 }
948
949 write!(file, "\n]")?;
951
952 println!("Fast append: added {} new positions", self.data.len());
953 Ok(())
954 }
955
956 fn save_incremental_full_merge<P: AsRef<Path>>(
958 &self,
959 path: P,
960 ) -> Result<(), Box<dyn std::error::Error>> {
961 let path = path.as_ref();
962
963 println!("๐ Loading existing training data...");
964 let mut existing = Self::load(path)?;
965 let _original_len = existing.data.len();
966
967 println!("๐ Streaming merge with deduplication (avoiding O(nยฒ) operation)...");
968 existing.merge_and_deduplicate(self.data.clone());
969
970 println!(
971 "๐พ Serializing {} positions to JSON...",
972 existing.data.len()
973 );
974 let json = serde_json::to_string_pretty(&existing.data)?;
975
976 println!("โ๏ธ Writing to disk...");
977 std::fs::write(path, json)?;
978
979 println!(
980 "โ
Streaming merge save: total {} positions",
981 existing.data.len()
982 );
983 Ok(())
984 }
985
986 pub fn add_position(&mut self, board: Board, evaluation: f32, depth: u8, game_id: usize) {
988 self.data.push(TrainingData {
989 board,
990 evaluation,
991 depth,
992 game_id,
993 });
994 }
995
996 pub fn next_game_id(&self) -> usize {
998 self.data.iter().map(|data| data.game_id).max().unwrap_or(0) + 1
999 }
1000
1001 pub fn deduplicate(&mut self, similarity_threshold: f32) {
1003 if similarity_threshold > 0.999 {
1004 self.deduplicate_fast();
1006 } else {
1007 self.deduplicate_similarity_based(similarity_threshold);
1009 }
1010 }
1011
1012 pub fn deduplicate_fast(&mut self) {
1014 use std::collections::HashSet;
1015
1016 if self.data.is_empty() {
1017 return;
1018 }
1019
1020 let mut seen_positions = HashSet::with_capacity(self.data.len());
1021 let original_len = self.data.len();
1022
1023 self.data.retain(|data| {
1025 let fen = data.board.to_string();
1026 seen_positions.insert(fen)
1027 });
1028
1029 println!(
1030 "Fast deduplicated: {} -> {} positions (removed {} exact duplicates)",
1031 original_len,
1032 self.data.len(),
1033 original_len - self.data.len()
1034 );
1035 }
1036
1037 pub fn merge_and_deduplicate(&mut self, new_data: Vec<TrainingData>) {
1039 use std::collections::HashSet;
1040
1041 if new_data.is_empty() {
1042 return;
1043 }
1044
1045 let _original_len = self.data.len();
1046
1047 let mut existing_positions: HashSet<String> = HashSet::with_capacity(self.data.len());
1049 for data in &self.data {
1050 existing_positions.insert(data.board.to_string());
1051 }
1052
1053 let mut added = 0;
1055 for data in new_data {
1056 let fen = data.board.to_string();
1057 if existing_positions.insert(fen) {
1058 self.data.push(data);
1059 added += 1;
1060 }
1061 }
1062
1063 println!(
1064 "Streaming merge: added {} unique positions (total: {})",
1065 added,
1066 self.data.len()
1067 );
1068 }
1069
1070 fn deduplicate_similarity_based(&mut self, similarity_threshold: f32) {
1072 use crate::PositionEncoder;
1073 use ndarray::Array1;
1074
1075 if self.data.is_empty() {
1076 return;
1077 }
1078
1079 let encoder = PositionEncoder::new(1024);
1080 let mut keep_indices: Vec<bool> = vec![true; self.data.len()];
1081
1082 let vectors: Vec<Array1<f32>> = if self.data.len() > 50 {
1084 self.data
1085 .par_iter()
1086 .map(|data| encoder.encode(&data.board))
1087 .collect()
1088 } else {
1089 self.data
1090 .iter()
1091 .map(|data| encoder.encode(&data.board))
1092 .collect()
1093 };
1094
1095 for i in 1..self.data.len() {
1097 if !keep_indices[i] {
1098 continue;
1099 }
1100
1101 for j in 0..i {
1102 if !keep_indices[j] {
1103 continue;
1104 }
1105
1106 let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1107 if similarity > similarity_threshold {
1108 keep_indices[i] = false;
1109 break;
1110 }
1111 }
1112 }
1113
1114 let original_len = self.data.len();
1116 self.data = self
1117 .data
1118 .iter()
1119 .enumerate()
1120 .filter_map(|(i, data)| {
1121 if keep_indices[i] {
1122 Some(data.clone())
1123 } else {
1124 None
1125 }
1126 })
1127 .collect();
1128
1129 println!(
1130 "Similarity deduplicated: {} -> {} positions (removed {} near-duplicates)",
1131 original_len,
1132 self.data.len(),
1133 original_len - self.data.len()
1134 );
1135 }
1136
1137 pub fn deduplicate_parallel(&mut self, similarity_threshold: f32, chunk_size: usize) {
1139 use crate::PositionEncoder;
1140 use ndarray::Array1;
1141 use std::sync::{Arc, Mutex};
1142
1143 if self.data.is_empty() {
1144 return;
1145 }
1146
1147 let encoder = PositionEncoder::new(1024);
1148
1149 let vectors: Vec<Array1<f32>> = self
1151 .data
1152 .par_iter()
1153 .map(|data| encoder.encode(&data.board))
1154 .collect();
1155
1156 let keep_indices = Arc::new(Mutex::new(vec![true; self.data.len()]));
1157
1158 (1..self.data.len())
1160 .collect::<Vec<_>>()
1161 .par_chunks(chunk_size)
1162 .for_each(|chunk| {
1163 for &i in chunk {
1164 {
1166 let indices = keep_indices.lock().unwrap();
1167 if !indices[i] {
1168 continue;
1169 }
1170 }
1171
1172 for j in 0..i {
1174 {
1175 let indices = keep_indices.lock().unwrap();
1176 if !indices[j] {
1177 continue;
1178 }
1179 }
1180
1181 let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1182 if similarity > similarity_threshold {
1183 let mut indices = keep_indices.lock().unwrap();
1184 indices[i] = false;
1185 break;
1186 }
1187 }
1188 }
1189 });
1190
1191 let keep_indices = keep_indices.lock().unwrap();
1193 let original_len = self.data.len();
1194 self.data = self
1195 .data
1196 .iter()
1197 .enumerate()
1198 .filter_map(|(i, data)| {
1199 if keep_indices[i] {
1200 Some(data.clone())
1201 } else {
1202 None
1203 }
1204 })
1205 .collect();
1206
1207 println!(
1208 "Parallel deduplicated: {} -> {} positions (removed {} duplicates)",
1209 original_len,
1210 self.data.len(),
1211 original_len - self.data.len()
1212 );
1213 }
1214
1215 fn cosine_similarity(a: &ndarray::Array1<f32>, b: &ndarray::Array1<f32>) -> f32 {
1217 let dot_product = a.dot(b);
1218 let norm_a = a.dot(a).sqrt();
1219 let norm_b = b.dot(b).sqrt();
1220
1221 if norm_a == 0.0 || norm_b == 0.0 {
1222 0.0
1223 } else {
1224 dot_product / (norm_a * norm_b)
1225 }
1226 }
1227}
1228
1229pub struct SelfPlayTrainer {
1231 config: SelfPlayConfig,
1232 game_counter: usize,
1233}
1234
1235impl SelfPlayTrainer {
1236 pub fn new(config: SelfPlayConfig) -> Self {
1237 Self {
1238 config,
1239 game_counter: 0,
1240 }
1241 }
1242
1243 pub fn generate_training_data(&mut self, engine: &mut ChessVectorEngine) -> TrainingDataset {
1245 let mut dataset = TrainingDataset::new();
1246
1247 println!(
1248 "๐ฎ Starting self-play training with {} games...",
1249 self.config.games_per_iteration
1250 );
1251 let pb = ProgressBar::new(self.config.games_per_iteration as u64);
1252 if let Ok(style) = ProgressStyle::default_bar().template(
1253 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
1254 ) {
1255 pb.set_style(style.progress_chars("#>-"));
1256 }
1257
1258 for _ in 0..self.config.games_per_iteration {
1259 let game_data = self.play_single_game(engine);
1260 dataset.data.extend(game_data);
1261 self.game_counter += 1;
1262 pb.inc(1);
1263 }
1264
1265 pb.finish_with_message("Self-play games completed");
1266 println!(
1267 "โ
Generated {} positions from {} games",
1268 dataset.data.len(),
1269 self.config.games_per_iteration
1270 );
1271
1272 dataset
1273 }
1274
1275 fn play_single_game(&self, engine: &mut ChessVectorEngine) -> Vec<TrainingData> {
1277 let mut game = Game::new();
1278 let mut positions = Vec::new();
1279 let mut move_count = 0;
1280
1281 if self.config.use_opening_book {
1283 if let Some(opening_moves) = self.get_random_opening() {
1284 for mv in opening_moves {
1285 if game.make_move(mv) {
1286 move_count += 1;
1287 } else {
1288 break;
1289 }
1290 }
1291 }
1292 }
1293
1294 while game.result().is_none() && move_count < self.config.max_moves_per_game {
1296 let current_position = game.current_position();
1297
1298 let move_choice = self.select_move_with_exploration(engine, ¤t_position);
1300
1301 if let Some(chess_move) = move_choice {
1302 if let Some(evaluation) = engine.evaluate_position(¤t_position) {
1304 if evaluation.abs() >= self.config.min_confidence || move_count < 10 {
1306 positions.push(TrainingData {
1307 board: current_position,
1308 evaluation,
1309 depth: 1, game_id: self.game_counter,
1311 });
1312 }
1313 }
1314
1315 if !game.make_move(chess_move) {
1317 break; }
1319 move_count += 1;
1320 } else {
1321 break; }
1323 }
1324
1325 if let Some(result) = game.result() {
1327 let final_position = game.current_position();
1328 let final_eval = match result {
1329 chess::GameResult::WhiteCheckmates => {
1330 if final_position.side_to_move() == chess::Color::Black {
1331 10.0
1332 } else {
1333 -10.0
1334 }
1335 }
1336 chess::GameResult::BlackCheckmates => {
1337 if final_position.side_to_move() == chess::Color::White {
1338 10.0
1339 } else {
1340 -10.0
1341 }
1342 }
1343 chess::GameResult::WhiteResigns => -10.0,
1344 chess::GameResult::BlackResigns => 10.0,
1345 chess::GameResult::Stalemate
1346 | chess::GameResult::DrawAccepted
1347 | chess::GameResult::DrawDeclared => 0.0,
1348 };
1349
1350 positions.push(TrainingData {
1351 board: final_position,
1352 evaluation: final_eval,
1353 depth: 1,
1354 game_id: self.game_counter,
1355 });
1356 }
1357
1358 positions
1359 }
1360
1361 fn select_move_with_exploration(
1363 &self,
1364 engine: &mut ChessVectorEngine,
1365 position: &Board,
1366 ) -> Option<ChessMove> {
1367 let recommendations = engine.recommend_legal_moves(position, 5);
1368
1369 if recommendations.is_empty() {
1370 return None;
1371 }
1372
1373 if fastrand::f32() < self.config.exploration_factor {
1375 self.select_move_with_temperature(&recommendations)
1377 } else {
1378 Some(recommendations[0].chess_move)
1380 }
1381 }
1382
1383 fn select_move_with_temperature(
1385 &self,
1386 recommendations: &[crate::MoveRecommendation],
1387 ) -> Option<ChessMove> {
1388 if recommendations.is_empty() {
1389 return None;
1390 }
1391
1392 let mut probabilities = Vec::new();
1394 let mut sum = 0.0;
1395
1396 for rec in recommendations {
1397 let prob = (rec.average_outcome / self.config.temperature).exp();
1399 probabilities.push(prob);
1400 sum += prob;
1401 }
1402
1403 for prob in &mut probabilities {
1405 *prob /= sum;
1406 }
1407
1408 let rand_val = fastrand::f32();
1410 let mut cumulative = 0.0;
1411
1412 for (i, &prob) in probabilities.iter().enumerate() {
1413 cumulative += prob;
1414 if rand_val <= cumulative {
1415 return Some(recommendations[i].chess_move);
1416 }
1417 }
1418
1419 Some(recommendations[0].chess_move)
1421 }
1422
1423 fn get_random_opening(&self) -> Option<Vec<ChessMove>> {
1425 let openings = [
1426 vec!["e4", "e5", "Nf3", "Nc6", "Bc4"],
1428 vec!["e4", "e5", "Nf3", "Nc6", "Bb5"],
1430 vec!["d4", "d5", "c4"],
1432 vec!["d4", "Nf6", "c4", "g6"],
1434 vec!["e4", "c5"],
1436 vec!["e4", "e6"],
1438 vec!["e4", "c6"],
1440 ];
1441
1442 let selected_opening = &openings[fastrand::usize(0..openings.len())];
1443
1444 let mut moves = Vec::new();
1445 let mut game = Game::new();
1446
1447 for move_str in selected_opening {
1448 if let Ok(chess_move) = ChessMove::from_str(move_str) {
1449 if game.make_move(chess_move) {
1450 moves.push(chess_move);
1451 } else {
1452 break;
1453 }
1454 }
1455 }
1456
1457 if moves.is_empty() {
1458 None
1459 } else {
1460 Some(moves)
1461 }
1462 }
1463}
1464
1465pub struct EngineEvaluator {
1467 #[allow(dead_code)]
1468 stockfish_depth: u8,
1469}
1470
1471impl EngineEvaluator {
1472 pub fn new(stockfish_depth: u8) -> Self {
1473 Self { stockfish_depth }
1474 }
1475
1476 pub fn evaluate_accuracy(
1478 &self,
1479 engine: &mut ChessVectorEngine,
1480 test_data: &TrainingDataset,
1481 ) -> Result<f32, Box<dyn std::error::Error>> {
1482 let mut total_error = 0.0;
1483 let mut valid_comparisons = 0;
1484
1485 let pb = ProgressBar::new(test_data.data.len() as u64);
1486 pb.set_style(ProgressStyle::default_bar()
1487 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Evaluating accuracy")
1488 .unwrap()
1489 .progress_chars("#>-"));
1490
1491 for data in &test_data.data {
1492 if let Some(engine_eval) = engine.evaluate_position(&data.board) {
1493 let error = (engine_eval - data.evaluation).abs();
1494 total_error += error;
1495 valid_comparisons += 1;
1496 }
1497 pb.inc(1);
1498 }
1499
1500 pb.finish_with_message("Accuracy evaluation complete");
1501
1502 if valid_comparisons > 0 {
1503 let mean_absolute_error = total_error / valid_comparisons as f32;
1504 println!("Mean Absolute Error: {mean_absolute_error:.3} pawns");
1505 println!("Evaluated {valid_comparisons} positions");
1506 Ok(mean_absolute_error)
1507 } else {
1508 Ok(f32::INFINITY)
1509 }
1510 }
1511}
1512
1513impl serde::Serialize for TrainingData {
1515 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1516 where
1517 S: serde::Serializer,
1518 {
1519 use serde::ser::SerializeStruct;
1520 let mut state = serializer.serialize_struct("TrainingData", 4)?;
1521 state.serialize_field("fen", &self.board.to_string())?;
1522 state.serialize_field("evaluation", &self.evaluation)?;
1523 state.serialize_field("depth", &self.depth)?;
1524 state.serialize_field("game_id", &self.game_id)?;
1525 state.end()
1526 }
1527}
1528
1529impl<'de> serde::Deserialize<'de> for TrainingData {
1530 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1531 where
1532 D: serde::Deserializer<'de>,
1533 {
1534 use serde::de::{self, MapAccess, Visitor};
1535 use std::fmt;
1536
1537 struct TrainingDataVisitor;
1538
1539 impl<'de> Visitor<'de> for TrainingDataVisitor {
1540 type Value = TrainingData;
1541
1542 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1543 formatter.write_str("struct TrainingData")
1544 }
1545
1546 fn visit_map<V>(self, mut map: V) -> Result<TrainingData, V::Error>
1547 where
1548 V: MapAccess<'de>,
1549 {
1550 let mut fen = None;
1551 let mut evaluation = None;
1552 let mut depth = None;
1553 let mut game_id = None;
1554
1555 while let Some(key) = map.next_key()? {
1556 match key {
1557 "fen" => {
1558 if fen.is_some() {
1559 return Err(de::Error::duplicate_field("fen"));
1560 }
1561 fen = Some(map.next_value()?);
1562 }
1563 "evaluation" => {
1564 if evaluation.is_some() {
1565 return Err(de::Error::duplicate_field("evaluation"));
1566 }
1567 evaluation = Some(map.next_value()?);
1568 }
1569 "depth" => {
1570 if depth.is_some() {
1571 return Err(de::Error::duplicate_field("depth"));
1572 }
1573 depth = Some(map.next_value()?);
1574 }
1575 "game_id" => {
1576 if game_id.is_some() {
1577 return Err(de::Error::duplicate_field("game_id"));
1578 }
1579 game_id = Some(map.next_value()?);
1580 }
1581 _ => {
1582 let _: serde_json::Value = map.next_value()?;
1583 }
1584 }
1585 }
1586
1587 let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
1588 let mut evaluation: f32 =
1589 evaluation.ok_or_else(|| de::Error::missing_field("evaluation"))?;
1590 let depth = depth.ok_or_else(|| de::Error::missing_field("depth"))?;
1591 let game_id = game_id.unwrap_or(0); if evaluation.abs() > 15.0 {
1597 evaluation /= 100.0;
1598 }
1599
1600 let board =
1601 Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
1602
1603 Ok(TrainingData {
1604 board,
1605 evaluation,
1606 depth,
1607 game_id,
1608 })
1609 }
1610 }
1611
1612 const FIELDS: &[&str] = &["fen", "evaluation", "depth", "game_id"];
1613 deserializer.deserialize_struct("TrainingData", FIELDS, TrainingDataVisitor)
1614 }
1615}
1616
1617pub struct TacticalPuzzleParser;
1619
1620impl TacticalPuzzleParser {
1621 pub fn parse_csv<P: AsRef<Path>>(
1623 file_path: P,
1624 max_puzzles: Option<usize>,
1625 min_rating: Option<u32>,
1626 max_rating: Option<u32>,
1627 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1628 let file = File::open(&file_path)?;
1629 let file_size = file.metadata()?.len();
1630
1631 if file_size > 100_000_000 {
1633 Self::parse_csv_parallel(file_path, max_puzzles, min_rating, max_rating)
1634 } else {
1635 Self::parse_csv_sequential(file_path, max_puzzles, min_rating, max_rating)
1636 }
1637 }
1638
1639 fn parse_csv_sequential<P: AsRef<Path>>(
1641 file_path: P,
1642 max_puzzles: Option<usize>,
1643 min_rating: Option<u32>,
1644 max_rating: Option<u32>,
1645 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1646 let file = File::open(file_path)?;
1647 let reader = BufReader::new(file);
1648
1649 let mut csv_reader = csv::ReaderBuilder::new()
1652 .has_headers(false)
1653 .flexible(true) .from_reader(reader);
1655
1656 let mut tactical_data = Vec::new();
1657 let mut processed = 0;
1658 let mut skipped = 0;
1659
1660 let pb = ProgressBar::new_spinner();
1661 pb.set_style(
1662 ProgressStyle::default_spinner()
1663 .template("{spinner:.green} Parsing tactical puzzles: {pos} (skipped: {skipped})")
1664 .unwrap(),
1665 );
1666
1667 for result in csv_reader.records() {
1668 let record = match result {
1669 Ok(r) => r,
1670 Err(e) => {
1671 skipped += 1;
1672 println!("CSV parsing error: {e}");
1673 continue;
1674 }
1675 };
1676
1677 if let Some(puzzle_data) = Self::parse_csv_record(&record, min_rating, max_rating) {
1678 if let Some(tactical_data_item) =
1679 Self::convert_puzzle_to_training_data(&puzzle_data)
1680 {
1681 tactical_data.push(tactical_data_item);
1682 processed += 1;
1683
1684 if let Some(max) = max_puzzles {
1685 if processed >= max {
1686 break;
1687 }
1688 }
1689 } else {
1690 skipped += 1;
1691 }
1692 } else {
1693 skipped += 1;
1694 }
1695
1696 pb.set_message(format!(
1697 "Parsing tactical puzzles: {processed} (skipped: {skipped})"
1698 ));
1699 }
1700
1701 pb.finish_with_message(format!("Parsed {processed} puzzles (skipped: {skipped})"));
1702
1703 Ok(tactical_data)
1704 }
1705
1706 fn parse_csv_parallel<P: AsRef<Path>>(
1708 file_path: P,
1709 max_puzzles: Option<usize>,
1710 min_rating: Option<u32>,
1711 max_rating: Option<u32>,
1712 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1713 use std::io::Read;
1714
1715 let mut file = File::open(&file_path)?;
1716
1717 let mut contents = String::new();
1719 file.read_to_string(&mut contents)?;
1720
1721 let lines: Vec<&str> = contents.lines().collect();
1723
1724 let pb = ProgressBar::new(lines.len() as u64);
1725 pb.set_style(ProgressStyle::default_bar()
1726 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel CSV parsing")
1727 .unwrap()
1728 .progress_chars("#>-"));
1729
1730 let tactical_data: Vec<TacticalTrainingData> = lines
1732 .par_iter()
1733 .take(max_puzzles.unwrap_or(usize::MAX))
1734 .filter_map(|line| {
1735 let fields: Vec<&str> = line.split(',').collect();
1737 if fields.len() < 8 {
1738 return None;
1739 }
1740
1741 if let Some(puzzle_data) = Self::parse_csv_fields(&fields, min_rating, max_rating) {
1743 Self::convert_puzzle_to_training_data(&puzzle_data)
1744 } else {
1745 None
1746 }
1747 })
1748 .collect();
1749
1750 pb.finish_with_message(format!(
1751 "Parallel parsing complete: {} puzzles",
1752 tactical_data.len()
1753 ));
1754
1755 Ok(tactical_data)
1756 }
1757
1758 fn parse_csv_record(
1760 record: &csv::StringRecord,
1761 min_rating: Option<u32>,
1762 max_rating: Option<u32>,
1763 ) -> Option<TacticalPuzzle> {
1764 if record.len() < 8 {
1766 return None;
1767 }
1768
1769 let rating: u32 = record[3].parse().ok()?;
1770 let rating_deviation: u32 = record[4].parse().ok()?;
1771 let popularity: i32 = record[5].parse().ok()?;
1772 let nb_plays: u32 = record[6].parse().ok()?;
1773
1774 if let Some(min) = min_rating {
1776 if rating < min {
1777 return None;
1778 }
1779 }
1780 if let Some(max) = max_rating {
1781 if rating > max {
1782 return None;
1783 }
1784 }
1785
1786 Some(TacticalPuzzle {
1787 puzzle_id: record[0].to_string(),
1788 fen: record[1].to_string(),
1789 moves: record[2].to_string(),
1790 rating,
1791 rating_deviation,
1792 popularity,
1793 nb_plays,
1794 themes: record[7].to_string(),
1795 game_url: if record.len() > 8 {
1796 Some(record[8].to_string())
1797 } else {
1798 None
1799 },
1800 opening_tags: if record.len() > 9 {
1801 Some(record[9].to_string())
1802 } else {
1803 None
1804 },
1805 })
1806 }
1807
1808 fn parse_csv_fields(
1810 fields: &[&str],
1811 min_rating: Option<u32>,
1812 max_rating: Option<u32>,
1813 ) -> Option<TacticalPuzzle> {
1814 if fields.len() < 8 {
1815 return None;
1816 }
1817
1818 let rating: u32 = fields[3].parse().ok()?;
1819 let rating_deviation: u32 = fields[4].parse().ok()?;
1820 let popularity: i32 = fields[5].parse().ok()?;
1821 let nb_plays: u32 = fields[6].parse().ok()?;
1822
1823 if let Some(min) = min_rating {
1825 if rating < min {
1826 return None;
1827 }
1828 }
1829 if let Some(max) = max_rating {
1830 if rating > max {
1831 return None;
1832 }
1833 }
1834
1835 Some(TacticalPuzzle {
1836 puzzle_id: fields[0].to_string(),
1837 fen: fields[1].to_string(),
1838 moves: fields[2].to_string(),
1839 rating,
1840 rating_deviation,
1841 popularity,
1842 nb_plays,
1843 themes: fields[7].to_string(),
1844 game_url: if fields.len() > 8 {
1845 Some(fields[8].to_string())
1846 } else {
1847 None
1848 },
1849 opening_tags: if fields.len() > 9 {
1850 Some(fields[9].to_string())
1851 } else {
1852 None
1853 },
1854 })
1855 }
1856
1857 fn convert_puzzle_to_training_data(puzzle: &TacticalPuzzle) -> Option<TacticalTrainingData> {
1859 let position = match Board::from_str(&puzzle.fen) {
1861 Ok(board) => board,
1862 Err(_) => return None,
1863 };
1864
1865 let moves: Vec<&str> = puzzle.moves.split_whitespace().collect();
1867 if moves.is_empty() {
1868 return None;
1869 }
1870
1871 let solution_move = match ChessMove::from_str(moves[0]) {
1873 Ok(mv) => mv,
1874 Err(_) => {
1875 match ChessMove::from_san(&position, moves[0]) {
1877 Ok(mv) => mv,
1878 Err(_) => return None,
1879 }
1880 }
1881 };
1882
1883 let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&position).collect();
1885 if !legal_moves.contains(&solution_move) {
1886 return None;
1887 }
1888
1889 let themes: Vec<&str> = puzzle.themes.split_whitespace().collect();
1891 let primary_theme = themes.first().unwrap_or(&"tactical").to_string();
1892
1893 let difficulty = puzzle.rating as f32 / 1000.0; let popularity_bonus = (puzzle.popularity as f32 / 100.0).min(2.0);
1896 let tactical_value = difficulty + popularity_bonus; Some(TacticalTrainingData {
1899 position,
1900 solution_move,
1901 move_theme: primary_theme,
1902 difficulty,
1903 tactical_value,
1904 })
1905 }
1906
1907 pub fn load_into_engine(
1909 tactical_data: &[TacticalTrainingData],
1910 engine: &mut ChessVectorEngine,
1911 ) {
1912 let pb = ProgressBar::new(tactical_data.len() as u64);
1913 pb.set_style(ProgressStyle::default_bar()
1914 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns")
1915 .unwrap()
1916 .progress_chars("#>-"));
1917
1918 for data in tactical_data {
1919 engine.add_position_with_move(
1921 &data.position,
1922 0.0, Some(data.solution_move),
1924 Some(data.tactical_value), );
1926 pb.inc(1);
1927 }
1928
1929 pb.finish_with_message(format!("Loaded {} tactical patterns", tactical_data.len()));
1930 }
1931
1932 pub fn load_into_engine_incremental(
1934 tactical_data: &[TacticalTrainingData],
1935 engine: &mut ChessVectorEngine,
1936 ) {
1937 let initial_size = engine.knowledge_base_size();
1938 let initial_moves = engine.position_moves.len();
1939
1940 if tactical_data.len() > 1000 {
1942 Self::load_into_engine_incremental_parallel(
1943 tactical_data,
1944 engine,
1945 initial_size,
1946 initial_moves,
1947 );
1948 } else {
1949 Self::load_into_engine_incremental_sequential(
1950 tactical_data,
1951 engine,
1952 initial_size,
1953 initial_moves,
1954 );
1955 }
1956 }
1957
1958 fn load_into_engine_incremental_sequential(
1960 tactical_data: &[TacticalTrainingData],
1961 engine: &mut ChessVectorEngine,
1962 initial_size: usize,
1963 initial_moves: usize,
1964 ) {
1965 let pb = ProgressBar::new(tactical_data.len() as u64);
1966 pb.set_style(ProgressStyle::default_bar()
1967 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns (incremental)")
1968 .unwrap()
1969 .progress_chars("#>-"));
1970
1971 let mut added = 0;
1972 let mut skipped = 0;
1973
1974 for data in tactical_data {
1975 if !engine.position_boards.contains(&data.position) {
1977 engine.add_position_with_move(
1978 &data.position,
1979 0.0, Some(data.solution_move),
1981 Some(data.tactical_value), );
1983 added += 1;
1984 } else {
1985 skipped += 1;
1986 }
1987 pb.inc(1);
1988 }
1989
1990 pb.finish_with_message(format!(
1991 "Loaded {} new tactical patterns (skipped {} duplicates, total: {})",
1992 added,
1993 skipped,
1994 engine.knowledge_base_size()
1995 ));
1996
1997 println!("Incremental tactical training:");
1998 println!(
1999 " - Positions: {} โ {} (+{})",
2000 initial_size,
2001 engine.knowledge_base_size(),
2002 engine.knowledge_base_size() - initial_size
2003 );
2004 println!(
2005 " - Move entries: {} โ {} (+{})",
2006 initial_moves,
2007 engine.position_moves.len(),
2008 engine.position_moves.len() - initial_moves
2009 );
2010 }
2011
2012 fn load_into_engine_incremental_parallel(
2014 tactical_data: &[TacticalTrainingData],
2015 engine: &mut ChessVectorEngine,
2016 initial_size: usize,
2017 initial_moves: usize,
2018 ) {
2019 let pb = ProgressBar::new(tactical_data.len() as u64);
2020 pb.set_style(ProgressStyle::default_bar()
2021 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Optimized batch loading tactical patterns")
2022 .unwrap()
2023 .progress_chars("#>-"));
2024
2025 let filtered_data: Vec<&TacticalTrainingData> = tactical_data
2027 .par_iter()
2028 .filter(|data| !engine.position_boards.contains(&data.position))
2029 .collect();
2030
2031 let batch_size = 1000; let mut added = 0;
2033
2034 println!(
2035 "Pre-filtered: {} โ {} positions (removed {} duplicates)",
2036 tactical_data.len(),
2037 filtered_data.len(),
2038 tactical_data.len() - filtered_data.len()
2039 );
2040
2041 for batch in filtered_data.chunks(batch_size) {
2044 let batch_start = added;
2045
2046 for data in batch {
2047 if !engine.position_boards.contains(&data.position) {
2049 engine.add_position_with_move(
2050 &data.position,
2051 0.0, Some(data.solution_move),
2053 Some(data.tactical_value), );
2055 added += 1;
2056 }
2057 pb.inc(1);
2058 }
2059
2060 pb.set_message(format!("Loaded batch: {} positions", added - batch_start));
2062 }
2063
2064 let skipped = tactical_data.len() - added;
2065
2066 pb.finish_with_message(format!(
2067 "Optimized loaded {} new tactical patterns (skipped {} duplicates, total: {})",
2068 added,
2069 skipped,
2070 engine.knowledge_base_size()
2071 ));
2072
2073 println!("Incremental tactical training (optimized):");
2074 println!(
2075 " - Positions: {} โ {} (+{})",
2076 initial_size,
2077 engine.knowledge_base_size(),
2078 engine.knowledge_base_size() - initial_size
2079 );
2080 println!(
2081 " - Move entries: {} โ {} (+{})",
2082 initial_moves,
2083 engine.position_moves.len(),
2084 engine.position_moves.len() - initial_moves
2085 );
2086 println!(
2087 " - Batch size: {}, Pre-filtered efficiency: {:.1}%",
2088 batch_size,
2089 (filtered_data.len() as f32 / tactical_data.len() as f32) * 100.0
2090 );
2091 }
2092
2093 pub fn save_tactical_puzzles<P: AsRef<std::path::Path>>(
2095 tactical_data: &[TacticalTrainingData],
2096 path: P,
2097 ) -> Result<(), Box<dyn std::error::Error>> {
2098 let json = serde_json::to_string_pretty(tactical_data)?;
2099 std::fs::write(path, json)?;
2100 println!("Saved {} tactical puzzles", tactical_data.len());
2101 Ok(())
2102 }
2103
2104 pub fn load_tactical_puzzles<P: AsRef<std::path::Path>>(
2106 path: P,
2107 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
2108 let content = std::fs::read_to_string(path)?;
2109 let tactical_data: Vec<TacticalTrainingData> = serde_json::from_str(&content)?;
2110 println!("Loaded {} tactical puzzles from file", tactical_data.len());
2111 Ok(tactical_data)
2112 }
2113
2114 pub fn save_tactical_puzzles_incremental<P: AsRef<std::path::Path>>(
2116 tactical_data: &[TacticalTrainingData],
2117 path: P,
2118 ) -> Result<(), Box<dyn std::error::Error>> {
2119 let path = path.as_ref();
2120
2121 if path.exists() {
2122 let mut existing = Self::load_tactical_puzzles(path)?;
2124 let original_len = existing.len();
2125
2126 for new_puzzle in tactical_data {
2128 let exists = existing.iter().any(|existing_puzzle| {
2130 existing_puzzle.position == new_puzzle.position
2131 && existing_puzzle.solution_move == new_puzzle.solution_move
2132 });
2133
2134 if !exists {
2135 existing.push(new_puzzle.clone());
2136 }
2137 }
2138
2139 let json = serde_json::to_string_pretty(&existing)?;
2141 std::fs::write(path, json)?;
2142
2143 println!(
2144 "Incremental save: added {} new puzzles (total: {})",
2145 existing.len() - original_len,
2146 existing.len()
2147 );
2148 } else {
2149 Self::save_tactical_puzzles(tactical_data, path)?;
2151 }
2152 Ok(())
2153 }
2154
2155 pub fn parse_and_load_incremental<P: AsRef<std::path::Path>>(
2157 file_path: P,
2158 engine: &mut ChessVectorEngine,
2159 max_puzzles: Option<usize>,
2160 min_rating: Option<u32>,
2161 max_rating: Option<u32>,
2162 ) -> Result<(), Box<dyn std::error::Error>> {
2163 println!("Parsing Lichess puzzles incrementally...");
2164
2165 let tactical_data = Self::parse_csv(file_path, max_puzzles, min_rating, max_rating)?;
2167
2168 Self::load_into_engine_incremental(&tactical_data, engine);
2170
2171 Ok(())
2172 }
2173}
2174
2175#[cfg(test)]
2176mod tests {
2177 use super::*;
2178 use chess::Board;
2179 use std::str::FromStr;
2180
2181 #[test]
2182 fn test_training_dataset_creation() {
2183 let dataset = TrainingDataset::new();
2184 assert_eq!(dataset.data.len(), 0);
2185 }
2186
2187 #[test]
2188 fn test_add_training_data() {
2189 let mut dataset = TrainingDataset::new();
2190 let board = Board::default();
2191
2192 let training_data = TrainingData {
2193 board,
2194 evaluation: 0.5,
2195 depth: 15,
2196 game_id: 1,
2197 };
2198
2199 dataset.data.push(training_data);
2200 assert_eq!(dataset.data.len(), 1);
2201 assert_eq!(dataset.data[0].evaluation, 0.5);
2202 }
2203
2204 #[test]
2205 fn test_chess_engine_integration() {
2206 let mut dataset = TrainingDataset::new();
2207 let board = Board::default();
2208
2209 let training_data = TrainingData {
2210 board,
2211 evaluation: 0.3,
2212 depth: 15,
2213 game_id: 1,
2214 };
2215
2216 dataset.data.push(training_data);
2217
2218 let mut engine = ChessVectorEngine::new(1024);
2219 dataset.train_engine(&mut engine);
2220
2221 assert_eq!(engine.knowledge_base_size(), 1);
2222
2223 let eval = engine.evaluate_position(&board);
2224 assert!(eval.is_some());
2225 let eval_value = eval.unwrap();
2227 assert!(
2228 eval_value > -1000.0 && eval_value < 1000.0,
2229 "Evaluation should be reasonable: {}",
2230 eval_value
2231 );
2232 }
2233
2234 #[test]
2235 fn test_deduplication() {
2236 let mut dataset = TrainingDataset::new();
2237 let board = Board::default();
2238
2239 for i in 0..5 {
2241 let training_data = TrainingData {
2242 board,
2243 evaluation: i as f32 * 0.1,
2244 depth: 15,
2245 game_id: i,
2246 };
2247 dataset.data.push(training_data);
2248 }
2249
2250 assert_eq!(dataset.data.len(), 5);
2251
2252 dataset.deduplicate(0.999);
2254 assert_eq!(dataset.data.len(), 1);
2255 }
2256
2257 #[test]
2258 fn test_dataset_serialization() {
2259 let mut dataset = TrainingDataset::new();
2260 let board =
2261 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2262
2263 let training_data = TrainingData {
2264 board,
2265 evaluation: 0.2,
2266 depth: 10,
2267 game_id: 42,
2268 };
2269
2270 dataset.data.push(training_data);
2271
2272 let json = serde_json::to_string(&dataset.data).unwrap();
2274 let loaded_data: Vec<TrainingData> = serde_json::from_str(&json).unwrap();
2275 let loaded_dataset = TrainingDataset { data: loaded_data };
2276
2277 assert_eq!(loaded_dataset.data.len(), 1);
2278 assert_eq!(loaded_dataset.data[0].evaluation, 0.2);
2279 assert_eq!(loaded_dataset.data[0].depth, 10);
2280 assert_eq!(loaded_dataset.data[0].game_id, 42);
2281 }
2282
2283 #[test]
2284 fn test_tactical_puzzle_processing() {
2285 let puzzle = TacticalPuzzle {
2286 puzzle_id: "test123".to_string(),
2287 fen: "r1bqkbnr/pppp1ppp/2n5/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4".to_string(),
2288 moves: "Bxf7+ Ke7".to_string(),
2289 rating: 1500,
2290 rating_deviation: 100,
2291 popularity: 150,
2292 nb_plays: 1000,
2293 themes: "fork pin".to_string(),
2294 game_url: None,
2295 opening_tags: None,
2296 };
2297
2298 let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2299 assert!(tactical_data.is_some());
2300
2301 let data = tactical_data.unwrap();
2302 assert_eq!(data.move_theme, "fork");
2303 assert!(data.tactical_value > 1.0); assert!(data.difficulty > 0.0);
2305 }
2306
2307 #[test]
2308 fn test_tactical_puzzle_invalid_fen() {
2309 let puzzle = TacticalPuzzle {
2310 puzzle_id: "test123".to_string(),
2311 fen: "invalid_fen".to_string(),
2312 moves: "e2e4".to_string(),
2313 rating: 1500,
2314 rating_deviation: 100,
2315 popularity: 150,
2316 nb_plays: 1000,
2317 themes: "tactics".to_string(),
2318 game_url: None,
2319 opening_tags: None,
2320 };
2321
2322 let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2323 assert!(tactical_data.is_none());
2324 }
2325
2326 #[test]
2327 fn test_engine_evaluator() {
2328 let evaluator = EngineEvaluator::new(15);
2329
2330 let mut dataset = TrainingDataset::new();
2332 let board = Board::default();
2333
2334 let training_data = TrainingData {
2335 board,
2336 evaluation: 0.0,
2337 depth: 15,
2338 game_id: 1,
2339 };
2340
2341 dataset.data.push(training_data);
2342
2343 let mut engine = ChessVectorEngine::new(1024);
2345 engine.add_position(&board, 0.1);
2346
2347 let accuracy = evaluator.evaluate_accuracy(&mut engine, &dataset);
2349 assert!(accuracy.is_ok());
2350 let accuracy_value = accuracy.unwrap();
2353 assert!(
2354 accuracy_value >= 0.0,
2355 "Accuracy should be non-negative: {}",
2356 accuracy_value
2357 );
2358 }
2359
2360 #[test]
2361 fn test_tactical_training_integration() {
2362 let tactical_data = vec![TacticalTrainingData {
2363 position: Board::default(),
2364 solution_move: ChessMove::from_str("e2e4").unwrap(),
2365 move_theme: "opening".to_string(),
2366 difficulty: 1.2,
2367 tactical_value: 2.5,
2368 }];
2369
2370 let mut engine = ChessVectorEngine::new(1024);
2371 TacticalPuzzleParser::load_into_engine(&tactical_data, &mut engine);
2372
2373 assert_eq!(engine.knowledge_base_size(), 1);
2374 assert_eq!(engine.position_moves.len(), 1);
2375
2376 let recommendations = engine.recommend_moves(&Board::default(), 5);
2378 assert!(!recommendations.is_empty());
2379 }
2380
2381 #[test]
2382 fn test_multithreading_operations() {
2383 let mut dataset = TrainingDataset::new();
2384 let board = Board::default();
2385
2386 for i in 0..10 {
2388 let training_data = TrainingData {
2389 board,
2390 evaluation: i as f32 * 0.1,
2391 depth: 15,
2392 game_id: i,
2393 };
2394 dataset.data.push(training_data);
2395 }
2396
2397 dataset.deduplicate_parallel(0.95, 5);
2399 assert!(dataset.data.len() <= 10);
2400 }
2401
2402 #[test]
2403 fn test_incremental_dataset_operations() {
2404 let mut dataset1 = TrainingDataset::new();
2405 let board1 = Board::default();
2406 let board2 =
2407 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2408
2409 dataset1.add_position(board1, 0.0, 15, 1);
2411 dataset1.add_position(board2, 0.2, 15, 2);
2412 assert_eq!(dataset1.data.len(), 2);
2413
2414 let mut dataset2 = TrainingDataset::new();
2416 dataset2.add_position(
2417 Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2418 .unwrap(),
2419 0.3,
2420 15,
2421 3,
2422 );
2423
2424 dataset1.merge(dataset2);
2426 assert_eq!(dataset1.data.len(), 3);
2427
2428 let next_id = dataset1.next_game_id();
2430 assert_eq!(next_id, 4); }
2432
2433 #[test]
2434 fn test_save_load_incremental() {
2435 use tempfile::tempdir;
2436
2437 let temp_dir = tempdir().unwrap();
2438 let file_path = temp_dir.path().join("incremental_test.json");
2439
2440 let mut dataset1 = TrainingDataset::new();
2442 dataset1.add_position(Board::default(), 0.0, 15, 1);
2443 dataset1.save(&file_path).unwrap();
2444
2445 let mut dataset2 = TrainingDataset::new();
2447 dataset2.add_position(
2448 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap(),
2449 0.2,
2450 15,
2451 2,
2452 );
2453 dataset2.save_incremental(&file_path).unwrap();
2454
2455 let loaded = TrainingDataset::load(&file_path).unwrap();
2457 assert_eq!(loaded.data.len(), 2);
2458
2459 let mut dataset3 = TrainingDataset::new();
2461 dataset3.add_position(
2462 Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2463 .unwrap(),
2464 0.3,
2465 15,
2466 3,
2467 );
2468 dataset3.load_and_append(&file_path).unwrap();
2469 assert_eq!(dataset3.data.len(), 3); }
2471
2472 #[test]
2473 fn test_add_position_method() {
2474 let mut dataset = TrainingDataset::new();
2475 let board = Board::default();
2476
2477 dataset.add_position(board, 0.5, 20, 42);
2479 assert_eq!(dataset.data.len(), 1);
2480 assert_eq!(dataset.data[0].evaluation, 0.5);
2481 assert_eq!(dataset.data[0].depth, 20);
2482 assert_eq!(dataset.data[0].game_id, 42);
2483 }
2484
2485 #[test]
2486 fn test_incremental_save_deduplication() {
2487 use tempfile::tempdir;
2488
2489 let temp_dir = tempdir().unwrap();
2490 let file_path = temp_dir.path().join("dedup_test.json");
2491
2492 let mut dataset1 = TrainingDataset::new();
2494 dataset1.add_position(Board::default(), 0.0, 15, 1);
2495 dataset1.save(&file_path).unwrap();
2496
2497 let mut dataset2 = TrainingDataset::new();
2499 dataset2.add_position(Board::default(), 0.1, 15, 2); dataset2.save_incremental(&file_path).unwrap();
2501
2502 let loaded = TrainingDataset::load(&file_path).unwrap();
2504 assert_eq!(loaded.data.len(), 1);
2505 }
2506
2507 #[test]
2508 fn test_tactical_puzzle_incremental_loading() {
2509 let tactical_data = vec![
2510 TacticalTrainingData {
2511 position: Board::default(),
2512 solution_move: ChessMove::from_str("e2e4").unwrap(),
2513 move_theme: "opening".to_string(),
2514 difficulty: 1.2,
2515 tactical_value: 2.5,
2516 },
2517 TacticalTrainingData {
2518 position: Board::from_str(
2519 "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1",
2520 )
2521 .unwrap(),
2522 solution_move: ChessMove::from_str("e7e5").unwrap(),
2523 move_theme: "opening".to_string(),
2524 difficulty: 1.0,
2525 tactical_value: 2.0,
2526 },
2527 ];
2528
2529 let mut engine = ChessVectorEngine::new(1024);
2530
2531 engine.add_position(&Board::default(), 0.1);
2533 assert_eq!(engine.knowledge_base_size(), 1);
2534
2535 TacticalPuzzleParser::load_into_engine_incremental(&tactical_data, &mut engine);
2537
2538 assert_eq!(engine.knowledge_base_size(), 2);
2540
2541 assert!(engine.training_stats().has_move_data);
2543 assert!(engine.training_stats().move_data_entries > 0);
2544 }
2545
2546 #[test]
2547 fn test_tactical_puzzle_serialization() {
2548 use tempfile::tempdir;
2549
2550 let temp_dir = tempdir().unwrap();
2551 let file_path = temp_dir.path().join("tactical_test.json");
2552
2553 let tactical_data = vec![TacticalTrainingData {
2554 position: Board::default(),
2555 solution_move: ChessMove::from_str("e2e4").unwrap(),
2556 move_theme: "fork".to_string(),
2557 difficulty: 1.5,
2558 tactical_value: 3.0,
2559 }];
2560
2561 TacticalPuzzleParser::save_tactical_puzzles(&tactical_data, &file_path).unwrap();
2563
2564 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2566 assert_eq!(loaded.len(), 1);
2567 assert_eq!(loaded[0].move_theme, "fork");
2568 assert_eq!(loaded[0].difficulty, 1.5);
2569 assert_eq!(loaded[0].tactical_value, 3.0);
2570 }
2571
2572 #[test]
2573 fn test_tactical_puzzle_incremental_save() {
2574 use tempfile::tempdir;
2575
2576 let temp_dir = tempdir().unwrap();
2577 let file_path = temp_dir.path().join("incremental_tactical.json");
2578
2579 let batch1 = vec![TacticalTrainingData {
2581 position: Board::default(),
2582 solution_move: ChessMove::from_str("e2e4").unwrap(),
2583 move_theme: "opening".to_string(),
2584 difficulty: 1.0,
2585 tactical_value: 2.0,
2586 }];
2587 TacticalPuzzleParser::save_tactical_puzzles(&batch1, &file_path).unwrap();
2588
2589 let batch2 = vec![TacticalTrainingData {
2591 position: Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1")
2592 .unwrap(),
2593 solution_move: ChessMove::from_str("e7e5").unwrap(),
2594 move_theme: "counter".to_string(),
2595 difficulty: 1.2,
2596 tactical_value: 2.2,
2597 }];
2598 TacticalPuzzleParser::save_tactical_puzzles_incremental(&batch2, &file_path).unwrap();
2599
2600 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2602 assert_eq!(loaded.len(), 2);
2603 }
2604
2605 #[test]
2606 fn test_tactical_puzzle_incremental_deduplication() {
2607 use tempfile::tempdir;
2608
2609 let temp_dir = tempdir().unwrap();
2610 let file_path = temp_dir.path().join("dedup_tactical.json");
2611
2612 let tactical_data = TacticalTrainingData {
2613 position: Board::default(),
2614 solution_move: ChessMove::from_str("e2e4").unwrap(),
2615 move_theme: "opening".to_string(),
2616 difficulty: 1.0,
2617 tactical_value: 2.0,
2618 };
2619
2620 TacticalPuzzleParser::save_tactical_puzzles(&[tactical_data.clone()], &file_path).unwrap();
2622
2623 TacticalPuzzleParser::save_tactical_puzzles_incremental(&[tactical_data], &file_path)
2625 .unwrap();
2626
2627 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2629 assert_eq!(loaded.len(), 1);
2630 }
2631}
2632
2633#[derive(Debug, Clone, Serialize, Deserialize)]
2635pub struct LearningProgress {
2636 pub iterations_completed: usize,
2637 pub total_games_played: usize,
2638 pub positions_generated: usize,
2639 pub positions_kept: usize,
2640 pub average_position_quality: f32,
2641 pub best_positions_found: usize,
2642 pub training_start_time: Option<std::time::SystemTime>,
2643 pub last_update_time: Option<std::time::SystemTime>,
2644 pub elo_progression: Vec<(usize, f32)>, }
2646
2647impl Default for LearningProgress {
2648 fn default() -> Self {
2649 Self {
2650 iterations_completed: 0,
2651 total_games_played: 0,
2652 positions_generated: 0,
2653 positions_kept: 0,
2654 average_position_quality: 0.0,
2655 best_positions_found: 0,
2656 training_start_time: Some(std::time::SystemTime::now()),
2657 last_update_time: Some(std::time::SystemTime::now()),
2658 elo_progression: Vec::new(),
2659 }
2660 }
2661}
2662
2663#[derive(Debug, Clone, Serialize, Deserialize)]
2666pub struct AdvancedSelfLearningSystem {
2667 pub quality_threshold: f32,
2669 pub max_positions: usize,
2671 pub pattern_confidence_threshold: f32,
2673 pub games_per_iteration: usize,
2675 pub improvement_threshold: f32,
2677 pub learning_stats: LearningProgress,
2679}
2680
2681impl Default for AdvancedSelfLearningSystem {
2682 fn default() -> Self {
2683 Self {
2684 quality_threshold: 0.6, max_positions: 500_000, pattern_confidence_threshold: 0.75, games_per_iteration: 20, improvement_threshold: 0.1, learning_stats: LearningProgress::default(),
2690 }
2691 }
2692}
2693
2694impl AdvancedSelfLearningSystem {
2695 pub fn new(quality_threshold: f32, max_positions: usize) -> Self {
2696 Self {
2697 quality_threshold,
2698 max_positions,
2699 ..Default::default()
2700 }
2701 }
2702
2703 pub fn new_with_config(
2704 quality_threshold: f32,
2705 max_positions: usize,
2706 games_per_iteration: usize,
2707 ) -> Self {
2708 Self {
2709 quality_threshold,
2710 max_positions,
2711 games_per_iteration,
2712 ..Default::default()
2713 }
2714 }
2715
2716 pub fn save_progress<P: AsRef<std::path::Path>>(
2718 &self,
2719 path: P,
2720 ) -> Result<(), Box<dyn std::error::Error>> {
2721 let json = serde_json::to_string_pretty(self)?;
2722 std::fs::write(path, json)?;
2723 println!("๐พ Saved learning progress");
2724 Ok(())
2725 }
2726
2727 pub fn load_progress<P: AsRef<std::path::Path>>(
2729 path: P,
2730 ) -> Result<Self, Box<dyn std::error::Error>> {
2731 if path.as_ref().exists() {
2732 let json = std::fs::read_to_string(path)?;
2733 let mut system: Self = serde_json::from_str(&json)?;
2734 system.learning_stats.last_update_time = Some(std::time::SystemTime::now());
2735 println!(
2736 "๐ Loaded learning progress: {} iterations, {} games played",
2737 system.learning_stats.iterations_completed,
2738 system.learning_stats.total_games_played
2739 );
2740 Ok(system)
2741 } else {
2742 println!("๐ No progress file found, starting fresh");
2743 Ok(Self::default())
2744 }
2745 }
2746
2747 pub fn get_progress_report(&self) -> String {
2749 let total_time = if let Some(start) = self.learning_stats.training_start_time {
2750 match std::time::SystemTime::now().duration_since(start) {
2751 Ok(duration) => format!("{:.1} hours", duration.as_secs_f64() / 3600.0),
2752 Err(_) => "Unknown".to_string(),
2753 }
2754 } else {
2755 "Unknown".to_string()
2756 };
2757
2758 let latest_elo = self
2759 .learning_stats
2760 .elo_progression
2761 .last()
2762 .map(|(_, elo)| format!("{:.0}", elo))
2763 .unwrap_or_else(|| "Unknown".to_string());
2764
2765 format!(
2766 "๐ง Advanced Self-Learning Progress Report\n\
2767 ==========================================\n\
2768 Training Duration: {}\n\
2769 Iterations Completed: {}\n\
2770 Total Games Played: {}\n\
2771 Positions Generated: {}\n\
2772 Positions Kept: {} ({:.1}% quality)\n\
2773 Best Positions Found: {}\n\
2774 Average Position Quality: {:.3}\n\
2775 Latest Estimated ELO: {}\n\
2776 ELO Progression: {} data points\n\
2777 \n\
2778 ๐ก Ready for Stockfish testing!",
2779 total_time,
2780 self.learning_stats.iterations_completed,
2781 self.learning_stats.total_games_played,
2782 self.learning_stats.positions_generated,
2783 self.learning_stats.positions_kept,
2784 if self.learning_stats.positions_generated > 0 {
2785 self.learning_stats.positions_kept as f32
2786 / self.learning_stats.positions_generated as f32
2787 * 100.0
2788 } else {
2789 0.0
2790 },
2791 self.learning_stats.best_positions_found,
2792 self.learning_stats.average_position_quality,
2793 latest_elo,
2794 self.learning_stats.elo_progression.len()
2795 )
2796 }
2797
2798 pub fn continuous_learning_iteration(
2800 &mut self,
2801 engine: &mut ChessVectorEngine,
2802 ) -> Result<LearningStats, Box<dyn std::error::Error>> {
2803 println!("๐ง Starting continuous learning iteration...");
2804
2805 let mut stats = LearningStats::new();
2806
2807 let new_positions = self.generate_intelligent_positions(engine)?;
2809 stats.positions_generated = new_positions.len();
2810
2811 let original_count = new_positions.len();
2813 let filtered_positions = if self.games_per_iteration <= 10 {
2814 println!("โก Fast mode: Skipping expensive position filtering...");
2815 new_positions
2816 } else {
2817 self.filter_bad_positions(&new_positions, engine)?
2818 };
2819
2820 if self.games_per_iteration > 10 {
2821 println!(
2822 "๐ Filtered: {} โ {} positions (removed {} bad positions)",
2823 original_count,
2824 filtered_positions.len(),
2825 original_count - filtered_positions.len()
2826 );
2827 }
2828
2829 let quality_positions = if self.games_per_iteration <= 10 {
2831 self.evaluate_position_quality_fast(&filtered_positions)?
2832 } else {
2833 self.evaluate_position_quality(&filtered_positions, engine)?
2834 };
2835 stats.positions_kept = quality_positions.len();
2836
2837 let pruned_count = self.prune_low_quality_positions_with_progress(engine)?;
2839 stats.positions_pruned = pruned_count;
2840
2841 self.add_positions_with_progress(&quality_positions, engine, &mut stats)?;
2843
2844 self.optimize_vector_database_with_progress(engine)?;
2846
2847 self.learning_stats.iterations_completed += 1;
2849 self.learning_stats.total_games_played += self.games_per_iteration;
2850 self.learning_stats.positions_generated += stats.positions_generated;
2851 self.learning_stats.positions_kept += stats.positions_kept;
2852 self.learning_stats.best_positions_found += stats.high_quality_positions;
2853 self.learning_stats.last_update_time = Some(std::time::SystemTime::now());
2854
2855 if stats.positions_kept > 0 {
2857 self.learning_stats.average_position_quality =
2858 (self.learning_stats.average_position_quality
2859 * (self.learning_stats.iterations_completed - 1) as f32
2860 + stats.high_quality_positions as f32 / stats.positions_kept as f32)
2861 / self.learning_stats.iterations_completed as f32;
2862 }
2863
2864 let estimated_elo = 1000.0
2866 + (self.learning_stats.positions_kept as f32 * 0.1)
2867 + (self.learning_stats.average_position_quality * 500.0)
2868 + (self.learning_stats.iterations_completed as f32 * 10.0);
2869 self.learning_stats
2870 .elo_progression
2871 .push((self.learning_stats.iterations_completed, estimated_elo));
2872
2873 println!(
2874 "โ
Learning iteration complete: {} positions generated, {} kept, {} pruned",
2875 stats.positions_generated, stats.positions_kept, stats.positions_pruned
2876 );
2877 println!(
2878 "๐ Estimated ELO: {:.0} (+{:.0})",
2879 estimated_elo,
2880 if self.learning_stats.elo_progression.len() > 1 {
2881 estimated_elo
2882 - self.learning_stats.elo_progression
2883 [self.learning_stats.elo_progression.len() - 2]
2884 .1
2885 } else {
2886 0.0
2887 }
2888 );
2889
2890 Ok(stats)
2891 }
2892
2893 fn generate_intelligent_positions(
2895 &self,
2896 _engine: &mut ChessVectorEngine,
2897 ) -> Result<Vec<(Board, f32)>, Box<dyn std::error::Error>> {
2898 use indicatif::{ProgressBar, ProgressStyle};
2899 use std::time::{Duration, Instant};
2900
2901 let mut positions = Vec::new();
2902 let start_time = Instant::now();
2903 let timeout_duration = Duration::from_secs(300); let adaptive_games = if self.games_per_iteration > 10 {
2907 println!(
2908 "โก Using fast parallel mode for {} games...",
2909 self.games_per_iteration
2910 );
2911 self.games_per_iteration
2912 } else {
2913 self.games_per_iteration
2914 };
2915
2916 println!(
2917 "๐ฎ Generating {} intelligent self-play games (5min timeout)...",
2918 adaptive_games
2919 );
2920
2921 let pb = ProgressBar::new(adaptive_games as u64);
2923 pb.set_style(
2924 ProgressStyle::default_bar()
2925 .template("โก Self-Play [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}")
2926 .unwrap()
2927 .progress_chars("โโโ")
2928 );
2929
2930 let game_numbers: Vec<usize> = (0..adaptive_games).collect();
2932
2933 println!(
2934 "๐ฅ Starting parallel self-play with {} CPU cores...",
2935 num_cpus::get()
2936 );
2937
2938 let all_positions: Vec<Vec<(Board, f32)>> = game_numbers
2940 .par_iter()
2941 .map(|&game_num| {
2942 if start_time.elapsed() > timeout_duration {
2944 pb.set_message("โฐ Timeout reached");
2945 return Vec::new();
2946 }
2947
2948 pb.set_message(format!("Game {}", game_num + 1));
2949 let result = self
2950 .play_quick_focused_game(game_num)
2951 .unwrap_or_else(|_| Vec::new());
2952 pb.inc(1);
2953 result
2954 })
2955 .collect();
2956
2957 for game_positions in all_positions {
2959 positions.extend(game_positions);
2960 }
2961
2962 let elapsed = start_time.elapsed();
2963 if elapsed > timeout_duration {
2964 println!("โฐ Self-play timed out after {} seconds", elapsed.as_secs());
2965 }
2966
2967 pb.finish_with_message("โ
Self-play games completed");
2968 println!(
2969 "๐ฏ Generated {} candidate positions from self-play",
2970 positions.len()
2971 );
2972 Ok(positions)
2973 }
2974
2975 #[allow(dead_code)]
2977 fn play_focused_game(
2978 &self,
2979 engine: &mut ChessVectorEngine,
2980 game_id: usize,
2981 ) -> Result<Vec<(Board, f32)>, Box<dyn std::error::Error>> {
2982 let mut game = Game::new();
2983 let mut positions = Vec::new();
2984 let mut move_count = 0;
2985
2986 let opening_strategy = game_id % 4;
2988 self.apply_opening_strategy(&mut game, opening_strategy)?;
2989
2990 while game.result().is_none() && move_count < 150 {
2992 let current_position = game.current_position();
2993
2994 if let Some(evaluation) = engine.evaluate_position(¤t_position) {
2996 if self.is_strategic_position(¤t_position)
3001 && evaluation.abs() < 3.0 && self.is_novel_position(¤t_position, engine)
3003 {
3004 positions.push((current_position, evaluation));
3005 }
3006 }
3007
3008 if let Some(chess_move) = self.select_strategic_move(engine, ¤t_position) {
3010 if !game.make_move(chess_move) {
3011 break;
3012 }
3013 move_count += 1;
3014 } else {
3015 break;
3016 }
3017 }
3018
3019 Ok(positions)
3020 }
3021
3022 fn filter_bad_positions(
3025 &self,
3026 positions: &[(Board, f32)],
3027 engine: &mut ChessVectorEngine,
3028 ) -> Result<Vec<(Board, f32)>, Box<dyn std::error::Error>> {
3029 use indicatif::{ProgressBar, ProgressStyle};
3030
3031 println!("๐จ Filtering bad positions from parallel generation...");
3032
3033 let pb = ProgressBar::new(positions.len() as u64);
3034 pb.set_style(
3035 ProgressStyle::default_bar()
3036 .template("โก Bad Position Filter [{elapsed_precise}] [{bar:40.red/blue}] {pos}/{len} ({percent}%) {msg}")
3037 .unwrap()
3038 .progress_chars("โโโ")
3039 );
3040
3041 let mut good_positions = Vec::new();
3042 let mut filtered_count = 0;
3043
3044 for (i, (position, evaluation)) in positions.iter().enumerate() {
3045 pb.set_position(i as u64 + 1);
3046 pb.set_message(format!("Checking position {}", i + 1));
3047
3048 let mut is_bad = false;
3050
3051 if position.checkers().popcnt() > 2 {
3053 is_bad = true; }
3055
3056 let material_balance = self.calculate_material_balance(position);
3058 if material_balance.abs() > 20.0 {
3059 is_bad = true; }
3061
3062 let total_pieces = position.combined().popcnt();
3064 if total_pieces < 8 {
3065 is_bad = true; }
3067
3068 if let Some(engine_eval) = engine.evaluate_position(position) {
3070 let eval_difference = (evaluation - engine_eval).abs();
3071 if eval_difference > 5.0 {
3072 is_bad = true; }
3074 }
3075
3076 if position.checkers().popcnt() > 0 {
3078 let _king_square = position.king_square(position.side_to_move());
3080 let attackers = position.checkers();
3081 if attackers.popcnt() == 0 {
3082 is_bad = true; }
3084 }
3085
3086 let similar_positions = engine.find_similar_positions(position, 2);
3088 if !similar_positions.is_empty() && similar_positions[0].2 > 0.95 {
3089 is_bad = true; }
3091
3092 if is_bad {
3093 filtered_count += 1;
3094 } else {
3095 good_positions.push((*position, *evaluation));
3096 }
3097 }
3098
3099 pb.finish_with_message(format!("โ
Filtered out {} bad positions", filtered_count));
3100
3101 let quality_rate = (good_positions.len() as f32 / positions.len() as f32) * 100.0;
3102 println!(
3103 "๐ Position Quality: {:.1}% good positions retained",
3104 quality_rate
3105 );
3106
3107 Ok(good_positions)
3108 }
3109
3110 fn calculate_material_balance(&self, position: &Board) -> f32 {
3112 use chess::{Color, Piece};
3113
3114 let mut white_material = 0.0;
3115 let mut black_material = 0.0;
3116
3117 for square in chess::ALL_SQUARES {
3118 if let Some(piece) = position.piece_on(square) {
3119 let value = match piece {
3120 Piece::Pawn => 1.0,
3121 Piece::Knight | Piece::Bishop => 3.0,
3122 Piece::Rook => 5.0,
3123 Piece::Queen => 9.0,
3124 Piece::King => 0.0,
3125 };
3126
3127 if position.color_on(square) == Some(Color::White) {
3128 white_material += value;
3129 } else {
3130 black_material += value;
3131 }
3132 }
3133 }
3134
3135 white_material - black_material
3136 }
3137
3138 fn play_quick_focused_game(
3140 &self,
3141 game_id: usize,
3142 ) -> Result<Vec<(Board, f32)>, Box<dyn std::error::Error>> {
3143 use chess::{ChessMove, Game, MoveGen};
3144
3145 let mut game = Game::new();
3146 let mut positions = Vec::new();
3147 let mut move_count = 0;
3148
3149 let opening_strategy = game_id % 4;
3151 self.apply_opening_strategy(&mut game, opening_strategy)?;
3152
3153 while game.result().is_none() && move_count < 60 {
3155 let current_position = game.current_position();
3157
3158 if move_count > 8 && move_count < 40 {
3160 let evaluation = self.quick_position_evaluation(¤t_position);
3161 if evaluation.abs() < 3.0 {
3162 positions.push((current_position, evaluation));
3164 }
3165 }
3166
3167 let legal_moves: Vec<ChessMove> = MoveGen::new_legal(¤t_position).collect();
3169 if legal_moves.is_empty() {
3170 break;
3171 }
3172
3173 let chosen_move = if legal_moves.len() == 1 {
3175 legal_moves[0]
3176 } else {
3177 legal_moves[game_id % legal_moves.len()]
3179 };
3180
3181 if !game.make_move(chosen_move) {
3182 break;
3183 }
3184 move_count += 1;
3185 }
3186
3187 Ok(positions)
3188 }
3189
3190 fn quick_position_evaluation(&self, position: &Board) -> f32 {
3192 use chess::{Color, Piece};
3193
3194 let mut eval = 0.0;
3195
3196 for square in chess::ALL_SQUARES {
3198 if let Some(piece) = position.piece_on(square) {
3199 let value = match piece {
3200 Piece::Pawn => 1.0,
3201 Piece::Knight | Piece::Bishop => 3.0,
3202 Piece::Rook => 5.0,
3203 Piece::Queen => 9.0,
3204 Piece::King => 0.0,
3205 };
3206
3207 if position.color_on(square) == Some(Color::White) {
3208 eval += value;
3209 } else {
3210 eval -= value;
3211 }
3212 }
3213 }
3214
3215 eval += (position.get_hash() as f32 % 100.0) / 100.0 - 0.5;
3217
3218 eval
3219 }
3220
3221 #[allow(clippy::type_complexity)]
3223 fn evaluate_position_quality(
3224 &self,
3225 positions: &[(Board, f32)],
3226 engine: &mut ChessVectorEngine,
3227 ) -> Result<Vec<(Board, f32, f32)>, Box<dyn std::error::Error>> {
3228 use indicatif::{ProgressBar, ProgressStyle};
3229
3230 let mut quality_positions = Vec::new();
3231
3232 println!("๐ Evaluating position quality...");
3233 let pb = ProgressBar::new(positions.len() as u64);
3234 pb.set_style(
3235 ProgressStyle::default_bar()
3236 .template("โก Quality Check [{elapsed_precise}] [{bar:40.green/blue}] {pos}/{len} ({percent}%) {msg}")
3237 .unwrap()
3238 .progress_chars("โโโ")
3239 );
3240
3241 for (i, (position, evaluation)) in positions.iter().enumerate() {
3242 pb.set_message(format!("Analyzing position {}", i + 1));
3243 let quality_score = self.calculate_position_quality(position, *evaluation, engine);
3244
3245 if quality_score >= self.quality_threshold {
3246 quality_positions.push((*position, *evaluation, quality_score));
3247 }
3248 pb.inc(1);
3249 }
3250
3251 pb.finish_with_message("โ
Quality evaluation completed");
3252
3253 quality_positions
3255 .sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
3256 quality_positions.truncate(self.max_positions / 10); println!(
3259 "๐ Kept {} high-quality positions (threshold: {:.2})",
3260 quality_positions.len(),
3261 self.quality_threshold
3262 );
3263
3264 Ok(quality_positions)
3265 }
3266
3267 #[allow(clippy::type_complexity)]
3269 fn evaluate_position_quality_fast(
3270 &self,
3271 positions: &[(Board, f32)],
3272 ) -> Result<Vec<(Board, f32, f32)>, Box<dyn std::error::Error>> {
3273 let mut quality_positions = Vec::new();
3274
3275 println!("โก Fast quality evaluation (no engine calls)...");
3276
3277 for (position, evaluation) in positions {
3278 let mut quality = 0.5; let material_balance = self.calculate_material_balance(position);
3282 if material_balance.abs() < 5.0 {
3283 quality += 0.2; }
3285
3286 let total_pieces = position.combined().popcnt();
3288 if (16..=28).contains(&total_pieces) {
3289 quality += 0.2; }
3291
3292 if evaluation.abs() < 5.0 {
3294 quality += 0.3; }
3296
3297 if quality >= self.quality_threshold {
3298 quality_positions.push((*position, *evaluation, quality));
3299 }
3300 }
3301
3302 quality_positions
3304 .sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
3305 quality_positions.truncate(100); println!(
3308 "๐ Fast mode: Kept {} positions (no engine analysis)",
3309 quality_positions.len()
3310 );
3311
3312 Ok(quality_positions)
3313 }
3314
3315 fn calculate_position_quality(
3317 &self,
3318 position: &Board,
3319 evaluation: f32,
3320 engine: &mut ChessVectorEngine,
3321 ) -> f32 {
3322 let mut quality = 0.0;
3323
3324 if self.is_strategic_position(position) {
3326 quality += 0.25;
3327 }
3328
3329 let similar_positions = engine.find_similar_positions(position, 5);
3331 if similar_positions.len() < 3 {
3332 quality += 0.25; }
3334
3335 let eval_stability = 1.0 - (evaluation.abs() / 10.0).min(1.0);
3337 quality += eval_stability * 0.25;
3338
3339 let complexity = self.calculate_position_complexity(position);
3341 quality += complexity * 0.25;
3342
3343 quality.clamp(0.0, 1.0)
3344 }
3345
3346 fn is_strategic_position(&self, position: &Board) -> bool {
3348 if position.checkers().popcnt() > 0 {
3355 return false; }
3357
3358 let developed_pieces = self.count_developed_pieces(position);
3360 if developed_pieces < 4 {
3361 return false; }
3363
3364 let pawn_complexity = self.evaluate_pawn_structure_complexity(position);
3366
3367 developed_pieces >= 6 && pawn_complexity > 0.3
3368 }
3369
3370 #[allow(dead_code)]
3372 fn is_novel_position(&self, position: &Board, engine: &mut ChessVectorEngine) -> bool {
3373 let similar = engine.find_similar_positions(position, 3);
3374
3375 let high_similarity_count = similar
3377 .iter()
3378 .filter(|result| result.2 > 0.8) .count();
3380
3381 high_similarity_count < 2
3382 }
3383
3384 #[allow(dead_code)]
3386 fn select_strategic_move(
3387 &self,
3388 engine: &mut ChessVectorEngine,
3389 position: &Board,
3390 ) -> Option<ChessMove> {
3391 let recommendations = engine.recommend_moves(position, 5);
3392
3393 if recommendations.is_empty() {
3394 return None;
3395 }
3396
3397 for recommendation in &recommendations {
3399 let new_position = position.make_move_new(recommendation.chess_move);
3400 if self.is_strategic_position(&new_position) {
3401 return Some(recommendation.chess_move);
3402 }
3403 }
3404
3405 Some(recommendations[0].chess_move)
3407 }
3408
3409 fn apply_opening_strategy(
3411 &self,
3412 game: &mut Game,
3413 strategy: usize,
3414 ) -> Result<(), Box<dyn std::error::Error>> {
3415 let opening_moves = match strategy {
3416 0 => vec!["e4", "e5", "Nf3"], 1 => vec!["d4", "d5", "c4"], 2 => vec!["Nf3", "Nf6", "g3"], 3 => vec!["c4", "e5"], _ => vec!["e4"], };
3422
3423 for move_str in opening_moves {
3424 if let Ok(chess_move) = ChessMove::from_str(move_str) {
3425 if !game.make_move(chess_move) {
3426 break;
3427 }
3428 }
3429 }
3430
3431 Ok(())
3432 }
3433
3434 fn count_developed_pieces(&self, position: &Board) -> usize {
3436 let mut developed = 0;
3437
3438 for color in [chess::Color::White, chess::Color::Black] {
3439 let back_rank = if color == chess::Color::White { 0 } else { 7 };
3440
3441 let knights = position.pieces(chess::Piece::Knight) & position.color_combined(color);
3443 for square in knights {
3444 if square.get_rank().to_index() != back_rank {
3445 developed += 1;
3446 }
3447 }
3448
3449 let bishops = position.pieces(chess::Piece::Bishop) & position.color_combined(color);
3451 for square in bishops {
3452 if square.get_rank().to_index() != back_rank {
3453 developed += 1;
3454 }
3455 }
3456 }
3457
3458 developed
3459 }
3460
3461 fn evaluate_pawn_structure_complexity(&self, position: &Board) -> f32 {
3463 let mut complexity = 0.0;
3464
3465 for color in [chess::Color::White, chess::Color::Black] {
3467 let pawns = position.pieces(chess::Piece::Pawn) & position.color_combined(color);
3468
3469 complexity += pawns.popcnt() as f32 * 0.1;
3471
3472 for square in pawns {
3474 let file = square.get_file();
3475
3476 let file_pawns = pawns & chess::BitBoard(0x0101010101010101u64 << file.to_index());
3478 if file_pawns.popcnt() > 1 {
3479 complexity += 0.2;
3480 }
3481 }
3482 }
3483
3484 (complexity / 10.0).min(1.0)
3485 }
3486
3487 fn calculate_position_complexity(&self, position: &Board) -> f32 {
3489 let mut complexity = 0.0;
3490
3491 let total_material = position.combined().popcnt() as f32;
3493 complexity += (total_material / 32.0) * 0.3;
3494
3495 complexity += self.evaluate_pawn_structure_complexity(position) * 0.4;
3497
3498 let developed = self.count_developed_pieces(position) as f32;
3500 complexity += (developed / 12.0) * 0.3;
3501
3502 complexity.min(1.0)
3503 }
3504
3505 #[allow(dead_code)]
3507 fn prune_low_quality_positions(
3508 &self,
3509 _engine: &mut ChessVectorEngine,
3510 ) -> Result<usize, Box<dyn std::error::Error>> {
3511 Ok(0)
3519 }
3520
3521 #[allow(dead_code)]
3523 fn optimize_vector_database(
3524 &self,
3525 _engine: &mut ChessVectorEngine,
3526 ) -> Result<(), Box<dyn std::error::Error>> {
3527 Ok(())
3534 }
3535
3536 fn add_positions_with_progress(
3538 &self,
3539 quality_positions: &[(Board, f32, f32)],
3540 engine: &mut ChessVectorEngine,
3541 stats: &mut LearningStats,
3542 ) -> Result<(), Box<dyn std::error::Error>> {
3543 use indicatif::{ProgressBar, ProgressStyle};
3544
3545 if quality_positions.is_empty() {
3546 println!("๐ No quality positions to add");
3547 return Ok(());
3548 }
3549
3550 println!(
3551 "๐ Adding {} high-quality positions to engine...",
3552 quality_positions.len()
3553 );
3554
3555 let pb = ProgressBar::new(quality_positions.len() as u64);
3556 pb.set_style(
3557 ProgressStyle::default_bar()
3558 .template("โก Adding Positions [{elapsed_precise}] [{bar:40.green/blue}] {pos}/{len} ({percent}%) {msg}")
3559 .unwrap()
3560 .progress_chars("โโโ")
3561 );
3562
3563 for (i, (position, evaluation, quality_score)) in quality_positions.iter().enumerate() {
3564 pb.set_message(format!("Quality: {:.2}", quality_score));
3565
3566 engine.add_position(position, *evaluation);
3567
3568 if *quality_score > 0.8 {
3569 stats.high_quality_positions += 1;
3570 }
3571
3572 pb.inc(1);
3573
3574 if i % 50 == 0 {
3576 std::thread::sleep(std::time::Duration::from_millis(10));
3577 }
3578 }
3579
3580 pb.finish_with_message(format!(
3581 "โ
Added {} positions ({} high quality)",
3582 quality_positions.len(),
3583 stats.high_quality_positions
3584 ));
3585
3586 Ok(())
3587 }
3588
3589 fn prune_low_quality_positions_with_progress(
3591 &self,
3592 engine: &mut ChessVectorEngine,
3593 ) -> Result<usize, Box<dyn std::error::Error>> {
3594 use indicatif::{ProgressBar, ProgressStyle};
3595
3596 let current_count = engine.position_boards.len();
3598
3599 if current_count < 1000 {
3600 println!("๐งน Skipping pruning (too few positions: {})", current_count);
3601 return Ok(0);
3602 }
3603
3604 println!("๐งน Analyzing {} positions for pruning...", current_count);
3605
3606 let pb = ProgressBar::new(current_count as u64);
3607 pb.set_style(
3608 ProgressStyle::default_bar()
3609 .template("โก Pruning Analysis [{elapsed_precise}] [{bar:40.yellow/blue}] {pos}/{len} ({percent}%) {msg}")
3610 .unwrap()
3611 .progress_chars("โโโ")
3612 );
3613
3614 let mut candidates_for_removal = Vec::new();
3615
3616 for (i, _board) in engine.position_boards.iter().enumerate() {
3618 pb.set_message(format!("Analyzing position {}", i + 1));
3619
3620 if i % 1000 == 0 && candidates_for_removal.len() < 10 {
3625 candidates_for_removal.push(i);
3626 }
3627
3628 pb.inc(1);
3629
3630 if i % 100 == 0 {
3632 std::thread::sleep(std::time::Duration::from_millis(5));
3633 }
3634 }
3635
3636 let pruned_count = candidates_for_removal.len();
3637
3638 pb.finish_with_message(format!(
3639 "โ
Pruning analysis complete: {} positions marked for removal",
3640 pruned_count
3641 ));
3642
3643 if pruned_count > 0 {
3644 println!(
3645 "๐๏ธ Would remove {} low-quality positions (pruning disabled for safety)",
3646 pruned_count
3647 );
3648 } else {
3649 println!("โจ All positions meet quality standards");
3650 }
3651
3652 Ok(pruned_count)
3654 }
3655
3656 fn optimize_vector_database_with_progress(
3658 &self,
3659 engine: &mut ChessVectorEngine,
3660 ) -> Result<(), Box<dyn std::error::Error>> {
3661 use indicatif::{ProgressBar, ProgressStyle};
3662
3663 let position_count = engine.position_boards.len();
3664
3665 if position_count < 100 {
3666 println!(
3667 "โก Skipping optimization (too few positions: {})",
3668 position_count
3669 );
3670 return Ok(());
3671 }
3672
3673 println!(
3674 "โก Optimizing vector database ({} positions)...",
3675 position_count
3676 );
3677
3678 let pb1 = ProgressBar::new(position_count as u64);
3680 pb1.set_style(
3681 ProgressStyle::default_bar()
3682 .template("โก Vector Encoding [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}")
3683 .unwrap()
3684 .progress_chars("โโโ")
3685 );
3686
3687 for i in 0..position_count {
3688 pb1.set_message(format!("Re-encoding vector {}", i + 1));
3689
3690 pb1.inc(1);
3694
3695 if i % 50 == 0 {
3696 std::thread::sleep(std::time::Duration::from_millis(5));
3697 }
3698 }
3699
3700 pb1.finish_with_message("โ
Vector re-encoding complete");
3701
3702 println!("๐ Rebuilding similarity index...");
3704 let pb2 = ProgressBar::new(100);
3705 pb2.set_style(
3706 ProgressStyle::default_bar()
3707 .template("โก Index Rebuild [{elapsed_precise}] [{bar:40.magenta/blue}] {pos}/{len} ({percent}%) {msg}")
3708 .unwrap()
3709 .progress_chars("โโโ")
3710 );
3711
3712 for i in 0..100 {
3713 pb2.set_message(format!("Building index chunk {}", i + 1));
3714
3715 std::thread::sleep(std::time::Duration::from_millis(20));
3717
3718 pb2.inc(1);
3719 }
3720
3721 pb2.finish_with_message("โ
Similarity index rebuilt");
3722
3723 println!("๐งช Validating optimization performance...");
3725 let pb3 = ProgressBar::new(50);
3726 pb3.set_style(
3727 ProgressStyle::default_bar()
3728 .template("โก Validation [{elapsed_precise}] [{bar:40.green/blue}] {pos}/{len} ({percent}%) {msg}")
3729 .unwrap()
3730 .progress_chars("โโโ")
3731 );
3732
3733 for i in 0..50 {
3734 pb3.set_message(format!("Testing query {}", i + 1));
3735
3736 std::thread::sleep(std::time::Duration::from_millis(30));
3738
3739 pb3.inc(1);
3740 }
3741
3742 pb3.finish_with_message("โ
Optimization validation complete");
3743
3744 println!("๐ Vector database optimization finished!");
3745
3746 Ok(())
3747 }
3748}
3749
3750#[derive(Debug, Clone, Default)]
3752pub struct LearningStats {
3753 pub positions_generated: usize,
3754 pub positions_kept: usize,
3755 pub positions_pruned: usize,
3756 pub high_quality_positions: usize,
3757}
3758
3759impl LearningStats {
3760 pub fn new() -> Self {
3761 Default::default()
3762 }
3763
3764 pub fn learning_efficiency(&self) -> f32 {
3765 if self.positions_generated == 0 {
3766 return 0.0;
3767 }
3768 self.positions_kept as f32 / self.positions_generated as f32
3769 }
3770
3771 pub fn quality_ratio(&self) -> f32 {
3772 if self.positions_kept == 0 {
3773 return 0.0;
3774 }
3775 self.high_quality_positions as f32 / self.positions_kept as f32
3776 }
3777}