chess_vector_engine/
auto_discovery.rs

1use std::collections::HashMap;
2use std::fs;
3use std::io::{BufRead, BufReader};
4use std::path::{Path, PathBuf};
5
6/// File format priority (lower = better)
7#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
8pub enum FormatPriority {
9    MemoryMapped = 1,
10    MessagePack = 2,
11    Binary = 3,
12    Zstd = 4,
13    Json = 5,
14}
15
16/// Training data file information
17#[derive(Debug, Clone)]
18pub struct TrainingFile {
19    pub path: PathBuf,
20    pub format: String,
21    pub priority: FormatPriority,
22    pub base_name: String,
23    pub size_bytes: u64,
24}
25
26/// Auto-discovery and format consolidation engine
27pub struct AutoDiscovery;
28
29impl AutoDiscovery {
30    /// Discover all training data files in a directory
31    pub fn discover_training_files<P: AsRef<Path>>(
32        base_path: P,
33        recursive: bool,
34    ) -> Result<Vec<TrainingFile>, Box<dyn std::error::Error>> {
35        let mut discovered_files = Vec::new();
36        let base_path = base_path.as_ref();
37
38        println!(
39            "๐Ÿ” Discovering training data files in {}...",
40            base_path.display()
41        );
42
43        Self::scan_directory(base_path, recursive, &mut discovered_files)?;
44
45        // Sort by base name, then by priority
46        discovered_files.sort_by(|a, b| {
47            a.base_name
48                .cmp(&b.base_name)
49                .then_with(|| a.priority.cmp(&b.priority))
50        });
51
52        println!(
53            "๐Ÿ“ Discovered {} training data files",
54            discovered_files.len()
55        );
56        for file in &discovered_files {
57            println!(
58                "   {} - {} ({})",
59                file.format,
60                file.path.display(),
61                Self::format_bytes(file.size_bytes)
62            );
63        }
64
65        Ok(discovered_files)
66    }
67
68    /// Group files by base name and select best format for each
69    pub fn consolidate_by_base_name(files: Vec<TrainingFile>) -> HashMap<String, TrainingFile> {
70        let mut groups: HashMap<String, Vec<TrainingFile>> = HashMap::new();
71
72        // Group by base name
73        for file in files {
74            groups.entry(file.base_name.clone()).or_default().push(file);
75        }
76
77        let mut consolidated = HashMap::new();
78
79        // Select best format for each group
80        for (base_name, mut group) in groups {
81            group.sort_by(|a, b| a.priority.cmp(&b.priority));
82
83            if let Some(best_file) = group.into_iter().next() {
84                consolidated.insert(base_name, best_file);
85            }
86        }
87
88        consolidated
89    }
90
91    /// Get list of inferior formats that can be cleaned up
92    pub fn get_cleanup_candidates(files: &[TrainingFile]) -> Vec<PathBuf> {
93        let mut cleanup_files = Vec::new();
94        let consolidated = Self::consolidate_by_base_name(files.to_vec());
95
96        for file in files {
97            if let Some(best_file) = consolidated.get(&file.base_name) {
98                // If this file is not the best format for its base name, mark for cleanup
99                if file.path != best_file.path && file.priority > best_file.priority {
100                    cleanup_files.push(file.path.clone());
101                }
102            }
103        }
104
105        cleanup_files
106    }
107
108    /// Clean up old format files
109    pub fn cleanup_old_formats(
110        files_to_remove: &[PathBuf],
111        dry_run: bool,
112    ) -> Result<(), Box<dyn std::error::Error>> {
113        if dry_run {
114            println!(
115                "๐Ÿงน DRY RUN - Would remove {} old format files:",
116                files_to_remove.len()
117            );
118            for _path in files_to_remove {
119                println!("Discovery complete");
120            }
121            return Ok(());
122        }
123
124        println!(
125            "๐Ÿงน Cleaning up {} old format files...",
126            files_to_remove.len()
127        );
128
129        for path in files_to_remove {
130            match fs::remove_file(path) {
131                Ok(()) => println!("Removed file: {}", path.display()),
132                Err(e) => println!("Error removing file: {e}"),
133            }
134        }
135
136        Ok(())
137    }
138
139    /// Scan directory recursively for training files
140    fn scan_directory(
141        dir: &Path,
142        recursive: bool,
143        files: &mut Vec<TrainingFile>,
144    ) -> Result<(), Box<dyn std::error::Error>> {
145        if !dir.is_dir() {
146            return Ok(());
147        }
148
149        for entry in fs::read_dir(dir)? {
150            let entry = entry?;
151            let path = entry.path();
152
153            if path.is_dir() && recursive {
154                Self::scan_directory(&path, recursive, files)?;
155            } else if path.is_file() {
156                if let Some(training_file) = Self::analyze_file(&path)? {
157                    files.push(training_file);
158                }
159            }
160        }
161
162        Ok(())
163    }
164
165    /// Analyze a file to determine if it's training data
166    fn analyze_file(path: &Path) -> Result<Option<TrainingFile>, Box<dyn std::error::Error>> {
167        let metadata = fs::metadata(path)?;
168        let size_bytes = metadata.len();
169
170        // Skip system files, hidden files, and very small files
171        if let Some(file_name) = path.file_name().and_then(|n| n.to_str()) {
172            if file_name.starts_with('.')
173                || file_name.starts_with("~")
174                || file_name.contains("lock")
175                || file_name.contains("tmp")
176                || size_bytes < 10
177            {
178                return Ok(None);
179            }
180        }
181
182        // Get base name (without extension)
183        let base_name = Self::extract_base_name(path);
184
185        // Detect format by extension and content
186        if let Some((format, priority)) = Self::detect_format(path)? {
187            return Ok(Some(TrainingFile {
188                path: path.to_path_buf(),
189                format,
190                priority,
191                base_name,
192                size_bytes,
193            }));
194        }
195
196        Ok(None)
197    }
198
199    /// Extract base name from path, removing format extensions
200    fn extract_base_name(path: &Path) -> String {
201        let file_name = path
202            .file_name()
203            .and_then(|n| n.to_str())
204            .unwrap_or("unknown");
205
206        // Remove known extensions to get base name
207
208        file_name
209            .replace(".mmap", "")
210            .replace(".msgpack", "")
211            .replace(".bin", "")
212            .replace(".zst", "")
213            .replace(".json", "")
214    }
215
216    /// Detect file format and priority
217    fn detect_format(
218        path: &Path,
219    ) -> Result<Option<(String, FormatPriority)>, Box<dyn std::error::Error>> {
220        let file_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
221
222        // Check by extension
223        if file_name.ends_with(".mmap") {
224            return Ok(Some(("MMAP".to_string(), FormatPriority::MemoryMapped)));
225        }
226
227        if file_name.ends_with(".msgpack") {
228            return Ok(Some(("MSGPACK".to_string(), FormatPriority::MessagePack)));
229        }
230
231        if file_name.ends_with(".bin") && Self::verify_binary_training_data(path)? {
232            return Ok(Some(("BINARY".to_string(), FormatPriority::Binary)));
233        }
234
235        if file_name.ends_with(".zst") {
236            return Ok(Some(("ZSTD".to_string(), FormatPriority::Zstd)));
237        }
238
239        if file_name.ends_with(".json") && Self::verify_json_training_data(path)? {
240            return Ok(Some(("JSON".to_string(), FormatPriority::Json)));
241        }
242
243        // Check by content for files that might not have proper extensions
244        if (file_name.contains("training")
245            || file_name.contains("position")
246            || file_name.contains("tactical"))
247            && Self::verify_json_training_data(path)?
248        {
249            return Ok(Some(("JSON".to_string(), FormatPriority::Json)));
250        }
251
252        Ok(None)
253    }
254
255    /// Verify that a JSON file contains training data
256    fn verify_json_training_data(path: &Path) -> Result<bool, Box<dyn std::error::Error>> {
257        // Check file size first
258        if let Ok(metadata) = std::fs::metadata(path) {
259            let size = metadata.len();
260            if size == 0 || size > 10_000_000 {
261                // Skip empty files or files > 10MB
262                return Ok(false);
263            }
264        }
265
266        let file = std::fs::File::open(path)?;
267        let reader = BufReader::new(file);
268
269        // Check first few lines for training data structure
270        for (i, line_result) in reader.lines().enumerate() {
271            if i >= 10 {
272                break;
273            } // Only check first 10 lines
274
275            // Handle UTF-8 errors gracefully
276            let line = match line_result {
277                Ok(line) => line,
278                Err(_) => {
279                    // If we can't read as UTF-8, it's not a JSON training file
280                    return Ok(false);
281                }
282            };
283
284            if line.trim().is_empty() {
285                continue;
286            }
287
288            if let Ok(json) = serde_json::from_str::<serde_json::Value>(&line) {
289                // Look for common training data fields
290                if json.get("fen").is_some() && json.get("evaluation").is_some() {
291                    return Ok(true);
292                }
293                if json.get("board").is_some() && json.get("eval").is_some() {
294                    return Ok(true);
295                }
296                if json.get("position").is_some() && json.get("score").is_some() {
297                    return Ok(true);
298                }
299            }
300        }
301
302        Ok(false)
303    }
304
305    /// Verify that a binary file contains training data
306    fn verify_binary_training_data(path: &Path) -> Result<bool, Box<dyn std::error::Error>> {
307        use std::io::Read;
308
309        // Skip very small files and very large files for safety
310        if let Ok(metadata) = std::fs::metadata(path) {
311            let size = metadata.len();
312            if !(100..=100_000_000).contains(&size) {
313                // Skip files < 100B or > 100MB
314                return Ok(false);
315            }
316        }
317
318        let mut file = std::fs::File::open(path)?;
319        let mut buffer = vec![0u8; 1024]; // Read first 1KB
320        let bytes_read = file.read(&mut buffer)?;
321
322        if bytes_read == 0 {
323            return Ok(false);
324        }
325
326        buffer.truncate(bytes_read);
327
328        // Try to deserialize as training data (catch all errors)
329        if bincode::deserialize::<Vec<(String, f32)>>(&buffer).is_ok() {
330            return Ok(true);
331        }
332
333        // Try LZ4 decompression first (catch all errors)
334        if let Ok(decompressed) = lz4_flex::decompress_size_prepended(&buffer) {
335            if bincode::deserialize::<Vec<(String, f32)>>(&decompressed).is_ok() {
336                return Ok(true);
337            }
338        }
339
340        Ok(false)
341    }
342
343    /// Format bytes for display
344    fn format_bytes(bytes: u64) -> String {
345        const UNITS: &[&str] = &["B", "KB", "MB", "GB"];
346        let mut size = bytes as f64;
347        let mut unit_index = 0;
348
349        while size >= 1024.0 && unit_index < UNITS.len() - 1 {
350            size /= 1024.0;
351            unit_index += 1;
352        }
353
354        "Processing files...".to_string()
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn test_base_name_extraction() {
364        assert_eq!(
365            AutoDiscovery::extract_base_name(Path::new("training_data.json")),
366            "training_data"
367        );
368        assert_eq!(
369            AutoDiscovery::extract_base_name(Path::new("training_data.mmap")),
370            "training_data"
371        );
372        assert_eq!(
373            AutoDiscovery::extract_base_name(Path::new("tactical_training_data.msgpack")),
374            "tactical_training_data"
375        );
376    }
377
378    #[test]
379    fn test_format_priority() {
380        assert!(FormatPriority::MemoryMapped < FormatPriority::MessagePack);
381        assert!(FormatPriority::MessagePack < FormatPriority::Binary);
382        assert!(FormatPriority::Binary < FormatPriority::Json);
383    }
384
385    #[test]
386    fn test_consolidation() {
387        let files = vec![
388            TrainingFile {
389                path: PathBuf::from("training_data.json"),
390                format: "JSON".to_string(),
391                priority: FormatPriority::Json,
392                base_name: "training_data".to_string(),
393                size_bytes: 1000,
394            },
395            TrainingFile {
396                path: PathBuf::from("training_data.mmap"),
397                format: "MMAP".to_string(),
398                priority: FormatPriority::MemoryMapped,
399                base_name: "training_data".to_string(),
400                size_bytes: 800,
401            },
402        ];
403
404        let consolidated = AutoDiscovery::consolidate_by_base_name(files);
405        let best = consolidated.get("training_data").unwrap();
406
407        assert_eq!(best.format, "MMAP");
408        assert_eq!(best.priority, FormatPriority::MemoryMapped);
409    }
410}