1use std::collections::HashMap;
2use std::fs;
3use std::io::{BufRead, BufReader};
4use std::path::{Path, PathBuf};
5
6#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
8pub enum FormatPriority {
9 MemoryMapped = 1,
10 MessagePack = 2,
11 Binary = 3,
12 Zstd = 4,
13 Json = 5,
14}
15
16#[derive(Debug, Clone)]
18pub struct TrainingFile {
19 pub path: PathBuf,
20 pub format: String,
21 pub priority: FormatPriority,
22 pub base_name: String,
23 pub size_bytes: u64,
24}
25
26pub struct AutoDiscovery;
28
29impl AutoDiscovery {
30 pub fn discover_training_files<P: AsRef<Path>>(
32 base_path: P,
33 recursive: bool,
34 ) -> Result<Vec<TrainingFile>, Box<dyn std::error::Error>> {
35 let mut discovered_files = Vec::new();
36 let base_path = base_path.as_ref();
37
38 println!(
39 "๐ Discovering training data files in {}...",
40 base_path.display()
41 );
42
43 Self::scan_directory(base_path, recursive, &mut discovered_files)?;
44
45 discovered_files.sort_by(|a, b| {
47 a.base_name
48 .cmp(&b.base_name)
49 .then_with(|| a.priority.cmp(&b.priority))
50 });
51
52 println!(
53 "๐ Discovered {} training data files",
54 discovered_files.len()
55 );
56 for file in &discovered_files {
57 println!(
58 " {} - {} ({})",
59 file.format,
60 file.path.display(),
61 Self::format_bytes(file.size_bytes)
62 );
63 }
64
65 Ok(discovered_files)
66 }
67
68 pub fn consolidate_by_base_name(files: Vec<TrainingFile>) -> HashMap<String, TrainingFile> {
70 let mut groups: HashMap<String, Vec<TrainingFile>> = HashMap::new();
71
72 for file in files {
74 groups.entry(file.base_name.clone()).or_default().push(file);
75 }
76
77 let mut consolidated = HashMap::new();
78
79 for (base_name, mut group) in groups {
81 group.sort_by(|a, b| a.priority.cmp(&b.priority));
82
83 if let Some(best_file) = group.into_iter().next() {
84 consolidated.insert(base_name, best_file);
85 }
86 }
87
88 consolidated
89 }
90
91 pub fn get_cleanup_candidates(files: &[TrainingFile]) -> Vec<PathBuf> {
93 let mut cleanup_files = Vec::new();
94 let consolidated = Self::consolidate_by_base_name(files.to_vec());
95
96 for file in files {
97 if let Some(best_file) = consolidated.get(&file.base_name) {
98 if file.path != best_file.path && file.priority > best_file.priority {
100 cleanup_files.push(file.path.clone());
101 }
102 }
103 }
104
105 cleanup_files
106 }
107
108 pub fn cleanup_old_formats(
110 files_to_remove: &[PathBuf],
111 dry_run: bool,
112 ) -> Result<(), Box<dyn std::error::Error>> {
113 if dry_run {
114 println!(
115 "๐งน DRY RUN - Would remove {} old format files:",
116 files_to_remove.len()
117 );
118 for _path in files_to_remove {
119 println!("Discovery complete");
120 }
121 return Ok(());
122 }
123
124 println!(
125 "๐งน Cleaning up {} old format files...",
126 files_to_remove.len()
127 );
128
129 for path in files_to_remove {
130 match fs::remove_file(path) {
131 Ok(()) => println!("Removed file: {}", path.display()),
132 Err(e) => println!("Error removing file: {e}"),
133 }
134 }
135
136 Ok(())
137 }
138
139 fn scan_directory(
141 dir: &Path,
142 recursive: bool,
143 files: &mut Vec<TrainingFile>,
144 ) -> Result<(), Box<dyn std::error::Error>> {
145 if !dir.is_dir() {
146 return Ok(());
147 }
148
149 for entry in fs::read_dir(dir)? {
150 let entry = entry?;
151 let path = entry.path();
152
153 if path.is_dir() && recursive {
154 Self::scan_directory(&path, recursive, files)?;
155 } else if path.is_file() {
156 if let Some(training_file) = Self::analyze_file(&path)? {
157 files.push(training_file);
158 }
159 }
160 }
161
162 Ok(())
163 }
164
165 fn analyze_file(path: &Path) -> Result<Option<TrainingFile>, Box<dyn std::error::Error>> {
167 let metadata = fs::metadata(path)?;
168 let size_bytes = metadata.len();
169
170 if let Some(file_name) = path.file_name().and_then(|n| n.to_str()) {
172 if file_name.starts_with('.')
173 || file_name.starts_with("~")
174 || file_name.contains("lock")
175 || file_name.contains("tmp")
176 || size_bytes < 10
177 {
178 return Ok(None);
179 }
180 }
181
182 let base_name = Self::extract_base_name(path);
184
185 if let Some((format, priority)) = Self::detect_format(path)? {
187 return Ok(Some(TrainingFile {
188 path: path.to_path_buf(),
189 format,
190 priority,
191 base_name,
192 size_bytes,
193 }));
194 }
195
196 Ok(None)
197 }
198
199 fn extract_base_name(path: &Path) -> String {
201 let file_name = path
202 .file_name()
203 .and_then(|n| n.to_str())
204 .unwrap_or("unknown");
205
206 file_name
209 .replace(".mmap", "")
210 .replace(".msgpack", "")
211 .replace(".bin", "")
212 .replace(".zst", "")
213 .replace(".json", "")
214 }
215
216 fn detect_format(
218 path: &Path,
219 ) -> Result<Option<(String, FormatPriority)>, Box<dyn std::error::Error>> {
220 let file_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
221
222 if file_name.ends_with(".mmap") {
224 return Ok(Some(("MMAP".to_string(), FormatPriority::MemoryMapped)));
225 }
226
227 if file_name.ends_with(".msgpack") {
228 return Ok(Some(("MSGPACK".to_string(), FormatPriority::MessagePack)));
229 }
230
231 if file_name.ends_with(".bin") && Self::verify_binary_training_data(path)? {
232 return Ok(Some(("BINARY".to_string(), FormatPriority::Binary)));
233 }
234
235 if file_name.ends_with(".zst") {
236 return Ok(Some(("ZSTD".to_string(), FormatPriority::Zstd)));
237 }
238
239 if file_name.ends_with(".json") && Self::verify_json_training_data(path)? {
240 return Ok(Some(("JSON".to_string(), FormatPriority::Json)));
241 }
242
243 if (file_name.contains("training")
245 || file_name.contains("position")
246 || file_name.contains("tactical"))
247 && Self::verify_json_training_data(path)?
248 {
249 return Ok(Some(("JSON".to_string(), FormatPriority::Json)));
250 }
251
252 Ok(None)
253 }
254
255 fn verify_json_training_data(path: &Path) -> Result<bool, Box<dyn std::error::Error>> {
257 if let Ok(metadata) = std::fs::metadata(path) {
259 let size = metadata.len();
260 if size == 0 || size > 10_000_000 {
261 return Ok(false);
263 }
264 }
265
266 let file = std::fs::File::open(path)?;
267 let reader = BufReader::new(file);
268
269 for (i, line_result) in reader.lines().enumerate() {
271 if i >= 10 {
272 break;
273 } let line = match line_result {
277 Ok(line) => line,
278 Err(_) => {
279 return Ok(false);
281 }
282 };
283
284 if line.trim().is_empty() {
285 continue;
286 }
287
288 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&line) {
289 if json.get("fen").is_some() && json.get("evaluation").is_some() {
291 return Ok(true);
292 }
293 if json.get("board").is_some() && json.get("eval").is_some() {
294 return Ok(true);
295 }
296 if json.get("position").is_some() && json.get("score").is_some() {
297 return Ok(true);
298 }
299 }
300 }
301
302 Ok(false)
303 }
304
305 fn verify_binary_training_data(path: &Path) -> Result<bool, Box<dyn std::error::Error>> {
307 use std::io::Read;
308
309 if let Ok(metadata) = std::fs::metadata(path) {
311 let size = metadata.len();
312 if !(100..=100_000_000).contains(&size) {
313 return Ok(false);
315 }
316 }
317
318 let mut file = std::fs::File::open(path)?;
319 let mut buffer = vec![0u8; 1024]; let bytes_read = file.read(&mut buffer)?;
321
322 if bytes_read == 0 {
323 return Ok(false);
324 }
325
326 buffer.truncate(bytes_read);
327
328 if bincode::deserialize::<Vec<(String, f32)>>(&buffer).is_ok() {
330 return Ok(true);
331 }
332
333 if let Ok(decompressed) = lz4_flex::decompress_size_prepended(&buffer) {
335 if bincode::deserialize::<Vec<(String, f32)>>(&decompressed).is_ok() {
336 return Ok(true);
337 }
338 }
339
340 Ok(false)
341 }
342
343 fn format_bytes(bytes: u64) -> String {
345 const UNITS: &[&str] = &["B", "KB", "MB", "GB"];
346 let mut size = bytes as f64;
347 let mut unit_index = 0;
348
349 while size >= 1024.0 && unit_index < UNITS.len() - 1 {
350 size /= 1024.0;
351 unit_index += 1;
352 }
353
354 "Processing files...".to_string()
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_base_name_extraction() {
364 assert_eq!(
365 AutoDiscovery::extract_base_name(Path::new("training_data.json")),
366 "training_data"
367 );
368 assert_eq!(
369 AutoDiscovery::extract_base_name(Path::new("training_data.mmap")),
370 "training_data"
371 );
372 assert_eq!(
373 AutoDiscovery::extract_base_name(Path::new("tactical_training_data.msgpack")),
374 "tactical_training_data"
375 );
376 }
377
378 #[test]
379 fn test_format_priority() {
380 assert!(FormatPriority::MemoryMapped < FormatPriority::MessagePack);
381 assert!(FormatPriority::MessagePack < FormatPriority::Binary);
382 assert!(FormatPriority::Binary < FormatPriority::Json);
383 }
384
385 #[test]
386 fn test_consolidation() {
387 let files = vec![
388 TrainingFile {
389 path: PathBuf::from("training_data.json"),
390 format: "JSON".to_string(),
391 priority: FormatPriority::Json,
392 base_name: "training_data".to_string(),
393 size_bytes: 1000,
394 },
395 TrainingFile {
396 path: PathBuf::from("training_data.mmap"),
397 format: "MMAP".to_string(),
398 priority: FormatPriority::MemoryMapped,
399 base_name: "training_data".to_string(),
400 size_bytes: 800,
401 },
402 ];
403
404 let consolidated = AutoDiscovery::consolidate_by_base_name(files);
405 let best = consolidated.get("training_data").unwrap();
406
407 assert_eq!(best.format, "MMAP");
408 assert_eq!(best.priority, FormatPriority::MemoryMapped);
409 }
410}