use std::{collections::HashMap, str::FromStr};
use exn::{Exn, OptionExt, ResultExt};
use reqwest::{
header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT},
ClientBuilder,
};
use serde_json::Value as JsonValue;
use url::Url;
use crate::{
datasets::{
Arxiv, DataDryad, Dataone, DataverseDataset, DataverseFile, GitHub, HalScience,
HuggingFace, MaterialsCloud, Zenodo, OSF,
},
repo::Dataset,
};
use crate::helper::json_extract;
use std::collections::HashSet;
use std::sync::LazyLock;
#[derive(Debug)]
pub struct DispatchError {
pub message: String,
}
impl std::fmt::Display for DispatchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for DispatchError {}
#[derive(Debug)]
pub struct ResolveError {
pub message: String,
}
impl std::fmt::Display for ResolveError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for ResolveError {}
static DATAONE_DOMAINS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
HashSet::from([
"arcticdata.io",
"knb.ecoinformatics.org",
"data.pndb.fr",
"opc.dataone.org",
"portal.edirepository.org",
"goa.nceas.ucsb.edu",
"data.piscoweb.org",
"adc.arm.gov",
"scidb.cn",
"data.ess-dive.lbl.gov",
"hydroshare.org",
"ecl.earthchem.org",
"get.iedadata.org",
"usap-dc.org",
"iys.hakai.org",
"doi.pangaea.de",
"rvdata.us",
"sead-published.ncsa.illinois.edu",
])
});
static HAL_DOMAINS: LazyLock<HashSet<&'static str>> =
LazyLock::new(|| HashSet::from(["hal.science", "inrae.fr"]));
static DATAVERSE_DOMAINS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
HashSet::from([
"www.march.es",
"www.murray.harvard.edu",
"abacus.library.ubc.ca",
"ada.edu.au",
"adattar.unideb.hu",
"archive.data.jhu.edu",
"borealisdata.ca",
"dados.ipb.pt",
"dadosdepesquisa.fiocruz.br",
"darus.uni-stuttgart.de",
"data.aussda.at",
"data.cimmyt.org",
"data.fz-juelich.de",
"data.goettingen-research-online.de",
"data.inrae.fr",
"data.scielo.org",
"data.sciencespo.fr",
"data.tdl.org",
"data.univ-gustave-eiffel.fr",
"datarepositorium.uminho.pt",
"datasets.iisg.amsterdam",
"dataspace.ust.hk",
"dataverse.asu.edu",
"dataverse.cirad.fr",
"dataverse.csuc.cat",
"dataverse.harvard.edu",
"dataverse.iit.it",
"dataverse.uliege.be",
"dataverse.ird.fr",
"dataverse.lib.umanitoba.ca",
"dataverse.lib.unb.ca",
"dataverse.lib.virginia.edu",
"dataverse.nl",
"dataverse.no",
"dataverse.openforestdata.pl",
"dataverse.scholarsportal.info",
"dataverse.theacss.org",
"dataverse.ucla.edu",
"dataverse.unc.edu",
"dataverse.unimi.it",
"dataverse.yale-nus.edu.sg",
"dorel.univ-lorraine.fr",
"dvn.fudan.edu.cn",
"edatos.consorciomadrono.es",
"edmond.mpdl.mpg.de",
"heidata.uni-heidelberg.de",
"lida.dataverse.lt",
"mxrdr.icm.edu.pl",
"osnadata.ub.uni-osnabrueck.de",
"planetary-data-portal.org",
"qdr.syr.edu",
"rdm.aau.edu.et",
"rdr.kuleuven.be",
"rds.icm.edu.pl",
"recherche.data.gouv.fr",
"redu.unicamp.br",
"repod.icm.edu.pl",
"repositoriopesquisas.ibict.br",
"research-data.urosario.edu.co",
"researchdata.cuhk.edu.hk",
"researchdata.ntu.edu.sg",
"rin.lipi.go.id",
"ssri.is",
"www.seanoe.org",
"trolling.uit.no",
"www.sodha.be",
"www.uni-hildesheim.de",
"dataverse.acg.maine.edu",
"dataverse.icrisat.org",
"datos.pucp.edu.pe",
"datos.uchile.cl",
"opendata.pku.edu.cn",
"archaeology.datastations.nl",
"ssh.datastations.nl",
"lifesciences.datastations.nl",
"phys-techsciences.datastations.nl",
"dataverse.nl",
])
});
async fn github_get_default_branch_commit(
owner: &str,
repo: &str,
) -> Result<String, Exn<DispatchError>> {
let user_agent = format!("datahugger-cli/{}", env!("CARGO_PKG_VERSION"));
let mut headers = HeaderMap::new();
if let Ok(token) = std::env::var("GITHUB_TOKEN") {
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("token {token}")).unwrap(),
);
}
headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent).unwrap());
let client = ClientBuilder::new()
.user_agent(&user_agent)
.default_headers(headers)
.use_native_tls()
.build()
.unwrap();
let repo_url = format!("https://api.github.com/repos/{owner}/{repo}");
let resp: JsonValue = client
.get(&repo_url)
.send()
.await
.unwrap()
.error_for_status()
.unwrap()
.json()
.await
.unwrap();
let default_branch: String =
json_extract(&resp, "default_branch").map_err(|_| DispatchError {
message: "not able to get default branch".to_string(),
})?;
let commits_url =
format!("https://api.github.com/repos/{owner}/{repo}/commits/{default_branch}");
let resp: JsonValue = client
.get(&commits_url)
.header("User-Agent", user_agent.clone())
.send()
.await
.unwrap()
.error_for_status()
.unwrap()
.json()
.await
.unwrap();
let commit_sha: String = json_extract(&resp, "sha").map_err(|_| DispatchError {
message: "not able to get default branch".to_string(),
})?;
Ok(commit_sha)
}
async fn resolve_doi_to_url_with_base(
client: &reqwest::Client,
doi: &str,
base_url: Option<&str>,
follow_redirects: bool,
) -> Result<String, Exn<ResolveError>> {
if !(doi.starts_with("10.") && doi.contains('/')) {
exn::bail!(ResolveError {
message: format!("Invalid DOI: '{doi}'"),
});
}
let base_url = base_url.unwrap_or("https://doi.org/api/handles");
let res = match client
.get(format!("{}/{}", base_url, doi))
.query(&[("type", "URL")])
.send()
.await
{
Ok(res) => res,
Err(err) => {
exn::bail!(ResolveError {
message: format!("failed to resolve '{doi}': {err:?}")
})
}
};
let status = res.status();
if !status.is_success() {
exn::bail!(ResolveError {
message: format!("failed to resolve '{doi}': status {status}")
});
}
let json: serde_json::Value = match res.json().await {
Ok(json) => json,
Err(err) => {
exn::bail!(ResolveError {
message: format!("failed to parse response for '{doi}': {err:?}")
})
}
};
let url = match json.get("responseCode").and_then(|v| v.as_i64()) {
Some(1) => match json.get("values").and_then(|v| v.as_array()) {
Some(values) if !values.is_empty() => {
match values[0]
.get("data")
.and_then(|d| d.get("value"))
.and_then(|v| v.as_str())
{
Some(url) => Ok::<String, Exn<ResolveError>>(url.to_string()),
None => exn::bail!(ResolveError {
message: format!("missing data.value for '{doi}'")
}),
}
}
_ => exn::bail!(ResolveError {
message: format!("empty or missing values for '{doi}'")
}),
},
Some(code) => exn::bail!(ResolveError {
message: format!("unexpected responseCode {code} for '{doi}'")
}),
None => exn::bail!(ResolveError {
message: format!("missing responseCode for '{doi}'")
}),
}?;
if follow_redirects {
let res = match client.head(&url).send().await {
Ok(res) => res,
Err(err) => exn::bail!(ResolveError {
message: format!("failed to follow redirect for '{url}': {err:?}")
}),
};
Ok(res.url().to_string())
} else {
Ok(url)
}
}
pub async fn resolve_doi_to_url(
client: &reqwest::Client,
doi: &str,
follow_redirects: bool,
) -> Result<String, Exn<ResolveError>> {
resolve_doi_to_url_with_base(client, doi, None, follow_redirects).await
}
#[allow(clippy::too_many_lines)]
pub async fn resolve(url: &str) -> Result<Dataset, Exn<DispatchError>> {
let url = Url::from_str(url).or_raise(|| DispatchError {
message: format!("'{url}' not a valid url"),
})?;
let scheme = url.scheme();
let domain = url.domain().ok_or_else(|| DispatchError {
message: "domain unresolved".to_string(),
})?;
let host_str = url.host_str().ok_or_else(|| DispatchError {
message: format!("host_str unresolved from '{url}'"),
})?;
if DATAONE_DOMAINS.contains(domain) {
let base_url = format!("{scheme}://{host_str}");
let base_url = Url::from_str(&base_url).or_raise(|| DispatchError {
message: format!("'{base_url}' is not valid url"),
})?;
let mut segments = url.path_segments().ok_or_else(|| DispatchError {
message: format!("'{url}' cannot be base"),
})?;
let id = segments
.find(|pat| pat.starts_with("doi"))
.ok_or_raise(|| DispatchError {
message: format!("expect 'doi' in '{url}'"),
})?;
let dataset = Dataset::new(Dataone::new(&base_url, id));
return Ok(dataset);
}
if DATAVERSE_DOMAINS.contains(domain) {
let mut segments = url.path_segments().ok_or_else(|| DispatchError {
message: format!("'{url}' cannot be base"),
})?;
let typ = segments.next().ok_or_else(|| DispatchError {
message: format!("'{url}' no segments found"),
})?;
let queries = url.query_pairs();
let queries = queries.collect::<HashMap<_, _>>();
let Some(id) = queries.get("persistentId") else {
exn::bail!(DispatchError {
message: "query don't contains 'persistentId'".to_string()
})
};
let typ = typ.strip_suffix(".xhtml").ok_or_else(|| DispatchError {
message: "segment not in format *.xhtml".to_string(),
})?;
let base_url = format!("{scheme}://{host_str}");
let base_url = Url::from_str(&base_url).or_raise(|| DispatchError {
message: format!("'{base_url}' is not valid url"),
})?;
let version = ":latest-published".to_string();
match typ {
"dataset" => {
let dataset = Dataset::new(DataverseDataset::new(id.as_ref(), &base_url, &version));
return Ok(dataset);
}
"file" => {
let dataset = Dataset::new(DataverseFile::new(id.as_ref(), &base_url, &version));
return Ok(dataset);
}
ty => exn::bail!(DispatchError {
message: format!("{ty} is not valid type, can only be 'dataset' or 'file'")
}),
}
}
match domain {
"arxiv.org" => {
let mut segments = url.path_segments().ok_or_else(|| DispatchError {
message: format!("cannot get path segments of url '{}'", url.as_str()),
})?;
let id = segments
.next()
.and_then(|_| segments.next())
.ok_or(DispatchError {
message: format!("connot get record id from '{url}'"),
})?;
let dataset = Dataset::new(Arxiv::new(id));
Ok(dataset)
}
d if HAL_DOMAINS
.iter()
.any(|&hal_domain| d.ends_with(hal_domain)) =>
{
let mut segments = url.path_segments().ok_or_else(|| DispatchError {
message: format!("cannot get path segments of url '{}'", url.as_str()),
})?;
let id = segments.next().ok_or(DispatchError {
message: format!("connot get record id from '{url}'"),
})?;
let id = if let Some(pos) = id.rfind('v') {
&id[..pos] } else {
id };
let dataset = Dataset::new(HalScience::new(id));
Ok(dataset)
}
"huggingface.co" => {
eprintln!(
"\x1b[33mwarning:\x1b[0m for reliable downloads, consider using the official Hugging Face APIs:\n\
\x1b[36m - Rust:\x1b[0m hf_hub\n\
\x1b[36m - Python:\x1b[0m datasets\n\
\n\
\x1b[2mFor example, datahugger would handle caching, retries, and consistency for you.\x1b[0m"
);
let mut segments = url.path_segments().ok_or_else(|| DispatchError {
message: format!("cannot get path segments of url '{}'", url.as_str()),
})?;
let kind = segments.next().ok_or_else(|| DispatchError {
message: format!("missing repo kind in url '{}'", url.as_str()),
})?;
if kind != "datasets" {
exn::bail!(DispatchError {
message: format!("unsupported Hugging Face repo kind '{kind}'"),
});
}
let owner = segments.next().ok_or_else(|| DispatchError {
message: format!("missing owner in url '{}'", url.as_str()),
})?;
let repo = segments.next().ok_or_else(|| DispatchError {
message: format!("missing repo name in url '{}'", url.as_str()),
})?;
let (revision, _subpath) = match segments.next() {
Some("tree") => {
let rev = segments.next().ok_or_else(|| DispatchError {
message: format!("missing revision in url '{}'", url.as_str()),
})?;
let rest: Vec<&str> = segments.collect();
(rev, rest.join("/"))
}
_ => ("main", String::new()),
};
let dataset = Dataset::new(HuggingFace::new(owner, repo, revision));
Ok(dataset)
}
"archive.materialscloud.org" => {
let segments = url
.path_segments()
.ok_or_else(|| DispatchError {
message: format!("cannot get path segments of url '{}'", url.as_str()),
})?
.collect::<Vec<&str>>();
let record_id = if segments.len() >= 2 {
segments[1]
} else {
exn::bail!(DispatchError {
message: format!("unable to parse dryad dataset id from '{url}'",)
})
};
let dataset = Dataset::new(MaterialsCloud::new(record_id));
Ok(dataset)
}
"zenodo.org" => {
let segments = url
.path_segments()
.ok_or_else(|| DispatchError {
message: format!("cannot get path segments of url '{}'", url.as_str()),
})?
.collect::<Vec<&str>>();
let record_id = if segments.len() >= 2 {
segments[1]
} else {
exn::bail!(DispatchError {
message: format!("unable to parse dryad dataset id from '{url}'",)
})
};
let dataset = Dataset::new(Zenodo::new(record_id));
Ok(dataset)
}
"github.com" => {
let mut segments = url.path_segments().ok_or_else(|| DispatchError {
message: format!("cannot get path segments of url '{}'", url.as_str()),
})?;
let owner = segments.next().ok_or_else(|| DispatchError {
message: format!("missing owner in url '{}'", url.as_str()),
})?;
let repo_name = segments.next().ok_or_else(|| DispatchError {
message: format!("missing repo in url '{}'", url.as_str()),
})?;
let dataset = if let Some(branch_or_commit) =
segments.next().and_then(|_| segments.next())
{
Dataset::new(GitHub::new(owner, repo_name, branch_or_commit))
} else {
let branch_or_commit = github_get_default_branch_commit(owner, repo_name).await?;
Dataset::new(GitHub::new(owner, repo_name, branch_or_commit))
};
Ok(dataset)
}
"datadryad.org" => {
let segments = url
.path_segments()
.ok_or_else(|| DispatchError {
message: format!("cannot get path segments of url '{}'", url.as_str()),
})?
.collect::<Vec<&str>>();
let record_id = if segments.len() >= 3 && segments[0] == "dataset" {
format!("{}/{}", segments[1], segments[2])
} else {
exn::bail!(DispatchError {
message: format!("unable to parse dryad dataset id from '{url}'",)
})
};
let base_url = Url::from_str("https://datadryad.org/").or_raise(|| DispatchError {
message: "invalid base url".to_string(),
})?;
let dataset = Dataset::new(DataDryad::new(record_id, &base_url));
Ok(dataset)
}
"osf.io" => {
let mut segments = url.path_segments().ok_or_else(|| DispatchError {
message: format!("cannot get path segments of url '{}'", url.as_str()),
})?;
let id = segments.next().ok_or_else(|| DispatchError {
message: format!("no segments path in url '{}'", url.as_str()),
})?;
let dataset = Dataset::new(OSF::new(id));
Ok(dataset)
}
"data.mendeley.com" => {
unimplemented!("help us! open an issue to request or PR to help us.")
}
"data.4tu.nl" => {
unimplemented!("help us! open an issue to request or PR to help us.")
}
"b2share.eudat.eu" | "data.europa.eu" => {
unimplemented!("help us! open an issue to request or PR to help us.")
}
_ => {
exn::bail!(DispatchError {
message: format!("unknown domain: {domain}")
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use wiremock::matchers::{method, path, query_param};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[tokio::test]
async fn test_resolve_dataverse_default() {
let url = "https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/KBHLOD";
let qr = resolve(url).await.unwrap();
let qr = qr
.backend
.as_any()
.downcast_ref::<DataverseDataset>()
.unwrap();
assert_eq!(qr.id.as_str(), "doi:10.7910/DVN/KBHLOD");
let url =
"https://dataverse.harvard.edu/file.xhtml?persistentId=doi:10.7910/DVN/KBHLOD/DHJ45U";
let qr = resolve(url).await.unwrap();
let qr = qr.backend.as_any().downcast_ref::<DataverseFile>().unwrap();
assert_eq!(qr.id.as_str(), "doi:10.7910/DVN/KBHLOD/DHJ45U");
}
#[tokio::test]
async fn test_resolve_default() {
for url in ["https://osf.io/dezms/overview", "https://osf.io/dezms/"] {
let qr = resolve(url).await.unwrap();
let qr = qr.backend.as_any().downcast_ref::<OSF>().unwrap();
assert_eq!(qr.id.as_str(), "dezms");
}
let url = "https://arxiv.org/abs/2101.00001v1";
let qr = resolve(url).await.unwrap();
let qr = qr.backend.as_any().downcast_ref::<Arxiv>().unwrap();
assert_eq!(qr.id.as_str(), "2101.00001v1");
let url = "https://arcticdata.io/catalog/view/doi%3A10.18739%2FA2542JB2X";
let qr = resolve(url).await.unwrap();
let qr = qr.backend.as_any().downcast_ref::<Dataone>().unwrap();
assert_eq!(qr.id.as_str(), "doi%3A10.18739%2FA2542JB2X");
assert_eq!(qr.base_url.as_str(), "https://arcticdata.io/");
let url = "https://datadryad.org/dataset/doi:10.5061/dryad.mj8m0";
let qr = resolve(url).await.unwrap();
let qr = qr.backend.as_any().downcast_ref::<DataDryad>().unwrap();
assert_eq!(qr.id.as_str(), "doi:10.5061/dryad.mj8m0");
let url = "https://hal.science/cel-01830944";
let qr = resolve(url).await.unwrap();
let qr = qr.backend.as_any().downcast_ref::<HalScience>().unwrap();
assert_eq!(qr.id.as_str(), "cel-01830944");
let url = "https://huggingface.co/datasets/HuggingFaceFW/finepdfs";
let qr = resolve(url).await.unwrap();
let qr = qr.backend.as_any().downcast_ref::<HuggingFace>().unwrap();
assert_eq!(qr.owner.as_str(), "HuggingFaceFW");
assert_eq!(qr.repo.as_str(), "finepdfs");
assert_eq!(qr.revision.as_str(), "main");
let url = "https://zenodo.org/records/17867222";
let qr = resolve(url).await.unwrap();
let qr = qr.backend.as_any().downcast_ref::<Zenodo>().unwrap();
assert_eq!(qr.id.as_str(), "17867222");
}
#[tokio::test]
async fn test_resolve_doi_to_url() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/10.34894/0B7ZLK"))
.and(query_param("type", "URL"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"responseCode": 1,
"values": [
{
"index": 1,
"type": "URL",
"data": {
"format": "string",
"value": "https://dataverse.nl/citation?persistentId=doi:10.34894/0B7ZLK"
},
"ttl": 86400,
"timestamp": "2021-12-23T16:59:30Z"
}
]
})))
.mount(&mock_server)
.await;
let client = reqwest::Client::builder()
.use_native_tls()
.timeout(Duration::from_secs(5))
.build()
.unwrap();
let res = resolve_doi_to_url_with_base(
&client,
"10.34894/0B7ZLK",
Some(&mock_server.uri()),
false,
)
.await;
assert!(res.is_ok());
let url = res.unwrap();
assert_eq!(
url,
"https://dataverse.nl/citation?persistentId=doi:10.34894/0B7ZLK"
);
let res = resolve_doi_to_url_with_base(
&client,
"https://doi.org/10.34894/0B7ZLK",
Some(&mock_server.uri()),
false,
)
.await;
assert!(res.is_err());
assert_eq!(
res.unwrap_err().message,
"Invalid DOI: 'https://doi.org/10.34894/0B7ZLK'"
);
}
}