helix/dna/map/
core.rs

1use anyhow::{bail, Context, Result};
2use clap::Args;
3use indicatif::{ProgressBar, ProgressStyle};
4use serde::{Deserialize, Serialize};
5use serde_json::{json, Value};
6use std::collections::HashMap;
7use std::fs;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10use tokio::sync::Mutex;
11use walkdir::WalkDir;
12#[derive(Debug, Clone, PartialEq)]
13pub enum DataFormat {
14    Auto,
15    Legacy,
16    Molds,
17    Custom,
18}
19#[derive(Debug, Clone, PartialEq)]
20pub enum TrainingFormat {
21    Preference { chosen_field: String, rejected_field: String },
22    Completion { completion_field: String, label_field: Option<String> },
23    Instruction { instruction_field: String, output_field: String },
24    Chat { messages_field: String },
25    Custom { fields: Vec<String> },
26}
27impl std::str::FromStr for DataFormat {
28    type Err = anyhow::Error;
29    fn from_str(s: &str) -> Result<Self> {
30        match s.to_lowercase().as_str() {
31            "auto" => Ok(DataFormat::Auto),
32            "legacy" => Ok(DataFormat::Legacy),
33            "molds" => Ok(DataFormat::Molds),
34            "custom" => Ok(DataFormat::Custom),
35            _ => bail!("Invalid format: {}. Must be auto, legacy, molds, or custom", s),
36        }
37    }
38}
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct LegacySample {
41    pub x: Vec<f32>,
42    pub y: Vec<f32>,
43}
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct MoldsSample {
46    pub module_name: String,
47    pub file_name: String,
48    pub implementation: String,
49    pub documentation: String,
50    #[serde(rename = "system_context")]
51    pub system_context: Option<String>,
52}
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct TrainingSample {
55    pub prompt: Option<String>,
56    pub chosen: Option<String>,
57    pub rejected: Option<String>,
58    pub completion: Option<String>,
59    pub label: Option<f32>,
60    pub meta: HashMap<String, Value>,
61}
62#[derive(Debug, Clone)]
63pub struct TrainingDataset {
64    pub samples: Vec<TrainingSample>,
65    pub format: TrainingFormat,
66    pub statistics: DatasetStats,
67}
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct DatasetStats {
70    pub total_samples: usize,
71    pub avg_prompt_length: f64,
72    pub avg_completion_length: f64,
73    pub field_coverage: HashMap<String, f64>,
74    pub quality_score: Option<f64>,
75}
76#[derive(Debug)]
77pub struct GenericJSONDataset {
78    pub data: Vec<Value>,
79    pub format: DataFormat,
80    pub schema: Option<Value>,
81}
82impl GenericJSONDataset {
83    pub fn new(
84        json_paths: &[PathBuf],
85        schema_path: Option<&Path>,
86        data_format: DataFormat,
87    ) -> Result<Self> {
88        if json_paths.is_empty() {
89            bail!("No JSON files provided");
90        }
91        for path in json_paths {
92            if !path.is_file() {
93                bail!("JSON file not found: {}", path.display());
94            }
95        }
96        let mut raw_data = Vec::new();
97        for path in json_paths {
98            let content = fs::read_to_string(path)
99                .with_context(|| format!("Failed to read {}", path.display()))?;
100            let parsed: Value = serde_json::from_str(&content)
101                .with_context(|| format!("Failed to parse JSON in {}", path.display()))?;
102            match parsed {
103                Value::Array(arr) => raw_data.extend(arr),
104                Value::Object(obj) => raw_data.push(Value::Object(obj)),
105                _ => {
106                    bail!(
107                        "Root object in {} must be an array or object, got {}", path
108                        .display(), parsed
109                    )
110                }
111            }
112        }
113        if raw_data.is_empty() {
114            bail!("All input files are empty");
115        }
116        let format = match data_format {
117            DataFormat::Auto => Self::detect_format(&raw_data[0])?,
118            _ => data_format,
119        };
120        let schema = if let Some(schema_path) = schema_path {
121            let schema_content = fs::read_to_string(schema_path)
122                .with_context(|| {
123                    format!("Failed to read schema {}", schema_path.display())
124                })?;
125            Some(
126                serde_json::from_str(&schema_content)
127                    .with_context(|| format!("Failed to parse schema JSON"))?,
128            )
129        } else {
130            Self::get_builtin_schema(&format)
131        };
132        if let Some(ref schema) = schema {
133            Self::validate_data(&raw_data, schema, &format)?;
134        }
135        Ok(GenericJSONDataset {
136            data: raw_data,
137            format,
138            schema,
139        })
140    }
141    fn detect_format(first_sample: &Value) -> Result<DataFormat> {
142        if let Some(obj) = first_sample.as_object() {
143            if obj.contains_key("module_name") {
144                Ok(DataFormat::Molds)
145            } else if obj.contains_key("x") && obj.contains_key("y") {
146                Ok(DataFormat::Legacy)
147            } else {
148                Ok(DataFormat::Custom)
149            }
150        } else {
151            bail!("First sample is not an object - cannot auto-detect format");
152        }
153    }
154    fn get_builtin_schema(format: &DataFormat) -> Option<Value> {
155        match format {
156            DataFormat::Legacy => {
157                Some(
158                    json!(
159                        { "type" : "array", "items" : { "type" : "object", "required" :
160                        ["x", "y"] } }
161                    ),
162                )
163            }
164            DataFormat::Molds => {
165                Some(
166                    json!(
167                        { "type" : "array", "items" : { "type" : "object", "required" :
168                        ["module_name", "file_name", "implementation", "documentation"] }
169                        }
170                    ),
171                )
172            }
173            _ => None,
174        }
175    }
176    fn validate_data(
177        data: &[Value],
178        _schema: &Value,
179        format: &DataFormat,
180    ) -> Result<()> {
181        let required_keys = match format {
182            DataFormat::Legacy => vec!["x", "y"],
183            DataFormat::Molds => {
184                vec!["module_name", "file_name", "implementation", "documentation"]
185            }
186            _ => return Ok(()),
187        };
188        for (i, sample) in data.iter().enumerate() {
189            if let Some(obj) = sample.as_object() {
190                for key in &required_keys {
191                    if !obj.contains_key(*key) {
192                        bail!(
193                            "Sample {} is missing required key '{}' for {} format", i,
194                            key, format!("{:?}", format) .to_lowercase()
195                        );
196                    }
197                }
198            } else {
199                bail!("Sample {} is not an object", i);
200            }
201        }
202        Ok(())
203    }
204    pub fn len(&self) -> usize {
205        self.data.len()
206    }
207    pub fn is_empty(&self) -> bool {
208        self.data.is_empty()
209    }
210    pub fn get_random_sample(&self) -> Option<&Value> {
211        if self.data.is_empty() {
212            None
213        } else {
214            use rand::Rng;
215            let mut rng = rand::thread_rng();
216            let idx = rng.gen_range(0..self.data.len());
217            Some(&self.data[idx])
218        }
219    }
220    pub fn stats(&self) -> HashMap<String, Value> {
221        let mut stats = HashMap::new();
222        stats.insert("num_samples".to_string(), json!(self.len()));
223        stats
224            .insert(
225                "format".to_string(),
226                json!(format!("{:?}", self.format) .to_lowercase()),
227            );
228        stats.insert("has_schema".to_string(), json!(self.schema.is_some()));
229        if !self.data.is_empty() {
230            if let Some(obj) = self.data[0].as_object() {
231                stats
232                    .insert(
233                        "sample_keys".to_string(),
234                        json!(obj.keys().collect::< Vec < _ >> ()),
235                    );
236            }
237        }
238        stats
239    }
240    pub fn detect_training_format(&self) -> Result<TrainingFormat> {
241        if self.data.is_empty() {
242            bail!("Cannot detect training format from empty dataset");
243        }
244        let first_sample = &self.data[0];
245        let fields = if let Some(obj) = first_sample.as_object() {
246            obj.keys().map(|s| s.as_str()).collect::<Vec<_>>()
247        } else {
248            bail!("First sample is not an object");
249        };
250        if fields.contains(&"chosen") && fields.contains(&"rejected") {
251            return Ok(TrainingFormat::Preference {
252                chosen_field: "chosen".to_string(),
253                rejected_field: "rejected".to_string(),
254            });
255        }
256        if fields.contains(&"completion") {
257            let label_field = if fields.contains(&"label") {
258                Some("label".to_string())
259            } else {
260                None
261            };
262            return Ok(TrainingFormat::Completion {
263                completion_field: "completion".to_string(),
264                label_field,
265            });
266        }
267        if fields.contains(&"instruction") && fields.contains(&"output") {
268            return Ok(TrainingFormat::Instruction {
269                instruction_field: "instruction".to_string(),
270                output_field: "output".to_string(),
271            });
272        }
273        if fields.contains(&"messages") {
274            return Ok(TrainingFormat::Chat {
275                messages_field: "messages".to_string(),
276            });
277        }
278        Ok(TrainingFormat::Custom {
279            fields: fields.into_iter().map(|s| s.to_string()).collect(),
280        })
281    }
282    pub fn to_training_dataset(&self) -> Result<TrainingDataset> {
283        let training_format = self.detect_training_format()?;
284        let mut samples = Vec::new();
285        for (i, sample) in self.data.iter().enumerate() {
286            if let Some(obj) = sample.as_object() {
287                let training_sample = self
288                    .convert_sample_to_training(obj, &training_format)
289                    .with_context(|| format!("Failed to convert sample {}", i))?;
290                samples.push(training_sample);
291            } else {
292                bail!("Sample {} is not an object", i);
293            }
294        }
295        let statistics = self.compute_statistics(&samples)?;
296        Ok(TrainingDataset {
297            samples,
298            format: training_format,
299            statistics,
300        })
301    }
302    fn convert_sample_to_training(
303        &self,
304        obj: &serde_json::Map<String, Value>,
305        format: &TrainingFormat,
306    ) -> Result<TrainingSample> {
307        let mut sample = TrainingSample {
308            prompt: None,
309            chosen: None,
310            rejected: None,
311            completion: None,
312            label: None,
313            meta: HashMap::new(),
314        };
315        if let Some(prompt_val) = obj.get("prompt") {
316            if let Some(prompt_str) = prompt_val.as_str() {
317                sample.prompt = Some(prompt_str.to_string());
318            }
319        }
320        match format {
321            TrainingFormat::Preference { chosen_field, rejected_field } => {
322                if let Some(chosen_val) = obj.get(chosen_field) {
323                    if let Some(chosen_str) = chosen_val.as_str() {
324                        sample.chosen = Some(chosen_str.to_string());
325                    }
326                }
327                if let Some(rejected_val) = obj.get(rejected_field) {
328                    if let Some(rejected_str) = rejected_val.as_str() {
329                        sample.rejected = Some(rejected_str.to_string());
330                    }
331                }
332            }
333            TrainingFormat::Completion { completion_field, label_field } => {
334                if let Some(completion_val) = obj.get(completion_field) {
335                    if let Some(completion_str) = completion_val.as_str() {
336                        sample.completion = Some(completion_str.to_string());
337                    }
338                }
339                if let Some(label_field) = label_field {
340                    if let Some(label_val) = obj.get(label_field) {
341                        if let Some(label_num) = label_val.as_f64() {
342                            sample.label = Some(label_num as f32);
343                        } else if let Some(label_bool) = label_val.as_bool() {
344                            sample.label = Some(if label_bool { 1.0 } else { 0.0 });
345                        }
346                    }
347                }
348            }
349            TrainingFormat::Instruction { instruction_field, output_field } => {
350                if let Some(instruction_val) = obj.get(instruction_field) {
351                    if let Some(instruction_str) = instruction_val.as_str() {
352                        sample.prompt = Some(instruction_str.to_string());
353                    }
354                }
355                if let Some(output_val) = obj.get(output_field) {
356                    if let Some(output_str) = output_val.as_str() {
357                        sample.completion = Some(output_str.to_string());
358                    }
359                }
360            }
361            TrainingFormat::Chat { messages_field } => {
362                if let Some(messages_val) = obj.get(messages_field) {
363                    sample.meta.insert("messages".to_string(), messages_val.clone());
364                    if let Some(messages) = messages_val.as_array() {
365                        if let Some(first_msg) = messages.first() {
366                            if let Some(content) = first_msg
367                                .get("content")
368                                .and_then(|c| c.as_str())
369                            {
370                                sample.prompt = Some(content.to_string());
371                            }
372                        }
373                        if let Some(last_msg) = messages.last() {
374                            if let Some(content) = last_msg
375                                .get("content")
376                                .and_then(|c| c.as_str())
377                            {
378                                sample.completion = Some(content.to_string());
379                            }
380                        }
381                    }
382                }
383            }
384            TrainingFormat::Custom { fields } => {
385                for field in fields {
386                    if let Some(value) = obj.get(field) {
387                        sample.meta.insert(field.clone(), value.clone());
388                    }
389                }
390            }
391        }
392        for (key, value) in obj {
393            if !matches!(
394                key.as_str(), "prompt" | "chosen" | "rejected" | "completion" | "label" |
395                "instruction" | "output" | "messages"
396            ) {
397                sample.meta.insert(key.clone(), value.clone());
398            }
399        }
400        Ok(sample)
401    }
402    fn compute_statistics(&self, samples: &[TrainingSample]) -> Result<DatasetStats> {
403        let total_samples = samples.len();
404        let mut total_prompt_length = 0;
405        let mut total_completion_length = 0;
406        let mut prompt_count = 0;
407        let mut completion_count = 0;
408        let mut field_coverage = HashMap::new();
409        for sample in samples {
410            if sample.prompt.is_some() {
411                *field_coverage.entry("prompt".to_string()).or_insert(0.0) += 1.0;
412                total_prompt_length += sample.prompt.as_ref().unwrap().len();
413                prompt_count += 1;
414            }
415            if sample.chosen.is_some() {
416                *field_coverage.entry("chosen".to_string()).or_insert(0.0) += 1.0;
417            }
418            if sample.rejected.is_some() {
419                *field_coverage.entry("rejected".to_string()).or_insert(0.0) += 1.0;
420            }
421            if sample.completion.is_some() {
422                *field_coverage.entry("completion".to_string()).or_insert(0.0) += 1.0;
423                total_completion_length += sample.completion.as_ref().unwrap().len();
424                completion_count += 1;
425            }
426            if sample.label.is_some() {
427                *field_coverage.entry("label".to_string()).or_insert(0.0) += 1.0;
428            }
429        }
430        for count in field_coverage.values_mut() {
431            *count = *count / total_samples as f64;
432        }
433        let avg_prompt_length = if prompt_count > 0 {
434            total_prompt_length as f64 / prompt_count as f64
435        } else {
436            0.0
437        };
438        let avg_completion_length = if completion_count > 0 {
439            total_completion_length as f64 / completion_count as f64
440        } else {
441            0.0
442        };
443        let quality_score = Some(
444            (field_coverage.get("prompt").unwrap_or(&0.0)
445                + field_coverage.get("chosen").unwrap_or(&0.0)
446                + field_coverage.get("rejected").unwrap_or(&0.0)
447                + field_coverage.get("completion").unwrap_or(&0.0)) / 4.0,
448        );
449        Ok(DatasetStats {
450            total_samples,
451            avg_prompt_length,
452            avg_completion_length,
453            field_coverage,
454            quality_score,
455        })
456    }
457}
458impl TrainingDataset {
459    pub fn to_algorithm_format(
460        &self,
461        algorithm: &str,
462    ) -> Result<Box<dyn std::any::Any>> {
463        match algorithm.to_lowercase().as_str() {
464            "bco" => {
465                let bco_data = self.to_bco_format()?;
466                Ok(Box::new(bco_data))
467            }
468            "dpo" => {
469                let dpo_data = self.to_dpo_format()?;
470                Ok(Box::new(dpo_data))
471            }
472            "ppo" => {
473                let ppo_data = self.to_ppo_format()?;
474                Ok(Box::new(ppo_data))
475            }
476            "sft" => {
477                let sft_data = self.to_sft_format()?;
478                Ok(Box::new(sft_data))
479            }
480            _ => bail!("Unsupported algorithm: {}", algorithm),
481        }
482    }
483    fn to_bco_format(&self) -> Result<BCODataset> {
484        let mut bco_samples = Vec::new();
485        for sample in &self.samples {
486            match &self.format {
487                TrainingFormat::Preference { .. } => {
488                    if let (Some(chosen), Some(rejected)) = (
489                        &sample.chosen,
490                        &sample.rejected,
491                    ) {
492                        let (completion, label) = if rand::random::<bool>() {
493                            (chosen.clone(), true)
494                        } else {
495                            (rejected.clone(), false)
496                        };
497                        bco_samples
498                            .push(BCOSample {
499                                prompt: sample.prompt.clone().unwrap_or_default(),
500                                completion,
501                                label,
502                            });
503                    }
504                }
505                TrainingFormat::Completion { .. } => {
506                    if let Some(completion) = &sample.completion {
507                        let label = sample.label.map(|l| l > 0.5).unwrap_or(true);
508                        bco_samples
509                            .push(BCOSample {
510                                prompt: sample.prompt.clone().unwrap_or_default(),
511                                completion: completion.clone(),
512                                label,
513                            });
514                    }
515                }
516                _ => {
517                    if let Some(completion) = &sample.completion {
518                        bco_samples
519                            .push(BCOSample {
520                                prompt: sample.prompt.clone().unwrap_or_default(),
521                                completion: completion.clone(),
522                                label: sample.label.map(|l| l > 0.5).unwrap_or(true),
523                            });
524                    }
525                }
526            }
527        }
528        Ok(BCODataset { samples: bco_samples })
529    }
530    fn to_dpo_format(&self) -> Result<DPODataset> {
531        let mut dpo_samples = Vec::new();
532        for sample in &self.samples {
533            if let TrainingFormat::Preference { .. } = &self.format {
534                if let (Some(chosen), Some(rejected)) = (
535                    &sample.chosen,
536                    &sample.rejected,
537                ) {
538                    dpo_samples
539                        .push(DPOSample {
540                            prompt: sample.prompt.clone().unwrap_or_default(),
541                            chosen: chosen.clone(),
542                            rejected: rejected.clone(),
543                        });
544                }
545            } else {
546                bail!(
547                    "DPO format requires preference-style data (chosen/rejected fields)"
548                );
549            }
550        }
551        Ok(DPODataset { samples: dpo_samples })
552    }
553    fn to_ppo_format(&self) -> Result<PPODataset> {
554        let mut ppo_samples = Vec::new();
555        for sample in &self.samples {
556            if let Some(completion) = &sample.completion {
557                ppo_samples
558                    .push(PPOSample {
559                        prompt: sample.prompt.clone().unwrap_or_default(),
560                        completion: completion.clone(),
561                        reward: sample.label.unwrap_or(0.0),
562                    });
563            }
564        }
565        Ok(PPODataset { samples: ppo_samples })
566    }
567    fn to_sft_format(&self) -> Result<SFTDataset> {
568        let mut sft_samples = Vec::new();
569        for sample in &self.samples {
570            if let Some(completion) = &sample.completion {
571                sft_samples
572                    .push(SFTSample {
573                        prompt: sample.prompt.clone().unwrap_or_default(),
574                        completion: completion.clone(),
575                    });
576            }
577        }
578        Ok(SFTDataset { samples: sft_samples })
579    }
580    pub fn quality_assessment(&self) -> DatasetQualityReport {
581        let mut report = DatasetQualityReport {
582            overall_score: 0.0,
583            issues: Vec::new(),
584            recommendations: Vec::new(),
585        };
586        let prompt_coverage = self
587            .statistics
588            .field_coverage
589            .get("prompt")
590            .unwrap_or(&0.0);
591        let completion_coverage = self
592            .statistics
593            .field_coverage
594            .get("completion")
595            .unwrap_or(&0.0);
596        let chosen_coverage = self
597            .statistics
598            .field_coverage
599            .get("chosen")
600            .unwrap_or(&0.0);
601        let rejected_coverage = self
602            .statistics
603            .field_coverage
604            .get("rejected")
605            .unwrap_or(&0.0);
606        match &self.format {
607            TrainingFormat::Preference { .. } => {
608                if *chosen_coverage < 0.9 {
609                    report
610                        .issues
611                        .push(
612                            format!(
613                                "Low chosen field coverage: {:.1}%", chosen_coverage * 100.0
614                            ),
615                        );
616                }
617                if *rejected_coverage < 0.9 {
618                    report
619                        .issues
620                        .push(
621                            format!(
622                                "Low rejected field coverage: {:.1}%", rejected_coverage *
623                                100.0
624                            ),
625                        );
626                }
627                report.overall_score = (chosen_coverage + rejected_coverage
628                    + prompt_coverage) / 3.0;
629            }
630            TrainingFormat::Completion { .. } => {
631                if *completion_coverage < 0.9 {
632                    report
633                        .issues
634                        .push(
635                            format!(
636                                "Low completion field coverage: {:.1}%", completion_coverage
637                                * 100.0
638                            ),
639                        );
640                }
641                report.overall_score = (completion_coverage + prompt_coverage) / 2.0;
642            }
643            _ => {
644                report.overall_score = self.statistics.quality_score.unwrap_or(0.0);
645            }
646        }
647        if self.statistics.avg_prompt_length < 10.0 {
648            report.issues.push("Very short average prompt length".to_string());
649        }
650        if self.statistics.avg_completion_length < 10.0 {
651            report.issues.push("Very short average completion length".to_string());
652        }
653        if report.issues.is_empty() {
654            report.recommendations.push("Dataset quality looks good!".to_string());
655        } else {
656            report
657                .recommendations
658                .push(
659                    "Consider filtering or augmenting low-quality samples".to_string(),
660                );
661        }
662        report
663    }
664}
665#[derive(Debug, Clone, Serialize, Deserialize)]
666pub struct BCOSample {
667    pub prompt: String,
668    pub completion: String,
669    pub label: bool,
670}
671#[derive(Debug, Clone)]
672pub struct BCODataset {
673    pub samples: Vec<BCOSample>,
674}
675#[derive(Debug, Clone, Serialize, Deserialize)]
676pub struct DPOSample {
677    pub prompt: String,
678    pub chosen: String,
679    pub rejected: String,
680}
681#[derive(Debug, Clone)]
682pub struct DPODataset {
683    pub samples: Vec<DPOSample>,
684}
685#[derive(Debug, Clone, Serialize, Deserialize)]
686pub struct PPOSample {
687    pub prompt: String,
688    pub completion: String,
689    pub reward: f32,
690}
691#[derive(Debug, Clone)]
692pub struct PPODataset {
693    pub samples: Vec<PPOSample>,
694}
695#[derive(Debug, Clone, Serialize, Deserialize)]
696pub struct SFTSample {
697    pub prompt: String,
698    pub completion: String,
699}
700#[derive(Debug, Clone)]
701pub struct SFTDataset {
702    pub samples: Vec<SFTSample>,
703}
704#[derive(Debug, Clone)]
705pub struct DatasetQualityReport {
706    pub overall_score: f64,
707    pub issues: Vec<String>,
708    pub recommendations: Vec<String>,
709}
710fn clean_whitespace(text: &str) -> String {
711    text.lines()
712        .map(|line| line.trim_end())
713        .collect::<Vec<_>>()
714        .join("\n")
715        .trim()
716        .to_string()
717}
718fn process_file(
719    src_path: &Path,
720    dst_path: &Path,
721    schema_path: Option<&Path>,
722    format_override: &DataFormat,
723) -> Result<()> {
724    let content = fs::read_to_string(src_path)
725        .with_context(|| format!("Failed to read {}", src_path.display()))?;
726    let raw: Value = serde_json::from_str(&content)
727        .with_context(|| format!("Failed to parse JSON in {}", src_path.display()))?;
728    let _temp_dataset = GenericJSONDataset::new(
729        &[src_path.to_path_buf()],
730        schema_path,
731        format_override.clone(),
732    )?;
733    let cleaned = if let Value::Array(arr) = raw {
734        let cleaned_arr: Vec<Value> = arr
735            .into_iter()
736            .map(|mut entry| {
737                if let Value::Object(ref mut obj) = entry {
738                    for (_key, value) in obj.iter_mut() {
739                        if let Value::String(ref mut s) = value {
740                            *s = clean_whitespace(s);
741                        }
742                    }
743                }
744                entry
745            })
746            .collect();
747        Value::Array(cleaned_arr)
748    } else {
749        raw
750    };
751    dst_path
752        .parent()
753        .map(|p| fs::create_dir_all(p))
754        .transpose()
755        .with_context(|| {
756            format!("Failed to create directory for {}", dst_path.display())
757        })?;
758    let cleaned_json = serde_json::to_string_pretty(&cleaned)
759        .with_context(|| "Failed to serialize cleaned JSON")?;
760    fs::write(dst_path, cleaned_json)
761        .with_context(|| format!("Failed to write to {}", dst_path.display()))?;
762    Ok(())
763}
764pub async fn run_multi_process_clean(
765    src_files: Vec<PathBuf>,
766    dst_root: &Path,
767    schema_dir: Option<&Path>,
768    format_override: &DataFormat,
769    _jobs: usize,
770) -> Result<()> {
771    if src_files.is_empty() {
772        bail!("No source files provided");
773    }
774    let tasks: Vec<_> = src_files
775        .iter()
776        .map(|src| {
777            let dst = dst_root.join(src.file_name().unwrap());
778            let schema_path = schema_dir
779                .and_then(|dir| {
780                    let candidate = dir
781                        .join(format!("{}.schema.json", src.file_stem() ?.to_str() ?));
782                    if candidate.is_file() { Some(candidate) } else { None }
783                });
784            (src.clone(), dst, schema_path)
785        })
786        .collect();
787    let pb = ProgressBar::new(tasks.len() as u64);
788    pb.set_style(
789        ProgressStyle::default_bar()
790            .template(
791                "{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({eta}) {msg}",
792            )
793            .unwrap()
794            .progress_chars("#>-"),
795    );
796    pb.set_message("Cleaning & validating");
797    let pb = Arc::new(Mutex::new(pb));
798    let results: Vec<Result<()>> = tasks
799        .iter()
800        .map(|(src, dst, schema_path)| {
801            let result = process_file(src, dst, schema_path.as_deref(), format_override);
802            if let Ok(pb) = pb.try_lock() {
803                pb.inc(1);
804            }
805            result
806        })
807        .collect();
808    if let Ok(pb) = pb.try_lock() {
809        pb.finish_with_message("Complete");
810    }
811    let errors: Vec<_> = results
812        .into_iter()
813        .enumerate()
814        .filter_map(|(i, r)| r.err().map(|e| (i, e)))
815        .collect();
816    if !errors.is_empty() {
817        for (i, e) in &errors {
818            eprintln!("Error processing file {}: {}", i, e);
819        }
820        bail!("{} files failed processing", errors.len());
821    }
822    Ok(())
823}
824pub fn gather_json_paths(
825    files: &[String],
826    data_dirs: &[String],
827) -> Result<Vec<PathBuf>> {
828    let mut paths = Vec::new();
829    for file in files {
830        let path = PathBuf::from(file);
831        if !path.is_file() {
832            bail!("Specified file does not exist: {}", path.display());
833        }
834        paths.push(path);
835    }
836    for dir in data_dirs {
837        let dir_path = PathBuf::from(dir);
838        if !dir_path.is_dir() {
839            bail!("Data directory does not exist: {}", dir_path.display());
840        }
841        for entry in WalkDir::new(&dir_path).into_iter().filter_map(|e| e.ok()) {
842            if entry.path().extension().map_or(false, |ext| ext == "json") {
843                paths.push(entry.path().to_path_buf());
844            }
845        }
846    }
847    if paths.is_empty() {
848        bail!("No JSON files found in specified paths");
849    }
850    paths.sort();
851    paths.dedup();
852    Ok(paths)
853}
854pub fn find_schema_for_file(
855    json_path: &Path,
856    schema_dir: Option<&Path>,
857) -> Option<PathBuf> {
858    schema_dir
859        .and_then(|dir| {
860            let candidate = dir
861                .join(format!("{}.schema.json", json_path.file_stem() ?.to_str() ?));
862            if candidate.is_file() { Some(candidate) } else { None }
863        })
864}
865pub async fn run_json_cmd(args: JsonArgs) -> Result<()> {
866    if args.multi_process {
867        if args.input_folder.is_none() || args.output.is_none() {
868            bail!("Multi-process mode requires both --input-folder and --output");
869        }
870        let src_root = args.input_folder.unwrap();
871        let dst_root = args.output.unwrap();
872        if !src_root.is_dir() {
873            bail!("Input folder does not exist: {}", src_root.display());
874        }
875        let src_files: Vec<_> = WalkDir::new(&src_root)
876            .into_iter()
877            .filter_map(|e| e.ok())
878            .filter(|e| e.path().extension().map_or(false, |ext| ext == "json"))
879            .map(|e| e.path().to_path_buf())
880            .collect();
881        if src_files.is_empty() {
882            bail!("No JSON files found in {}", src_root.display());
883        }
884        let schema_dir = args.schema_dir.as_ref();
885        let format_override = args.format.clone();
886        println!(
887            "🔧 Starting multi-process clean-validate:\n  source: {}\n  destination: {}\n  workers: {}\n  format: {:?}",
888            src_root.display(), dst_root.display(), args.jobs, format_override
889        );
890        run_multi_process_clean(
891                src_files,
892                &dst_root,
893                schema_dir.map(|p| p.as_path()),
894                &format_override,
895                args.jobs,
896            )
897            .await?;
898        println!("✅ All files cleaned and validated successfully.");
899        return Ok(());
900    }
901    let json_paths = gather_json_paths(&args.file, &args.data_dir)?;
902    let mut all_samples = Vec::new();
903    for path in &json_paths {
904        let schema_path = find_schema_for_file(
905            path,
906            args.schema_dir.as_ref().map(|p| p.as_path()),
907        );
908        println!(
909            "Loading {} (schema: {})", path.file_name().unwrap().to_string_lossy(),
910            schema_path.as_ref().map(| p | p.file_name().unwrap().to_string_lossy())
911            .unwrap_or(std::borrow::Cow::Borrowed("built-in"))
912        );
913        let dataset = GenericJSONDataset::new(
914            &[path.clone()],
915            schema_path.as_ref().map(|p| p.as_path()),
916            args.format.clone(),
917        )?;
918        all_samples.extend(dataset.data);
919    }
920    if let Some(merge_output) = &args.merge_output {
921        merge_output
922            .parent()
923            .map(|p| fs::create_dir_all(p))
924            .transpose()
925            .with_context(|| {
926                format!("Failed to create directory for {}", merge_output.display())
927            })?;
928        let merged_json = serde_json::to_string_pretty(&all_samples)
929            .with_context(|| "Failed to serialize merged JSON")?;
930        fs::write(merge_output, merged_json)
931            .with_context(|| {
932                format!("Failed to write merged output to {}", merge_output.display())
933            })?;
934        println!(
935            "✅ Merged {} samples into {}", all_samples.len(), merge_output.display()
936        );
937    }
938    if args.show_stats {
939        if !json_paths.is_empty() {
940            let temp_dataset = GenericJSONDataset::new(
941                &json_paths,
942                args.schema_dir.as_ref().map(|p| p.as_path()),
943                args.format,
944            )?;
945            println!(
946                "\n--- Dataset statistics ------------------------------------------------"
947            );
948            for (k, v) in temp_dataset.stats() {
949                println!("{:20}: {}", k, v);
950            }
951            println!(
952                "--------------------------------------------------------------------"
953            );
954        }
955    }
956    if args.merge_output.is_none() {
957        println!("🎉 Validation finished – no merged output requested.");
958    }
959    Ok(())
960}
961#[derive(Args, Debug, Clone)]
962pub struct JsonArgs {
963    #[arg(long)]
964    pub data_dir: Vec<String>,
965    #[arg(long, short = 'f')]
966    pub file: Vec<String>,
967    #[arg(long)]
968    pub schema_dir: Option<PathBuf>,
969    #[arg(long, default_value = "auto")]
970    pub format: DataFormat,
971    #[arg(long)]
972    pub merge_output: Option<PathBuf>,
973    #[arg(long)]
974    pub show_stats: bool,
975    #[arg(long, default_value = "42")]
976    pub seed: u64,
977    #[arg(long)]
978    pub multi_process: bool,
979    #[arg(long)]
980    pub input_folder: Option<PathBuf>,
981    #[arg(long, short = 'o')]
982    pub output: Option<PathBuf>,
983    #[arg(long, short = 'j', default_value_t = num_cpus::get())]
984    pub jobs: usize,
985}