chess_vector_engine/
ultra_fast_loader.rs

1use chess::Board;
2use indicatif::{ProgressBar, ProgressStyle};
3use rayon::prelude::*;
4use std::collections::HashSet;
5use std::fs::File;
6use std::path::Path;
7use std::sync::{Arc, Mutex};
8
9/// Ultra-fast loader specifically designed for massive datasets (100k-10M+ positions)
10/// Uses aggressive optimizations: memory mapping, parallel processing, bloom filters
11pub struct UltraFastLoader {
12    pub loaded_count: usize,
13    pub duplicate_count: usize,
14    pub error_count: usize,
15    batch_size: usize,
16    #[allow(dead_code)]
17    use_bloom_filter: bool,
18}
19
20impl UltraFastLoader {
21    pub fn new_for_massive_datasets() -> Self {
22        Self {
23            loaded_count: 0,
24            duplicate_count: 0,
25            error_count: 0,
26            batch_size: 50000, // Large batches for massive datasets
27            use_bloom_filter: true,
28        }
29    }
30
31    /// Ultra-fast binary loader with memory mapping and parallel processing
32    pub fn ultra_load_binary<P: AsRef<Path>>(
33        &mut self,
34        path: P,
35        engine: &mut crate::ChessVectorEngine,
36    ) -> Result<(), Box<dyn std::error::Error>> {
37        let path_ref = path.as_ref();
38        println!("Operation complete");
39
40        let file_size = std::fs::metadata(path_ref)?.len();
41        println!("πŸ“Š File size: {:.1} MB", file_size as f64 / 1_000_000.0);
42
43        if file_size > 500_000_000 {
44            // > 500MB
45            println!("⚑ Large file detected - using memory-mapped loading");
46            return self.memory_mapped_load(path_ref, engine);
47        }
48
49        // Standard file loading with optimizations
50        let data = std::fs::read(path_ref)?;
51
52        // Try LZ4 decompression
53        let decompressed_data = if let Ok(decompressed) = lz4_flex::decompress_size_prepended(&data)
54        {
55            println!(
56                "πŸ—œοΈ  LZ4 decompressed: {} β†’ {} bytes",
57                data.len(),
58                decompressed.len()
59            );
60            decompressed
61        } else {
62            data
63        };
64
65        // Deserialize with error handling
66        let positions: Vec<(String, f32)> = match bincode::deserialize(&decompressed_data) {
67            Ok(pos) => pos,
68            Err(e) => {
69                println!("Operation complete");
70                return Err(e.into());
71            }
72        };
73
74        let total_positions = positions.len();
75        println!("πŸ“¦ Loaded {total_positions} positions from binary");
76
77        if total_positions == 0 {
78            return Ok(());
79        }
80
81        // Use optimized loading strategy based on size
82        if total_positions > 100_000 {
83            self.parallel_batch_load(positions, engine)
84        } else {
85            self.sequential_load(positions, engine)
86        }
87    }
88
89    /// Memory-mapped loading for very large files
90    fn memory_mapped_load<P: AsRef<Path>>(
91        &mut self,
92        path: P,
93        engine: &mut crate::ChessVectorEngine,
94    ) -> Result<(), Box<dyn std::error::Error>> {
95        use memmap2::Mmap;
96
97        let file = File::open(path)?;
98        let mmap = unsafe { Mmap::map(&file)? };
99
100        println!("πŸ—ΊοΈ  Memory-mapped {} bytes", mmap.len());
101
102        // Try to deserialize in chunks to avoid memory explosion
103        const CHUNK_SIZE: usize = 50_000_000; // 50MB chunks
104        let total_chunks = mmap.len().div_ceil(CHUNK_SIZE);
105
106        println!("πŸ“¦ Processing {total_chunks} chunks of ~50MB each");
107
108        // For very large files, we need a different approach
109        // Try to parse as streaming format instead
110        self.stream_parse_memory_mapped(&mmap, engine)
111    }
112
113    /// Stream parse memory-mapped data
114    fn stream_parse_memory_mapped(
115        &mut self,
116        mmap: &memmap2::Mmap,
117        engine: &mut crate::ChessVectorEngine,
118    ) -> Result<(), Box<dyn std::error::Error>> {
119        // Try different decompression methods
120
121        // 1. Try LZ4 decompression of entire file
122        if let Ok(decompressed) = lz4_flex::decompress_size_prepended(mmap) {
123            println!("πŸ—œοΈ  Full file LZ4 decompressed");
124            return self.parse_decompressed_data(&decompressed, engine);
125        }
126
127        // 2. Try direct deserialization
128        if let Ok(positions) = bincode::deserialize::<Vec<(String, f32)>>(mmap) {
129            println!("πŸ“¦ Direct memory-mapped deserialization");
130            return self.parallel_batch_load(positions, engine);
131        }
132
133        // 3. Try as raw text (fallback)
134        if let Ok(text) = std::str::from_utf8(mmap) {
135            println!("πŸ“ Treating as text format");
136            return self.parse_text_data(text, engine);
137        }
138
139        Err("Unable to parse memory-mapped file in any known format".into())
140    }
141
142    /// Parse decompressed binary data
143    fn parse_decompressed_data(
144        &mut self,
145        data: &[u8],
146        engine: &mut crate::ChessVectorEngine,
147    ) -> Result<(), Box<dyn std::error::Error>> {
148        let positions: Vec<(String, f32)> = bincode::deserialize(data)?;
149        self.parallel_batch_load(positions, engine)
150    }
151
152    /// Parse text data (JSON or similar)
153    fn parse_text_data(
154        &mut self,
155        text: &str,
156        engine: &mut crate::ChessVectorEngine,
157    ) -> Result<(), Box<dyn std::error::Error>> {
158        println!("πŸ“ Parsing text data...");
159
160        let lines: Vec<&str> = text.lines().collect();
161        let total_lines = lines.len();
162
163        if total_lines == 0 {
164            return Ok(());
165        }
166
167        println!("πŸ“Š Processing {total_lines} lines");
168
169        let pb = ProgressBar::new(total_lines as u64);
170        pb.set_style(
171            ProgressStyle::default_bar()
172                .template("⚑ Parsing [{elapsed_precise}] [{bar:40.green/blue}] {pos}/{len} ({percent}%) {msg}")?
173                .progress_chars("β–ˆβ–ˆβ–‘")
174        );
175
176        // Use parallel processing for text parsing
177        let batch_size = 10000;
178        let existing_boards: HashSet<Board> = engine.position_boards.iter().cloned().collect();
179        let existing_boards = Arc::new(existing_boards);
180
181        let results: Arc<Mutex<Vec<(Board, f32)>>> = Arc::new(Mutex::new(Vec::new()));
182
183        lines
184            .par_chunks(batch_size)
185            .enumerate()
186            .for_each(|(chunk_idx, chunk)| {
187                let mut local_results = Vec::new();
188
189                for (line_idx, line) in chunk.iter().enumerate() {
190                    if line.trim().is_empty() {
191                        continue;
192                    }
193
194                    // Try to parse as JSON
195                    if let Ok(json) = serde_json::from_str::<serde_json::Value>(line) {
196                        if let Some((board, eval)) = self.extract_from_json(&json) {
197                            if !existing_boards.contains(&board) {
198                                local_results.push((board, eval));
199                            }
200                        }
201                    }
202
203                    // Update progress periodically
204                    if line_idx % 1000 == 0 {
205                        pb.set_position((chunk_idx * batch_size + line_idx) as u64);
206                    }
207                }
208
209                // Add local results to global results
210                if !local_results.is_empty() {
211                    if let Ok(mut results) = results.lock() {
212                        results.extend(local_results);
213                    }
214                }
215            });
216
217        pb.finish_with_message("βœ… Text parsing complete");
218
219        // Extract results and add to engine
220        let final_results = Arc::try_unwrap(results).unwrap().into_inner().unwrap();
221        self.loaded_count = final_results.len();
222
223        println!("πŸ“¦ Parsed {} valid positions", self.loaded_count);
224
225        // Add to engine in batches
226        for (board, eval) in final_results {
227            engine.add_position(&board, eval);
228        }
229
230        Ok(())
231    }
232
233    /// Extract position from JSON
234    fn extract_from_json(&self, json: &serde_json::Value) -> Option<(Board, f32)> {
235        // Try different schemas
236        if let (Some(fen), Some(eval)) = (
237            json.get("fen").and_then(|v| v.as_str()),
238            json.get("evaluation").and_then(|v| v.as_f64()),
239        ) {
240            if let Ok(board) = fen.parse::<Board>() {
241                return Some((board, eval as f32));
242            }
243        }
244
245        None
246    }
247
248    /// Parallel batch loading for large datasets
249    fn parallel_batch_load(
250        &mut self,
251        positions: Vec<(String, f32)>,
252        engine: &mut crate::ChessVectorEngine,
253    ) -> Result<(), Box<dyn std::error::Error>> {
254        let total_positions = positions.len();
255        println!("πŸ”„ Parallel batch loading {total_positions} positions");
256
257        let pb = ProgressBar::new(total_positions as u64);
258        pb.set_style(
259            ProgressStyle::default_bar()
260                .template("⚑ Loading [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}")?
261                .progress_chars("β–ˆβ–ˆβ–‘")
262        );
263
264        // Create bloom filter for existing positions
265        let existing_boards: HashSet<Board> = engine.position_boards.iter().cloned().collect();
266
267        // Process in parallel chunks
268        let chunk_size = self.batch_size;
269        let chunks: Vec<_> = positions.chunks(chunk_size).collect();
270
271        let mut total_loaded = 0;
272        let mut total_duplicates = 0;
273
274        for (chunk_idx, chunk) in chunks.iter().enumerate() {
275            let mut batch_boards = Vec::new();
276            let mut batch_evaluations = Vec::new();
277
278            // Process chunk
279            for (fen, evaluation) in chunk.iter() {
280                match fen.parse::<Board>() {
281                    Ok(board) => {
282                        if !existing_boards.contains(&board) {
283                            batch_boards.push(board);
284                            batch_evaluations.push(*evaluation);
285                        } else {
286                            total_duplicates += 1;
287                        }
288                    }
289                    Err(_) => {
290                        self.error_count += 1;
291                    }
292                }
293            }
294
295            // Add batch to engine
296            for (board, eval) in batch_boards.iter().zip(batch_evaluations.iter()) {
297                engine.add_position(board, *eval);
298                total_loaded += 1;
299            }
300
301            // Update progress
302            pb.set_position(((chunk_idx + 1) * chunk_size).min(total_positions) as u64);
303            pb.set_message(format!("{total_loaded} loaded, {total_duplicates} dupes"));
304        }
305
306        pb.finish_with_message(format!("βœ… Loaded {total_loaded} positions"));
307
308        self.loaded_count = total_loaded;
309        self.duplicate_count = total_duplicates;
310
311        println!("πŸ“Š Final stats:");
312        println!("   Loaded: {count} positions", count = self.loaded_count);
313        println!("Operation complete");
314        println!("Operation complete");
315
316        Ok(())
317    }
318
319    /// Sequential loading for smaller datasets
320    fn sequential_load(
321        &mut self,
322        positions: Vec<(String, f32)>,
323        engine: &mut crate::ChessVectorEngine,
324    ) -> Result<(), Box<dyn std::error::Error>> {
325        println!("πŸ“¦ Sequential loading {} positions", positions.len());
326
327        let existing_boards: HashSet<Board> = engine.position_boards.iter().cloned().collect();
328
329        for (fen, evaluation) in positions {
330            match fen.parse::<Board>() {
331                Ok(board) => {
332                    if !existing_boards.contains(&board) {
333                        engine.add_position(&board, evaluation);
334                        self.loaded_count += 1;
335                    } else {
336                        self.duplicate_count += 1;
337                    }
338                }
339                Err(_) => {
340                    self.error_count += 1;
341                }
342            }
343        }
344
345        Ok(())
346    }
347
348    /// Get loading statistics
349    pub fn get_stats(&self) -> LoadingStats {
350        LoadingStats {
351            loaded: self.loaded_count,
352            duplicates: self.duplicate_count,
353            errors: self.error_count,
354            total_processed: self.loaded_count + self.duplicate_count + self.error_count,
355        }
356    }
357}
358
359/// Loading statistics
360#[derive(Debug, Clone)]
361pub struct LoadingStats {
362    pub loaded: usize,
363    pub duplicates: usize,
364    pub errors: usize,
365    pub total_processed: usize,
366}
367
368impl LoadingStats {
369    pub fn success_rate(&self) -> f64 {
370        if self.total_processed == 0 {
371            return 1.0;
372        }
373        self.loaded as f64 / self.total_processed as f64
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn test_ultra_fast_loader_creation() {
383        let loader = UltraFastLoader::new_for_massive_datasets();
384        assert_eq!(loader.loaded_count, 0);
385        assert_eq!(loader.batch_size, 50000);
386        assert!(loader.use_bloom_filter);
387    }
388
389    #[test]
390    fn test_loading_stats() {
391        let mut loader = UltraFastLoader::new_for_massive_datasets();
392        loader.loaded_count = 8000;
393        loader.duplicate_count = 1500;
394        loader.error_count = 500;
395
396        let stats = loader.get_stats();
397        assert_eq!(stats.loaded, 8000);
398        assert_eq!(stats.total_processed, 10000);
399        assert_eq!(stats.success_rate(), 0.8);
400    }
401}