1#![allow(clippy::type_complexity)]
2use crate::TrainingData;
3use chess::{Board, ChessMove, Color};
4use rayon::prelude::*;
5use serde::Deserialize;
6use std::fs::File;
7use std::io::{BufRead, BufReader};
8use std::path::Path;
9use std::str::FromStr;
10use std::sync::{Arc, Mutex};
11use std::time::Instant;
12
13#[derive(Debug, Deserialize)]
15struct LichessPuzzle {
16 #[serde(rename = "PuzzleId")]
17 #[allow(dead_code)]
18 puzzle_id: String,
19 #[serde(rename = "FEN")]
20 #[allow(dead_code)]
21 fen: String,
22 #[serde(rename = "Moves")]
23 #[allow(dead_code)]
24 moves: String,
25 #[serde(rename = "Rating")]
26 #[allow(dead_code)]
27 rating: u32,
28 #[serde(rename = "RatingDeviation")]
29 #[allow(dead_code)]
30 rating_deviation: u32,
31 #[serde(rename = "Popularity")]
32 #[allow(dead_code)]
33 popularity: i32,
34 #[serde(rename = "NbPlays")]
35 #[allow(dead_code)]
36 nb_plays: u32,
37 #[serde(rename = "Themes")]
38 #[allow(dead_code)]
39 themes: String,
40 #[serde(rename = "GameUrl")]
41 #[allow(dead_code)]
42 game_url: String,
43}
44
45pub struct LichessLoader {
47 min_rating: u32,
49 max_rating: u32,
51 batch_size: usize,
53 num_threads: usize,
55 theme_filter: Option<Vec<String>>,
57}
58
59impl LichessLoader {
60 pub fn new() -> Self {
62 Self {
63 min_rating: 800, max_rating: 2800, batch_size: 10_000, num_threads: num_cpus::get().min(16), theme_filter: None,
68 }
69 }
70
71 pub fn new_premium() -> Self {
73 Self {
74 min_rating: 1000, max_rating: 2500, batch_size: 50_000, num_threads: num_cpus::get().min(32), theme_filter: Some(vec![
79 "checkmate".to_string(),
80 "mateIn2".to_string(),
81 "mateIn3".to_string(),
82 "fork".to_string(),
83 "pin".to_string(),
84 "skewer".to_string(),
85 "discovery".to_string(),
86 "sacrifice".to_string(),
87 "deflection".to_string(),
88 "attraction".to_string(),
89 ]),
90 }
91 }
92
93 pub fn with_rating_range(mut self, min: u32, max: u32) -> Self {
95 self.min_rating = min;
96 self.max_rating = max;
97 self
98 }
99
100 pub fn with_themes(mut self, themes: Vec<String>) -> Self {
102 self.theme_filter = Some(themes);
103 self
104 }
105
106 pub fn with_batch_size(mut self, size: usize) -> Self {
108 self.batch_size = size;
109 self
110 }
111
112 pub fn load_parallel<P: AsRef<Path>>(
114 &self,
115 csv_path: P,
116 ) -> Result<Vec<TrainingData>, Box<dyn std::error::Error>> {
117 let start_time = Instant::now();
118 let path = csv_path.as_ref();
119
120 println!("🔥 Lightning-fast Lichess puzzle loader starting...");
121 println!("Loading from file: {}", path.display());
122 println!("⚡ Parallel processing with {} threads", self.num_threads);
123 println!("Processing data...");
124
125 let pool = rayon::ThreadPoolBuilder::new()
127 .num_threads(self.num_threads)
128 .build()?;
129
130 let results = Arc::new(Mutex::new(Vec::new()));
131 let total_processed = Arc::new(Mutex::new(0usize));
132 let valid_puzzles = Arc::new(Mutex::new(0usize));
133
134 let file = File::open(path)?;
136 let reader = BufReader::with_capacity(1024 * 1024, file); let mut lines = reader.lines();
140 lines.next(); let mut batch = Vec::with_capacity(self.batch_size);
144 let mut batch_count = 0;
145
146 for line in lines {
147 let line = line?;
148 batch.push(line);
149
150 if batch.len() >= self.batch_size {
151 batch_count += 1;
152 let batch_data = std::mem::take(&mut batch);
153
154 let batch_results = self.process_batch_parallel(&pool, batch_data)?;
156
157 {
159 let mut results_guard = results.lock().unwrap();
160 let mut processed_guard = total_processed.lock().unwrap();
161 let mut valid_guard = valid_puzzles.lock().unwrap();
162
163 *processed_guard += self.batch_size;
164 *valid_guard += batch_results.len();
165 results_guard.extend(batch_results);
166
167 if batch_count % 10 == 0 {
168 println!(
169 "📈 Batch {}: Processed {}k puzzles, {} valid positions",
170 batch_count,
171 *processed_guard / 1000,
172 *valid_guard
173 );
174 }
175 }
176
177 batch = Vec::with_capacity(self.batch_size);
178 }
179 }
180
181 if !batch.is_empty() {
183 let batch_results = self.process_batch_parallel(&pool, batch)?;
184 let mut results_guard = results.lock().unwrap();
185 results_guard.extend(batch_results);
186 }
187
188 let final_results = Arc::try_unwrap(results).unwrap().into_inner().unwrap();
189 let elapsed = start_time.elapsed();
190
191 println!("🎉 Lightning loading complete!");
192 println!("⏱️ Time: {:.2}s", elapsed.as_secs_f64());
193 println!("📊 Loaded {} training positions", final_results.len());
194 println!(
195 "🚀 Speed: {:.0} puzzles/second",
196 final_results.len() as f64 / elapsed.as_secs_f64()
197 );
198
199 Ok(final_results)
200 }
201
202 pub fn load_parallel_with_moves<P: AsRef<Path>>(
204 &self,
205 csv_path: P,
206 ) -> Result<Vec<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
207 let start_time = Instant::now();
208 let path = csv_path.as_ref();
209
210 println!("🧠 Lightning-fast Lichess puzzle loader (with moves) starting...");
211 println!("Loading from file: {}", path.display());
212 println!("⚡ Parallel processing with {} threads", self.num_threads);
213 println!("Processing data...");
214
215 let pool = rayon::ThreadPoolBuilder::new()
217 .num_threads(self.num_threads)
218 .build()?;
219
220 let results = Arc::new(Mutex::new(Vec::new()));
221 let total_processed = Arc::new(Mutex::new(0usize));
222 let valid_puzzles = Arc::new(Mutex::new(0usize));
223
224 let file = File::open(path)?;
226 let reader = BufReader::with_capacity(1024 * 1024, file); let mut lines = reader.lines();
230 lines.next(); let mut batch = Vec::with_capacity(self.batch_size);
234 let mut batch_count = 0;
235
236 for line in lines {
237 let line = line?;
238 batch.push(line);
239
240 if batch.len() >= self.batch_size {
241 batch_count += 1;
242 let batch_data = std::mem::take(&mut batch);
243
244 let batch_results = self.process_batch_parallel_with_moves(&pool, batch_data)?;
246
247 {
249 let mut results_guard = results.lock().unwrap();
250 let mut processed_guard = total_processed.lock().unwrap();
251 let mut valid_guard = valid_puzzles.lock().unwrap();
252
253 *processed_guard += self.batch_size;
254 *valid_guard += batch_results.len();
255 results_guard.extend(batch_results);
256
257 if batch_count % 10 == 0 {
258 println!(
259 "📈 Batch {}: Processed {}k puzzles, {} valid moves",
260 batch_count,
261 *processed_guard / 1000,
262 *valid_guard
263 );
264 }
265 }
266
267 batch = Vec::with_capacity(self.batch_size);
268 }
269 }
270
271 if !batch.is_empty() {
273 let batch_results = self.process_batch_parallel_with_moves(&pool, batch)?;
274 let mut results_guard = results.lock().unwrap();
275 results_guard.extend(batch_results);
276 }
277
278 let final_results = Arc::try_unwrap(results).unwrap().into_inner().unwrap();
279 let elapsed = start_time.elapsed();
280
281 println!("🎉 Lightning loading with moves complete!");
282 println!("⏱️ Time: {:.2}s", elapsed.as_secs_f64());
283 println!("🧠 Loaded {} tactical moves", final_results.len());
284 println!(
285 "🚀 Speed: {:.0} puzzles/second",
286 final_results.len() as f64 / elapsed.as_secs_f64()
287 );
288
289 Ok(final_results)
290 }
291
292 fn process_batch_parallel(
294 &self,
295 pool: &rayon::ThreadPool,
296 batch: Vec<String>,
297 ) -> Result<Vec<TrainingData>, Box<dyn std::error::Error>> {
298 let loader = self;
299 let batch_results: Vec<_> = pool.install(|| {
300 batch
301 .par_iter()
302 .filter_map(|line| loader.parse_puzzle_line(line).ok().flatten())
303 .collect()
304 });
305
306 Ok(batch_results)
307 }
308
309 fn process_batch_parallel_with_moves(
311 &self,
312 pool: &rayon::ThreadPool,
313 batch: Vec<String>,
314 ) -> Result<Vec<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
315 let loader = self;
316 let batch_results: Vec<_> = pool.install(|| {
317 batch
318 .par_iter()
319 .filter_map(|line| loader.parse_puzzle_line_with_move(line).ok().flatten())
320 .collect()
321 });
322
323 Ok(batch_results)
324 }
325
326 fn parse_puzzle_line(
328 &self,
329 line: &str,
330 ) -> Result<Option<TrainingData>, Box<dyn std::error::Error>> {
331 match std::panic::catch_unwind(
333 || -> Result<Option<TrainingData>, Box<dyn std::error::Error>> {
334 let mut reader = csv::ReaderBuilder::new()
336 .has_headers(false)
337 .from_reader(line.as_bytes());
338
339 let record = match reader.records().next() {
340 Some(Ok(record)) => record,
341 _ => return Ok(None), };
343
344 if record.len() < 8 {
345 return Ok(None); }
347
348 let fen = record.get(1).unwrap_or("").trim();
350 let moves = record.get(2).unwrap_or("").trim();
351 let rating: u32 = record.get(3).unwrap_or("0").parse().unwrap_or(0);
352 let themes = record.get(7).unwrap_or("").trim();
353
354 if rating < self.min_rating || rating > self.max_rating {
356 return Ok(None);
357 }
358
359 if let Some(ref theme_filter) = self.theme_filter {
360 let has_target_theme = theme_filter.iter().any(|theme| themes.contains(theme));
361 if !has_target_theme {
362 return Ok(None);
363 }
364 }
365
366 let board = match Board::from_str(fen) {
368 Ok(b) => b,
369 Err(_) => return Ok(None), };
371
372 let move_sequence: Vec<&str> = moves.split_whitespace().collect();
374 if move_sequence.is_empty() {
375 return Ok(None);
376 }
377
378 let _target_move = match ChessMove::from_str(move_sequence[0]) {
380 Ok(m) => {
381 use chess::MoveGen;
383 let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&board).collect();
384
385 if legal_moves.is_empty() {
387 return Ok(None);
388 }
389
390 if legal_moves.contains(&m) {
391 m
392 } else {
393 return Ok(None); }
395 }
396 Err(_) => return Ok(None), };
398
399 let evaluation = self.calculate_puzzle_evaluation(rating, themes, &board);
401
402 Ok(Some(TrainingData {
403 board,
404 evaluation,
405 depth: 1, game_id: rating as usize, }))
408 },
409 ) {
410 Ok(result) => result,
411 Err(_) => Ok(None), }
413 }
414
415 fn calculate_puzzle_evaluation(&self, rating: u32, themes: &str, board: &Board) -> f32 {
417 let mut eval = 0.0;
418
419 eval += (rating as f32 - 1500.0) / 400.0; if themes.contains("checkmate") || themes.contains("mateIn") {
424 eval = if board.side_to_move() == Color::White {
425 500.0
426 } else {
427 -500.0
428 };
429 } else if themes.contains("fork") || themes.contains("pin") {
430 eval += if board.side_to_move() == Color::White {
431 300.0
432 } else {
433 -300.0
434 };
435 } else if themes.contains("sacrifice") {
436 eval += if board.side_to_move() == Color::White {
437 200.0
438 } else {
439 -200.0
440 };
441 }
442
443 eval
444 }
445
446 fn parse_puzzle_line_with_move(
448 &self,
449 line: &str,
450 ) -> Result<Option<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
451 match std::panic::catch_unwind(
453 || -> Result<Option<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
454 let mut reader = csv::ReaderBuilder::new()
456 .has_headers(false)
457 .from_reader(line.as_bytes());
458
459 let record = match reader.records().next() {
460 Some(Ok(record)) => record,
461 _ => return Ok(None), };
463
464 if record.len() < 8 {
465 return Ok(None); }
467
468 let fen = record.get(1).unwrap_or("").trim();
470 let moves = record.get(2).unwrap_or("").trim();
471 let rating: u32 = record.get(3).unwrap_or("0").parse().unwrap_or(0);
472 let themes = record.get(7).unwrap_or("").trim();
473
474 if rating < self.min_rating || rating > self.max_rating {
476 return Ok(None);
477 }
478
479 if let Some(ref theme_filter) = self.theme_filter {
480 let has_target_theme = theme_filter.iter().any(|theme| themes.contains(theme));
481 if !has_target_theme {
482 return Ok(None);
483 }
484 }
485
486 let board = match Board::from_str(fen) {
488 Ok(b) => b,
489 Err(_) => return Ok(None), };
491
492 let move_sequence: Vec<&str> = moves.split_whitespace().collect();
494 if move_sequence.is_empty() {
495 return Ok(None);
496 }
497
498 let target_move = match ChessMove::from_str(move_sequence[0]) {
500 Ok(m) => {
501 use chess::MoveGen;
503 let legal_moves: Vec<ChessMove> = MoveGen::new_legal(&board).collect();
504
505 if legal_moves.is_empty() {
507 return Ok(None);
508 }
509
510 if legal_moves.contains(&m) {
511 m
512 } else {
513 return Ok(None); }
516 }
517 Err(_) => return Ok(None), };
519
520 let evaluation = self.calculate_puzzle_evaluation(rating, themes, &board);
522
523 Ok(Some((board, evaluation, target_move)))
524 },
525 ) {
526 Ok(result) => result,
527 Err(_) => Ok(None), }
529 }
530}
531
532impl Default for LichessLoader {
533 fn default() -> Self {
534 Self::new()
535 }
536}
537
538pub fn load_lichess_puzzles_premium<P: AsRef<Path>>(
540 csv_path: P,
541) -> Result<Vec<TrainingData>, Box<dyn std::error::Error>> {
542 let loader = LichessLoader::new_premium()
543 .with_rating_range(1200, 2400) .with_batch_size(100_000); loader.load_parallel(csv_path)
547}
548
549pub fn load_lichess_puzzles_basic<P: AsRef<Path>>(
551 csv_path: P,
552 max_puzzles: usize,
553) -> Result<Vec<TrainingData>, Box<dyn std::error::Error>> {
554 let loader = LichessLoader::new()
555 .with_rating_range(1000, 2000) .with_batch_size(10_000); let mut results = loader.load_parallel(csv_path)?;
559 results.truncate(max_puzzles); Ok(results)
561}
562
563pub fn load_lichess_puzzles_premium_with_moves<P: AsRef<Path>>(
565 csv_path: P,
566) -> Result<Vec<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
567 let loader = LichessLoader::new_premium()
568 .with_rating_range(1200, 2400) .with_batch_size(100_000); loader.load_parallel_with_moves(csv_path)
572}
573
574pub fn load_lichess_puzzles_basic_with_moves<P: AsRef<Path>>(
576 csv_path: P,
577 max_puzzles: usize,
578) -> Result<Vec<(Board, f32, ChessMove)>, Box<dyn std::error::Error>> {
579 let loader = LichessLoader::new()
580 .with_rating_range(1000, 2000) .with_batch_size(10_000); let mut results = loader.load_parallel_with_moves(csv_path)?;
584 results.truncate(max_puzzles); Ok(results)
586}