chess_vector_engine/
streaming_loader.rs

1use chess::Board;
2use indicatif::{ProgressBar, ProgressStyle};
3use serde_json;
4use std::collections::HashSet;
5use std::fs::File;
6use std::io::{BufRead, BufReader};
7use std::path::Path;
8
9/// Ultra-fast streaming loader for massive datasets
10/// Optimized for loading 100k-1M+ positions efficiently
11pub struct StreamingLoader {
12    pub loaded_count: usize,
13    pub duplicate_count: usize,
14    pub total_processed: usize,
15}
16
17impl StreamingLoader {
18    pub fn new() -> Self {
19        Self {
20            loaded_count: 0,
21            duplicate_count: 0,
22            total_processed: 0,
23        }
24    }
25
26    /// Stream-load massive JSON files with minimal memory usage
27    /// Uses streaming JSON parser and batched processing
28    pub fn stream_load_json<P: AsRef<Path>>(
29        &mut self,
30        path: P,
31        engine: &mut crate::ChessVectorEngine,
32        batch_size: usize,
33    ) -> Result<(), Box<dyn std::error::Error>> {
34        let path_ref = path.as_ref();
35        println!("Operation complete");
36
37        let file = File::open(path_ref)?;
38        let reader = BufReader::with_capacity(64 * 1024, file); // 64KB buffer
39
40        // Estimate total lines for progress tracking
41        let total_lines = self.estimate_line_count(path_ref)?;
42        println!("📊 Estimated {total_lines} lines to process");
43
44        let pb = ProgressBar::new(total_lines as u64);
45        pb.set_style(
46            ProgressStyle::default_bar()
47                .template("⚡ Streaming [{elapsed_precise}] [{bar:40.green/blue}] {pos}/{len} ({percent}%) {msg}")?
48                .progress_chars("██░")
49        );
50
51        // Use existing positions as a bloom filter approximation
52        let existing_boards: HashSet<Board> = engine.position_boards.iter().cloned().collect();
53        let initial_size = existing_boards.len();
54
55        // Batch processing variables
56        let mut batch_boards = Vec::with_capacity(batch_size);
57        let mut batch_evaluations = Vec::with_capacity(batch_size);
58        let mut line_count = 0;
59
60        // Stream process each line
61        for line_result in reader.lines() {
62            let line = line_result?;
63            line_count += 1;
64
65            if line.trim().is_empty() {
66                continue;
67            }
68
69            // Parse JSON line
70            if let Ok(json) = serde_json::from_str::<serde_json::Value>(&line) {
71                if let Some((board, evaluation)) = self.extract_position_data(&json)? {
72                    // Quick duplicate check (not perfect but fast)
73                    if !existing_boards.contains(&board) {
74                        batch_boards.push(board);
75                        batch_evaluations.push(evaluation);
76
77                        // Process batch when full
78                        if batch_boards.len() >= batch_size {
79                            self.process_batch(engine, &mut batch_boards, &mut batch_evaluations)?;
80
81                            pb.set_message(format!(
82                                "{loaded} loaded, {dupes} dupes",
83                                loaded = self.loaded_count, dupes = self.duplicate_count
84                            ));
85                        }
86                    } else {
87                        self.duplicate_count += 1;
88                    }
89                }
90            }
91
92            self.total_processed += 1;
93
94            // Update progress every 1000 lines
95            if line_count % 1000 == 0 {
96                pb.set_position(line_count as u64);
97            }
98        }
99
100        // Process remaining batch
101        if !batch_boards.is_empty() {
102            self.process_batch(engine, &mut batch_boards, &mut batch_evaluations)?;
103        }
104
105        pb.finish_with_message(format!(
106            "✅ Complete: {} loaded, {} duplicates from {} lines",
107            self.loaded_count, self.duplicate_count, line_count
108        ));
109
110        let new_positions = engine.position_boards.len() - initial_size;
111        println!("🎯 Added {new_positions} new positions to engine");
112
113        Ok(())
114    }
115
116    /// Ultra-fast binary format streaming loader
117    /// For pre-processed binary training data
118    pub fn stream_load_binary<P: AsRef<Path>>(
119        &mut self,
120        path: P,
121        engine: &mut crate::ChessVectorEngine,
122    ) -> Result<(), Box<dyn std::error::Error>> {
123        let path_ref = path.as_ref();
124        println!("Operation complete");
125
126        // Load binary data
127        let data = std::fs::read(path_ref)?;
128        println!("📦 Read {} bytes", data.len());
129
130        // Try LZ4 decompression first
131        let decompressed_data = if let Ok(decompressed) = lz4_flex::decompress_size_prepended(&data)
132        {
133            println!(
134                "🗜️  LZ4 decompressed: {} → {} bytes",
135                data.len(),
136                decompressed.len()
137            );
138            decompressed
139        } else {
140            data
141        };
142
143        // Deserialize positions
144        let positions: Vec<(String, f32)> = bincode::deserialize(&decompressed_data)?;
145        let total_positions = positions.len();
146        println!("📊 Loaded {total_positions} positions from binary");
147
148        if total_positions == 0 {
149            return Ok(());
150        }
151
152        let pb = ProgressBar::new(total_positions as u64);
153        pb.set_style(
154            ProgressStyle::default_bar()
155                .template("⚡ Binary loading [{elapsed_precise}] [{bar:40.blue/green}] {pos}/{len} ({percent}%) {msg}")?
156                .progress_chars("██░")
157        );
158
159        // Use existing positions for duplicate detection
160        let existing_boards: HashSet<Board> = engine.position_boards.iter().cloned().collect();
161
162        // Process in large batches for efficiency
163        const BATCH_SIZE: usize = 10000;
164        let mut processed = 0;
165
166        for chunk in positions.chunks(BATCH_SIZE) {
167            let mut batch_boards = Vec::with_capacity(BATCH_SIZE);
168            let mut batch_evaluations = Vec::with_capacity(BATCH_SIZE);
169
170            for (fen, evaluation) in chunk {
171                if let Ok(board) = fen.parse::<Board>() {
172                    if !existing_boards.contains(&board) {
173                        batch_boards.push(board);
174                        batch_evaluations.push(*evaluation);
175                    } else {
176                        self.duplicate_count += 1;
177                    }
178                }
179                processed += 1;
180            }
181
182            // Add batch to engine
183            if !batch_boards.is_empty() {
184                self.process_batch(engine, &mut batch_boards, &mut batch_evaluations)?;
185            }
186
187            pb.set_position(processed as u64);
188            pb.set_message(format!("{count} loaded", count = self.loaded_count));
189        }
190
191        pb.finish_with_message(format!("✅ Loaded {count} positions", count = self.loaded_count));
192
193        Ok(())
194    }
195
196    /// Process a batch of positions efficiently
197    fn process_batch(
198        &mut self,
199        engine: &mut crate::ChessVectorEngine,
200        boards: &mut Vec<Board>,
201        evaluations: &mut Vec<f32>,
202    ) -> Result<(), Box<dyn std::error::Error>> {
203        // Add all positions in batch
204        for (board, evaluation) in boards.iter().zip(evaluations.iter()) {
205            engine.add_position(board, *evaluation);
206            self.loaded_count += 1;
207        }
208
209        // Clear for next batch
210        boards.clear();
211        evaluations.clear();
212
213        Ok(())
214    }
215
216    /// Extract position data from JSON value
217    fn extract_position_data(
218        &self,
219        json: &serde_json::Value,
220    ) -> Result<Option<(Board, f32)>, Box<dyn std::error::Error>> {
221        // Try different JSON schemas
222        if let (Some(fen), Some(eval)) = (
223            json.get("fen").and_then(|v| v.as_str()),
224            json.get("evaluation").and_then(|v| v.as_f64()),
225        ) {
226            if let Ok(board) = fen.parse::<Board>() {
227                return Ok(Some((board, eval as f32)));
228            }
229        }
230
231        if let (Some(fen), Some(eval)) = (
232            json.get("board").and_then(|v| v.as_str()),
233            json.get("eval").and_then(|v| v.as_f64()),
234        ) {
235            if let Ok(board) = fen.parse::<Board>() {
236                return Ok(Some((board, eval as f32)));
237            }
238        }
239
240        if let (Some(fen), Some(eval)) = (
241            json.get("position").and_then(|v| v.as_str()),
242            json.get("score").and_then(|v| v.as_f64()),
243        ) {
244            if let Ok(board) = fen.parse::<Board>() {
245                return Ok(Some((board, eval as f32)));
246            }
247        }
248
249        Ok(None)
250    }
251
252    /// Estimate line count for progress tracking
253    fn estimate_line_count<P: AsRef<Path>>(
254        &self,
255        path: P,
256    ) -> Result<usize, Box<dyn std::error::Error>> {
257        use std::io::Read;
258
259        let path_ref = path.as_ref();
260        let file = File::open(path_ref)?;
261        let mut reader = BufReader::new(file);
262
263        // Sample first 1MB to estimate
264        let mut sample = vec![0u8; 1024 * 1024];
265        let bytes_read = reader.read(&mut sample)?;
266
267        if bytes_read == 0 {
268            return Ok(0);
269        }
270
271        // Count newlines in sample
272        let newlines_in_sample = sample[..bytes_read].iter().filter(|&&b| b == b'\n').count();
273
274        // Get total file size by re-opening the file
275        let total_size = std::fs::metadata(path_ref)?.len() as usize;
276
277        if bytes_read >= total_size {
278            // We read the whole file
279            return Ok(newlines_in_sample);
280        }
281
282        // Estimate based on sample
283        let estimated_lines = (newlines_in_sample * total_size) / bytes_read;
284        Ok(estimated_lines)
285    }
286}
287
288impl Default for StreamingLoader {
289    fn default() -> Self {
290        Self::new()
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use std::io::Write;
298    use tempfile::NamedTempFile;
299
300    #[test]
301    fn test_json_extraction() {
302        let loader = StreamingLoader::new();
303
304        let json = serde_json::json!({
305            "fen": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
306            "evaluation": 0.25
307        });
308
309        let result = loader.extract_position_data(&json).unwrap();
310        assert!(result.is_some());
311
312        let (board, eval) = result.unwrap();
313        assert_eq!(board, Board::default());
314        assert_eq!(eval, 0.25);
315    }
316
317    #[test]
318    fn test_line_estimation() {
319        let loader = StreamingLoader::new();
320
321        // Create a temporary file with known line count
322        let mut temp_file = NamedTempFile::new().unwrap();
323        for _i in 0..100 {
324            writeln!(temp_file, "Loading complete").unwrap();
325        }
326
327        let estimated = loader.estimate_line_count(temp_file.path()).unwrap();
328        // Should be approximately 100 (within reasonable range)
329        assert!((80..=120).contains(&estimated));
330    }
331}