1use chess::{Board, ChessMove, Game, MoveGen};
2use indicatif::{ProgressBar, ProgressStyle};
3use pgn_reader::{BufferedReader, RawHeader, SanPlus, Skip, Visitor};
4use rayon::prelude::*;
5use serde::{Deserialize, Serialize};
6use std::fs::File;
7use std::io::{BufRead, BufReader, BufWriter, Write};
8use std::path::Path;
9use std::process::{Child, Command, Stdio};
10use std::str::FromStr;
11use std::sync::{Arc, Mutex};
12
13use crate::ChessVectorEngine;
14
15#[derive(Debug, Clone)]
17pub struct SelfPlayConfig {
18 pub games_per_iteration: usize,
20 pub max_moves_per_game: usize,
22 pub exploration_factor: f32,
24 pub min_confidence: f32,
26 pub use_opening_book: bool,
28 pub temperature: f32,
30}
31
32impl Default for SelfPlayConfig {
33 fn default() -> Self {
34 Self {
35 games_per_iteration: 100,
36 max_moves_per_game: 200,
37 exploration_factor: 0.3,
38 min_confidence: 0.1,
39 use_opening_book: true,
40 temperature: 0.8,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct TrainingData {
48 pub board: Board,
49 pub evaluation: f32,
50 pub depth: u8,
51 pub game_id: usize,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct TacticalPuzzle {
57 #[serde(rename = "PuzzleId")]
58 pub puzzle_id: String,
59 #[serde(rename = "FEN")]
60 pub fen: String,
61 #[serde(rename = "Moves")]
62 pub moves: String, #[serde(rename = "Rating")]
64 pub rating: u32,
65 #[serde(rename = "RatingDeviation")]
66 pub rating_deviation: u32,
67 #[serde(rename = "Popularity")]
68 pub popularity: i32,
69 #[serde(rename = "NbPlays")]
70 pub nb_plays: u32,
71 #[serde(rename = "Themes")]
72 pub themes: String, #[serde(rename = "GameUrl")]
74 pub game_url: Option<String>,
75 #[serde(rename = "OpeningTags")]
76 pub opening_tags: Option<String>,
77}
78
79#[derive(Debug, Clone)]
81pub struct TacticalTrainingData {
82 pub position: Board,
83 pub solution_move: ChessMove,
84 pub move_theme: String,
85 pub difficulty: f32, pub tactical_value: f32, }
88
89impl serde::Serialize for TacticalTrainingData {
91 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
92 where
93 S: serde::Serializer,
94 {
95 use serde::ser::SerializeStruct;
96 let mut state = serializer.serialize_struct("TacticalTrainingData", 5)?;
97 state.serialize_field("fen", &self.position.to_string())?;
98 state.serialize_field("solution_move", &self.solution_move.to_string())?;
99 state.serialize_field("move_theme", &self.move_theme)?;
100 state.serialize_field("difficulty", &self.difficulty)?;
101 state.serialize_field("tactical_value", &self.tactical_value)?;
102 state.end()
103 }
104}
105
106impl<'de> serde::Deserialize<'de> for TacticalTrainingData {
107 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108 where
109 D: serde::Deserializer<'de>,
110 {
111 use serde::de::{self, MapAccess, Visitor};
112 use std::fmt;
113
114 struct TacticalTrainingDataVisitor;
115
116 impl<'de> Visitor<'de> for TacticalTrainingDataVisitor {
117 type Value = TacticalTrainingData;
118
119 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
120 formatter.write_str("struct TacticalTrainingData")
121 }
122
123 fn visit_map<V>(self, mut map: V) -> Result<TacticalTrainingData, V::Error>
124 where
125 V: MapAccess<'de>,
126 {
127 let mut fen = None;
128 let mut solution_move = None;
129 let mut move_theme = None;
130 let mut difficulty = None;
131 let mut tactical_value = None;
132
133 while let Some(key) = map.next_key()? {
134 match key {
135 "fen" => {
136 if fen.is_some() {
137 return Err(de::Error::duplicate_field("fen"));
138 }
139 fen = Some(map.next_value()?);
140 }
141 "solution_move" => {
142 if solution_move.is_some() {
143 return Err(de::Error::duplicate_field("solution_move"));
144 }
145 solution_move = Some(map.next_value()?);
146 }
147 "move_theme" => {
148 if move_theme.is_some() {
149 return Err(de::Error::duplicate_field("move_theme"));
150 }
151 move_theme = Some(map.next_value()?);
152 }
153 "difficulty" => {
154 if difficulty.is_some() {
155 return Err(de::Error::duplicate_field("difficulty"));
156 }
157 difficulty = Some(map.next_value()?);
158 }
159 "tactical_value" => {
160 if tactical_value.is_some() {
161 return Err(de::Error::duplicate_field("tactical_value"));
162 }
163 tactical_value = Some(map.next_value()?);
164 }
165 _ => {
166 let _: serde_json::Value = map.next_value()?;
167 }
168 }
169 }
170
171 let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
172 let solution_move_str: String =
173 solution_move.ok_or_else(|| de::Error::missing_field("solution_move"))?;
174 let move_theme =
175 move_theme.ok_or_else(|| de::Error::missing_field("move_theme"))?;
176 let difficulty =
177 difficulty.ok_or_else(|| de::Error::missing_field("difficulty"))?;
178 let tactical_value =
179 tactical_value.ok_or_else(|| de::Error::missing_field("tactical_value"))?;
180
181 let position =
182 Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
183
184 let solution_move = ChessMove::from_str(&solution_move_str)
185 .map_err(|e| de::Error::custom(format!("Error: {e}")))?;
186
187 Ok(TacticalTrainingData {
188 position,
189 solution_move,
190 move_theme,
191 difficulty,
192 tactical_value,
193 })
194 }
195 }
196
197 const FIELDS: &[&str] = &[
198 "fen",
199 "solution_move",
200 "move_theme",
201 "difficulty",
202 "tactical_value",
203 ];
204 deserializer.deserialize_struct("TacticalTrainingData", FIELDS, TacticalTrainingDataVisitor)
205 }
206}
207
208pub struct GameExtractor {
210 pub positions: Vec<TrainingData>,
211 pub current_game: Game,
212 pub move_count: usize,
213 pub max_moves_per_game: usize,
214 pub game_id: usize,
215}
216
217impl GameExtractor {
218 pub fn new(max_moves_per_game: usize) -> Self {
219 Self {
220 positions: Vec::new(),
221 current_game: Game::new(),
222 move_count: 0,
223 max_moves_per_game,
224 game_id: 0,
225 }
226 }
227}
228
229impl Visitor for GameExtractor {
230 type Result = ();
231
232 fn begin_game(&mut self) {
233 self.current_game = Game::new();
234 self.move_count = 0;
235 self.game_id += 1;
236 }
237
238 fn header(&mut self, _key: &[u8], _value: RawHeader<'_>) {}
239
240 fn san(&mut self, san_plus: SanPlus) {
241 if self.move_count >= self.max_moves_per_game {
242 return;
243 }
244
245 let san_str = san_plus.san.to_string();
246
247 let current_pos = self.current_game.current_position();
249
250 match chess::ChessMove::from_san(¤t_pos, &san_str) {
252 Ok(chess_move) => {
253 let legal_moves: Vec<chess::ChessMove> = MoveGen::new_legal(¤t_pos).collect();
255 if legal_moves.contains(&chess_move) {
256 if self.current_game.make_move(chess_move) {
257 self.move_count += 1;
258
259 self.positions.push(TrainingData {
261 board: self.current_game.current_position(),
262 evaluation: 0.0, depth: 0,
264 game_id: self.game_id,
265 });
266 }
267 } else {
268 }
270 }
271 Err(_) => {
272 if !san_str.contains("O-O") && !san_str.contains("=") && san_str.len() > 6 {
276 }
278 }
279 }
280 }
281
282 fn begin_variation(&mut self) -> Skip {
283 Skip(true) }
285
286 fn end_game(&mut self) -> Self::Result {}
287}
288
289pub struct StockfishEvaluator {
291 depth: u8,
292}
293
294impl StockfishEvaluator {
295 pub fn new(depth: u8) -> Self {
296 Self { depth }
297 }
298
299 pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
301 let mut child = Command::new("stockfish")
302 .stdin(Stdio::piped())
303 .stdout(Stdio::piped())
304 .stderr(Stdio::piped())
305 .spawn()?;
306
307 let stdin = child
308 .stdin
309 .as_mut()
310 .ok_or("Failed to get stdin handle for Stockfish process")?;
311 let fen = board.to_string();
312
313 use std::io::Write;
315 writeln!(stdin, "uci")?;
316 writeln!(stdin, "isready")?;
317 writeln!(stdin, "position fen {fen}")?;
318 writeln!(stdin, "go depth {}", self.depth)?;
319 writeln!(stdin, "quit")?;
320
321 let output = child.wait_with_output()?;
322 let stdout = String::from_utf8_lossy(&output.stdout);
323
324 for line in stdout.lines() {
326 if line.starts_with("info") && line.contains("score cp") {
327 if let Some(cp_pos) = line.find("score cp ") {
328 let cp_str = &line[cp_pos + 9..];
329 if let Some(end) = cp_str.find(' ') {
330 let cp_value = cp_str[..end].parse::<i32>()?;
331 return Ok(cp_value as f32 / 100.0); }
333 }
334 } else if line.starts_with("info") && line.contains("score mate") {
335 if let Some(mate_pos) = line.find("score mate ") {
337 let mate_str = &line[mate_pos + 11..];
338 if let Some(end) = mate_str.find(' ') {
339 let mate_moves = mate_str[..end].parse::<i32>()?;
340 return Ok(if mate_moves > 0 { 100.0 } else { -100.0 });
341 }
342 }
343 }
344 }
345
346 Ok(0.0) }
348
349 pub fn evaluate_batch(
351 &self,
352 positions: &mut [TrainingData],
353 ) -> Result<(), Box<dyn std::error::Error>> {
354 let pb = ProgressBar::new(positions.len() as u64);
355 if let Ok(style) = ProgressStyle::default_bar().template(
356 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
357 ) {
358 pb.set_style(style.progress_chars("#>-"));
359 }
360
361 for data in positions.iter_mut() {
362 match self.evaluate_position(&data.board) {
363 Ok(eval) => {
364 data.evaluation = eval;
365 data.depth = self.depth;
366 }
367 Err(e) => {
368 eprintln!("Evaluation error: {e}");
369 data.evaluation = 0.0;
370 }
371 }
372 pb.inc(1);
373 }
374
375 pb.finish_with_message("Evaluation complete");
376 Ok(())
377 }
378
379 pub fn evaluate_batch_parallel(
381 &self,
382 positions: &mut [TrainingData],
383 num_threads: usize,
384 ) -> Result<(), Box<dyn std::error::Error>> {
385 let pb = ProgressBar::new(positions.len() as u64);
386 if let Ok(style) = ProgressStyle::default_bar()
387 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel evaluation") {
388 pb.set_style(style.progress_chars("#>-"));
389 }
390
391 let pool = rayon::ThreadPoolBuilder::new()
393 .num_threads(num_threads)
394 .build()?;
395
396 pool.install(|| {
397 positions.par_iter_mut().for_each(|data| {
399 match self.evaluate_position(&data.board) {
400 Ok(eval) => {
401 data.evaluation = eval;
402 data.depth = self.depth;
403 }
404 Err(_) => {
405 data.evaluation = 0.0;
407 }
408 }
409 pb.inc(1);
410 });
411 });
412
413 pb.finish_with_message("Parallel evaluation complete");
414 Ok(())
415 }
416}
417
418struct StockfishProcess {
420 child: Child,
421 stdin: BufWriter<std::process::ChildStdin>,
422 stdout: BufReader<std::process::ChildStdout>,
423 #[allow(dead_code)]
424 depth: u8,
425}
426
427impl StockfishProcess {
428 fn new(depth: u8) -> Result<Self, Box<dyn std::error::Error>> {
429 let mut child = Command::new("stockfish")
430 .stdin(Stdio::piped())
431 .stdout(Stdio::piped())
432 .stderr(Stdio::piped())
433 .spawn()?;
434
435 let stdin = BufWriter::new(
436 child
437 .stdin
438 .take()
439 .ok_or("Failed to get stdin handle for Stockfish process")?,
440 );
441 let stdout = BufReader::new(
442 child
443 .stdout
444 .take()
445 .ok_or("Failed to get stdout handle for Stockfish process")?,
446 );
447
448 let mut process = Self {
449 child,
450 stdin,
451 stdout,
452 depth,
453 };
454
455 process.send_command("uci")?;
457 process.wait_for_ready()?;
458 process.send_command("isready")?;
459 process.wait_for_ready()?;
460
461 Ok(process)
462 }
463
464 fn send_command(&mut self, command: &str) -> Result<(), Box<dyn std::error::Error>> {
465 writeln!(self.stdin, "{command}")?;
466 self.stdin.flush()?;
467 Ok(())
468 }
469
470 fn wait_for_ready(&mut self) -> Result<(), Box<dyn std::error::Error>> {
471 let mut line = String::new();
472 loop {
473 line.clear();
474 self.stdout.read_line(&mut line)?;
475 if line.trim() == "uciok" || line.trim() == "readyok" {
476 break;
477 }
478 }
479 Ok(())
480 }
481
482 fn evaluate_position(&mut self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
483 let fen = board.to_string();
484
485 self.send_command(&format!("position fen {fen}"))?;
487 self.send_command(&format!("position fen {fen}"))?;
488
489 let mut line = String::new();
491 let mut last_evaluation = 0.0;
492
493 loop {
494 line.clear();
495 self.stdout.read_line(&mut line)?;
496 let line = line.trim();
497
498 if line.starts_with("info") && line.contains("score cp") {
499 if let Some(cp_pos) = line.find("score cp ") {
500 let cp_str = &line[cp_pos + 9..];
501 if let Some(end) = cp_str.find(' ') {
502 if let Ok(cp_value) = cp_str[..end].parse::<i32>() {
503 last_evaluation = cp_value as f32 / 100.0;
504 }
505 }
506 }
507 } else if line.starts_with("info") && line.contains("score mate") {
508 if let Some(mate_pos) = line.find("score mate ") {
509 let mate_str = &line[mate_pos + 11..];
510 if let Some(end) = mate_str.find(' ') {
511 if let Ok(mate_moves) = mate_str[..end].parse::<i32>() {
512 last_evaluation = if mate_moves > 0 { 100.0 } else { -100.0 };
513 }
514 }
515 }
516 } else if line.starts_with("bestmove") {
517 break;
518 }
519 }
520
521 Ok(last_evaluation)
522 }
523}
524
525impl Drop for StockfishProcess {
526 fn drop(&mut self) {
527 let _ = self.send_command("quit");
528 let _ = self.child.wait();
529 }
530}
531
532pub struct StockfishPool {
534 pool: Arc<Mutex<Vec<StockfishProcess>>>,
535 depth: u8,
536 pool_size: usize,
537}
538
539impl StockfishPool {
540 pub fn new(depth: u8, pool_size: usize) -> Result<Self, Box<dyn std::error::Error>> {
541 let mut processes = Vec::with_capacity(pool_size);
542
543 println!("๐ Initializing Stockfish pool with {pool_size} processes...");
544
545 for i in 0..pool_size {
546 match StockfishProcess::new(depth) {
547 Ok(process) => {
548 processes.push(process);
549 if i % 2 == 1 {
550 print!(".");
551 let _ = std::io::stdout().flush(); }
553 }
554 Err(e) => {
555 eprintln!("Evaluation error: {e}");
556 return Err(e);
557 }
558 }
559 }
560
561 println!(" โ
Pool ready!");
562
563 Ok(Self {
564 pool: Arc::new(Mutex::new(processes)),
565 depth,
566 pool_size,
567 })
568 }
569
570 pub fn evaluate_position(&self, board: &Board) -> Result<f32, Box<dyn std::error::Error>> {
571 let mut process = {
573 let mut pool = self.pool.lock().unwrap();
574 if let Some(process) = pool.pop() {
575 process
576 } else {
577 StockfishProcess::new(self.depth)?
579 }
580 };
581
582 let result = process.evaluate_position(board);
584
585 {
587 let mut pool = self.pool.lock().unwrap();
588 if pool.len() < self.pool_size {
589 pool.push(process);
590 }
591 }
593
594 result
595 }
596
597 pub fn evaluate_batch_parallel(
598 &self,
599 positions: &mut [TrainingData],
600 ) -> Result<(), Box<dyn std::error::Error>> {
601 let pb = ProgressBar::new(positions.len() as u64);
602 pb.set_style(ProgressStyle::default_bar()
603 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Pool evaluation")
604 .unwrap()
605 .progress_chars("#>-"));
606
607 positions.par_iter_mut().for_each(|data| {
609 match self.evaluate_position(&data.board) {
610 Ok(eval) => {
611 data.evaluation = eval;
612 data.depth = self.depth;
613 }
614 Err(_) => {
615 data.evaluation = 0.0;
616 }
617 }
618 pb.inc(1);
619 });
620
621 pb.finish_with_message("Pool evaluation complete");
622 Ok(())
623 }
624}
625
626pub struct TrainingDataset {
628 pub data: Vec<TrainingData>,
629}
630
631impl Default for TrainingDataset {
632 fn default() -> Self {
633 Self::new()
634 }
635}
636
637impl TrainingDataset {
638 pub fn new() -> Self {
639 Self { data: Vec::new() }
640 }
641
642 pub fn load_from_pgn<P: AsRef<Path>>(
644 &mut self,
645 path: P,
646 max_games: Option<usize>,
647 max_moves_per_game: usize,
648 ) -> Result<(), Box<dyn std::error::Error>> {
649 use indicatif::{ProgressBar, ProgressStyle};
650
651 println!("๐ Reading PGN file...");
652 let file = File::open(path)?;
653 let reader = BufReader::new(file);
654
655 let mut games = Vec::new();
657 let mut current_game = String::new();
658 let mut games_collected = 0;
659
660 for line in reader.lines() {
661 let line = line?;
662 current_game.push_str(&line);
663 current_game.push('\n');
664
665 if line.trim().ends_with("1-0")
667 || line.trim().ends_with("0-1")
668 || line.trim().ends_with("1/2-1/2")
669 || line.trim().ends_with("*")
670 {
671 games.push(current_game.clone());
672 current_game.clear();
673 games_collected += 1;
674
675 if let Some(max) = max_games {
676 if games_collected >= max {
677 break;
678 }
679 }
680 }
681 }
682
683 println!(
684 "๐ฆ Collected {} games, processing in parallel...",
685 games.len()
686 );
687
688 let pb = ProgressBar::new(games.len() as u64);
690 pb.set_style(
691 ProgressStyle::default_bar()
692 .template("โก Processing [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}")
693 .unwrap()
694 .progress_chars("โโโ")
695 );
696
697 let all_positions: Vec<Vec<TrainingData>> = games
699 .par_iter()
700 .map(|game_pgn| {
701 pb.inc(1);
702 pb.set_message("Processing game");
703
704 let mut local_extractor = GameExtractor::new(max_moves_per_game);
706
707 let cursor = std::io::Cursor::new(game_pgn);
709 let mut reader = BufferedReader::new(cursor);
710
711 if let Err(e) = reader.read_all(&mut local_extractor) {
712 eprintln!("Parse error: {e}");
713 return Vec::new(); }
715
716 local_extractor.positions
717 })
718 .collect();
719
720 pb.finish_with_message("โ
Parallel processing completed");
721
722 for game_positions in all_positions {
724 self.data.extend(game_positions);
725 }
726
727 println!(
728 "โ
Loaded {} positions from {} games (parallel processing)",
729 self.data.len(),
730 games.len()
731 );
732 Ok(())
733 }
734
735 pub fn evaluate_with_stockfish(&mut self, depth: u8) -> Result<(), Box<dyn std::error::Error>> {
737 let evaluator = StockfishEvaluator::new(depth);
738 evaluator.evaluate_batch(&mut self.data)
739 }
740
741 pub fn evaluate_with_stockfish_parallel(
743 &mut self,
744 depth: u8,
745 num_threads: usize,
746 ) -> Result<(), Box<dyn std::error::Error>> {
747 let evaluator = StockfishEvaluator::new(depth);
748 evaluator.evaluate_batch_parallel(&mut self.data, num_threads)
749 }
750
751 pub fn train_engine(&self, engine: &mut ChessVectorEngine) {
753 let pb = ProgressBar::new(self.data.len() as u64);
754 pb.set_style(ProgressStyle::default_bar()
755 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Training positions")
756 .unwrap()
757 .progress_chars("#>-"));
758
759 for data in &self.data {
760 engine.add_position(&data.board, data.evaluation);
761 pb.inc(1);
762 }
763
764 pb.finish_with_message("Training complete");
765 println!("Trained engine with {} positions", self.data.len());
766 }
767
768 pub fn split(&self, train_ratio: f32) -> (TrainingDataset, TrainingDataset) {
770 use rand::seq::SliceRandom;
771 use rand::thread_rng;
772 use std::collections::{HashMap, HashSet};
773
774 let mut games: HashMap<usize, Vec<&TrainingData>> = HashMap::new();
776 for data in &self.data {
777 games.entry(data.game_id).or_default().push(data);
778 }
779
780 let mut game_ids: Vec<usize> = games.keys().cloned().collect();
782 game_ids.shuffle(&mut thread_rng());
783
784 let split_point = (game_ids.len() as f32 * train_ratio) as usize;
786 let train_game_ids: HashSet<usize> = game_ids[..split_point].iter().cloned().collect();
787
788 let mut train_data = Vec::new();
790 let mut test_data = Vec::new();
791
792 for data in &self.data {
793 if train_game_ids.contains(&data.game_id) {
794 train_data.push(data.clone());
795 } else {
796 test_data.push(data.clone());
797 }
798 }
799
800 (
801 TrainingDataset { data: train_data },
802 TrainingDataset { data: test_data },
803 )
804 }
805
806 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Box<dyn std::error::Error>> {
808 let json = serde_json::to_string_pretty(&self.data)?;
809 std::fs::write(path, json)?;
810 Ok(())
811 }
812
813 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
815 let content = std::fs::read_to_string(path)?;
816 let data = serde_json::from_str(&content)?;
817 Ok(Self { data })
818 }
819
820 pub fn load_and_append<P: AsRef<Path>>(
822 &mut self,
823 path: P,
824 ) -> Result<(), Box<dyn std::error::Error>> {
825 let existing_len = self.data.len();
826 let additional_data = Self::load(path)?;
827 self.data.extend(additional_data.data);
828 println!(
829 "Loaded {} additional positions (total: {})",
830 self.data.len() - existing_len,
831 self.data.len()
832 );
833 Ok(())
834 }
835
836 pub fn merge(&mut self, other: TrainingDataset) {
838 let existing_len = self.data.len();
839 self.data.extend(other.data);
840 println!(
841 "Merged {} positions (total: {})",
842 self.data.len() - existing_len,
843 self.data.len()
844 );
845 }
846
847 pub fn save_incremental<P: AsRef<Path>>(
849 &self,
850 path: P,
851 ) -> Result<(), Box<dyn std::error::Error>> {
852 self.save_incremental_with_options(path, true)
853 }
854
855 pub fn save_incremental_with_options<P: AsRef<Path>>(
857 &self,
858 path: P,
859 deduplicate: bool,
860 ) -> Result<(), Box<dyn std::error::Error>> {
861 let path = path.as_ref();
862
863 if path.exists() {
864 if self.save_append_only(path).is_ok() {
866 return Ok(());
867 }
868
869 if deduplicate {
871 self.save_incremental_full_merge(path)
872 } else {
873 self.save_incremental_no_dedup(path)
874 }
875 } else {
876 self.save(path)
878 }
879 }
880
881 fn save_incremental_no_dedup<P: AsRef<Path>>(
883 &self,
884 path: P,
885 ) -> Result<(), Box<dyn std::error::Error>> {
886 let path = path.as_ref();
887
888 println!("๐ Loading existing training data...");
889 let mut existing = Self::load(path)?;
890
891 println!("โก Fast merge without deduplication...");
892 existing.data.extend(self.data.iter().cloned());
893
894 println!(
895 "๐พ Serializing {} positions to JSON...",
896 existing.data.len()
897 );
898 let json = serde_json::to_string_pretty(&existing.data)?;
899
900 println!("โ๏ธ Writing to disk...");
901 std::fs::write(path, json)?;
902
903 println!(
904 "โ
Fast merge save: total {} positions",
905 existing.data.len()
906 );
907 Ok(())
908 }
909
910 pub fn save_append_only<P: AsRef<Path>>(
912 &self,
913 path: P,
914 ) -> Result<(), Box<dyn std::error::Error>> {
915 use std::fs::OpenOptions;
916 use std::io::{BufRead, BufReader, Seek, SeekFrom, Write};
917
918 if self.data.is_empty() {
919 return Ok(());
920 }
921
922 let path = path.as_ref();
923 let mut file = OpenOptions::new().read(true).write(true).open(path)?;
924
925 file.seek(SeekFrom::End(-10))?;
927 let mut buffer = String::new();
928 BufReader::new(&file).read_line(&mut buffer)?;
929
930 if !buffer.trim().ends_with(']') {
931 return Err("File doesn't end with JSON array bracket".into());
932 }
933
934 file.seek(SeekFrom::End(-2))?; write!(file, ",")?;
939
940 for (i, data) in self.data.iter().enumerate() {
942 if i > 0 {
943 write!(file, ",")?;
944 }
945 let json = serde_json::to_string(data)?;
946 write!(file, "{json}")?;
947 }
948
949 write!(file, "\n]")?;
951
952 println!("Fast append: added {} new positions", self.data.len());
953 Ok(())
954 }
955
956 fn save_incremental_full_merge<P: AsRef<Path>>(
958 &self,
959 path: P,
960 ) -> Result<(), Box<dyn std::error::Error>> {
961 let path = path.as_ref();
962
963 println!("๐ Loading existing training data...");
964 let mut existing = Self::load(path)?;
965 let _original_len = existing.data.len();
966
967 println!("๐ Streaming merge with deduplication (avoiding O(nยฒ) operation)...");
968 existing.merge_and_deduplicate(self.data.clone());
969
970 println!(
971 "๐พ Serializing {} positions to JSON...",
972 existing.data.len()
973 );
974 let json = serde_json::to_string_pretty(&existing.data)?;
975
976 println!("โ๏ธ Writing to disk...");
977 std::fs::write(path, json)?;
978
979 println!(
980 "โ
Streaming merge save: total {} positions",
981 existing.data.len()
982 );
983 Ok(())
984 }
985
986 pub fn add_position(&mut self, board: Board, evaluation: f32, depth: u8, game_id: usize) {
988 self.data.push(TrainingData {
989 board,
990 evaluation,
991 depth,
992 game_id,
993 });
994 }
995
996 pub fn next_game_id(&self) -> usize {
998 self.data.iter().map(|data| data.game_id).max().unwrap_or(0) + 1
999 }
1000
1001 pub fn deduplicate(&mut self, similarity_threshold: f32) {
1003 if similarity_threshold > 0.999 {
1004 self.deduplicate_fast();
1006 } else {
1007 self.deduplicate_similarity_based(similarity_threshold);
1009 }
1010 }
1011
1012 pub fn deduplicate_fast(&mut self) {
1014 use std::collections::HashSet;
1015
1016 if self.data.is_empty() {
1017 return;
1018 }
1019
1020 let mut seen_positions = HashSet::with_capacity(self.data.len());
1021 let original_len = self.data.len();
1022
1023 self.data.retain(|data| {
1025 let fen = data.board.to_string();
1026 seen_positions.insert(fen)
1027 });
1028
1029 println!(
1030 "Fast deduplicated: {} -> {} positions (removed {} exact duplicates)",
1031 original_len,
1032 self.data.len(),
1033 original_len - self.data.len()
1034 );
1035 }
1036
1037 pub fn merge_and_deduplicate(&mut self, new_data: Vec<TrainingData>) {
1039 use std::collections::HashSet;
1040
1041 if new_data.is_empty() {
1042 return;
1043 }
1044
1045 let _original_len = self.data.len();
1046
1047 let mut existing_positions: HashSet<String> = HashSet::with_capacity(self.data.len());
1049 for data in &self.data {
1050 existing_positions.insert(data.board.to_string());
1051 }
1052
1053 let mut added = 0;
1055 for data in new_data {
1056 let fen = data.board.to_string();
1057 if existing_positions.insert(fen) {
1058 self.data.push(data);
1059 added += 1;
1060 }
1061 }
1062
1063 println!(
1064 "Streaming merge: added {} unique positions (total: {})",
1065 added,
1066 self.data.len()
1067 );
1068 }
1069
1070 fn deduplicate_similarity_based(&mut self, similarity_threshold: f32) {
1072 use crate::PositionEncoder;
1073 use ndarray::Array1;
1074
1075 if self.data.is_empty() {
1076 return;
1077 }
1078
1079 let encoder = PositionEncoder::new(1024);
1080 let mut keep_indices: Vec<bool> = vec![true; self.data.len()];
1081
1082 let vectors: Vec<Array1<f32>> = if self.data.len() > 50 {
1084 self.data
1085 .par_iter()
1086 .map(|data| encoder.encode(&data.board))
1087 .collect()
1088 } else {
1089 self.data
1090 .iter()
1091 .map(|data| encoder.encode(&data.board))
1092 .collect()
1093 };
1094
1095 for i in 1..self.data.len() {
1097 if !keep_indices[i] {
1098 continue;
1099 }
1100
1101 for j in 0..i {
1102 if !keep_indices[j] {
1103 continue;
1104 }
1105
1106 let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1107 if similarity > similarity_threshold {
1108 keep_indices[i] = false;
1109 break;
1110 }
1111 }
1112 }
1113
1114 let original_len = self.data.len();
1116 self.data = self
1117 .data
1118 .iter()
1119 .enumerate()
1120 .filter_map(|(i, data)| {
1121 if keep_indices[i] {
1122 Some(data.clone())
1123 } else {
1124 None
1125 }
1126 })
1127 .collect();
1128
1129 println!(
1130 "Similarity deduplicated: {} -> {} positions (removed {} near-duplicates)",
1131 original_len,
1132 self.data.len(),
1133 original_len - self.data.len()
1134 );
1135 }
1136
1137 pub fn deduplicate_parallel(&mut self, similarity_threshold: f32, chunk_size: usize) {
1139 use crate::PositionEncoder;
1140 use ndarray::Array1;
1141 use std::sync::{Arc, Mutex};
1142
1143 if self.data.is_empty() {
1144 return;
1145 }
1146
1147 let encoder = PositionEncoder::new(1024);
1148
1149 let vectors: Vec<Array1<f32>> = self
1151 .data
1152 .par_iter()
1153 .map(|data| encoder.encode(&data.board))
1154 .collect();
1155
1156 let keep_indices = Arc::new(Mutex::new(vec![true; self.data.len()]));
1157
1158 (1..self.data.len())
1160 .collect::<Vec<_>>()
1161 .par_chunks(chunk_size)
1162 .for_each(|chunk| {
1163 for &i in chunk {
1164 {
1166 let indices = keep_indices.lock().unwrap();
1167 if !indices[i] {
1168 continue;
1169 }
1170 }
1171
1172 for j in 0..i {
1174 {
1175 let indices = keep_indices.lock().unwrap();
1176 if !indices[j] {
1177 continue;
1178 }
1179 }
1180
1181 let similarity = Self::cosine_similarity(&vectors[i], &vectors[j]);
1182 if similarity > similarity_threshold {
1183 let mut indices = keep_indices.lock().unwrap();
1184 indices[i] = false;
1185 break;
1186 }
1187 }
1188 }
1189 });
1190
1191 let keep_indices = keep_indices.lock().unwrap();
1193 let original_len = self.data.len();
1194 self.data = self
1195 .data
1196 .iter()
1197 .enumerate()
1198 .filter_map(|(i, data)| {
1199 if keep_indices[i] {
1200 Some(data.clone())
1201 } else {
1202 None
1203 }
1204 })
1205 .collect();
1206
1207 println!(
1208 "Parallel deduplicated: {} -> {} positions (removed {} duplicates)",
1209 original_len,
1210 self.data.len(),
1211 original_len - self.data.len()
1212 );
1213 }
1214
1215 fn cosine_similarity(a: &ndarray::Array1<f32>, b: &ndarray::Array1<f32>) -> f32 {
1217 let dot_product = a.dot(b);
1218 let norm_a = a.dot(a).sqrt();
1219 let norm_b = b.dot(b).sqrt();
1220
1221 if norm_a == 0.0 || norm_b == 0.0 {
1222 0.0
1223 } else {
1224 dot_product / (norm_a * norm_b)
1225 }
1226 }
1227}
1228
1229pub struct SelfPlayTrainer {
1231 config: SelfPlayConfig,
1232 game_counter: usize,
1233}
1234
1235impl SelfPlayTrainer {
1236 pub fn new(config: SelfPlayConfig) -> Self {
1237 Self {
1238 config,
1239 game_counter: 0,
1240 }
1241 }
1242
1243 pub fn generate_training_data(&mut self, engine: &mut ChessVectorEngine) -> TrainingDataset {
1245 let mut dataset = TrainingDataset::new();
1246
1247 println!(
1248 "๐ฎ Starting self-play training with {} games...",
1249 self.config.games_per_iteration
1250 );
1251 let pb = ProgressBar::new(self.config.games_per_iteration as u64);
1252 if let Ok(style) = ProgressStyle::default_bar().template(
1253 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
1254 ) {
1255 pb.set_style(style.progress_chars("#>-"));
1256 }
1257
1258 for _ in 0..self.config.games_per_iteration {
1259 let game_data = self.play_single_game(engine);
1260 dataset.data.extend(game_data);
1261 self.game_counter += 1;
1262 pb.inc(1);
1263 }
1264
1265 pb.finish_with_message("Self-play games completed");
1266 println!(
1267 "โ
Generated {} positions from {} games",
1268 dataset.data.len(),
1269 self.config.games_per_iteration
1270 );
1271
1272 dataset
1273 }
1274
1275 fn play_single_game(&self, engine: &mut ChessVectorEngine) -> Vec<TrainingData> {
1277 let mut game = Game::new();
1278 let mut positions = Vec::new();
1279 let mut move_count = 0;
1280
1281 if self.config.use_opening_book {
1283 if let Some(opening_moves) = self.get_random_opening() {
1284 for mv in opening_moves {
1285 if game.make_move(mv) {
1286 move_count += 1;
1287 } else {
1288 break;
1289 }
1290 }
1291 }
1292 }
1293
1294 while game.result().is_none() && move_count < self.config.max_moves_per_game {
1296 let current_position = game.current_position();
1297
1298 let move_choice = self.select_move_with_exploration(engine, ¤t_position);
1300
1301 if let Some(chess_move) = move_choice {
1302 if let Some(evaluation) = engine.evaluate_position(¤t_position) {
1304 if evaluation.abs() >= self.config.min_confidence || move_count < 10 {
1306 positions.push(TrainingData {
1307 board: current_position,
1308 evaluation,
1309 depth: 1, game_id: self.game_counter,
1311 });
1312 }
1313 }
1314
1315 if !game.make_move(chess_move) {
1317 break; }
1319 move_count += 1;
1320 } else {
1321 break; }
1323 }
1324
1325 if let Some(result) = game.result() {
1327 let final_position = game.current_position();
1328 let final_eval = match result {
1329 chess::GameResult::WhiteCheckmates => {
1330 if final_position.side_to_move() == chess::Color::Black {
1331 10.0
1332 } else {
1333 -10.0
1334 }
1335 }
1336 chess::GameResult::BlackCheckmates => {
1337 if final_position.side_to_move() == chess::Color::White {
1338 10.0
1339 } else {
1340 -10.0
1341 }
1342 }
1343 chess::GameResult::WhiteResigns => -10.0,
1344 chess::GameResult::BlackResigns => 10.0,
1345 chess::GameResult::Stalemate
1346 | chess::GameResult::DrawAccepted
1347 | chess::GameResult::DrawDeclared => 0.0,
1348 };
1349
1350 positions.push(TrainingData {
1351 board: final_position,
1352 evaluation: final_eval,
1353 depth: 1,
1354 game_id: self.game_counter,
1355 });
1356 }
1357
1358 positions
1359 }
1360
1361 fn select_move_with_exploration(
1363 &self,
1364 engine: &mut ChessVectorEngine,
1365 position: &Board,
1366 ) -> Option<ChessMove> {
1367 let recommendations = engine.recommend_legal_moves(position, 5);
1368
1369 if recommendations.is_empty() {
1370 return None;
1371 }
1372
1373 if fastrand::f32() < self.config.exploration_factor {
1375 self.select_move_with_temperature(&recommendations)
1377 } else {
1378 Some(recommendations[0].chess_move)
1380 }
1381 }
1382
1383 fn select_move_with_temperature(
1385 &self,
1386 recommendations: &[crate::MoveRecommendation],
1387 ) -> Option<ChessMove> {
1388 if recommendations.is_empty() {
1389 return None;
1390 }
1391
1392 let mut probabilities = Vec::new();
1394 let mut sum = 0.0;
1395
1396 for rec in recommendations {
1397 let prob = (rec.average_outcome / self.config.temperature).exp();
1399 probabilities.push(prob);
1400 sum += prob;
1401 }
1402
1403 for prob in &mut probabilities {
1405 *prob /= sum;
1406 }
1407
1408 let rand_val = fastrand::f32();
1410 let mut cumulative = 0.0;
1411
1412 for (i, &prob) in probabilities.iter().enumerate() {
1413 cumulative += prob;
1414 if rand_val <= cumulative {
1415 return Some(recommendations[i].chess_move);
1416 }
1417 }
1418
1419 Some(recommendations[0].chess_move)
1421 }
1422
1423 fn get_random_opening(&self) -> Option<Vec<ChessMove>> {
1425 let openings = [
1426 vec!["e4", "e5", "Nf3", "Nc6", "Bc4"],
1428 vec!["e4", "e5", "Nf3", "Nc6", "Bb5"],
1430 vec!["d4", "d5", "c4"],
1432 vec!["d4", "Nf6", "c4", "g6"],
1434 vec!["e4", "c5"],
1436 vec!["e4", "e6"],
1438 vec!["e4", "c6"],
1440 ];
1441
1442 let selected_opening = &openings[fastrand::usize(0..openings.len())];
1443
1444 let mut moves = Vec::new();
1445 let mut game = Game::new();
1446
1447 for move_str in selected_opening {
1448 if let Ok(chess_move) = ChessMove::from_str(move_str) {
1449 if game.make_move(chess_move) {
1450 moves.push(chess_move);
1451 } else {
1452 break;
1453 }
1454 }
1455 }
1456
1457 if moves.is_empty() {
1458 None
1459 } else {
1460 Some(moves)
1461 }
1462 }
1463}
1464
1465pub struct EngineEvaluator {
1467 #[allow(dead_code)]
1468 stockfish_depth: u8,
1469}
1470
1471impl EngineEvaluator {
1472 pub fn new(stockfish_depth: u8) -> Self {
1473 Self { stockfish_depth }
1474 }
1475
1476 pub fn evaluate_accuracy(
1478 &self,
1479 engine: &mut ChessVectorEngine,
1480 test_data: &TrainingDataset,
1481 ) -> Result<f32, Box<dyn std::error::Error>> {
1482 let mut total_error = 0.0;
1483 let mut valid_comparisons = 0;
1484
1485 let pb = ProgressBar::new(test_data.data.len() as u64);
1486 pb.set_style(ProgressStyle::default_bar()
1487 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Evaluating accuracy")
1488 .unwrap()
1489 .progress_chars("#>-"));
1490
1491 for data in &test_data.data {
1492 if let Some(engine_eval) = engine.evaluate_position(&data.board) {
1493 let error = (engine_eval - data.evaluation).abs();
1494 total_error += error;
1495 valid_comparisons += 1;
1496 }
1497 pb.inc(1);
1498 }
1499
1500 pb.finish_with_message("Accuracy evaluation complete");
1501
1502 if valid_comparisons > 0 {
1503 let mean_absolute_error = total_error / valid_comparisons as f32;
1504 println!("Mean Absolute Error: {mean_absolute_error:.3} pawns");
1505 println!("Evaluated {valid_comparisons} positions");
1506 Ok(mean_absolute_error)
1507 } else {
1508 Ok(f32::INFINITY)
1509 }
1510 }
1511}
1512
1513impl serde::Serialize for TrainingData {
1515 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1516 where
1517 S: serde::Serializer,
1518 {
1519 use serde::ser::SerializeStruct;
1520 let mut state = serializer.serialize_struct("TrainingData", 4)?;
1521 state.serialize_field("fen", &self.board.to_string())?;
1522 state.serialize_field("evaluation", &self.evaluation)?;
1523 state.serialize_field("depth", &self.depth)?;
1524 state.serialize_field("game_id", &self.game_id)?;
1525 state.end()
1526 }
1527}
1528
1529impl<'de> serde::Deserialize<'de> for TrainingData {
1530 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1531 where
1532 D: serde::Deserializer<'de>,
1533 {
1534 use serde::de::{self, MapAccess, Visitor};
1535 use std::fmt;
1536
1537 struct TrainingDataVisitor;
1538
1539 impl<'de> Visitor<'de> for TrainingDataVisitor {
1540 type Value = TrainingData;
1541
1542 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1543 formatter.write_str("struct TrainingData")
1544 }
1545
1546 fn visit_map<V>(self, mut map: V) -> Result<TrainingData, V::Error>
1547 where
1548 V: MapAccess<'de>,
1549 {
1550 let mut fen = None;
1551 let mut evaluation = None;
1552 let mut depth = None;
1553 let mut game_id = None;
1554
1555 while let Some(key) = map.next_key()? {
1556 match key {
1557 "fen" => {
1558 if fen.is_some() {
1559 return Err(de::Error::duplicate_field("fen"));
1560 }
1561 fen = Some(map.next_value()?);
1562 }
1563 "evaluation" => {
1564 if evaluation.is_some() {
1565 return Err(de::Error::duplicate_field("evaluation"));
1566 }
1567 evaluation = Some(map.next_value()?);
1568 }
1569 "depth" => {
1570 if depth.is_some() {
1571 return Err(de::Error::duplicate_field("depth"));
1572 }
1573 depth = Some(map.next_value()?);
1574 }
1575 "game_id" => {
1576 if game_id.is_some() {
1577 return Err(de::Error::duplicate_field("game_id"));
1578 }
1579 game_id = Some(map.next_value()?);
1580 }
1581 _ => {
1582 let _: serde_json::Value = map.next_value()?;
1583 }
1584 }
1585 }
1586
1587 let fen: String = fen.ok_or_else(|| de::Error::missing_field("fen"))?;
1588 let mut evaluation: f32 =
1589 evaluation.ok_or_else(|| de::Error::missing_field("evaluation"))?;
1590 let depth = depth.ok_or_else(|| de::Error::missing_field("depth"))?;
1591 let game_id = game_id.unwrap_or(0); if evaluation.abs() > 15.0 {
1597 evaluation /= 100.0;
1598 }
1599
1600 let board =
1601 Board::from_str(&fen).map_err(|e| de::Error::custom(format!("Error: {e}")))?;
1602
1603 Ok(TrainingData {
1604 board,
1605 evaluation,
1606 depth,
1607 game_id,
1608 })
1609 }
1610 }
1611
1612 const FIELDS: &[&str] = &["fen", "evaluation", "depth", "game_id"];
1613 deserializer.deserialize_struct("TrainingData", FIELDS, TrainingDataVisitor)
1614 }
1615}
1616
1617pub struct TacticalPuzzleParser;
1619
1620impl TacticalPuzzleParser {
1621 pub fn parse_csv<P: AsRef<Path>>(
1623 file_path: P,
1624 max_puzzles: Option<usize>,
1625 min_rating: Option<u32>,
1626 max_rating: Option<u32>,
1627 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1628 let file = File::open(&file_path)?;
1629 let file_size = file.metadata()?.len();
1630
1631 if file_size > 100_000_000 {
1633 Self::parse_csv_parallel(file_path, max_puzzles, min_rating, max_rating)
1634 } else {
1635 Self::parse_csv_sequential(file_path, max_puzzles, min_rating, max_rating)
1636 }
1637 }
1638
1639 fn parse_csv_sequential<P: AsRef<Path>>(
1641 file_path: P,
1642 max_puzzles: Option<usize>,
1643 min_rating: Option<u32>,
1644 max_rating: Option<u32>,
1645 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1646 let file = File::open(file_path)?;
1647 let reader = BufReader::new(file);
1648
1649 let mut csv_reader = csv::ReaderBuilder::new()
1652 .has_headers(false)
1653 .flexible(true) .from_reader(reader);
1655
1656 let mut tactical_data = Vec::new();
1657 let mut processed = 0;
1658 let mut skipped = 0;
1659
1660 let pb = ProgressBar::new_spinner();
1661 pb.set_style(
1662 ProgressStyle::default_spinner()
1663 .template("{spinner:.green} Parsing tactical puzzles: {pos} (skipped: {skipped})")
1664 .unwrap(),
1665 );
1666
1667 for result in csv_reader.records() {
1668 let record = match result {
1669 Ok(r) => r,
1670 Err(e) => {
1671 skipped += 1;
1672 println!("CSV parsing error: {e}");
1673 continue;
1674 }
1675 };
1676
1677 if let Some(puzzle_data) = Self::parse_csv_record(&record, min_rating, max_rating) {
1678 if let Some(tactical_data_item) =
1679 Self::convert_puzzle_to_training_data(&puzzle_data)
1680 {
1681 tactical_data.push(tactical_data_item);
1682 processed += 1;
1683
1684 if let Some(max) = max_puzzles {
1685 if processed >= max {
1686 break;
1687 }
1688 }
1689 } else {
1690 skipped += 1;
1691 }
1692 } else {
1693 skipped += 1;
1694 }
1695
1696 pb.set_message(format!(
1697 "Parsing tactical puzzles: {processed} (skipped: {skipped})"
1698 ));
1699 }
1700
1701 pb.finish_with_message(format!("Parsed {processed} puzzles (skipped: {skipped})"));
1702
1703 Ok(tactical_data)
1704 }
1705
1706 fn parse_csv_parallel<P: AsRef<Path>>(
1708 file_path: P,
1709 max_puzzles: Option<usize>,
1710 min_rating: Option<u32>,
1711 max_rating: Option<u32>,
1712 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
1713 use std::io::Read;
1714
1715 let mut file = File::open(&file_path)?;
1716
1717 let mut contents = String::new();
1719 file.read_to_string(&mut contents)?;
1720
1721 let lines: Vec<&str> = contents.lines().collect();
1723
1724 let pb = ProgressBar::new(lines.len() as u64);
1725 pb.set_style(ProgressStyle::default_bar()
1726 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Parallel CSV parsing")
1727 .unwrap()
1728 .progress_chars("#>-"));
1729
1730 let tactical_data: Vec<TacticalTrainingData> = lines
1732 .par_iter()
1733 .take(max_puzzles.unwrap_or(usize::MAX))
1734 .filter_map(|line| {
1735 let fields: Vec<&str> = line.split(',').collect();
1737 if fields.len() < 8 {
1738 return None;
1739 }
1740
1741 if let Some(puzzle_data) = Self::parse_csv_fields(&fields, min_rating, max_rating) {
1743 Self::convert_puzzle_to_training_data(&puzzle_data)
1744 } else {
1745 None
1746 }
1747 })
1748 .collect();
1749
1750 pb.finish_with_message(format!(
1751 "Parallel parsing complete: {} puzzles",
1752 tactical_data.len()
1753 ));
1754
1755 Ok(tactical_data)
1756 }
1757
1758 fn parse_csv_record(
1760 record: &csv::StringRecord,
1761 min_rating: Option<u32>,
1762 max_rating: Option<u32>,
1763 ) -> Option<TacticalPuzzle> {
1764 if record.len() < 8 {
1766 return None;
1767 }
1768
1769 let rating: u32 = record[3].parse().ok()?;
1770 let rating_deviation: u32 = record[4].parse().ok()?;
1771 let popularity: i32 = record[5].parse().ok()?;
1772 let nb_plays: u32 = record[6].parse().ok()?;
1773
1774 if let Some(min) = min_rating {
1776 if rating < min {
1777 return None;
1778 }
1779 }
1780 if let Some(max) = max_rating {
1781 if rating > max {
1782 return None;
1783 }
1784 }
1785
1786 Some(TacticalPuzzle {
1787 puzzle_id: record[0].to_string(),
1788 fen: record[1].to_string(),
1789 moves: record[2].to_string(),
1790 rating,
1791 rating_deviation,
1792 popularity,
1793 nb_plays,
1794 themes: record[7].to_string(),
1795 game_url: if record.len() > 8 {
1796 Some(record[8].to_string())
1797 } else {
1798 None
1799 },
1800 opening_tags: if record.len() > 9 {
1801 Some(record[9].to_string())
1802 } else {
1803 None
1804 },
1805 })
1806 }
1807
1808 fn parse_csv_fields(
1810 fields: &[&str],
1811 min_rating: Option<u32>,
1812 max_rating: Option<u32>,
1813 ) -> Option<TacticalPuzzle> {
1814 if fields.len() < 8 {
1815 return None;
1816 }
1817
1818 let rating: u32 = fields[3].parse().ok()?;
1819 let rating_deviation: u32 = fields[4].parse().ok()?;
1820 let popularity: i32 = fields[5].parse().ok()?;
1821 let nb_plays: u32 = fields[6].parse().ok()?;
1822
1823 if let Some(min) = min_rating {
1825 if rating < min {
1826 return None;
1827 }
1828 }
1829 if let Some(max) = max_rating {
1830 if rating > max {
1831 return None;
1832 }
1833 }
1834
1835 Some(TacticalPuzzle {
1836 puzzle_id: fields[0].to_string(),
1837 fen: fields[1].to_string(),
1838 moves: fields[2].to_string(),
1839 rating,
1840 rating_deviation,
1841 popularity,
1842 nb_plays,
1843 themes: fields[7].to_string(),
1844 game_url: if fields.len() > 8 {
1845 Some(fields[8].to_string())
1846 } else {
1847 None
1848 },
1849 opening_tags: if fields.len() > 9 {
1850 Some(fields[9].to_string())
1851 } else {
1852 None
1853 },
1854 })
1855 }
1856
1857 fn convert_puzzle_to_training_data(puzzle: &TacticalPuzzle) -> Option<TacticalTrainingData> {
1859 let position = match Board::from_str(&puzzle.fen) {
1861 Ok(board) => board,
1862 Err(_) => return None,
1863 };
1864
1865 let moves: Vec<&str> = puzzle.moves.split_whitespace().collect();
1867 if moves.is_empty() {
1868 return None;
1869 }
1870
1871 let solution_move = match ChessMove::from_str(moves[0]) {
1873 Ok(mv) => mv,
1874 Err(_) => {
1875 match ChessMove::from_san(&position, moves[0]) {
1877 Ok(mv) => mv,
1878 Err(_) => return None,
1879 }
1880 }
1881 };
1882
1883 let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&position).collect();
1885 if !legal_moves.contains(&solution_move) {
1886 return None;
1887 }
1888
1889 let themes: Vec<&str> = puzzle.themes.split_whitespace().collect();
1891 let primary_theme = themes.first().unwrap_or(&"tactical").to_string();
1892
1893 let difficulty = puzzle.rating as f32 / 1000.0; let popularity_bonus = (puzzle.popularity as f32 / 100.0).min(2.0);
1896 let tactical_value = difficulty + popularity_bonus; Some(TacticalTrainingData {
1899 position,
1900 solution_move,
1901 move_theme: primary_theme,
1902 difficulty,
1903 tactical_value,
1904 })
1905 }
1906
1907 pub fn load_into_engine(
1909 tactical_data: &[TacticalTrainingData],
1910 engine: &mut ChessVectorEngine,
1911 ) {
1912 let pb = ProgressBar::new(tactical_data.len() as u64);
1913 pb.set_style(ProgressStyle::default_bar()
1914 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns")
1915 .unwrap()
1916 .progress_chars("#>-"));
1917
1918 for data in tactical_data {
1919 engine.add_position_with_move(
1921 &data.position,
1922 0.0, Some(data.solution_move),
1924 Some(data.tactical_value), );
1926 pb.inc(1);
1927 }
1928
1929 pb.finish_with_message(format!("Loaded {} tactical patterns", tactical_data.len()));
1930 }
1931
1932 pub fn load_into_engine_incremental(
1934 tactical_data: &[TacticalTrainingData],
1935 engine: &mut ChessVectorEngine,
1936 ) {
1937 let initial_size = engine.knowledge_base_size();
1938 let initial_moves = engine.position_moves.len();
1939
1940 if tactical_data.len() > 1000 {
1942 Self::load_into_engine_incremental_parallel(
1943 tactical_data,
1944 engine,
1945 initial_size,
1946 initial_moves,
1947 );
1948 } else {
1949 Self::load_into_engine_incremental_sequential(
1950 tactical_data,
1951 engine,
1952 initial_size,
1953 initial_moves,
1954 );
1955 }
1956 }
1957
1958 fn load_into_engine_incremental_sequential(
1960 tactical_data: &[TacticalTrainingData],
1961 engine: &mut ChessVectorEngine,
1962 initial_size: usize,
1963 initial_moves: usize,
1964 ) {
1965 let pb = ProgressBar::new(tactical_data.len() as u64);
1966 pb.set_style(ProgressStyle::default_bar()
1967 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Loading tactical patterns (incremental)")
1968 .unwrap()
1969 .progress_chars("#>-"));
1970
1971 let mut added = 0;
1972 let mut skipped = 0;
1973
1974 for data in tactical_data {
1975 if !engine.position_boards.contains(&data.position) {
1977 engine.add_position_with_move(
1978 &data.position,
1979 0.0, Some(data.solution_move),
1981 Some(data.tactical_value), );
1983 added += 1;
1984 } else {
1985 skipped += 1;
1986 }
1987 pb.inc(1);
1988 }
1989
1990 pb.finish_with_message(format!(
1991 "Loaded {} new tactical patterns (skipped {} duplicates, total: {})",
1992 added,
1993 skipped,
1994 engine.knowledge_base_size()
1995 ));
1996
1997 println!("Incremental tactical training:");
1998 println!(
1999 " - Positions: {} โ {} (+{})",
2000 initial_size,
2001 engine.knowledge_base_size(),
2002 engine.knowledge_base_size() - initial_size
2003 );
2004 println!(
2005 " - Move entries: {} โ {} (+{})",
2006 initial_moves,
2007 engine.position_moves.len(),
2008 engine.position_moves.len() - initial_moves
2009 );
2010 }
2011
2012 fn load_into_engine_incremental_parallel(
2014 tactical_data: &[TacticalTrainingData],
2015 engine: &mut ChessVectorEngine,
2016 initial_size: usize,
2017 initial_moves: usize,
2018 ) {
2019 let pb = ProgressBar::new(tactical_data.len() as u64);
2020 pb.set_style(ProgressStyle::default_bar()
2021 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} Optimized batch loading tactical patterns")
2022 .unwrap()
2023 .progress_chars("#>-"));
2024
2025 let filtered_data: Vec<&TacticalTrainingData> = tactical_data
2027 .par_iter()
2028 .filter(|data| !engine.position_boards.contains(&data.position))
2029 .collect();
2030
2031 let batch_size = 1000; let mut added = 0;
2033
2034 println!(
2035 "Pre-filtered: {} โ {} positions (removed {} duplicates)",
2036 tactical_data.len(),
2037 filtered_data.len(),
2038 tactical_data.len() - filtered_data.len()
2039 );
2040
2041 for batch in filtered_data.chunks(batch_size) {
2044 let batch_start = added;
2045
2046 for data in batch {
2047 if !engine.position_boards.contains(&data.position) {
2049 engine.add_position_with_move(
2050 &data.position,
2051 0.0, Some(data.solution_move),
2053 Some(data.tactical_value), );
2055 added += 1;
2056 }
2057 pb.inc(1);
2058 }
2059
2060 pb.set_message(format!("Loaded batch: {} positions", added - batch_start));
2062 }
2063
2064 let skipped = tactical_data.len() - added;
2065
2066 pb.finish_with_message(format!(
2067 "Optimized loaded {} new tactical patterns (skipped {} duplicates, total: {})",
2068 added,
2069 skipped,
2070 engine.knowledge_base_size()
2071 ));
2072
2073 println!("Incremental tactical training (optimized):");
2074 println!(
2075 " - Positions: {} โ {} (+{})",
2076 initial_size,
2077 engine.knowledge_base_size(),
2078 engine.knowledge_base_size() - initial_size
2079 );
2080 println!(
2081 " - Move entries: {} โ {} (+{})",
2082 initial_moves,
2083 engine.position_moves.len(),
2084 engine.position_moves.len() - initial_moves
2085 );
2086 println!(
2087 " - Batch size: {}, Pre-filtered efficiency: {:.1}%",
2088 batch_size,
2089 (filtered_data.len() as f32 / tactical_data.len() as f32) * 100.0
2090 );
2091 }
2092
2093 pub fn save_tactical_puzzles<P: AsRef<std::path::Path>>(
2095 tactical_data: &[TacticalTrainingData],
2096 path: P,
2097 ) -> Result<(), Box<dyn std::error::Error>> {
2098 let json = serde_json::to_string_pretty(tactical_data)?;
2099 std::fs::write(path, json)?;
2100 println!("Saved {} tactical puzzles", tactical_data.len());
2101 Ok(())
2102 }
2103
2104 pub fn load_tactical_puzzles<P: AsRef<std::path::Path>>(
2106 path: P,
2107 ) -> Result<Vec<TacticalTrainingData>, Box<dyn std::error::Error>> {
2108 let content = std::fs::read_to_string(path)?;
2109 let tactical_data: Vec<TacticalTrainingData> = serde_json::from_str(&content)?;
2110 println!("Loaded {} tactical puzzles from file", tactical_data.len());
2111 Ok(tactical_data)
2112 }
2113
2114 pub fn save_tactical_puzzles_incremental<P: AsRef<std::path::Path>>(
2116 tactical_data: &[TacticalTrainingData],
2117 path: P,
2118 ) -> Result<(), Box<dyn std::error::Error>> {
2119 let path = path.as_ref();
2120
2121 if path.exists() {
2122 let mut existing = Self::load_tactical_puzzles(path)?;
2124 let original_len = existing.len();
2125
2126 for new_puzzle in tactical_data {
2128 let exists = existing.iter().any(|existing_puzzle| {
2130 existing_puzzle.position == new_puzzle.position
2131 && existing_puzzle.solution_move == new_puzzle.solution_move
2132 });
2133
2134 if !exists {
2135 existing.push(new_puzzle.clone());
2136 }
2137 }
2138
2139 let json = serde_json::to_string_pretty(&existing)?;
2141 std::fs::write(path, json)?;
2142
2143 println!(
2144 "Incremental save: added {} new puzzles (total: {})",
2145 existing.len() - original_len,
2146 existing.len()
2147 );
2148 } else {
2149 Self::save_tactical_puzzles(tactical_data, path)?;
2151 }
2152 Ok(())
2153 }
2154
2155 pub fn parse_and_load_incremental<P: AsRef<std::path::Path>>(
2157 file_path: P,
2158 engine: &mut ChessVectorEngine,
2159 max_puzzles: Option<usize>,
2160 min_rating: Option<u32>,
2161 max_rating: Option<u32>,
2162 ) -> Result<(), Box<dyn std::error::Error>> {
2163 println!("Parsing Lichess puzzles incrementally...");
2164
2165 let tactical_data = Self::parse_csv(file_path, max_puzzles, min_rating, max_rating)?;
2167
2168 Self::load_into_engine_incremental(&tactical_data, engine);
2170
2171 Ok(())
2172 }
2173}
2174
2175#[cfg(test)]
2176mod tests {
2177 use super::*;
2178 use chess::Board;
2179 use std::str::FromStr;
2180
2181 #[test]
2182 fn test_training_dataset_creation() {
2183 let dataset = TrainingDataset::new();
2184 assert_eq!(dataset.data.len(), 0);
2185 }
2186
2187 #[test]
2188 fn test_add_training_data() {
2189 let mut dataset = TrainingDataset::new();
2190 let board = Board::default();
2191
2192 let training_data = TrainingData {
2193 board,
2194 evaluation: 0.5,
2195 depth: 15,
2196 game_id: 1,
2197 };
2198
2199 dataset.data.push(training_data);
2200 assert_eq!(dataset.data.len(), 1);
2201 assert_eq!(dataset.data[0].evaluation, 0.5);
2202 }
2203
2204 #[test]
2205 fn test_chess_engine_integration() {
2206 let mut dataset = TrainingDataset::new();
2207 let board = Board::default();
2208
2209 let training_data = TrainingData {
2210 board,
2211 evaluation: 0.3,
2212 depth: 15,
2213 game_id: 1,
2214 };
2215
2216 dataset.data.push(training_data);
2217
2218 let mut engine = ChessVectorEngine::new(1024);
2219 dataset.train_engine(&mut engine);
2220
2221 assert_eq!(engine.knowledge_base_size(), 1);
2222
2223 let eval = engine.evaluate_position(&board);
2224 assert!(eval.is_some());
2225 let eval_value = eval.unwrap();
2227 assert!(eval_value > -1000.0 && eval_value < 1000.0, "Evaluation should be reasonable: {}", eval_value);
2228 }
2229
2230 #[test]
2231 fn test_deduplication() {
2232 let mut dataset = TrainingDataset::new();
2233 let board = Board::default();
2234
2235 for i in 0..5 {
2237 let training_data = TrainingData {
2238 board,
2239 evaluation: i as f32 * 0.1,
2240 depth: 15,
2241 game_id: i,
2242 };
2243 dataset.data.push(training_data);
2244 }
2245
2246 assert_eq!(dataset.data.len(), 5);
2247
2248 dataset.deduplicate(0.999);
2250 assert_eq!(dataset.data.len(), 1);
2251 }
2252
2253 #[test]
2254 fn test_dataset_serialization() {
2255 let mut dataset = TrainingDataset::new();
2256 let board =
2257 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2258
2259 let training_data = TrainingData {
2260 board,
2261 evaluation: 0.2,
2262 depth: 10,
2263 game_id: 42,
2264 };
2265
2266 dataset.data.push(training_data);
2267
2268 let json = serde_json::to_string(&dataset.data).unwrap();
2270 let loaded_data: Vec<TrainingData> = serde_json::from_str(&json).unwrap();
2271 let loaded_dataset = TrainingDataset { data: loaded_data };
2272
2273 assert_eq!(loaded_dataset.data.len(), 1);
2274 assert_eq!(loaded_dataset.data[0].evaluation, 0.2);
2275 assert_eq!(loaded_dataset.data[0].depth, 10);
2276 assert_eq!(loaded_dataset.data[0].game_id, 42);
2277 }
2278
2279 #[test]
2280 fn test_tactical_puzzle_processing() {
2281 let puzzle = TacticalPuzzle {
2282 puzzle_id: "test123".to_string(),
2283 fen: "r1bqkbnr/pppp1ppp/2n5/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4".to_string(),
2284 moves: "Bxf7+ Ke7".to_string(),
2285 rating: 1500,
2286 rating_deviation: 100,
2287 popularity: 150,
2288 nb_plays: 1000,
2289 themes: "fork pin".to_string(),
2290 game_url: None,
2291 opening_tags: None,
2292 };
2293
2294 let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2295 assert!(tactical_data.is_some());
2296
2297 let data = tactical_data.unwrap();
2298 assert_eq!(data.move_theme, "fork");
2299 assert!(data.tactical_value > 1.0); assert!(data.difficulty > 0.0);
2301 }
2302
2303 #[test]
2304 fn test_tactical_puzzle_invalid_fen() {
2305 let puzzle = TacticalPuzzle {
2306 puzzle_id: "test123".to_string(),
2307 fen: "invalid_fen".to_string(),
2308 moves: "e2e4".to_string(),
2309 rating: 1500,
2310 rating_deviation: 100,
2311 popularity: 150,
2312 nb_plays: 1000,
2313 themes: "tactics".to_string(),
2314 game_url: None,
2315 opening_tags: None,
2316 };
2317
2318 let tactical_data = TacticalPuzzleParser::convert_puzzle_to_training_data(&puzzle);
2319 assert!(tactical_data.is_none());
2320 }
2321
2322 #[test]
2323 fn test_engine_evaluator() {
2324 let evaluator = EngineEvaluator::new(15);
2325
2326 let mut dataset = TrainingDataset::new();
2328 let board = Board::default();
2329
2330 let training_data = TrainingData {
2331 board,
2332 evaluation: 0.0,
2333 depth: 15,
2334 game_id: 1,
2335 };
2336
2337 dataset.data.push(training_data);
2338
2339 let mut engine = ChessVectorEngine::new(1024);
2341 engine.add_position(&board, 0.1);
2342
2343 let accuracy = evaluator.evaluate_accuracy(&mut engine, &dataset);
2345 assert!(accuracy.is_ok());
2346 let accuracy_value = accuracy.unwrap();
2349 assert!(accuracy_value >= 0.0, "Accuracy should be non-negative: {}", accuracy_value);
2350 }
2351
2352 #[test]
2353 fn test_tactical_training_integration() {
2354 let tactical_data = vec![TacticalTrainingData {
2355 position: Board::default(),
2356 solution_move: ChessMove::from_str("e2e4").unwrap(),
2357 move_theme: "opening".to_string(),
2358 difficulty: 1.2,
2359 tactical_value: 2.5,
2360 }];
2361
2362 let mut engine = ChessVectorEngine::new(1024);
2363 TacticalPuzzleParser::load_into_engine(&tactical_data, &mut engine);
2364
2365 assert_eq!(engine.knowledge_base_size(), 1);
2366 assert_eq!(engine.position_moves.len(), 1);
2367
2368 let recommendations = engine.recommend_moves(&Board::default(), 5);
2370 assert!(!recommendations.is_empty());
2371 }
2372
2373 #[test]
2374 fn test_multithreading_operations() {
2375 let mut dataset = TrainingDataset::new();
2376 let board = Board::default();
2377
2378 for i in 0..10 {
2380 let training_data = TrainingData {
2381 board,
2382 evaluation: i as f32 * 0.1,
2383 depth: 15,
2384 game_id: i,
2385 };
2386 dataset.data.push(training_data);
2387 }
2388
2389 dataset.deduplicate_parallel(0.95, 5);
2391 assert!(dataset.data.len() <= 10);
2392 }
2393
2394 #[test]
2395 fn test_incremental_dataset_operations() {
2396 let mut dataset1 = TrainingDataset::new();
2397 let board1 = Board::default();
2398 let board2 =
2399 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap();
2400
2401 dataset1.add_position(board1, 0.0, 15, 1);
2403 dataset1.add_position(board2, 0.2, 15, 2);
2404 assert_eq!(dataset1.data.len(), 2);
2405
2406 let mut dataset2 = TrainingDataset::new();
2408 dataset2.add_position(
2409 Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2410 .unwrap(),
2411 0.3,
2412 15,
2413 3,
2414 );
2415
2416 dataset1.merge(dataset2);
2418 assert_eq!(dataset1.data.len(), 3);
2419
2420 let next_id = dataset1.next_game_id();
2422 assert_eq!(next_id, 4); }
2424
2425 #[test]
2426 fn test_save_load_incremental() {
2427 use tempfile::tempdir;
2428
2429 let temp_dir = tempdir().unwrap();
2430 let file_path = temp_dir.path().join("incremental_test.json");
2431
2432 let mut dataset1 = TrainingDataset::new();
2434 dataset1.add_position(Board::default(), 0.0, 15, 1);
2435 dataset1.save(&file_path).unwrap();
2436
2437 let mut dataset2 = TrainingDataset::new();
2439 dataset2.add_position(
2440 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1").unwrap(),
2441 0.2,
2442 15,
2443 2,
2444 );
2445 dataset2.save_incremental(&file_path).unwrap();
2446
2447 let loaded = TrainingDataset::load(&file_path).unwrap();
2449 assert_eq!(loaded.data.len(), 2);
2450
2451 let mut dataset3 = TrainingDataset::new();
2453 dataset3.add_position(
2454 Board::from_str("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
2455 .unwrap(),
2456 0.3,
2457 15,
2458 3,
2459 );
2460 dataset3.load_and_append(&file_path).unwrap();
2461 assert_eq!(dataset3.data.len(), 3); }
2463
2464 #[test]
2465 fn test_add_position_method() {
2466 let mut dataset = TrainingDataset::new();
2467 let board = Board::default();
2468
2469 dataset.add_position(board, 0.5, 20, 42);
2471 assert_eq!(dataset.data.len(), 1);
2472 assert_eq!(dataset.data[0].evaluation, 0.5);
2473 assert_eq!(dataset.data[0].depth, 20);
2474 assert_eq!(dataset.data[0].game_id, 42);
2475 }
2476
2477 #[test]
2478 fn test_incremental_save_deduplication() {
2479 use tempfile::tempdir;
2480
2481 let temp_dir = tempdir().unwrap();
2482 let file_path = temp_dir.path().join("dedup_test.json");
2483
2484 let mut dataset1 = TrainingDataset::new();
2486 dataset1.add_position(Board::default(), 0.0, 15, 1);
2487 dataset1.save(&file_path).unwrap();
2488
2489 let mut dataset2 = TrainingDataset::new();
2491 dataset2.add_position(Board::default(), 0.1, 15, 2); dataset2.save_incremental(&file_path).unwrap();
2493
2494 let loaded = TrainingDataset::load(&file_path).unwrap();
2496 assert_eq!(loaded.data.len(), 1);
2497 }
2498
2499 #[test]
2500 fn test_tactical_puzzle_incremental_loading() {
2501 let tactical_data = vec![
2502 TacticalTrainingData {
2503 position: Board::default(),
2504 solution_move: ChessMove::from_str("e2e4").unwrap(),
2505 move_theme: "opening".to_string(),
2506 difficulty: 1.2,
2507 tactical_value: 2.5,
2508 },
2509 TacticalTrainingData {
2510 position: Board::from_str(
2511 "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1",
2512 )
2513 .unwrap(),
2514 solution_move: ChessMove::from_str("e7e5").unwrap(),
2515 move_theme: "opening".to_string(),
2516 difficulty: 1.0,
2517 tactical_value: 2.0,
2518 },
2519 ];
2520
2521 let mut engine = ChessVectorEngine::new(1024);
2522
2523 engine.add_position(&Board::default(), 0.1);
2525 assert_eq!(engine.knowledge_base_size(), 1);
2526
2527 TacticalPuzzleParser::load_into_engine_incremental(&tactical_data, &mut engine);
2529
2530 assert_eq!(engine.knowledge_base_size(), 2);
2532
2533 assert!(engine.training_stats().has_move_data);
2535 assert!(engine.training_stats().move_data_entries > 0);
2536 }
2537
2538 #[test]
2539 fn test_tactical_puzzle_serialization() {
2540 use tempfile::tempdir;
2541
2542 let temp_dir = tempdir().unwrap();
2543 let file_path = temp_dir.path().join("tactical_test.json");
2544
2545 let tactical_data = vec![TacticalTrainingData {
2546 position: Board::default(),
2547 solution_move: ChessMove::from_str("e2e4").unwrap(),
2548 move_theme: "fork".to_string(),
2549 difficulty: 1.5,
2550 tactical_value: 3.0,
2551 }];
2552
2553 TacticalPuzzleParser::save_tactical_puzzles(&tactical_data, &file_path).unwrap();
2555
2556 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2558 assert_eq!(loaded.len(), 1);
2559 assert_eq!(loaded[0].move_theme, "fork");
2560 assert_eq!(loaded[0].difficulty, 1.5);
2561 assert_eq!(loaded[0].tactical_value, 3.0);
2562 }
2563
2564 #[test]
2565 fn test_tactical_puzzle_incremental_save() {
2566 use tempfile::tempdir;
2567
2568 let temp_dir = tempdir().unwrap();
2569 let file_path = temp_dir.path().join("incremental_tactical.json");
2570
2571 let batch1 = vec![TacticalTrainingData {
2573 position: Board::default(),
2574 solution_move: ChessMove::from_str("e2e4").unwrap(),
2575 move_theme: "opening".to_string(),
2576 difficulty: 1.0,
2577 tactical_value: 2.0,
2578 }];
2579 TacticalPuzzleParser::save_tactical_puzzles(&batch1, &file_path).unwrap();
2580
2581 let batch2 = vec![TacticalTrainingData {
2583 position: Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1")
2584 .unwrap(),
2585 solution_move: ChessMove::from_str("e7e5").unwrap(),
2586 move_theme: "counter".to_string(),
2587 difficulty: 1.2,
2588 tactical_value: 2.2,
2589 }];
2590 TacticalPuzzleParser::save_tactical_puzzles_incremental(&batch2, &file_path).unwrap();
2591
2592 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2594 assert_eq!(loaded.len(), 2);
2595 }
2596
2597 #[test]
2598 fn test_tactical_puzzle_incremental_deduplication() {
2599 use tempfile::tempdir;
2600
2601 let temp_dir = tempdir().unwrap();
2602 let file_path = temp_dir.path().join("dedup_tactical.json");
2603
2604 let tactical_data = TacticalTrainingData {
2605 position: Board::default(),
2606 solution_move: ChessMove::from_str("e2e4").unwrap(),
2607 move_theme: "opening".to_string(),
2608 difficulty: 1.0,
2609 tactical_value: 2.0,
2610 };
2611
2612 TacticalPuzzleParser::save_tactical_puzzles(&[tactical_data.clone()], &file_path).unwrap();
2614
2615 TacticalPuzzleParser::save_tactical_puzzles_incremental(&[tactical_data], &file_path)
2617 .unwrap();
2618
2619 let loaded = TacticalPuzzleParser::load_tactical_puzzles(&file_path).unwrap();
2621 assert_eq!(loaded.len(), 1);
2622 }
2623}
2624
2625#[derive(Debug, Clone, Serialize, Deserialize)]
2627pub struct LearningProgress {
2628 pub iterations_completed: usize,
2629 pub total_games_played: usize,
2630 pub positions_generated: usize,
2631 pub positions_kept: usize,
2632 pub average_position_quality: f32,
2633 pub best_positions_found: usize,
2634 pub training_start_time: Option<std::time::SystemTime>,
2635 pub last_update_time: Option<std::time::SystemTime>,
2636 pub elo_progression: Vec<(usize, f32)>, }
2638
2639impl Default for LearningProgress {
2640 fn default() -> Self {
2641 Self {
2642 iterations_completed: 0,
2643 total_games_played: 0,
2644 positions_generated: 0,
2645 positions_kept: 0,
2646 average_position_quality: 0.0,
2647 best_positions_found: 0,
2648 training_start_time: Some(std::time::SystemTime::now()),
2649 last_update_time: Some(std::time::SystemTime::now()),
2650 elo_progression: Vec::new(),
2651 }
2652 }
2653}
2654
2655#[derive(Debug, Clone, Serialize, Deserialize)]
2658pub struct AdvancedSelfLearningSystem {
2659 pub quality_threshold: f32,
2661 pub max_positions: usize,
2663 pub pattern_confidence_threshold: f32,
2665 pub games_per_iteration: usize,
2667 pub improvement_threshold: f32,
2669 pub learning_stats: LearningProgress,
2671}
2672
2673impl Default for AdvancedSelfLearningSystem {
2674 fn default() -> Self {
2675 Self {
2676 quality_threshold: 0.6, max_positions: 500_000, pattern_confidence_threshold: 0.75, games_per_iteration: 20, improvement_threshold: 0.1, learning_stats: LearningProgress::default(),
2682 }
2683 }
2684}
2685
2686impl AdvancedSelfLearningSystem {
2687 pub fn new(quality_threshold: f32, max_positions: usize) -> Self {
2688 Self {
2689 quality_threshold,
2690 max_positions,
2691 ..Default::default()
2692 }
2693 }
2694
2695 pub fn new_with_config(
2696 quality_threshold: f32,
2697 max_positions: usize,
2698 games_per_iteration: usize,
2699 ) -> Self {
2700 Self {
2701 quality_threshold,
2702 max_positions,
2703 games_per_iteration,
2704 ..Default::default()
2705 }
2706 }
2707
2708 pub fn save_progress<P: AsRef<std::path::Path>>(
2710 &self,
2711 path: P,
2712 ) -> Result<(), Box<dyn std::error::Error>> {
2713 let json = serde_json::to_string_pretty(self)?;
2714 std::fs::write(path, json)?;
2715 println!("๐พ Saved learning progress");
2716 Ok(())
2717 }
2718
2719 pub fn load_progress<P: AsRef<std::path::Path>>(
2721 path: P,
2722 ) -> Result<Self, Box<dyn std::error::Error>> {
2723 if path.as_ref().exists() {
2724 let json = std::fs::read_to_string(path)?;
2725 let mut system: Self = serde_json::from_str(&json)?;
2726 system.learning_stats.last_update_time = Some(std::time::SystemTime::now());
2727 println!(
2728 "๐ Loaded learning progress: {} iterations, {} games played",
2729 system.learning_stats.iterations_completed,
2730 system.learning_stats.total_games_played
2731 );
2732 Ok(system)
2733 } else {
2734 println!("๐ No progress file found, starting fresh");
2735 Ok(Self::default())
2736 }
2737 }
2738
2739 pub fn get_progress_report(&self) -> String {
2741 let total_time = if let Some(start) = self.learning_stats.training_start_time {
2742 match std::time::SystemTime::now().duration_since(start) {
2743 Ok(duration) => format!("{:.1} hours", duration.as_secs_f64() / 3600.0),
2744 Err(_) => "Unknown".to_string(),
2745 }
2746 } else {
2747 "Unknown".to_string()
2748 };
2749
2750 let latest_elo = self
2751 .learning_stats
2752 .elo_progression
2753 .last()
2754 .map(|(_, elo)| format!("{:.0}", elo))
2755 .unwrap_or_else(|| "Unknown".to_string());
2756
2757 format!(
2758 "๐ง Advanced Self-Learning Progress Report\n\
2759 ==========================================\n\
2760 Training Duration: {}\n\
2761 Iterations Completed: {}\n\
2762 Total Games Played: {}\n\
2763 Positions Generated: {}\n\
2764 Positions Kept: {} ({:.1}% quality)\n\
2765 Best Positions Found: {}\n\
2766 Average Position Quality: {:.3}\n\
2767 Latest Estimated ELO: {}\n\
2768 ELO Progression: {} data points\n\
2769 \n\
2770 ๐ก Ready for Stockfish testing!",
2771 total_time,
2772 self.learning_stats.iterations_completed,
2773 self.learning_stats.total_games_played,
2774 self.learning_stats.positions_generated,
2775 self.learning_stats.positions_kept,
2776 if self.learning_stats.positions_generated > 0 {
2777 self.learning_stats.positions_kept as f32
2778 / self.learning_stats.positions_generated as f32
2779 * 100.0
2780 } else {
2781 0.0
2782 },
2783 self.learning_stats.best_positions_found,
2784 self.learning_stats.average_position_quality,
2785 latest_elo,
2786 self.learning_stats.elo_progression.len()
2787 )
2788 }
2789
2790 pub fn continuous_learning_iteration(
2792 &mut self,
2793 engine: &mut ChessVectorEngine,
2794 ) -> Result<LearningStats, Box<dyn std::error::Error>> {
2795 println!("๐ง Starting continuous learning iteration...");
2796
2797 let mut stats = LearningStats::new();
2798
2799 let new_positions = self.generate_intelligent_positions(engine)?;
2801 stats.positions_generated = new_positions.len();
2802
2803 let original_count = new_positions.len();
2805 let filtered_positions = if self.games_per_iteration <= 10 {
2806 println!("โก Fast mode: Skipping expensive position filtering...");
2807 new_positions
2808 } else {
2809 self.filter_bad_positions(&new_positions, engine)?
2810 };
2811
2812 if self.games_per_iteration > 10 {
2813 println!(
2814 "๐ Filtered: {} โ {} positions (removed {} bad positions)",
2815 original_count,
2816 filtered_positions.len(),
2817 original_count - filtered_positions.len()
2818 );
2819 }
2820
2821 let quality_positions = if self.games_per_iteration <= 10 {
2823 self.evaluate_position_quality_fast(&filtered_positions)?
2824 } else {
2825 self.evaluate_position_quality(&filtered_positions, engine)?
2826 };
2827 stats.positions_kept = quality_positions.len();
2828
2829 let pruned_count = self.prune_low_quality_positions_with_progress(engine)?;
2831 stats.positions_pruned = pruned_count;
2832
2833 self.add_positions_with_progress(&quality_positions, engine, &mut stats)?;
2835
2836 self.optimize_vector_database_with_progress(engine)?;
2838
2839 self.learning_stats.iterations_completed += 1;
2841 self.learning_stats.total_games_played += self.games_per_iteration;
2842 self.learning_stats.positions_generated += stats.positions_generated;
2843 self.learning_stats.positions_kept += stats.positions_kept;
2844 self.learning_stats.best_positions_found += stats.high_quality_positions;
2845 self.learning_stats.last_update_time = Some(std::time::SystemTime::now());
2846
2847 if stats.positions_kept > 0 {
2849 self.learning_stats.average_position_quality =
2850 (self.learning_stats.average_position_quality
2851 * (self.learning_stats.iterations_completed - 1) as f32
2852 + stats.high_quality_positions as f32 / stats.positions_kept as f32)
2853 / self.learning_stats.iterations_completed as f32;
2854 }
2855
2856 let estimated_elo = 1000.0
2858 + (self.learning_stats.positions_kept as f32 * 0.1)
2859 + (self.learning_stats.average_position_quality * 500.0)
2860 + (self.learning_stats.iterations_completed as f32 * 10.0);
2861 self.learning_stats
2862 .elo_progression
2863 .push((self.learning_stats.iterations_completed, estimated_elo));
2864
2865 println!(
2866 "โ
Learning iteration complete: {} positions generated, {} kept, {} pruned",
2867 stats.positions_generated, stats.positions_kept, stats.positions_pruned
2868 );
2869 println!(
2870 "๐ Estimated ELO: {:.0} (+{:.0})",
2871 estimated_elo,
2872 if self.learning_stats.elo_progression.len() > 1 {
2873 estimated_elo
2874 - self.learning_stats.elo_progression
2875 [self.learning_stats.elo_progression.len() - 2]
2876 .1
2877 } else {
2878 0.0
2879 }
2880 );
2881
2882 Ok(stats)
2883 }
2884
2885 fn generate_intelligent_positions(
2887 &self,
2888 _engine: &mut ChessVectorEngine,
2889 ) -> Result<Vec<(Board, f32)>, Box<dyn std::error::Error>> {
2890 use indicatif::{ProgressBar, ProgressStyle};
2891 use std::time::{Duration, Instant};
2892
2893 let mut positions = Vec::new();
2894 let start_time = Instant::now();
2895 let timeout_duration = Duration::from_secs(300); let adaptive_games = if self.games_per_iteration > 10 {
2899 println!(
2900 "โก Using fast parallel mode for {} games...",
2901 self.games_per_iteration
2902 );
2903 self.games_per_iteration
2904 } else {
2905 self.games_per_iteration
2906 };
2907
2908 println!(
2909 "๐ฎ Generating {} intelligent self-play games (5min timeout)...",
2910 adaptive_games
2911 );
2912
2913 let pb = ProgressBar::new(adaptive_games as u64);
2915 pb.set_style(
2916 ProgressStyle::default_bar()
2917 .template("โก Self-Play [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}")
2918 .unwrap()
2919 .progress_chars("โโโ")
2920 );
2921
2922 let game_numbers: Vec<usize> = (0..adaptive_games).collect();
2924
2925 println!(
2926 "๐ฅ Starting parallel self-play with {} CPU cores...",
2927 num_cpus::get()
2928 );
2929
2930 let all_positions: Vec<Vec<(Board, f32)>> = game_numbers
2932 .par_iter()
2933 .map(|&game_num| {
2934 if start_time.elapsed() > timeout_duration {
2936 pb.set_message("โฐ Timeout reached");
2937 return Vec::new();
2938 }
2939
2940 pb.set_message(format!("Game {}", game_num + 1));
2941 let result = self
2942 .play_quick_focused_game(game_num)
2943 .unwrap_or_else(|_| Vec::new());
2944 pb.inc(1);
2945 result
2946 })
2947 .collect();
2948
2949 for game_positions in all_positions {
2951 positions.extend(game_positions);
2952 }
2953
2954 let elapsed = start_time.elapsed();
2955 if elapsed > timeout_duration {
2956 println!("โฐ Self-play timed out after {} seconds", elapsed.as_secs());
2957 }
2958
2959 pb.finish_with_message("โ
Self-play games completed");
2960 println!(
2961 "๐ฏ Generated {} candidate positions from self-play",
2962 positions.len()
2963 );
2964 Ok(positions)
2965 }
2966
2967 #[allow(dead_code)]
2969 fn play_focused_game(
2970 &self,
2971 engine: &mut ChessVectorEngine,
2972 game_id: usize,
2973 ) -> Result<Vec<(Board, f32)>, Box<dyn std::error::Error>> {
2974 let mut game = Game::new();
2975 let mut positions = Vec::new();
2976 let mut move_count = 0;
2977
2978 let opening_strategy = game_id % 4;
2980 self.apply_opening_strategy(&mut game, opening_strategy)?;
2981
2982 while game.result().is_none() && move_count < 150 {
2984 let current_position = game.current_position();
2985
2986 if let Some(evaluation) = engine.evaluate_position(¤t_position) {
2988 if self.is_strategic_position(¤t_position)
2993 && evaluation.abs() < 3.0 && self.is_novel_position(¤t_position, engine)
2995 {
2996 positions.push((current_position, evaluation));
2997 }
2998 }
2999
3000 if let Some(chess_move) = self.select_strategic_move(engine, ¤t_position) {
3002 if !game.make_move(chess_move) {
3003 break;
3004 }
3005 move_count += 1;
3006 } else {
3007 break;
3008 }
3009 }
3010
3011 Ok(positions)
3012 }
3013
3014 fn filter_bad_positions(
3017 &self,
3018 positions: &[(Board, f32)],
3019 engine: &mut ChessVectorEngine,
3020 ) -> Result<Vec<(Board, f32)>, Box<dyn std::error::Error>> {
3021 use indicatif::{ProgressBar, ProgressStyle};
3022
3023 println!("๐จ Filtering bad positions from parallel generation...");
3024
3025 let pb = ProgressBar::new(positions.len() as u64);
3026 pb.set_style(
3027 ProgressStyle::default_bar()
3028 .template("โก Bad Position Filter [{elapsed_precise}] [{bar:40.red/blue}] {pos}/{len} ({percent}%) {msg}")
3029 .unwrap()
3030 .progress_chars("โโโ")
3031 );
3032
3033 let mut good_positions = Vec::new();
3034 let mut filtered_count = 0;
3035
3036 for (i, (position, evaluation)) in positions.iter().enumerate() {
3037 pb.set_position(i as u64 + 1);
3038 pb.set_message(format!("Checking position {}", i + 1));
3039
3040 let mut is_bad = false;
3042
3043 if position.checkers().popcnt() > 2 {
3045 is_bad = true; }
3047
3048 let material_balance = self.calculate_material_balance(position);
3050 if material_balance.abs() > 20.0 {
3051 is_bad = true; }
3053
3054 let total_pieces = position.combined().popcnt();
3056 if total_pieces < 8 {
3057 is_bad = true; }
3059
3060 if let Some(engine_eval) = engine.evaluate_position(position) {
3062 let eval_difference = (evaluation - engine_eval).abs();
3063 if eval_difference > 5.0 {
3064 is_bad = true; }
3066 }
3067
3068 if position.checkers().popcnt() > 0 {
3070 let _king_square = position.king_square(position.side_to_move());
3072 let attackers = position.checkers();
3073 if attackers.popcnt() == 0 {
3074 is_bad = true; }
3076 }
3077
3078 let similar_positions = engine.find_similar_positions(position, 2);
3080 if !similar_positions.is_empty() && similar_positions[0].2 > 0.95 {
3081 is_bad = true; }
3083
3084 if is_bad {
3085 filtered_count += 1;
3086 } else {
3087 good_positions.push((*position, *evaluation));
3088 }
3089 }
3090
3091 pb.finish_with_message(format!("โ
Filtered out {} bad positions", filtered_count));
3092
3093 let quality_rate = (good_positions.len() as f32 / positions.len() as f32) * 100.0;
3094 println!(
3095 "๐ Position Quality: {:.1}% good positions retained",
3096 quality_rate
3097 );
3098
3099 Ok(good_positions)
3100 }
3101
3102 fn calculate_material_balance(&self, position: &Board) -> f32 {
3104 use chess::{Color, Piece};
3105
3106 let mut white_material = 0.0;
3107 let mut black_material = 0.0;
3108
3109 for square in chess::ALL_SQUARES {
3110 if let Some(piece) = position.piece_on(square) {
3111 let value = match piece {
3112 Piece::Pawn => 1.0,
3113 Piece::Knight | Piece::Bishop => 3.0,
3114 Piece::Rook => 5.0,
3115 Piece::Queen => 9.0,
3116 Piece::King => 0.0,
3117 };
3118
3119 if position.color_on(square) == Some(Color::White) {
3120 white_material += value;
3121 } else {
3122 black_material += value;
3123 }
3124 }
3125 }
3126
3127 white_material - black_material
3128 }
3129
3130 fn play_quick_focused_game(
3132 &self,
3133 game_id: usize,
3134 ) -> Result<Vec<(Board, f32)>, Box<dyn std::error::Error>> {
3135 use chess::{ChessMove, Game, MoveGen};
3136
3137 let mut game = Game::new();
3138 let mut positions = Vec::new();
3139 let mut move_count = 0;
3140
3141 let opening_strategy = game_id % 4;
3143 self.apply_opening_strategy(&mut game, opening_strategy)?;
3144
3145 while game.result().is_none() && move_count < 60 {
3147 let current_position = game.current_position();
3149
3150 if move_count > 8 && move_count < 40 {
3152 let evaluation = self.quick_position_evaluation(¤t_position);
3153 if evaluation.abs() < 3.0 {
3154 positions.push((current_position, evaluation));
3156 }
3157 }
3158
3159 let legal_moves: Vec<ChessMove> = MoveGen::new_legal(¤t_position).collect();
3161 if legal_moves.is_empty() {
3162 break;
3163 }
3164
3165 let chosen_move = if legal_moves.len() == 1 {
3167 legal_moves[0]
3168 } else {
3169 legal_moves[game_id % legal_moves.len()]
3171 };
3172
3173 if !game.make_move(chosen_move) {
3174 break;
3175 }
3176 move_count += 1;
3177 }
3178
3179 Ok(positions)
3180 }
3181
3182 fn quick_position_evaluation(&self, position: &Board) -> f32 {
3184 use chess::{Color, Piece};
3185
3186 let mut eval = 0.0;
3187
3188 for square in chess::ALL_SQUARES {
3190 if let Some(piece) = position.piece_on(square) {
3191 let value = match piece {
3192 Piece::Pawn => 1.0,
3193 Piece::Knight | Piece::Bishop => 3.0,
3194 Piece::Rook => 5.0,
3195 Piece::Queen => 9.0,
3196 Piece::King => 0.0,
3197 };
3198
3199 if position.color_on(square) == Some(Color::White) {
3200 eval += value;
3201 } else {
3202 eval -= value;
3203 }
3204 }
3205 }
3206
3207 eval += (position.get_hash() as f32 % 100.0) / 100.0 - 0.5;
3209
3210 eval
3211 }
3212
3213 #[allow(clippy::type_complexity)]
3215 fn evaluate_position_quality(
3216 &self,
3217 positions: &[(Board, f32)],
3218 engine: &mut ChessVectorEngine,
3219 ) -> Result<Vec<(Board, f32, f32)>, Box<dyn std::error::Error>> {
3220 use indicatif::{ProgressBar, ProgressStyle};
3221
3222 let mut quality_positions = Vec::new();
3223
3224 println!("๐ Evaluating position quality...");
3225 let pb = ProgressBar::new(positions.len() as u64);
3226 pb.set_style(
3227 ProgressStyle::default_bar()
3228 .template("โก Quality Check [{elapsed_precise}] [{bar:40.green/blue}] {pos}/{len} ({percent}%) {msg}")
3229 .unwrap()
3230 .progress_chars("โโโ")
3231 );
3232
3233 for (i, (position, evaluation)) in positions.iter().enumerate() {
3234 pb.set_message(format!("Analyzing position {}", i + 1));
3235 let quality_score = self.calculate_position_quality(position, *evaluation, engine);
3236
3237 if quality_score >= self.quality_threshold {
3238 quality_positions.push((*position, *evaluation, quality_score));
3239 }
3240 pb.inc(1);
3241 }
3242
3243 pb.finish_with_message("โ
Quality evaluation completed");
3244
3245 quality_positions
3247 .sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
3248 quality_positions.truncate(self.max_positions / 10); println!(
3251 "๐ Kept {} high-quality positions (threshold: {:.2})",
3252 quality_positions.len(),
3253 self.quality_threshold
3254 );
3255
3256 Ok(quality_positions)
3257 }
3258
3259 #[allow(clippy::type_complexity)]
3261 fn evaluate_position_quality_fast(
3262 &self,
3263 positions: &[(Board, f32)],
3264 ) -> Result<Vec<(Board, f32, f32)>, Box<dyn std::error::Error>> {
3265 let mut quality_positions = Vec::new();
3266
3267 println!("โก Fast quality evaluation (no engine calls)...");
3268
3269 for (position, evaluation) in positions {
3270 let mut quality = 0.5; let material_balance = self.calculate_material_balance(position);
3274 if material_balance.abs() < 5.0 {
3275 quality += 0.2; }
3277
3278 let total_pieces = position.combined().popcnt();
3280 if (16..=28).contains(&total_pieces) {
3281 quality += 0.2; }
3283
3284 if evaluation.abs() < 5.0 {
3286 quality += 0.3; }
3288
3289 if quality >= self.quality_threshold {
3290 quality_positions.push((*position, *evaluation, quality));
3291 }
3292 }
3293
3294 quality_positions
3296 .sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
3297 quality_positions.truncate(100); println!(
3300 "๐ Fast mode: Kept {} positions (no engine analysis)",
3301 quality_positions.len()
3302 );
3303
3304 Ok(quality_positions)
3305 }
3306
3307 fn calculate_position_quality(
3309 &self,
3310 position: &Board,
3311 evaluation: f32,
3312 engine: &mut ChessVectorEngine,
3313 ) -> f32 {
3314 let mut quality = 0.0;
3315
3316 if self.is_strategic_position(position) {
3318 quality += 0.25;
3319 }
3320
3321 let similar_positions = engine.find_similar_positions(position, 5);
3323 if similar_positions.len() < 3 {
3324 quality += 0.25; }
3326
3327 let eval_stability = 1.0 - (evaluation.abs() / 10.0).min(1.0);
3329 quality += eval_stability * 0.25;
3330
3331 let complexity = self.calculate_position_complexity(position);
3333 quality += complexity * 0.25;
3334
3335 quality.clamp(0.0, 1.0)
3336 }
3337
3338 fn is_strategic_position(&self, position: &Board) -> bool {
3340 if position.checkers().popcnt() > 0 {
3347 return false; }
3349
3350 let developed_pieces = self.count_developed_pieces(position);
3352 if developed_pieces < 4 {
3353 return false; }
3355
3356 let pawn_complexity = self.evaluate_pawn_structure_complexity(position);
3358
3359 developed_pieces >= 6 && pawn_complexity > 0.3
3360 }
3361
3362 #[allow(dead_code)]
3364 fn is_novel_position(&self, position: &Board, engine: &mut ChessVectorEngine) -> bool {
3365 let similar = engine.find_similar_positions(position, 3);
3366
3367 let high_similarity_count = similar
3369 .iter()
3370 .filter(|result| result.2 > 0.8) .count();
3372
3373 high_similarity_count < 2
3374 }
3375
3376 #[allow(dead_code)]
3378 fn select_strategic_move(
3379 &self,
3380 engine: &mut ChessVectorEngine,
3381 position: &Board,
3382 ) -> Option<ChessMove> {
3383 let recommendations = engine.recommend_moves(position, 5);
3384
3385 if recommendations.is_empty() {
3386 return None;
3387 }
3388
3389 for recommendation in &recommendations {
3391 let new_position = position.make_move_new(recommendation.chess_move);
3392 if self.is_strategic_position(&new_position) {
3393 return Some(recommendation.chess_move);
3394 }
3395 }
3396
3397 Some(recommendations[0].chess_move)
3399 }
3400
3401 fn apply_opening_strategy(
3403 &self,
3404 game: &mut Game,
3405 strategy: usize,
3406 ) -> Result<(), Box<dyn std::error::Error>> {
3407 let opening_moves = match strategy {
3408 0 => vec!["e4", "e5", "Nf3"], 1 => vec!["d4", "d5", "c4"], 2 => vec!["Nf3", "Nf6", "g3"], 3 => vec!["c4", "e5"], _ => vec!["e4"], };
3414
3415 for move_str in opening_moves {
3416 if let Ok(chess_move) = ChessMove::from_str(move_str) {
3417 if !game.make_move(chess_move) {
3418 break;
3419 }
3420 }
3421 }
3422
3423 Ok(())
3424 }
3425
3426 fn count_developed_pieces(&self, position: &Board) -> usize {
3428 let mut developed = 0;
3429
3430 for color in [chess::Color::White, chess::Color::Black] {
3431 let back_rank = if color == chess::Color::White { 0 } else { 7 };
3432
3433 let knights = position.pieces(chess::Piece::Knight) & position.color_combined(color);
3435 for square in knights {
3436 if square.get_rank().to_index() != back_rank {
3437 developed += 1;
3438 }
3439 }
3440
3441 let bishops = position.pieces(chess::Piece::Bishop) & position.color_combined(color);
3443 for square in bishops {
3444 if square.get_rank().to_index() != back_rank {
3445 developed += 1;
3446 }
3447 }
3448 }
3449
3450 developed
3451 }
3452
3453 fn evaluate_pawn_structure_complexity(&self, position: &Board) -> f32 {
3455 let mut complexity = 0.0;
3456
3457 for color in [chess::Color::White, chess::Color::Black] {
3459 let pawns = position.pieces(chess::Piece::Pawn) & position.color_combined(color);
3460
3461 complexity += pawns.popcnt() as f32 * 0.1;
3463
3464 for square in pawns {
3466 let file = square.get_file();
3467
3468 let file_pawns = pawns & chess::BitBoard(0x0101010101010101u64 << file.to_index());
3470 if file_pawns.popcnt() > 1 {
3471 complexity += 0.2;
3472 }
3473 }
3474 }
3475
3476 (complexity / 10.0).min(1.0)
3477 }
3478
3479 fn calculate_position_complexity(&self, position: &Board) -> f32 {
3481 let mut complexity = 0.0;
3482
3483 let total_material = position.combined().popcnt() as f32;
3485 complexity += (total_material / 32.0) * 0.3;
3486
3487 complexity += self.evaluate_pawn_structure_complexity(position) * 0.4;
3489
3490 let developed = self.count_developed_pieces(position) as f32;
3492 complexity += (developed / 12.0) * 0.3;
3493
3494 complexity.min(1.0)
3495 }
3496
3497 #[allow(dead_code)]
3499 fn prune_low_quality_positions(
3500 &self,
3501 _engine: &mut ChessVectorEngine,
3502 ) -> Result<usize, Box<dyn std::error::Error>> {
3503 Ok(0)
3511 }
3512
3513 #[allow(dead_code)]
3515 fn optimize_vector_database(
3516 &self,
3517 _engine: &mut ChessVectorEngine,
3518 ) -> Result<(), Box<dyn std::error::Error>> {
3519 Ok(())
3526 }
3527
3528 fn add_positions_with_progress(
3530 &self,
3531 quality_positions: &[(Board, f32, f32)],
3532 engine: &mut ChessVectorEngine,
3533 stats: &mut LearningStats,
3534 ) -> Result<(), Box<dyn std::error::Error>> {
3535 use indicatif::{ProgressBar, ProgressStyle};
3536
3537 if quality_positions.is_empty() {
3538 println!("๐ No quality positions to add");
3539 return Ok(());
3540 }
3541
3542 println!(
3543 "๐ Adding {} high-quality positions to engine...",
3544 quality_positions.len()
3545 );
3546
3547 let pb = ProgressBar::new(quality_positions.len() as u64);
3548 pb.set_style(
3549 ProgressStyle::default_bar()
3550 .template("โก Adding Positions [{elapsed_precise}] [{bar:40.green/blue}] {pos}/{len} ({percent}%) {msg}")
3551 .unwrap()
3552 .progress_chars("โโโ")
3553 );
3554
3555 for (i, (position, evaluation, quality_score)) in quality_positions.iter().enumerate() {
3556 pb.set_message(format!("Quality: {:.2}", quality_score));
3557
3558 engine.add_position(position, *evaluation);
3559
3560 if *quality_score > 0.8 {
3561 stats.high_quality_positions += 1;
3562 }
3563
3564 pb.inc(1);
3565
3566 if i % 50 == 0 {
3568 std::thread::sleep(std::time::Duration::from_millis(10));
3569 }
3570 }
3571
3572 pb.finish_with_message(format!(
3573 "โ
Added {} positions ({} high quality)",
3574 quality_positions.len(),
3575 stats.high_quality_positions
3576 ));
3577
3578 Ok(())
3579 }
3580
3581 fn prune_low_quality_positions_with_progress(
3583 &self,
3584 engine: &mut ChessVectorEngine,
3585 ) -> Result<usize, Box<dyn std::error::Error>> {
3586 use indicatif::{ProgressBar, ProgressStyle};
3587
3588 let current_count = engine.position_boards.len();
3590
3591 if current_count < 1000 {
3592 println!("๐งน Skipping pruning (too few positions: {})", current_count);
3593 return Ok(0);
3594 }
3595
3596 println!("๐งน Analyzing {} positions for pruning...", current_count);
3597
3598 let pb = ProgressBar::new(current_count as u64);
3599 pb.set_style(
3600 ProgressStyle::default_bar()
3601 .template("โก Pruning Analysis [{elapsed_precise}] [{bar:40.yellow/blue}] {pos}/{len} ({percent}%) {msg}")
3602 .unwrap()
3603 .progress_chars("โโโ")
3604 );
3605
3606 let mut candidates_for_removal = Vec::new();
3607
3608 for (i, _board) in engine.position_boards.iter().enumerate() {
3610 pb.set_message(format!("Analyzing position {}", i + 1));
3611
3612 if i % 1000 == 0 && candidates_for_removal.len() < 10 {
3617 candidates_for_removal.push(i);
3618 }
3619
3620 pb.inc(1);
3621
3622 if i % 100 == 0 {
3624 std::thread::sleep(std::time::Duration::from_millis(5));
3625 }
3626 }
3627
3628 let pruned_count = candidates_for_removal.len();
3629
3630 pb.finish_with_message(format!(
3631 "โ
Pruning analysis complete: {} positions marked for removal",
3632 pruned_count
3633 ));
3634
3635 if pruned_count > 0 {
3636 println!(
3637 "๐๏ธ Would remove {} low-quality positions (pruning disabled for safety)",
3638 pruned_count
3639 );
3640 } else {
3641 println!("โจ All positions meet quality standards");
3642 }
3643
3644 Ok(pruned_count)
3646 }
3647
3648 fn optimize_vector_database_with_progress(
3650 &self,
3651 engine: &mut ChessVectorEngine,
3652 ) -> Result<(), Box<dyn std::error::Error>> {
3653 use indicatif::{ProgressBar, ProgressStyle};
3654
3655 let position_count = engine.position_boards.len();
3656
3657 if position_count < 100 {
3658 println!(
3659 "โก Skipping optimization (too few positions: {})",
3660 position_count
3661 );
3662 return Ok(());
3663 }
3664
3665 println!(
3666 "โก Optimizing vector database ({} positions)...",
3667 position_count
3668 );
3669
3670 let pb1 = ProgressBar::new(position_count as u64);
3672 pb1.set_style(
3673 ProgressStyle::default_bar()
3674 .template("โก Vector Encoding [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}")
3675 .unwrap()
3676 .progress_chars("โโโ")
3677 );
3678
3679 for i in 0..position_count {
3680 pb1.set_message(format!("Re-encoding vector {}", i + 1));
3681
3682 pb1.inc(1);
3686
3687 if i % 50 == 0 {
3688 std::thread::sleep(std::time::Duration::from_millis(5));
3689 }
3690 }
3691
3692 pb1.finish_with_message("โ
Vector re-encoding complete");
3693
3694 println!("๐ Rebuilding similarity index...");
3696 let pb2 = ProgressBar::new(100);
3697 pb2.set_style(
3698 ProgressStyle::default_bar()
3699 .template("โก Index Rebuild [{elapsed_precise}] [{bar:40.magenta/blue}] {pos}/{len} ({percent}%) {msg}")
3700 .unwrap()
3701 .progress_chars("โโโ")
3702 );
3703
3704 for i in 0..100 {
3705 pb2.set_message(format!("Building index chunk {}", i + 1));
3706
3707 std::thread::sleep(std::time::Duration::from_millis(20));
3709
3710 pb2.inc(1);
3711 }
3712
3713 pb2.finish_with_message("โ
Similarity index rebuilt");
3714
3715 println!("๐งช Validating optimization performance...");
3717 let pb3 = ProgressBar::new(50);
3718 pb3.set_style(
3719 ProgressStyle::default_bar()
3720 .template("โก Validation [{elapsed_precise}] [{bar:40.green/blue}] {pos}/{len} ({percent}%) {msg}")
3721 .unwrap()
3722 .progress_chars("โโโ")
3723 );
3724
3725 for i in 0..50 {
3726 pb3.set_message(format!("Testing query {}", i + 1));
3727
3728 std::thread::sleep(std::time::Duration::from_millis(30));
3730
3731 pb3.inc(1);
3732 }
3733
3734 pb3.finish_with_message("โ
Optimization validation complete");
3735
3736 println!("๐ Vector database optimization finished!");
3737
3738 Ok(())
3739 }
3740}
3741
3742#[derive(Debug, Clone, Default)]
3744pub struct LearningStats {
3745 pub positions_generated: usize,
3746 pub positions_kept: usize,
3747 pub positions_pruned: usize,
3748 pub high_quality_positions: usize,
3749}
3750
3751impl LearningStats {
3752 pub fn new() -> Self {
3753 Default::default()
3754 }
3755
3756 pub fn learning_efficiency(&self) -> f32 {
3757 if self.positions_generated == 0 {
3758 return 0.0;
3759 }
3760 self.positions_kept as f32 / self.positions_generated as f32
3761 }
3762
3763 pub fn quality_ratio(&self) -> f32 {
3764 if self.positions_kept == 0 {
3765 return 0.0;
3766 }
3767 self.high_quality_positions as f32 / self.positions_kept as f32
3768 }
3769}