helix/dna/map/
hf.rs

1#![warn(clippy::all, clippy::pedantic)]
2use anyhow::{bail, Context, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use super::core::{TrainingDataset, TrainingSample, DatasetStats, TrainingFormat};
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct HfDatasetConfig {
9    pub source: String,
10    pub split: String,
11    pub format: Option<String>,
12    pub rpl_filter: Option<HashMap<String, serde_json::Value>>,
13    pub revision: Option<String>,
14    pub streaming: bool,
15    pub trust_remote_code: bool,
16    pub num_proc: Option<usize>,
17}
18pub struct HuggingFaceDataset {
19    pub name: String,
20    pub split: String,
21    pub data: Vec<serde_json::Value>,
22    pub features: HashMap<String, serde_json::Value>,
23    pub metadata: HashMap<String, serde_json::Value>,
24}
25impl HuggingFaceDataset {
26    pub async fn load(name: &str, split: &str, cache_dir: &Path) -> Result<Self> {
27        println!("🔄 Loading dataset {} from HuggingFace Hub...", name);
28        let cache = hf_hub::Cache::new(cache_dir.to_path_buf());
29        let repo = cache.dataset(name.to_string());
30        let json_files = Self::find_json_files(&repo, name).await?;
31        if json_files.is_empty() {
32            bail!("No JSON files found in dataset {}", name);
33        }
34        let data_file = &json_files[0];
35        println!("📁 Found data file: {}", data_file.display());
36        let data = Self::load_json_data(data_file).await?;
37        let features = Self::infer_features(&data)?;
38        let metadata = Self::extract_metadata(&repo).await?;
39        println!("✅ Successfully loaded {} samples from {}", data.len(), name);
40        Ok(HuggingFaceDataset {
41            name: name.to_string(),
42            split: split.to_string(),
43            data,
44            features,
45            metadata,
46        })
47    }
48    async fn find_json_files(
49        repo: &hf_hub::CacheRepo,
50        name: &str,
51    ) -> Result<Vec<PathBuf>> {
52        let mut json_files = Vec::new();
53        let possible_files = vec![
54            format!("{}.json", name.replace('/', "--")), "train.json".to_string(),
55            "validation.json".to_string(), "test.json".to_string(), "data.json"
56            .to_string(),
57        ];
58        for file_name in possible_files {
59            if let Some(path) = repo.get(&file_name) {
60                json_files.push(path);
61            }
62        }
63        if json_files.is_empty() {}
64        Ok(json_files)
65    }
66    async fn load_json_data(file_path: &Path) -> Result<Vec<serde_json::Value>> {
67        let content = tokio::fs::read_to_string(file_path)
68            .await
69            .with_context(|| {
70                format!("Failed to read JSON file: {}", file_path.display())
71            })?;
72        let json_value: serde_json::Value = serde_json::from_str(&content)
73            .with_context(|| {
74                format!("Failed to parse JSON from: {}", file_path.display())
75            })?;
76        match json_value {
77            serde_json::Value::Array(arr) => Ok(arr),
78            serde_json::Value::Object(obj) => {
79                if let Some(data) = obj.get("data") {
80                    if let Some(arr) = data.as_array() {
81                        return Ok(arr.clone());
82                    }
83                }
84                if let Some(train) = obj.get("train") {
85                    if let Some(arr) = train.as_array() {
86                        return Ok(arr.clone());
87                    }
88                }
89                Ok(vec![serde_json::Value::Object(obj)])
90            }
91            _ => bail!("Unsupported JSON structure in {}", file_path.display()),
92        }
93    }
94    fn infer_features(
95        data: &[serde_json::Value],
96    ) -> Result<HashMap<String, serde_json::Value>> {
97        let mut features = HashMap::new();
98        if data.is_empty() {
99            return Ok(features);
100        }
101        let sample_size = std::cmp::min(10, data.len());
102        let samples = &data[..sample_size];
103        let mut key_types = HashMap::new();
104        for sample in samples {
105            if let Some(obj) = sample.as_object() {
106                for (key, value) in obj {
107                    let type_str = match value {
108                        serde_json::Value::String(_) => "string",
109                        serde_json::Value::Number(_) => "number",
110                        serde_json::Value::Bool(_) => "boolean",
111                        serde_json::Value::Array(_) => "array",
112                        serde_json::Value::Object(_) => "object",
113                        serde_json::Value::Null => "null",
114                    };
115                    key_types.insert(key.clone(), type_str.to_string());
116                }
117            }
118        }
119        for (key, type_str) in key_types {
120            features
121                .insert(
122                    key,
123                    serde_json::json!({ "dtype" : type_str, "_type" : "Value" }),
124                );
125        }
126        Ok(features)
127    }
128    async fn extract_metadata(
129        _repo: &hf_hub::CacheRepo,
130    ) -> Result<HashMap<String, serde_json::Value>> {
131        let mut metadata = HashMap::new();
132        metadata.insert("dataset_name".to_string(), serde_json::json!("unknown"));
133        metadata.insert("split".to_string(), serde_json::json!("unknown"));
134        metadata.insert("num_samples".to_string(), serde_json::json!(0));
135        Ok(metadata)
136    }
137    pub fn get_features(&self) -> Result<Vec<String>> {
138        Ok(self.features.keys().map(|s| s.to_string()).collect())
139    }
140    pub fn info(&self) -> HashMap<String, serde_json::Value> {
141        let mut info = HashMap::new();
142        info.insert("name".to_string(), serde_json::json!(self.name));
143        info.insert("split".to_string(), serde_json::json!(self.split));
144        info.insert("num_samples".to_string(), serde_json::json!(self.data.len()));
145        info.insert("features".to_string(), serde_json::json!(self.features));
146        info.insert("metadata".to_string(), serde_json::json!(self.metadata));
147        info
148    }
149}
150#[async_trait::async_trait]
151pub trait DatasetProcessor {
152    async fn process(&self, dataset: HuggingFaceDataset) -> Result<TrainingDataset>;
153}
154pub struct PreferenceProcessor;
155#[async_trait::async_trait]
156impl DatasetProcessor for PreferenceProcessor {
157    async fn process(&self, dataset: HuggingFaceDataset) -> Result<TrainingDataset> {
158        let mut samples = Vec::new();
159        for item in dataset.data {
160            if let Some(obj) = item.as_object() {
161                let sample = TrainingSample {
162                    prompt: obj
163                        .get("prompt")
164                        .and_then(|v| v.as_str())
165                        .map(|s| s.to_string()),
166                    chosen: obj
167                        .get("chosen")
168                        .and_then(|v| v.as_str())
169                        .map(|s| s.to_string()),
170                    rejected: obj
171                        .get("rejected")
172                        .and_then(|v| v.as_str())
173                        .map(|s| s.to_string()),
174                    completion: None,
175                    label: None,
176                    meta: obj
177                        .clone()
178                        .into_iter()
179                        .filter(|(k, _)| {
180                            !matches!(k.as_str(), "prompt" | "chosen" | "rejected")
181                        })
182                        .map(|(k, v)| (k, v))
183                        .collect(),
184                };
185                samples.push(sample);
186            }
187        }
188        let format = TrainingFormat::Preference {
189            chosen_field: "chosen".to_string(),
190            rejected_field: "rejected".to_string(),
191        };
192        let statistics = Self::compute_statistics(&samples);
193        Ok(TrainingDataset {
194            samples,
195            format,
196            statistics,
197        })
198    }
199}
200impl PreferenceProcessor {
201    fn compute_statistics(samples: &[TrainingSample]) -> DatasetStats {
202        let total_samples = samples.len();
203        let mut total_prompt_length = 0;
204        let mut prompt_count = 0;
205        let mut field_coverage = HashMap::new();
206        for sample in samples {
207            if sample.prompt.is_some() {
208                *field_coverage.entry("prompt".to_string()).or_insert(0.0) += 1.0;
209                total_prompt_length += sample.prompt.as_ref().unwrap().len();
210                prompt_count += 1;
211            }
212            if sample.chosen.is_some() {
213                *field_coverage.entry("chosen".to_string()).or_insert(0.0) += 1.0;
214            }
215            if sample.rejected.is_some() {
216                *field_coverage.entry("rejected".to_string()).or_insert(0.0) += 1.0;
217            }
218        }
219        for count in field_coverage.values_mut() {
220            *count = *count / total_samples as f64;
221        }
222        let avg_prompt_length = if prompt_count > 0 {
223            total_prompt_length as f64 / prompt_count as f64
224        } else {
225            0.0
226        };
227        let quality_score = Some(
228            (field_coverage.get("prompt").unwrap_or(&0.0)
229                + field_coverage.get("chosen").unwrap_or(&0.0)
230                + field_coverage.get("rejected").unwrap_or(&0.0)) / 3.0,
231        );
232        DatasetStats {
233            total_samples,
234            avg_prompt_length,
235            avg_completion_length: 0.0,
236            field_coverage,
237            quality_score,
238        }
239    }
240}
241pub struct CompletionProcessor;
242#[async_trait::async_trait]
243impl DatasetProcessor for CompletionProcessor {
244    async fn process(&self, dataset: HuggingFaceDataset) -> Result<TrainingDataset> {
245        let mut samples = Vec::new();
246        for item in dataset.data {
247            if let Some(obj) = item.as_object() {
248                let sample = TrainingSample {
249                    prompt: obj
250                        .get("prompt")
251                        .and_then(|v| v.as_str())
252                        .map(|s| s.to_string()),
253                    chosen: None,
254                    rejected: None,
255                    completion: obj
256                        .get("completion")
257                        .and_then(|v| v.as_str())
258                        .map(|s| s.to_string()),
259                    label: obj.get("label").and_then(|v| v.as_f64()).map(|f| f as f32),
260                    meta: obj
261                        .clone()
262                        .into_iter()
263                        .filter(|(k, _)| {
264                            !matches!(k.as_str(), "prompt" | "completion" | "label")
265                        })
266                        .map(|(k, v)| (k, v))
267                        .collect(),
268                };
269                samples.push(sample);
270            }
271        }
272        let format = TrainingFormat::Completion {
273            completion_field: "completion".to_string(),
274            label_field: Some("label".to_string()),
275        };
276        let statistics = Self::compute_statistics(&samples);
277        Ok(TrainingDataset {
278            samples,
279            format,
280            statistics,
281        })
282    }
283}
284impl CompletionProcessor {
285    fn compute_statistics(samples: &[TrainingSample]) -> DatasetStats {
286        let total_samples = samples.len();
287        let mut total_prompt_length = 0;
288        let mut total_completion_length = 0;
289        let mut prompt_count = 0;
290        let mut completion_count = 0;
291        let mut field_coverage = HashMap::new();
292        for sample in samples {
293            if sample.prompt.is_some() {
294                *field_coverage.entry("prompt".to_string()).or_insert(0.0) += 1.0;
295                total_prompt_length += sample.prompt.as_ref().unwrap().len();
296                prompt_count += 1;
297            }
298            if sample.completion.is_some() {
299                *field_coverage.entry("completion".to_string()).or_insert(0.0) += 1.0;
300                total_completion_length += sample.completion.as_ref().unwrap().len();
301                completion_count += 1;
302            }
303            if sample.label.is_some() {
304                *field_coverage.entry("label".to_string()).or_insert(0.0) += 1.0;
305            }
306        }
307        for count in field_coverage.values_mut() {
308            *count = *count / total_samples as f64;
309        }
310        let avg_prompt_length = if prompt_count > 0 {
311            total_prompt_length as f64 / prompt_count as f64
312        } else {
313            0.0
314        };
315        let avg_completion_length = if completion_count > 0 {
316            total_completion_length as f64 / completion_count as f64
317        } else {
318            0.0
319        };
320        let quality_score = Some(
321            (field_coverage.get("prompt").unwrap_or(&0.0)
322                + field_coverage.get("completion").unwrap_or(&0.0)) / 2.0,
323        );
324        DatasetStats {
325            total_samples,
326            avg_prompt_length,
327            avg_completion_length,
328            field_coverage,
329            quality_score,
330        }
331    }
332}
333pub struct InstructionProcessor;
334#[async_trait::async_trait]
335impl DatasetProcessor for InstructionProcessor {
336    async fn process(&self, dataset: HuggingFaceDataset) -> Result<TrainingDataset> {
337        let mut samples = Vec::new();
338        for item in dataset.data {
339            if let Some(obj) = item.as_object() {
340                let sample = TrainingSample {
341                    prompt: obj
342                        .get("instruction")
343                        .and_then(|v| v.as_str())
344                        .map(|s| s.to_string()),
345                    chosen: None,
346                    rejected: None,
347                    completion: obj
348                        .get("output")
349                        .and_then(|v| v.as_str())
350                        .map(|s| s.to_string()),
351                    label: None,
352                    meta: obj
353                        .clone()
354                        .into_iter()
355                        .filter(|(k, _)| !matches!(k.as_str(), "instruction" | "output"))
356                        .map(|(k, v)| (k, v))
357                        .collect(),
358                };
359                samples.push(sample);
360            }
361        }
362        let format = TrainingFormat::Instruction {
363            instruction_field: "instruction".to_string(),
364            output_field: "output".to_string(),
365        };
366        let statistics = Self::compute_statistics(&samples);
367        Ok(TrainingDataset {
368            samples,
369            format,
370            statistics,
371        })
372    }
373}
374impl InstructionProcessor {
375    fn compute_statistics(samples: &[TrainingSample]) -> DatasetStats {
376        let total_samples = samples.len();
377        let mut total_prompt_length = 0;
378        let mut total_completion_length = 0;
379        let mut prompt_count = 0;
380        let mut completion_count = 0;
381        let mut field_coverage = HashMap::new();
382        for sample in samples {
383            if sample.prompt.is_some() {
384                *field_coverage.entry("instruction".to_string()).or_insert(0.0) += 1.0;
385                total_prompt_length += sample.prompt.as_ref().unwrap().len();
386                prompt_count += 1;
387            }
388            if sample.completion.is_some() {
389                *field_coverage.entry("output".to_string()).or_insert(0.0) += 1.0;
390                total_completion_length += sample.completion.as_ref().unwrap().len();
391                completion_count += 1;
392            }
393        }
394        for count in field_coverage.values_mut() {
395            *count = *count / total_samples as f64;
396        }
397        let avg_prompt_length = if prompt_count > 0 {
398            total_prompt_length as f64 / prompt_count as f64
399        } else {
400            0.0
401        };
402        let avg_completion_length = if completion_count > 0 {
403            total_completion_length as f64 / completion_count as f64
404        } else {
405            0.0
406        };
407        let quality_score = Some(
408            (field_coverage.get("instruction").unwrap_or(&0.0)
409                + field_coverage.get("output").unwrap_or(&0.0)) / 2.0,
410        );
411        DatasetStats {
412            total_samples,
413            avg_prompt_length,
414            avg_completion_length,
415            field_coverage,
416            quality_score,
417        }
418    }
419}
420pub struct HfProcessor {
421    cache_dir: PathBuf,
422    processors: HashMap<String, Box<dyn DatasetProcessor + Send + Sync>>,
423}
424impl HfProcessor {
425    pub fn new(cache_dir: PathBuf) -> Self {
426        let mut processors: HashMap<String, Box<dyn DatasetProcessor + Send + Sync>> = HashMap::new();
427        processors.insert("preference".to_string(), Box::new(PreferenceProcessor));
428        processors.insert("completion".to_string(), Box::new(CompletionProcessor));
429        processors.insert("instruction".to_string(), Box::new(InstructionProcessor));
430        Self { cache_dir, processors }
431    }
432    pub async fn process_dataset(
433        &self,
434        dataset_name: &str,
435        config: &HfDatasetConfig,
436    ) -> Result<TrainingDataset> {
437        let dataset = HuggingFaceDataset::load(
438                dataset_name,
439                &config.split,
440                &self.cache_dir,
441            )
442            .await?;
443        let dataset_type = self.detect_dataset_type(&dataset)?;
444        let processor = self
445            .processors
446            .get(&dataset_type)
447            .ok_or_else(|| {
448                anyhow::anyhow!("No processor for dataset type: {}", dataset_type)
449            })?;
450        let mut processed = processor.process(dataset).await?;
451        if let Some(filters) = &config.rpl_filter {
452            processed = self.apply_filters(processed, filters)?;
453        }
454        Ok(processed)
455    }
456    fn detect_dataset_type(&self, dataset: &HuggingFaceDataset) -> Result<String> {
457        let features = dataset.get_features()?;
458        if features.contains(&"chosen".to_string())
459            && features.contains(&"rejected".to_string())
460        {
461            Ok("preference".to_string())
462        } else if features.contains(&"completion".to_string())
463            && features.contains(&"label".to_string())
464        {
465            Ok("completion".to_string())
466        } else if features.contains(&"instruction".to_string())
467            && features.contains(&"output".to_string())
468        {
469            Ok("instruction".to_string())
470        } else {
471            bail!("Cannot determine dataset type from features: {:?}", features)
472        }
473    }
474    fn apply_filters(
475        &self,
476        dataset: TrainingDataset,
477        _filters: &HashMap<String, serde_json::Value>,
478    ) -> Result<TrainingDataset> {
479        Ok(dataset)
480    }
481}
482impl Default for HfProcessor {
483    fn default() -> Self {
484        Self::new(PathBuf::from("./hf_cache"))
485    }
486}
487#[cfg(test)]
488mod tests {
489    use super::*;
490    #[tokio::test]
491    async fn test_hf_processor_creation() {
492        let processor = HfProcessor::default();
493        assert!(processor.processors.contains_key("preference"));
494        assert!(processor.processors.contains_key("completion"));
495        assert!(processor.processors.contains_key("instruction"));
496    }
497    #[tokio::test]
498    async fn test_preference_processor() {
499        let mut features = HashMap::new();
500        features
501            .insert(
502                "prompt".to_string(),
503                serde_json::json!({ "dtype" : "string", "_type" : "Value" }),
504            );
505        features
506            .insert(
507                "chosen".to_string(),
508                serde_json::json!({ "dtype" : "string", "_type" : "Value" }),
509            );
510        features
511            .insert(
512                "rejected".to_string(),
513                serde_json::json!({ "dtype" : "string", "_type" : "Value" }),
514            );
515        let dataset = HuggingFaceDataset {
516            name: "test".to_string(),
517            split: "train".to_string(),
518            data: vec![
519                serde_json::json!({ "prompt" : "Test prompt", "chosen" : "Good response",
520                "rejected" : "Bad response" })
521            ],
522            features,
523            metadata: HashMap::new(),
524        };
525        let processor = PreferenceProcessor;
526        let result = processor.process(dataset).await.unwrap();
527        assert_eq!(result.samples.len(), 1);
528        assert_eq!(result.samples[0].prompt.as_ref().unwrap(), "Test prompt");
529        assert_eq!(result.samples[0].chosen.as_ref().unwrap(), "Good response");
530        assert_eq!(result.samples[0].rejected.as_ref().unwrap(), "Bad response");
531    }
532    #[tokio::test]
533    async fn test_completion_processor() {
534        let mut features = HashMap::new();
535        features
536            .insert(
537                "prompt".to_string(),
538                serde_json::json!({ "dtype" : "string", "_type" : "Value" }),
539            );
540        features
541            .insert(
542                "completion".to_string(),
543                serde_json::json!({ "dtype" : "string", "_type" : "Value" }),
544            );
545        features
546            .insert(
547                "label".to_string(),
548                serde_json::json!({ "dtype" : "number", "_type" : "Value" }),
549            );
550        let dataset = HuggingFaceDataset {
551            name: "test".to_string(),
552            split: "train".to_string(),
553            data: vec![
554                serde_json::json!({ "prompt" : "Test prompt", "completion" :
555                "Test completion", "label" : 1.0 })
556            ],
557            features,
558            metadata: HashMap::new(),
559        };
560        let processor = CompletionProcessor;
561        let result = processor.process(dataset).await.unwrap();
562        assert_eq!(result.samples.len(), 1);
563        assert_eq!(result.samples[0].prompt.as_ref().unwrap(), "Test prompt");
564        assert_eq!(result.samples[0].completion.as_ref().unwrap(), "Test completion");
565        assert_eq!(result.samples[0].label.unwrap(), 1.0);
566    }
567}