Skip to main content

alimentar/hf_hub/
validation.rs

1//! Dataset card validation for HuggingFace Hub.
2
3use crate::error::{Error, Result};
4
5/// Valid HuggingFace task categories as of 2024.
6/// Source: https://huggingface.co/docs/hub/datasets-cards#task-categories
7pub const VALID_TASK_CATEGORIES: &[&str] = &[
8    // NLP
9    "text-classification",
10    "token-classification",
11    "table-question-answering",
12    "question-answering",
13    "zero-shot-classification",
14    "translation",
15    "summarization",
16    "feature-extraction",
17    "text-generation",
18    "text2text-generation",
19    "fill-mask",
20    "sentence-similarity",
21    "text-to-speech",
22    "text-to-audio",
23    "automatic-speech-recognition",
24    "audio-to-audio",
25    "audio-classification",
26    "voice-activity-detection",
27    // Computer Vision
28    "image-classification",
29    "object-detection",
30    "image-segmentation",
31    "text-to-image",
32    "image-to-text",
33    "image-to-image",
34    "image-to-video",
35    "unconditional-image-generation",
36    "video-classification",
37    "reinforcement-learning",
38    "robotics",
39    "tabular-classification",
40    "tabular-regression",
41    // Multimodal
42    "visual-question-answering",
43    "document-question-answering",
44    "zero-shot-image-classification",
45    "graph-ml",
46    "mask-generation",
47    "zero-shot-object-detection",
48    "text-to-3d",
49    "image-to-3d",
50    "image-feature-extraction",
51    // Other
52    "other",
53];
54
55/// Valid HuggingFace size categories.
56pub const VALID_SIZE_CATEGORIES: &[&str] = &[
57    "n<1K",
58    "1K<n<10K",
59    "10K<n<100K",
60    "100K<n<1M",
61    "1M<n<10M",
62    "10M<n<100M",
63    "100M<n<1B",
64    "1B<n<10B",
65    "10B<n<100B",
66    "100B<n<1T",
67    "n>1T",
68];
69
70/// Common valid SPDX license identifiers
71pub const VALID_LICENSES: &[&str] = &[
72    "apache-2.0",
73    "mit",
74    "gpl-3.0",
75    "gpl-2.0",
76    "bsd-3-clause",
77    "bsd-2-clause",
78    "cc-by-4.0",
79    "cc-by-sa-4.0",
80    "cc-by-nc-4.0",
81    "cc-by-nc-sa-4.0",
82    "cc0-1.0",
83    "unlicense",
84    "openrail",
85    "openrail++",
86    "bigscience-openrail-m",
87    "creativeml-openrail-m",
88    "llama2",
89    "llama3",
90    "llama3.1",
91    "gemma",
92    "other",
93];
94
95/// Validation error for dataset card metadata.
96#[derive(Debug, Clone)]
97pub struct ValidationError {
98    /// The field that has an invalid value
99    pub field: String,
100    /// The invalid value
101    pub value: String,
102    /// Suggested valid values (if applicable)
103    pub suggestions: Vec<String>,
104}
105
106impl std::fmt::Display for ValidationError {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        write!(f, "Invalid '{}': '{}' is not valid", self.field, self.value)?;
109        if !self.suggestions.is_empty() {
110            write!(f, ". Did you mean: {}?", self.suggestions.join(", "))?;
111        }
112        Ok(())
113    }
114}
115
116/// Validator for HuggingFace dataset card YAML metadata.
117///
118/// Validates common fields against HuggingFace's official accepted values.
119///
120/// # Example
121///
122/// ```
123/// use alimentar::hf_hub::DatasetCardValidator;
124///
125/// let readme = r"---
126/// license: mit
127/// task_categories:
128///   - translation
129/// ---
130/// # My Dataset
131/// ";
132///
133/// let errors = DatasetCardValidator::validate_readme(readme);
134/// assert!(errors.is_empty());
135/// ```
136#[derive(Debug, Default)]
137pub struct DatasetCardValidator;
138
139impl DatasetCardValidator {
140    /// Validates a README.md content and returns any validation errors.
141    ///
142    /// Parses the YAML frontmatter and validates:
143    /// - `task_categories`: Must be from the official HuggingFace list
144    /// - `size_categories`: Must match the HuggingFace format
145    pub fn validate_readme(content: &str) -> Vec<ValidationError> {
146        let mut errors = Vec::new();
147
148        // Extract YAML frontmatter (between --- markers)
149        let Some(yaml_content) = Self::extract_frontmatter(content) else {
150            return errors;
151        };
152
153        // Parse YAML
154        let Ok(yaml) = serde_yaml::from_str::<serde_yaml::Value>(&yaml_content) else {
155            return errors;
156        };
157
158        // Validate task_categories
159        if let Some(categories) = yaml.get("task_categories") {
160            if let Some(arr) = categories.as_sequence() {
161                for cat in arr {
162                    if let Some(cat_str) = cat.as_str() {
163                        if !VALID_TASK_CATEGORIES.contains(&cat_str) {
164                            errors.push(ValidationError {
165                                field: "task_categories".to_string(),
166                                value: cat_str.to_string(),
167                                suggestions: Self::suggest_similar(cat_str, VALID_TASK_CATEGORIES),
168                            });
169                        }
170                    }
171                }
172            }
173        }
174
175        // Validate size_categories
176        if let Some(sizes) = yaml.get("size_categories") {
177            if let Some(arr) = sizes.as_sequence() {
178                for size in arr {
179                    if let Some(size_str) = size.as_str() {
180                        if !VALID_SIZE_CATEGORIES.contains(&size_str) {
181                            errors.push(ValidationError {
182                                field: "size_categories".to_string(),
183                                value: size_str.to_string(),
184                                suggestions: Self::suggest_similar(size_str, VALID_SIZE_CATEGORIES),
185                            });
186                        }
187                    }
188                }
189            }
190        }
191
192        errors
193    }
194
195    /// Validates a README file and returns a Result.
196    ///
197    /// Returns Ok(()) if valid, or Err with combined error messages.
198    pub fn validate_readme_strict(content: &str) -> Result<()> {
199        let errors = Self::validate_readme(content);
200        if errors.is_empty() {
201            Ok(())
202        } else {
203            let msg = errors
204                .iter()
205                .map(|e| e.to_string())
206                .collect::<Vec<_>>()
207                .join("; ");
208            Err(Error::invalid_config(msg))
209        }
210    }
211
212    /// Extracts YAML frontmatter from markdown content.
213    fn extract_frontmatter(content: &str) -> Option<String> {
214        let content = content.trim_start();
215        if !content.starts_with("---") {
216            return None;
217        }
218
219        let rest = &content[3..];
220        let end_idx = rest.find("\n---")?;
221        Some(rest[..end_idx].to_string())
222    }
223
224    /// Suggests similar valid values using simple substring matching.
225    pub(crate) fn suggest_similar(value: &str, valid: &[&str]) -> Vec<String> {
226        let value_lower = value.to_lowercase();
227        valid
228            .iter()
229            .filter(|v| {
230                let v_lower = v.to_lowercase();
231                v_lower.contains(&value_lower)
232                    || value_lower.contains(&v_lower)
233                    || Self::levenshtein(&value_lower, &v_lower) <= 3
234            })
235            .take(3)
236            .map(|s| (*s).to_string())
237            .collect()
238    }
239
240    /// Simple Levenshtein distance for fuzzy matching.
241    pub(crate) fn levenshtein(a: &str, b: &str) -> usize {
242        let a_chars: Vec<char> = a.chars().collect();
243        let b_chars: Vec<char> = b.chars().collect();
244        let m = a_chars.len();
245        let n = b_chars.len();
246
247        if m == 0 {
248            return n;
249        }
250        if n == 0 {
251            return m;
252        }
253
254        let mut dp = vec![vec![0; n + 1]; m + 1];
255
256        for (i, row) in dp.iter_mut().enumerate().take(m + 1) {
257            row[0] = i;
258        }
259        for (j, cell) in dp[0].iter_mut().enumerate().take(n + 1) {
260            *cell = j;
261        }
262
263        for i in 1..=m {
264            for j in 1..=n {
265                let cost = usize::from(a_chars[i - 1] != b_chars[j - 1]);
266                dp[i][j] = (dp[i - 1][j] + 1)
267                    .min(dp[i][j - 1] + 1)
268                    .min(dp[i - 1][j - 1] + cost);
269            }
270        }
271
272        dp[m][n]
273    }
274
275    /// Check if a task category is valid.
276    #[must_use]
277    pub fn is_valid_task_category(category: &str) -> bool {
278        VALID_TASK_CATEGORIES.contains(&category)
279    }
280
281    /// Check if a license is valid (case-insensitive).
282    #[must_use]
283    pub fn is_valid_license(license: &str) -> bool {
284        let lower = license.to_lowercase();
285        VALID_LICENSES.contains(&lower.as_str())
286    }
287
288    /// Check if a size category is valid.
289    #[must_use]
290    pub fn is_valid_size_category(size: &str) -> bool {
291        VALID_SIZE_CATEGORIES.contains(&size)
292    }
293
294    /// Suggest a similar valid task category for common mistakes.
295    #[must_use]
296    pub fn suggest_task_category(invalid: &str) -> Option<&'static str> {
297        match invalid {
298            "text2text-generation" => Some("text2text-generation"), // This is actually valid
299            "code-generation" | "code" => Some("text-generation"),
300            "qa" | "QA" => Some("question-answering"),
301            "ner" | "NER" => Some("token-classification"),
302            "sentiment" => Some("text-classification"),
303            _ => VALID_TASK_CATEGORIES
304                .iter()
305                .find(|c| c.starts_with(invalid) || invalid.starts_with(*c))
306                .copied(),
307        }
308    }
309}