chess_vector_engine/
streaming_loader.rs

1use chess::Board;
2use indicatif::{ProgressBar, ProgressStyle};
3use serde_json;
4use std::collections::HashSet;
5use std::fs::File;
6use std::io::{BufRead, BufReader};
7use std::path::Path;
8
9/// Ultra-fast streaming loader for massive datasets
10/// Optimized for loading 100k-1M+ positions efficiently
11pub struct StreamingLoader {
12    pub loaded_count: usize,
13    pub duplicate_count: usize,
14    pub total_processed: usize,
15}
16
17impl StreamingLoader {
18    pub fn new() -> Self {
19        Self {
20            loaded_count: 0,
21            duplicate_count: 0,
22            total_processed: 0,
23        }
24    }
25
26    /// Stream-load massive JSON files with minimal memory usage
27    /// Uses streaming JSON parser and batched processing
28    pub fn stream_load_json<P: AsRef<Path>>(
29        &mut self,
30        path: P,
31        engine: &mut crate::ChessVectorEngine,
32        batch_size: usize,
33    ) -> Result<(), Box<dyn std::error::Error>> {
34        let path_ref = path.as_ref();
35        println!("Operation complete");
36
37        let file = File::open(path_ref)?;
38        let reader = BufReader::with_capacity(64 * 1024, file); // 64KB buffer
39
40        // Estimate total lines for progress tracking
41        let total_lines = self.estimate_line_count(path_ref)?;
42        println!("📊 Estimated {total_lines} lines to process");
43
44        let pb = ProgressBar::new(total_lines as u64);
45        pb.set_style(
46            ProgressStyle::default_bar()
47                .template("⚡ Streaming [{elapsed_precise}] [{bar:40.green/blue}] {pos}/{len} ({percent}%) {msg}")?
48                .progress_chars("██░")
49        );
50
51        // Use existing positions as a bloom filter approximation
52        let existing_boards: HashSet<Board> = engine.position_boards.iter().cloned().collect();
53        let initial_size = existing_boards.len();
54
55        // Batch processing variables
56        let mut batch_boards = Vec::with_capacity(batch_size);
57        let mut batch_evaluations = Vec::with_capacity(batch_size);
58        let mut line_count = 0;
59
60        // Stream process each line
61        for line_result in reader.lines() {
62            let line = line_result?;
63            line_count += 1;
64
65            if line.trim().is_empty() {
66                continue;
67            }
68
69            // Parse JSON line
70            if let Ok(json) = serde_json::from_str::<serde_json::Value>(&line) {
71                if let Some((board, evaluation)) = self.extract_position_data(&json)? {
72                    // Quick duplicate check (not perfect but fast)
73                    if !existing_boards.contains(&board) {
74                        batch_boards.push(board);
75                        batch_evaluations.push(evaluation);
76
77                        // Process batch when full
78                        if batch_boards.len() >= batch_size {
79                            self.process_batch(engine, &mut batch_boards, &mut batch_evaluations)?;
80
81                            pb.set_message(format!(
82                                "{loaded} loaded, {dupes} dupes",
83                                loaded = self.loaded_count,
84                                dupes = self.duplicate_count
85                            ));
86                        }
87                    } else {
88                        self.duplicate_count += 1;
89                    }
90                }
91            }
92
93            self.total_processed += 1;
94
95            // Update progress every 1000 lines
96            if line_count % 1000 == 0 {
97                pb.set_position(line_count as u64);
98            }
99        }
100
101        // Process remaining batch
102        if !batch_boards.is_empty() {
103            self.process_batch(engine, &mut batch_boards, &mut batch_evaluations)?;
104        }
105
106        pb.finish_with_message(format!(
107            "✅ Complete: {} loaded, {} duplicates from {} lines",
108            self.loaded_count, self.duplicate_count, line_count
109        ));
110
111        let new_positions = engine.position_boards.len() - initial_size;
112        println!("🎯 Added {new_positions} new positions to engine");
113
114        Ok(())
115    }
116
117    /// Ultra-fast binary format streaming loader
118    /// For pre-processed binary training data
119    pub fn stream_load_binary<P: AsRef<Path>>(
120        &mut self,
121        path: P,
122        engine: &mut crate::ChessVectorEngine,
123    ) -> Result<(), Box<dyn std::error::Error>> {
124        let path_ref = path.as_ref();
125        println!("Operation complete");
126
127        // Load binary data
128        let data = std::fs::read(path_ref)?;
129        println!("📦 Read {} bytes", data.len());
130
131        // Try LZ4 decompression first
132        let decompressed_data = if let Ok(decompressed) = lz4_flex::decompress_size_prepended(&data)
133        {
134            println!(
135                "🗜️  LZ4 decompressed: {} → {} bytes",
136                data.len(),
137                decompressed.len()
138            );
139            decompressed
140        } else {
141            data
142        };
143
144        // Deserialize positions
145        let positions: Vec<(String, f32)> = bincode::deserialize(&decompressed_data)?;
146        let total_positions = positions.len();
147        println!("📊 Loaded {total_positions} positions from binary");
148
149        if total_positions == 0 {
150            return Ok(());
151        }
152
153        let pb = ProgressBar::new(total_positions as u64);
154        pb.set_style(
155            ProgressStyle::default_bar()
156                .template("⚡ Binary loading [{elapsed_precise}] [{bar:40.blue/green}] {pos}/{len} ({percent}%) {msg}")?
157                .progress_chars("██░")
158        );
159
160        // Use existing positions for duplicate detection
161        let existing_boards: HashSet<Board> = engine.position_boards.iter().cloned().collect();
162
163        // Process in large batches for efficiency
164        const BATCH_SIZE: usize = 10000;
165        let mut processed = 0;
166
167        for chunk in positions.chunks(BATCH_SIZE) {
168            let mut batch_boards = Vec::with_capacity(BATCH_SIZE);
169            let mut batch_evaluations = Vec::with_capacity(BATCH_SIZE);
170
171            for (fen, evaluation) in chunk {
172                if let Ok(board) = fen.parse::<Board>() {
173                    if !existing_boards.contains(&board) {
174                        batch_boards.push(board);
175                        batch_evaluations.push(*evaluation);
176                    } else {
177                        self.duplicate_count += 1;
178                    }
179                }
180                processed += 1;
181            }
182
183            // Add batch to engine
184            if !batch_boards.is_empty() {
185                self.process_batch(engine, &mut batch_boards, &mut batch_evaluations)?;
186            }
187
188            pb.set_position(processed as u64);
189            pb.set_message(format!("{count} loaded", count = self.loaded_count));
190        }
191
192        pb.finish_with_message(format!(
193            "✅ Loaded {count} positions",
194            count = self.loaded_count
195        ));
196
197        Ok(())
198    }
199
200    /// Process a batch of positions efficiently
201    fn process_batch(
202        &mut self,
203        engine: &mut crate::ChessVectorEngine,
204        boards: &mut Vec<Board>,
205        evaluations: &mut Vec<f32>,
206    ) -> Result<(), Box<dyn std::error::Error>> {
207        // Add all positions in batch
208        for (board, evaluation) in boards.iter().zip(evaluations.iter()) {
209            engine.add_position(board, *evaluation);
210            self.loaded_count += 1;
211        }
212
213        // Clear for next batch
214        boards.clear();
215        evaluations.clear();
216
217        Ok(())
218    }
219
220    /// Extract position data from JSON value
221    fn extract_position_data(
222        &self,
223        json: &serde_json::Value,
224    ) -> Result<Option<(Board, f32)>, Box<dyn std::error::Error>> {
225        // Try different JSON schemas
226        if let (Some(fen), Some(eval)) = (
227            json.get("fen").and_then(|v| v.as_str()),
228            json.get("evaluation").and_then(|v| v.as_f64()),
229        ) {
230            if let Ok(board) = fen.parse::<Board>() {
231                return Ok(Some((board, eval as f32)));
232            }
233        }
234
235        if let (Some(fen), Some(eval)) = (
236            json.get("board").and_then(|v| v.as_str()),
237            json.get("eval").and_then(|v| v.as_f64()),
238        ) {
239            if let Ok(board) = fen.parse::<Board>() {
240                return Ok(Some((board, eval as f32)));
241            }
242        }
243
244        if let (Some(fen), Some(eval)) = (
245            json.get("position").and_then(|v| v.as_str()),
246            json.get("score").and_then(|v| v.as_f64()),
247        ) {
248            if let Ok(board) = fen.parse::<Board>() {
249                return Ok(Some((board, eval as f32)));
250            }
251        }
252
253        Ok(None)
254    }
255
256    /// Estimate line count for progress tracking
257    fn estimate_line_count<P: AsRef<Path>>(
258        &self,
259        path: P,
260    ) -> Result<usize, Box<dyn std::error::Error>> {
261        use std::io::Read;
262
263        let path_ref = path.as_ref();
264        let file = File::open(path_ref)?;
265        let mut reader = BufReader::new(file);
266
267        // Sample first 1MB to estimate
268        let mut sample = vec![0u8; 1024 * 1024];
269        let bytes_read = reader.read(&mut sample)?;
270
271        if bytes_read == 0 {
272            return Ok(0);
273        }
274
275        // Count newlines in sample
276        let newlines_in_sample = sample[..bytes_read].iter().filter(|&&b| b == b'\n').count();
277
278        // Get total file size by re-opening the file
279        let total_size = std::fs::metadata(path_ref)?.len() as usize;
280
281        if bytes_read >= total_size {
282            // We read the whole file
283            return Ok(newlines_in_sample);
284        }
285
286        // Estimate based on sample
287        let estimated_lines = (newlines_in_sample * total_size) / bytes_read;
288        Ok(estimated_lines)
289    }
290}
291
292impl Default for StreamingLoader {
293    fn default() -> Self {
294        Self::new()
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use std::io::Write;
302    use tempfile::NamedTempFile;
303
304    #[test]
305    fn test_json_extraction() {
306        let loader = StreamingLoader::new();
307
308        let json = serde_json::json!({
309            "fen": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
310            "evaluation": 0.25
311        });
312
313        let result = loader.extract_position_data(&json).unwrap();
314        assert!(result.is_some());
315
316        let (board, eval) = result.unwrap();
317        assert_eq!(board, Board::default());
318        assert_eq!(eval, 0.25);
319    }
320
321    #[test]
322    fn test_line_estimation() {
323        let loader = StreamingLoader::new();
324
325        // Create a temporary file with known line count
326        let mut temp_file = NamedTempFile::new().unwrap();
327        for _i in 0..100 {
328            writeln!(temp_file, "Loading complete").unwrap();
329        }
330
331        let estimated = loader.estimate_line_count(temp_file.path()).unwrap();
332        // Should be approximately 100 (within reasonable range)
333        assert!((80..=120).contains(&estimated));
334    }
335}