1use anyhow::{bail, Context, Result};
2use clap::Args;
3use indicatif::{ProgressBar, ProgressStyle};
4use serde::{Deserialize, Serialize};
5use serde_json::{json, Value};
6use std::collections::HashMap;
7use std::fs;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10use tokio::sync::Mutex;
11use walkdir::WalkDir;
12#[derive(Debug, Clone, PartialEq)]
13pub enum DataFormat {
14 Auto,
15 Legacy,
16 Molds,
17 Custom,
18}
19
20#[derive(Debug, Clone, PartialEq)]
21pub enum TrainingFormat {
22 Preference { chosen_field: String, rejected_field: String },
23 Completion { completion_field: String, label_field: Option<String> },
24 Instruction { instruction_field: String, output_field: String },
25 Chat { messages_field: String },
26 Custom { fields: Vec<String> },
27}
28impl std::str::FromStr for DataFormat {
29 type Err = anyhow::Error;
30 fn from_str(s: &str) -> Result<Self> {
31 match s.to_lowercase().as_str() {
32 "auto" => Ok(DataFormat::Auto),
33 "legacy" => Ok(DataFormat::Legacy),
34 "molds" => Ok(DataFormat::Molds),
35 "custom" => Ok(DataFormat::Custom),
36 _ => bail!("Invalid format: {}. Must be auto, legacy, molds, or custom", s),
37 }
38 }
39}
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct LegacySample {
42 pub x: Vec<f32>,
43 pub y: Vec<f32>,
44}
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct MoldsSample {
47 pub module_name: String,
48 pub file_name: String,
49 pub implementation: String,
50 pub documentation: String,
51 #[serde(rename = "system_context")]
52 pub system_context: Option<String>,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct TrainingSample {
58 pub prompt: Option<String>,
59 pub chosen: Option<String>, pub rejected: Option<String>, pub completion: Option<String>, pub label: Option<f32>, pub meta: HashMap<String, Value>, }
65
66#[derive(Debug, Clone)]
68pub struct TrainingDataset {
69 pub samples: Vec<TrainingSample>,
70 pub format: TrainingFormat,
71 pub statistics: DatasetStats,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct DatasetStats {
77 pub total_samples: usize,
78 pub avg_prompt_length: f64,
79 pub avg_completion_length: f64,
80 pub field_coverage: HashMap<String, f64>,
81 pub quality_score: Option<f64>,
82}
83#[derive(Debug)]
84pub struct GenericJSONDataset {
85 pub data: Vec<Value>,
86 pub format: DataFormat,
87 pub schema: Option<Value>,
88}
89impl GenericJSONDataset {
90 pub fn new(
91 json_paths: &[PathBuf],
92 schema_path: Option<&Path>,
93 data_format: DataFormat,
94 ) -> Result<Self> {
95 if json_paths.is_empty() {
96 bail!("No JSON files provided");
97 }
98 for path in json_paths {
99 if !path.is_file() {
100 bail!("JSON file not found: {}", path.display());
101 }
102 }
103 let mut raw_data = Vec::new();
104 for path in json_paths {
105 let content = fs::read_to_string(path)
106 .with_context(|| format!("Failed to read {}", path.display()))?;
107 let parsed: Value = serde_json::from_str(&content)
108 .with_context(|| format!("Failed to parse JSON in {}", path.display()))?;
109 match parsed {
110 Value::Array(arr) => raw_data.extend(arr),
111 Value::Object(obj) => raw_data.push(Value::Object(obj)),
112 _ => {
113 bail!(
114 "Root object in {} must be an array or object, got {}", path
115 .display(), parsed
116 )
117 }
118 }
119 }
120 if raw_data.is_empty() {
121 bail!("All input files are empty");
122 }
123 let format = match data_format {
124 DataFormat::Auto => Self::detect_format(&raw_data[0])?,
125 _ => data_format,
126 };
127 let schema = if let Some(schema_path) = schema_path {
128 let schema_content = fs::read_to_string(schema_path)
129 .with_context(|| {
130 format!("Failed to read schema {}", schema_path.display())
131 })?;
132 Some(
133 serde_json::from_str(&schema_content)
134 .with_context(|| format!("Failed to parse schema JSON"))?,
135 )
136 } else {
137 Self::get_builtin_schema(&format)
138 };
139 if let Some(ref schema) = schema {
140 Self::validate_data(&raw_data, schema, &format)?;
141 }
142 Ok(GenericJSONDataset {
143 data: raw_data,
144 format,
145 schema,
146 })
147 }
148 fn detect_format(first_sample: &Value) -> Result<DataFormat> {
149 if let Some(obj) = first_sample.as_object() {
150 if obj.contains_key("module_name") {
151 Ok(DataFormat::Molds)
152 } else if obj.contains_key("x") && obj.contains_key("y") {
153 Ok(DataFormat::Legacy)
154 } else {
155 Ok(DataFormat::Custom)
156 }
157 } else {
158 bail!("First sample is not an object - cannot auto-detect format");
159 }
160 }
161 fn get_builtin_schema(format: &DataFormat) -> Option<Value> {
162 match format {
163 DataFormat::Legacy => {
164 Some(
165 json!(
166 { "type" : "array", "items" : { "type" : "object", "required" :
167 ["x", "y"] } }
168 ),
169 )
170 }
171 DataFormat::Molds => {
172 Some(
173 json!(
174 { "type" : "array", "items" : { "type" : "object", "required" :
175 ["module_name", "file_name", "implementation", "documentation"] }
176 }
177 ),
178 )
179 }
180 _ => None,
181 }
182 }
183 fn validate_data(
184 data: &[Value],
185 _schema: &Value,
186 format: &DataFormat,
187 ) -> Result<()> {
188 let required_keys = match format {
189 DataFormat::Legacy => vec!["x", "y"],
190 DataFormat::Molds => {
191 vec!["module_name", "file_name", "implementation", "documentation"]
192 }
193 _ => return Ok(()),
194 };
195 for (i, sample) in data.iter().enumerate() {
196 if let Some(obj) = sample.as_object() {
197 for key in &required_keys {
198 if !obj.contains_key(*key) {
199 bail!(
200 "Sample {} is missing required key '{}' for {} format", i,
201 key, format!("{:?}", format) .to_lowercase()
202 );
203 }
204 }
205 } else {
206 bail!("Sample {} is not an object", i);
207 }
208 }
209 Ok(())
210 }
211 pub fn len(&self) -> usize {
212 self.data.len()
213 }
214 pub fn is_empty(&self) -> bool {
215 self.data.is_empty()
216 }
217 pub fn get_random_sample(&self) -> Option<&Value> {
218 if self.data.is_empty() {
219 None
220 } else {
221 use rand::Rng;
222 let mut rng = rand::thread_rng();
223 let idx = rng.gen_range(0..self.data.len());
224 Some(&self.data[idx])
225 }
226 }
227 pub fn stats(&self) -> HashMap<String, Value> {
228 let mut stats = HashMap::new();
229 stats.insert("num_samples".to_string(), json!(self.len()));
230 stats
231 .insert(
232 "format".to_string(),
233 json!(format!("{:?}", self.format) .to_lowercase()),
234 );
235 stats.insert("has_schema".to_string(), json!(self.schema.is_some()));
236 if !self.data.is_empty() {
237 if let Some(obj) = self.data[0].as_object() {
238 stats
239 .insert(
240 "sample_keys".to_string(),
241 json!(obj.keys().collect::< Vec < _ >> ()),
242 );
243 }
244 }
245 stats
246 }
247
248 pub fn detect_training_format(&self) -> Result<TrainingFormat> {
250 if self.data.is_empty() {
251 bail!("Cannot detect training format from empty dataset");
252 }
253
254 let first_sample = &self.data[0];
255 let fields = if let Some(obj) = first_sample.as_object() {
256 obj.keys().map(|s| s.as_str()).collect::<Vec<_>>()
257 } else {
258 bail!("First sample is not an object");
259 };
260
261 if fields.contains(&"chosen") && fields.contains(&"rejected") {
263 return Ok(TrainingFormat::Preference {
264 chosen_field: "chosen".to_string(),
265 rejected_field: "rejected".to_string(),
266 });
267 }
268
269 if fields.contains(&"completion") {
271 let label_field = if fields.contains(&"label") {
272 Some("label".to_string())
273 } else {
274 None
275 };
276 return Ok(TrainingFormat::Completion {
277 completion_field: "completion".to_string(),
278 label_field,
279 });
280 }
281
282 if fields.contains(&"instruction") && fields.contains(&"output") {
284 return Ok(TrainingFormat::Instruction {
285 instruction_field: "instruction".to_string(),
286 output_field: "output".to_string(),
287 });
288 }
289
290 if fields.contains(&"messages") {
292 return Ok(TrainingFormat::Chat {
293 messages_field: "messages".to_string(),
294 });
295 }
296
297 Ok(TrainingFormat::Custom {
299 fields: fields.into_iter().map(|s| s.to_string()).collect(),
300 })
301 }
302
303 pub fn to_training_dataset(&self) -> Result<TrainingDataset> {
305 let training_format = self.detect_training_format()?;
306 let mut samples = Vec::new();
307
308 for (i, sample) in self.data.iter().enumerate() {
309 if let Some(obj) = sample.as_object() {
310 let training_sample = self.convert_sample_to_training(obj, &training_format)
311 .with_context(|| format!("Failed to convert sample {}", i))?;
312 samples.push(training_sample);
313 } else {
314 bail!("Sample {} is not an object", i);
315 }
316 }
317
318 let statistics = self.compute_statistics(&samples)?;
319
320 Ok(TrainingDataset {
321 samples,
322 format: training_format,
323 statistics,
324 })
325 }
326
327 fn convert_sample_to_training(&self, obj: &serde_json::Map<String, Value>, format: &TrainingFormat) -> Result<TrainingSample> {
328 let mut sample = TrainingSample {
329 prompt: None,
330 chosen: None,
331 rejected: None,
332 completion: None,
333 label: None,
334 meta: HashMap::new(),
335 };
336
337 if let Some(prompt_val) = obj.get("prompt") {
339 if let Some(prompt_str) = prompt_val.as_str() {
340 sample.prompt = Some(prompt_str.to_string());
341 }
342 }
343
344 match format {
345 TrainingFormat::Preference { chosen_field, rejected_field } => {
346 if let Some(chosen_val) = obj.get(chosen_field) {
347 if let Some(chosen_str) = chosen_val.as_str() {
348 sample.chosen = Some(chosen_str.to_string());
349 }
350 }
351 if let Some(rejected_val) = obj.get(rejected_field) {
352 if let Some(rejected_str) = rejected_val.as_str() {
353 sample.rejected = Some(rejected_str.to_string());
354 }
355 }
356 }
357 TrainingFormat::Completion { completion_field, label_field } => {
358 if let Some(completion_val) = obj.get(completion_field) {
359 if let Some(completion_str) = completion_val.as_str() {
360 sample.completion = Some(completion_str.to_string());
361 }
362 }
363 if let Some(label_field) = label_field {
364 if let Some(label_val) = obj.get(label_field) {
365 if let Some(label_num) = label_val.as_f64() {
366 sample.label = Some(label_num as f32);
367 } else if let Some(label_bool) = label_val.as_bool() {
368 sample.label = Some(if label_bool { 1.0 } else { 0.0 });
369 }
370 }
371 }
372 }
373 TrainingFormat::Instruction { instruction_field, output_field } => {
374 if let Some(instruction_val) = obj.get(instruction_field) {
375 if let Some(instruction_str) = instruction_val.as_str() {
376 sample.prompt = Some(instruction_str.to_string());
377 }
378 }
379 if let Some(output_val) = obj.get(output_field) {
380 if let Some(output_str) = output_val.as_str() {
381 sample.completion = Some(output_str.to_string());
382 }
383 }
384 }
385 TrainingFormat::Chat { messages_field } => {
386 if let Some(messages_val) = obj.get(messages_field) {
388 sample.meta.insert("messages".to_string(), messages_val.clone());
389 if let Some(messages) = messages_val.as_array() {
391 if let Some(first_msg) = messages.first() {
392 if let Some(content) = first_msg.get("content").and_then(|c| c.as_str()) {
393 sample.prompt = Some(content.to_string());
394 }
395 }
396 if let Some(last_msg) = messages.last() {
398 if let Some(content) = last_msg.get("content").and_then(|c| c.as_str()) {
399 sample.completion = Some(content.to_string());
400 }
401 }
402 }
403 }
404 }
405 TrainingFormat::Custom { fields } => {
406 for field in fields {
408 if let Some(value) = obj.get(field) {
409 sample.meta.insert(field.clone(), value.clone());
410 }
411 }
412 }
413 }
414
415 for (key, value) in obj {
417 if !matches!(key.as_str(),
418 "prompt" | "chosen" | "rejected" | "completion" | "label" |
419 "instruction" | "output" | "messages") {
420 sample.meta.insert(key.clone(), value.clone());
421 }
422 }
423
424 Ok(sample)
425 }
426
427 fn compute_statistics(&self, samples: &[TrainingSample]) -> Result<DatasetStats> {
428 let total_samples = samples.len();
429 let mut total_prompt_length = 0;
430 let mut total_completion_length = 0;
431 let mut prompt_count = 0;
432 let mut completion_count = 0;
433 let mut field_coverage = HashMap::new();
434
435 for sample in samples {
437 if sample.prompt.is_some() {
438 *field_coverage.entry("prompt".to_string()).or_insert(0.0) += 1.0;
439 total_prompt_length += sample.prompt.as_ref().unwrap().len();
440 prompt_count += 1;
441 }
442 if sample.chosen.is_some() {
443 *field_coverage.entry("chosen".to_string()).or_insert(0.0) += 1.0;
444 }
445 if sample.rejected.is_some() {
446 *field_coverage.entry("rejected".to_string()).or_insert(0.0) += 1.0;
447 }
448 if sample.completion.is_some() {
449 *field_coverage.entry("completion".to_string()).or_insert(0.0) += 1.0;
450 total_completion_length += sample.completion.as_ref().unwrap().len();
451 completion_count += 1;
452 }
453 if sample.label.is_some() {
454 *field_coverage.entry("label".to_string()).or_insert(0.0) += 1.0;
455 }
456 }
457
458 for count in field_coverage.values_mut() {
460 *count = *count / total_samples as f64;
461 }
462
463 let avg_prompt_length = if prompt_count > 0 {
464 total_prompt_length as f64 / prompt_count as f64
465 } else {
466 0.0
467 };
468
469 let avg_completion_length = if completion_count > 0 {
470 total_completion_length as f64 / completion_count as f64
471 } else {
472 0.0
473 };
474
475 let quality_score = Some(
477 (field_coverage.get("prompt").unwrap_or(&0.0) +
478 field_coverage.get("chosen").unwrap_or(&0.0) +
479 field_coverage.get("rejected").unwrap_or(&0.0) +
480 field_coverage.get("completion").unwrap_or(&0.0)) / 4.0
481 );
482
483 Ok(DatasetStats {
484 total_samples,
485 avg_prompt_length,
486 avg_completion_length,
487 field_coverage,
488 quality_score,
489 })
490 }
491}
492
493impl TrainingDataset {
494 pub fn to_algorithm_format(&self, algorithm: &str) -> Result<Box<dyn std::any::Any>> {
496 match algorithm.to_lowercase().as_str() {
497 "bco" => {
498 let bco_data = self.to_bco_format()?;
499 Ok(Box::new(bco_data))
500 }
501 "dpo" => {
502 let dpo_data = self.to_dpo_format()?;
503 Ok(Box::new(dpo_data))
504 }
505 "ppo" => {
506 let ppo_data = self.to_ppo_format()?;
507 Ok(Box::new(ppo_data))
508 }
509 "sft" => {
510 let sft_data = self.to_sft_format()?;
511 Ok(Box::new(sft_data))
512 }
513 _ => bail!("Unsupported algorithm: {}", algorithm),
514 }
515 }
516
517 fn to_bco_format(&self) -> Result<BCODataset> {
518 let mut bco_samples = Vec::new();
519
520 for sample in &self.samples {
521 match &self.format {
522 TrainingFormat::Preference { .. } => {
523 if let (Some(chosen), Some(rejected)) = (&sample.chosen, &sample.rejected) {
525 let (completion, label) = if rand::random::<bool>() {
527 (chosen.clone(), true)
528 } else {
529 (rejected.clone(), false)
530 };
531
532 bco_samples.push(BCOSample {
533 prompt: sample.prompt.clone().unwrap_or_default(),
534 completion,
535 label,
536 });
537 }
538 }
539 TrainingFormat::Completion { .. } => {
540 if let Some(completion) = &sample.completion {
542 let label = sample.label.map(|l| l > 0.5).unwrap_or(true);
543 bco_samples.push(BCOSample {
544 prompt: sample.prompt.clone().unwrap_or_default(),
545 completion: completion.clone(),
546 label,
547 });
548 }
549 }
550 _ => {
551 if let Some(completion) = &sample.completion {
553 bco_samples.push(BCOSample {
554 prompt: sample.prompt.clone().unwrap_or_default(),
555 completion: completion.clone(),
556 label: sample.label.map(|l| l > 0.5).unwrap_or(true),
557 });
558 }
559 }
560 }
561 }
562
563 Ok(BCODataset { samples: bco_samples })
564 }
565
566 fn to_dpo_format(&self) -> Result<DPODataset> {
567 let mut dpo_samples = Vec::new();
568
569 for sample in &self.samples {
570 if let TrainingFormat::Preference { .. } = &self.format {
571 if let (Some(chosen), Some(rejected)) = (&sample.chosen, &sample.rejected) {
572 dpo_samples.push(DPOSample {
573 prompt: sample.prompt.clone().unwrap_or_default(),
574 chosen: chosen.clone(),
575 rejected: rejected.clone(),
576 });
577 }
578 } else {
579 bail!("DPO format requires preference-style data (chosen/rejected fields)");
580 }
581 }
582
583 Ok(DPODataset { samples: dpo_samples })
584 }
585
586 fn to_ppo_format(&self) -> Result<PPODataset> {
587 let mut ppo_samples = Vec::new();
589
590 for sample in &self.samples {
591 if let Some(completion) = &sample.completion {
592 ppo_samples.push(PPOSample {
593 prompt: sample.prompt.clone().unwrap_or_default(),
594 completion: completion.clone(),
595 reward: sample.label.unwrap_or(0.0),
596 });
597 }
598 }
599
600 Ok(PPODataset { samples: ppo_samples })
601 }
602
603 fn to_sft_format(&self) -> Result<SFTDataset> {
604 let mut sft_samples = Vec::new();
605
606 for sample in &self.samples {
607 if let Some(completion) = &sample.completion {
608 sft_samples.push(SFTSample {
609 prompt: sample.prompt.clone().unwrap_or_default(),
610 completion: completion.clone(),
611 });
612 }
613 }
614
615 Ok(SFTDataset { samples: sft_samples })
616 }
617
618 pub fn quality_assessment(&self) -> DatasetQualityReport {
620 let mut report = DatasetQualityReport {
621 overall_score: 0.0,
622 issues: Vec::new(),
623 recommendations: Vec::new(),
624 };
625
626 let prompt_coverage = self.statistics.field_coverage.get("prompt").unwrap_or(&0.0);
628 let completion_coverage = self.statistics.field_coverage.get("completion").unwrap_or(&0.0);
629 let chosen_coverage = self.statistics.field_coverage.get("chosen").unwrap_or(&0.0);
630 let rejected_coverage = self.statistics.field_coverage.get("rejected").unwrap_or(&0.0);
631
632 match &self.format {
633 TrainingFormat::Preference { .. } => {
634 if *chosen_coverage < 0.9 {
635 report.issues.push(format!("Low chosen field coverage: {:.1}%", chosen_coverage * 100.0));
636 }
637 if *rejected_coverage < 0.9 {
638 report.issues.push(format!("Low rejected field coverage: {:.1}%", rejected_coverage * 100.0));
639 }
640 report.overall_score = (chosen_coverage + rejected_coverage + prompt_coverage) / 3.0;
641 }
642 TrainingFormat::Completion { .. } => {
643 if *completion_coverage < 0.9 {
644 report.issues.push(format!("Low completion field coverage: {:.1}%", completion_coverage * 100.0));
645 }
646 report.overall_score = (completion_coverage + prompt_coverage) / 2.0;
647 }
648 _ => {
649 report.overall_score = self.statistics.quality_score.unwrap_or(0.0);
650 }
651 }
652
653 if self.statistics.avg_prompt_length < 10.0 {
655 report.issues.push("Very short average prompt length".to_string());
656 }
657 if self.statistics.avg_completion_length < 10.0 {
658 report.issues.push("Very short average completion length".to_string());
659 }
660
661 if report.issues.is_empty() {
663 report.recommendations.push("Dataset quality looks good!".to_string());
664 } else {
665 report.recommendations.push("Consider filtering or augmenting low-quality samples".to_string());
666 }
667
668 report
669 }
670}
671
672#[derive(Debug, Clone, Serialize, Deserialize)]
674pub struct BCOSample {
675 pub prompt: String,
676 pub completion: String,
677 pub label: bool,
678}
679
680#[derive(Debug, Clone)]
681pub struct BCODataset {
682 pub samples: Vec<BCOSample>,
683}
684
685#[derive(Debug, Clone, Serialize, Deserialize)]
686pub struct DPOSample {
687 pub prompt: String,
688 pub chosen: String,
689 pub rejected: String,
690}
691
692#[derive(Debug, Clone)]
693pub struct DPODataset {
694 pub samples: Vec<DPOSample>,
695}
696
697#[derive(Debug, Clone, Serialize, Deserialize)]
698pub struct PPOSample {
699 pub prompt: String,
700 pub completion: String,
701 pub reward: f32,
702}
703
704#[derive(Debug, Clone)]
705pub struct PPODataset {
706 pub samples: Vec<PPOSample>,
707}
708
709#[derive(Debug, Clone, Serialize, Deserialize)]
710pub struct SFTSample {
711 pub prompt: String,
712 pub completion: String,
713}
714
715#[derive(Debug, Clone)]
716pub struct SFTDataset {
717 pub samples: Vec<SFTSample>,
718}
719
720#[derive(Debug, Clone)]
722pub struct DatasetQualityReport {
723 pub overall_score: f64,
724 pub issues: Vec<String>,
725 pub recommendations: Vec<String>,
726}
727
728fn clean_whitespace(text: &str) -> String {
729 text.lines()
730 .map(|line| line.trim_end())
731 .collect::<Vec<_>>()
732 .join("\n")
733 .trim()
734 .to_string()
735}
736fn process_file(
737 src_path: &Path,
738 dst_path: &Path,
739 schema_path: Option<&Path>,
740 format_override: &DataFormat,
741) -> Result<()> {
742 let content = fs::read_to_string(src_path)
743 .with_context(|| format!("Failed to read {}", src_path.display()))?;
744 let raw: Value = serde_json::from_str(&content)
745 .with_context(|| format!("Failed to parse JSON in {}", src_path.display()))?;
746 let _temp_dataset = GenericJSONDataset::new(
747 &[src_path.to_path_buf()],
748 schema_path,
749 format_override.clone(),
750 )?;
751 let cleaned = if let Value::Array(arr) = raw {
752 let cleaned_arr: Vec<Value> = arr
753 .into_iter()
754 .map(|mut entry| {
755 if let Value::Object(ref mut obj) = entry {
756 for (_key, value) in obj.iter_mut() {
757 if let Value::String(ref mut s) = value {
758 *s = clean_whitespace(s);
759 }
760 }
761 }
762 entry
763 })
764 .collect();
765 Value::Array(cleaned_arr)
766 } else {
767 raw
768 };
769 dst_path
770 .parent()
771 .map(|p| fs::create_dir_all(p))
772 .transpose()
773 .with_context(|| {
774 format!("Failed to create directory for {}", dst_path.display())
775 })?;
776 let cleaned_json = serde_json::to_string_pretty(&cleaned)
777 .with_context(|| "Failed to serialize cleaned JSON")?;
778 fs::write(dst_path, cleaned_json)
779 .with_context(|| format!("Failed to write to {}", dst_path.display()))?;
780 Ok(())
781}
782pub async fn run_multi_process_clean(
783 src_files: Vec<PathBuf>,
784 dst_root: &Path,
785 schema_dir: Option<&Path>,
786 format_override: &DataFormat,
787 _jobs: usize,
788) -> Result<()> {
789 if src_files.is_empty() {
790 bail!("No source files provided");
791 }
792 let tasks: Vec<_> = src_files
793 .iter()
794 .map(|src| {
795 let dst = dst_root.join(src.file_name().unwrap());
796 let schema_path = schema_dir
797 .and_then(|dir| {
798 let candidate = dir
799 .join(format!("{}.schema.json", src.file_stem() ?.to_str() ?));
800 if candidate.is_file() { Some(candidate) } else { None }
801 });
802 (src.clone(), dst, schema_path)
803 })
804 .collect();
805 let pb = ProgressBar::new(tasks.len() as u64);
806 pb.set_style(
807 ProgressStyle::default_bar()
808 .template(
809 "{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({eta}) {msg}",
810 )
811 .unwrap()
812 .progress_chars("#>-"),
813 );
814 pb.set_message("Cleaning & validating");
815 let pb = Arc::new(Mutex::new(pb));
816 let results: Vec<Result<()>> = tasks
817 .iter()
818 .map(|(src, dst, schema_path)| {
819 let result = process_file(
820 src,
821 dst,
822 schema_path.as_deref(),
823 format_override,
824 );
825 if let Ok(pb) = pb.try_lock() {
826 pb.inc(1);
827 }
828 result
829 })
830 .collect();
831 if let Ok(pb) = pb.try_lock() {
832 pb.finish_with_message("Complete");
833 }
834 let errors: Vec<_> = results
835 .into_iter()
836 .enumerate()
837 .filter_map(|(i, r)| r.err().map(|e| (i, e)))
838 .collect();
839 if !errors.is_empty() {
840 for (i, e) in &errors {
841 eprintln!("Error processing file {}: {}", i, e);
842 }
843 bail!("{} files failed processing", errors.len());
844 }
845 Ok(())
846}
847pub fn gather_json_paths(
848 files: &[String],
849 data_dirs: &[String],
850) -> Result<Vec<PathBuf>> {
851 let mut paths = Vec::new();
852 for file in files {
853 let path = PathBuf::from(file);
854 if !path.is_file() {
855 bail!("Specified file does not exist: {}", path.display());
856 }
857 paths.push(path);
858 }
859 for dir in data_dirs {
860 let dir_path = PathBuf::from(dir);
861 if !dir_path.is_dir() {
862 bail!("Data directory does not exist: {}", dir_path.display());
863 }
864 for entry in WalkDir::new(&dir_path).into_iter().filter_map(|e| e.ok()) {
865 if entry.path().extension().map_or(false, |ext| ext == "json") {
866 paths.push(entry.path().to_path_buf());
867 }
868 }
869 }
870 if paths.is_empty() {
871 bail!("No JSON files found in specified paths");
872 }
873 paths.sort();
874 paths.dedup();
875 Ok(paths)
876}
877pub fn find_schema_for_file(
878 json_path: &Path,
879 schema_dir: Option<&Path>,
880) -> Option<PathBuf> {
881 schema_dir
882 .and_then(|dir| {
883 let candidate = dir
884 .join(format!("{}.schema.json", json_path.file_stem() ?.to_str() ?));
885 if candidate.is_file() { Some(candidate) } else { None }
886 })
887}
888pub async fn run_json_cmd(args: JsonArgs) -> Result<()> {
889 if args.multi_process {
890 if args.input_folder.is_none() || args.output.is_none() {
891 bail!("Multi-process mode requires both --input-folder and --output");
892 }
893 let src_root = args.input_folder.unwrap();
894 let dst_root = args.output.unwrap();
895 if !src_root.is_dir() {
896 bail!("Input folder does not exist: {}", src_root.display());
897 }
898 let src_files: Vec<_> = WalkDir::new(&src_root)
899 .into_iter()
900 .filter_map(|e| e.ok())
901 .filter(|e| e.path().extension().map_or(false, |ext| ext == "json"))
902 .map(|e| e.path().to_path_buf())
903 .collect();
904 if src_files.is_empty() {
905 bail!("No JSON files found in {}", src_root.display());
906 }
907 let schema_dir = args.schema_dir.as_ref();
908 let format_override = args.format.clone();
909 println!(
910 "🔧 Starting multi-process clean-validate:\n source: {}\n destination: {}\n workers: {}\n format: {:?}",
911 src_root.display(), dst_root.display(), args.jobs, format_override
912 );
913 run_multi_process_clean(
914 src_files,
915 &dst_root,
916 schema_dir.map(|p| p.as_path()),
917 &format_override,
918 args.jobs,
919 )
920 .await?;
921 println!("✅ All files cleaned and validated successfully.");
922 return Ok(());
923 }
924 let json_paths = gather_json_paths(&args.file, &args.data_dir)?;
925 let mut all_samples = Vec::new();
926 for path in &json_paths {
927 let schema_path = find_schema_for_file(
928 path,
929 args.schema_dir.as_ref().map(|p| p.as_path()),
930 );
931 println!(
932 "Loading {} (schema: {})", path.file_name().unwrap().to_string_lossy(),
933 schema_path.as_ref().map(| p | p.file_name().unwrap().to_string_lossy())
934 .unwrap_or(std::borrow::Cow::Borrowed("built-in"))
935 );
936 let dataset = GenericJSONDataset::new(
937 &[path.clone()],
938 schema_path.as_ref().map(|p| p.as_path()),
939 args.format.clone(),
940 )?;
941 all_samples.extend(dataset.data);
942 }
943 if let Some(merge_output) = &args.merge_output {
944 merge_output
945 .parent()
946 .map(|p| fs::create_dir_all(p))
947 .transpose()
948 .with_context(|| {
949 format!("Failed to create directory for {}", merge_output.display())
950 })?;
951 let merged_json = serde_json::to_string_pretty(&all_samples)
952 .with_context(|| "Failed to serialize merged JSON")?;
953 fs::write(merge_output, merged_json)
954 .with_context(|| {
955 format!("Failed to write merged output to {}", merge_output.display())
956 })?;
957 println!(
958 "✅ Merged {} samples into {}", all_samples.len(), merge_output.display()
959 );
960 }
961 if args.show_stats {
962 if !json_paths.is_empty() {
963 let temp_dataset = GenericJSONDataset::new(
964 &json_paths,
965 args.schema_dir.as_ref().map(|p| p.as_path()),
966 args.format,
967 )?;
968 println!(
969 "\n--- Dataset statistics ------------------------------------------------"
970 );
971 for (k, v) in temp_dataset.stats() {
972 println!("{:20}: {}", k, v);
973 }
974 println!(
975 "--------------------------------------------------------------------"
976 );
977 }
978 }
979 if args.merge_output.is_none() {
980 println!("🎉 Validation finished – no merged output requested.");
981 }
982 Ok(())
983}
984#[derive(Args, Debug, Clone)]
985pub struct JsonArgs {
986 #[arg(long)]
987 pub data_dir: Vec<String>,
988 #[arg(long, short = 'f')]
989 pub file: Vec<String>,
990 #[arg(long)]
991 pub schema_dir: Option<PathBuf>,
992 #[arg(long, default_value = "auto")]
993 pub format: DataFormat,
994 #[arg(long)]
995 pub merge_output: Option<PathBuf>,
996 #[arg(long)]
997 pub show_stats: bool,
998 #[arg(long, default_value = "42")]
999 pub seed: u64,
1000 #[arg(long)]
1001 pub multi_process: bool,
1002 #[arg(long)]
1003 pub input_folder: Option<PathBuf>,
1004 #[arg(long, short = 'o')]
1005 pub output: Option<PathBuf>,
1006 #[arg(long, short = 'j', default_value_t = num_cpus::get())]
1007 pub jobs: usize,
1008}