1#![warn(clippy::all, clippy::pedantic)]
2use anyhow::{bail, Context, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use super::core::{TrainingDataset, TrainingSample, DatasetStats, TrainingFormat};
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct HfDatasetConfig {
9 pub source: String,
10 pub split: String,
11 pub format: Option<String>,
12 pub rpl_filter: Option<HashMap<String, serde_json::Value>>,
13 pub revision: Option<String>,
14 pub streaming: bool,
15 pub trust_remote_code: bool,
16 pub num_proc: Option<usize>,
17}
18pub struct HuggingFaceDataset {
19 pub name: String,
20 pub split: String,
21 pub data: Vec<serde_json::Value>,
22 pub features: HashMap<String, serde_json::Value>,
23 pub metadata: HashMap<String, serde_json::Value>,
24}
25impl HuggingFaceDataset {
26 pub async fn load(name: &str, split: &str, cache_dir: &Path) -> Result<Self> {
27 println!("🔄 Loading dataset {} from HuggingFace Hub...", name);
28 let cache = hf_hub::Cache::new(cache_dir.to_path_buf());
29 let repo = cache.dataset(name.to_string());
30 let json_files = Self::find_json_files(&repo, name).await?;
31 if json_files.is_empty() {
32 bail!("No JSON files found in dataset {}", name);
33 }
34 let data_file = &json_files[0];
35 println!("📁 Found data file: {}", data_file.display());
36 let data = Self::load_json_data(data_file).await?;
37 let features = Self::infer_features(&data)?;
38 let metadata = Self::extract_metadata(&repo).await?;
39 println!("✅ Successfully loaded {} samples from {}", data.len(), name);
40 Ok(HuggingFaceDataset {
41 name: name.to_string(),
42 split: split.to_string(),
43 data,
44 features,
45 metadata,
46 })
47 }
48 async fn find_json_files(
49 repo: &hf_hub::CacheRepo,
50 name: &str,
51 ) -> Result<Vec<PathBuf>> {
52 let mut json_files = Vec::new();
53 let possible_files = vec![
54 format!("{}.json", name.replace('/', "--")), "train.json".to_string(),
55 "validation.json".to_string(), "test.json".to_string(), "data.json"
56 .to_string(),
57 ];
58 for file_name in possible_files {
59 if let Some(path) = repo.get(&file_name) {
60 json_files.push(path);
61 }
62 }
63 if json_files.is_empty() {}
64 Ok(json_files)
65 }
66 async fn load_json_data(file_path: &Path) -> Result<Vec<serde_json::Value>> {
67 let content = tokio::fs::read_to_string(file_path)
68 .await
69 .with_context(|| {
70 format!("Failed to read JSON file: {}", file_path.display())
71 })?;
72 let json_value: serde_json::Value = serde_json::from_str(&content)
73 .with_context(|| {
74 format!("Failed to parse JSON from: {}", file_path.display())
75 })?;
76 match json_value {
77 serde_json::Value::Array(arr) => Ok(arr),
78 serde_json::Value::Object(obj) => {
79 if let Some(data) = obj.get("data") {
80 if let Some(arr) = data.as_array() {
81 return Ok(arr.clone());
82 }
83 }
84 if let Some(train) = obj.get("train") {
85 if let Some(arr) = train.as_array() {
86 return Ok(arr.clone());
87 }
88 }
89 Ok(vec![serde_json::Value::Object(obj)])
90 }
91 _ => bail!("Unsupported JSON structure in {}", file_path.display()),
92 }
93 }
94 fn infer_features(
95 data: &[serde_json::Value],
96 ) -> Result<HashMap<String, serde_json::Value>> {
97 let mut features = HashMap::new();
98 if data.is_empty() {
99 return Ok(features);
100 }
101 let sample_size = std::cmp::min(10, data.len());
102 let samples = &data[..sample_size];
103 let mut key_types = HashMap::new();
104 for sample in samples {
105 if let Some(obj) = sample.as_object() {
106 for (key, value) in obj {
107 let type_str = match value {
108 serde_json::Value::String(_) => "string",
109 serde_json::Value::Number(_) => "number",
110 serde_json::Value::Bool(_) => "boolean",
111 serde_json::Value::Array(_) => "array",
112 serde_json::Value::Object(_) => "object",
113 serde_json::Value::Null => "null",
114 };
115 key_types.insert(key.clone(), type_str.to_string());
116 }
117 }
118 }
119 for (key, type_str) in key_types {
120 features
121 .insert(
122 key,
123 serde_json::json!({ "dtype" : type_str, "_type" : "Value" }),
124 );
125 }
126 Ok(features)
127 }
128 async fn extract_metadata(
129 _repo: &hf_hub::CacheRepo,
130 ) -> Result<HashMap<String, serde_json::Value>> {
131 let mut metadata = HashMap::new();
132 metadata.insert("dataset_name".to_string(), serde_json::json!("unknown"));
133 metadata.insert("split".to_string(), serde_json::json!("unknown"));
134 metadata.insert("num_samples".to_string(), serde_json::json!(0));
135 Ok(metadata)
136 }
137 pub fn get_features(&self) -> Result<Vec<String>> {
138 Ok(self.features.keys().map(|s| s.to_string()).collect())
139 }
140 pub fn info(&self) -> HashMap<String, serde_json::Value> {
141 let mut info = HashMap::new();
142 info.insert("name".to_string(), serde_json::json!(self.name));
143 info.insert("split".to_string(), serde_json::json!(self.split));
144 info.insert("num_samples".to_string(), serde_json::json!(self.data.len()));
145 info.insert("features".to_string(), serde_json::json!(self.features));
146 info.insert("metadata".to_string(), serde_json::json!(self.metadata));
147 info
148 }
149}
150#[async_trait::async_trait]
151pub trait DatasetProcessor {
152 async fn process(&self, dataset: HuggingFaceDataset) -> Result<TrainingDataset>;
153}
154pub struct PreferenceProcessor;
155#[async_trait::async_trait]
156impl DatasetProcessor for PreferenceProcessor {
157 async fn process(&self, dataset: HuggingFaceDataset) -> Result<TrainingDataset> {
158 let mut samples = Vec::new();
159 for item in dataset.data {
160 if let Some(obj) = item.as_object() {
161 let sample = TrainingSample {
162 prompt: obj
163 .get("prompt")
164 .and_then(|v| v.as_str())
165 .map(|s| s.to_string()),
166 chosen: obj
167 .get("chosen")
168 .and_then(|v| v.as_str())
169 .map(|s| s.to_string()),
170 rejected: obj
171 .get("rejected")
172 .and_then(|v| v.as_str())
173 .map(|s| s.to_string()),
174 completion: None,
175 label: None,
176 meta: obj
177 .clone()
178 .into_iter()
179 .filter(|(k, _)| {
180 !matches!(k.as_str(), "prompt" | "chosen" | "rejected")
181 })
182 .map(|(k, v)| (k, v))
183 .collect(),
184 };
185 samples.push(sample);
186 }
187 }
188 let format = TrainingFormat::Preference {
189 chosen_field: "chosen".to_string(),
190 rejected_field: "rejected".to_string(),
191 };
192 let statistics = Self::compute_statistics(&samples);
193 Ok(TrainingDataset {
194 samples,
195 format,
196 statistics,
197 })
198 }
199}
200impl PreferenceProcessor {
201 fn compute_statistics(samples: &[TrainingSample]) -> DatasetStats {
202 let total_samples = samples.len();
203 let mut total_prompt_length = 0;
204 let mut prompt_count = 0;
205 let mut field_coverage = HashMap::new();
206 for sample in samples {
207 if sample.prompt.is_some() {
208 *field_coverage.entry("prompt".to_string()).or_insert(0.0) += 1.0;
209 total_prompt_length += sample.prompt.as_ref().unwrap().len();
210 prompt_count += 1;
211 }
212 if sample.chosen.is_some() {
213 *field_coverage.entry("chosen".to_string()).or_insert(0.0) += 1.0;
214 }
215 if sample.rejected.is_some() {
216 *field_coverage.entry("rejected".to_string()).or_insert(0.0) += 1.0;
217 }
218 }
219 for count in field_coverage.values_mut() {
220 *count = *count / total_samples as f64;
221 }
222 let avg_prompt_length = if prompt_count > 0 {
223 total_prompt_length as f64 / prompt_count as f64
224 } else {
225 0.0
226 };
227 let quality_score = Some(
228 (field_coverage.get("prompt").unwrap_or(&0.0)
229 + field_coverage.get("chosen").unwrap_or(&0.0)
230 + field_coverage.get("rejected").unwrap_or(&0.0)) / 3.0,
231 );
232 DatasetStats {
233 total_samples,
234 avg_prompt_length,
235 avg_completion_length: 0.0,
236 field_coverage,
237 quality_score,
238 }
239 }
240}
241pub struct CompletionProcessor;
242#[async_trait::async_trait]
243impl DatasetProcessor for CompletionProcessor {
244 async fn process(&self, dataset: HuggingFaceDataset) -> Result<TrainingDataset> {
245 let mut samples = Vec::new();
246 for item in dataset.data {
247 if let Some(obj) = item.as_object() {
248 let sample = TrainingSample {
249 prompt: obj
250 .get("prompt")
251 .and_then(|v| v.as_str())
252 .map(|s| s.to_string()),
253 chosen: None,
254 rejected: None,
255 completion: obj
256 .get("completion")
257 .and_then(|v| v.as_str())
258 .map(|s| s.to_string()),
259 label: obj.get("label").and_then(|v| v.as_f64()).map(|f| f as f32),
260 meta: obj
261 .clone()
262 .into_iter()
263 .filter(|(k, _)| {
264 !matches!(k.as_str(), "prompt" | "completion" | "label")
265 })
266 .map(|(k, v)| (k, v))
267 .collect(),
268 };
269 samples.push(sample);
270 }
271 }
272 let format = TrainingFormat::Completion {
273 completion_field: "completion".to_string(),
274 label_field: Some("label".to_string()),
275 };
276 let statistics = Self::compute_statistics(&samples);
277 Ok(TrainingDataset {
278 samples,
279 format,
280 statistics,
281 })
282 }
283}
284impl CompletionProcessor {
285 fn compute_statistics(samples: &[TrainingSample]) -> DatasetStats {
286 let total_samples = samples.len();
287 let mut total_prompt_length = 0;
288 let mut total_completion_length = 0;
289 let mut prompt_count = 0;
290 let mut completion_count = 0;
291 let mut field_coverage = HashMap::new();
292 for sample in samples {
293 if sample.prompt.is_some() {
294 *field_coverage.entry("prompt".to_string()).or_insert(0.0) += 1.0;
295 total_prompt_length += sample.prompt.as_ref().unwrap().len();
296 prompt_count += 1;
297 }
298 if sample.completion.is_some() {
299 *field_coverage.entry("completion".to_string()).or_insert(0.0) += 1.0;
300 total_completion_length += sample.completion.as_ref().unwrap().len();
301 completion_count += 1;
302 }
303 if sample.label.is_some() {
304 *field_coverage.entry("label".to_string()).or_insert(0.0) += 1.0;
305 }
306 }
307 for count in field_coverage.values_mut() {
308 *count = *count / total_samples as f64;
309 }
310 let avg_prompt_length = if prompt_count > 0 {
311 total_prompt_length as f64 / prompt_count as f64
312 } else {
313 0.0
314 };
315 let avg_completion_length = if completion_count > 0 {
316 total_completion_length as f64 / completion_count as f64
317 } else {
318 0.0
319 };
320 let quality_score = Some(
321 (field_coverage.get("prompt").unwrap_or(&0.0)
322 + field_coverage.get("completion").unwrap_or(&0.0)) / 2.0,
323 );
324 DatasetStats {
325 total_samples,
326 avg_prompt_length,
327 avg_completion_length,
328 field_coverage,
329 quality_score,
330 }
331 }
332}
333pub struct InstructionProcessor;
334#[async_trait::async_trait]
335impl DatasetProcessor for InstructionProcessor {
336 async fn process(&self, dataset: HuggingFaceDataset) -> Result<TrainingDataset> {
337 let mut samples = Vec::new();
338 for item in dataset.data {
339 if let Some(obj) = item.as_object() {
340 let sample = TrainingSample {
341 prompt: obj
342 .get("instruction")
343 .and_then(|v| v.as_str())
344 .map(|s| s.to_string()),
345 chosen: None,
346 rejected: None,
347 completion: obj
348 .get("output")
349 .and_then(|v| v.as_str())
350 .map(|s| s.to_string()),
351 label: None,
352 meta: obj
353 .clone()
354 .into_iter()
355 .filter(|(k, _)| !matches!(k.as_str(), "instruction" | "output"))
356 .map(|(k, v)| (k, v))
357 .collect(),
358 };
359 samples.push(sample);
360 }
361 }
362 let format = TrainingFormat::Instruction {
363 instruction_field: "instruction".to_string(),
364 output_field: "output".to_string(),
365 };
366 let statistics = Self::compute_statistics(&samples);
367 Ok(TrainingDataset {
368 samples,
369 format,
370 statistics,
371 })
372 }
373}
374impl InstructionProcessor {
375 fn compute_statistics(samples: &[TrainingSample]) -> DatasetStats {
376 let total_samples = samples.len();
377 let mut total_prompt_length = 0;
378 let mut total_completion_length = 0;
379 let mut prompt_count = 0;
380 let mut completion_count = 0;
381 let mut field_coverage = HashMap::new();
382 for sample in samples {
383 if sample.prompt.is_some() {
384 *field_coverage.entry("instruction".to_string()).or_insert(0.0) += 1.0;
385 total_prompt_length += sample.prompt.as_ref().unwrap().len();
386 prompt_count += 1;
387 }
388 if sample.completion.is_some() {
389 *field_coverage.entry("output".to_string()).or_insert(0.0) += 1.0;
390 total_completion_length += sample.completion.as_ref().unwrap().len();
391 completion_count += 1;
392 }
393 }
394 for count in field_coverage.values_mut() {
395 *count = *count / total_samples as f64;
396 }
397 let avg_prompt_length = if prompt_count > 0 {
398 total_prompt_length as f64 / prompt_count as f64
399 } else {
400 0.0
401 };
402 let avg_completion_length = if completion_count > 0 {
403 total_completion_length as f64 / completion_count as f64
404 } else {
405 0.0
406 };
407 let quality_score = Some(
408 (field_coverage.get("instruction").unwrap_or(&0.0)
409 + field_coverage.get("output").unwrap_or(&0.0)) / 2.0,
410 );
411 DatasetStats {
412 total_samples,
413 avg_prompt_length,
414 avg_completion_length,
415 field_coverage,
416 quality_score,
417 }
418 }
419}
420pub struct HfProcessor {
421 cache_dir: PathBuf,
422 processors: HashMap<String, Box<dyn DatasetProcessor + Send + Sync>>,
423}
424impl HfProcessor {
425 pub fn new(cache_dir: PathBuf) -> Self {
426 let mut processors: HashMap<String, Box<dyn DatasetProcessor + Send + Sync>> = HashMap::new();
427 processors.insert("preference".to_string(), Box::new(PreferenceProcessor));
428 processors.insert("completion".to_string(), Box::new(CompletionProcessor));
429 processors.insert("instruction".to_string(), Box::new(InstructionProcessor));
430 Self { cache_dir, processors }
431 }
432 pub async fn process_dataset(
433 &self,
434 dataset_name: &str,
435 config: &HfDatasetConfig,
436 ) -> Result<TrainingDataset> {
437 let dataset = HuggingFaceDataset::load(
438 dataset_name,
439 &config.split,
440 &self.cache_dir,
441 )
442 .await?;
443 let dataset_type = self.detect_dataset_type(&dataset)?;
444 let processor = self
445 .processors
446 .get(&dataset_type)
447 .ok_or_else(|| {
448 anyhow::anyhow!("No processor for dataset type: {}", dataset_type)
449 })?;
450 let mut processed = processor.process(dataset).await?;
451 if let Some(filters) = &config.rpl_filter {
452 processed = self.apply_filters(processed, filters)?;
453 }
454 Ok(processed)
455 }
456 fn detect_dataset_type(&self, dataset: &HuggingFaceDataset) -> Result<String> {
457 let features = dataset.get_features()?;
458 if features.contains(&"chosen".to_string())
459 && features.contains(&"rejected".to_string())
460 {
461 Ok("preference".to_string())
462 } else if features.contains(&"completion".to_string())
463 && features.contains(&"label".to_string())
464 {
465 Ok("completion".to_string())
466 } else if features.contains(&"instruction".to_string())
467 && features.contains(&"output".to_string())
468 {
469 Ok("instruction".to_string())
470 } else {
471 bail!("Cannot determine dataset type from features: {:?}", features)
472 }
473 }
474 fn apply_filters(
475 &self,
476 dataset: TrainingDataset,
477 _filters: &HashMap<String, serde_json::Value>,
478 ) -> Result<TrainingDataset> {
479 Ok(dataset)
480 }
481}
482impl Default for HfProcessor {
483 fn default() -> Self {
484 Self::new(PathBuf::from("./hf_cache"))
485 }
486}
487#[cfg(test)]
488mod tests {
489 use super::*;
490 #[tokio::test]
491 async fn test_hf_processor_creation() {
492 let processor = HfProcessor::default();
493 assert!(processor.processors.contains_key("preference"));
494 assert!(processor.processors.contains_key("completion"));
495 assert!(processor.processors.contains_key("instruction"));
496 }
497 #[tokio::test]
498 async fn test_preference_processor() {
499 let mut features = HashMap::new();
500 features
501 .insert(
502 "prompt".to_string(),
503 serde_json::json!({ "dtype" : "string", "_type" : "Value" }),
504 );
505 features
506 .insert(
507 "chosen".to_string(),
508 serde_json::json!({ "dtype" : "string", "_type" : "Value" }),
509 );
510 features
511 .insert(
512 "rejected".to_string(),
513 serde_json::json!({ "dtype" : "string", "_type" : "Value" }),
514 );
515 let dataset = HuggingFaceDataset {
516 name: "test".to_string(),
517 split: "train".to_string(),
518 data: vec![
519 serde_json::json!({ "prompt" : "Test prompt", "chosen" : "Good response",
520 "rejected" : "Bad response" })
521 ],
522 features,
523 metadata: HashMap::new(),
524 };
525 let processor = PreferenceProcessor;
526 let result = processor.process(dataset).await.unwrap();
527 assert_eq!(result.samples.len(), 1);
528 assert_eq!(result.samples[0].prompt.as_ref().unwrap(), "Test prompt");
529 assert_eq!(result.samples[0].chosen.as_ref().unwrap(), "Good response");
530 assert_eq!(result.samples[0].rejected.as_ref().unwrap(), "Bad response");
531 }
532 #[tokio::test]
533 async fn test_completion_processor() {
534 let mut features = HashMap::new();
535 features
536 .insert(
537 "prompt".to_string(),
538 serde_json::json!({ "dtype" : "string", "_type" : "Value" }),
539 );
540 features
541 .insert(
542 "completion".to_string(),
543 serde_json::json!({ "dtype" : "string", "_type" : "Value" }),
544 );
545 features
546 .insert(
547 "label".to_string(),
548 serde_json::json!({ "dtype" : "number", "_type" : "Value" }),
549 );
550 let dataset = HuggingFaceDataset {
551 name: "test".to_string(),
552 split: "train".to_string(),
553 data: vec![
554 serde_json::json!({ "prompt" : "Test prompt", "completion" :
555 "Test completion", "label" : 1.0 })
556 ],
557 features,
558 metadata: HashMap::new(),
559 };
560 let processor = CompletionProcessor;
561 let result = processor.process(dataset).await.unwrap();
562 assert_eq!(result.samples.len(), 1);
563 assert_eq!(result.samples[0].prompt.as_ref().unwrap(), "Test prompt");
564 assert_eq!(result.samples[0].completion.as_ref().unwrap(), "Test completion");
565 assert_eq!(result.samples[0].label.unwrap(), 1.0);
566 }
567}