use std::env;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
use rho_core::{
RhoResult, arg_value, from_yaml, has_flag, normalize_actor_id, uuid_like, yaml_quote,
};
use serde::Deserialize;
fn usage() -> ! {
eprintln!(
"usage: rho dataset set <name> --public <source> [--root <repo>] [--owner <owner>] [--revision <rev>] [--commit] [--push] [--pr]"
);
std::process::exit(2);
}
pub fn run(args: &[String]) -> RhoResult<()> {
let Some(name) = args.first().filter(|value| !value.starts_with('-')) else {
usage();
};
let dataset_slug = dataset_path_slug(name)?;
let root = arg_value(args, "--root")
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from("."));
let owner_arg = arg_value(args, "--owner")
.or_else(active_handle)
.unwrap_or_else(|| "user1".to_string());
let owner_id = normalize_actor_id(&owner_arg)?;
let dataset_dir = root.join("datasets").join(&dataset_slug);
let manifest_path = dataset_dir.join("dataset.yaml");
let existing = read_existing_dataset(&manifest_path)?;
let dataset_uuid = arg_value(args, "--uuid")
.or(existing.uuid)
.unwrap_or_else(uuid_like);
let description = arg_value(args, "--description")
.or(existing.description)
.unwrap_or_default();
let public_source = parse_public_source(args)?;
if should_branch(args) {
switch_branch(
&root,
&arg_value(args, "--branch").unwrap_or_else(|| {
format!(
"{}/dataset-{}",
github_handle_from_identity(&owner_id).unwrap_or(owner_arg.clone()),
dataset_slug
)
}),
)?;
}
fs::create_dir_all(&dataset_dir)?;
fs::write(
&manifest_path,
format!(
concat!(
"version: 1\n",
"dataset:\n",
" uuid: {}\n",
" name: {}\n",
" owner: {}\n",
" description: {}\n",
" variants:\n",
" public:\n",
" tier: \"public\"\n",
" source:\n",
" kind: {}\n",
" repo: {}\n",
" url: {}\n",
" revision: {}\n",
" materialization:\n",
" mode: \"on_demand\"\n",
" path: {}\n",
),
yaml_quote(&dataset_uuid),
yaml_quote(name),
yaml_quote(&owner_id),
yaml_quote(&description),
yaml_quote(&public_source.kind),
yaml_quote(&public_source.repo),
yaml_quote(&public_source.url),
yaml_quote(&public_source.revision),
yaml_quote(&format!(".rho/external/datasets/{dataset_slug}/public")),
),
)?;
println!("dataset: {name}");
println!("uuid: {dataset_uuid}");
println!("public source: {}", public_source.url);
println!("manifest: {}", manifest_path.display());
finish_dataset_set(args, &root, &dataset_slug)?;
Ok(())
}
fn parse_public_source(args: &[String]) -> RhoResult<PublicSource> {
if let Some(value) = arg_value(args, "--public") {
return public_source_from_value(&value, arg_value(args, "--revision"));
}
let kind = arg_value(args, "--source-kind").unwrap_or_else(|| "huggingface".to_string());
let repo = arg_value(args, "--source-repo").ok_or("missing --public or --source-repo")?;
let url = arg_value(args, "--source-url").unwrap_or_else(|| default_source_url(&kind, &repo));
Ok(PublicSource {
kind,
repo,
url,
revision: arg_value(args, "--revision").unwrap_or_else(|| "main".to_string()),
})
}
fn public_source_from_value(value: &str, revision: Option<String>) -> RhoResult<PublicSource> {
if let Some(repo) = value.strip_prefix("repo:huggingface:") {
return Ok(huggingface_source(repo, revision));
}
if let Some(repo) = value.strip_prefix("huggingface:") {
return Ok(huggingface_source(repo, revision));
}
if let Some(repo) = value.strip_prefix("hf:") {
return Ok(huggingface_source(repo, revision));
}
if let Some(repo) = value.strip_prefix("https://huggingface.co/datasets/") {
let repo = repo.trim_matches('/');
return Ok(huggingface_source(repo, revision).with_url(value.to_string()));
}
Err(format!("unsupported public dataset source: {value}").into())
}
fn huggingface_source(repo: &str, revision: Option<String>) -> PublicSource {
PublicSource {
kind: "huggingface".to_string(),
repo: repo.to_string(),
url: default_source_url("huggingface", repo),
revision: revision.unwrap_or_else(|| "main".to_string()),
}
}
fn default_source_url(kind: &str, repo: &str) -> String {
match kind {
"huggingface" => format!("https://huggingface.co/datasets/{repo}"),
_ => repo.to_string(),
}
}
#[derive(Debug, Clone)]
struct PublicSource {
kind: String,
repo: String,
url: String,
revision: String,
}
impl PublicSource {
fn with_url(mut self, url: String) -> Self {
self.url = url;
self
}
}
fn read_existing_dataset(path: &Path) -> RhoResult<ExistingDataset> {
if !path.is_file() {
return Ok(ExistingDataset::default());
}
let text = fs::read_to_string(path)?;
let manifest: ExistingDatasetManifest = from_yaml(&text)?;
Ok(ExistingDataset {
uuid: Some(manifest.dataset.uuid),
description: manifest.dataset.description,
})
}
#[derive(Debug, Default)]
struct ExistingDataset {
uuid: Option<String>,
description: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ExistingDatasetManifest {
dataset: ExistingDatasetRecord,
}
#[derive(Debug, Deserialize)]
struct ExistingDatasetRecord {
uuid: String,
#[serde(default)]
description: Option<String>,
}
fn active_handle() -> Option<String> {
env::var("RHO_ENV_HANDLE")
.ok()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
}
fn dataset_path_slug(value: &str) -> RhoResult<String> {
if value.is_empty()
|| !value
.chars()
.all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' || ch == '.')
{
return Err(format!("dataset name is not path-safe: {value}").into());
}
Ok(value.to_string())
}
fn should_branch(args: &[String]) -> bool {
!has_flag(args, "--no-branch")
&& (has_flag(args, "--pr") || has_flag(args, "--push") || has_flag(args, "--branch"))
}
fn finish_dataset_set(args: &[String], root: &Path, dataset_slug: &str) -> RhoResult<()> {
let want_pr = has_flag(args, "--pr");
let want_commit = want_pr || has_flag(args, "--commit");
let want_push = want_pr || has_flag(args, "--push");
if !want_commit && !want_push && !want_pr {
return Ok(());
}
if want_commit {
git_add_relative(root, &format!("datasets/{dataset_slug}"))?;
if git_has_staged_changes(root)? {
git_commit(root, &format!("Set {dataset_slug} dataset public source"))?;
println!("committed: Set {dataset_slug} dataset public source");
} else {
println!("commit: no staged dataset changes");
}
}
if want_push {
git_push_current_branch(root)?;
println!("pushed: {}", current_branch(root)?);
}
if want_pr {
create_pr(
root,
&format!("Set {dataset_slug} dataset public source"),
&format!("Declare the public source for dataset {dataset_slug}."),
)?;
}
Ok(())
}
fn switch_branch(root: &Path, branch: &str) -> RhoResult<()> {
let status = Command::new("git")
.arg("-C")
.arg(root)
.args(["switch", "-C", branch])
.status()?;
if !status.success() {
return Err(format!("git switch failed for branch {branch}").into());
}
Ok(())
}
fn git_add_relative(root: &Path, path: &str) -> RhoResult<()> {
let status = Command::new("git")
.arg("-C")
.arg(root)
.args(["add", "--", path])
.status()?;
if !status.success() {
return Err(format!("git add failed for {path}").into());
}
Ok(())
}
fn git_has_staged_changes(root: &Path) -> RhoResult<bool> {
let status = Command::new("git")
.arg("-C")
.arg(root)
.args(["diff", "--cached", "--quiet"])
.status()?;
match status.code() {
Some(0) => Ok(false),
Some(1) => Ok(true),
_ => Err(format!("git diff --cached --quiet failed in {}", root.display()).into()),
}
}
fn git_commit(root: &Path, message: &str) -> RhoResult<()> {
let exe = env::current_exe()?;
let status = Command::new(exe)
.arg("commit")
.arg("-C")
.arg(root)
.arg("-m")
.arg(message)
.status()?;
if !status.success() {
return Err(format!("rho commit failed in {}", root.display()).into());
}
Ok(())
}
fn git_push_current_branch(root: &Path) -> RhoResult<()> {
let branch = current_branch(root)?;
let status = Command::new("git")
.arg("-C")
.arg(root)
.args(["push", "-u", "origin", &branch])
.status()?;
if !status.success() {
return Err(format!("git push failed for {branch}").into());
}
Ok(())
}
fn current_branch(root: &Path) -> RhoResult<String> {
let output = Command::new("git")
.arg("-C")
.arg(root)
.args(["branch", "--show-current"])
.output()?;
if !output.status.success() {
return Err(format!("git branch --show-current failed in {}", root.display()).into());
}
let branch = String::from_utf8(output.stdout)?.trim().to_string();
if branch.is_empty() {
return Err("current Git branch is empty".into());
}
Ok(branch)
}
fn create_pr(root: &Path, title: &str, body: &str) -> RhoResult<()> {
let exe = env::current_exe()?;
let status = Command::new(exe)
.arg("repo")
.arg("create-pr")
.arg("--root")
.arg(root)
.arg("--title")
.arg(title)
.arg("--body")
.arg(body)
.status()?;
if !status.success() {
return Err(format!("rho repo create-pr failed in {}", root.display()).into());
}
Ok(())
}
fn github_handle_from_identity(identity_id: &str) -> RhoResult<String> {
let Some(handle) = identity_id.strip_prefix("rho://id/github/") else {
return Err(format!("unsupported identity id: {identity_id}").into());
};
if handle.is_empty() || handle.contains('/') {
return Err(format!("unsupported identity id: {identity_id}").into());
}
Ok(handle.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_huggingface_public_source() {
let source =
public_source_from_value("repo:huggingface:madhavajay/1kgp-bv-all", None).unwrap();
assert_eq!(source.kind, "huggingface");
assert_eq!(source.repo, "madhavajay/1kgp-bv-all");
assert_eq!(
source.url,
"https://huggingface.co/datasets/madhavajay/1kgp-bv-all"
);
assert_eq!(source.revision, "main");
}
}