use hf_hub::api::tokio::ApiRepo;
use serde::Deserialize;
use crate::error::FetchError;
#[derive(Debug, Clone)]
pub struct RepoFile {
pub filename: String,
pub size: Option<u64>,
pub sha256: Option<String>,
}
pub async fn list_repo_files(repo: &ApiRepo, repo_id: String) -> Result<Vec<RepoFile>, FetchError> {
let info = repo.info().await.map_err(|e| {
let msg = e.to_string();
if msg.contains("404") {
FetchError::RepoNotFound { repo_id }
} else {
FetchError::Api(e)
}
})?;
let files = info
.siblings
.into_iter()
.map(|s| RepoFile {
filename: s.rfilename,
size: None,
sha256: None,
})
.collect();
Ok(files)
}
#[derive(Debug, Deserialize)]
struct ApiSibling {
rfilename: String,
#[serde(default)]
size: Option<u64>,
#[serde(default)]
lfs: Option<ApiLfs>,
}
#[derive(Debug, Deserialize)]
struct ApiLfs {
sha256: String,
size: u64,
}
#[derive(Debug, Deserialize)]
struct ApiModelInfo {
siblings: Vec<ApiSibling>,
}
pub async fn list_repo_files_with_metadata(
repo_id: &str,
token: Option<&str>,
revision: Option<&str>,
) -> Result<Vec<RepoFile>, FetchError> {
let mut url = format!("https://huggingface.co/api/models/{repo_id}?blobs=true");
if let Some(rev) = revision {
url = format!("{url}&revision={rev}");
}
let client = reqwest::Client::new();
let mut request = client.get(url.as_str());
if let Some(t) = token {
request = request.bearer_auth(t);
}
let response = request
.send()
.await
.map_err(|e| FetchError::Http(e.to_string()))?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Err(FetchError::RepoNotFound {
repo_id: repo_id.to_owned(),
});
}
if !response.status().is_success() {
return Err(FetchError::Http(format!(
"HF API returned status {}",
response.status()
)));
}
let info: ApiModelInfo = response
.json()
.await
.map_err(|e| FetchError::Http(e.to_string()))?;
let files = info
.siblings
.into_iter()
.map(|s| {
let (size, sha256) = match s.lfs {
Some(lfs) => (Some(lfs.size), Some(lfs.sha256)),
None => (s.size, None),
};
RepoFile {
filename: s.rfilename,
size,
sha256,
}
})
.collect();
Ok(files)
}