chess_vector_engine/
lichess_loader.rs

1#![allow(clippy::type_complexity)]
2use crate::TrainingData;
3use chess::{Board, ChessMove, Color};
4use rayon::prelude::*;
5use serde::Deserialize;
6use std::fs::File;
7use std::io::{BufRead, BufReader};
8use std::path::Path;
9use std::str::FromStr;
10use std::sync::{Arc, Mutex};
11use std::time::Instant;
12
13/// Lichess puzzle entry from CSV
14#[derive(Debug, Deserialize)]
15struct LichessPuzzle {
16    #[serde(rename = "PuzzleId")]
17    #[allow(dead_code)]
18    puzzle_id: String,
19    #[serde(rename = "FEN")]
20    #[allow(dead_code)]
21    fen: String,
22    #[serde(rename = "Moves")]
23    #[allow(dead_code)]
24    moves: String,
25    #[serde(rename = "Rating")]
26    #[allow(dead_code)]
27    rating: u32,
28    #[serde(rename = "RatingDeviation")]
29    #[allow(dead_code)]
30    rating_deviation: u32,
31    #[serde(rename = "Popularity")]
32    #[allow(dead_code)]
33    popularity: i32,
34    #[serde(rename = "NbPlays")]
35    #[allow(dead_code)]
36    nb_plays: u32,
37    #[serde(rename = "Themes")]
38    #[allow(dead_code)]
39    themes: String,
40    #[serde(rename = "GameUrl")]
41    #[allow(dead_code)]
42    game_url: String,
43}
44
45/// High-performance Lichess puzzle database loader
46pub struct LichessLoader {
47    /// Minimum puzzle rating to include
48    min_rating: u32,
49    /// Maximum puzzle rating to include  
50    max_rating: u32,
51    /// Batch size for parallel processing
52    batch_size: usize,
53    /// Number of worker threads
54    num_threads: usize,
55    /// Filter by themes (e.g., "checkmate", "fork", "pin")
56    theme_filter: Option<Vec<String>>,
57}
58
59impl LichessLoader {
60    /// Create a new Lichess loader with default settings
61    pub fn new() -> Self {
62        Self {
63            min_rating: 800,                      // Exclude beginner puzzles
64            max_rating: 2800,                     // Include all reasonable puzzles
65            batch_size: 10_000,                   // Process 10k puzzles per batch
66            num_threads: num_cpus::get().min(16), // Use available cores (max 16)
67            theme_filter: None,
68        }
69    }
70
71    /// Create a premium loader with optimized settings
72    pub fn new_premium() -> Self {
73        Self {
74            min_rating: 1000,                     // Focus on intermediate+ puzzles
75            max_rating: 2500,                     // Exclude super-GM puzzles
76            batch_size: 50_000,                   // Larger batches for premium performance
77            num_threads: num_cpus::get().min(32), // Use more cores for premium
78            theme_filter: Some(vec![
79                "checkmate".to_string(),
80                "mateIn2".to_string(),
81                "mateIn3".to_string(),
82                "fork".to_string(),
83                "pin".to_string(),
84                "skewer".to_string(),
85                "discovery".to_string(),
86                "sacrifice".to_string(),
87                "deflection".to_string(),
88                "attraction".to_string(),
89            ]),
90        }
91    }
92
93    /// Set rating range filter
94    pub fn with_rating_range(mut self, min: u32, max: u32) -> Self {
95        self.min_rating = min;
96        self.max_rating = max;
97        self
98    }
99
100    /// Set theme filter
101    pub fn with_themes(mut self, themes: Vec<String>) -> Self {
102        self.theme_filter = Some(themes);
103        self
104    }
105
106    /// Set batch size for memory control
107    pub fn with_batch_size(mut self, size: usize) -> Self {
108        self.batch_size = size;
109        self
110    }
111
112    /// Load training data from Lichess puzzle CSV with lightning speed
113    pub fn load_parallel<P: AsRef<Path>>(
114        &self,
115        csv_path: P,
116    ) -> Result<Vec<TrainingData>, Box<dyn std::error::Error>> {
117        let start_time = Instant::now();
118        let path = csv_path.as_ref();
119
120        println!("🔥 Lightning-fast Lichess puzzle loader starting...");
121        println!("Loading from file: {}", path.display());
122        println!("⚡ Parallel processing with {} threads", self.num_threads);
123        println!("Processing data...");
124
125        // Configure rayon thread pool
126        let pool = rayon::ThreadPoolBuilder::new()
127            .num_threads(self.num_threads)
128            .build()?;
129
130        let results = Arc::new(Mutex::new(Vec::new()));
131        let total_processed = Arc::new(Mutex::new(0usize));
132        let valid_puzzles = Arc::new(Mutex::new(0usize));
133
134        // Read file in streaming chunks to control memory
135        let file = File::open(path)?;
136        let reader = BufReader::with_capacity(1024 * 1024, file); // 1MB buffer
137
138        // Skip header line
139        let mut lines = reader.lines();
140        lines.next(); // Skip CSV header
141
142        // Process in batches
143        let mut batch = Vec::with_capacity(self.batch_size);
144        let mut batch_count = 0;
145
146        for line in lines {
147            let line = line?;
148            batch.push(line);
149
150            if batch.len() >= self.batch_size {
151                batch_count += 1;
152                let batch_data = std::mem::take(&mut batch);
153
154                // Process batch in parallel
155                let batch_results = self.process_batch_parallel(&pool, batch_data)?;
156
157                // Accumulate results
158                {
159                    let mut results_guard = results.lock().unwrap();
160                    let mut processed_guard = total_processed.lock().unwrap();
161                    let mut valid_guard = valid_puzzles.lock().unwrap();
162
163                    *processed_guard += self.batch_size;
164                    *valid_guard += batch_results.len();
165                    results_guard.extend(batch_results);
166
167                    if batch_count % 10 == 0 {
168                        println!(
169                            "📈 Batch {}: Processed {}k puzzles, {} valid positions",
170                            batch_count,
171                            *processed_guard / 1000,
172                            *valid_guard
173                        );
174                    }
175                }
176
177                batch = Vec::with_capacity(self.batch_size);
178            }
179        }
180
181        // Process final partial batch
182        if !batch.is_empty() {
183            let batch_results = self.process_batch_parallel(&pool, batch)?;
184            let mut results_guard = results.lock().unwrap();
185            results_guard.extend(batch_results);
186        }
187
188        let final_results = Arc::try_unwrap(results).unwrap().into_inner().unwrap();
189        let elapsed = start_time.elapsed();
190
191        println!("🎉 Lightning loading complete!");
192        println!("⏱️  Time: {:.2}s", elapsed.as_secs_f64());
193        println!("📊 Loaded {} training positions", final_results.len());
194        println!(
195            "🚀 Speed: {:.0} puzzles/second",
196            final_results.len() as f64 / elapsed.as_secs_f64()
197        );
198
199        Ok(final_results)
200    }
201
202    /// Load training data with moves from Lichess puzzle CSV for pattern recognition
203    pub fn load_parallel_with_moves<P: AsRef<Path>>(
204        &self,
205        csv_path: P,
206    ) -> Result<Vec<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
207        let start_time = Instant::now();
208        let path = csv_path.as_ref();
209
210        println!("🧠 Lightning-fast Lichess puzzle loader (with moves) starting...");
211        println!("Loading from file: {}", path.display());
212        println!("⚡ Parallel processing with {} threads", self.num_threads);
213        println!("Processing data...");
214
215        // Configure rayon thread pool
216        let pool = rayon::ThreadPoolBuilder::new()
217            .num_threads(self.num_threads)
218            .build()?;
219
220        let results = Arc::new(Mutex::new(Vec::new()));
221        let total_processed = Arc::new(Mutex::new(0usize));
222        let valid_puzzles = Arc::new(Mutex::new(0usize));
223
224        // Read file in streaming chunks to control memory
225        let file = File::open(path)?;
226        let reader = BufReader::with_capacity(1024 * 1024, file); // 1MB buffer
227
228        // Skip header line
229        let mut lines = reader.lines();
230        lines.next(); // Skip CSV header
231
232        // Process in batches
233        let mut batch = Vec::with_capacity(self.batch_size);
234        let mut batch_count = 0;
235
236        for line in lines {
237            let line = line?;
238            batch.push(line);
239
240            if batch.len() >= self.batch_size {
241                batch_count += 1;
242                let batch_data = std::mem::take(&mut batch);
243
244                // Process batch in parallel
245                let batch_results = self.process_batch_parallel_with_moves(&pool, batch_data)?;
246
247                // Accumulate results
248                {
249                    let mut results_guard = results.lock().unwrap();
250                    let mut processed_guard = total_processed.lock().unwrap();
251                    let mut valid_guard = valid_puzzles.lock().unwrap();
252
253                    *processed_guard += self.batch_size;
254                    *valid_guard += batch_results.len();
255                    results_guard.extend(batch_results);
256
257                    if batch_count % 10 == 0 {
258                        println!(
259                            "📈 Batch {}: Processed {}k puzzles, {} valid moves",
260                            batch_count,
261                            *processed_guard / 1000,
262                            *valid_guard
263                        );
264                    }
265                }
266
267                batch = Vec::with_capacity(self.batch_size);
268            }
269        }
270
271        // Process final partial batch
272        if !batch.is_empty() {
273            let batch_results = self.process_batch_parallel_with_moves(&pool, batch)?;
274            let mut results_guard = results.lock().unwrap();
275            results_guard.extend(batch_results);
276        }
277
278        let final_results = Arc::try_unwrap(results).unwrap().into_inner().unwrap();
279        let elapsed = start_time.elapsed();
280
281        println!("🎉 Lightning loading with moves complete!");
282        println!("⏱️  Time: {:.2}s", elapsed.as_secs_f64());
283        println!("🧠 Loaded {} tactical moves", final_results.len());
284        println!(
285            "🚀 Speed: {:.0} puzzles/second",
286            final_results.len() as f64 / elapsed.as_secs_f64()
287        );
288
289        Ok(final_results)
290    }
291
292    /// Process a batch of CSV lines in parallel
293    fn process_batch_parallel(
294        &self,
295        pool: &rayon::ThreadPool,
296        batch: Vec<String>,
297    ) -> Result<Vec<TrainingData>, Box<dyn std::error::Error>> {
298        let loader = self;
299        let batch_results: Vec<_> = pool.install(|| {
300            batch
301                .par_iter()
302                .filter_map(|line| loader.parse_puzzle_line(line).ok().flatten())
303                .collect()
304        });
305
306        Ok(batch_results)
307    }
308
309    /// Process a batch of CSV lines in parallel with moves for pattern recognition
310    fn process_batch_parallel_with_moves(
311        &self,
312        pool: &rayon::ThreadPool,
313        batch: Vec<String>,
314    ) -> Result<Vec<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
315        let loader = self;
316        let batch_results: Vec<_> = pool.install(|| {
317            batch
318                .par_iter()
319                .filter_map(|line| loader.parse_puzzle_line_with_move(line).ok().flatten())
320                .collect()
321        });
322
323        Ok(batch_results)
324    }
325
326    /// Parse a single CSV line into training data
327    fn parse_puzzle_line(
328        &self,
329        line: &str,
330    ) -> Result<Option<TrainingData>, Box<dyn std::error::Error>> {
331        // Wrap entire parsing in panic protection to catch any chess library panics
332        match std::panic::catch_unwind(
333            || -> Result<Option<TrainingData>, Box<dyn std::error::Error>> {
334                // Use proper CSV parsing to handle quotes and commas correctly
335                let mut reader = csv::ReaderBuilder::new()
336                    .has_headers(false)
337                    .from_reader(line.as_bytes());
338
339                let record = match reader.records().next() {
340                    Some(Ok(record)) => record,
341                    _ => return Ok(None), // Skip malformed lines
342                };
343
344                if record.len() < 8 {
345                    return Ok(None); // Skip malformed lines
346                }
347
348                // Extract key fields by position with proper CSV parsing
349                let fen = record.get(1).unwrap_or("").trim();
350                let moves = record.get(2).unwrap_or("").trim();
351                let rating: u32 = record.get(3).unwrap_or("0").parse().unwrap_or(0);
352                let themes = record.get(7).unwrap_or("").trim();
353
354                // Apply filters for performance
355                if rating < self.min_rating || rating > self.max_rating {
356                    return Ok(None);
357                }
358
359                if let Some(ref theme_filter) = self.theme_filter {
360                    let has_target_theme = theme_filter.iter().any(|theme| themes.contains(theme));
361                    if !has_target_theme {
362                        return Ok(None);
363                    }
364                }
365
366                // Parse board position
367                let board = match Board::from_str(fen) {
368                    Ok(b) => b,
369                    Err(_) => return Ok(None), // Skip invalid FEN
370                };
371
372                // Parse moves and create training data
373                let move_sequence: Vec<&str> = moves.split_whitespace().collect();
374                if move_sequence.is_empty() {
375                    return Ok(None);
376                }
377
378                // Use the first move as the target move - validate it's legal and board is valid
379                let _target_move = match ChessMove::from_str(move_sequence[0]) {
380                    Ok(m) => {
381                        // Verify the move is legal for this position
382                        use chess::MoveGen;
383                        let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&board).collect();
384
385                        // Skip positions with no legal moves (checkmate/stalemate)
386                        if legal_moves.is_empty() {
387                            return Ok(None);
388                        }
389
390                        if legal_moves.contains(&m) {
391                            m
392                        } else {
393                            return Ok(None); // Skip illegal moves
394                        }
395                    }
396                    Err(_) => return Ok(None), // Skip invalid moves
397                };
398
399                // Calculate evaluation based on puzzle rating and themes
400                let evaluation = self.calculate_puzzle_evaluation(rating, themes, &board);
401
402                Ok(Some(TrainingData {
403                    board,
404                    evaluation,
405                    depth: 1,                 // Puzzle depth
406                    game_id: rating as usize, // Use puzzle rating as game_id for uniqueness
407                }))
408            },
409        ) {
410            Ok(result) => result,
411            Err(_) => Ok(None), // Skip any position that causes panic
412        }
413    }
414
415    /// Calculate position evaluation based on puzzle characteristics
416    fn calculate_puzzle_evaluation(&self, rating: u32, themes: &str, board: &Board) -> f32 {
417        let mut eval = 0.0;
418
419        // Base evaluation from puzzle difficulty (normalized to pawn units)
420        eval += (rating as f32 - 1500.0) / 1000.0; // Much smaller scaling for reasonable range
421
422        // Moderate tactical theme adjustments (in pawn units)
423        if themes.contains("checkmate") || themes.contains("mateIn") {
424            eval += if board.side_to_move() == Color::White {
425                5.0 // Strong advantage (5 pawns)
426            } else {
427                -5.0 // Strong disadvantage (5 pawns)
428            };
429        } else if themes.contains("fork") || themes.contains("pin") {
430            eval += if board.side_to_move() == Color::White {
431                2.0 // Moderate advantage (2 pawns)
432            } else {
433                -2.0 // Moderate disadvantage (2 pawns)
434            };
435        } else if themes.contains("sacrifice") {
436            eval += if board.side_to_move() == Color::White {
437                1.5 // Small advantage (1.5 pawns)
438            } else {
439                -1.5 // Small disadvantage (1.5 pawns)
440            };
441        }
442
443        // Clamp to reasonable evaluation range
444        eval.clamp(-8.0, 8.0)
445    }
446
447    /// Parse a single CSV line into position, evaluation, and best move
448    fn parse_puzzle_line_with_move(
449        &self,
450        line: &str,
451    ) -> Result<Option<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
452        // Wrap entire parsing in panic protection to catch any chess library panics
453        match std::panic::catch_unwind(
454            || -> Result<Option<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
455                // Use proper CSV parsing to handle quotes and commas correctly
456                let mut reader = csv::ReaderBuilder::new()
457                    .has_headers(false)
458                    .from_reader(line.as_bytes());
459
460                let record = match reader.records().next() {
461                    Some(Ok(record)) => record,
462                    _ => return Ok(None), // Skip malformed lines
463                };
464
465                if record.len() < 8 {
466                    return Ok(None); // Skip malformed lines
467                }
468
469                // Extract key fields by position with proper CSV parsing
470                let fen = record.get(1).unwrap_or("").trim();
471                let moves = record.get(2).unwrap_or("").trim();
472                let rating: u32 = record.get(3).unwrap_or("0").parse().unwrap_or(0);
473                let themes = record.get(7).unwrap_or("").trim();
474
475                // Apply filters for performance
476                if rating < self.min_rating || rating > self.max_rating {
477                    return Ok(None);
478                }
479
480                if let Some(ref theme_filter) = self.theme_filter {
481                    let has_target_theme = theme_filter.iter().any(|theme| themes.contains(theme));
482                    if !has_target_theme {
483                        return Ok(None);
484                    }
485                }
486
487                // Parse board position
488                let board = match Board::from_str(fen) {
489                    Ok(b) => b,
490                    Err(_) => return Ok(None), // Skip invalid FEN
491                };
492
493                // Parse moves and create training data
494                let move_sequence: Vec<&str> = moves.split_whitespace().collect();
495                if move_sequence.is_empty() {
496                    return Ok(None);
497                }
498
499                // Use the first move as the target move - validate it's legal and board is valid
500                let target_move = match ChessMove::from_str(move_sequence[0]) {
501                    Ok(m) => {
502                        // Verify the move is legal for this position
503                        use chess::MoveGen;
504                        let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&board).collect();
505
506                        // Skip positions with no legal moves (checkmate/stalemate)
507                        if legal_moves.is_empty() {
508                            return Ok(None);
509                        }
510
511                        if legal_moves.contains(&m) {
512                            m
513                        } else {
514                            // For debugging: could add logging here to track rejection reasons
515                            return Ok(None); // Skip illegal moves
516                        }
517                    }
518                    Err(_) => return Ok(None), // Skip invalid moves
519                };
520
521                // Calculate evaluation based on puzzle rating and themes
522                let evaluation = self.calculate_puzzle_evaluation(rating, themes, &board);
523
524                Ok(Some((board, evaluation, target_move)))
525            },
526        ) {
527            Ok(result) => result,
528            Err(_) => Ok(None), // Skip any position that causes panic
529        }
530    }
531}
532
533impl Default for LichessLoader {
534    fn default() -> Self {
535        Self::new()
536    }
537}
538
539/// Premium feature: Load Lichess puzzles with maximum performance
540pub fn load_lichess_puzzles_premium<P: AsRef<Path>>(
541    csv_path: P,
542) -> Result<Vec<TrainingData>, Box<dyn std::error::Error>> {
543    let loader = LichessLoader::new_premium()
544        .with_rating_range(1200, 2400) // Focus on strong tactical puzzles
545        .with_batch_size(100_000); // Large batches for premium speed
546
547    loader.load_parallel(csv_path)
548}
549
550/// Open source feature: Load limited Lichess puzzles
551pub fn load_lichess_puzzles_basic<P: AsRef<Path>>(
552    csv_path: P,
553    max_puzzles: usize,
554) -> Result<Vec<TrainingData>, Box<dyn std::error::Error>> {
555    let loader = LichessLoader::new()
556        .with_rating_range(1000, 2000) // Basic tactical puzzles
557        .with_batch_size(10_000); // Smaller batches for basic tier
558
559    let mut results = loader.load_parallel(csv_path)?;
560    results.truncate(max_puzzles); // Limit for open source
561    Ok(results)
562}
563
564/// Premium feature: Load Lichess puzzles with moves for pattern recognition
565pub fn load_lichess_puzzles_premium_with_moves<P: AsRef<Path>>(
566    csv_path: P,
567) -> Result<Vec<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
568    let loader = LichessLoader::new_premium()
569        .with_rating_range(1200, 2400) // Focus on strong tactical puzzles
570        .with_batch_size(100_000); // Large batches for premium speed
571
572    loader.load_parallel_with_moves(csv_path)
573}
574
575/// Open source feature: Load limited Lichess puzzles with moves
576pub fn load_lichess_puzzles_basic_with_moves<P: AsRef<Path>>(
577    csv_path: P,
578    max_puzzles: usize,
579) -> Result<Vec<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
580    let loader = LichessLoader::new()
581        .with_rating_range(1000, 2000) // Basic tactical puzzles
582        .with_batch_size(10_000); // Smaller batches for basic tier
583
584    let mut results = loader.load_parallel_with_moves(csv_path)?;
585    results.truncate(max_puzzles); // Limit for open source
586    Ok(results)
587}