chess_vector_engine/utils/
mmap_loader.rs1use crate::errors::ChessEngineError;
2use memmap2::Mmap;
3use ndarray::Array1;
4use serde::{Deserialize, Serialize};
5use std::fs::File;
6use std::io::{self, BufRead, BufReader, Result};
7use std::path::Path;
8
9pub struct MmapLoader {
11 file: File,
12 mmap: Mmap,
13}
14
15impl MmapLoader {
16 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
18 let file = File::open(path)?;
19 let mmap = unsafe { Mmap::map(&file)? };
20
21 Ok(Self { file, mmap })
22 }
23
24 pub fn data(&self) -> &[u8] {
26 &self.mmap
27 }
28
29 pub fn size(&self) -> usize {
31 self.mmap.len()
32 }
33
34 pub fn load_at_offset<T>(&self, offset: usize) -> Result<&T>
36 where
37 T: Sized,
38 {
39 let size = std::mem::size_of::<T>();
40 if offset + size > self.size() {
41 return Err(io::Error::new(
42 io::ErrorKind::UnexpectedEof,
43 "Offset exceeds file size",
44 ));
45 }
46
47 let ptr = unsafe { self.mmap.as_ptr().add(offset) as *const T };
48 Ok(unsafe { &*ptr })
49 }
50
51 pub fn load_slice_at_offset<T>(&self, offset: usize, count: usize) -> Result<&[T]>
53 where
54 T: Sized,
55 {
56 let size = std::mem::size_of::<T>() * count;
57 if offset + size > self.size() {
58 return Err(io::Error::new(
59 io::ErrorKind::UnexpectedEof,
60 "Offset and size exceed file size",
61 ));
62 }
63
64 let ptr = unsafe { self.mmap.as_ptr().add(offset) as *const T };
65 Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
66 }
67}
68
69pub struct FastPositionLoader;
71
72impl FastPositionLoader {
73 pub fn load_positions_mmap<P: AsRef<Path>>(
75 path: P,
76 ) -> crate::errors::Result<Vec<(Array1<f32>, f32)>> {
77 let loader = MmapLoader::new(path)
78 .map_err(|e| ChessEngineError::IoError(format!("Failed to memory-map file: {}", e)))?;
79
80 let data = loader.data();
81 if data.len() < 8 {
82 return Err(ChessEngineError::IoError(
83 "File too small to contain header".to_string(),
84 ));
85 }
86
87 let version = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
89 let count = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
90
91 if version != 1 {
92 return Err(ChessEngineError::IoError(format!(
93 "Unsupported version: {}",
94 version
95 )));
96 }
97
98 let mut positions = Vec::with_capacity(count as usize);
99 let mut offset = 8;
100
101 for _ in 0..count {
102 if offset + 4 > data.len() {
104 return Err(ChessEngineError::IoError(
105 "Unexpected end of file".to_string(),
106 ));
107 }
108 let vector_size = u32::from_le_bytes([
109 data[offset],
110 data[offset + 1],
111 data[offset + 2],
112 data[offset + 3],
113 ]) as usize;
114 offset += 4;
115
116 let vector_bytes = vector_size * 4; if offset + vector_bytes > data.len() {
119 return Err(ChessEngineError::IoError(
120 "Unexpected end of file".to_string(),
121 ));
122 }
123
124 let vector_data = loader
125 .load_slice_at_offset::<f32>(offset, vector_size)
126 .map_err(|e| {
127 ChessEngineError::IoError(format!("Failed to load vector data: {}", e))
128 })?;
129
130 let vector = Array1::from_vec(vector_data.to_vec());
131 offset += vector_bytes;
132
133 if offset + 4 > data.len() {
135 return Err(ChessEngineError::IoError(
136 "Unexpected end of file".to_string(),
137 ));
138 }
139 let evaluation = f32::from_le_bytes([
140 data[offset],
141 data[offset + 1],
142 data[offset + 2],
143 data[offset + 3],
144 ]);
145 offset += 4;
146
147 positions.push((vector, evaluation));
148 }
149
150 Ok(positions)
151 }
152
153 pub fn save_positions_binary<P: AsRef<Path>>(
155 path: P,
156 positions: &[(Array1<f32>, f32)],
157 ) -> crate::errors::Result<()> {
158 use std::fs::OpenOptions;
159 use std::io::Write;
160
161 let mut file = OpenOptions::new()
162 .create(true)
163 .write(true)
164 .truncate(true)
165 .open(path)
166 .map_err(|e| ChessEngineError::IoError(format!("Failed to create file: {}", e)))?;
167
168 let version = 1u32;
170 let count = positions.len() as u32;
171 file.write_all(&version.to_le_bytes())
172 .map_err(|e| ChessEngineError::IoError(format!("Failed to write header: {}", e)))?;
173 file.write_all(&count.to_le_bytes())
174 .map_err(|e| ChessEngineError::IoError(format!("Failed to write count: {}", e)))?;
175
176 for (vector, evaluation) in positions {
177 let vector_size = vector.len() as u32;
179 file.write_all(&vector_size.to_le_bytes()).map_err(|e| {
180 ChessEngineError::IoError(format!("Failed to write vector size: {}", e))
181 })?;
182
183 let vector_bytes = unsafe {
185 std::slice::from_raw_parts(vector.as_ptr() as *const u8, vector.len() * 4)
186 };
187 file.write_all(vector_bytes).map_err(|e| {
188 ChessEngineError::IoError(format!("Failed to write vector data: {}", e))
189 })?;
190
191 file.write_all(&evaluation.to_le_bytes()).map_err(|e| {
193 ChessEngineError::IoError(format!("Failed to write evaluation: {}", e))
194 })?;
195 }
196
197 file.flush()
198 .map_err(|e| ChessEngineError::IoError(format!("Failed to flush file: {}", e)))?;
199
200 Ok(())
201 }
202
203 pub fn load_positions_json_streaming<P: AsRef<Path>>(
205 path: P,
206 ) -> crate::errors::Result<Vec<(Array1<f32>, f32)>> {
207 let file = File::open(path)
208 .map_err(|e| ChessEngineError::IoError(format!("Failed to open file: {}", e)))?;
209
210 let reader = BufReader::new(file);
211 let mut positions = Vec::new();
212
213 for (line_num, line) in reader.lines().enumerate() {
214 let line = line.map_err(|e| {
215 ChessEngineError::IoError(format!("Failed to read line {}: {}", line_num, e))
216 })?;
217
218 if line.trim().is_empty() || line.trim().starts_with('#') {
220 continue;
221 }
222
223 let position: JsonPosition = serde_json::from_str(&line).map_err(|e| {
224 ChessEngineError::IoError(format!(
225 "Failed to parse JSON at line {}: {}",
226 line_num, e
227 ))
228 })?;
229
230 let vector = Array1::from_vec(position.vector);
231 positions.push((vector, position.evaluation));
232 }
233
234 Ok(positions)
235 }
236}
237
238#[derive(Debug, Serialize, Deserialize)]
240struct JsonPosition {
241 vector: Vec<f32>,
242 evaluation: f32,
243}
244
245pub struct ChunkedLoader {
247 chunk_size: usize,
248}
249
250impl ChunkedLoader {
251 pub fn new(chunk_size: usize) -> Self {
253 Self { chunk_size }
254 }
255
256 pub fn process_in_chunks<P, F, R>(
258 &self,
259 path: P,
260 mut processor: F,
261 ) -> crate::errors::Result<Vec<R>>
262 where
263 P: AsRef<Path>,
264 F: FnMut(&[(Array1<f32>, f32)]) -> crate::errors::Result<R>,
265 {
266 let loader = MmapLoader::new(path)
267 .map_err(|e| ChessEngineError::IoError(format!("Failed to memory-map file: {}", e)))?;
268
269 let data = loader.data();
270 if data.len() < 8 {
271 return Err(ChessEngineError::IoError(
272 "File too small to contain header".to_string(),
273 ));
274 }
275
276 let version = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
278 let count = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
279
280 if version != 1 {
281 return Err(ChessEngineError::IoError(format!(
282 "Unsupported version: {}",
283 version
284 )));
285 }
286
287 let mut results = Vec::new();
288 let mut offset = 8;
289 let mut processed = 0;
290
291 while processed < count {
292 let chunk_end = ((processed + self.chunk_size as u32).min(count)) as usize;
293 let chunk_count = chunk_end - processed as usize;
294
295 let mut chunk = Vec::with_capacity(chunk_count);
296
297 for _ in 0..chunk_count {
299 if offset + 4 > data.len() {
301 return Err(ChessEngineError::IoError(
302 "Unexpected end of file".to_string(),
303 ));
304 }
305 let vector_size = u32::from_le_bytes([
306 data[offset],
307 data[offset + 1],
308 data[offset + 2],
309 data[offset + 3],
310 ]) as usize;
311 offset += 4;
312
313 let vector_bytes = vector_size * 4;
315 if offset + vector_bytes > data.len() {
316 return Err(ChessEngineError::IoError(
317 "Unexpected end of file".to_string(),
318 ));
319 }
320
321 let vector_data = loader
322 .load_slice_at_offset::<f32>(offset, vector_size)
323 .map_err(|e| {
324 ChessEngineError::IoError(format!("Failed to load vector data: {}", e))
325 })?;
326
327 let vector = Array1::from_vec(vector_data.to_vec());
328 offset += vector_bytes;
329
330 if offset + 4 > data.len() {
332 return Err(ChessEngineError::IoError(
333 "Unexpected end of file".to_string(),
334 ));
335 }
336 let evaluation = f32::from_le_bytes([
337 data[offset],
338 data[offset + 1],
339 data[offset + 2],
340 data[offset + 3],
341 ]);
342 offset += 4;
343
344 chunk.push((vector, evaluation));
345 }
346
347 let result = processor(&chunk)?;
349 results.push(result);
350
351 processed += chunk_count as u32;
352 }
353
354 Ok(results)
355 }
356}
357
358pub struct CompressedLoader;
360
361impl CompressedLoader {
362 pub fn load_compressed<P: AsRef<Path>>(
364 path: P,
365 ) -> crate::errors::Result<Vec<(Array1<f32>, f32)>> {
366 let path = path.as_ref();
367 let extension = path.extension().and_then(|ext| ext.to_str()).unwrap_or("");
368
369 match extension {
370 "gz" => Self::load_gzip(path),
371 "zst" => Self::load_zstd(path),
372 "lz4" => Self::load_lz4(path),
373 _ => Err(ChessEngineError::IoError(format!(
374 "Unsupported compression format: {}",
375 extension
376 ))),
377 }
378 }
379
380 fn load_gzip<P: AsRef<Path>>(_path: P) -> crate::errors::Result<Vec<(Array1<f32>, f32)>> {
382 Err(ChessEngineError::IoError(
385 "Gzip support not implemented".to_string(),
386 ))
387 }
397
398 fn load_zstd<P: AsRef<Path>>(_path: P) -> crate::errors::Result<Vec<(Array1<f32>, f32)>> {
400 Err(ChessEngineError::IoError(
402 "Zstd support not implemented".to_string(),
403 ))
404 }
405
406 fn load_lz4<P: AsRef<Path>>(_path: P) -> crate::errors::Result<Vec<(Array1<f32>, f32)>> {
408 Err(ChessEngineError::IoError(
410 "LZ4 support not implemented".to_string(),
411 ))
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use tempfile::NamedTempFile;
419
420 #[test]
421 fn test_binary_save_load() {
422 let temp_file = NamedTempFile::new().unwrap();
423 let temp_path = temp_file.path();
424
425 let original_positions = vec![
427 (Array1::from_vec(vec![1.0, 2.0, 3.0]), 0.5),
428 (Array1::from_vec(vec![4.0, 5.0, 6.0]), -0.3),
429 (Array1::from_vec(vec![7.0, 8.0, 9.0]), 0.1),
430 ];
431
432 FastPositionLoader::save_positions_binary(temp_path, &original_positions).unwrap();
434
435 let loaded_positions = FastPositionLoader::load_positions_mmap(temp_path).unwrap();
437
438 assert_eq!(loaded_positions.len(), original_positions.len());
440 for (loaded, original) in loaded_positions.iter().zip(original_positions.iter()) {
441 assert_eq!(loaded.0.len(), original.0.len());
442 for (l, o) in loaded.0.iter().zip(original.0.iter()) {
443 assert!((l - o).abs() < 1e-6);
444 }
445 assert!((loaded.1 - original.1).abs() < 1e-6);
446 }
447 }
448
449 #[test]
450 fn test_chunked_processing() {
451 let temp_file = NamedTempFile::new().unwrap();
452 let temp_path = temp_file.path();
453
454 let original_positions = vec![
456 (Array1::from_vec(vec![1.0, 2.0]), 0.5),
457 (Array1::from_vec(vec![3.0, 4.0]), -0.3),
458 (Array1::from_vec(vec![5.0, 6.0]), 0.1),
459 (Array1::from_vec(vec![7.0, 8.0]), 0.7),
460 ];
461
462 FastPositionLoader::save_positions_binary(temp_path, &original_positions).unwrap();
464
465 let chunked_loader = ChunkedLoader::new(2);
467 let results = chunked_loader
468 .process_in_chunks(temp_path, |chunk| Ok(chunk.len()))
469 .unwrap();
470
471 assert_eq!(results.len(), 2);
473 assert_eq!(results[0], 2);
474 assert_eq!(results[1], 2);
475 }
476
477 #[test]
478 fn test_mmap_loader() {
479 let temp_file = NamedTempFile::new().unwrap();
480 let temp_path = temp_file.path();
481
482 use std::io::Write;
484 let mut file = std::fs::OpenOptions::new()
485 .write(true)
486 .open(temp_path)
487 .unwrap();
488
489 let test_data = [1u32, 2u32, 3u32, 4u32];
490 for value in &test_data {
491 file.write_all(&value.to_le_bytes()).unwrap();
492 }
493 file.flush().unwrap();
494 drop(file);
495
496 let loader = MmapLoader::new(temp_path).unwrap();
498 assert_eq!(loader.size(), 16); let value1 = loader.load_at_offset::<u32>(0).unwrap();
502 assert_eq!(*value1, 1);
503
504 let value2 = loader.load_at_offset::<u32>(4).unwrap();
505 assert_eq!(*value2, 2);
506
507 let slice = loader.load_slice_at_offset::<u32>(0, 4).unwrap();
509 assert_eq!(slice, &test_data);
510 }
511}