1use crate::error::{Error, Result};
4
5pub const VALID_TASK_CATEGORIES: &[&str] = &[
8 "text-classification",
10 "token-classification",
11 "table-question-answering",
12 "question-answering",
13 "zero-shot-classification",
14 "translation",
15 "summarization",
16 "feature-extraction",
17 "text-generation",
18 "text2text-generation",
19 "fill-mask",
20 "sentence-similarity",
21 "text-to-speech",
22 "text-to-audio",
23 "automatic-speech-recognition",
24 "audio-to-audio",
25 "audio-classification",
26 "voice-activity-detection",
27 "image-classification",
29 "object-detection",
30 "image-segmentation",
31 "text-to-image",
32 "image-to-text",
33 "image-to-image",
34 "image-to-video",
35 "unconditional-image-generation",
36 "video-classification",
37 "reinforcement-learning",
38 "robotics",
39 "tabular-classification",
40 "tabular-regression",
41 "visual-question-answering",
43 "document-question-answering",
44 "zero-shot-image-classification",
45 "graph-ml",
46 "mask-generation",
47 "zero-shot-object-detection",
48 "text-to-3d",
49 "image-to-3d",
50 "image-feature-extraction",
51 "other",
53];
54
55pub const VALID_SIZE_CATEGORIES: &[&str] = &[
57 "n<1K",
58 "1K<n<10K",
59 "10K<n<100K",
60 "100K<n<1M",
61 "1M<n<10M",
62 "10M<n<100M",
63 "100M<n<1B",
64 "1B<n<10B",
65 "10B<n<100B",
66 "100B<n<1T",
67 "n>1T",
68];
69
70pub const VALID_LICENSES: &[&str] = &[
72 "apache-2.0",
73 "mit",
74 "gpl-3.0",
75 "gpl-2.0",
76 "bsd-3-clause",
77 "bsd-2-clause",
78 "cc-by-4.0",
79 "cc-by-sa-4.0",
80 "cc-by-nc-4.0",
81 "cc-by-nc-sa-4.0",
82 "cc0-1.0",
83 "unlicense",
84 "openrail",
85 "openrail++",
86 "bigscience-openrail-m",
87 "creativeml-openrail-m",
88 "llama2",
89 "llama3",
90 "llama3.1",
91 "gemma",
92 "other",
93];
94
95#[derive(Debug, Clone)]
97pub struct ValidationError {
98 pub field: String,
100 pub value: String,
102 pub suggestions: Vec<String>,
104}
105
106impl std::fmt::Display for ValidationError {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 write!(f, "Invalid '{}': '{}' is not valid", self.field, self.value)?;
109 if !self.suggestions.is_empty() {
110 write!(f, ". Did you mean: {}?", self.suggestions.join(", "))?;
111 }
112 Ok(())
113 }
114}
115
116#[derive(Debug, Default)]
137pub struct DatasetCardValidator;
138
139impl DatasetCardValidator {
140 pub fn validate_readme(content: &str) -> Vec<ValidationError> {
146 let mut errors = Vec::new();
147
148 let Some(yaml_content) = Self::extract_frontmatter(content) else {
150 return errors;
151 };
152
153 let Ok(yaml) = serde_yaml::from_str::<serde_yaml::Value>(&yaml_content) else {
155 return errors;
156 };
157
158 if let Some(categories) = yaml.get("task_categories") {
160 if let Some(arr) = categories.as_sequence() {
161 for cat in arr {
162 if let Some(cat_str) = cat.as_str() {
163 if !VALID_TASK_CATEGORIES.contains(&cat_str) {
164 errors.push(ValidationError {
165 field: "task_categories".to_string(),
166 value: cat_str.to_string(),
167 suggestions: Self::suggest_similar(cat_str, VALID_TASK_CATEGORIES),
168 });
169 }
170 }
171 }
172 }
173 }
174
175 if let Some(sizes) = yaml.get("size_categories") {
177 if let Some(arr) = sizes.as_sequence() {
178 for size in arr {
179 if let Some(size_str) = size.as_str() {
180 if !VALID_SIZE_CATEGORIES.contains(&size_str) {
181 errors.push(ValidationError {
182 field: "size_categories".to_string(),
183 value: size_str.to_string(),
184 suggestions: Self::suggest_similar(size_str, VALID_SIZE_CATEGORIES),
185 });
186 }
187 }
188 }
189 }
190 }
191
192 errors
193 }
194
195 pub fn validate_readme_strict(content: &str) -> Result<()> {
199 let errors = Self::validate_readme(content);
200 if errors.is_empty() {
201 Ok(())
202 } else {
203 let msg = errors
204 .iter()
205 .map(|e| e.to_string())
206 .collect::<Vec<_>>()
207 .join("; ");
208 Err(Error::invalid_config(msg))
209 }
210 }
211
212 fn extract_frontmatter(content: &str) -> Option<String> {
214 let content = content.trim_start();
215 if !content.starts_with("---") {
216 return None;
217 }
218
219 let rest = &content[3..];
220 let end_idx = rest.find("\n---")?;
221 Some(rest[..end_idx].to_string())
222 }
223
224 pub(crate) fn suggest_similar(value: &str, valid: &[&str]) -> Vec<String> {
226 let value_lower = value.to_lowercase();
227 valid
228 .iter()
229 .filter(|v| {
230 let v_lower = v.to_lowercase();
231 v_lower.contains(&value_lower)
232 || value_lower.contains(&v_lower)
233 || Self::levenshtein(&value_lower, &v_lower) <= 3
234 })
235 .take(3)
236 .map(|s| (*s).to_string())
237 .collect()
238 }
239
240 pub(crate) fn levenshtein(a: &str, b: &str) -> usize {
242 let a_chars: Vec<char> = a.chars().collect();
243 let b_chars: Vec<char> = b.chars().collect();
244 let m = a_chars.len();
245 let n = b_chars.len();
246
247 if m == 0 {
248 return n;
249 }
250 if n == 0 {
251 return m;
252 }
253
254 let mut dp = vec![vec![0; n + 1]; m + 1];
255
256 for (i, row) in dp.iter_mut().enumerate().take(m + 1) {
257 row[0] = i;
258 }
259 for (j, cell) in dp[0].iter_mut().enumerate().take(n + 1) {
260 *cell = j;
261 }
262
263 for i in 1..=m {
264 for j in 1..=n {
265 let cost = usize::from(a_chars[i - 1] != b_chars[j - 1]);
266 dp[i][j] = (dp[i - 1][j] + 1)
267 .min(dp[i][j - 1] + 1)
268 .min(dp[i - 1][j - 1] + cost);
269 }
270 }
271
272 dp[m][n]
273 }
274
275 #[must_use]
277 pub fn is_valid_task_category(category: &str) -> bool {
278 VALID_TASK_CATEGORIES.contains(&category)
279 }
280
281 #[must_use]
283 pub fn is_valid_license(license: &str) -> bool {
284 let lower = license.to_lowercase();
285 VALID_LICENSES.contains(&lower.as_str())
286 }
287
288 #[must_use]
290 pub fn is_valid_size_category(size: &str) -> bool {
291 VALID_SIZE_CATEGORIES.contains(&size)
292 }
293
294 #[must_use]
296 pub fn suggest_task_category(invalid: &str) -> Option<&'static str> {
297 match invalid {
298 "text2text-generation" => Some("text2text-generation"), "code-generation" | "code" => Some("text-generation"),
300 "qa" | "QA" => Some("question-answering"),
301 "ner" | "NER" => Some("token-classification"),
302 "sentiment" => Some("text-classification"),
303 _ => VALID_TASK_CATEGORIES
304 .iter()
305 .find(|c| c.starts_with(invalid) || invalid.starts_with(*c))
306 .copied(),
307 }
308 }
309}