chess_vector_engine/
streaming_loader.rs1use chess::Board;
2use indicatif::{ProgressBar, ProgressStyle};
3use serde_json;
4use std::collections::HashSet;
5use std::fs::File;
6use std::io::{BufRead, BufReader};
7use std::path::Path;
8
9pub struct StreamingLoader {
12 pub loaded_count: usize,
13 pub duplicate_count: usize,
14 pub total_processed: usize,
15}
16
17impl StreamingLoader {
18 pub fn new() -> Self {
19 Self {
20 loaded_count: 0,
21 duplicate_count: 0,
22 total_processed: 0,
23 }
24 }
25
26 pub fn stream_load_json<P: AsRef<Path>>(
29 &mut self,
30 path: P,
31 engine: &mut crate::ChessVectorEngine,
32 batch_size: usize,
33 ) -> Result<(), Box<dyn std::error::Error>> {
34 let path_ref = path.as_ref();
35 println!("Operation complete");
36
37 let file = File::open(path_ref)?;
38 let reader = BufReader::with_capacity(64 * 1024, file); let total_lines = self.estimate_line_count(path_ref)?;
42 println!("📊 Estimated {total_lines} lines to process");
43
44 let pb = ProgressBar::new(total_lines as u64);
45 pb.set_style(
46 ProgressStyle::default_bar()
47 .template("⚡ Streaming [{elapsed_precise}] [{bar:40.green/blue}] {pos}/{len} ({percent}%) {msg}")?
48 .progress_chars("██░")
49 );
50
51 let existing_boards: HashSet<Board> = engine.position_boards.iter().cloned().collect();
53 let initial_size = existing_boards.len();
54
55 let mut batch_boards = Vec::with_capacity(batch_size);
57 let mut batch_evaluations = Vec::with_capacity(batch_size);
58 let mut line_count = 0;
59
60 for line_result in reader.lines() {
62 let line = line_result?;
63 line_count += 1;
64
65 if line.trim().is_empty() {
66 continue;
67 }
68
69 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&line) {
71 if let Some((board, evaluation)) = self.extract_position_data(&json)? {
72 if !existing_boards.contains(&board) {
74 batch_boards.push(board);
75 batch_evaluations.push(evaluation);
76
77 if batch_boards.len() >= batch_size {
79 self.process_batch(engine, &mut batch_boards, &mut batch_evaluations)?;
80
81 pb.set_message(format!(
82 "{loaded} loaded, {dupes} dupes",
83 loaded = self.loaded_count,
84 dupes = self.duplicate_count
85 ));
86 }
87 } else {
88 self.duplicate_count += 1;
89 }
90 }
91 }
92
93 self.total_processed += 1;
94
95 if line_count % 1000 == 0 {
97 pb.set_position(line_count as u64);
98 }
99 }
100
101 if !batch_boards.is_empty() {
103 self.process_batch(engine, &mut batch_boards, &mut batch_evaluations)?;
104 }
105
106 pb.finish_with_message(format!(
107 "✅ Complete: {} loaded, {} duplicates from {} lines",
108 self.loaded_count, self.duplicate_count, line_count
109 ));
110
111 let new_positions = engine.position_boards.len() - initial_size;
112 println!("🎯 Added {new_positions} new positions to engine");
113
114 Ok(())
115 }
116
117 pub fn stream_load_binary<P: AsRef<Path>>(
120 &mut self,
121 path: P,
122 engine: &mut crate::ChessVectorEngine,
123 ) -> Result<(), Box<dyn std::error::Error>> {
124 let path_ref = path.as_ref();
125 println!("Operation complete");
126
127 let data = std::fs::read(path_ref)?;
129 println!("📦 Read {} bytes", data.len());
130
131 let decompressed_data = if let Ok(decompressed) = lz4_flex::decompress_size_prepended(&data)
133 {
134 println!(
135 "🗜️ LZ4 decompressed: {} → {} bytes",
136 data.len(),
137 decompressed.len()
138 );
139 decompressed
140 } else {
141 data
142 };
143
144 let positions: Vec<(String, f32)> = bincode::deserialize(&decompressed_data)?;
146 let total_positions = positions.len();
147 println!("📊 Loaded {total_positions} positions from binary");
148
149 if total_positions == 0 {
150 return Ok(());
151 }
152
153 let pb = ProgressBar::new(total_positions as u64);
154 pb.set_style(
155 ProgressStyle::default_bar()
156 .template("⚡ Binary loading [{elapsed_precise}] [{bar:40.blue/green}] {pos}/{len} ({percent}%) {msg}")?
157 .progress_chars("██░")
158 );
159
160 let existing_boards: HashSet<Board> = engine.position_boards.iter().cloned().collect();
162
163 const BATCH_SIZE: usize = 10000;
165 let mut processed = 0;
166
167 for chunk in positions.chunks(BATCH_SIZE) {
168 let mut batch_boards = Vec::with_capacity(BATCH_SIZE);
169 let mut batch_evaluations = Vec::with_capacity(BATCH_SIZE);
170
171 for (fen, evaluation) in chunk {
172 if let Ok(board) = fen.parse::<Board>() {
173 if !existing_boards.contains(&board) {
174 batch_boards.push(board);
175 batch_evaluations.push(*evaluation);
176 } else {
177 self.duplicate_count += 1;
178 }
179 }
180 processed += 1;
181 }
182
183 if !batch_boards.is_empty() {
185 self.process_batch(engine, &mut batch_boards, &mut batch_evaluations)?;
186 }
187
188 pb.set_position(processed as u64);
189 pb.set_message(format!("{count} loaded", count = self.loaded_count));
190 }
191
192 pb.finish_with_message(format!(
193 "✅ Loaded {count} positions",
194 count = self.loaded_count
195 ));
196
197 Ok(())
198 }
199
200 fn process_batch(
202 &mut self,
203 engine: &mut crate::ChessVectorEngine,
204 boards: &mut Vec<Board>,
205 evaluations: &mut Vec<f32>,
206 ) -> Result<(), Box<dyn std::error::Error>> {
207 for (board, evaluation) in boards.iter().zip(evaluations.iter()) {
209 engine.add_position(board, *evaluation);
210 self.loaded_count += 1;
211 }
212
213 boards.clear();
215 evaluations.clear();
216
217 Ok(())
218 }
219
220 fn extract_position_data(
222 &self,
223 json: &serde_json::Value,
224 ) -> Result<Option<(Board, f32)>, Box<dyn std::error::Error>> {
225 if let (Some(fen), Some(eval)) = (
227 json.get("fen").and_then(|v| v.as_str()),
228 json.get("evaluation").and_then(|v| v.as_f64()),
229 ) {
230 if let Ok(board) = fen.parse::<Board>() {
231 return Ok(Some((board, eval as f32)));
232 }
233 }
234
235 if let (Some(fen), Some(eval)) = (
236 json.get("board").and_then(|v| v.as_str()),
237 json.get("eval").and_then(|v| v.as_f64()),
238 ) {
239 if let Ok(board) = fen.parse::<Board>() {
240 return Ok(Some((board, eval as f32)));
241 }
242 }
243
244 if let (Some(fen), Some(eval)) = (
245 json.get("position").and_then(|v| v.as_str()),
246 json.get("score").and_then(|v| v.as_f64()),
247 ) {
248 if let Ok(board) = fen.parse::<Board>() {
249 return Ok(Some((board, eval as f32)));
250 }
251 }
252
253 Ok(None)
254 }
255
256 fn estimate_line_count<P: AsRef<Path>>(
258 &self,
259 path: P,
260 ) -> Result<usize, Box<dyn std::error::Error>> {
261 use std::io::Read;
262
263 let path_ref = path.as_ref();
264 let file = File::open(path_ref)?;
265 let mut reader = BufReader::new(file);
266
267 let mut sample = vec![0u8; 1024 * 1024];
269 let bytes_read = reader.read(&mut sample)?;
270
271 if bytes_read == 0 {
272 return Ok(0);
273 }
274
275 let newlines_in_sample = sample[..bytes_read].iter().filter(|&&b| b == b'\n').count();
277
278 let total_size = std::fs::metadata(path_ref)?.len() as usize;
280
281 if bytes_read >= total_size {
282 return Ok(newlines_in_sample);
284 }
285
286 let estimated_lines = (newlines_in_sample * total_size) / bytes_read;
288 Ok(estimated_lines)
289 }
290}
291
292impl Default for StreamingLoader {
293 fn default() -> Self {
294 Self::new()
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use std::io::Write;
302 use tempfile::NamedTempFile;
303
304 #[test]
305 fn test_json_extraction() {
306 let loader = StreamingLoader::new();
307
308 let json = serde_json::json!({
309 "fen": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
310 "evaluation": 0.25
311 });
312
313 let result = loader.extract_position_data(&json).unwrap();
314 assert!(result.is_some());
315
316 let (board, eval) = result.unwrap();
317 assert_eq!(board, Board::default());
318 assert_eq!(eval, 0.25);
319 }
320
321 #[test]
322 fn test_line_estimation() {
323 let loader = StreamingLoader::new();
324
325 let mut temp_file = NamedTempFile::new().unwrap();
327 for _i in 0..100 {
328 writeln!(temp_file, "Loading complete").unwrap();
329 }
330
331 let estimated = loader.estimate_line_count(temp_file.path()).unwrap();
332 assert!((80..=120).contains(&estimated));
334 }
335}