chess_vector_engine/
streaming_loader.rs1use chess::Board;
2use indicatif::{ProgressBar, ProgressStyle};
3use serde_json;
4use std::collections::HashSet;
5use std::fs::File;
6use std::io::{BufRead, BufReader};
7use std::path::Path;
8
9pub struct StreamingLoader {
12 pub loaded_count: usize,
13 pub duplicate_count: usize,
14 pub total_processed: usize,
15}
16
17impl StreamingLoader {
18 pub fn new() -> Self {
19 Self {
20 loaded_count: 0,
21 duplicate_count: 0,
22 total_processed: 0,
23 }
24 }
25
26 pub fn stream_load_json<P: AsRef<Path>>(
29 &mut self,
30 path: P,
31 engine: &mut crate::ChessVectorEngine,
32 batch_size: usize,
33 ) -> Result<(), Box<dyn std::error::Error>> {
34 let path_ref = path.as_ref();
35 println!("Operation complete");
36
37 let file = File::open(path_ref)?;
38 let reader = BufReader::with_capacity(64 * 1024, file); let total_lines = self.estimate_line_count(path_ref)?;
42 println!("📊 Estimated {} lines to process", total_lines);
43
44 let pb = ProgressBar::new(total_lines as u64);
45 pb.set_style(
46 ProgressStyle::default_bar()
47 .template("⚡ Streaming [{elapsed_precise}] [{bar:40.green/blue}] {pos}/{len} ({percent}%) {msg}")?
48 .progress_chars("██░")
49 );
50
51 let existing_boards: HashSet<Board> = engine.position_boards.iter().cloned().collect();
53 let initial_size = existing_boards.len();
54
55 let mut batch_boards = Vec::with_capacity(batch_size);
57 let mut batch_evaluations = Vec::with_capacity(batch_size);
58 let mut line_count = 0;
59
60 for line_result in reader.lines() {
62 let line = line_result?;
63 line_count += 1;
64
65 if line.trim().is_empty() {
66 continue;
67 }
68
69 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&line) {
71 if let Some((board, evaluation)) = self.extract_position_data(&json)? {
72 if !existing_boards.contains(&board) {
74 batch_boards.push(board);
75 batch_evaluations.push(evaluation);
76
77 if batch_boards.len() >= batch_size {
79 self.process_batch(engine, &mut batch_boards, &mut batch_evaluations)?;
80
81 pb.set_message(format!(
82 "{} loaded, {} dupes",
83 self.loaded_count, self.duplicate_count
84 ));
85 }
86 } else {
87 self.duplicate_count += 1;
88 }
89 }
90 }
91
92 self.total_processed += 1;
93
94 if line_count % 1000 == 0 {
96 pb.set_position(line_count as u64);
97 }
98 }
99
100 if !batch_boards.is_empty() {
102 self.process_batch(engine, &mut batch_boards, &mut batch_evaluations)?;
103 }
104
105 pb.finish_with_message(format!(
106 "✅ Complete: {} loaded, {} duplicates from {} lines",
107 self.loaded_count, self.duplicate_count, line_count
108 ));
109
110 let new_positions = engine.position_boards.len() - initial_size;
111 println!("🎯 Added {} new positions to engine", new_positions);
112
113 Ok(())
114 }
115
116 pub fn stream_load_binary<P: AsRef<Path>>(
119 &mut self,
120 path: P,
121 engine: &mut crate::ChessVectorEngine,
122 ) -> Result<(), Box<dyn std::error::Error>> {
123 let path_ref = path.as_ref();
124 println!("Operation complete");
125
126 let data = std::fs::read(path_ref)?;
128 println!("📦 Read {} bytes", data.len());
129
130 let decompressed_data = if let Ok(decompressed) = lz4_flex::decompress_size_prepended(&data)
132 {
133 println!(
134 "🗜️ LZ4 decompressed: {} → {} bytes",
135 data.len(),
136 decompressed.len()
137 );
138 decompressed
139 } else {
140 data
141 };
142
143 let positions: Vec<(String, f32)> = bincode::deserialize(&decompressed_data)?;
145 let total_positions = positions.len();
146 println!("📊 Loaded {} positions from binary", total_positions);
147
148 if total_positions == 0 {
149 return Ok(());
150 }
151
152 let pb = ProgressBar::new(total_positions as u64);
153 pb.set_style(
154 ProgressStyle::default_bar()
155 .template("⚡ Binary loading [{elapsed_precise}] [{bar:40.blue/green}] {pos}/{len} ({percent}%) {msg}")?
156 .progress_chars("██░")
157 );
158
159 let existing_boards: HashSet<Board> = engine.position_boards.iter().cloned().collect();
161
162 const BATCH_SIZE: usize = 10000;
164 let mut processed = 0;
165
166 for chunk in positions.chunks(BATCH_SIZE) {
167 let mut batch_boards = Vec::with_capacity(BATCH_SIZE);
168 let mut batch_evaluations = Vec::with_capacity(BATCH_SIZE);
169
170 for (fen, evaluation) in chunk {
171 if let Ok(board) = fen.parse::<Board>() {
172 if !existing_boards.contains(&board) {
173 batch_boards.push(board);
174 batch_evaluations.push(*evaluation);
175 } else {
176 self.duplicate_count += 1;
177 }
178 }
179 processed += 1;
180 }
181
182 if !batch_boards.is_empty() {
184 self.process_batch(engine, &mut batch_boards, &mut batch_evaluations)?;
185 }
186
187 pb.set_position(processed as u64);
188 pb.set_message(format!("{} loaded", self.loaded_count));
189 }
190
191 pb.finish_with_message(format!("✅ Loaded {} positions", self.loaded_count));
192
193 Ok(())
194 }
195
196 fn process_batch(
198 &mut self,
199 engine: &mut crate::ChessVectorEngine,
200 boards: &mut Vec<Board>,
201 evaluations: &mut Vec<f32>,
202 ) -> Result<(), Box<dyn std::error::Error>> {
203 for (board, evaluation) in boards.iter().zip(evaluations.iter()) {
205 engine.add_position(board, *evaluation);
206 self.loaded_count += 1;
207 }
208
209 boards.clear();
211 evaluations.clear();
212
213 Ok(())
214 }
215
216 fn extract_position_data(
218 &self,
219 json: &serde_json::Value,
220 ) -> Result<Option<(Board, f32)>, Box<dyn std::error::Error>> {
221 if let (Some(fen), Some(eval)) = (
223 json.get("fen").and_then(|v| v.as_str()),
224 json.get("evaluation").and_then(|v| v.as_f64()),
225 ) {
226 if let Ok(board) = fen.parse::<Board>() {
227 return Ok(Some((board, eval as f32)));
228 }
229 }
230
231 if let (Some(fen), Some(eval)) = (
232 json.get("board").and_then(|v| v.as_str()),
233 json.get("eval").and_then(|v| v.as_f64()),
234 ) {
235 if let Ok(board) = fen.parse::<Board>() {
236 return Ok(Some((board, eval as f32)));
237 }
238 }
239
240 if let (Some(fen), Some(eval)) = (
241 json.get("position").and_then(|v| v.as_str()),
242 json.get("score").and_then(|v| v.as_f64()),
243 ) {
244 if let Ok(board) = fen.parse::<Board>() {
245 return Ok(Some((board, eval as f32)));
246 }
247 }
248
249 Ok(None)
250 }
251
252 fn estimate_line_count<P: AsRef<Path>>(
254 &self,
255 path: P,
256 ) -> Result<usize, Box<dyn std::error::Error>> {
257 use std::io::Read;
258
259 let path_ref = path.as_ref();
260 let file = File::open(path_ref)?;
261 let mut reader = BufReader::new(file);
262
263 let mut sample = vec![0u8; 1024 * 1024];
265 let bytes_read = reader.read(&mut sample)?;
266
267 if bytes_read == 0 {
268 return Ok(0);
269 }
270
271 let newlines_in_sample = sample[..bytes_read].iter().filter(|&&b| b == b'\n').count();
273
274 let total_size = std::fs::metadata(path_ref)?.len() as usize;
276
277 if bytes_read >= total_size {
278 return Ok(newlines_in_sample);
280 }
281
282 let estimated_lines = (newlines_in_sample * total_size) / bytes_read;
284 Ok(estimated_lines)
285 }
286}
287
288impl Default for StreamingLoader {
289 fn default() -> Self {
290 Self::new()
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use std::io::Write;
298 use tempfile::NamedTempFile;
299
300 #[test]
301 fn test_json_extraction() {
302 let loader = StreamingLoader::new();
303
304 let json = serde_json::json!({
305 "fen": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
306 "evaluation": 0.25
307 });
308
309 let result = loader.extract_position_data(&json).unwrap();
310 assert!(result.is_some());
311
312 let (board, eval) = result.unwrap();
313 assert_eq!(board, Board::default());
314 assert_eq!(eval, 0.25);
315 }
316
317 #[test]
318 fn test_line_estimation() {
319 let loader = StreamingLoader::new();
320
321 let mut temp_file = NamedTempFile::new().unwrap();
323 for _i in 0..100 {
324 writeln!(temp_file, "Loading complete").unwrap();
325 }
326
327 let estimated = loader.estimate_line_count(temp_file.path()).unwrap();
328 assert!((80..=120).contains(&estimated));
330 }
331}