use std::io;
use std::path::Path;
#[derive(Debug)]
pub enum HfError {
Io(io::Error),
Parse(String),
MissingField(&'static str),
}
impl std::fmt::Display for HfError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HfError::Io(e) => write!(f, "IO error: {e}"),
HfError::Parse(msg) => write!(f, "parse error: {msg}"),
HfError::MissingField(field) => write!(f, "missing field: {field}"),
}
}
}
impl std::error::Error for HfError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
HfError::Io(e) => Some(e),
_ => None,
}
}
}
impl From<io::Error> for HfError {
fn from(e: io::Error) -> Self {
HfError::Io(e)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct HfSplitInfo {
pub name: String,
pub num_rows: usize,
pub num_bytes: usize,
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct HfDatasetCard {
pub dataset_name: String,
pub task_categories: Vec<String>,
pub language: Vec<String>,
pub size_categories: Vec<String>,
pub license: Option<String>,
pub pretty_name: Option<String>,
pub splits: Vec<HfSplitInfo>,
}
fn parse_scalar(s: &str) -> String {
let s = s.trim();
if s.len() >= 2
&& ((s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')))
{
s[1..s.len() - 1].to_owned()
} else {
s.to_owned()
}
}
fn indent_of(line: &str) -> usize {
line.len() - line.trim_start_matches(' ').len()
}
fn find_colon(s: &str) -> Option<usize> {
let mut in_single = false;
let mut in_double = false;
for (i, c) in s.char_indices() {
match c {
'\'' if !in_double => in_single = !in_single,
'"' if !in_single => in_double = !in_double,
':' if !in_single && !in_double => return Some(i),
_ => {}
}
}
None
}
fn parse_hf_yaml(yaml: &str) -> Vec<(String, Vec<String>)> {
let mut result: Vec<(String, Vec<String>)> = Vec::new();
let lines: Vec<&str> = yaml.lines().collect();
let mut i = 0;
while i < lines.len() {
let line = lines[i];
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') || trimmed == "---" {
i += 1;
continue;
}
if indent_of(line) != 0 {
i += 1;
continue;
}
if let Some(colon) = find_colon(line) {
let key = line[..colon].trim().to_owned();
let rest = line[colon + 1..].trim();
if rest.is_empty() {
i += 1;
let mut items: Vec<String> = Vec::new();
while i < lines.len() {
let sub = lines[i];
let sub_trimmed = sub.trim();
if !sub_trimmed.is_empty()
&& !sub_trimmed.starts_with('#')
&& indent_of(sub) == 0
{
break;
}
if let Some(rest) = sub_trimmed.strip_prefix("- ") {
items.push(parse_scalar(rest));
} else if sub_trimmed == "-" {
items.push(String::new());
}
i += 1;
}
result.push((key, items));
continue;
} else if rest.starts_with('[') && rest.ends_with(']') {
let inner = &rest[1..rest.len() - 1];
let items: Vec<String> = inner.split(',').map(parse_scalar).collect();
result.push((key, items));
} else {
result.push((key, vec![parse_scalar(rest)]));
}
}
i += 1;
}
result
}
fn parse_splits_from_yaml(yaml: &str) -> Vec<HfSplitInfo> {
let mut splits: Vec<HfSplitInfo> = Vec::new();
let lines: Vec<&str> = yaml.lines().collect();
let mut i = 0;
while i < lines.len() {
let line = lines[i];
let trimmed = line.trim();
if indent_of(line) == 0 && trimmed.starts_with("splits:") {
i += 1;
while i < lines.len() {
let sub = lines[i];
let sub_trimmed = sub.trim();
if !sub_trimmed.is_empty() && !sub_trimmed.starts_with('#') && indent_of(sub) == 0 {
break;
}
if sub_trimmed.starts_with("- name:") || sub_trimmed == "-" {
let name_part = if let Some(rest) = sub_trimmed.strip_prefix("- name:") {
parse_scalar(rest)
} else {
String::new()
};
let mut num_rows = 0usize;
let mut num_bytes = 0usize;
let item_indent = indent_of(sub);
i += 1;
while i < lines.len() {
let inner = lines[i];
let inner_trimmed = inner.trim();
if inner_trimmed.is_empty() || inner_trimmed.starts_with('#') {
i += 1;
continue;
}
let inner_indent = indent_of(inner);
if inner_indent <= item_indent
&& (inner_trimmed.starts_with('-') || inner_indent == 0)
{
break;
}
if let Some(colon) = find_colon(inner_trimmed) {
let k = inner_trimmed[..colon].trim();
let v = parse_scalar(&inner_trimmed[colon + 1..]);
match k {
"num_rows" => {
num_rows = v.parse().unwrap_or(0);
}
"num_bytes" => {
num_bytes = v.parse().unwrap_or(0);
}
_ => {}
}
}
i += 1;
}
splits.push(HfSplitInfo {
name: name_part,
num_rows,
num_bytes,
});
} else {
i += 1;
}
}
return splits;
}
i += 1;
}
splits
}
fn extract_frontmatter(input: &str) -> Option<&str> {
let input_trimmed = input.trim_start();
if !input_trimmed.starts_with("---") {
return None;
}
let after_open = input_trimmed.find('\n').map(|p| p + 1)?;
let rest = &input_trimmed[after_open..];
let close = rest.find("\n---")?;
Some(&rest[..close])
}
pub fn parse_dataset_card(yaml_str: &str) -> Result<HfDatasetCard, HfError> {
let yaml_body = extract_frontmatter(yaml_str).unwrap_or(yaml_str);
let pairs = parse_hf_yaml(yaml_body);
let mut card = HfDatasetCard::default();
for (key, values) in &pairs {
match key.as_str() {
"dataset_name" => {
card.dataset_name = values.first().cloned().unwrap_or_default();
}
"task_categories" => {
card.task_categories = values.clone();
}
"language" => {
card.language = values.clone();
}
"size_categories" => {
card.size_categories = values.clone();
}
"license" => {
let s = values.first().cloned().unwrap_or_default();
if !s.is_empty() {
card.license = Some(s);
}
}
"pretty_name" => {
let s = values.first().cloned().unwrap_or_default();
if !s.is_empty() {
card.pretty_name = Some(s);
}
}
_ => {}
}
}
card.splits = parse_splits_from_yaml(yaml_body);
Ok(card)
}
pub fn load_dataset_card(dir: &Path) -> Result<HfDatasetCard, HfError> {
let readme_path = dir.join("README.md");
let content = std::fs::read_to_string(&readme_path)?;
if extract_frontmatter(&content).is_none() {
return Err(HfError::MissingField("YAML frontmatter (---) in README.md"));
}
parse_dataset_card(&content)
}
pub fn to_hf_card(name: &str, n_rows: usize, task: &str) -> HfDatasetCard {
let size_cat = size_category(n_rows);
HfDatasetCard {
dataset_name: name.to_owned(),
task_categories: vec![task.to_owned()],
language: vec!["en".to_owned()],
size_categories: vec![size_cat],
license: None,
pretty_name: Some(name.to_owned()),
splits: vec![HfSplitInfo {
name: "train".to_owned(),
num_rows: n_rows,
num_bytes: n_rows * 64, }],
}
}
pub fn card_to_readme(card: &HfDatasetCard) -> String {
let mut out = String::from("---\n");
out.push_str(&format!("dataset_name: {}\n", yaml_str(&card.dataset_name)));
if !card.task_categories.is_empty() {
out.push_str("task_categories:\n");
for tc in &card.task_categories {
out.push_str(&format!(" - {}\n", yaml_str(tc)));
}
}
if !card.language.is_empty() {
out.push_str("language:\n");
for lang in &card.language {
out.push_str(&format!(" - {}\n", yaml_str(lang)));
}
}
if !card.size_categories.is_empty() {
out.push_str("size_categories:\n");
for sc in &card.size_categories {
out.push_str(&format!(" - {}\n", yaml_str(sc)));
}
}
if let Some(ref lic) = card.license {
out.push_str(&format!("license: {}\n", yaml_str(lic)));
}
if let Some(ref pn) = card.pretty_name {
out.push_str(&format!("pretty_name: {}\n", yaml_str(pn)));
}
if !card.splits.is_empty() {
out.push_str("splits:\n");
for split in &card.splits {
out.push_str(&format!(
" - name: {}\n num_rows: {}\n num_bytes: {}\n",
yaml_str(&split.name),
split.num_rows,
split.num_bytes,
));
}
}
out.push_str("---\n\n");
out.push_str(&format!("# {}\n\n", card.dataset_name));
if let Some(ref pn) = card.pretty_name {
out.push_str(&format!("{}\n\n", pn));
}
if !card.task_categories.is_empty() {
out.push_str(&format!("Tasks: {}\n", card.task_categories.join(", ")));
}
out
}
fn size_category(n: usize) -> String {
match n {
0..=999 => "n<1K".to_owned(),
1_000..=9_999 => "1K<n<10K".to_owned(),
10_000..=99_999 => "10K<n<100K".to_owned(),
100_000..=999_999 => "100K<n<1M".to_owned(),
1_000_000..=9_999_999 => "1M<n<10M".to_owned(),
_ => "10M<n<100M".to_owned(),
}
}
fn yaml_str(s: &str) -> String {
if s.contains(':') || s.contains('#') || s.contains('"') || s.contains('\'') {
format!("\"{}\"", s.replace('"', "\\\""))
} else {
s.to_owned()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
const SAMPLE_YAML: &str = "\
dataset_name: squad
task_categories:
- question-answering
language:
- en
size_categories:
- 100K<n<1M
license: cc-by-4.0
pretty_name: Stanford Question Answering Dataset
splits:
- name: train
num_rows: 87599
num_bytes: 29344551
- name: validation
num_rows: 10570
num_bytes: 3519936
";
#[test]
fn test_parse_dataset_card_basic() {
let card = parse_dataset_card(SAMPLE_YAML).expect("should parse");
assert_eq!(card.dataset_name, "squad");
assert_eq!(card.task_categories, vec!["question-answering"]);
assert_eq!(card.language, vec!["en"]);
assert_eq!(card.size_categories, vec!["100K<n<1M"]);
assert_eq!(card.license, Some("cc-by-4.0".to_owned()));
assert_eq!(
card.pretty_name,
Some("Stanford Question Answering Dataset".to_owned())
);
}
#[test]
fn test_parse_splits() {
let card = parse_dataset_card(SAMPLE_YAML).expect("should parse");
assert_eq!(card.splits.len(), 2);
assert_eq!(card.splits[0].name, "train");
assert_eq!(card.splits[0].num_rows, 87599);
assert_eq!(card.splits[0].num_bytes, 29344551);
assert_eq!(card.splits[1].name, "validation");
assert_eq!(card.splits[1].num_rows, 10570);
}
#[test]
fn test_to_hf_card_n_rows() {
let card = to_hf_card("my-ds", 5000, "classification");
assert_eq!(card.dataset_name, "my-ds");
assert_eq!(card.task_categories, vec!["classification"]);
assert!(!card.splits.is_empty());
let train_split = card.splits.iter().find(|s| s.name == "train");
assert!(train_split.is_some(), "should have a train split");
assert_eq!(train_split.expect("verified above").num_rows, 5000);
}
#[test]
fn test_card_to_readme_contains_name() {
let card = to_hf_card("awesome-dataset", 100, "text-classification");
let readme = card_to_readme(&card);
assert!(
readme.contains("awesome-dataset"),
"README should contain the dataset name"
);
}
#[test]
fn test_load_dataset_card_nonexistent() {
let result = load_dataset_card(Path::new("/nonexistent/path/that/does/not/exist"));
assert!(result.is_err(), "should fail for non-existent path");
}
#[test]
fn test_roundtrip_dataset_name() {
let original = to_hf_card("roundtrip-test", 2000, "regression");
let readme = card_to_readme(&original);
let parsed = parse_dataset_card(&readme).expect("round-trip parse should succeed");
assert_eq!(
parsed.dataset_name, original.dataset_name,
"dataset_name should survive round-trip"
);
}
#[test]
fn test_load_dataset_card_from_temp_dir() {
let tmp_dir = std::env::temp_dir().join("scirs2_hf_test_load_card");
std::fs::create_dir_all(&tmp_dir).expect("create temp dir");
let yaml_fm = "---\ndataset_name: temp-dataset\ntask_categories:\n - classification\nlanguage:\n - en\n---\n# temp-dataset\n";
let readme_path = tmp_dir.join("README.md");
let mut f = std::fs::File::create(&readme_path).expect("create README.md");
f.write_all(yaml_fm.as_bytes()).expect("write");
let card = load_dataset_card(&tmp_dir).expect("load card");
assert_eq!(card.dataset_name, "temp-dataset");
assert_eq!(card.task_categories, vec!["classification"]);
let _ = std::fs::remove_file(&readme_path);
let _ = std::fs::remove_dir(&tmp_dir);
}
#[test]
fn test_load_dataset_card_no_frontmatter() {
let tmp_dir = std::env::temp_dir().join("scirs2_hf_test_no_fm");
std::fs::create_dir_all(&tmp_dir).expect("create temp dir");
let readme_path = tmp_dir.join("README.md");
let mut f = std::fs::File::create(&readme_path).expect("create README.md");
f.write_all(b"# Plain README\n\nNo frontmatter here.\n")
.expect("write");
let result = load_dataset_card(&tmp_dir);
assert!(
matches!(result, Err(HfError::MissingField(_))),
"expected MissingField, got: {:?}",
result
);
let _ = std::fs::remove_file(&readme_path);
let _ = std::fs::remove_dir(&tmp_dir);
}
#[test]
fn test_size_categories() {
assert_eq!(size_category(500), "n<1K");
assert_eq!(size_category(5000), "1K<n<10K");
assert_eq!(size_category(50_000), "10K<n<100K");
assert_eq!(size_category(500_000), "100K<n<1M");
assert_eq!(size_category(5_000_000), "1M<n<10M");
assert_eq!(size_category(50_000_000), "10M<n<100M");
}
#[test]
fn test_parse_inline_list() {
let yaml = "dataset_name: inline-test\nlanguage: [en, fr, de]\n";
let card = parse_dataset_card(yaml).expect("parse");
assert_eq!(card.dataset_name, "inline-test");
assert_eq!(card.language, vec!["en", "fr", "de"]);
}
}