1use anyhow::{bail, Context, Result};
2use clap::Args;
3use indicatif::{ProgressBar, ProgressStyle};
4use serde::{Deserialize, Serialize};
5use serde_json::{json, Value};
6use std::collections::HashMap;
7use std::fs;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10use tokio::sync::Mutex;
11use walkdir::WalkDir;
12#[derive(Debug, Clone, PartialEq)]
13pub enum DataFormat {
14 Auto,
15 Legacy,
16 Molds,
17 Custom,
18}
19#[derive(Debug, Clone, PartialEq)]
20pub enum TrainingFormat {
21 Preference { chosen_field: String, rejected_field: String },
22 Completion { completion_field: String, label_field: Option<String> },
23 Instruction { instruction_field: String, output_field: String },
24 Chat { messages_field: String },
25 Custom { fields: Vec<String> },
26}
27impl std::str::FromStr for DataFormat {
28 type Err = anyhow::Error;
29 fn from_str(s: &str) -> Result<Self> {
30 match s.to_lowercase().as_str() {
31 "auto" => Ok(DataFormat::Auto),
32 "legacy" => Ok(DataFormat::Legacy),
33 "molds" => Ok(DataFormat::Molds),
34 "custom" => Ok(DataFormat::Custom),
35 _ => bail!("Invalid format: {}. Must be auto, legacy, molds, or custom", s),
36 }
37 }
38}
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct LegacySample {
41 pub x: Vec<f32>,
42 pub y: Vec<f32>,
43}
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct MoldsSample {
46 pub module_name: String,
47 pub file_name: String,
48 pub implementation: String,
49 pub documentation: String,
50 #[serde(rename = "system_context")]
51 pub system_context: Option<String>,
52}
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct TrainingSample {
55 pub prompt: Option<String>,
56 pub chosen: Option<String>,
57 pub rejected: Option<String>,
58 pub completion: Option<String>,
59 pub label: Option<f32>,
60 pub meta: HashMap<String, Value>,
61}
62#[derive(Debug, Clone)]
63pub struct TrainingDataset {
64 pub samples: Vec<TrainingSample>,
65 pub format: TrainingFormat,
66 pub statistics: DatasetStats,
67}
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct DatasetStats {
70 pub total_samples: usize,
71 pub avg_prompt_length: f64,
72 pub avg_completion_length: f64,
73 pub field_coverage: HashMap<String, f64>,
74 pub quality_score: Option<f64>,
75}
76#[derive(Debug)]
77pub struct GenericJSONDataset {
78 pub data: Vec<Value>,
79 pub format: DataFormat,
80 pub schema: Option<Value>,
81}
82impl GenericJSONDataset {
83 pub fn new(
84 json_paths: &[PathBuf],
85 schema_path: Option<&Path>,
86 data_format: DataFormat,
87 ) -> Result<Self> {
88 if json_paths.is_empty() {
89 bail!("No JSON files provided");
90 }
91 for path in json_paths {
92 if !path.is_file() {
93 bail!("JSON file not found: {}", path.display());
94 }
95 }
96 let mut raw_data = Vec::new();
97 for path in json_paths {
98 let content = fs::read_to_string(path)
99 .with_context(|| format!("Failed to read {}", path.display()))?;
100 let parsed: Value = serde_json::from_str(&content)
101 .with_context(|| format!("Failed to parse JSON in {}", path.display()))?;
102 match parsed {
103 Value::Array(arr) => raw_data.extend(arr),
104 Value::Object(obj) => raw_data.push(Value::Object(obj)),
105 _ => {
106 bail!(
107 "Root object in {} must be an array or object, got {}", path
108 .display(), parsed
109 )
110 }
111 }
112 }
113 if raw_data.is_empty() {
114 bail!("All input files are empty");
115 }
116 let format = match data_format {
117 DataFormat::Auto => Self::detect_format(&raw_data[0])?,
118 _ => data_format,
119 };
120 let schema = if let Some(schema_path) = schema_path {
121 let schema_content = fs::read_to_string(schema_path)
122 .with_context(|| {
123 format!("Failed to read schema {}", schema_path.display())
124 })?;
125 Some(
126 serde_json::from_str(&schema_content)
127 .with_context(|| format!("Failed to parse schema JSON"))?,
128 )
129 } else {
130 Self::get_builtin_schema(&format)
131 };
132 if let Some(ref schema) = schema {
133 Self::validate_data(&raw_data, schema, &format)?;
134 }
135 Ok(GenericJSONDataset {
136 data: raw_data,
137 format,
138 schema,
139 })
140 }
141 fn detect_format(first_sample: &Value) -> Result<DataFormat> {
142 if let Some(obj) = first_sample.as_object() {
143 if obj.contains_key("module_name") {
144 Ok(DataFormat::Molds)
145 } else if obj.contains_key("x") && obj.contains_key("y") {
146 Ok(DataFormat::Legacy)
147 } else {
148 Ok(DataFormat::Custom)
149 }
150 } else {
151 bail!("First sample is not an object - cannot auto-detect format");
152 }
153 }
154 fn get_builtin_schema(format: &DataFormat) -> Option<Value> {
155 match format {
156 DataFormat::Legacy => {
157 Some(
158 json!(
159 { "type" : "array", "items" : { "type" : "object", "required" :
160 ["x", "y"] } }
161 ),
162 )
163 }
164 DataFormat::Molds => {
165 Some(
166 json!(
167 { "type" : "array", "items" : { "type" : "object", "required" :
168 ["module_name", "file_name", "implementation", "documentation"] }
169 }
170 ),
171 )
172 }
173 _ => None,
174 }
175 }
176 fn validate_data(
177 data: &[Value],
178 _schema: &Value,
179 format: &DataFormat,
180 ) -> Result<()> {
181 let required_keys = match format {
182 DataFormat::Legacy => vec!["x", "y"],
183 DataFormat::Molds => {
184 vec!["module_name", "file_name", "implementation", "documentation"]
185 }
186 _ => return Ok(()),
187 };
188 for (i, sample) in data.iter().enumerate() {
189 if let Some(obj) = sample.as_object() {
190 for key in &required_keys {
191 if !obj.contains_key(*key) {
192 bail!(
193 "Sample {} is missing required key '{}' for {} format", i,
194 key, format!("{:?}", format) .to_lowercase()
195 );
196 }
197 }
198 } else {
199 bail!("Sample {} is not an object", i);
200 }
201 }
202 Ok(())
203 }
204 pub fn len(&self) -> usize {
205 self.data.len()
206 }
207 pub fn is_empty(&self) -> bool {
208 self.data.is_empty()
209 }
210 pub fn get_random_sample(&self) -> Option<&Value> {
211 if self.data.is_empty() {
212 None
213 } else {
214 use rand::Rng;
215 let mut rng = rand::thread_rng();
216 let idx = rng.gen_range(0..self.data.len());
217 Some(&self.data[idx])
218 }
219 }
220 pub fn stats(&self) -> HashMap<String, Value> {
221 let mut stats = HashMap::new();
222 stats.insert("num_samples".to_string(), json!(self.len()));
223 stats
224 .insert(
225 "format".to_string(),
226 json!(format!("{:?}", self.format) .to_lowercase()),
227 );
228 stats.insert("has_schema".to_string(), json!(self.schema.is_some()));
229 if !self.data.is_empty() {
230 if let Some(obj) = self.data[0].as_object() {
231 stats
232 .insert(
233 "sample_keys".to_string(),
234 json!(obj.keys().collect::< Vec < _ >> ()),
235 );
236 }
237 }
238 stats
239 }
240 pub fn detect_training_format(&self) -> Result<TrainingFormat> {
241 if self.data.is_empty() {
242 bail!("Cannot detect training format from empty dataset");
243 }
244 let first_sample = &self.data[0];
245 let fields = if let Some(obj) = first_sample.as_object() {
246 obj.keys().map(|s| s.as_str()).collect::<Vec<_>>()
247 } else {
248 bail!("First sample is not an object");
249 };
250 if fields.contains(&"chosen") && fields.contains(&"rejected") {
251 return Ok(TrainingFormat::Preference {
252 chosen_field: "chosen".to_string(),
253 rejected_field: "rejected".to_string(),
254 });
255 }
256 if fields.contains(&"completion") {
257 let label_field = if fields.contains(&"label") {
258 Some("label".to_string())
259 } else {
260 None
261 };
262 return Ok(TrainingFormat::Completion {
263 completion_field: "completion".to_string(),
264 label_field,
265 });
266 }
267 if fields.contains(&"instruction") && fields.contains(&"output") {
268 return Ok(TrainingFormat::Instruction {
269 instruction_field: "instruction".to_string(),
270 output_field: "output".to_string(),
271 });
272 }
273 if fields.contains(&"messages") {
274 return Ok(TrainingFormat::Chat {
275 messages_field: "messages".to_string(),
276 });
277 }
278 Ok(TrainingFormat::Custom {
279 fields: fields.into_iter().map(|s| s.to_string()).collect(),
280 })
281 }
282 pub fn to_training_dataset(&self) -> Result<TrainingDataset> {
283 let training_format = self.detect_training_format()?;
284 let mut samples = Vec::new();
285 for (i, sample) in self.data.iter().enumerate() {
286 if let Some(obj) = sample.as_object() {
287 let training_sample = self
288 .convert_sample_to_training(obj, &training_format)
289 .with_context(|| format!("Failed to convert sample {}", i))?;
290 samples.push(training_sample);
291 } else {
292 bail!("Sample {} is not an object", i);
293 }
294 }
295 let statistics = self.compute_statistics(&samples)?;
296 Ok(TrainingDataset {
297 samples,
298 format: training_format,
299 statistics,
300 })
301 }
302 fn convert_sample_to_training(
303 &self,
304 obj: &serde_json::Map<String, Value>,
305 format: &TrainingFormat,
306 ) -> Result<TrainingSample> {
307 let mut sample = TrainingSample {
308 prompt: None,
309 chosen: None,
310 rejected: None,
311 completion: None,
312 label: None,
313 meta: HashMap::new(),
314 };
315 if let Some(prompt_val) = obj.get("prompt") {
316 if let Some(prompt_str) = prompt_val.as_str() {
317 sample.prompt = Some(prompt_str.to_string());
318 }
319 }
320 match format {
321 TrainingFormat::Preference { chosen_field, rejected_field } => {
322 if let Some(chosen_val) = obj.get(chosen_field) {
323 if let Some(chosen_str) = chosen_val.as_str() {
324 sample.chosen = Some(chosen_str.to_string());
325 }
326 }
327 if let Some(rejected_val) = obj.get(rejected_field) {
328 if let Some(rejected_str) = rejected_val.as_str() {
329 sample.rejected = Some(rejected_str.to_string());
330 }
331 }
332 }
333 TrainingFormat::Completion { completion_field, label_field } => {
334 if let Some(completion_val) = obj.get(completion_field) {
335 if let Some(completion_str) = completion_val.as_str() {
336 sample.completion = Some(completion_str.to_string());
337 }
338 }
339 if let Some(label_field) = label_field {
340 if let Some(label_val) = obj.get(label_field) {
341 if let Some(label_num) = label_val.as_f64() {
342 sample.label = Some(label_num as f32);
343 } else if let Some(label_bool) = label_val.as_bool() {
344 sample.label = Some(if label_bool { 1.0 } else { 0.0 });
345 }
346 }
347 }
348 }
349 TrainingFormat::Instruction { instruction_field, output_field } => {
350 if let Some(instruction_val) = obj.get(instruction_field) {
351 if let Some(instruction_str) = instruction_val.as_str() {
352 sample.prompt = Some(instruction_str.to_string());
353 }
354 }
355 if let Some(output_val) = obj.get(output_field) {
356 if let Some(output_str) = output_val.as_str() {
357 sample.completion = Some(output_str.to_string());
358 }
359 }
360 }
361 TrainingFormat::Chat { messages_field } => {
362 if let Some(messages_val) = obj.get(messages_field) {
363 sample.meta.insert("messages".to_string(), messages_val.clone());
364 if let Some(messages) = messages_val.as_array() {
365 if let Some(first_msg) = messages.first() {
366 if let Some(content) = first_msg
367 .get("content")
368 .and_then(|c| c.as_str())
369 {
370 sample.prompt = Some(content.to_string());
371 }
372 }
373 if let Some(last_msg) = messages.last() {
374 if let Some(content) = last_msg
375 .get("content")
376 .and_then(|c| c.as_str())
377 {
378 sample.completion = Some(content.to_string());
379 }
380 }
381 }
382 }
383 }
384 TrainingFormat::Custom { fields } => {
385 for field in fields {
386 if let Some(value) = obj.get(field) {
387 sample.meta.insert(field.clone(), value.clone());
388 }
389 }
390 }
391 }
392 for (key, value) in obj {
393 if !matches!(
394 key.as_str(), "prompt" | "chosen" | "rejected" | "completion" | "label" |
395 "instruction" | "output" | "messages"
396 ) {
397 sample.meta.insert(key.clone(), value.clone());
398 }
399 }
400 Ok(sample)
401 }
402 fn compute_statistics(&self, samples: &[TrainingSample]) -> Result<DatasetStats> {
403 let total_samples = samples.len();
404 let mut total_prompt_length = 0;
405 let mut total_completion_length = 0;
406 let mut prompt_count = 0;
407 let mut completion_count = 0;
408 let mut field_coverage = HashMap::new();
409 for sample in samples {
410 if sample.prompt.is_some() {
411 *field_coverage.entry("prompt".to_string()).or_insert(0.0) += 1.0;
412 total_prompt_length += sample.prompt.as_ref().unwrap().len();
413 prompt_count += 1;
414 }
415 if sample.chosen.is_some() {
416 *field_coverage.entry("chosen".to_string()).or_insert(0.0) += 1.0;
417 }
418 if sample.rejected.is_some() {
419 *field_coverage.entry("rejected".to_string()).or_insert(0.0) += 1.0;
420 }
421 if sample.completion.is_some() {
422 *field_coverage.entry("completion".to_string()).or_insert(0.0) += 1.0;
423 total_completion_length += sample.completion.as_ref().unwrap().len();
424 completion_count += 1;
425 }
426 if sample.label.is_some() {
427 *field_coverage.entry("label".to_string()).or_insert(0.0) += 1.0;
428 }
429 }
430 for count in field_coverage.values_mut() {
431 *count = *count / total_samples as f64;
432 }
433 let avg_prompt_length = if prompt_count > 0 {
434 total_prompt_length as f64 / prompt_count as f64
435 } else {
436 0.0
437 };
438 let avg_completion_length = if completion_count > 0 {
439 total_completion_length as f64 / completion_count as f64
440 } else {
441 0.0
442 };
443 let quality_score = Some(
444 (field_coverage.get("prompt").unwrap_or(&0.0)
445 + field_coverage.get("chosen").unwrap_or(&0.0)
446 + field_coverage.get("rejected").unwrap_or(&0.0)
447 + field_coverage.get("completion").unwrap_or(&0.0)) / 4.0,
448 );
449 Ok(DatasetStats {
450 total_samples,
451 avg_prompt_length,
452 avg_completion_length,
453 field_coverage,
454 quality_score,
455 })
456 }
457}
458impl TrainingDataset {
459 pub fn to_algorithm_format(
460 &self,
461 algorithm: &str,
462 ) -> Result<Box<dyn std::any::Any>> {
463 match algorithm.to_lowercase().as_str() {
464 "bco" => {
465 let bco_data = self.to_bco_format()?;
466 Ok(Box::new(bco_data))
467 }
468 "dpo" => {
469 let dpo_data = self.to_dpo_format()?;
470 Ok(Box::new(dpo_data))
471 }
472 "ppo" => {
473 let ppo_data = self.to_ppo_format()?;
474 Ok(Box::new(ppo_data))
475 }
476 "sft" => {
477 let sft_data = self.to_sft_format()?;
478 Ok(Box::new(sft_data))
479 }
480 _ => bail!("Unsupported algorithm: {}", algorithm),
481 }
482 }
483 fn to_bco_format(&self) -> Result<BCODataset> {
484 let mut bco_samples = Vec::new();
485 for sample in &self.samples {
486 match &self.format {
487 TrainingFormat::Preference { .. } => {
488 if let (Some(chosen), Some(rejected)) = (
489 &sample.chosen,
490 &sample.rejected,
491 ) {
492 let (completion, label) = if rand::random::<bool>() {
493 (chosen.clone(), true)
494 } else {
495 (rejected.clone(), false)
496 };
497 bco_samples
498 .push(BCOSample {
499 prompt: sample.prompt.clone().unwrap_or_default(),
500 completion,
501 label,
502 });
503 }
504 }
505 TrainingFormat::Completion { .. } => {
506 if let Some(completion) = &sample.completion {
507 let label = sample.label.map(|l| l > 0.5).unwrap_or(true);
508 bco_samples
509 .push(BCOSample {
510 prompt: sample.prompt.clone().unwrap_or_default(),
511 completion: completion.clone(),
512 label,
513 });
514 }
515 }
516 _ => {
517 if let Some(completion) = &sample.completion {
518 bco_samples
519 .push(BCOSample {
520 prompt: sample.prompt.clone().unwrap_or_default(),
521 completion: completion.clone(),
522 label: sample.label.map(|l| l > 0.5).unwrap_or(true),
523 });
524 }
525 }
526 }
527 }
528 Ok(BCODataset { samples: bco_samples })
529 }
530 fn to_dpo_format(&self) -> Result<DPODataset> {
531 let mut dpo_samples = Vec::new();
532 for sample in &self.samples {
533 if let TrainingFormat::Preference { .. } = &self.format {
534 if let (Some(chosen), Some(rejected)) = (
535 &sample.chosen,
536 &sample.rejected,
537 ) {
538 dpo_samples
539 .push(DPOSample {
540 prompt: sample.prompt.clone().unwrap_or_default(),
541 chosen: chosen.clone(),
542 rejected: rejected.clone(),
543 });
544 }
545 } else {
546 bail!(
547 "DPO format requires preference-style data (chosen/rejected fields)"
548 );
549 }
550 }
551 Ok(DPODataset { samples: dpo_samples })
552 }
553 fn to_ppo_format(&self) -> Result<PPODataset> {
554 let mut ppo_samples = Vec::new();
555 for sample in &self.samples {
556 if let Some(completion) = &sample.completion {
557 ppo_samples
558 .push(PPOSample {
559 prompt: sample.prompt.clone().unwrap_or_default(),
560 completion: completion.clone(),
561 reward: sample.label.unwrap_or(0.0),
562 });
563 }
564 }
565 Ok(PPODataset { samples: ppo_samples })
566 }
567 fn to_sft_format(&self) -> Result<SFTDataset> {
568 let mut sft_samples = Vec::new();
569 for sample in &self.samples {
570 if let Some(completion) = &sample.completion {
571 sft_samples
572 .push(SFTSample {
573 prompt: sample.prompt.clone().unwrap_or_default(),
574 completion: completion.clone(),
575 });
576 }
577 }
578 Ok(SFTDataset { samples: sft_samples })
579 }
580 pub fn quality_assessment(&self) -> DatasetQualityReport {
581 let mut report = DatasetQualityReport {
582 overall_score: 0.0,
583 issues: Vec::new(),
584 recommendations: Vec::new(),
585 };
586 let prompt_coverage = self
587 .statistics
588 .field_coverage
589 .get("prompt")
590 .unwrap_or(&0.0);
591 let completion_coverage = self
592 .statistics
593 .field_coverage
594 .get("completion")
595 .unwrap_or(&0.0);
596 let chosen_coverage = self
597 .statistics
598 .field_coverage
599 .get("chosen")
600 .unwrap_or(&0.0);
601 let rejected_coverage = self
602 .statistics
603 .field_coverage
604 .get("rejected")
605 .unwrap_or(&0.0);
606 match &self.format {
607 TrainingFormat::Preference { .. } => {
608 if *chosen_coverage < 0.9 {
609 report
610 .issues
611 .push(
612 format!(
613 "Low chosen field coverage: {:.1}%", chosen_coverage * 100.0
614 ),
615 );
616 }
617 if *rejected_coverage < 0.9 {
618 report
619 .issues
620 .push(
621 format!(
622 "Low rejected field coverage: {:.1}%", rejected_coverage *
623 100.0
624 ),
625 );
626 }
627 report.overall_score = (chosen_coverage + rejected_coverage
628 + prompt_coverage) / 3.0;
629 }
630 TrainingFormat::Completion { .. } => {
631 if *completion_coverage < 0.9 {
632 report
633 .issues
634 .push(
635 format!(
636 "Low completion field coverage: {:.1}%", completion_coverage
637 * 100.0
638 ),
639 );
640 }
641 report.overall_score = (completion_coverage + prompt_coverage) / 2.0;
642 }
643 _ => {
644 report.overall_score = self.statistics.quality_score.unwrap_or(0.0);
645 }
646 }
647 if self.statistics.avg_prompt_length < 10.0 {
648 report.issues.push("Very short average prompt length".to_string());
649 }
650 if self.statistics.avg_completion_length < 10.0 {
651 report.issues.push("Very short average completion length".to_string());
652 }
653 if report.issues.is_empty() {
654 report.recommendations.push("Dataset quality looks good!".to_string());
655 } else {
656 report
657 .recommendations
658 .push(
659 "Consider filtering or augmenting low-quality samples".to_string(),
660 );
661 }
662 report
663 }
664}
665#[derive(Debug, Clone, Serialize, Deserialize)]
666pub struct BCOSample {
667 pub prompt: String,
668 pub completion: String,
669 pub label: bool,
670}
671#[derive(Debug, Clone)]
672pub struct BCODataset {
673 pub samples: Vec<BCOSample>,
674}
675#[derive(Debug, Clone, Serialize, Deserialize)]
676pub struct DPOSample {
677 pub prompt: String,
678 pub chosen: String,
679 pub rejected: String,
680}
681#[derive(Debug, Clone)]
682pub struct DPODataset {
683 pub samples: Vec<DPOSample>,
684}
685#[derive(Debug, Clone, Serialize, Deserialize)]
686pub struct PPOSample {
687 pub prompt: String,
688 pub completion: String,
689 pub reward: f32,
690}
691#[derive(Debug, Clone)]
692pub struct PPODataset {
693 pub samples: Vec<PPOSample>,
694}
695#[derive(Debug, Clone, Serialize, Deserialize)]
696pub struct SFTSample {
697 pub prompt: String,
698 pub completion: String,
699}
700#[derive(Debug, Clone)]
701pub struct SFTDataset {
702 pub samples: Vec<SFTSample>,
703}
704#[derive(Debug, Clone)]
705pub struct DatasetQualityReport {
706 pub overall_score: f64,
707 pub issues: Vec<String>,
708 pub recommendations: Vec<String>,
709}
710fn clean_whitespace(text: &str) -> String {
711 text.lines()
712 .map(|line| line.trim_end())
713 .collect::<Vec<_>>()
714 .join("\n")
715 .trim()
716 .to_string()
717}
718fn process_file(
719 src_path: &Path,
720 dst_path: &Path,
721 schema_path: Option<&Path>,
722 format_override: &DataFormat,
723) -> Result<()> {
724 let content = fs::read_to_string(src_path)
725 .with_context(|| format!("Failed to read {}", src_path.display()))?;
726 let raw: Value = serde_json::from_str(&content)
727 .with_context(|| format!("Failed to parse JSON in {}", src_path.display()))?;
728 let _temp_dataset = GenericJSONDataset::new(
729 &[src_path.to_path_buf()],
730 schema_path,
731 format_override.clone(),
732 )?;
733 let cleaned = if let Value::Array(arr) = raw {
734 let cleaned_arr: Vec<Value> = arr
735 .into_iter()
736 .map(|mut entry| {
737 if let Value::Object(ref mut obj) = entry {
738 for (_key, value) in obj.iter_mut() {
739 if let Value::String(ref mut s) = value {
740 *s = clean_whitespace(s);
741 }
742 }
743 }
744 entry
745 })
746 .collect();
747 Value::Array(cleaned_arr)
748 } else {
749 raw
750 };
751 dst_path
752 .parent()
753 .map(|p| fs::create_dir_all(p))
754 .transpose()
755 .with_context(|| {
756 format!("Failed to create directory for {}", dst_path.display())
757 })?;
758 let cleaned_json = serde_json::to_string_pretty(&cleaned)
759 .with_context(|| "Failed to serialize cleaned JSON")?;
760 fs::write(dst_path, cleaned_json)
761 .with_context(|| format!("Failed to write to {}", dst_path.display()))?;
762 Ok(())
763}
764pub async fn run_multi_process_clean(
765 src_files: Vec<PathBuf>,
766 dst_root: &Path,
767 schema_dir: Option<&Path>,
768 format_override: &DataFormat,
769 _jobs: usize,
770) -> Result<()> {
771 if src_files.is_empty() {
772 bail!("No source files provided");
773 }
774 let tasks: Vec<_> = src_files
775 .iter()
776 .map(|src| {
777 let dst = dst_root.join(src.file_name().unwrap());
778 let schema_path = schema_dir
779 .and_then(|dir| {
780 let candidate = dir
781 .join(format!("{}.schema.json", src.file_stem() ?.to_str() ?));
782 if candidate.is_file() { Some(candidate) } else { None }
783 });
784 (src.clone(), dst, schema_path)
785 })
786 .collect();
787 let pb = ProgressBar::new(tasks.len() as u64);
788 pb.set_style(
789 ProgressStyle::default_bar()
790 .template(
791 "{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({eta}) {msg}",
792 )
793 .unwrap()
794 .progress_chars("#>-"),
795 );
796 pb.set_message("Cleaning & validating");
797 let pb = Arc::new(Mutex::new(pb));
798 let results: Vec<Result<()>> = tasks
799 .iter()
800 .map(|(src, dst, schema_path)| {
801 let result = process_file(src, dst, schema_path.as_deref(), format_override);
802 if let Ok(pb) = pb.try_lock() {
803 pb.inc(1);
804 }
805 result
806 })
807 .collect();
808 if let Ok(pb) = pb.try_lock() {
809 pb.finish_with_message("Complete");
810 }
811 let errors: Vec<_> = results
812 .into_iter()
813 .enumerate()
814 .filter_map(|(i, r)| r.err().map(|e| (i, e)))
815 .collect();
816 if !errors.is_empty() {
817 for (i, e) in &errors {
818 eprintln!("Error processing file {}: {}", i, e);
819 }
820 bail!("{} files failed processing", errors.len());
821 }
822 Ok(())
823}
824pub fn gather_json_paths(
825 files: &[String],
826 data_dirs: &[String],
827) -> Result<Vec<PathBuf>> {
828 let mut paths = Vec::new();
829 for file in files {
830 let path = PathBuf::from(file);
831 if !path.is_file() {
832 bail!("Specified file does not exist: {}", path.display());
833 }
834 paths.push(path);
835 }
836 for dir in data_dirs {
837 let dir_path = PathBuf::from(dir);
838 if !dir_path.is_dir() {
839 bail!("Data directory does not exist: {}", dir_path.display());
840 }
841 for entry in WalkDir::new(&dir_path).into_iter().filter_map(|e| e.ok()) {
842 if entry.path().extension().map_or(false, |ext| ext == "json") {
843 paths.push(entry.path().to_path_buf());
844 }
845 }
846 }
847 if paths.is_empty() {
848 bail!("No JSON files found in specified paths");
849 }
850 paths.sort();
851 paths.dedup();
852 Ok(paths)
853}
854pub fn find_schema_for_file(
855 json_path: &Path,
856 schema_dir: Option<&Path>,
857) -> Option<PathBuf> {
858 schema_dir
859 .and_then(|dir| {
860 let candidate = dir
861 .join(format!("{}.schema.json", json_path.file_stem() ?.to_str() ?));
862 if candidate.is_file() { Some(candidate) } else { None }
863 })
864}
865pub async fn run_json_cmd(args: JsonArgs) -> Result<()> {
866 if args.multi_process {
867 if args.input_folder.is_none() || args.output.is_none() {
868 bail!("Multi-process mode requires both --input-folder and --output");
869 }
870 let src_root = args.input_folder.unwrap();
871 let dst_root = args.output.unwrap();
872 if !src_root.is_dir() {
873 bail!("Input folder does not exist: {}", src_root.display());
874 }
875 let src_files: Vec<_> = WalkDir::new(&src_root)
876 .into_iter()
877 .filter_map(|e| e.ok())
878 .filter(|e| e.path().extension().map_or(false, |ext| ext == "json"))
879 .map(|e| e.path().to_path_buf())
880 .collect();
881 if src_files.is_empty() {
882 bail!("No JSON files found in {}", src_root.display());
883 }
884 let schema_dir = args.schema_dir.as_ref();
885 let format_override = args.format.clone();
886 println!(
887 "🔧 Starting multi-process clean-validate:\n source: {}\n destination: {}\n workers: {}\n format: {:?}",
888 src_root.display(), dst_root.display(), args.jobs, format_override
889 );
890 run_multi_process_clean(
891 src_files,
892 &dst_root,
893 schema_dir.map(|p| p.as_path()),
894 &format_override,
895 args.jobs,
896 )
897 .await?;
898 println!("✅ All files cleaned and validated successfully.");
899 return Ok(());
900 }
901 let json_paths = gather_json_paths(&args.file, &args.data_dir)?;
902 let mut all_samples = Vec::new();
903 for path in &json_paths {
904 let schema_path = find_schema_for_file(
905 path,
906 args.schema_dir.as_ref().map(|p| p.as_path()),
907 );
908 println!(
909 "Loading {} (schema: {})", path.file_name().unwrap().to_string_lossy(),
910 schema_path.as_ref().map(| p | p.file_name().unwrap().to_string_lossy())
911 .unwrap_or(std::borrow::Cow::Borrowed("built-in"))
912 );
913 let dataset = GenericJSONDataset::new(
914 &[path.clone()],
915 schema_path.as_ref().map(|p| p.as_path()),
916 args.format.clone(),
917 )?;
918 all_samples.extend(dataset.data);
919 }
920 if let Some(merge_output) = &args.merge_output {
921 merge_output
922 .parent()
923 .map(|p| fs::create_dir_all(p))
924 .transpose()
925 .with_context(|| {
926 format!("Failed to create directory for {}", merge_output.display())
927 })?;
928 let merged_json = serde_json::to_string_pretty(&all_samples)
929 .with_context(|| "Failed to serialize merged JSON")?;
930 fs::write(merge_output, merged_json)
931 .with_context(|| {
932 format!("Failed to write merged output to {}", merge_output.display())
933 })?;
934 println!(
935 "✅ Merged {} samples into {}", all_samples.len(), merge_output.display()
936 );
937 }
938 if args.show_stats {
939 if !json_paths.is_empty() {
940 let temp_dataset = GenericJSONDataset::new(
941 &json_paths,
942 args.schema_dir.as_ref().map(|p| p.as_path()),
943 args.format,
944 )?;
945 println!(
946 "\n--- Dataset statistics ------------------------------------------------"
947 );
948 for (k, v) in temp_dataset.stats() {
949 println!("{:20}: {}", k, v);
950 }
951 println!(
952 "--------------------------------------------------------------------"
953 );
954 }
955 }
956 if args.merge_output.is_none() {
957 println!("🎉 Validation finished – no merged output requested.");
958 }
959 Ok(())
960}
961#[derive(Args, Debug, Clone)]
962pub struct JsonArgs {
963 #[arg(long)]
964 pub data_dir: Vec<String>,
965 #[arg(long, short = 'f')]
966 pub file: Vec<String>,
967 #[arg(long)]
968 pub schema_dir: Option<PathBuf>,
969 #[arg(long, default_value = "auto")]
970 pub format: DataFormat,
971 #[arg(long)]
972 pub merge_output: Option<PathBuf>,
973 #[arg(long)]
974 pub show_stats: bool,
975 #[arg(long, default_value = "42")]
976 pub seed: u64,
977 #[arg(long)]
978 pub multi_process: bool,
979 #[arg(long)]
980 pub input_folder: Option<PathBuf>,
981 #[arg(long, short = 'o')]
982 pub output: Option<PathBuf>,
983 #[arg(long, short = 'j', default_value_t = num_cpus::get())]
984 pub jobs: usize,
985}