1#![allow(clippy::type_complexity)]
2use crate::TrainingData;
3use chess::{Board, ChessMove, Color};
4use rayon::prelude::*;
5use serde::Deserialize;
6use std::fs::File;
7use std::io::{BufRead, BufReader};
8use std::path::Path;
9use std::str::FromStr;
10use std::sync::{Arc, Mutex};
11use std::time::Instant;
12
13#[derive(Debug, Deserialize)]
15struct LichessPuzzle {
16 #[serde(rename = "PuzzleId")]
17 #[allow(dead_code)]
18 puzzle_id: String,
19 #[serde(rename = "FEN")]
20 #[allow(dead_code)]
21 fen: String,
22 #[serde(rename = "Moves")]
23 #[allow(dead_code)]
24 moves: String,
25 #[serde(rename = "Rating")]
26 #[allow(dead_code)]
27 rating: u32,
28 #[serde(rename = "RatingDeviation")]
29 #[allow(dead_code)]
30 rating_deviation: u32,
31 #[serde(rename = "Popularity")]
32 #[allow(dead_code)]
33 popularity: i32,
34 #[serde(rename = "NbPlays")]
35 #[allow(dead_code)]
36 nb_plays: u32,
37 #[serde(rename = "Themes")]
38 #[allow(dead_code)]
39 themes: String,
40 #[serde(rename = "GameUrl")]
41 #[allow(dead_code)]
42 game_url: String,
43}
44
45pub struct LichessLoader {
47 min_rating: u32,
49 max_rating: u32,
51 batch_size: usize,
53 num_threads: usize,
55 theme_filter: Option<Vec<String>>,
57}
58
59impl LichessLoader {
60 pub fn new() -> Self {
62 Self {
63 min_rating: 800, max_rating: 2800, batch_size: 10_000, num_threads: num_cpus::get().min(16), theme_filter: None,
68 }
69 }
70
71 pub fn new_premium() -> Self {
73 Self {
74 min_rating: 1000, max_rating: 2500, batch_size: 50_000, num_threads: num_cpus::get().min(32), theme_filter: Some(vec![
79 "checkmate".to_string(),
80 "mateIn2".to_string(),
81 "mateIn3".to_string(),
82 "fork".to_string(),
83 "pin".to_string(),
84 "skewer".to_string(),
85 "discovery".to_string(),
86 "sacrifice".to_string(),
87 "deflection".to_string(),
88 "attraction".to_string(),
89 ]),
90 }
91 }
92
93 pub fn with_rating_range(mut self, min: u32, max: u32) -> Self {
95 self.min_rating = min;
96 self.max_rating = max;
97 self
98 }
99
100 pub fn with_themes(mut self, themes: Vec<String>) -> Self {
102 self.theme_filter = Some(themes);
103 self
104 }
105
106 pub fn with_batch_size(mut self, size: usize) -> Self {
108 self.batch_size = size;
109 self
110 }
111
112 pub fn load_parallel<P: AsRef<Path>>(
114 &self,
115 csv_path: P,
116 ) -> Result<Vec<TrainingData>, Box<dyn std::error::Error>> {
117 let start_time = Instant::now();
118 let path = csv_path.as_ref();
119
120 println!("🔥 Lightning-fast Lichess puzzle loader starting...");
121 println!("Loading from file: {}", path.display());
122 println!("⚡ Parallel processing with {} threads", self.num_threads);
123 println!("Processing data...");
124
125 let pool = rayon::ThreadPoolBuilder::new()
127 .num_threads(self.num_threads)
128 .build()?;
129
130 let results = Arc::new(Mutex::new(Vec::new()));
131 let total_processed = Arc::new(Mutex::new(0usize));
132 let valid_puzzles = Arc::new(Mutex::new(0usize));
133
134 let file = File::open(path)?;
136 let reader = BufReader::with_capacity(1024 * 1024, file); let mut lines = reader.lines();
140 lines.next(); let mut batch = Vec::with_capacity(self.batch_size);
144 let mut batch_count = 0;
145
146 for line in lines {
147 let line = line?;
148 batch.push(line);
149
150 if batch.len() >= self.batch_size {
151 batch_count += 1;
152 let batch_data = std::mem::take(&mut batch);
153
154 let batch_results = self.process_batch_parallel(&pool, batch_data)?;
156
157 {
159 let mut results_guard = results.lock().unwrap();
160 let mut processed_guard = total_processed.lock().unwrap();
161 let mut valid_guard = valid_puzzles.lock().unwrap();
162
163 *processed_guard += self.batch_size;
164 *valid_guard += batch_results.len();
165 results_guard.extend(batch_results);
166
167 if batch_count % 10 == 0 {
168 println!(
169 "📈 Batch {}: Processed {}k puzzles, {} valid positions",
170 batch_count,
171 *processed_guard / 1000,
172 *valid_guard
173 );
174 }
175 }
176
177 batch = Vec::with_capacity(self.batch_size);
178 }
179 }
180
181 if !batch.is_empty() {
183 let batch_results = self.process_batch_parallel(&pool, batch)?;
184 let mut results_guard = results.lock().unwrap();
185 results_guard.extend(batch_results);
186 }
187
188 let final_results = Arc::try_unwrap(results).unwrap().into_inner().unwrap();
189 let elapsed = start_time.elapsed();
190
191 println!("🎉 Lightning loading complete!");
192 println!("⏱️ Time: {:.2}s", elapsed.as_secs_f64());
193 println!("📊 Loaded {} training positions", final_results.len());
194 println!(
195 "🚀 Speed: {:.0} puzzles/second",
196 final_results.len() as f64 / elapsed.as_secs_f64()
197 );
198
199 Ok(final_results)
200 }
201
202 pub fn load_parallel_with_moves<P: AsRef<Path>>(
204 &self,
205 csv_path: P,
206 ) -> Result<Vec<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
207 let start_time = Instant::now();
208 let path = csv_path.as_ref();
209
210 println!("🧠 Lightning-fast Lichess puzzle loader (with moves) starting...");
211 println!("Loading from file: {}", path.display());
212 println!("⚡ Parallel processing with {} threads", self.num_threads);
213 println!("Processing data...");
214
215 let pool = rayon::ThreadPoolBuilder::new()
217 .num_threads(self.num_threads)
218 .build()?;
219
220 let results = Arc::new(Mutex::new(Vec::new()));
221 let total_processed = Arc::new(Mutex::new(0usize));
222 let valid_puzzles = Arc::new(Mutex::new(0usize));
223
224 let file = File::open(path)?;
226 let reader = BufReader::with_capacity(1024 * 1024, file); let mut lines = reader.lines();
230 lines.next(); let mut batch = Vec::with_capacity(self.batch_size);
234 let mut batch_count = 0;
235
236 for line in lines {
237 let line = line?;
238 batch.push(line);
239
240 if batch.len() >= self.batch_size {
241 batch_count += 1;
242 let batch_data = std::mem::take(&mut batch);
243
244 let batch_results = self.process_batch_parallel_with_moves(&pool, batch_data)?;
246
247 {
249 let mut results_guard = results.lock().unwrap();
250 let mut processed_guard = total_processed.lock().unwrap();
251 let mut valid_guard = valid_puzzles.lock().unwrap();
252
253 *processed_guard += self.batch_size;
254 *valid_guard += batch_results.len();
255 results_guard.extend(batch_results);
256
257 if batch_count % 10 == 0 {
258 println!(
259 "📈 Batch {}: Processed {}k puzzles, {} valid moves",
260 batch_count,
261 *processed_guard / 1000,
262 *valid_guard
263 );
264 }
265 }
266
267 batch = Vec::with_capacity(self.batch_size);
268 }
269 }
270
271 if !batch.is_empty() {
273 let batch_results = self.process_batch_parallel_with_moves(&pool, batch)?;
274 let mut results_guard = results.lock().unwrap();
275 results_guard.extend(batch_results);
276 }
277
278 let final_results = Arc::try_unwrap(results).unwrap().into_inner().unwrap();
279 let elapsed = start_time.elapsed();
280
281 println!("🎉 Lightning loading with moves complete!");
282 println!("⏱️ Time: {:.2}s", elapsed.as_secs_f64());
283 println!("🧠 Loaded {} tactical moves", final_results.len());
284 println!(
285 "🚀 Speed: {:.0} puzzles/second",
286 final_results.len() as f64 / elapsed.as_secs_f64()
287 );
288
289 Ok(final_results)
290 }
291
292 fn process_batch_parallel(
294 &self,
295 pool: &rayon::ThreadPool,
296 batch: Vec<String>,
297 ) -> Result<Vec<TrainingData>, Box<dyn std::error::Error>> {
298 let loader = self;
299 let batch_results: Vec<_> = pool.install(|| {
300 batch
301 .par_iter()
302 .filter_map(|line| loader.parse_puzzle_line(line).ok().flatten())
303 .collect()
304 });
305
306 Ok(batch_results)
307 }
308
309 fn process_batch_parallel_with_moves(
311 &self,
312 pool: &rayon::ThreadPool,
313 batch: Vec<String>,
314 ) -> Result<Vec<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
315 let loader = self;
316 let batch_results: Vec<_> = pool.install(|| {
317 batch
318 .par_iter()
319 .filter_map(|line| loader.parse_puzzle_line_with_move(line).ok().flatten())
320 .collect()
321 });
322
323 Ok(batch_results)
324 }
325
326 fn parse_puzzle_line(
328 &self,
329 line: &str,
330 ) -> Result<Option<TrainingData>, Box<dyn std::error::Error>> {
331 match std::panic::catch_unwind(
333 || -> Result<Option<TrainingData>, Box<dyn std::error::Error>> {
334 let mut reader = csv::ReaderBuilder::new()
336 .has_headers(false)
337 .from_reader(line.as_bytes());
338
339 let record = match reader.records().next() {
340 Some(Ok(record)) => record,
341 _ => return Ok(None), };
343
344 if record.len() < 8 {
345 return Ok(None); }
347
348 let fen = record.get(1).unwrap_or("").trim();
350 let moves = record.get(2).unwrap_or("").trim();
351 let rating: u32 = record.get(3).unwrap_or("0").parse().unwrap_or(0);
352 let themes = record.get(7).unwrap_or("").trim();
353
354 if rating < self.min_rating || rating > self.max_rating {
356 return Ok(None);
357 }
358
359 if let Some(ref theme_filter) = self.theme_filter {
360 let has_target_theme = theme_filter.iter().any(|theme| themes.contains(theme));
361 if !has_target_theme {
362 return Ok(None);
363 }
364 }
365
366 let board = match Board::from_str(fen) {
368 Ok(b) => b,
369 Err(_) => return Ok(None), };
371
372 let move_sequence: Vec<&str> = moves.split_whitespace().collect();
374 if move_sequence.is_empty() {
375 return Ok(None);
376 }
377
378 let _target_move = match ChessMove::from_str(move_sequence[0]) {
380 Ok(m) => {
381 use chess::MoveGen;
383 let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&board).collect();
384
385 if legal_moves.is_empty() {
387 return Ok(None);
388 }
389
390 if legal_moves.contains(&m) {
391 m
392 } else {
393 return Ok(None); }
395 }
396 Err(_) => return Ok(None), };
398
399 let evaluation = self.calculate_puzzle_evaluation(rating, themes, &board);
401
402 Ok(Some(TrainingData {
403 board,
404 evaluation,
405 depth: 1, game_id: rating as usize, }))
408 },
409 ) {
410 Ok(result) => result,
411 Err(_) => Ok(None), }
413 }
414
415 fn calculate_puzzle_evaluation(&self, rating: u32, themes: &str, board: &Board) -> f32 {
417 let mut eval = 0.0;
418
419 eval += (rating as f32 - 1500.0) / 1000.0; if themes.contains("checkmate") || themes.contains("mateIn") {
424 eval += if board.side_to_move() == Color::White {
425 5.0 } else {
427 -5.0 };
429 } else if themes.contains("fork") || themes.contains("pin") {
430 eval += if board.side_to_move() == Color::White {
431 2.0 } else {
433 -2.0 };
435 } else if themes.contains("sacrifice") {
436 eval += if board.side_to_move() == Color::White {
437 1.5 } else {
439 -1.5 };
441 }
442
443 eval.clamp(-8.0, 8.0)
445 }
446
447 fn parse_puzzle_line_with_move(
449 &self,
450 line: &str,
451 ) -> Result<Option<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
452 match std::panic::catch_unwind(
454 || -> Result<Option<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
455 let mut reader = csv::ReaderBuilder::new()
457 .has_headers(false)
458 .from_reader(line.as_bytes());
459
460 let record = match reader.records().next() {
461 Some(Ok(record)) => record,
462 _ => return Ok(None), };
464
465 if record.len() < 8 {
466 return Ok(None); }
468
469 let fen = record.get(1).unwrap_or("").trim();
471 let moves = record.get(2).unwrap_or("").trim();
472 let rating: u32 = record.get(3).unwrap_or("0").parse().unwrap_or(0);
473 let themes = record.get(7).unwrap_or("").trim();
474
475 if rating < self.min_rating || rating > self.max_rating {
477 return Ok(None);
478 }
479
480 if let Some(ref theme_filter) = self.theme_filter {
481 let has_target_theme = theme_filter.iter().any(|theme| themes.contains(theme));
482 if !has_target_theme {
483 return Ok(None);
484 }
485 }
486
487 let board = match Board::from_str(fen) {
489 Ok(b) => b,
490 Err(_) => return Ok(None), };
492
493 let move_sequence: Vec<&str> = moves.split_whitespace().collect();
495 if move_sequence.is_empty() {
496 return Ok(None);
497 }
498
499 let target_move = match ChessMove::from_str(move_sequence[0]) {
501 Ok(m) => {
502 use chess::MoveGen;
504 let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&board).collect();
505
506 if legal_moves.is_empty() {
508 return Ok(None);
509 }
510
511 if legal_moves.contains(&m) {
512 m
513 } else {
514 return Ok(None); }
517 }
518 Err(_) => return Ok(None), };
520
521 let evaluation = self.calculate_puzzle_evaluation(rating, themes, &board);
523
524 Ok(Some((board, evaluation, target_move)))
525 },
526 ) {
527 Ok(result) => result,
528 Err(_) => Ok(None), }
530 }
531}
532
533impl Default for LichessLoader {
534 fn default() -> Self {
535 Self::new()
536 }
537}
538
539pub fn load_lichess_puzzles_premium<P: AsRef<Path>>(
541 csv_path: P,
542) -> Result<Vec<TrainingData>, Box<dyn std::error::Error>> {
543 let loader = LichessLoader::new_premium()
544 .with_rating_range(1200, 2400) .with_batch_size(100_000); loader.load_parallel(csv_path)
548}
549
550pub fn load_lichess_puzzles_basic<P: AsRef<Path>>(
552 csv_path: P,
553 max_puzzles: usize,
554) -> Result<Vec<TrainingData>, Box<dyn std::error::Error>> {
555 let loader = LichessLoader::new()
556 .with_rating_range(1000, 2000) .with_batch_size(10_000); let mut results = loader.load_parallel(csv_path)?;
560 results.truncate(max_puzzles); Ok(results)
562}
563
564pub fn load_lichess_puzzles_premium_with_moves<P: AsRef<Path>>(
566 csv_path: P,
567) -> Result<Vec<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
568 let loader = LichessLoader::new_premium()
569 .with_rating_range(1200, 2400) .with_batch_size(100_000); loader.load_parallel_with_moves(csv_path)
573}
574
575pub fn load_lichess_puzzles_basic_with_moves<P: AsRef<Path>>(
577 csv_path: P,
578 max_puzzles: usize,
579) -> Result<Vec<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
580 let loader = LichessLoader::new()
581 .with_rating_range(1000, 2000) .with_batch_size(10_000); let mut results = loader.load_parallel_with_moves(csv_path)?;
585 results.truncate(max_puzzles); Ok(results)
587}